Repository: gofiber/fiber Branch: main Commit: d6eb4973338e Files: 382 Total size: 3.6 MB Directory structure: gitextract_9m90gl79/ ├── .cspell.json ├── .editorconfig ├── .gitattributes ├── .github/ │ ├── .editorconfig │ ├── .hound.yml │ ├── CODEOWNERS │ ├── CODE_OF_CONDUCT.md │ ├── CONTRIBUTING.md │ ├── FUNDING.yml │ ├── ISSUE_TEMPLATE/ │ │ ├── bug-report.yaml │ │ ├── config.yml │ │ ├── feature-request.yaml │ │ ├── maintenance-task.yaml │ │ └── question.yaml │ ├── README.md │ ├── SECURITY.md │ ├── codecov.yml │ ├── config.yml │ ├── copilot-instructions.md │ ├── copilot-setup-steps.yml │ ├── dependabot.yml │ ├── index.html │ ├── labeler.yml │ ├── pull_request_template.md │ ├── release-drafter.yml │ ├── release.yml │ ├── scripts/ │ │ └── sync_docs.sh │ ├── testdata/ │ │ ├── ca-chain.cert.pem │ │ ├── fs/ │ │ │ ├── css/ │ │ │ │ ├── style.css │ │ │ │ └── test/ │ │ │ │ └── style2.css │ │ │ ├── img/ │ │ │ │ ├── fiberpng │ │ │ │ └── fiberpng.notvalidext │ │ │ └── index.html │ │ ├── hello_world.tmpl │ │ ├── index.html │ │ ├── index.tmpl │ │ ├── main.tmpl │ │ ├── ssl.key │ │ ├── ssl.pem │ │ ├── template-invalid.html │ │ ├── template.tmpl │ │ └── testRoutes.json │ ├── testdata2/ │ │ └── bruh.tmpl │ ├── testdata3/ │ │ └── hello_world.tmpl │ └── workflows/ │ ├── auto-labeler.yml │ ├── benchmark.yml │ ├── codeql-analysis.yml │ ├── dependabot_automerge.yml │ ├── linter.yml │ ├── manual-dependabot.yml │ ├── markdown.yml │ ├── modernize.yml │ ├── move-closed-milestone-items.yml │ ├── release-drafter.yml │ ├── spell-check.yml │ ├── sync-docs.yml │ ├── test.yml │ ├── v3-label-automation.yml │ └── vulncheck.yml ├── .gitignore ├── .golangci.yml ├── .markdownlint.yml ├── AGENTS.md ├── LICENSE ├── Makefile ├── adapter.go ├── adapter_test.go ├── addon/ │ └── retry/ │ ├── README.md │ ├── config.go │ ├── config_test.go │ ├── exponential_backoff.go │ └── exponential_backoff_test.go ├── app.go ├── app_integration_test.go ├── app_test.go ├── bind.go ├── bind_test.go ├── binder/ │ ├── README.md │ ├── binder.go │ ├── binder_test.go │ ├── cbor.go │ ├── cbor_test.go │ ├── cookie.go │ ├── cookie_test.go │ ├── form.go │ ├── form_test.go │ ├── header.go │ ├── header_test.go │ ├── json.go │ ├── json_test.go │ ├── mapping.go │ ├── mapping_test.go │ ├── msgpack.go │ ├── msgpack_test.go │ ├── query.go │ ├── query_test.go │ ├── resp_header.go │ ├── resp_header_test.go │ ├── uri.go │ ├── uri_test.go │ ├── xml.go │ └── xml_test.go ├── client/ │ ├── README.md │ ├── client.go │ ├── client_test.go │ ├── cookiejar.go │ ├── cookiejar_test.go │ ├── core.go │ ├── core_test.go │ ├── errors.go │ ├── helper_test.go │ ├── hooks.go │ ├── hooks_test.go │ ├── request.go │ ├── request_bench_test.go │ ├── request_test.go │ ├── response.go │ ├── response_test.go │ ├── transport.go │ └── transport_test.go ├── color.go ├── constants.go ├── ctx.go ├── ctx_interface.go ├── ctx_interface_gen.go ├── ctx_test.go ├── docs/ │ ├── addon/ │ │ ├── _category_.json │ │ └── retry.md │ ├── api/ │ │ ├── _category_.json │ │ ├── app.md │ │ ├── bind.md │ │ ├── constants.md │ │ ├── ctx.md │ │ ├── fiber.md │ │ ├── hooks.md │ │ ├── log.md │ │ ├── redirect.md │ │ ├── services.md │ │ └── state.md │ ├── client/ │ │ ├── _category_.json │ │ ├── examples.md │ │ ├── hooks.md │ │ ├── request.md │ │ ├── response.md │ │ └── rest.md │ ├── extra/ │ │ ├── _category_.json │ │ ├── benchmarks.md │ │ ├── faq.md │ │ ├── internal.md │ │ └── learning-resources.md │ ├── guide/ │ │ ├── _category_.json │ │ ├── advance-format.md │ │ ├── context.md │ │ ├── error-handling.md │ │ ├── extractors.md │ │ ├── faster-fiber.md │ │ ├── grouping.md │ │ ├── reverse-proxy.md │ │ ├── routing.md │ │ ├── templates.md │ │ ├── utils.md │ │ └── validation.md │ ├── intro.md │ ├── middleware/ │ │ ├── _category_.json │ │ ├── adaptor.md │ │ ├── basicauth.md │ │ ├── cache.md │ │ ├── compress.md │ │ ├── cors.md │ │ ├── csrf.md │ │ ├── earlydata.md │ │ ├── encryptcookie.md │ │ ├── envvar.md │ │ ├── etag.md │ │ ├── expvar.md │ │ ├── favicon.md │ │ ├── healthcheck.md │ │ ├── helmet.md │ │ ├── idempotency.md │ │ ├── keyauth.md │ │ ├── limiter.md │ │ ├── logger.md │ │ ├── paginate.md │ │ ├── pprof.md │ │ ├── proxy.md │ │ ├── recover.md │ │ ├── redirect.md │ │ ├── requestid.md │ │ ├── responsetime.md │ │ ├── rewrite.md │ │ ├── session.md │ │ ├── skip.md │ │ ├── static.md │ │ └── timeout.md │ ├── partials/ │ │ └── routing/ │ │ └── handler.md │ └── whats_new.md ├── error.go ├── error_test.go ├── errors_internal.go ├── extractors/ │ ├── README.md │ ├── extractors.go │ └── extractors_test.go ├── go.mod ├── go.sum ├── group.go ├── helpers.go ├── helpers_fuzz_test.go ├── helpers_test.go ├── hooks.go ├── hooks_test.go ├── internal/ │ ├── memory/ │ │ ├── memory.go │ │ └── memory_test.go │ ├── storage/ │ │ └── memory/ │ │ ├── config.go │ │ ├── memory.go │ │ └── memory_test.go │ └── tlstest/ │ └── tls.go ├── listen.go ├── listen_test.go ├── log/ │ ├── default.go │ ├── default_test.go │ ├── fiberlog.go │ ├── fiberlog_test.go │ └── log.go ├── middleware/ │ ├── adaptor/ │ │ ├── adaptor.go │ │ └── adaptor_test.go │ ├── basicauth/ │ │ ├── basicauth.go │ │ ├── basicauth_test.go │ │ └── config.go │ ├── cache/ │ │ ├── cache.go │ │ ├── cache_test.go │ │ ├── config.go │ │ ├── heap.go │ │ ├── manager.go │ │ ├── manager_msgp.go │ │ ├── manager_msgp_test.go │ │ └── manager_test.go │ ├── compress/ │ │ ├── compress.go │ │ ├── compress_test.go │ │ └── config.go │ ├── cors/ │ │ ├── config.go │ │ ├── cors.go │ │ ├── cors_test.go │ │ ├── utils.go │ │ └── utils_test.go │ ├── csrf/ │ │ ├── config.go │ │ ├── config_test.go │ │ ├── csrf.go │ │ ├── csrf_test.go │ │ ├── helpers.go │ │ ├── helpers_test.go │ │ ├── session_manager.go │ │ ├── storage_manager.go │ │ ├── storage_manager_msgp.go │ │ ├── storage_manager_msgp_test.go │ │ └── token.go │ ├── earlydata/ │ │ ├── config.go │ │ ├── earlydata.go │ │ └── earlydata_test.go │ ├── encryptcookie/ │ │ ├── config.go │ │ ├── config_test.go │ │ ├── encryptcookie.go │ │ ├── encryptcookie_test.go │ │ └── utils.go │ ├── envvar/ │ │ ├── config.go │ │ ├── envvar.go │ │ └── envvar_test.go │ ├── etag/ │ │ ├── config.go │ │ ├── etag.go │ │ └── etag_test.go │ ├── expvar/ │ │ ├── config.go │ │ ├── expvar.go │ │ └── expvar_test.go │ ├── favicon/ │ │ ├── config.go │ │ ├── favicon.go │ │ └── favicon_test.go │ ├── healthcheck/ │ │ ├── config.go │ │ ├── healthcheck.go │ │ └── healthcheck_test.go │ ├── helmet/ │ │ ├── config.go │ │ ├── helmet.go │ │ └── helmet_test.go │ ├── idempotency/ │ │ ├── config.go │ │ ├── idempotency.go │ │ ├── idempotency_test.go │ │ ├── locker.go │ │ ├── locker_test.go │ │ ├── response.go │ │ ├── response_msgp.go │ │ ├── response_msgp_test.go │ │ └── stub_test.go │ ├── keyauth/ │ │ ├── config.go │ │ ├── config_test.go │ │ ├── keyauth.go │ │ └── keyauth_test.go │ ├── limiter/ │ │ ├── config.go │ │ ├── limiter.go │ │ ├── limiter_fixed.go │ │ ├── limiter_sliding.go │ │ ├── limiter_test.go │ │ ├── manager.go │ │ ├── manager_msgp.go │ │ └── manager_msgp_test.go │ ├── logger/ │ │ ├── config.go │ │ ├── data.go │ │ ├── default_logger.go │ │ ├── errors.go │ │ ├── format.go │ │ ├── logger.go │ │ ├── logger_test.go │ │ ├── tags.go │ │ ├── template_chain.go │ │ └── utils.go │ ├── paginate/ │ │ ├── config.go │ │ ├── page_info.go │ │ ├── paginate.go │ │ └── paginate_test.go │ ├── pprof/ │ │ ├── config.go │ │ ├── pprof.go │ │ └── pprof_test.go │ ├── proxy/ │ │ ├── config.go │ │ ├── proxy.go │ │ └── proxy_test.go │ ├── recover/ │ │ ├── config.go │ │ ├── recover.go │ │ └── recover_test.go │ ├── redirect/ │ │ ├── config.go │ │ ├── redirect.go │ │ └── redirect_test.go │ ├── requestid/ │ │ ├── config.go │ │ ├── requestid.go │ │ └── requestid_test.go │ ├── responsetime/ │ │ ├── config.go │ │ ├── responsetime.go │ │ └── responsetime_test.go │ ├── rewrite/ │ │ ├── config.go │ │ ├── rewrite.go │ │ └── rewrite_test.go │ ├── session/ │ │ ├── config.go │ │ ├── config_test.go │ │ ├── data.go │ │ ├── data_msgp.go │ │ ├── data_msgp_test.go │ │ ├── data_test.go │ │ ├── middleware.go │ │ ├── middleware_test.go │ │ ├── session.go │ │ ├── session_test.go │ │ ├── store.go │ │ └── store_test.go │ ├── skip/ │ │ ├── skip.go │ │ └── skip_test.go │ ├── static/ │ │ ├── config.go │ │ ├── static.go │ │ └── static_test.go │ └── timeout/ │ ├── config.go │ ├── timeout.go │ └── timeout_test.go ├── mount.go ├── mount_test.go ├── path.go ├── path_test.go ├── path_testcases_test.go ├── prefork.go ├── prefork_test.go ├── readonly.go ├── readonly_strict.go ├── redirect.go ├── redirect_msgp.go ├── redirect_msgp_test.go ├── redirect_test.go ├── register.go ├── req.go ├── req_interface_gen.go ├── res.go ├── res_interface_gen.go ├── router.go ├── router_test.go ├── services.go ├── services_test.go ├── state.go ├── state_test.go └── storage_interface.go ================================================ FILE CONTENTS ================================================ ================================================ FILE: .cspell.json ================================================ { "version": "0.2", "language": "en, en-gb, en-us", "useGitignore": true, "caseSensitive": false, "import": [ "@cspell/dict-en_us/cspell-ext.json", "@cspell/dict-en-gb/cspell-ext.json", "@cspell/dict-software-terms/cspell-ext.json", "@cspell/dict-golang/cspell-ext.json", "@cspell/dict-fullstack/cspell-ext.json", "@cspell/dict-docker/cspell-ext.json", "@cspell/dict-k8s/cspell-ext.json", "@cspell/dict-node/cspell-ext.json", "@cspell/dict-npm/cspell-ext.json", "@cspell/dict-typescript/cspell-ext.json", "@cspell/dict-html/cspell-ext.json", "@cspell/dict-css/cspell-ext.json", "@cspell/dict-shell/cspell-ext.json", "@cspell/dict-python/cspell-ext.json", "@cspell/dict-redis/cspell-ext.json", "@cspell/dict-sql/cspell-ext.json", "@cspell/dict-filetypes/cspell-ext.json", "@cspell/dict-companies/cspell-ext.json", "@cspell/dict-markdown/cspell-ext.json", "@cspell/dict-en-common-misspellings/cspell-ext.json", "@cspell/dict-people-names/cspell-ext.json" ], "dictionaries": [ "en_us", "en-gb", "softwareTerms", "web-services", "networking-terms", "software-term-suggestions", "software-services", "software-terms", "software-tools", "coding-compound-terms", "golang", "fullstack", "docker", "k8s", "node", "npm", "typescript", "html", "css", "shell", "python", "redis", "sql", "filetypes", "companies", "markdown", "en-common-misspellings", "people-names", "data-science", "data-science-models", "data-science-tools" ], "ignorePaths": [ ".git", "node_modules", "vendor", "internal", ".github", "**/*.svg", "**/*.png", "**/*.jpg", "**/*.jpeg", "**/*.gif", "**/*.ico", "**/*.lock", "**/*_gen.go", "**/*_msgp.go", "**/*_msgp_test.go", "go.mod", "go.sum", ".golangci.yml", ".markdownlint.yml", "AGENTS.md" ] } ================================================ FILE: .editorconfig ================================================ ; This file is for unifying the coding style for different editors and IDEs. ; More information at http://editorconfig.org ; This style originates from https://github.com/fewagency/best-practices root = true [*] charset = utf-8 end_of_line = lf insert_final_newline = true trim_trailing_whitespace = true [*.go] indent_style = tab indent_size = 4 tab_width = 4 [Makefile] indent_style = tab [*.{yml,yaml,json,md}] indent_style = space indent_size = 2 ================================================ FILE: .gitattributes ================================================ # Handle line endings automatically for files detected as text # and leave all files detected as binary untouched. * text=auto eol=lf # Force batch scripts to always use CRLF line endings so that if a repo is accessed # in Windows via a file share from Linux, the scripts will work. *.{cmd,[cC][mM][dD]} text eol=crlf *.{bat,[bB][aA][tT]} text eol=crlf # Force bash scripts to always use LF line endings so that if a repo is accessed # in Unix via a file share from Windows, the scripts will work. *.sh text eol=lf ================================================ FILE: .github/.editorconfig ================================================ ; https://editorconfig.org/ root = true [*] insert_final_newline = true charset = utf-8 trim_trailing_whitespace = true indent_style = space indent_size = 2 [{Makefile,go.mod,go.sum,*.go,.gitmodules}] indent_style = tab indent_size = 8 [*.md] indent_size = 4 trim_trailing_whitespace = false eclint_indent_style = unset [Dockerfile] indent_size = 4 ================================================ FILE: .github/.hound.yml ================================================ golint: enabled: false ================================================ FILE: .github/CODEOWNERS ================================================ * @gofiber/maintainers ================================================ FILE: .github/CODE_OF_CONDUCT.md ================================================ # Contributor Covenant Code of Conduct ## Our Pledge We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community. ## Our Standards Examples of behavior that contributes to a positive environment for our community include: - Demonstrating empathy and kindness toward other people - Being respectful of differing opinions, viewpoints, and experiences - Giving and gracefully accepting constructive feedback - Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience - Focusing on what is best not just for us as individuals, but for the overall community Examples of unacceptable behavior include: - The use of sexualized language or imagery, and sexual attention or advances of any kind - Trolling, insulting or derogatory comments, and personal or political attacks - Public or private harassment - Publishing others' private information, such as a physical or email address, without their explicit permission - Other conduct which could reasonably be considered inappropriate in a professional setting ## Enforcement Responsibilities Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful. Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate. ## Scope This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at [Discord](https://gofiber.io/discord). All complaints will be reviewed and investigated promptly and fairly. All community leaders are obligated to respect the privacy and security of the reporter of any incident. ## Enforcement Guidelines Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct: ### 1. Correction **Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community. **Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested. ### 2. Warning **Community Impact**: A violation through a single incident or series of actions. **Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban. ### 3. Temporary Ban **Community Impact**: A serious violation of community standards, including sustained inappropriate behavior. **Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban. ### 4. Permanent Ban **Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. **Consequence**: A permanent ban from any sort of public interaction within the community. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at [https://www.contributor-covenant.org/version/2/0/code_of_conduct.html](https://www.contributor-covenant.org/version/2/0/code_of_conduct.html). Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity). [homepage]: https://www.contributor-covenant.org For answers to common questions about this code of conduct, see the FAQ at [https://www.contributor-covenant.org/faq](https://www.contributor-covenant.org/faq). Translations are available at [https://www.contributor-covenant.org/translations](https://www.contributor-covenant.org/translations). ================================================ FILE: .github/CONTRIBUTING.md ================================================ # Contributing Before making any changes to this repository, we kindly request you to initiate discussions for proposed changes that do not yet have an associated [issue](https://github.com/gofiber/fiber/issues). Please use our [Discord](https://gofiber.io/discord) server to initiate these discussions. For [issue](https://github.com/gofiber/fiber/issues) that already exist, you may proceed with discussions using our [issue](https://github.com/gofiber/fiber/issues) tracker or any other suitable method, in consultation with the repository owners. Your collaboration is greatly appreciated. Please note: we have a [code of conduct](https://github.com/gofiber/fiber/blob/main/.github/CODE_OF_CONDUCT.md), please follow it in all your interactions with the `Fiber` project. ## Pull Requests or Commits Titles always we must use prefix according to below: > 🔥 Feature, ♻️ Refactor, 🩹 Fix, 🚨 Test, 📚 Doc, 🎨 Style - 🔥 Feature: Add flow to add person - ♻️ Refactor: Rename file X to Y - 🩹 Fix: Improve flow - 🚨 Test: Validate to add a new person - 📚 Doc: Translate to Portuguese middleware redirect - 🎨 Style: Respected pattern Golint All pull requests that contain a feature or fix are mandatory to have unit tests. Your PR is only to be merged if you respect this flow. ## 👍 Contribute If you want to say **thank you** and/or support the active development of `Fiber`: 1. Add a [GitHub Star](https://github.com/gofiber/fiber/stargazers) to the project. 2. Tweet about the project [on your 𝕏 (Twitter)](https://x.com/intent/tweet?text=%F0%9F%9A%80%20Fiber%20%E2%80%94%20is%20an%20Express.js%20inspired%20web%20framework%20build%20on%20Fasthttp%20for%20%23Go%20https%3A%2F%2Fgithub.com%2Fgofiber%2Ffiber). 3. Write a review or tutorial on [Medium](https://medium.com/), [Dev.to](https://dev.to/) or personal blog. 4. Support the project by donating a [cup of coffee](https://buymeacoff.ee/fenny). ================================================ FILE: .github/FUNDING.yml ================================================ # These are supported funding model platforms github: [gofiber] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] patreon: # Replace with a single Patreon username open_collective: # Replace with a single Open Collective username ko_fi: # Replace with a single Ko-fi username tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel custom: https://github.com/sponsors/gofiber ================================================ FILE: .github/ISSUE_TEMPLATE/bug-report.yaml ================================================ name: "\U0001F41B Bug Report" title: "\U0001F41B [Bug]: " description: Create a bug report to help us fix it. labels: ["☢️ Bug"] body: - type: markdown id: notice attributes: value: | ### Notice **This repository is not related to external or third-part Fiber modules. If you have a problem with them, open an issue under their repos. If you think the problem is related to Fiber, open the issue here.** - Don't forget you can ask your questions in our [Discord server](https://gofiber.io/discord). - If you have a suggestion for a Fiber feature you would like to see, open the issue with the **✏️ Feature Request** template. - Write your issue with clear and understandable English. - type: textarea id: description attributes: label: "Bug Description" description: "A clear and detailed description of what the bug is." placeholder: "Explain your problem clearly and in detail." validations: required: true - type: textarea id: how-to-reproduce attributes: label: How to Reproduce description: "Steps to reproduce the behavior and what should be observed in the end." placeholder: "Tell us step by step how we can replicate your problem and what we should see in the end." value: | Steps to reproduce the behavior: 1. Go to '....' 2. Click on '....' 3. Do '....' 4. See '....' validations: required: true - type: textarea id: expected-behavior attributes: label: Expected Behavior description: "A clear and detailed description of what you think should happen." placeholder: "Tell us what Fiber should normally do." validations: required: true - type: input id: version attributes: label: "Fiber Version" description: "Some bugs may be fixed in future Fiber releases, so we have to know your Fiber version." placeholder: "Write your Fiber version. (v2.33.0, v2.34.1...)" validations: required: true - type: textarea id: snippet attributes: label: "Code Snippet (optional)" description: "For some issues, we need to know some parts of your code." placeholder: "Share a code snippet that you think is related to the issue." render: go value: | package main import "github.com/gofiber/fiber/v3" import "log" func main() { app := fiber.New() // Steps to reproduce log.Fatal(app.Listen(":3000")) } - type: checkboxes id: terms attributes: label: "Checklist:" description: "By submitting this issue, you confirm that:" options: - label: "I agree to follow Fiber's [Code of Conduct](https://github.com/gofiber/fiber/blob/main/.github/CODE_OF_CONDUCT.md)." required: true - label: "I have checked for existing issues that describe my problem prior to opening this one." required: true - label: "I understand that improperly formatted bug reports may be closed without explanation." required: true ================================================ FILE: .github/ISSUE_TEMPLATE/config.yml ================================================ blank_issues_enabled: false ================================================ FILE: .github/ISSUE_TEMPLATE/feature-request.yaml ================================================ name: "📝 Feature Proposal" title: "📝 [Proposal]: " description: Propose a feature or improvement for Fiber. labels: ["📝 Proposal", "✏️ Feature", "v3"] body: - type: markdown id: notice attributes: value: | ### Notice - For questions, join our [Discord server](https://gofiber.io/discord). - Please write in clear, understandable English. - Ensure your proposal aligns with Express design principles and HTTP RFC standards. - Describe features expected to remain stable and not require changes in the foreseeable future. - type: textarea id: description attributes: label: "Feature Proposal Description" description: "A clear and detailed description of the feature you are proposing for Fiber v3. How should it work, and what API endpoints and methods would it involve?" placeholder: "Describe your feature proposal clearly and in detail, including API endpoints and methods." validations: required: true - type: textarea id: express-alignment attributes: label: "Alignment with Express API" description: "Explain how your proposal aligns with the design and API of Express.js. Provide comparative examples if possible." placeholder: "Outline how the feature aligns with Express.js design principles and API standards." validations: required: true - type: textarea id: standards-compliance attributes: label: "HTTP RFC Standards Compliance" description: "Confirm that the feature complies with HTTP RFC standards, and describe any relevant aspects." placeholder: "Detail how the feature adheres to HTTP RFC standards." validations: required: true - type: textarea id: stability attributes: label: "API Stability" description: "Discuss the expected stability of the feature and its API. How do you ensure that it will not require changes or deprecations in the near future?" placeholder: "Describe measures taken to ensure the feature's API stability over time." validations: required: true - type: textarea id: examples attributes: label: "Feature Examples" description: "Provide concrete examples and code snippets to illustrate how the proposed feature should function." placeholder: "Share code snippets that exemplify the proposed feature and its usage." render: go validations: required: true - type: checkboxes id: terms attributes: label: "Checklist:" description: "By submitting this issue, you confirm that:" options: - label: "I agree to follow Fiber's [Code of Conduct](https://github.com/gofiber/fiber/blob/main/.github/CODE_OF_CONDUCT.md)." required: true - label: "I have searched for existing issues that describe my proposal before opening this one." required: true - label: "I understand that a proposal that does not meet these guidelines may be closed without explanation." required: true ================================================ FILE: .github/ISSUE_TEMPLATE/maintenance-task.yaml ================================================ name: "🧹 Maintenance Task" title: "🧹 [Maintenance]: " description: Describe a maintenance task for the v3 of the Fiber project. labels: ["🧹 Updates", "v3"] body: - type: markdown id: notice attributes: value: | ### Notice - Before submitting a maintenance task, please check if a similar task has already been filed. - Clearly outline the purpose of the maintenance task and its impact on the project. - Use clear and understandable English. - type: textarea id: task-description attributes: label: "Maintenance Task Description" description: "Provide a detailed description of the maintenance task. Include any specific areas of the codebase that require attention, and the desired outcomes of this task." placeholder: "Detail the maintenance task, specifying what needs to be done and why it is necessary." validations: required: true - type: textarea id: impact attributes: label: "Impact on the Project" description: "Explain the impact this maintenance will have on the project. Include benefits and potential risks if applicable." placeholder: "Describe how completing this task will benefit the project, or the risks of not addressing it." validations: required: false - type: textarea id: additional-context attributes: label: "Additional Context (optional)" description: "Any additional information or context regarding the maintenance task that might be helpful." placeholder: "Provide any additional information that may be relevant to the task at hand." validations: required: false - type: checkboxes id: terms attributes: label: "Checklist:" description: "Please confirm the following:" options: - label: "I have confirmed that this maintenance task is currently not being addressed." required: true - label: "I understand that this task will be evaluated by the maintainers and prioritized accordingly." required: true - label: "I am available to provide further information if needed." required: true ================================================ FILE: .github/ISSUE_TEMPLATE/question.yaml ================================================ name: "🤔 Question" title: "\U0001F917 [Question]: " description: Ask a question so we can help you easily. labels: ["🤔 Question"] body: - type: markdown id: notice attributes: value: | ### Notice - Don't forget you can ask your questions in our [Discord server](https://gofiber.io/discord). - If you think this is just a bug, open the issue with the **☢️ Bug Report** template. - If you have a suggestion for a Fiber feature you would like to see, open the issue with the **✏️ Feature Request** template. - Write your issue with clear and understandable English. - type: textarea id: description attributes: label: "Question Description" description: "A clear and detailed description of the question." placeholder: "Explain your question clearly, and in detail." validations: required: true - type: textarea id: snippet attributes: label: "Code Snippet (optional)" description: "Code snippet may be really helpful to describe some features." placeholder: "Share a code snippet to explain the feature better." render: go value: | package main import "github.com/gofiber/fiber/v3" import "log" func main() { app := fiber.New() // An example to describe the question log.Fatal(app.Listen(":3000")) } - type: checkboxes id: terms attributes: label: "Checklist:" description: "By submitting this issue, you confirm that:" options: - label: "I agree to follow Fiber's [Code of Conduct](https://github.com/gofiber/fiber/blob/main/.github/CODE_OF_CONDUCT.md)." required: true - label: "I have checked for existing issues that describe my questions prior to opening this one." required: true - label: "I understand that improperly formatted questions may be closed without explanation." required: true ================================================ FILE: .github/README.md ================================================

Fiber
Codecov

Fiber is an Express inspired web framework built on top of Fasthttp, the fastest HTTP engine for Go. Designed to ease things up for fast development with zero memory allocation and performance in mind.

--- ## ⚙️ Installation Fiber requires **Go version `1.25` or higher** to run. If you need to install or upgrade Go, visit the [official Go download page](https://go.dev/dl/). To start setting up your project, create a new directory for your project and navigate into it. Then, initialize your project with Go modules by executing the following command in your terminal: ```bash go mod init github.com/your/repo ``` To learn more about Go modules and how they work, you can check out the [Using Go Modules](https://go.dev/blog/using-go-modules) blog post. After setting up your project, you can install Fiber with the `go get` command: ```bash go get -u github.com/gofiber/fiber/v3 ``` This command fetches the Fiber package and adds it to your project's dependencies, allowing you to start building your web applications with Fiber. ## ⚡️ Quickstart Getting started with Fiber is easy. Here's a basic example to create a simple web server that responds with "Hello, World 👋!" on the root path. This example demonstrates initializing a new Fiber app, setting up a route, and starting the server. ```go title="Example" package main import ( "log" "github.com/gofiber/fiber/v3" ) func main() { // Initialize a new Fiber app app := fiber.New() // Define a route for the GET method on the root path '/' app.Get("/", func(c fiber.Ctx) error { // Send a string response to the client return c.SendString("Hello, World 👋!") }) // Start the server on port 3000 log.Fatal(app.Listen(":3000")) } ``` This simple server is easy to set up and run. It introduces the core concepts of Fiber: app initialization, route definition, and starting the server. Just run this Go program, and visit `http://localhost:3000` in your browser to see the message. ## Zero Allocation Fiber is optimized for **high-performance**, meaning values returned from **fiber.Ctx** are **not** immutable by default and **will** be re-used across requests. As a rule of thumb, you **must** only use context values within the handler and **must not** keep any references. Once you return from the handler, any values obtained from the context will be re-used in future requests. Visit our [documentation](https://docs.gofiber.io/#zero-allocation) to learn more. ## 🤖 Benchmarks These tests are performed by [TechEmpower](https://www.techempower.com/benchmarks/#section=data-r19&hw=ph&test=plaintext). If you want to see all the results, please visit our [Wiki](https://docs.gofiber.io/extra/benchmarks).

## 🎯 Features - Robust [Routing](https://docs.gofiber.io/guide/routing) - Serve [Static Files](https://docs.gofiber.io/api/app#static) - Extreme [Performance](https://docs.gofiber.io/extra/benchmarks) - [Low Memory](https://docs.gofiber.io/extra/benchmarks) footprint - [API Endpoints](https://docs.gofiber.io/api/ctx) - [Middleware](https://docs.gofiber.io/category/-middleware) & [Next](https://docs.gofiber.io/api/ctx#next) support - [Rapid](https://dev.to/koddr/welcome-to-fiber-an-express-js-styled-fastest-web-framework-written-with-on-golang-497) server-side programming - [Template Engines](https://github.com/gofiber/template) - [WebSocket Support](https://github.com/gofiber/contrib/tree/main/websocket) - [Socket.io Support](https://github.com/gofiber/contrib/tree/main/socketio) - [Server-Sent Events](https://github.com/gofiber/recipes/tree/master/sse) - [Rate Limiter](https://docs.gofiber.io/api/middleware/limiter) - And much more, [explore Fiber](https://docs.gofiber.io/) ## 💡 Philosophy New gophers that make the switch from [Node.js](https://nodejs.org/en/about/) to [Go](https://go.dev/doc/) are dealing with a learning curve before they can start building their web applications or microservices. Fiber, as a **web framework**, was created with the idea of **minimalism** and follows the **UNIX way**, so that new gophers can quickly enter the world of Go with a warm and trusted welcome. Fiber is **inspired** by Express, the most popular web framework on the Internet. We combined the **ease** of Express and **raw performance** of Go. If you have ever implemented a web application in Node.js (_using Express or similar_), then many methods and principles will seem **very common** to you. We **listen** to our users in [issues](https://github.com/gofiber/fiber/issues), Discord [channel](https://gofiber.io/discord) _and all over the Internet_ to create a **fast**, **flexible** and **friendly** Go web framework for **any** task, **deadline** and developer **skill**! Just like Express does in the JavaScript world. ## ⚠️ Limitations - Due to Fiber's usage of unsafe, the library may not always be compatible with the latest Go version. Fiber v3 has been tested with Go version 1.25 or higher. - Fiber automatically adapts common `net/http` handler shapes when you register them on the router, and you can still use the [adaptor middleware](https://docs.gofiber.io/next/middleware/adaptor/) when you need to bridge entire apps or `net/http` middleware. ### net/http compatibility Fiber can run side by side with the standard library. The router accepts existing `net/http` handlers directly and even works with native `fasthttp.RequestHandler` callbacks, so you can plug in legacy endpoints without wrapping them manually: ```go package main import ( "log" "net/http" "github.com/gofiber/fiber/v3" ) func main() { httpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if _, err := w.Write([]byte("served by net/http")); err != nil { panic(err) } }) app := fiber.New() app.Get("/", httpHandler) // Start the server on port 3000 log.Fatal(app.Listen(":3000")) } ``` When you need to convert entire applications or re-use `net/http` middleware chains, rely on the [adaptor middleware](https://docs.gofiber.io/next/middleware/adaptor/). It converts handlers and middlewares in both directions and even lets you mount a Fiber app in a `net/http` server. ### Express-style handlers Fiber also adapts Express-style callbacks that operate on the lightweight `fiber.Req` and `fiber.Res` helper interfaces. This lets you port middleware and route handlers from Express-inspired codebases while keeping Fiber's router features: ```go // Request/response handlers (2-argument) app.Get("/", func(req fiber.Req, res fiber.Res) error { return res.SendString("Hello from Express-style handlers!") }) // Middleware with an error-returning next callback (3-argument) app.Use(func(req fiber.Req, res fiber.Res, next func() error) error { if req.IP() == "192.168.1.254" { return res.SendStatus(fiber.StatusForbidden) } return next() }) // Middleware with a no-arg next callback (3-argument) app.Use(func(req fiber.Req, res fiber.Res, next func()) { if req.Get("X-Skip") == "true" { return // stop the chain without calling next } next() }) ``` > **Note:** Adapted `net/http` handlers continue to operate with the standard-library semantics. They don't get access to `fiber.Ctx` features and incur the overhead of the compatibility layer, so native `fiber.Handler` callbacks still provide the best performance. ## 👀 Examples Listed below are some of the common examples. If you want to see more code examples, please visit our [Recipes repository](https://github.com/gofiber/recipes) or visit our hosted [API documentation](https://docs.gofiber.io). ### 📖 [**Basic Routing**](https://docs.gofiber.io/#basic-routing) ```go title="Example" package main import ( "fmt" "log" "github.com/gofiber/fiber/v3" ) func main() { app := fiber.New() // GET /api/register app.Get("/api/*", func(c fiber.Ctx) error { msg := fmt.Sprintf("✋ %s", c.Params("*")) return c.SendString(msg) // => ✋ register }) // GET /flights/LAX-SFO app.Get("/flights/:from-:to", func(c fiber.Ctx) error { msg := fmt.Sprintf("💸 From: %s, To: %s", c.Params("from"), c.Params("to")) return c.SendString(msg) // => 💸 From: LAX, To: SFO }) // GET /dictionary.txt app.Get("/:file.:ext", func(c fiber.Ctx) error { msg := fmt.Sprintf("📃 %s.%s", c.Params("file"), c.Params("ext")) return c.SendString(msg) // => 📃 dictionary.txt }) // GET /john/75 app.Get("/:name/:age/:gender?", func(c fiber.Ctx) error { msg := fmt.Sprintf("👴 %s is %s years old", c.Params("name"), c.Params("age")) return c.SendString(msg) // => 👴 john is 75 years old }) // GET /john app.Get("/:name", func(c fiber.Ctx) error { msg := fmt.Sprintf("Hello, %s 👋!", c.Params("name")) return c.SendString(msg) // => Hello john 👋! }) log.Fatal(app.Listen(":3000")) } ``` #### 📖 [**Route Naming**](https://docs.gofiber.io/api/app#name) ```go title="Example" package main import ( "encoding/json" "fmt" "log" "github.com/gofiber/fiber/v3" ) func main() { app := fiber.New() app.Get("/api/*", func(c fiber.Ctx) error { msg := fmt.Sprintf("✋ %s", c.Params("*")) return c.SendString(msg) // => ✋ register }).Name("api") route := app.GetRoute("api") data, _ := json.MarshalIndent(route, "", " ") fmt.Println(string(data)) // Prints: // { // "method": "GET", // "name": "api", // "path": "/api/*", // "params": [ // "*1" // ] // } log.Fatal(app.Listen(":3000")) } ``` #### 📖 [**Serving Static Files**](https://docs.gofiber.io/api/app#static) ```go title="Example" package main import ( "log" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/static" ) func main() { app := fiber.New() // Serve static files from the "./public" directory app.Get("/*", static.New("./public")) // => http://localhost:3000/js/script.js // => http://localhost:3000/css/style.css app.Get("/prefix*", static.New("./public")) // => http://localhost:3000/prefix/js/script.js // => http://localhost:3000/prefix/css/style.css // Serve a single file for any unmatched routes app.Get("*", static.New("./public/index.html")) // => http://localhost:3000/any/path/shows/index.html log.Fatal(app.Listen(":3000")) } ``` #### 📖 [**Middleware & Next**](https://docs.gofiber.io/api/ctx#next) ```go title="Example" package main import ( "fmt" "log" "github.com/gofiber/fiber/v3" ) func main() { app := fiber.New() // Middleware that matches any route app.Use(func(c fiber.Ctx) error { fmt.Println("🥇 First handler") return c.Next() }) // Middleware that matches all routes starting with /api app.Use("/api", func(c fiber.Ctx) error { fmt.Println("🥈 Second handler") return c.Next() }) // GET /api/list app.Get("/api/list", func(c fiber.Ctx) error { fmt.Println("🥉 Last handler") return c.SendString("Hello, World 👋!") }) log.Fatal(app.Listen(":3000")) } ```
📚 Show more code examples ### Views Engines 📖 [Config](https://docs.gofiber.io/api/fiber#config) 📖 [Engines](https://github.com/gofiber/template) 📖 [Render](https://docs.gofiber.io/api/ctx#render) Fiber defaults to the [html/template](https://pkg.go.dev/html/template/) when no view engine is set. If you want to execute partials or use a different engine like [amber](https://github.com/eknkc/amber), [handlebars](https://github.com/aymerick/raymond), [mustache](https://github.com/cbroglie/mustache), or [pug](https://github.com/Joker/jade), etc., check out our [Template](https://github.com/gofiber/template) package that supports multiple view engines. ```go title="Example" package main import ( "log" "github.com/gofiber/fiber/v3" "github.com/gofiber/template/pug" ) func main() { // Initialize a new Fiber app with Pug template engine app := fiber.New(fiber.Config{ Views: pug.New("./views", ".pug"), }) // Define a route that renders the "home.pug" template app.Get("/", func(c fiber.Ctx) error { return c.Render("home", fiber.Map{ "title": "Homepage", "year": 1999, }) }) log.Fatal(app.Listen(":3000")) } ``` ### Grouping Routes into Chains 📖 [Group](https://docs.gofiber.io/api/app#group) ```go title="Example" package main import ( "log" "github.com/gofiber/fiber/v3" ) func middleware(c fiber.Ctx) error { log.Println("Middleware executed") return c.Next() } func handler(c fiber.Ctx) error { return c.SendString("Handler response") } func main() { app := fiber.New() // Root API group with middleware api := app.Group("/api", middleware) // /api // API v1 routes v1 := api.Group("/v1", middleware) // /api/v1 v1.Get("/list", handler) // /api/v1/list v1.Get("/user", handler) // /api/v1/user // API v2 routes v2 := api.Group("/v2", middleware) // /api/v2 v2.Get("/list", handler) // /api/v2/list v2.Get("/user", handler) // /api/v2/user log.Fatal(app.Listen(":3000")) } ``` ### Middleware Logger 📖 [Logger](https://docs.gofiber.io/api/middleware/logger) ```go title="Example" package main import ( "log" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/logger" ) func main() { app := fiber.New() // Use Logger middleware app.Use(logger.New()) // Define routes app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, Logger!") }) log.Fatal(app.Listen(":3000")) } ``` ### Cross-Origin Resource Sharing (CORS) 📖 [CORS](https://docs.gofiber.io/api/middleware/cors) ```go title="Example" package main import ( "log" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/cors" ) func main() { app := fiber.New() // Use CORS middleware with default settings app.Use(cors.New()) // Define routes app.Get("/", func(c fiber.Ctx) error { return c.SendString("CORS enabled!") }) log.Fatal(app.Listen(":3000")) } ``` Check CORS by passing any domain in `Origin` header: ```bash curl -H "Origin: http://example.com" --verbose http://localhost:3000 ``` ### Custom 404 Response 📖 [HTTP Methods](https://docs.gofiber.io/api/ctx#status) ```go title="Example" package main import ( "log" "github.com/gofiber/fiber/v3" ) func main() { app := fiber.New() // Define routes app.Get("/", static.New("./public")) app.Get("/demo", func(c fiber.Ctx) error { return c.SendString("This is a demo page!") }) app.Post("/register", func(c fiber.Ctx) error { return c.SendString("Registration successful!") }) // Middleware to handle 404 Not Found app.Use(func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusNotFound) // => 404 "Not Found" }) log.Fatal(app.Listen(":3000")) } ``` ### JSON Response 📖 [JSON](https://docs.gofiber.io/api/ctx#json) ```go title="Example" package main import ( "log" "github.com/gofiber/fiber/v3" ) type User struct { Name string `json:"name"` Age int `json:"age"` } func main() { app := fiber.New() // Route that returns a JSON object app.Get("/user", func(c fiber.Ctx) error { return c.JSON(&User{"John", 20}) // => {"name":"John", "age":20} }) // Route that returns a JSON map app.Get("/json", func(c fiber.Ctx) error { return c.JSON(fiber.Map{ "success": true, "message": "Hi John!", }) // => {"success":true, "message":"Hi John!"} }) log.Fatal(app.Listen(":3000")) } ``` ### WebSocket Upgrade 📖 [Websocket](https://github.com/gofiber/websocket) ```go title="Example" package main import ( "log" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/websocket" ) func main() { app := fiber.New() // WebSocket route app.Get("/ws", websocket.New(func(c *websocket.Conn) { defer c.Close() for { // Read message from client mt, msg, err := c.ReadMessage() if err != nil { log.Println("read:", err) break } log.Printf("recv: %s", msg) // Write message back to client err = c.WriteMessage(mt, msg) if err != nil { log.Println("write:", err) break } } })) log.Fatal(app.Listen(":3000")) // Connect via WebSocket at ws://localhost:3000/ws } ``` ### Server-Sent Events 📖 [More Info](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events) ```go title="Example" package main import ( "bufio" "fmt" "log" "time" "github.com/gofiber/fiber/v3" "github.com/valyala/fasthttp" ) func main() { app := fiber.New() // Server-Sent Events route app.Get("/sse", func(c fiber.Ctx) error { c.Set("Content-Type", "text/event-stream") c.Set("Cache-Control", "no-cache") c.Set("Connection", "keep-alive") c.Set("Transfer-Encoding", "chunked") c.Context().SetBodyStreamWriter(func(w *bufio.Writer) { var i int for { i++ msg := fmt.Sprintf("%d - the time is %v", i, time.Now()) fmt.Fprintf(w, "data: Message: %s\n\n", msg) fmt.Println(msg) w.Flush() time.Sleep(5 * time.Second) } }) return nil }) log.Fatal(app.Listen(":3000")) } ``` ### Recover Middleware 📖 [Recover](https://docs.gofiber.io/api/middleware/recover) ```go title="Example" package main import ( "log" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/recover" ) func main() { app := fiber.New() // Use Recover middleware to handle panics gracefully app.Use(recover.New()) // Route that intentionally panics app.Get("/", func(c fiber.Ctx) error { panic("normally this would crash your app") }) log.Fatal(app.Listen(":3000")) } ``` ### Using Trusted Proxy 📖 [Config](https://docs.gofiber.io/api/fiber#config) ```go title="Example" package main import ( "log" "github.com/gofiber/fiber/v3" ) func main() { app := fiber.New(fiber.Config{ // Configure trusted proxies - WARNING: Only trust proxies you control // Using TrustProxy: true with unrestricted IPs can lead to IP spoofing TrustProxy: true, TrustProxyConfig: fiber.TrustProxyConfig{ Proxies: []string{"10.0.0.0/8", "172.16.0.0/12"}, // Example: Internal network ranges only }, ProxyHeader: fiber.HeaderXForwardedFor, }) // Define routes app.Get("/", func(c fiber.Ctx) error { return c.SendString("Trusted Proxy Configured!") }) log.Fatal(app.Listen(":3000")) } ```
## 🧬 Internal Middleware Here is a list of middleware that are included within the Fiber framework. | Middleware | Description | |--------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | [adaptor](https://github.com/gofiber/fiber/tree/main/middleware/adaptor) | Converter for net/http handlers to/from Fiber request handlers. | | [basicauth](https://github.com/gofiber/fiber/tree/main/middleware/basicauth) | Provides HTTP basic authentication. It calls the next handler for valid credentials and 401 Unauthorized for missing or invalid credentials. | | [cache](https://github.com/gofiber/fiber/tree/main/middleware/cache) | Intercept and cache HTTP responses. | | [compress](https://github.com/gofiber/fiber/tree/main/middleware/compress) | Compression middleware for Fiber, with support for `deflate`, `gzip`, `brotli` and `zstd`. | | [cors](https://github.com/gofiber/fiber/tree/main/middleware/cors) | Enable cross-origin resource sharing (CORS) with various options. | | [csrf](https://github.com/gofiber/fiber/tree/main/middleware/csrf) | Protect from CSRF exploits. | | [earlydata](https://github.com/gofiber/fiber/tree/main/middleware/earlydata) | Adds support for TLS 1.3's early data ("0-RTT") feature. | | [encryptcookie](https://github.com/gofiber/fiber/tree/main/middleware/encryptcookie) | Encrypt middleware which encrypts cookie values. | | [envvar](https://github.com/gofiber/fiber/tree/main/middleware/envvar) | Expose environment variables with providing an optional config. | | [etag](https://github.com/gofiber/fiber/tree/main/middleware/etag) | Allows for caches to be more efficient and save bandwidth, as a web server does not need to resend a full response if the content has not changed. | | [expvar](https://github.com/gofiber/fiber/tree/main/middleware/expvar) | Serves via its HTTP server runtime exposed variables in the JSON format. | | [favicon](https://github.com/gofiber/fiber/tree/main/middleware/favicon) | Ignore favicon from logs or serve from memory if a file path is provided. | | [healthcheck](https://github.com/gofiber/fiber/tree/main/middleware/healthcheck) | Liveness and Readiness probes for Fiber. | | [helmet](https://github.com/gofiber/fiber/tree/main/middleware/helmet) | Helps secure your apps by setting various HTTP headers. | | [idempotency](https://github.com/gofiber/fiber/tree/main/middleware/idempotency) | Allows for fault-tolerant APIs where duplicate requests do not erroneously cause the same action performed multiple times on the server-side. | | [keyauth](https://github.com/gofiber/fiber/tree/main/middleware/keyauth) | Adds support for key based authentication. | | [limiter](https://github.com/gofiber/fiber/tree/main/middleware/limiter) | Adds Rate-limiting support to Fiber. Use to limit repeated requests to public APIs and/or endpoints such as password reset. | | [logger](https://github.com/gofiber/fiber/tree/main/middleware/logger) | HTTP request/response logger. | | [paginate](https://github.com/gofiber/fiber/tree/main/middleware/paginate) | Extracts pagination parameters from query strings. Supports page-based, offset-based, and cursor-based pagination with multi-field sorting. | | [pprof](https://github.com/gofiber/fiber/tree/main/middleware/pprof) | Serves runtime profiling data in pprof format. | | [proxy](https://github.com/gofiber/fiber/tree/main/middleware/proxy) | Allows you to proxy requests to multiple servers. | | [recover](https://github.com/gofiber/fiber/tree/main/middleware/recover) | Recovers from panics anywhere in the stack chain and handles the control to the centralized ErrorHandler. | | [redirect](https://github.com/gofiber/fiber/tree/main/middleware/redirect) | Redirect middleware. | | [requestid](https://github.com/gofiber/fiber/tree/main/middleware/requestid) | Adds a request ID to every request. | | [responsetime](https://github.com/gofiber/fiber/tree/main/middleware/responsetime) | Measures request handling duration and writes it to a configurable response header. | | [rewrite](https://github.com/gofiber/fiber/tree/main/middleware/rewrite) | Rewrites the URL path based on provided rules. It can be helpful for backward compatibility or just creating cleaner and more descriptive links. | | [session](https://github.com/gofiber/fiber/tree/main/middleware/session) | Session middleware. NOTE: This middleware uses our Storage package. | | [skip](https://github.com/gofiber/fiber/tree/main/middleware/skip) | Skip middleware that skips a wrapped handler if a predicate is true. | | [static](https://github.com/gofiber/fiber/tree/main/middleware/static) | Static middleware for Fiber that serves static files such as **images**, **CSS**, and **JavaScript**. | | [timeout](https://github.com/gofiber/fiber/tree/main/middleware/timeout) | Adds a max time for a request and forwards to ErrorHandler if it is exceeded. | ## 🧬 External Middleware List of externally hosted middleware modules and maintained by the [Fiber team](https://github.com/orgs/gofiber/people). | Middleware | Description | | :------------------------------------------------ | :-------------------------------------------------------------------------------------------------------------------- | | [contrib](https://github.com/gofiber/contrib) | Third-party middlewares | | [storage](https://github.com/gofiber/storage) | Premade storage drivers that implement the Storage interface, designed to be used with various Fiber middlewares. | | [template](https://github.com/gofiber/template) | This package contains 9 template engines that can be used with Fiber. | ## 🕶️ Awesome List For more articles, middlewares, examples, or tools, check our [awesome list](https://github.com/gofiber/awesome-fiber). ## 👍 Contribute If you want to say **Thank You** and/or support the active development of `Fiber`: 1. Add a [GitHub Star](https://github.com/gofiber/fiber/stargazers) to the project. 2. Tweet about the project [on your 𝕏 (Twitter)](https://x.com/intent/tweet?text=Fiber%20is%20an%20Express%20inspired%20%23web%20%23framework%20built%20on%20top%20of%20Fasthttp%2C%20the%20fastest%20HTTP%20engine%20for%20%23Go.%20Designed%20to%20ease%20things%20up%20for%20%23fast%20development%20with%20zero%20memory%20allocation%20and%20%23performance%20in%20mind%20%F0%9F%9A%80%20https%3A%2F%2Fgithub.com%2Fgofiber%2Ffiber). 3. Write a review or tutorial on [Medium](https://medium.com/), [Dev.to](https://dev.to/) or your personal blog. 4. Support the project by donating a [cup of coffee](https://buymeacoff.ee/fenny). ## 💻 Development To ensure your contributions are ready for a Pull Request, please use the following `Makefile` commands. These tools help maintain code quality and consistency. - **make help**: Display available commands. - **make audit**: Conduct quality checks. - **make benchmark**: Benchmark code performance. - **make coverage**: Generate test coverage report. - **make format**: Automatically format code. - **make lint**: Run lint checks. - **make test**: Execute all tests. - **make tidy**: Tidy dependencies. Run these commands to ensure your code adheres to project standards and best practices. ## ☕ Supporters Fiber is an open-source project that runs on donations to pay the bills, e.g., our domain name, GitBook, Netlify, and serverless hosting. If you want to support Fiber, you can ☕ [**buy a coffee here**](https://buymeacoff.ee/fenny). | | User | Donation | | ---------------------------------------------------------- | ------------------------------------------------ | -------- | | ![](https://avatars.githubusercontent.com/u/204341?s=25) | [@destari](https://github.com/destari) | ☕ x 10 | | ![](https://avatars.githubusercontent.com/u/63164982?s=25) | [@dembygenesis](https://github.com/dembygenesis) | ☕ x 5 | | thomasvvugt | [@thomasvvugt](https://github.com/thomasvvugt) | ☕ x 5 | | ![](https://avatars.githubusercontent.com/u/27820675?s=25) | [@hendratommy](https://github.com/hendratommy) | ☕ x 5 | | ![](https://avatars.githubusercontent.com/u/1094221?s=25) | [@ekaputra07](https://github.com/ekaputra07) | ☕ x 5 | | ![](https://avatars.githubusercontent.com/u/194590?s=25) | [@jorgefuertes](https://github.com/jorgefuertes) | ☕ x 5 | | ![](https://avatars.githubusercontent.com/u/186637?s=25) | [@candidosales](https://github.com/candidosales) | ☕ x 5 | | ![](https://avatars.githubusercontent.com/u/29659953?s=25) | [@l0nax](https://github.com/l0nax) | ☕ x 3 | | ![](https://avatars.githubusercontent.com/u/635852?s=25) | [@bihe](https://github.com/bihe) | ☕ x 3 | | ![](https://avatars.githubusercontent.com/u/307334?s=25) | [@justdave](https://github.com/justdave) | ☕ x 3 | | ![](https://avatars.githubusercontent.com/u/11155743?s=25) | [@koddr](https://github.com/koddr) | ☕ x 1 | | ![](https://avatars.githubusercontent.com/u/29042462?s=25) | [@lapolinar](https://github.com/lapolinar) | ☕ x 1 | | ![](https://avatars.githubusercontent.com/u/2978730?s=25) | [@diegowifi](https://github.com/diegowifi) | ☕ x 1 | | ![](https://avatars.githubusercontent.com/u/44171355?s=25) | [@ssimk0](https://github.com/ssimk0) | ☕ x 1 | | ![](https://avatars.githubusercontent.com/u/5638101?s=25) | [@raymayemir](https://github.com/raymayemir) | ☕ x 1 | | ![](https://avatars.githubusercontent.com/u/619996?s=25) | [@melkorm](https://github.com/melkorm) | ☕ x 1 | | ![](https://avatars.githubusercontent.com/u/31022056?s=25) | [@marvinjwendt](https://github.com/marvinjwendt) | ☕ x 1 | | ![](https://avatars.githubusercontent.com/u/31921460?s=25) | [@toishy](https://github.com/toishy) | ☕ x 1 | ## 💻 Code Contributors Code Contributors ## ⭐️ Stargazers ## 🧾 License Copyright (c) 2019-present [Fenny](https://github.com/fenny) and [Contributors](https://github.com/gofiber/fiber/graphs/contributors). `Fiber` is free and open-source software licensed under the [MIT License](https://github.com/gofiber/fiber/blob/main/LICENSE). Official logo was created by [Vic Shóstak](https://github.com/koddr) and distributed under [Creative Commons](https://creativecommons.org/licenses/by-sa/4.0/) license (CC BY-SA 4.0 International). ================================================ FILE: .github/SECURITY.md ================================================ # Security Policy 1. [Supported Versions](#versions) 2. [Reporting security problems to Fiber](#reporting) 3. [Security Points of Contact](#contact) 4. [Incident Response Process](#process) ## Supported Versions The table below shows the supported versions for Fiber which include security updates. | Version | Supported | | --------- | ------------------ | | >= 1.12.6 | :white_check_mark: | | < 1.12.6 | :x: | ## Reporting security problems to Fiber **DO NOT CREATE AN ISSUE** to report a security problem. Instead, please send us an e-mail at `team@gofiber.io` or join our discord server via [this invite link](https://gofiber.io/discord) and send a private message to any of the maintainers. ## Security Points of Contact For security-related matters, please contact any of the [@maintainers](https://github.com/orgs/gofiber/teams/maintainers). The maintainers are the only persons with administrative access to Fiber's source code and respond to security incident reports as fast as possible, within one business day at the latest. ## Incident Response Process In case an incident is discovered or reported, we will follow the following process to contain, respond and remediate: ### 1. Containment The first step is to find out the root cause, nature and scope of the incident. - Is still ongoing? If yes, first priority is to stop it. - Is the incident outside of our influence? If yes, first priority is to contain it. - Find out knows about the incident and who is affected. - Find out what data was potentially exposed. ### 2. Response After the initial assessment and containment to our best abilities, we will document all actions taken in a response plan. We will create a comment in the official `#announcements` channel to inform users about the incident and what actions we took to contain it. ### 3. Remediation Once the incident is confirmed to be resolved, we will summarize the lessons learned from the incident and create a list of actions we will take to prevent it from happening again. ### Secure accounts with access The [Fiber Organization](https://github.com/gofiber) requires 2FA authorization for all of it's members. ### Critical Updates And Security Notices We learn about critical software updates and security threats from these sources 1. GitHub Security Alerts 2. GitHub: [https://status.github.com/](https://status.github.com/) & [@githubstatus](https://twitter.com/githubstatus) ================================================ FILE: .github/codecov.yml ================================================ coverage: status: project: default: target: auto threshold: 0.5% base: auto patch: default: target: auto threshold: 0.5% base: auto ignore: # Ignore generated root files - "*_msgp.go" - "*_msgp_test.go" - "*_gen.go" # Ignore generated files below root - "**/*_msgp.go" - "**/*_msgp_test.go" - "**/*_gen.go" # Ignore internal and docs - "internal/**" - "docs/**" ================================================ FILE: .github/config.yml ================================================ # Configuration for new-issue-welcome - https://github.com/behaviorbot/new-issue-welcome # Comment to be posted to on first time issues newIssueWelcomeComment: > Thanks for opening your first issue here! 🎉 Be sure to follow the issue template! If you need help or want to chat with us, join us on Discord https://gofiber.io/discord # Configuration for new-pr-welcome - https://github.com/behaviorbot/new-pr-welcome # Comment to be posted to on PRs from first time contributors in your repository newPRWelcomeComment: > Thanks for opening this pull request! 🎉 Please check out our contributing guidelines. If you need help or want to chat with us, join us on Discord https://gofiber.io/discord # Configuration for first-pr-merge - https://github.com/behaviorbot/first-pr-merge # Comment to be posted to on pull requests merged by a first time user firstPRMergeComment: > Congrats on merging your first pull request! 🎉 We here at Fiber are proud of you! If you need help or want to chat with us, join us on Discord https://gofiber.io/discord ================================================ FILE: .github/copilot-instructions.md ================================================ # Copilot Usage When modifying code, always perform these steps: 1. **Ensure code quality** - `make format` to format the project. - `make lint` for static analysis. - `make test` to run the test suite. 2. **Maintain documentation** Review and update the contents of the `docs` folder if necessary. 3. **Check Markdown** - Finish by running `make markdown` to lint all Markdown files. ================================================ FILE: .github/copilot-setup-steps.yml ================================================ steps: - run: | if [ -d vendor ] || go list -m -mod=readonly all; then echo "Dependencies already present" else go mod tidy && go mod download && go mod vendor fi - run: | go install gotest.tools/gotestsum@latest go install golang.org/x/vuln/cmd/govulncheck@latest go install mvdan.cc/gofumpt@latest go install github.com/tinylib/msgp@latest go install github.com/vburenin/ifacemaker@975a95966976eeb2d4365a7fb236e274c54da64c go install github.com/dkorunic/betteralign/cmd/betteralign@latest - run: go mod tidy ================================================ FILE: .github/dependabot.yml ================================================ # https://help.github.com/github/administering-a-repository/configuration-options-for-dependency-updates version: 2 updates: - package-ecosystem: "gomod" directory: "/" # Location of package manifests labels: - "🤖 Dependencies" schedule: interval: "daily" allow: # Allow both direct and indirect updates for all packages. - dependency-type: "all" groups: fasthttp-modules: patterns: - "github.com/valyala/fasthttp" - "github.com/valyala/fasthttp/**" golang-modules: patterns: - "golang.org/x/**" valyala-utils-modules: patterns: - "github.com/valyala/bytebufferpool" - "github.com/valyala/tcplisten" google-modules: patterns: - "github.com/google/**" - "google.golang.org/**" - package-ecosystem: "github-actions" directory: "/" schedule: interval: daily labels: - "🤖 Dependencies" ================================================ FILE: .github/index.html ================================================ Test file Hello, World! ================================================ FILE: .github/labeler.yml ================================================ _extends: .github labels: - label: 'v3' matcher: baseBranch: 'main' title: '/(v3)/i' body: '/(v3)/i' - label: 'v2' matcher: baseBranch: 'v2' title: '/(v2)/i' ================================================ FILE: .github/pull_request_template.md ================================================ # Description Please provide a clear and concise description of the changes you've made and the problem they address. Include the purpose of the change, any relevant issues it solves, and the benefits it brings to the project. If this change introduces new features or adjustments, highlight them here. Fixes # (issue) ## Changes introduced List the new features or adjustments introduced in this pull request. Provide details on benchmarks, documentation updates, changelog entries, and if applicable, the migration guide. - [ ] Benchmarks: Describe any performance benchmarks and improvements related to the changes. - [ ] Documentation Update: Detail the updates made to the documentation and links to the changed files. - [ ] Changelog/What's New: Include a summary of the additions for the upcoming release notes. - [ ] Migration Guide: If necessary, provide a guide or steps for users to migrate their existing code to accommodate these changes. - [ ] API Alignment with Express: Explain how the changes align with the Express API. - [ ] API Longevity: Discuss the steps taken to ensure that the new or updated APIs are consistent and not prone to breaking changes. - [ ] Examples: Provide examples demonstrating the new features or changes in action. ## Type of change Please delete options that are not relevant. - [ ] New feature (non-breaking change which adds functionality) - [ ] Enhancement (improvement to existing features and functionality) - [ ] Documentation update (changes to documentation) - [ ] Performance improvement (non-breaking change which improves efficiency) - [ ] Code consistency (non-breaking change which improves code reliability and robustness) ## Checklist Before you submit your pull request, please make sure you meet these requirements: - [ ] Followed the inspiration of the Express.js framework for new functionalities, making them similar in usage. - [ ] Conducted a self-review of the code and provided comments for complex or critical parts. - [ ] Updated the documentation in the `/docs/` directory for [Fiber's documentation](https://docs.gofiber.io/). - [ ] Added or updated unit tests to validate the effectiveness of the changes or new features. - [ ] Ensured that new and existing unit tests pass locally with the changes. - [ ] Verified that any new dependencies are essential and have been agreed upon by the maintainers/community. - [ ] Aimed for optimal performance with minimal allocations in the new code. - [ ] Provided benchmarks for the new code to analyze and improve upon. ## Commit formatting Please use emojis in commit messages for an easy way to identify the purpose or intention of a commit. Check out the emoji cheatsheet here: [CONTRIBUTING.md](https://github.com/gofiber/fiber/blob/main/.github/CONTRIBUTING.md#pull-requests-or-commits) ================================================ FILE: .github/release-drafter.yml ================================================ name-template: 'v$RESOLVED_VERSION' tag-template: 'v$RESOLVED_VERSION' commitish: main filter-by-commitish: true include-labels: - 'v3' exclude-labels: - 'v2' categories: - title: '❗ Breaking Changes' labels: - '❗ BreakingChange' - title: '🚀 New' labels: - '✏️ Feature' - '📝 Proposal' - title: '🧹 Updates' labels: - '🧹 Updates' - '⚡️ Performance' - title: '🐛 Fixes' labels: - '☢️ Bug' - title: '🛠️ Maintenance' labels: - '🤖 Dependencies' - title: '📚 Documentation' labels: - '📒 Documentation' change-template: '- $TITLE (#$NUMBER)' change-title-escapes: '\<*_&' # You can add # and @ to disable mentions, and add ` to disable code blocks. exclude-contributors: - dependabot - dependabot[bot] version-resolver: major: labels: - '❗ BreakingChange' minor: labels: - '✏️ Feature' patch: labels: - '📒 Documentation' - '☢️ Bug' - '🤖 Dependencies' - '🧹 Updates' - '⚡️ Performance' default: patch template: | $CHANGES **📒 Documentation**: https://docs.gofiber.io/next/ **💬 Discord**: https://gofiber.io/discord **Full Changelog**: https://github.com/$OWNER/$REPOSITORY/compare/$PREVIOUS_TAG...v$RESOLVED_VERSION Thank you $CONTRIBUTORS for making this release possible. ================================================ FILE: .github/release.yml ================================================ # .github/release.yml changelog: categories: - title: '❗ Breaking Changes' labels: - '❗ BreakingChange' - title: '🚀 New Features' labels: - '✏️ Feature' - '📝 Proposal' - title: '🧹 Updates' labels: - '🧹 Updates' - '⚡️ Performance' - title: '🐛 Bug Fixes' labels: - '☢️ Bug' - title: '🛠️ Maintenance' labels: - '🤖 Dependencies' - title: '📚 Documentation' labels: - '📒 Documentation' - title: 'Other Changes' labels: - '*' ================================================ FILE: .github/scripts/sync_docs.sh ================================================ #!/usr/bin/env bash set -e # Some env variables BRANCH="main" REPO_URL="github.com/gofiber/docs.git" AUTHOR_EMAIL="github-actions[bot]@users.noreply.github.com" AUTHOR_USERNAME="github-actions[bot]" VERSION_FILE="versions.json" REPO_DIR="core" COMMIT_URL="https://github.com/gofiber/fiber" DOCUSAURUS_COMMAND="npm run docusaurus -- docs:version" # Set commit author git config --global user.email "${AUTHOR_EMAIL}" git config --global user.name "${AUTHOR_USERNAME}" git clone https://${TOKEN}@${REPO_URL} fiber-docs # Handle push event if [ "$EVENT" == "push" ]; then latest_commit=$(git rev-parse --short HEAD) #log_output=$(git log --oneline ${BRANCH} HEAD~1..HEAD --name-status -- docs/) #if [[ $log_output != "" ]]; then cp -a docs/* fiber-docs/docs/${REPO_DIR} #fi # Handle release event elif [ "$EVENT" == "release" ]; then major_version="${TAG_NAME%%.*}" echo "Major version: $major_version" # Form new version name new_version="${major_version}.x" echo "New version: $new_version" cd fiber-docs/ || true npm ci # Check if contrib_versions.json exists and modify it if required if [[ -f $VERSION_FILE ]]; then echo "Modifying version file: $VERSION_FILE" jq --arg new_version "$new_version" 'del(.[] | select(. == $new_version))' $VERSION_FILE > temp.json && mv temp.json $VERSION_FILE fi # Run docusaurus versioning command $DOCUSAURUS_COMMAND "${new_version}" if [[ -f $VERSION_FILE ]]; then echo "Sorting version file: $VERSION_FILE" jq 'sort | reverse' ${VERSION_FILE} > temp.json && mv temp.json ${VERSION_FILE} fi fi # Push changes cd fiber-docs/ || true git add . if [[ $EVENT == "push" ]]; then git commit -m "Add docs from ${COMMIT_URL}/commit/${latest_commit}" elif [[ $EVENT == "release" ]]; then git commit -m "Sync docs for release ${COMMIT_URL}/releases/tag/${TAG_NAME}" fi MAX_RETRIES=5 DELAY=5 retry=0 while ((retry < MAX_RETRIES)); do git push https://${TOKEN}@${REPO_URL} && break retry=$((retry + 1)) git pull --rebase sleep $DELAY done if ((retry == MAX_RETRIES)); then echo "Failed to push after $MAX_RETRIES attempts. Exiting with 1." exit 1 fi ================================================ FILE: .github/testdata/ca-chain.cert.pem ================================================ -----BEGIN CERTIFICATE----- MIIFeTCCA2GgAwIBAgIDEAISMA0GCSqGSIb3DQEBCwUAMFYxCzAJBgNVBAYTAlVT MQ8wDQYDVQQIDAZEZW5pYWwxFDASBgNVBAcMC1NwcmluZ2ZpZWxkMQwwCgYDVQQK DANEaXMxEjAQBgNVBAMMCWxvY2FsaG9zdDAeFw0yMjAyMDgyMDA3MjlaFw0zMjAy MDYyMDA3MjlaMEAxCzAJBgNVBAYTAlVTMQ8wDQYDVQQIDAZEZW5pYWwxDDAKBgNV BAoMA0RpczESMBAGA1UEAwwJbG9jYWxob3N0MIICIjANBgkqhkiG9w0BAQEFAAOC Ag8AMIICCgKCAgEA5Cho0kbBDi1cy8bURStc95hK2RzjBQMd2hN5gFxZdF5knBfC LSiPMxtAn9zJYzYc9+Cq7hIOK19cgG4yKk9GFZaUe+mU4yWxRg1ViSu/jzQ04sVc JRSbSklXY1RNyxpUtGelxnluUvdvuXXlCPmKob4IsUtI1FTcumG1mzIO+cAzBd1J KQtNTUO9XfSHYusV/FQO2hIbaXcFgSAg50JJfYZaUw51J07j3vdb6lb1x4rRmIaq 8txrdHo0Y2tXHsq6jry1QrOZfoz4WbYcoID3JU1MC5f1HyR5uYiCA1RJVGnQ3iSX 3yM+gRy3SFPeaASs2useSzGkMr/pDlbcSVmsbXsasBxZq85T1FE8vuY6K4XlU2sN PyiPrNjDgVkQ8Lbj1B9oKEYKkmSieBx9YwRLarfru1kt+g3kdXuel7DyHpm+j/13 vqjyF9DAyx4wAEZC+DzeqBsbuiDdRkzwFMcKPxYpgSTLawnCjlFapPvE5kGN+O/j To2qWbWUU/upzBvHu4tnICSapJJ0VqA+7M7yaBAsIWK/yjNTzpHfx0oHudl8wBOG ySfOE52uouFsp2vs06YpEg2nGn/7Iu0Rbbwt4iFcSZlEnSk0cQlyMZxdj3M2fMKa /nrRQm7guPbVmBJOFHZuTTiilNSduSsDwCjJkGdJkSVYbj3+eJzKwYstnT8CAwEA AaNmMGQwHQYDVR0OBBYEFHwli1hTCVJLHPTHWW8O8BCaHci9MB8GA1UdIwQYMBaA FDlb/7rpDA2ZzsLZmqbW/krUtmGOMBIGA1UdEwEB/wQIMAYBAf8CAQAwDgYDVR0P AQH/BAQDAgGGMA0GCSqGSIb3DQEBCwUAA4ICAQBtMxa2/w6kGF9cmqpTdQ1La8nY R4Zoewnn+SCmcSOwCyBC32g0Ry6nKKUpJBpJEid5lBzWveIw4K7pdWvmuqmeMuWI ilvlCLzqYPigmnEIW96hc6XiQvl9NC5j+SAZSC+4uCNhEUx5pEbE1FU1gIX+szdJ tLdPwwg63Ce/us6QZ7Tx8qLIr+XU+DrCgjIheQFShtoNYDw0GxEtjeo8vHynj8EZ +p0OZgqoNlnRbQbllruDFPXDJVI23DVhNpJhT86iQDMtMV53ypMu62LXmdQIKa7l ITnEMGO626RKqw2kDHt7yinBlt1nHskaeeLya6K/08uJkqRCjzOshJgsjQ3e62vQ Mht9QvGBCAoY09fIGxRihtTWCFDe7MEnbh1PPYB7cZTOMnL3wxRPzLLYhclX+pt0 bBf7Dn84b3tdC5BFXBJeZMs5QSCvn4yrTew+NvvX3oL6Ny1JDZYaG5PhKf00J6iy TkXzK2n9U9RX+krPk8fU9Ae1nayD0vrmGaVcBdRQPn4XUuS3LhdlkHfr28z1nF9m ffd0WBrJlNX9SoKtsMj8VJFZ/nJ0EcCcY1mG3k/IGAY1HUeo4A+C5E/UO3h4+tOL uqUa8rkl9HoE4fIWdQVxtQjEdATSuJusaK7CFpWH8A0w9VchDx74saiwwGhVyXYk yBwSA5U88ymkQ7qNJg== -----END CERTIFICATE----- -----BEGIN CERTIFICATE----- MIIFnTCCA4WgAwIBAgIUQnfvDIm6z+973AYRTLorZHEQA30wDQYJKoZIhvcNAQEL BQAwVjELMAkGA1UEBhMCVVMxDzANBgNVBAgMBkRlbmlhbDEUMBIGA1UEBwwLU3By aW5nZmllbGQxDDAKBgNVBAoMA0RpczESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTIy MDIwODIwMDcyOVoXDTQyMDIwMzIwMDcyOVowVjELMAkGA1UEBhMCVVMxDzANBgNV BAgMBkRlbmlhbDEUMBIGA1UEBwwLU3ByaW5nZmllbGQxDDAKBgNVBAoMA0RpczES MBAGA1UEAwwJbG9jYWxob3N0MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKC AgEA+Cf/fKPvI5+Nh81wpxghLrpAjM1MyhdHUDXc8bu7NTNYZ+6ArMqDeKSszTWT gLV9EeJs57KwjwXIoYZDTcLvpjanrZ2s7JDEqsGl7S6Xr67qzYghlF/GaB3f3lAi GsvhmDgC4jkdCvkVBKOB4tE0dy6fnNCmIKhhJDje51p90LWFuX5sIKO2trgte6b/ P1PW8rOjedPd5Z+QCG4Mi8JnbJicX1YaOySGMcXHm/eM2wy5I3pEdUreZDFbjB7l CKa1kFAnDXBiQAoErggMlcXe6C+avB13wYCfi+R+9m2X0svmerSz+oHDCOhvnD52 EE7fBv89VS6pR8mykz1eHiBKkVT1qdmONThUQ8mqwxlo06bZyoykKSxG6ffJnGbM GkTWlaNEdZuY0pITQLxEX2SwLJuZKfTPFheno83bLCqTOTuWo7h1mLe0ogJMxl// NHzyoMJ0b1bbrRgsLcMaDn1MsI+gVdRY89+cZ2uedEgrr7fl3KFWiF2S57bxX1fJ P8HC0bzMny4jMtIf6YUqDtpGDPjZq3PGqcrcO0dVYkuaC96H3xLcvQxgkMDK0sm0 pUbWlECzAag/lxeeC22VnedqgCpiq/9z1b4j884ZkhIJyht0HR7L4I0gO/R/mWUY 8bO9XCMYmP0CjO7u93IlzQ7aSIpWprTHxjpiPelmcO001jkCAwEAAaNjMGEwHQYD VR0OBBYEFDlb/7rpDA2ZzsLZmqbW/krUtmGOMB8GA1UdIwQYMBaAFDlb/7rpDA2Z zsLZmqbW/krUtmGOMA8GA1UdEwEB/wQFMAMBAf8wDgYDVR0PAQH/BAQDAgGGMA0G CSqGSIb3DQEBCwUAA4ICAQDk02Tu0ypqnS897WBx98l2nYIrEhcrJg8ZMmSwEa9J 7TANofzsP9931YoQMh6BI6hB3OkyL6FYTUDykpGMMasojtL/F2iXEsjema2ilZ/7 hNAZ+j5mBemMwXfkfmRguXvnl7EWaZETgEoxhcOTYoYUYqDcyzuwK63fOs+YA5ke O8E3F1aLHzLpqVpiG7t740L7LdibNPko9JOd31Gqcq3nhXMf6/rOdL8VSj/F+4BG opgJBruJV9NxWRI/b0G6eImvaYL/Ljfd8wzwNpmYkNkHbhAiaHeXJQ05mebmr2Dr wne9QeSJkXCs/K5A/8+0CYNN4homt8xNNN02SnJ5e6nv1A1ntMW9n6n2KYo87tz9 VmqWXg7Y1BqXj287WRaWPJsBa2RBP1W2d3BQfHKJfu15blyXaczTi87WayEsBnQq TXy+1QP0IwQerSTOxdW25UoJmH18SRbLEIQs9Htvcpz2AncTYjeiLFa15FO2r5hP LYc9QOKn6yIZP9lYztleEqOLTmHnRnFcupDol+/x88d+kVLqmXDiKmWbVIz7C735 xgImsyrCPPYYiEA7/yaP5G1o5XU93kRPrtg/7jjyF+uBZ70fcbED3prpuiJYrL0O gvQUgmGUU30mPHjAKkEACeRXtoqucRDxvIkBb5zUvZG8RmSFae5siAWwLD7D7VJa IA== -----END CERTIFICATE----- ================================================ FILE: .github/testdata/fs/css/style.css ================================================ h1 { color: red; text-align: center; } ================================================ FILE: .github/testdata/fs/css/test/style2.css ================================================ h1 { color: black; } ================================================ FILE: .github/testdata/fs/index.html ================================================ Document

Hello, World!

================================================ FILE: .github/testdata/hello_world.tmpl ================================================

Hello {{ .Name }}!

================================================ FILE: .github/testdata/index.html ================================================

Hello, Fiber!

================================================ FILE: .github/testdata/index.tmpl ================================================

{{.Title}}

================================================ FILE: .github/testdata/main.tmpl ================================================

I'm main

================================================ FILE: .github/testdata/ssl.key ================================================ -----BEGIN PRIVATE KEY----- MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQD4IQusAs8PJdnG 3mURt/AXtgC+ceqLOatJ49JJE1VPTkMAy+oE1f1XvkMrYsHqmDf6GWVzgVXryL4U wq2/nJSm56ddhN55nI8oSN3dtywUB8/ShelEN73nlN77PeD9tl6NksPwWaKrqxq0 FlabRPZSQCfmgZbhDV8Sa8mfCkFU0G0lit6kLGceCKMvmW+9Bz7ebsYmVdmVMxmf IJStFD44lWFTdUc65WISKEdW2ELcUefb0zOLw+0PCbXFGJH5x5ktksW8+BBk2Hkg GeQRL/qPCccthbScO0VgNj3zJ3ZZL0ObSDAbvNDG85joeNjDNq5DT/BAZ0bOSbEF sh+f9BAzAgMBAAECggEBAJWv2cq7Jw6MVwSRxYca38xuD6TUNBopgBvjREixURW2 sNUaLuMb9Omp7fuOaE2N5rcJ+xnjPGIxh/oeN5MQctz9gwn3zf6vY+15h97pUb4D uGvYPRDaT8YVGS+X9NMZ4ZCmqW2lpWzKnCFoGHcy8yZLbcaxBsRdvKzwOYGoPiFb K2QuhXZ/1UPmqK9i2DFKtj40X6vBszTNboFxOVpXrPu0FJwLVSDf2hSZ4fMM0DH3 YqwKcYf5te+hxGKgrqRA3tn0NCWii0in6QIwXMC+kMw1ebg/tZKqyDLMNptAK8J+ DVw9m5X1seUHS5ehU/g2jrQrtK5WYn7MrFK4lBzlRwECgYEA/d1TeANYECDWRRDk B0aaRZs87Rwl/J9PsvbsKvtU/bX+OfSOUjOa9iQBqn0LmU8GqusEET/QVUfocVwV Bggf/5qDLxz100Rj0ags/yE/kNr0Bb31kkkKHFMnCT06YasR7qKllwrAlPJvQv9x IzBKq+T/Dx08Wep9bCRSFhzRCnsCgYEA+jdeZXTDr/Vz+D2B3nAw1frqYFfGnEVY wqmoK3VXMDkGuxsloO2rN+SyiUo3JNiQNPDub/t7175GH5pmKtZOlftePANsUjBj wZ1D0rI5Bxu/71ibIUYIRVmXsTEQkh/ozoh3jXCZ9+bLgYiYx7789IUZZSokFQ3D FICUT9KJ36kCgYAGoq9Y1rWJjmIrYfqj2guUQC+CfxbbGIrrwZqAsRsSmpwvhZ3m tiSZxG0quKQB+NfSxdvQW5ulbwC7Xc3K35F+i9pb8+TVBdeaFkw+yu6vaZmxQLrX fQM/pEjD7A7HmMIaO7QaU5SfEAsqdCTP56Y8AftMuNXn/8IRfo2KuGwaWwKBgFpU ILzJoVdlad9E/Rw7LjYhZfkv1uBVXIyxyKcfrkEXZSmozDXDdxsvcZCEfVHM6Ipk K/+7LuMcqp4AFEAEq8wTOdq6daFaHLkpt/FZK6M4TlruhtpFOPkoNc3e45eM83OT 6mziKINJC1CQ6m65sQHpBtjxlKMRG8rL/D6wx9s5AoGBAMRlqNPMwglT3hvDmsAt 9Lf9pdmhERUlHhD8bj8mDaBj2Aqv7f6VRJaYZqP403pKKQexuqcn80mtjkSAPFkN Cj7BVt/RXm5uoxDTnfi26RF9F6yNDEJ7UU9+peBr99aazF/fTgW/1GcMkQnum8uV c257YgaWmjK9uB0Y2r2VxS0G -----END PRIVATE KEY----- ================================================ FILE: .github/testdata/ssl.pem ================================================ -----BEGIN CERTIFICATE----- MIICujCCAaKgAwIBAgIJAMbXnKZ/cikUMA0GCSqGSIb3DQEBCwUAMBUxEzARBgNV BAMTCnVidW50dS5uYW4wHhcNMTUwMjA0MDgwMTM5WhcNMjUwMjAxMDgwMTM5WjAV MRMwEQYDVQQDEwp1YnVudHUubmFuMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB CgKCAQEA+CELrALPDyXZxt5lEbfwF7YAvnHqizmrSePSSRNVT05DAMvqBNX9V75D K2LB6pg3+hllc4FV68i+FMKtv5yUpuenXYTeeZyPKEjd3bcsFAfP0oXpRDe955Te +z3g/bZejZLD8Fmiq6satBZWm0T2UkAn5oGW4Q1fEmvJnwpBVNBtJYrepCxnHgij L5lvvQc+3m7GJlXZlTMZnyCUrRQ+OJVhU3VHOuViEihHVthC3FHn29Mzi8PtDwm1 xRiR+ceZLZLFvPgQZNh5IBnkES/6jwnHLYW0nDtFYDY98yd2WS9Dm0gwG7zQxvOY 6HjYwzauQ0/wQGdGzkmxBbIfn/QQMwIDAQABow0wCzAJBgNVHRMEAjAAMA0GCSqG SIb3DQEBCwUAA4IBAQBQjKm/4KN/iTgXbLTL3i7zaxYXFLXsnT1tF+ay4VA8aj98 L3JwRTciZ3A5iy/W4VSCt3eASwOaPWHKqDBB5RTtL73LoAqsWmO3APOGQAbixcQ2 45GXi05OKeyiYRi1Nvq7Unv9jUkRDHUYVPZVSAjCpsXzPhFkmZoTRxmx5l0ZF7Li K91lI5h+eFq0dwZwrmlPambyh1vQUi70VHv8DNToVU29kel7YLbxGbuqETfhrcy6 X+Mha6RYITkAn5FqsZcKMsc9eYGEF4l3XV+oS7q6xfTxktYJMFTI18J0lQ2Lv/CI whdMnYGntDQBE/iFCrJEGNsKGc38796GBOb5j+zd -----END CERTIFICATE----- ================================================ FILE: .github/testdata/template-invalid.html ================================================

{{.Title}

================================================ FILE: .github/testdata/template.tmpl ================================================

{{.Title}} {{.Summary}}

================================================ FILE: .github/testdata/testRoutes.json ================================================ { "test_routes": [{ "method": "GET", "path": "/authorizations" }, { "method": "GET", "path": "/authorizations/1337" }, { "method": "POST", "path": "/authorizations" }, { "method": "PUT", "path": "/authorizations/clients/inf1nd873nf8912g9t" }, { "method": "PATCH", "path": "/authorizations/1337" }, { "method": "DELETE", "path": "/authorizations/1337" }, { "method": "GET", "path": "/applications/2nds981mng6azl127y/tokens/sn108hbe1geheibf13f" }, { "method": "DELETE", "path": "/applications/2nds981mng6azl127y/tokens" }, { "method": "DELETE", "path": "/applications/2nds981mng6azl127y/tokens/sn108hbe1geheibf13f" }, { "method": "GET", "path": "/events" }, { "method": "GET", "path": "/repos/fenny/fiber/events" }, { "method": "GET", "path": "/networks/fenny/fiber/events" }, { "method": "GET", "path": "/orgs/gofiber/events" }, { "method": "GET", "path": "/users/fenny/received_events" }, { "method": "GET", "path": "/users/fenny/received_events/public" }, { "method": "GET", "path": "/users/fenny/events" }, { "method": "GET", "path": "/users/fenny/events/public" }, { "method": "GET", "path": "/users/fenny/events/orgs/gofiber" }, { "method": "GET", "path": "/feeds" }, { "method": "GET", "path": "/notifications" }, { "method": "GET", "path": "/repos/fenny/fiber/notifications" }, { "method": "PUT", "path": "/notifications" }, { "method": "PUT", "path": "/repos/fenny/fiber/notifications" }, { "method": "GET", "path": "/notifications/threads/1337" }, { "method": "PATCH", "path": "/notifications/threads/1337" }, { "method": "GET", "path": "/notifications/threads/1337/subscription" }, { "method": "PUT", "path": "/notifications/threads/1337/subscription" }, { "method": "DELETE", "path": "/notifications/threads/1337/subscription" }, { "method": "GET", "path": "/repos/fenny/fiber/stargazers" }, { "method": "GET", "path": "/users/fenny/starred" }, { "method": "GET", "path": "/user/starred" }, { "method": "GET", "path": "/user/starred/fenny/fiber" }, { "method": "PUT", "path": "/user/starred/fenny/fiber" }, { "method": "DELETE", "path": "/user/starred/fenny/fiber" }, { "method": "GET", "path": "/repos/fenny/fiber/subscribers" }, { "method": "GET", "path": "/users/fenny/subscriptions" }, { "method": "GET", "path": "/user/subscriptions" }, { "method": "GET", "path": "/repos/fenny/fiber/subscription" }, { "method": "PUT", "path": "/repos/fenny/fiber/subscription" }, { "method": "DELETE", "path": "/repos/fenny/fiber/subscription" }, { "method": "GET", "path": "/user/subscriptions/fenny/fiber" }, { "method": "PUT", "path": "/user/subscriptions/fenny/fiber" }, { "method": "DELETE", "path": "/user/subscriptions/fenny/fiber" }, { "method": "GET", "path": "/users/fenny/gists" }, { "method": "GET", "path": "/gists" }, { "method": "GET", "path": "/gists/public" }, { "method": "GET", "path": "/gists/starred" }, { "method": "GET", "path": "/gists/1337" }, { "method": "POST", "path": "/gists" }, { "method": "PATCH", "path": "/gists/1337" }, { "method": "PUT", "path": "/gists/1337/star" }, { "method": "DELETE", "path": "/gists/1337/star" }, { "method": "GET", "path": "/gists/1337/star" }, { "method": "POST", "path": "/gists/1337/forks" }, { "method": "DELETE", "path": "/gists/1337" }, { "method": "GET", "path": "/repos/fenny/fiber/git/blobs/v948b24g98ubngw9082bn02giub" }, { "method": "POST", "path": "/repos/fenny/fiber/git/blobs" }, { "method": "GET", "path": "/repos/fenny/fiber/git/commits/v948b24g98ubngw9082bn02giub" }, { "method": "POST", "path": "/repos/fenny/fiber/git/commits" }, { "method": "GET", "path": "/repos/fenny/fiber/git/refs/im/a/wildcard" }, { "method": "GET", "path": "/repos/fenny/fiber/git/refs" }, { "method": "POST", "path": "/repos/fenny/fiber/git/refs" }, { "method": "PATCH", "path": "/repos/fenny/fiber/git/refs/im/a/wildcard" }, { "method": "DELETE", "path": "/repos/fenny/fiber/git/refs/im/a/wildcard" }, { "method": "GET", "path": "/repos/fenny/fiber/git/tags/v948b24g98ubngw9082bn02giub" }, { "method": "POST", "path": "/repos/fenny/fiber/git/tags" }, { "method": "GET", "path": "/repos/fenny/fiber/git/trees/v948b24g98ubngw9082bn02giub" }, { "method": "POST", "path": "/repos/fenny/fiber/git/trees" }, { "method": "GET", "path": "/issues" }, { "method": "GET", "path": "/user/issues" }, { "method": "GET", "path": "/orgs/gofiber/issues" }, { "method": "GET", "path": "/repos/fenny/fiber/issues" }, { "method": "GET", "path": "/repos/fenny/fiber/issues/1000" }, { "method": "POST", "path": "/repos/fenny/fiber/issues" }, { "method": "PATCH", "path": "/repos/fenny/fiber/issues/1000" }, { "method": "GET", "path": "/repos/fenny/fiber/assignees" }, { "method": "GET", "path": "/repos/fenny/fiber/assignees/nic" }, { "method": "GET", "path": "/repos/fenny/fiber/issues/1000/comments" }, { "method": "GET", "path": "/repos/fenny/fiber/issues/comments" }, { "method": "GET", "path": "/repos/fenny/fiber/issues/comments/1337" }, { "method": "POST", "path": "/repos/fenny/fiber/issues/1000/comments" }, { "method": "PATCH", "path": "/repos/fenny/fiber/issues/comments/1337" }, { "method": "DELETE", "path": "/repos/fenny/fiber/issues/comments/1337" }, { "method": "GET", "path": "/repos/fenny/fiber/issues/1000/events" }, { "method": "GET", "path": "/repos/fenny/fiber/issues/events" }, { "method": "GET", "path": "/repos/fenny/fiber/issues/events/1337" }, { "method": "GET", "path": "/repos/fenny/fiber/labels" }, { "method": "GET", "path": "/repos/fenny/fiber/labels/john" }, { "method": "POST", "path": "/repos/fenny/fiber/labels" }, { "method": "PATCH", "path": "/repos/fenny/fiber/labels/john" }, { "method": "DELETE", "path": "/repos/fenny/fiber/labels/john" }, { "method": "GET", "path": "/repos/fenny/fiber/issues/1000/labels" }, { "method": "POST", "path": "/repos/fenny/fiber/issues/1000/labels" }, { "method": "DELETE", "path": "/repos/fenny/fiber/issues/1000/labels/john" }, { "method": "PUT", "path": "/repos/fenny/fiber/issues/1000/labels" }, { "method": "DELETE", "path": "/repos/fenny/fiber/issues/1000/labels" }, { "method": "GET", "path": "/repos/fenny/fiber/milestones/1000/labels" }, { "method": "GET", "path": "/repos/fenny/fiber/milestones" }, { "method": "GET", "path": "/repos/fenny/fiber/milestones/1000" }, { "method": "POST", "path": "/repos/fenny/fiber/milestones" }, { "method": "PATCH", "path": "/repos/fenny/fiber/milestones/1000" }, { "method": "DELETE", "path": "/repos/fenny/fiber/milestones/1000" }, { "method": "GET", "path": "/emojis" }, { "method": "GET", "path": "/gitignore/templates" }, { "method": "GET", "path": "/gitignore/templates/john" }, { "method": "POST", "path": "/markdown" }, { "method": "POST", "path": "/markdown/raw" }, { "method": "GET", "path": "/meta" }, { "method": "GET", "path": "/rate_limit" }, { "method": "GET", "path": "/users/fenny/orgs" }, { "method": "GET", "path": "/user/orgs" }, { "method": "GET", "path": "/orgs/gofiber" }, { "method": "PATCH", "path": "/orgs/gofiber" }, { "method": "GET", "path": "/orgs/gofiber/members" }, { "method": "GET", "path": "/orgs/gofiber/members/fenny" }, { "method": "DELETE", "path": "/orgs/gofiber/members/fenny" }, { "method": "GET", "path": "/orgs/gofiber/public_members" }, { "method": "GET", "path": "/orgs/gofiber/public_members/fenny" }, { "method": "PUT", "path": "/orgs/gofiber/public_members/fenny" }, { "method": "DELETE", "path": "/orgs/gofiber/public_members/fenny" }, { "method": "GET", "path": "/orgs/gofiber/teams" }, { "method": "GET", "path": "/teams/1337" }, { "method": "POST", "path": "/orgs/gofiber/teams" }, { "method": "PATCH", "path": "/teams/1337" }, { "method": "DELETE", "path": "/teams/1337" }, { "method": "GET", "path": "/teams/1337/members" }, { "method": "GET", "path": "/teams/1337/members/fenny" }, { "method": "PUT", "path": "/teams/1337/members/fenny" }, { "method": "DELETE", "path": "/teams/1337/members/fenny" }, { "method": "GET", "path": "/teams/1337/repos" }, { "method": "GET", "path": "/teams/1337/repos/fenny/fiber" }, { "method": "PUT", "path": "/teams/1337/repos/fenny/fiber" }, { "method": "DELETE", "path": "/teams/1337/repos/fenny/fiber" }, { "method": "GET", "path": "/user/teams" }, { "method": "GET", "path": "/repos/fenny/fiber/pulls" }, { "method": "GET", "path": "/repos/fenny/fiber/pulls/1000" }, { "method": "POST", "path": "/repos/fenny/fiber/pulls" }, { "method": "PATCH", "path": "/repos/fenny/fiber/pulls/1000" }, { "method": "GET", "path": "/repos/fenny/fiber/pulls/1000/commits" }, { "method": "GET", "path": "/repos/fenny/fiber/pulls/1000/files" }, { "method": "GET", "path": "/repos/fenny/fiber/pulls/1000/merge" }, { "method": "PUT", "path": "/repos/fenny/fiber/pulls/1000/merge" }, { "method": "GET", "path": "/repos/fenny/fiber/pulls/1000/comments" }, { "method": "GET", "path": "/repos/fenny/fiber/pulls/comments" }, { "method": "GET", "path": "/repos/fenny/fiber/pulls/comments/1000" }, { "method": "PUT", "path": "/repos/fenny/fiber/pulls/1000/comments" }, { "method": "PATCH", "path": "/repos/fenny/fiber/pulls/comments/1000" }, { "method": "DELETE", "path": "/repos/fenny/fiber/pulls/comments/1000" }, { "method": "GET", "path": "/user/repos" }, { "method": "GET", "path": "/users/fenny/repos" }, { "method": "GET", "path": "/orgs/gofiber/repos" }, { "method": "GET", "path": "/repositories" }, { "method": "POST", "path": "/user/repos" }, { "method": "POST", "path": "/orgs/gofiber/repos" }, { "method": "GET", "path": "/repos/fenny/fiber" }, { "method": "PATCH", "path": "/repos/fenny/fiber" }, { "method": "GET", "path": "/repos/fenny/fiber/contributors" }, { "method": "GET", "path": "/repos/fenny/fiber/languages" }, { "method": "GET", "path": "/repos/fenny/fiber/teams" }, { "method": "GET", "path": "/repos/fenny/fiber/tags" }, { "method": "GET", "path": "/repos/fenny/fiber/branches" }, { "method": "GET", "path": "/repos/fenny/fiber/branches/master" }, { "method": "DELETE", "path": "/repos/fenny/fiber" }, { "method": "GET", "path": "/repos/fenny/fiber/collaborators" }, { "method": "GET", "path": "/repos/fenny/fiber/collaborators/fenny" }, { "method": "PUT", "path": "/repos/fenny/fiber/collaborators/fenny" }, { "method": "DELETE", "path": "/repos/fenny/fiber/collaborators/fenny" }, { "method": "GET", "path": "/repos/fenny/fiber/comments" }, { "method": "GET", "path": "/repos/fenny/fiber/commits/v948b24g98ubngw9082bn02giub/comments" }, { "method": "POST", "path": "/repos/fenny/fiber/commits/v948b24g98ubngw9082bn02giub/comments" }, { "method": "GET", "path": "/repos/fenny/fiber/comments/1337" }, { "method": "PATCH", "path": "/repos/fenny/fiber/comments/1337" }, { "method": "DELETE", "path": "/repos/fenny/fiber/comments/1337" }, { "method": "GET", "path": "/repos/fenny/fiber/commits" }, { "method": "GET", "path": "/repos/fenny/fiber/commits/v948b24g98ubngw9082bn02giub" }, { "method": "GET", "path": "/repos/fenny/fiber/readme" }, { "method": "GET", "path": "/repos/fenny/fiber/contents/im/a/wildcard" }, { "method": "PUT", "path": "/repos/fenny/fiber/contents/im/a/wildcard" }, { "method": "DELETE", "path": "/repos/fenny/fiber/contents/im/a/wildcard" }, { "method": "GET", "path": "/repos/fenny/fiber/gzip/google" }, { "method": "GET", "path": "/repos/fenny/fiber/keys" }, { "method": "GET", "path": "/repos/fenny/fiber/keys/1337" }, { "method": "POST", "path": "/repos/fenny/fiber/keys" }, { "method": "PATCH", "path": "/repos/fenny/fiber/keys/1337" }, { "method": "DELETE", "path": "/repos/fenny/fiber/keys/1337" }, { "method": "GET", "path": "/repos/fenny/fiber/downloads" }, { "method": "GET", "path": "/repos/fenny/fiber/downloads/1337" }, { "method": "DELETE", "path": "/repos/fenny/fiber/downloads/1337" }, { "method": "GET", "path": "/repos/fenny/fiber/forks" }, { "method": "POST", "path": "/repos/fenny/fiber/forks" }, { "method": "GET", "path": "/repos/fenny/fiber/hooks" }, { "method": "GET", "path": "/repos/fenny/fiber/hooks/1337" }, { "method": "POST", "path": "/repos/fenny/fiber/hooks" }, { "method": "PATCH", "path": "/repos/fenny/fiber/hooks/1337" }, { "method": "POST", "path": "/repos/fenny/fiber/hooks/1337/tests" }, { "method": "DELETE", "path": "/repos/fenny/fiber/hooks/1337" }, { "method": "POST", "path": "/repos/fenny/fiber/merges" }, { "method": "GET", "path": "/repos/fenny/fiber/releases" }, { "method": "GET", "path": "/repos/fenny/fiber/releases/1337" }, { "method": "POST", "path": "/repos/fenny/fiber/releases" }, { "method": "PATCH", "path": "/repos/fenny/fiber/releases/1337" }, { "method": "DELETE", "path": "/repos/fenny/fiber/releases/1337" }, { "method": "GET", "path": "/repos/fenny/fiber/releases/1337/assets" }, { "method": "GET", "path": "/repos/fenny/fiber/stats/contributors" }, { "method": "GET", "path": "/repos/fenny/fiber/stats/commit_activity" }, { "method": "GET", "path": "/repos/fenny/fiber/stats/code_frequency" }, { "method": "GET", "path": "/repos/fenny/fiber/stats/participation" }, { "method": "GET", "path": "/repos/fenny/fiber/stats/punch_card" }, { "method": "GET", "path": "/repos/fenny/fiber/statuses/google" }, { "method": "POST", "path": "/repos/fenny/fiber/statuses/google" }, { "method": "GET", "path": "/search/repositories" }, { "method": "GET", "path": "/search/code" }, { "method": "GET", "path": "/search/issues" }, { "method": "GET", "path": "/search/users" }, { "method": "GET", "path": "/legacy/issues/search/fenny/fibersitory/locked/finish" }, { "method": "GET", "path": "/legacy/repos/search/finish" }, { "method": "GET", "path": "/legacy/user/search/finish" }, { "method": "GET", "path": "/legacy/user/email/info@gofiber.io" }, { "method": "GET", "path": "/users/fenny" }, { "method": "GET", "path": "/user" }, { "method": "PATCH", "path": "/user" }, { "method": "GET", "path": "/users" }, { "method": "GET", "path": "/user/emails" }, { "method": "POST", "path": "/user/emails" }, { "method": "DELETE", "path": "/user/emails" }, { "method": "GET", "path": "/users/fenny/followers" }, { "method": "GET", "path": "/user/followers" }, { "method": "GET", "path": "/users/fenny/following" }, { "method": "GET", "path": "/user/following" }, { "method": "GET", "path": "/user/following/fenny" }, { "method": "GET", "path": "/users/fenny/following/renan" }, { "method": "PUT", "path": "/user/following/fenny" }, { "method": "DELETE", "path": "/user/following/fenny" }, { "method": "GET", "path": "/users/fenny/keys" }, { "method": "GET", "path": "/user/keys" }, { "method": "GET", "path": "/user/keys/1337" }, { "method": "POST", "path": "/user/keys" }, { "method": "PATCH", "path": "/user/keys/1337" }, { "method": "DELETE", "path": "/user/keys/1337" } ], "github_api": [{ "method": "GET", "path": "/authorizations" }, { "method": "GET", "path": "/authorizations/:id" }, { "method": "POST", "path": "/authorizations" }, { "method": "PUT", "path": "/authorizations/clients/:client_id" }, { "method": "PATCH", "path": "/authorizations/:id" }, { "method": "DELETE", "path": "/authorizations/:id" }, { "method": "GET", "path": "/applications/:client_id/tokens/:access_token" }, { "method": "DELETE", "path": "/applications/:client_id/tokens" }, { "method": "DELETE", "path": "/applications/:client_id/tokens/:access_token" }, { "method": "GET", "path": "/events" }, { "method": "GET", "path": "/repos/:owner/:repo/events" }, { "method": "GET", "path": "/networks/:owner/:repo/events" }, { "method": "GET", "path": "/orgs/:org/events" }, { "method": "GET", "path": "/users/:user/received_events" }, { "method": "GET", "path": "/users/:user/received_events/public" }, { "method": "GET", "path": "/users/:user/events" }, { "method": "GET", "path": "/users/:user/events/public" }, { "method": "GET", "path": "/users/:user/events/orgs/:org" }, { "method": "GET", "path": "/feeds" }, { "method": "GET", "path": "/notifications" }, { "method": "GET", "path": "/repos/:owner/:repo/notifications" }, { "method": "PUT", "path": "/notifications" }, { "method": "PUT", "path": "/repos/:owner/:repo/notifications" }, { "method": "GET", "path": "/notifications/threads/:id" }, { "method": "PATCH", "path": "/notifications/threads/:id" }, { "method": "GET", "path": "/notifications/threads/:id/subscription" }, { "method": "PUT", "path": "/notifications/threads/:id/subscription" }, { "method": "DELETE", "path": "/notifications/threads/:id/subscription" }, { "method": "GET", "path": "/repos/:owner/:repo/stargazers" }, { "method": "GET", "path": "/users/:user/starred" }, { "method": "GET", "path": "/user/starred" }, { "method": "GET", "path": "/user/starred/:owner/:repo" }, { "method": "PUT", "path": "/user/starred/:owner/:repo" }, { "method": "DELETE", "path": "/user/starred/:owner/:repo" }, { "method": "GET", "path": "/repos/:owner/:repo/subscribers" }, { "method": "GET", "path": "/users/:user/subscriptions" }, { "method": "GET", "path": "/user/subscriptions" }, { "method": "GET", "path": "/repos/:owner/:repo/subscription" }, { "method": "PUT", "path": "/repos/:owner/:repo/subscription" }, { "method": "DELETE", "path": "/repos/:owner/:repo/subscription" }, { "method": "GET", "path": "/user/subscriptions/:owner/:repo" }, { "method": "PUT", "path": "/user/subscriptions/:owner/:repo" }, { "method": "DELETE", "path": "/user/subscriptions/:owner/:repo" }, { "method": "GET", "path": "/users/:user/gists" }, { "method": "GET", "path": "/gists" }, { "method": "GET", "path": "/gists/public" }, { "method": "GET", "path": "/gists/starred" }, { "method": "GET", "path": "/gists/:id" }, { "method": "POST", "path": "/gists" }, { "method": "PATCH", "path": "/gists/:id" }, { "method": "PUT", "path": "/gists/:id/star" }, { "method": "DELETE", "path": "/gists/:id/star" }, { "method": "GET", "path": "/gists/:id/star" }, { "method": "POST", "path": "/gists/:id/forks" }, { "method": "DELETE", "path": "/gists/:id" }, { "method": "GET", "path": "/repos/:owner/:repo/git/blobs/:sha" }, { "method": "POST", "path": "/repos/:owner/:repo/git/blobs" }, { "method": "GET", "path": "/repos/:owner/:repo/git/commits/:sha" }, { "method": "POST", "path": "/repos/:owner/:repo/git/commits" }, { "method": "GET", "path": "/repos/:owner/:repo/git/refs/*" }, { "method": "GET", "path": "/repos/:owner/:repo/git/refs" }, { "method": "POST", "path": "/repos/:owner/:repo/git/refs" }, { "method": "PATCH", "path": "/repos/:owner/:repo/git/refs/*" }, { "method": "DELETE", "path": "/repos/:owner/:repo/git/refs/*" }, { "method": "GET", "path": "/repos/:owner/:repo/git/tags/:sha" }, { "method": "POST", "path": "/repos/:owner/:repo/git/tags" }, { "method": "GET", "path": "/repos/:owner/:repo/git/trees/:sha" }, { "method": "POST", "path": "/repos/:owner/:repo/git/trees" }, { "method": "GET", "path": "/issues" }, { "method": "GET", "path": "/user/issues" }, { "method": "GET", "path": "/orgs/:org/issues" }, { "method": "GET", "path": "/repos/:owner/:repo/issues" }, { "method": "GET", "path": "/repos/:owner/:repo/issues/:number" }, { "method": "POST", "path": "/repos/:owner/:repo/issues" }, { "method": "PATCH", "path": "/repos/:owner/:repo/issues/:number" }, { "method": "GET", "path": "/repos/:owner/:repo/assignees" }, { "method": "GET", "path": "/repos/:owner/:repo/assignees/:assignee" }, { "method": "GET", "path": "/repos/:owner/:repo/issues/:number/comments" }, { "method": "GET", "path": "/repos/:owner/:repo/issues/comments" }, { "method": "GET", "path": "/repos/:owner/:repo/issues/comments/:id" }, { "method": "POST", "path": "/repos/:owner/:repo/issues/:number/comments" }, { "method": "PATCH", "path": "/repos/:owner/:repo/issues/comments/:id" }, { "method": "DELETE", "path": "/repos/:owner/:repo/issues/comments/:id" }, { "method": "GET", "path": "/repos/:owner/:repo/issues/:number/events" }, { "method": "GET", "path": "/repos/:owner/:repo/issues/events" }, { "method": "GET", "path": "/repos/:owner/:repo/issues/events/:id" }, { "method": "GET", "path": "/repos/:owner/:repo/labels" }, { "method": "GET", "path": "/repos/:owner/:repo/labels/:name" }, { "method": "POST", "path": "/repos/:owner/:repo/labels" }, { "method": "PATCH", "path": "/repos/:owner/:repo/labels/:name" }, { "method": "DELETE", "path": "/repos/:owner/:repo/labels/:name" }, { "method": "GET", "path": "/repos/:owner/:repo/issues/:number/labels" }, { "method": "POST", "path": "/repos/:owner/:repo/issues/:number/labels" }, { "method": "DELETE", "path": "/repos/:owner/:repo/issues/:number/labels/:name" }, { "method": "PUT", "path": "/repos/:owner/:repo/issues/:number/labels" }, { "method": "DELETE", "path": "/repos/:owner/:repo/issues/:number/labels" }, { "method": "GET", "path": "/repos/:owner/:repo/milestones/:number/labels" }, { "method": "GET", "path": "/repos/:owner/:repo/milestones" }, { "method": "GET", "path": "/repos/:owner/:repo/milestones/:number" }, { "method": "POST", "path": "/repos/:owner/:repo/milestones" }, { "method": "PATCH", "path": "/repos/:owner/:repo/milestones/:number" }, { "method": "DELETE", "path": "/repos/:owner/:repo/milestones/:number" }, { "method": "GET", "path": "/emojis" }, { "method": "GET", "path": "/gitignore/templates" }, { "method": "GET", "path": "/gitignore/templates/:name" }, { "method": "POST", "path": "/markdown" }, { "method": "POST", "path": "/markdown/raw" }, { "method": "GET", "path": "/meta" }, { "method": "GET", "path": "/rate_limit" }, { "method": "GET", "path": "/users/:user/orgs" }, { "method": "GET", "path": "/user/orgs" }, { "method": "GET", "path": "/orgs/:org" }, { "method": "PATCH", "path": "/orgs/:org" }, { "method": "GET", "path": "/orgs/:org/members" }, { "method": "GET", "path": "/orgs/:org/members/:user" }, { "method": "DELETE", "path": "/orgs/:org/members/:user" }, { "method": "GET", "path": "/orgs/:org/public_members" }, { "method": "GET", "path": "/orgs/:org/public_members/:user" }, { "method": "PUT", "path": "/orgs/:org/public_members/:user" }, { "method": "DELETE", "path": "/orgs/:org/public_members/:user" }, { "method": "GET", "path": "/orgs/:org/teams" }, { "method": "GET", "path": "/teams/:id" }, { "method": "POST", "path": "/orgs/:org/teams" }, { "method": "PATCH", "path": "/teams/:id" }, { "method": "DELETE", "path": "/teams/:id" }, { "method": "GET", "path": "/teams/:id/members" }, { "method": "GET", "path": "/teams/:id/members/:user" }, { "method": "PUT", "path": "/teams/:id/members/:user" }, { "method": "DELETE", "path": "/teams/:id/members/:user" }, { "method": "GET", "path": "/teams/:id/repos" }, { "method": "GET", "path": "/teams/:id/repos/:owner/:repo" }, { "method": "PUT", "path": "/teams/:id/repos/:owner/:repo" }, { "method": "DELETE", "path": "/teams/:id/repos/:owner/:repo" }, { "method": "GET", "path": "/user/teams" }, { "method": "GET", "path": "/repos/:owner/:repo/pulls" }, { "method": "GET", "path": "/repos/:owner/:repo/pulls/:number" }, { "method": "POST", "path": "/repos/:owner/:repo/pulls" }, { "method": "PATCH", "path": "/repos/:owner/:repo/pulls/:number" }, { "method": "GET", "path": "/repos/:owner/:repo/pulls/:number/commits" }, { "method": "GET", "path": "/repos/:owner/:repo/pulls/:number/files" }, { "method": "GET", "path": "/repos/:owner/:repo/pulls/:number/merge" }, { "method": "PUT", "path": "/repos/:owner/:repo/pulls/:number/merge" }, { "method": "GET", "path": "/repos/:owner/:repo/pulls/:number/comments" }, { "method": "GET", "path": "/repos/:owner/:repo/pulls/comments" }, { "method": "GET", "path": "/repos/:owner/:repo/pulls/comments/:number" }, { "method": "PUT", "path": "/repos/:owner/:repo/pulls/:number/comments" }, { "method": "PATCH", "path": "/repos/:owner/:repo/pulls/comments/:number" }, { "method": "DELETE", "path": "/repos/:owner/:repo/pulls/comments/:number" }, { "method": "GET", "path": "/user/repos" }, { "method": "GET", "path": "/users/:user/repos" }, { "method": "GET", "path": "/orgs/:org/repos" }, { "method": "GET", "path": "/repositories" }, { "method": "POST", "path": "/user/repos" }, { "method": "POST", "path": "/orgs/:org/repos" }, { "method": "GET", "path": "/repos/:owner/:repo" }, { "method": "PATCH", "path": "/repos/:owner/:repo" }, { "method": "GET", "path": "/repos/:owner/:repo/contributors" }, { "method": "GET", "path": "/repos/:owner/:repo/languages" }, { "method": "GET", "path": "/repos/:owner/:repo/teams" }, { "method": "GET", "path": "/repos/:owner/:repo/tags" }, { "method": "GET", "path": "/repos/:owner/:repo/branches" }, { "method": "GET", "path": "/repos/:owner/:repo/branches/:branch" }, { "method": "DELETE", "path": "/repos/:owner/:repo" }, { "method": "GET", "path": "/repos/:owner/:repo/collaborators" }, { "method": "GET", "path": "/repos/:owner/:repo/collaborators/:user" }, { "method": "PUT", "path": "/repos/:owner/:repo/collaborators/:user" }, { "method": "DELETE", "path": "/repos/:owner/:repo/collaborators/:user" }, { "method": "GET", "path": "/repos/:owner/:repo/comments" }, { "method": "GET", "path": "/repos/:owner/:repo/commits/:sha/comments" }, { "method": "POST", "path": "/repos/:owner/:repo/commits/:sha/comments" }, { "method": "GET", "path": "/repos/:owner/:repo/comments/:id" }, { "method": "PATCH", "path": "/repos/:owner/:repo/comments/:id" }, { "method": "DELETE", "path": "/repos/:owner/:repo/comments/:id" }, { "method": "GET", "path": "/repos/:owner/:repo/commits" }, { "method": "GET", "path": "/repos/:owner/:repo/commits/:sha" }, { "method": "GET", "path": "/repos/:owner/:repo/readme" }, { "method": "GET", "path": "/repos/:owner/:repo/contents/*" }, { "method": "PUT", "path": "/repos/:owner/:repo/contents/*" }, { "method": "DELETE", "path": "/repos/:owner/:repo/contents/*" }, { "method": "GET", "path": "/repos/:owner/:repo/:archive_format/:ref" }, { "method": "GET", "path": "/repos/:owner/:repo/keys" }, { "method": "GET", "path": "/repos/:owner/:repo/keys/:id" }, { "method": "POST", "path": "/repos/:owner/:repo/keys" }, { "method": "PATCH", "path": "/repos/:owner/:repo/keys/:id" }, { "method": "DELETE", "path": "/repos/:owner/:repo/keys/:id" }, { "method": "GET", "path": "/repos/:owner/:repo/downloads" }, { "method": "GET", "path": "/repos/:owner/:repo/downloads/:id" }, { "method": "DELETE", "path": "/repos/:owner/:repo/downloads/:id" }, { "method": "GET", "path": "/repos/:owner/:repo/forks" }, { "method": "POST", "path": "/repos/:owner/:repo/forks" }, { "method": "GET", "path": "/repos/:owner/:repo/hooks" }, { "method": "GET", "path": "/repos/:owner/:repo/hooks/:id" }, { "method": "POST", "path": "/repos/:owner/:repo/hooks" }, { "method": "PATCH", "path": "/repos/:owner/:repo/hooks/:id" }, { "method": "POST", "path": "/repos/:owner/:repo/hooks/:id/tests" }, { "method": "DELETE", "path": "/repos/:owner/:repo/hooks/:id" }, { "method": "POST", "path": "/repos/:owner/:repo/merges" }, { "method": "GET", "path": "/repos/:owner/:repo/releases" }, { "method": "GET", "path": "/repos/:owner/:repo/releases/:id" }, { "method": "POST", "path": "/repos/:owner/:repo/releases" }, { "method": "PATCH", "path": "/repos/:owner/:repo/releases/:id" }, { "method": "DELETE", "path": "/repos/:owner/:repo/releases/:id" }, { "method": "GET", "path": "/repos/:owner/:repo/releases/:id/assets" }, { "method": "GET", "path": "/repos/:owner/:repo/stats/contributors" }, { "method": "GET", "path": "/repos/:owner/:repo/stats/commit_activity" }, { "method": "GET", "path": "/repos/:owner/:repo/stats/code_frequency" }, { "method": "GET", "path": "/repos/:owner/:repo/stats/participation" }, { "method": "GET", "path": "/repos/:owner/:repo/stats/punch_card" }, { "method": "GET", "path": "/repos/:owner/:repo/statuses/:ref" }, { "method": "POST", "path": "/repos/:owner/:repo/statuses/:ref" }, { "method": "GET", "path": "/search/repositories" }, { "method": "GET", "path": "/search/code" }, { "method": "GET", "path": "/search/issues" }, { "method": "GET", "path": "/search/users" }, { "method": "GET", "path": "/legacy/issues/search/:owner/:repository/:state/:keyword" }, { "method": "GET", "path": "/legacy/repos/search/:keyword" }, { "method": "GET", "path": "/legacy/user/search/:keyword" }, { "method": "GET", "path": "/legacy/user/email/:email" }, { "method": "GET", "path": "/users/:user" }, { "method": "GET", "path": "/user" }, { "method": "PATCH", "path": "/user" }, { "method": "GET", "path": "/users" }, { "method": "GET", "path": "/user/emails" }, { "method": "POST", "path": "/user/emails" }, { "method": "DELETE", "path": "/user/emails" }, { "method": "GET", "path": "/users/:user/followers" }, { "method": "GET", "path": "/user/followers" }, { "method": "GET", "path": "/users/:user/following" }, { "method": "GET", "path": "/user/following" }, { "method": "GET", "path": "/user/following/:user" }, { "method": "GET", "path": "/users/:user/following/:target_user" }, { "method": "PUT", "path": "/user/following/:user" }, { "method": "DELETE", "path": "/user/following/:user" }, { "method": "GET", "path": "/users/:user/keys" }, { "method": "GET", "path": "/user/keys" }, { "method": "GET", "path": "/user/keys/:id" }, { "method": "POST", "path": "/user/keys" }, { "method": "PATCH", "path": "/user/keys/:id" }, { "method": "DELETE", "path": "/user/keys/:id" } ] } ================================================ FILE: .github/testdata2/bruh.tmpl ================================================

I'm Bruh

================================================ FILE: .github/testdata3/hello_world.tmpl ================================================

Hello {{ .Name }}!

================================================ FILE: .github/workflows/auto-labeler.yml ================================================ name: auto-labeler on: issues: types: [opened, edited, milestoned] pull_request_target: types: [opened, edited, reopened, synchronize] jobs: auto-labeler: uses: gofiber/.github/.github/workflows/auto-labeler.yml@main secrets: github-token: ${{ secrets.ISSUE_PR_TOKEN }} with: config-path: .github/labeler.yml config-repository: gofiber/fiber ================================================ FILE: .github/workflows/benchmark.yml ================================================ on: workflow_dispatch: push: branches: - main paths-ignore: - "**/*.md" pull_request: paths-ignore: - "**/*.md" permissions: # deployments permission to deploy GitHub pages website deployments: write # contents permission to update benchmark contents in gh-pages branch contents: write # allow posting comments to pull request pull-requests: write name: Benchmark jobs: Compare: runs-on: ubuntu-latest steps: - name: Fetch Repository uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: fetch-depth: 0 # to be able to retrieve the last commit in main - name: Install Go uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 with: # NOTE: Keep this in sync with the version from go.mod go-version: "1.25.x" - name: Run Benchmark run: set -o pipefail; go test ./... -benchmem -run=^$ -bench . | tee output.txt - name: Remove _Parallel Benchmarks run: | awk '!/^Benchmark.*_Parallel/' output.txt > output_filtered.txt mv output_filtered.txt output.txt # NOTE: Benchmarks could change with different CPU types - name: Get GitHub Runner System Information uses: kenchan0130/actions-system-info@59699597e84e80085a750998045983daa49274c4 # v1.4.0 id: system-info - name: Get Main branch SHA id: get-main-branch-sha run: | SHA=$(git rev-parse origin/main) echo "sha=$SHA" >> $GITHUB_OUTPUT - name: Get Benchmark Results from main branch id: cache uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 with: path: ./cache key: ${{ steps.get-main-branch-sha.outputs.sha }}-${{ runner.os }}-${{ steps.system-info.outputs.cpu-model }}-benchmark # This will only run if we have Benchmark Results from main branch - name: Compare PR Benchmark Results with main branch uses: benchmark-action/github-action-benchmark@a7bc2366eda11037936ea57d811a43b3418d3073 # v1.21.0 if: steps.cache.outputs.cache-hit == 'true' with: tool: 'go' output-file-path: output.txt external-data-json-path: ./cache/benchmark-data.json # Do not save the data (This allows comparing benchmarks) save-data-file: false fail-on-alert: true # Comment on the PR if the branch is not a fork comment-on-alert: ${{ github.event.pull_request.head.repo.fork == false }} github-token: ${{ secrets.GITHUB_TOKEN }} summary-always: true alert-threshold: "150%" go-force-package-suffix: true - name: Store Benchmark Results for main branch uses: benchmark-action/github-action-benchmark@a7bc2366eda11037936ea57d811a43b3418d3073 # v1.21.0 if: ${{ github.ref_name == 'main' }} with: tool: 'go' output-file-path: output.txt external-data-json-path: ./cache/benchmark-data.json # Save the data to external file (cache) save-data-file: true fail-on-alert: false github-token: ${{ secrets.GITHUB_TOKEN }} summary-always: true alert-threshold: "150%" go-force-package-suffix: true - name: Publish Benchmark Results to GitHub Pages uses: benchmark-action/github-action-benchmark@a7bc2366eda11037936ea57d811a43b3418d3073 # v1.21.0 if: ${{ github.ref_name == 'main' }} with: tool: 'go' output-file-path: output.txt benchmark-data-dir-path: "benchmarks" fail-on-alert: false github-token: ${{ secrets.GITHUB_TOKEN }} comment-on-alert: true summary-always: true # Save the data to external file (GitHub Pages) save-data-file: true alert-threshold: "150%" auto-push: ${{ github.event_name == 'push' || github.event_name == 'workflow_dispatch' }} go-force-package-suffix: true - name: Update Benchmark Results cache uses: actions/cache/save@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 if: ${{ github.ref_name == 'main' }} with: path: ./cache key: ${{ steps.get-main-branch-sha.outputs.sha }}-${{ runner.os }}-${{ steps.system-info.outputs.cpu-model }}-benchmark ================================================ FILE: .github/workflows/codeql-analysis.yml ================================================ name: "CodeQL" on: workflow_dispatch: push: branches: - main paths-ignore: - "**/*.md" pull_request: paths-ignore: - "**/*.md" schedule: - cron: "0 3 * * 6" jobs: analyse: name: Analyse runs-on: ubuntu-latest steps: - name: Checkout repository uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: # We must fetch at least the immediate parents so that if this is # a pull request then we can checkout the head. fetch-depth: 2 # If this run was triggered by a pull request event, then checkout # the head of the pull request instead of the merge commit. - run: git checkout HEAD^2 if: ${{ github.event_name == 'pull_request' }} # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL uses: github/codeql-action/init@b1bff81932f5cdfc8695c7752dcee935dcd061c8 # v4.33.0 # Override language selection by uncommenting this and choosing your languages with: languages: go # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). # If this step fails, then you should remove it and run the build manually (see below) - name: Autobuild uses: github/codeql-action/autobuild@b1bff81932f5cdfc8695c7752dcee935dcd061c8 # v4.33.0 # ℹ️ Command-line programs to run using the OS shell. # 📚 https://git.io/JvXDl # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines # and modify them (or add more) to build your code if your project # uses a compiled language #- run: | # make bootstrap # make release - name: Perform CodeQL Analysis uses: github/codeql-action/analyze@b1bff81932f5cdfc8695c7752dcee935dcd061c8 # v4.33.0 ================================================ FILE: .github/workflows/dependabot_automerge.yml ================================================ name: Dependabot auto-merge on: workflow_dispatch: pull_request_target: permissions: contents: write pull-requests: write jobs: wait_for_checks: runs-on: ubuntu-latest if: ${{ github.actor == 'dependabot[bot]' }} steps: - name: Wait for check is finished uses: lewagon/wait-on-check-action@v1.5.0 id: wait_for_checks with: ref: ${{ github.event.pull_request.head.sha || github.sha }} running-workflow-name: wait_for_checks check-regexp: unit repo-token: ${{ secrets.PR_TOKEN }} wait-interval: 10 dependabot: needs: [wait_for_checks] name: Dependabot auto-merge runs-on: ubuntu-latest if: ${{ github.actor == 'dependabot[bot]' }} steps: - name: Dependabot metadata id: metadata uses: dependabot/fetch-metadata@v2.5.0 with: github-token: "${{ secrets.PR_TOKEN }}" - name: Enable auto-merge for Dependabot PRs if: ${{steps.metadata.outputs.update-type == 'version-update:semver-minor' || steps.metadata.outputs.update-type == 'version-update:semver-patch'}} run: | gh pr review --approve "$PR_URL" gh pr merge --auto --merge "$PR_URL" env: PR_URL: ${{github.event.pull_request.html_url}} GITHUB_TOKEN: ${{secrets.PR_TOKEN}} ================================================ FILE: .github/workflows/linter.yml ================================================ name: golangci-lint on: workflow_dispatch: push: branches: - main paths-ignore: - "**/*.md" pull_request: paths-ignore: - "**/*.md" permissions: # Required: allow read access to the content for analysis. contents: read # Optional: allow read access to pull request. Use with `only-new-issues` option. pull-requests: read # Optional: Allow write access to checks to allow the action to annotate code in the PR. checks: write jobs: golangci: name: lint runs-on: ubuntu-latest steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 with: # NOTE: Keep this in sync with the version from go.mod go-version: "1.25.x" cache: false - name: golangci-lint uses: golangci/golangci-lint-action@1e7e51e771db61008b38414a730f564565cf7c20 # v9.2.0 with: # NOTE: Keep this in sync with the version from .golangci.yml version: v2.5.0 install-mode: goinstall ================================================ FILE: .github/workflows/manual-dependabot.yml ================================================ # https://github.com/dependabot/dependabot-script/blob/main/manual-github-actions.yaml # https://github.com/dependabot/dependabot-script?tab=readme-ov-file#github-actions-standalone name: ManualDependabot on: workflow_dispatch: inputs: package-manager: description: 'The package manager to use' required: true default: 'gomod' directory: description: 'The directory to scan' required: true default: '/' permissions: contents: read jobs: dependabot: permissions: contents: write # for Git to git push pull-requests: write # for repo-sync/pull-request to create pull requests runs-on: ubuntu-latest steps: - name: Checkout repo uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Checkout dependabot run: | cd /tmp/ git clone https://github.com/dependabot/dependabot-script - name: Build image run: | cd /tmp/dependabot-script docker build -t "dependabot/dependabot-script" -f Dockerfile . - name: Run dependabot env: PACKAGE_MANAGER: ${{ github.event.inputs.package-manager }} DIRECTORY: ${{ github.event.inputs.directory }} GITHUB_ACCESS_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: | docker run -v $PWD:/src -e PROJECT_PATH=$GITHUB_REPOSITORY -e PACKAGE_MANAGER=$PACKAGE_MANAGER -e DIRECTORY=$DIRECTORY -e GITHUB_ACCESS_TOKEN=$GITHUB_ACCESS_TOKEN -e OPTIONS="$OPTIONS" dependabot/dependabot-script ================================================ FILE: .github/workflows/markdown.yml ================================================ name: markdownlint on: workflow_dispatch: push: branches: - main paths: - "**/*.md" pull_request: paths: - "**/*.md" jobs: markdownlint: runs-on: ubuntu-latest steps: - name: Fetch Repository uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Run markdownlint-cli2 uses: DavidAnson/markdownlint-cli2-action@07035fd053f7be764496c0f8d8f9f41f98305101 # v22.0.0 with: globs: | **/*.md #vendor ================================================ FILE: .github/workflows/modernize.yml ================================================ name: Modernize Lint on: workflow_dispatch: push: branches: - main paths-ignore: - "**/*.md" - "**/*_msgp*.go" pull_request: paths-ignore: - "**/*.md" - "**/*_msgp*.go" permissions: contents: read pull-requests: write checks: write jobs: modernize: runs-on: ubuntu-latest steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 with: # NOTE: Keep this in sync with the version from go.mod go-version: "1.25.x" cache: false - name: modernize run: go run golang.org/x/tools/gopls/internal/analysis/modernize/cmd/modernize@latest -test=false ./... ================================================ FILE: .github/workflows/move-closed-milestone-items.yml ================================================ name: Move closed milestone items on: workflow_dispatch: inputs: source_milestone: description: Milestone that currently owns the closed items required: true type: string target_milestone: description: Milestone that should receive the closed items required: true type: string permissions: contents: read issues: write pull-requests: write jobs: move-closed-items: runs-on: ubuntu-latest steps: - name: Move closed items to target milestone uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: github-token: ${{ secrets.ISSUE_PR_TOKEN }} script: | const dispatchInputs = context.payload.inputs ?? {}; const sourceTitle = (dispatchInputs.source_milestone ?? '').trim(); const targetTitle = (dispatchInputs.target_milestone ?? '').trim(); if (sourceTitle.length === 0 || targetTitle.length === 0) { throw new Error('Both source_milestone and target_milestone must be non-empty.'); } if (sourceTitle === targetTitle) { throw new Error('source_milestone and target_milestone must be different.'); } const owner = context.repo.owner; const repo = context.repo.repo; async function listMilestones() { return github.paginate(github.rest.issues.listMilestones, { owner, repo, state: 'all', per_page: 100, }); } function findMilestoneByTitle(milestones, title) { return milestones.find((milestone) => milestone.title === title); } let milestones = await listMilestones(); const sourceMilestone = findMilestoneByTitle(milestones, sourceTitle); if (!sourceMilestone) { throw new Error(`Source milestone "${sourceTitle}" was not found.`); } let targetMilestone = findMilestoneByTitle(milestones, targetTitle); if (!targetMilestone) { const createdMilestone = await github.rest.issues.createMilestone({ owner, repo, title: targetTitle, }); targetMilestone = createdMilestone.data; core.info(`Created target milestone "${targetTitle}" (#${targetMilestone.number}).`); } else if (targetMilestone.state !== 'open') { const reopenedMilestone = await github.rest.issues.updateMilestone({ owner, repo, milestone_number: targetMilestone.number, state: 'open', }); targetMilestone = reopenedMilestone.data; core.info(`Reopened target milestone "${targetTitle}" (#${targetMilestone.number}).`); } const closedItems = await github.paginate(github.rest.issues.listForRepo, { owner, repo, state: 'closed', milestone: String(sourceMilestone.number), per_page: 100, }); if (closedItems.length === 0) { core.notice(`No closed items were found in milestone "${sourceTitle}".`); return; } for (const item of closedItems) { await github.rest.issues.update({ owner, repo, issue_number: item.number, milestone: targetMilestone.number, }); const itemType = item.pull_request ? 'pull request' : 'issue'; core.info(`Moved ${itemType} #${item.number} to milestone "${targetTitle}".`); } core.notice( `Moved ${closedItems.length} closed item(s) from "${sourceTitle}" to "${targetTitle}".`, ); ================================================ FILE: .github/workflows/release-drafter.yml ================================================ name: Release Drafter on: push: branches: - main workflow_dispatch: permissions: contents: read jobs: update_release_draft: permissions: # write permission is required to create a github release contents: write # write permission is required for autolabeler # otherwise, read permission is required at least pull-requests: read runs-on: ubuntu-latest steps: - uses: release-drafter/release-drafter@139054aeaa9adc52ab36ddf67437541f039b88e2 # v7.1.1 with: disable-autolabeler: true env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} ================================================ FILE: .github/workflows/spell-check.yml ================================================ name: Spell check on: workflow_dispatch: pull_request: types: - opened - synchronize - reopened - ready_for_review push: branches: - main permissions: contents: read pull-requests: read jobs: cspell: name: cspell runs-on: ubuntu-latest steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Set up Node.js uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0 with: node-version: "20.x" - name: Install cspell dictionaries run: | npm install --no-save \ @cspell/dict-en_us \ @cspell/dict-en-gb \ @cspell/dict-software-terms \ @cspell/dict-golang \ @cspell/dict-fullstack \ @cspell/dict-docker \ @cspell/dict-k8s \ @cspell/dict-node \ @cspell/dict-npm \ @cspell/dict-typescript \ @cspell/dict-html \ @cspell/dict-css \ @cspell/dict-shell \ @cspell/dict-python \ @cspell/dict-redis \ @cspell/dict-sql \ @cspell/dict-filetypes \ @cspell/dict-companies \ @cspell/dict-markdown \ @cspell/dict-en-common-misspellings \ @cspell/dict-people-names \ @cspell/dict-data-science - name: Run cspell uses: streetsidesoftware/cspell-action@9cd41bb518a24fefdafd9880cbab8f0ceba04d28 # v8.3.0 with: incremental_files_only: false check_dot_files: explicit report: typos verbose: true - name: Run codespell uses: codespell-project/actions-codespell@8f01853be192eb0f849a5c7d721450e7a467c579 # v2.2 with: skip: ./.git,./node_modules,./**/*.go,./*.go,./.github/workflows/spell-check.yml ignore_words_list: TE,te ================================================ FILE: .github/workflows/sync-docs.yml ================================================ name: "Sync docs" on: workflow_dispatch: push: branches: - main paths: - "docs/**" release: types: [published] jobs: sync-docs: runs-on: ubuntu-latest steps: - name: Checkout uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: ref: ${{ github.event.pull_request.head.sha }} fetch-depth: 2 - name: Setup Node.js environment uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0 with: node-version: "22.x" - name: Sync docs run: ./.github/scripts/sync_docs.sh env: EVENT: ${{ github.event_name }} TAG_NAME: ${{ github.ref_name }} TOKEN: ${{ secrets.DOC_SYNC_TOKEN }} ================================================ FILE: .github/workflows/test.yml ================================================ name: Test on: workflow_dispatch: push: branches: - main paths-ignore: - "**/*.md" pull_request: paths-ignore: - "**/*.md" jobs: unit: strategy: matrix: go-version: [1.25.x, 1.26.x] platform: [ubuntu-latest, windows-latest, macos-latest] runs-on: ${{ matrix.platform }} steps: - name: Fetch Repository uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Install Go uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 with: go-version: ${{ matrix.go-version }} - name: Test run: go run gotest.tools/gotestsum@latest -f testname -- ./... -race -count=1 -coverprofile=coverage.txt -covermode=atomic -shuffle=on - name: Upload coverage reports to Codecov if: ${{ matrix.platform == 'ubuntu-latest' && matrix.go-version == '1.25.x' }} uses: codecov/codecov-action@1af58845a975a7985b0beb0cbe6fbbb71a41dbad # v5.5.3 with: token: ${{ secrets.CODECOV_TOKEN }} flags: unittests slug: gofiber/fiber verbose: true repeated: runs-on: ubuntu-latest steps: - name: Fetch Repository uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Install Go uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 with: go-version: stable - name: Test run: go run gotest.tools/gotestsum@latest -f testname -- ./... -race -count=15 -shuffle=on ================================================ FILE: .github/workflows/v3-label-automation.yml ================================================ name: Assign v3 project and milestone on: issues: types: - labeled pull_request_target: types: - labeled permissions: contents: read issues: write pull-requests: write jobs: assign-v3: if: ${{ github.event.label && github.event.label.name == 'v3' }} runs-on: ubuntu-latest steps: - name: Add item to v3 project uses: actions/add-to-project@244f685bbc3b7adfa8466e08b698b5577571133e # v1.0.2 with: project-url: https://github.com/orgs/gofiber/projects/1 github-token: ${{ secrets.ISSUE_PR_TOKEN }} - name: Assign v3 milestone uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: github-token: ${{ secrets.ISSUE_PR_TOKEN }} script: | const payload = context.eventName === 'issues' ? context.payload.issue : context.payload.pull_request; const issueNumber = payload.number; const milestones = await github.paginate(github.rest.issues.listMilestones, { owner: context.repo.owner, repo: context.repo.repo, state: 'open', per_page: 100, }); const milestone = milestones.find((item) => item.title === 'v3'); if (!milestone) { throw new Error('Milestone "v3" was not found.'); } await github.rest.issues.update({ owner: context.repo.owner, repo: context.repo.repo, issue_number: issueNumber, milestone: milestone.number, }); ================================================ FILE: .github/workflows/vulncheck.yml ================================================ name: Run govulncheck on: workflow_dispatch: push: branches: - main paths-ignore: - "**/*.md" pull_request: paths-ignore: - "**/*.md" jobs: govulncheck-check: runs-on: ubuntu-latest env: GO111MODULE: on steps: - name: Fetch Repository uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Install Go uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 with: go-version: "stable" check-latest: true cache: false - name: Install Govulncheck run: go install golang.org/x/vuln/cmd/govulncheck@latest - name: Run Govulncheck run: govulncheck ./... ================================================ FILE: .gitignore ================================================ # Binaries for programs and plugins *.exe *.exe~ *.dll *.so *.dylib # Test binary, built with `go test -c` *.test *.tmp # Output of the go coverage tool **/*.out .cache .gocache # IDE files .vscode .DS_Store .idea .claude # Misc *.fiber.gz *.fiber.zst *.fiber.br *.fasthttp.gz *.fasthttp.zst *.fasthttp.br *.test.gz *.test.zst *.test.br *.pprof *.workspace # Dependencies /vendor/ vendor/ vendor /Godeps/ # Local tools bin/ ================================================ FILE: .golangci.yml ================================================ version: "2" run: modules-download-mode: readonly allow-serial-runners: true linters: enable: - asasalint - asciicheck - bidichk - bodyclose - containedctx - contextcheck - copyloopvar - decorder - depguard - dogsled - dupword - durationcheck - err113 - errchkjson - errname - errorlint - exhaustive - forbidigo - forcetypeassert - ginkgolinter - gochecksumtype - goconst - gocritic - gomoddirectives - goprintffuncname - gosec - grouper - loggercheck - makezero - mirror - misspell - musttag - nakedret - nilerr - nilnil # - noctx # TODO: enable this once the codebase is migrated to context aware APIs - nolintlint - nonamedreturns - nosprintfhostport - perfsprint - predeclared - promlinter - protogetter - reassign - revive - rowserrcheck - sloglint - spancheck - sqlclosecheck - staticcheck - tagliatelle - testableexamples - testifylint - thelper - tparallel - unconvert - unparam - usestdlibvars - whitespace - wrapcheck - zerologlint settings: depguard: rules: all: list-mode: lax deny: - pkg: flag desc: "`flag` package is only allowed in main.go" - pkg: log desc: logging is provided by `pkg/log` - pkg: io/ioutil desc: "`io/ioutil` package is deprecated, use the `io` and `os` package instead" errcheck: disable-default-exclusions: true check-type-assertions: true check-blank: true exclude-functions: - (*bytes.Buffer).Write - (*github.com/valyala/bytebufferpool.ByteBuffer).Write - (*github.com/valyala/bytebufferpool.ByteBuffer).WriteByte - (*github.com/valyala/bytebufferpool.ByteBuffer).WriteString errchkjson: report-no-exported: true exhaustive: default-signifies-exhaustive: true forbidigo: forbid: - pattern: ^print(ln)?$ - pattern: ^fmt\.Print(f|ln)?$ - pattern: ^http\.Default(Client|ServeMux|Transport)$ analyze-types: true goconst: numbers: true gocritic: enabled-tags: - diagnostic - style - performance settings: captLocal: paramsOnly: false elseif: skipBalanced: false underef: skipRecvDeref: false gosec: excludes: - G104 config: global: audit: true govet: enable-all: true grouper: import-require-single-import: true import-require-grouping: true loggercheck: require-string-key: true no-printf-like: true misspell: locale: US nolintlint: require-explanation: true require-specific: true nonamedreturns: report-error-in-defer: true perfsprint: err-error: true predeclared: qualified-name: true promlinter: strict: true revive: enable-all-rules: true rules: - name: add-constant disabled: true - name: argument-limit disabled: true - name: banned-characters disabled: true - name: cognitive-complexity disabled: true - name: confusing-results disabled: true - name: comment-spacings arguments: - nolint disabled: true - name: cyclomatic disabled: true - name: enforce-slice-style arguments: - make disabled: true - name: exported disabled: true - name: file-header disabled: true - name: function-result-limit arguments: - 3 - name: function-length disabled: true - name: line-length-limit disabled: true - name: max-public-structs disabled: true - name: modifies-parameter disabled: true - name: nested-structs disabled: true - name: package-comments disabled: true - name: optimize-operands-order disabled: true - name: unchecked-type-assertion disabled: true - name: unhandled-error disabled: true staticcheck: checks: - all - -ST1000 - -ST1020 - -ST1021 - -ST1022 tagalign: strict: true tagliatelle: case: rules: json: snake testifylint: enable-all: true testpackage: skip-regexp: ^$ unparam: check-exported: false unused: field-writes-are-uses: true exported-fields-are-used: true usestdlibvars: http-method: true http-status-code: true time-weekday: false time-month: false time-layout: false crypto-hash: true default-rpc-path: true sql-isolation-level: true tls-signature-scheme: true constant-kind: true wrapcheck: ignore-package-globs: - github.com/gofiber/fiber/* - github.com/valyala/fasthttp exclusions: generated: lax rules: - text: (?i)do not define dynamic errors, use wrapped static errors instead* linters: - err113 - path: log/.*\.go linters: - depguard - path: _test\.go linters: - bodyclose - err113 - goconst # disabling goconst in test files only - source: (?i)fmt.Fprintf? linters: - errcheck - revive paths: - _msgp\.go - _msgp_test\.go - third_party$ - builtin$ - examples$ issues: max-issues-per-linter: 0 max-same-issues: 0 formatters: enable: - gofmt - gofumpt - goimports settings: gci: sections: - standard - prefix(github.com/gofiber/fiber) - default - blank - dot custom-order: true gofumpt: module-path: github.com/gofiber/fiber extra-rules: true exclusions: generated: lax paths: - _msgp\.go - _msgp_test\.go - third_party$ - builtin$ - examples$ ================================================ FILE: .markdownlint.yml ================================================ # Example markdownlint configuration with all properties set to their default value # Default state for all rules default: true # Path to configuration file to extend extends: null # MD001/heading-increment : Heading levels should only increment by one level at a time : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md001.md # NOTE: The docs intentionally jump heading levels for anchor stability, so skip this rule globally. MD001: false # MD003/heading-style : Heading style : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md003.md MD003: # Heading style style: "consistent" # MD004/ul-style : Unordered list style : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md004.md MD004: # List style style: "consistent" # MD005/list-indent : Inconsistent indentation for list items at the same level : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md005.md MD005: true # MD007/ul-indent : Unordered list indentation : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md007.md MD007: # Spaces for indent indent: # Whether to indent the first level of the list start_indented: false # Spaces for first level indent (when start_indented is set) start_indent: 2 # MD009/no-trailing-spaces : Trailing spaces : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md009.md MD009: # Spaces for line break br_spaces: 2 # Allow spaces for empty lines in list items list_item_empty_lines: false # Include unnecessary breaks strict: true # MD010/no-hard-tabs : Hard tabs : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md010.md MD010: # Include code blocks code_blocks: true # Fenced code languages to ignore ignore_code_languages: [] # Number of spaces for each hard tab spaces_per_tab: 4 # MD011/no-reversed-links : Reversed link syntax : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md011.md MD011: true # MD012/no-multiple-blanks : Multiple consecutive blank lines : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md012.md MD012: # Consecutive blank lines maximum: 1 # MD013/line-length : Line length : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md013.md MD013: false # MD014/commands-show-output : Dollar signs used before commands without showing output : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md014.md MD014: true # MD018/no-missing-space-atx : No space after hash on atx style heading : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md018.md MD018: true # MD019/no-multiple-space-atx : Multiple spaces after hash on atx style heading : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md019.md MD019: true # MD020/no-missing-space-closed-atx : No space inside hashes on closed atx style heading : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md020.md MD020: true # MD021/no-multiple-space-closed-atx : Multiple spaces inside hashes on closed atx style heading : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md021.md MD021: true # MD022/blanks-around-headings : Headings should be surrounded by blank lines : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md022.md MD022: # Blank lines above heading lines_above: 1 # Blank lines below heading lines_below: 1 # MD023/heading-start-left : Headings must start at the beginning of the line : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md023.md MD023: true # MD024/no-duplicate-heading : Multiple headings with the same content : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md024.md MD024: false # MD025/single-title/single-h1 : Multiple top-level headings in the same document : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md025.md MD025: # Heading level level: 1 # RegExp for matching title in front matter front_matter_title: "^\\s*title\\s*[:=]" # MD026/no-trailing-punctuation : Trailing punctuation in heading : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md026.md MD026: # Punctuation characters punctuation: ".,;:!。,;:!" # MD027/no-multiple-space-blockquote : Multiple spaces after blockquote symbol : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md027.md MD027: true # MD028/no-blanks-blockquote : Blank line inside blockquote : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md028.md MD028: true # MD029/ol-prefix : Ordered list item prefix : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md029.md MD029: # List style style: "one_or_ordered" # MD030/list-marker-space : Spaces after list markers : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md030.md MD030: # Spaces for single-line unordered list items ul_single: 1 # Spaces for single-line ordered list items ol_single: 1 # Spaces for multi-line unordered list items ul_multi: 1 # Spaces for multi-line ordered list items ol_multi: 1 # MD031/blanks-around-fences : Fenced code blocks should be surrounded by blank lines : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md031.md MD031: # Include list items list_items: true # MD032/blanks-around-lists : Lists should be surrounded by blank lines : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md032.md MD032: true # MD033/no-inline-html : Inline HTML : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md033.md MD033: false # MD034/no-bare-urls : Bare URL used : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md034.md MD034: true # MD035/hr-style : Horizontal rule style : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md035.md MD035: # Horizontal rule style style: "consistent" # MD036/no-emphasis-as-heading : Emphasis used instead of a heading : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md036.md MD036: # Punctuation characters punctuation: ".,;:!?。,;:!?" # MD037/no-space-in-emphasis : Spaces inside emphasis markers : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md037.md MD037: true # MD038/no-space-in-code : Spaces inside code span elements : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md038.md MD038: true # MD039/no-space-in-links : Spaces inside link text : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md039.md MD039: true # MD040/fenced-code-language : Fenced code blocks should have a language specified : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md040.md MD040: # List of languages allowed_languages: [] # Require language only language_only: false # MD041/first-line-heading/first-line-h1 : First line in a file should be a top-level heading : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md041.md MD041: # Heading level level: 1 # RegExp for matching title in front matter front_matter_title: "^\\s*title\\s*[:=]" # MD042/no-empty-links : No empty links : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md042.md MD042: true # MD043/required-headings : Required heading structure : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md043.md MD043: false # MD044/proper-names : Proper names should have the correct capitalization : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md044.md MD044: # List of proper names names: [] # Include code blocks code_blocks: true # Include HTML elements html_elements: true # MD045/no-alt-text : Images should have alternate text (alt text) : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md045.md MD045: false # MD046/code-block-style : Code block style : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md046.md MD046: # Block style style: "fenced" # MD047/single-trailing-newline : Files should end with a single newline character : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md047.md MD047: true # MD048/code-fence-style : Code fence style : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md048.md MD048: # Code fence style style: "backtick" # MD049/emphasis-style : Emphasis style : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md049.md MD049: # Emphasis style style: "consistent" # MD050/strong-style : Strong style : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md050.md MD050: # Strong style style: "consistent" # MD051/link-fragments : Link fragments should be valid : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md051.md MD051: true # MD052/reference-links-images : Reference links and images should use a label that is defined : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md052.md MD052: # Include shortcut syntax shortcut_syntax: false # MD053/link-image-reference-definitions : Link and image reference definitions should be needed : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md053.md MD053: # Ignored definitions ignored_definitions: - "//" # MD054/link-image-style : Link and image style : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md054.md MD054: # Allow autolinks autolink: false # Allow inline links and images inline: true # Allow full reference links and images full: true # Allow collapsed reference links and images collapsed: true # Allow shortcut reference links and images shortcut: true # Allow URLs as inline links url_inline: true # MD055/table-pipe-style : Table pipe style : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md055.md MD055: # Table pipe style style: "consistent" # MD060/table-column-style : Table column style : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md060.md # NOTE: Legacy docs rely on hand-crafted alignment, so disable this rule to avoid noisy warnings. MD060: false # MD056/table-column-count : Table column count : https://github.com/DavidAnson/markdownlint/blob/v0.34.0/doc/md056.md MD056: true ================================================ FILE: AGENTS.md ================================================ # AGENTS.md ## Agent Instructions > **Note for the AI agent:** Before running the startup script, verify that all dependencies and CLI tools are already available. If a `vendor/` directory exists **or** `go list -m -mod=readonly all` succeeds (indicating modules are cached), you may skip the entire startup script—including all `go mod` commands and `go install` steps. Otherwise, run the script as listed below. ### General coding practices - When adding Go tests, always invoke `t.Parallel()` at the start of each test and subtest to maximize concurrency. - Prefer `github.com/gofiber/utils/v2` helpers (for example, `utils.Trim`) when performing common operations such as string manipulation, whenever it is practical and appropriate for the surrounding code. - Keep all protocol behavior RFC-compliant (e.g., HTTP/1.1 requirements) and document any intentional deviations. - Protect hot paths from regressions: profile changes. - Apply secure-by-default choices (validation, timeouts, sanitization) and ensure new code hardens attack surfaces. --- ## Startup script (reference only – do not run) - Fetch dependencies: ```bash go mod tidy && go mod download && go mod vendor ``` - Install CLI tools referenced in Makefile: ```bash go install gotest.tools/gotestsum@latest # test runner go install golang.org/x/vuln/cmd/govulncheck@latest # vulnerability scanner go install mvdan.cc/gofumpt@latest # code formatter go install github.com/tinylib/msgp@latest # msgp codegen go install github.com/vburenin/ifacemaker@f30b6f9bdbed4b5c4804ec9ba4a04a999525c202 # interface impls go install github.com/dkorunic/betteralign/cmd/betteralign@latest # struct alignment go install golang.org/x/tools/gopls/internal/analysis/modernize/cmd/modernize@latest go mod tidy # clean up go.mod & go.sum ``` ## Makefile commands Use `make help` to list all available commands. Common targets include: - **audit**: run `go mod verify`, `go vet`, and `govulncheck` for quality checks. - **benchmark**: run benchmarks with `go test`. - **coverage**: generate a coverage report. - **format**: apply formatting using `gofumpt`. - **lint**: execute `golangci-lint`. - **test**: run the test suite with `gotestsum`. - **longtest**: run the test suite 15 times with shuffling enabled. - **tidy**: clean and tidy dependencies. - **betteralign**: optimize struct field alignment. - **generate**: run `go generate` after installing msgp and ifacemaker. - **modernize**: run golps modernize These targets can be invoked via `make ` as needed during development and testing. ## Pull request guidelines - PR titles must start with a category prefix describing the change: `🐛 bug:`, `🔥 feat:`, `📒 docs:`, or `🧹 chore:`. - Generated PR titles and bodies must summarize the *entire* set of changes on the branch (for example, based on `git log --oneline ..HEAD` or the full diff), **not** just the latest commit. The Summary section should reflect all modifications that will be merged. ## Programmatic checks Before presenting final changes or submitting a pull request, run each of the following commands and ensure they succeed. Include the command outputs in your final response to confirm they were executed: ```bash make audit make generate make betteralign make modernize make format make lint make test ``` All checks must pass before the generated code can be merged. After completing the programmatic checks above, confirm that any relevant documentation has been updated to reflect the changes made, including PR instructions when applicable. ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2019-present Fenny and Contributors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: Makefile ================================================ GOVERSION ?= $(shell go env GOVERSION) ## help: 💡 Display available commands .PHONY: help help: @echo '⚡️ GoFiber/Fiber Development:' @sed -n 's/^##//p' ${MAKEFILE_LIST} | column -t -s ':' | sed -e 's/^/ /' ## audit: 🚀 Conduct quality checks .PHONY: audit audit: go mod verify go vet ./... GOTOOLCHAIN=$(GOVERSION) go run golang.org/x/vuln/cmd/govulncheck@latest ./... ## benchmark: 📈 Benchmark code performance .PHONY: benchmark benchmark: go test ./... -benchmem -bench=. -run=^Benchmark_$ ## coverage: ☂️ Generate coverage report .PHONY: coverage coverage: GOTOOLCHAIN=$(GOVERSION) go run gotest.tools/gotestsum@latest -f testname -- ./... -race -count=1 -coverprofile=/tmp/coverage.out -covermode=atomic go tool cover -html=/tmp/coverage.out ## format: 🎨 Fix code format issues .PHONY: format format: GOTOOLCHAIN=$(GOVERSION) go run mvdan.cc/gofumpt@latest -w -l . ## markdown: 🎨 Find markdown format issues (Requires markdownlint-cli2) .PHONY: markdown markdown: @which markdownlint-cli2 > /dev/null || npm install -g markdownlint-cli2 markdownlint-cli2 "**/*.md" "#vendor" ## lint: 🚨 Run lint checks .PHONY: lint lint: GOTOOLCHAIN=$(GOVERSION) go run github.com/golangci/golangci-lint/v2/cmd/golangci-lint@v2.5.0 run ./... ## modernize: 🛠 Run gopls modernize .PHONY: modernize modernize: GOTOOLCHAIN=$(GOVERSION) go run golang.org/x/tools/gopls/internal/analysis/modernize/cmd/modernize@latest -fix -test=false ./... ## test: 🚦 Execute all tests .PHONY: test test: GOTOOLCHAIN=$(GOVERSION) go run gotest.tools/gotestsum@latest -f testname -- ./... -race -count=1 -shuffle=on ## longtest: 🚦 Execute all tests 10x .PHONY: longtest longtest: GOTOOLCHAIN=$(GOVERSION) go run gotest.tools/gotestsum@latest -f testname -- ./... -race -count=15 -shuffle=on ## tidy: 📌 Clean and tidy dependencies .PHONY: tidy tidy: go mod tidy -v ## betteralign: 📐 Optimize alignment of fields in structs .PHONY: betteralign betteralign: GOTOOLCHAIN=$(GOVERSION) go run github.com/dkorunic/betteralign/cmd/betteralign@v0.8.0 -test_files -generated_files -apply ./... ## generate: ⚡️ Generate msgp && interface implementations .PHONY: generate generate: go install github.com/tinylib/msgp@latest go install github.com/vburenin/ifacemaker@f30b6f9bdbed4b5c4804ec9ba4a04a999525c202 go generate ./... # actionspin: 🤖 Bulk replace GitHub actions references from version tags to commit hashes .PHONY: actionspin actionspin: GOTOOLCHAIN=$(GOVERSION) go run github.com/mashiike/actionspin/cmd/actionspin@latest ================================================ FILE: adapter.go ================================================ package fiber import ( "fmt" "net/http" "reflect" "github.com/valyala/fasthttp" "github.com/valyala/fasthttp/fasthttpadaptor" ) // toFiberHandler converts a supported handler type to a Fiber handler. func toFiberHandler(handler any) (Handler, bool) { if handler == nil { return nil, false } switch handler.(type) { case Handler, func(Ctx): // (1)-(2) Fiber handlers return adaptFiberHandler(handler) case func(Req, Res) error, func(Req, Res), func(Req, Res, func() error) error, func(Req, Res, func() error), func(Req, Res, func()) error, func(Req, Res, func()), func(Req, Res, func(error)), func(Req, Res, func(error)) error, func(Req, Res, func(error) error), func(Req, Res, func(error) error) error: // (3)-(12) Express-style request handlers return adaptExpressHandler(handler) case http.HandlerFunc, http.Handler, func(http.ResponseWriter, *http.Request): // (13)-(15) net/http handlers return adaptHTTPHandler(handler) case fasthttp.RequestHandler, func(*fasthttp.RequestCtx) error: // (16)-(17) fasthttp handlers return adaptFastHTTPHandler(handler) default: // (18) unsupported handler type return nil, false } } func adaptFiberHandler(handler any) (Handler, bool) { switch h := handler.(type) { case Handler: // (1) direct Fiber handler if h == nil { return nil, false } return h, true case func(Ctx): // (2) Fiber handler without error return if h == nil { return nil, false } return func(c Ctx) error { h(c) return nil }, true default: return nil, false } } func adaptExpressHandler(handler any) (Handler, bool) { switch h := handler.(type) { case func(Req, Res) error: // (3) Express-style handler with error return if h == nil { return nil, false } return func(c Ctx) error { return h(c.Req(), c.Res()) }, true case func(Req, Res): // (4) Express-style handler without error return if h == nil { return nil, false } return func(c Ctx) error { h(c.Req(), c.Res()) return nil }, true case func(Req, Res, func() error) error: // (5) Express-style handler with error-returning next callback and error return if h == nil { return nil, false } return func(c Ctx) error { return h(c.Req(), c.Res(), func() error { return c.Next() }) }, true case func(Req, Res, func() error): // (6) Express-style handler with error-returning next callback if h == nil { return nil, false } return func(c Ctx) error { var nextErr error h(c.Req(), c.Res(), func() error { nextErr = c.Next() return nextErr }) return nextErr }, true case func(Req, Res, func()) error: // (7) Express-style handler with no-arg next callback and error return if h == nil { return nil, false } return func(c Ctx) error { var nextErr error err := h(c.Req(), c.Res(), func() { nextErr = c.Next() }) if err != nil { return err } return nextErr }, true case func(Req, Res, func()): // (8) Express-style handler with no-arg next callback if h == nil { return nil, false } return func(c Ctx) error { var nextErr error h(c.Req(), c.Res(), func() { nextErr = c.Next() }) return nextErr }, true case func(Req, Res, func(error)): // (9) Express-style handler with error-accepting next callback if h == nil { return nil, false } return func(c Ctx) error { var nextErr error h(c.Req(), c.Res(), func(err error) { if err != nil { nextErr = err return } nextErr = c.Next() }) return nextErr }, true case func(Req, Res, func(error)) error: // (10) Express-style handler with error-accepting next callback and error return if h == nil { return nil, false } return func(c Ctx) error { var nextErr error err := h(c.Req(), c.Res(), func(nextErrArg error) { if nextErrArg != nil { nextErr = nextErrArg return } nextErr = c.Next() }) if err != nil { return err } return nextErr }, true case func(Req, Res, func(error) error): // (11) Express-style handler with error-accepting next callback that returns an error if h == nil { return nil, false } return func(c Ctx) error { var nextErr error h(c.Req(), c.Res(), func(nextErrArg error) error { if nextErrArg != nil { nextErr = nextErrArg return nextErrArg } nextErr = c.Next() return nextErr }) return nextErr }, true case func(Req, Res, func(error) error) error: // (12) Express-style handler with error-accepting next callback that returns an error and error return if h == nil { return nil, false } return func(c Ctx) error { var nextErr error err := h(c.Req(), c.Res(), func(nextErrArg error) error { if nextErrArg != nil { nextErr = nextErrArg return nextErrArg } nextErr = c.Next() return nextErr }) if err != nil { return err } return nextErr }, true default: return nil, false } } func adaptHTTPHandler(handler any) (Handler, bool) { switch h := handler.(type) { case http.HandlerFunc: // (13) net/http HandlerFunc if h == nil { return nil, false } return wrapHTTPHandler(h), true case http.Handler: // (14) net/http Handler implementation if h == nil { return nil, false } hv := reflect.ValueOf(h) if isNilableKind(hv.Kind()) && hv.IsNil() { return nil, false } return wrapHTTPHandler(h), true case func(http.ResponseWriter, *http.Request): // (15) net/http function handler if h == nil { return nil, false } return wrapHTTPHandler(http.HandlerFunc(h)), true default: return nil, false } } func isNilableKind(kind reflect.Kind) bool { switch kind { case reflect.Chan, reflect.Func, reflect.Map, reflect.Pointer, reflect.Interface, reflect.Slice, reflect.UnsafePointer: return true default: return false } } func adaptFastHTTPHandler(handler any) (Handler, bool) { switch h := handler.(type) { case fasthttp.RequestHandler: // (16) fasthttp handler if h == nil { return nil, false } return func(c Ctx) error { h(c.RequestCtx()) return nil }, true case func(*fasthttp.RequestCtx) error: // (17) fasthttp handler with error return if h == nil { return nil, false } return func(c Ctx) error { return h(c.RequestCtx()) }, true default: return nil, false } } // wrapHTTPHandler adapts a net/http handler to a Fiber handler. func wrapHTTPHandler(handler http.Handler) Handler { if handler == nil { return nil } adapted := fasthttpadaptor.NewFastHTTPHandler(handler) return func(c Ctx) error { adapted(c.RequestCtx()) return nil } } // collectHandlers converts a slice of handler arguments to Fiber handlers. // The context string is used to provide informative panic messages when an // unsupported handler type is encountered. func collectHandlers(context string, args ...any) []Handler { handlers := make([]Handler, 0, len(args)) for i, arg := range args { handler, ok := toFiberHandler(arg) if !ok { panic(fmt.Sprintf("%s: invalid handler #%d (%T)\n", context, i, arg)) } handlers = append(handlers, handler) } return handlers } ================================================ FILE: adapter_test.go ================================================ package fiber import ( "errors" "fmt" "io" "net/http" "net/http/httptest" "os" "reflect" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) func TestToFiberHandler_Nil(t *testing.T) { t.Parallel() var handler Handler converted, ok := toFiberHandler(handler) require.False(t, ok) require.Nil(t, converted) } func TestToFiberHandler_FiberHandler(t *testing.T) { t.Parallel() fiberHandler := func(c Ctx) error { return c.SendStatus(http.StatusAccepted) } converted, ok := toFiberHandler(fiberHandler) require.True(t, ok) require.NotNil(t, converted) require.Equal(t, reflect.ValueOf(fiberHandler).Pointer(), reflect.ValueOf(converted).Pointer()) } func TestToFiberHandler_FiberHandlerNoErrorReturn(t *testing.T) { t.Parallel() app, ctx := newTestCtx(t) handler := func(c Ctx) { require.Equal(t, app, c.App()) c.Set("X-Handler", "ok") } converted, ok := toFiberHandler(handler) require.True(t, ok) require.NotNil(t, converted) require.NoError(t, converted(ctx)) require.Equal(t, "ok", string(ctx.Response().Header.Peek("X-Handler"))) } func TestNewTestCtx_ReturnsDefaultCtx(t *testing.T) { t.Parallel() app, ctx := newTestCtx(t) require.NotNil(t, app) require.NotNil(t, ctx) require.Equal(t, app, ctx.App()) } func newTestCtx(t *testing.T) (*App, *DefaultCtx) { t.Helper() app := New() fasthttpCtx := &fasthttp.RequestCtx{} customCtx := app.AcquireCtx(fasthttpCtx) ctx, ok := customCtx.(*DefaultCtx) require.True(t, ok) t.Cleanup(func() { app.ReleaseCtx(customCtx) }) return app, ctx } func withRouteHandlers(t *testing.T, ctx *DefaultCtx, handlers ...Handler) { t.Helper() ctx.route = &Route{Handlers: handlers} ctx.indexHandler = 0 t.Cleanup(func() { ctx.route = nil ctx.indexHandler = 0 }) } func TestToFiberHandler_ExpressTwoParamsWithError(t *testing.T) { t.Parallel() app, ctx := newTestCtx(t) handler := func(req Req, res Res) error { assert.Equal(t, app, req.App()) assert.Equal(t, app, res.App()) return res.SendString("express") } converted, ok := toFiberHandler(handler) require.True(t, ok) require.NoError(t, converted(ctx)) require.Equal(t, "express", string(ctx.Response().Body())) } func TestToFiberHandler_ExpressTwoParamsWithoutError(t *testing.T) { t.Parallel() app, ctx := newTestCtx(t) handler := func(req Req, res Res) { assert.Equal(t, app, req.App()) require.NoError(t, res.SendStatus(http.StatusCreated)) } converted, ok := toFiberHandler(handler) require.True(t, ok) require.NoError(t, converted(ctx)) require.Equal(t, http.StatusCreated, ctx.Response().StatusCode()) } func TestToFiberHandler_ExpressThreeParamsWithError(t *testing.T) { t.Parallel() app, ctx := newTestCtx(t) handler := func(req Req, res Res, next func() error) error { assert.Equal(t, app, req.App()) assert.Equal(t, app, res.App()) return next() } converted, ok := toFiberHandler(handler) require.True(t, ok) nextErr := errors.New("next") nextCalled := false nextHandler := func(_ Ctx) error { nextCalled = true return nextErr } withRouteHandlers(t, ctx, converted, nextHandler) err := converted(ctx) require.ErrorIs(t, err, nextErr) require.True(t, nextCalled) } func TestToFiberHandler_ExpressThreeParamsWithoutError(t *testing.T) { t.Parallel() app, ctx := newTestCtx(t) handler := func(req Req, _ Res, next func() error) { assert.Equal(t, app, req.App()) err := next() require.Error(t, err) assert.EqualError(t, err, "next without error") } converted, ok := toFiberHandler(handler) require.True(t, ok) nextHandler := func(_ Ctx) error { return errors.New("next without error") } withRouteHandlers(t, ctx, converted, nextHandler) err := converted(ctx) require.EqualError(t, err, "next without error") } func TestToFiberHandler_ExpressNextNoArgWithErrorReturn(t *testing.T) { t.Parallel() app, ctx := newTestCtx(t) handler := func(req Req, res Res, next func()) error { assert.Equal(t, app, req.App()) assert.Equal(t, app, res.App()) next() return nil } converted, ok := toFiberHandler(handler) require.True(t, ok) nextErr := errors.New("next without return value") nextCalled := false nextHandler := func(_ Ctx) error { nextCalled = true return nextErr } withRouteHandlers(t, ctx, converted, nextHandler) err := converted(ctx) require.ErrorIs(t, err, nextErr) require.True(t, nextCalled) } func TestToFiberHandler_ExpressNextWithErrorContinuesOnNil(t *testing.T) { t.Parallel() app, ctx := newTestCtx(t) handler := func(req Req, res Res, next func(error)) { assert.Equal(t, app, req.App()) assert.Equal(t, app, res.App()) next(nil) } converted, ok := toFiberHandler(handler) require.True(t, ok) nextCalled := false nextHandler := func(_ Ctx) error { nextCalled = true return nil } withRouteHandlers(t, ctx, converted, nextHandler) err := converted(ctx) require.NoError(t, err) require.True(t, nextCalled) } func TestToFiberHandler_ExpressNextWithErrorShortCircuitsOnError(t *testing.T) { t.Parallel() app, ctx := newTestCtx(t) handler := func(req Req, res Res, next func(error)) { assert.Equal(t, app, req.App()) assert.Equal(t, app, res.App()) next(errors.New("next error")) } converted, ok := toFiberHandler(handler) require.True(t, ok) nextCalled := false nextHandler := func(_ Ctx) error { nextCalled = true return nil } withRouteHandlers(t, ctx, converted, nextHandler) err := converted(ctx) require.EqualError(t, err, "next error") require.False(t, nextCalled) } func TestToFiberHandler_ExpressNextWithErrorReturn_ShortCircuitsOnNextError(t *testing.T) { t.Parallel() app, ctx := newTestCtx(t) handler := func(req Req, res Res, next func(error)) error { assert.Equal(t, app, req.App()) assert.Equal(t, app, res.App()) next(errors.New("next error")) return nil } converted, ok := toFiberHandler(handler) require.True(t, ok) nextCalled := false nextHandler := func(_ Ctx) error { nextCalled = true return nil } withRouteHandlers(t, ctx, converted, nextHandler) err := converted(ctx) require.EqualError(t, err, "next error") require.False(t, nextCalled) } func TestToFiberHandler_ExpressNextWithErrorReturnCallback_PropagatesNextError(t *testing.T) { t.Parallel() app, ctx := newTestCtx(t) handler := func(req Req, res Res, next func(error) error) { assert.Equal(t, app, req.App()) assert.Equal(t, app, res.App()) require.EqualError(t, next(nil), "next error") } converted, ok := toFiberHandler(handler) require.True(t, ok) nextErr := errors.New("next error") nextCalled := false nextHandler := func(_ Ctx) error { nextCalled = true return nextErr } withRouteHandlers(t, ctx, converted, nextHandler) err := converted(ctx) require.ErrorIs(t, err, nextErr) require.True(t, nextCalled) } func TestToFiberHandler_ExpressNextWithErrorReturnCallback_ShortCircuitsOnNextError(t *testing.T) { t.Parallel() app, ctx := newTestCtx(t) handler := func(req Req, res Res, next func(error) error) { assert.Equal(t, app, req.App()) assert.Equal(t, app, res.App()) require.EqualError(t, next(errors.New("next error")), "next error") } converted, ok := toFiberHandler(handler) require.True(t, ok) nextCalled := false nextHandler := func(_ Ctx) error { nextCalled = true return nil } withRouteHandlers(t, ctx, converted, nextHandler) err := converted(ctx) require.EqualError(t, err, "next error") require.False(t, nextCalled) } func TestToFiberHandler_ExpressNextWithErrorReturn_PrefersHandlerErrorOverNextError(t *testing.T) { t.Parallel() app, ctx := newTestCtx(t) handler := func(req Req, res Res, next func(error) error) error { assert.Equal(t, app, req.App()) assert.Equal(t, app, res.App()) require.EqualError(t, next(errors.New("next error")), "next error") return errors.New("handler error") } converted, ok := toFiberHandler(handler) require.True(t, ok) nextCalled := false nextHandler := func(_ Ctx) error { nextCalled = true return nil } withRouteHandlers(t, ctx, converted, nextHandler) err := converted(ctx) require.EqualError(t, err, "handler error") require.False(t, nextCalled) } func TestToFiberHandler_ExpressNextWithErrorReturn_PropagatesNextErrorWhenNoReturnError(t *testing.T) { t.Parallel() app, ctx := newTestCtx(t) handler := func(req Req, res Res, next func(error) error) error { assert.Equal(t, app, req.App()) assert.Equal(t, app, res.App()) return next(errors.New("next error")) } converted, ok := toFiberHandler(handler) require.True(t, ok) nextCalled := false nextHandler := func(_ Ctx) error { nextCalled = true return nil } withRouteHandlers(t, ctx, converted, nextHandler) err := converted(ctx) require.EqualError(t, err, "next error") require.False(t, nextCalled) } func TestToFiberHandler_ExpressNextWithErrorReturnCallback_StopsChainWithoutNextCall(t *testing.T) { t.Parallel() app, ctx := newTestCtx(t) handler := func(req Req, res Res, _ func(error) error) { assert.Equal(t, app, req.App()) assert.Equal(t, app, res.App()) // Intentionally do not call next. } converted, ok := toFiberHandler(handler) require.True(t, ok) nextCalled := false nextHandler := func(_ Ctx) error { nextCalled = true return errors.New("should not be called") } withRouteHandlers(t, ctx, converted, nextHandler) err := converted(ctx) require.NoError(t, err) require.False(t, nextCalled) } func TestAdapter_MixedHandlerIntegration(t *testing.T) { app := New() app.Use(func(c Ctx) error { c.Set("X-Middleware", "fiber") return c.Next() }) app.Use(func(_ Req, res Res, next func() error) error { res.Set("X-Express", "middleware") return next() }) app.Get("/fiber", func(c Ctx) error { c.Set("X-Route", "fiber") return c.SendString("fiber handler") }) app.Post("/express", func(_ Req, res Res) error { res.Set("X-Route", "express") return res.SendString("express handler") }) var httpHandlerWriteErr error app.Put("/http", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("X-Route", "http") w.WriteHeader(http.StatusAccepted) _, httpHandlerWriteErr = w.Write([]byte("http handler")) }) app.Delete("/fasthttp", func(ctx *fasthttp.RequestCtx) error { ctx.Response.Header.Set("X-Route", "fasthttp") ctx.SetStatusCode(http.StatusCreated) ctx.SetBodyString("fasthttp handler") return nil }) run := func(name string, buildRequest func() *http.Request, expectStatus int, expectBody, expectRoute string) { t.Run(name, func(t *testing.T) { req := buildRequest() resp, err := app.Test(req) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, resp.Body.Close()) }) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, expectStatus, resp.StatusCode) require.Equal(t, expectBody, string(body)) require.Equal(t, "fiber", resp.Header.Get("X-Middleware")) require.Equal(t, "middleware", resp.Header.Get("X-Express")) require.Equal(t, expectRoute, resp.Header.Get("X-Route")) }) } run("fiber", func() *http.Request { return httptest.NewRequest(http.MethodGet, "/fiber", http.NoBody) }, http.StatusOK, "fiber handler", "fiber") run("express", func() *http.Request { return httptest.NewRequest(http.MethodPost, "/express", http.NoBody) }, http.StatusOK, "express handler", "express") run("net/http", func() *http.Request { return httptest.NewRequest(http.MethodPut, "/http", http.NoBody) }, http.StatusAccepted, "http handler", "http") require.NoError(t, httpHandlerWriteErr) run("fasthttp", func() *http.Request { return httptest.NewRequest(http.MethodDelete, "/fasthttp", http.NoBody) }, http.StatusCreated, "fasthttp handler", "fasthttp") } func TestToFiberHandler_ExpressNextNoArgPropagatesError(t *testing.T) { t.Parallel() app, ctx := newTestCtx(t) handler := func(req Req, res Res, next func()) { assert.Equal(t, app, req.App()) assert.Equal(t, app, res.App()) next() } converted, ok := toFiberHandler(handler) require.True(t, ok) nextErr := errors.New("next without return value") nextCalled := false nextHandler := func(_ Ctx) error { nextCalled = true return nextErr } withRouteHandlers(t, ctx, converted, nextHandler) err := converted(ctx) require.ErrorIs(t, err, nextErr) require.True(t, nextCalled) } func TestToFiberHandler_ExpressNextNoArgStopsChain(t *testing.T) { t.Parallel() app, ctx := newTestCtx(t) handler := func(req Req, res Res, _ func()) { assert.Equal(t, app, req.App()) assert.Equal(t, app, res.App()) // Intentionally do not call next(). } converted, ok := toFiberHandler(handler) require.True(t, ok) nextCalled := false nextHandler := func(_ Ctx) error { nextCalled = true return errors.New("should not be called") } withRouteHandlers(t, ctx, converted, nextHandler) err := converted(ctx) require.NoError(t, err) require.False(t, nextCalled) } func TestToFiberHandler_ExpressNextNoArgMiddleware(t *testing.T) { t.Parallel() app := New() t.Cleanup(func() { require.NoError(t, app.Shutdown()) }) callOrder := make([]string, 0, 2) app.Use(func(req Req, res Res, next func()) { callOrder = append(callOrder, "middleware") next() assert.Equal(t, app, req.App()) assert.Equal(t, app, res.App()) }) app.Get("/", func(c Ctx) error { callOrder = append(callOrder, "handler") return c.SendStatus(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) require.NoError(t, resp.Body.Close()) require.Equal(t, []string{"middleware", "handler"}, callOrder) } func TestCollectHandlers_HTTPHandler(t *testing.T) { t.Parallel() var writeErr error httpHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("X-HTTP", "ok") w.WriteHeader(http.StatusTeapot) _, writeErr = w.Write([]byte("http")) }) handlers := collectHandlers("test", httpHandler) require.Len(t, handlers, 1) converted := handlers[0] require.NotNil(t, converted) app := New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) err := converted(ctx) require.NoError(t, err) require.Equal(t, http.StatusTeapot, ctx.Response().StatusCode()) require.Equal(t, "ok", string(ctx.Response().Header.Peek("X-HTTP"))) require.Equal(t, "http", string(ctx.Response().Body())) require.NoError(t, writeErr) } func TestToFiberHandler_HTTPHandler(t *testing.T) { t.Parallel() var writeErr error var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("X-HTTP", "handler") _, writeErr = w.Write([]byte("through")) }) converted, ok := toFiberHandler(handler) require.True(t, ok) require.NotNil(t, converted) app := New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) err := converted(ctx) require.NoError(t, err) require.Equal(t, "handler", string(ctx.Response().Header.Peek("X-HTTP"))) require.Equal(t, "through", string(ctx.Response().Body())) require.NoError(t, writeErr) } func TestToFiberHandler_FasthttpHandlerWithError(t *testing.T) { t.Parallel() _, ctx := newTestCtx(t) fasthttpHandler := func(fctx *fasthttp.RequestCtx) error { fctx.Response.Header.Set("X-FASTHTTP", "error") return errors.New("fasthttp error") } converted, ok := toFiberHandler(fasthttpHandler) require.True(t, ok) require.NotNil(t, converted) err := converted(ctx) require.EqualError(t, err, "fasthttp error") require.Equal(t, "error", string(ctx.Response().Header.Peek("X-FASTHTTP"))) } func TestToFiberHandler_HTTPHandler_Flush(t *testing.T) { t.Parallel() var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("X-HTTP", "handler") _, err := w.Write([]byte("through")) flusher, ok := w.(http.Flusher) assert.True(t, ok, "w does not implement http.Flusher") flusher.Flush() assert.NoError(t, err) }) converted, ok := toFiberHandler(handler) require.True(t, ok) require.NotNil(t, converted) app := New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) err := converted(ctx) require.NoError(t, err) require.Equal(t, "handler", string(ctx.Response().Header.Peek("X-HTTP"))) require.Equal(t, "through", string(ctx.Response().Body())) } func TestWrapHTTPHandler_Flush_App_Test(t *testing.T) { t.Parallel() app := New() app.Get("/", func(w http.ResponseWriter, _ *http.Request) { flusher, ok := w.(http.Flusher) if !ok { t.Fatal("w does not implement http.Flusher") } w.WriteHeader(StatusOK) fmt.Fprintf(w, "Hello ") flusher.Flush() fmt.Fprintf(w, "World!") }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/", http.NoBody)) require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck // not needed require.Equal(t, StatusOK, resp.StatusCode) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "Hello World!", string(body)) } func Test_HTTPHandler_App_Test_Interrupted(t *testing.T) { t.Parallel() app := New() app.Get("/", func(w http.ResponseWriter, _ *http.Request) { flusher, ok := w.(http.Flusher) if !ok { t.Fatal("w does not implement http.Flusher") } w.WriteHeader(StatusOK) fmt.Fprintf(w, "Hello ") flusher.Flush() time.Sleep(500 * time.Millisecond) fmt.Fprintf(w, "World!") }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/", http.NoBody), TestConfig{ Timeout: 200 * time.Millisecond, FailOnTimeout: true, // Changed to true to test interrupted behavior }) // With FailOnTimeout: true, we should get a timeout error require.ErrorIs(t, err, os.ErrDeadlineExceeded) require.Nil(t, resp) } func TestToFiberHandler_HTTPHandlerFunc(t *testing.T) { t.Parallel() httpFunc := func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusNoContent) } converted, ok := toFiberHandler(httpFunc) require.True(t, ok) require.NotNil(t, converted) app := New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) err := converted(ctx) require.NoError(t, err) require.Equal(t, http.StatusNoContent, ctx.Response().StatusCode()) } func TestWrapHTTPHandler_Nil(t *testing.T) { t.Parallel() require.Nil(t, wrapHTTPHandler(nil)) } func TestCollectHandlers_InvalidType(t *testing.T) { t.Parallel() require.PanicsWithValue(t, "context: invalid handler #0 (int)\n", func() { collectHandlers("context", 42) }) } func TestCollectHandlers_TypedNilHTTPHandlers(t *testing.T) { t.Parallel() var handlerFunc http.HandlerFunc var handler http.Handler var raw func(http.ResponseWriter, *http.Request) tests := []struct { handler any name string }{ { name: "HandlerFunc", handler: handlerFunc, }, { name: "Handler", handler: handler, }, { name: "Function", handler: raw, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() expected := fmt.Sprintf("context: invalid handler #0 (%T)\n", tt.handler) require.PanicsWithValue(t, expected, func() { collectHandlers("context", tt.handler) }) }) } } type dummyHandler struct{} func (dummyHandler) ServeHTTP(http.ResponseWriter, *http.Request) {} type dummyFuncHandler func(http.ResponseWriter, *http.Request) func (handler dummyFuncHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if handler == nil { return } handler(w, r) } func TestCollectHandlers_TypedNilPointerHTTPHandler(t *testing.T) { t.Parallel() var handler http.Handler = (*dummyHandler)(nil) require.PanicsWithValue(t, "context: invalid handler #0 (*fiber.dummyHandler)\n", func() { collectHandlers("context", handler) }) } func TestCollectHandlers_TypedNilFuncHTTPHandler(t *testing.T) { t.Parallel() var handler http.Handler = dummyFuncHandler(nil) expected := fmt.Sprintf("context: invalid handler #0 (%T)\n", handler) require.PanicsWithValue(t, expected, func() { collectHandlers("context", handler) }) } func TestCollectHandlers_TypedNilFasthttpHandlers(t *testing.T) { t.Parallel() var requestHandler fasthttp.RequestHandler var requestHandlerWithError func(*fasthttp.RequestCtx) error tests := []struct { handler any name string }{ { name: "RequestHandler", handler: requestHandler, }, { name: "RequestHandlerWithError", handler: requestHandlerWithError, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() expected := fmt.Sprintf("context: invalid handler #0 (%T)\n", tt.handler) require.PanicsWithValue(t, expected, func() { collectHandlers("context", tt.handler) }) }) } } func TestCollectHandlers_FasthttpHandler(t *testing.T) { t.Parallel() before := func(c Ctx) error { c.Set("X-Before", "fiber") return nil } fasthttpHandler := fasthttp.RequestHandler(func(ctx *fasthttp.RequestCtx) { ctx.Response.Header.Set("X-FASTHTTP", "ok") ctx.SetBody([]byte("done")) }) handlers := collectHandlers("fasthttp", before, fasthttpHandler) require.Len(t, handlers, 2) app := New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) for _, handler := range handlers { require.NoError(t, handler(ctx)) } require.Equal(t, "fiber", string(ctx.Response().Header.Peek("X-Before"))) require.Equal(t, "ok", string(ctx.Response().Header.Peek("X-FASTHTTP"))) require.Equal(t, "done", string(ctx.Response().Body())) } func TestCollectHandlers_FiberHandlerNoErrorReturn(t *testing.T) { t.Parallel() noError := func(c Ctx) { c.Set("X-Handler", "fiber") } handlers := collectHandlers("ctx", noError) require.Len(t, handlers, 1) app := New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) require.NoError(t, handlers[0](ctx)) require.Equal(t, "fiber", string(ctx.Response().Header.Peek("X-Handler"))) } func TestCollectHandlers_MixedHandlers(t *testing.T) { t.Parallel() before := func(c Ctx) error { c.Set("X-Before", "fiber") return nil } var writeErr error httpHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _, writeErr = w.Write([]byte("done")) }) handlers := collectHandlers("test", before, httpHandler) require.Len(t, handlers, 2) require.Equal(t, reflect.ValueOf(before).Pointer(), reflect.ValueOf(handlers[0]).Pointer()) require.NotNil(t, handlers[1]) app := New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) err := handlers[0](ctx) require.NoError(t, err) err = handlers[1](ctx) require.NoError(t, err) require.Equal(t, "done", string(ctx.Response().Body())) require.Equal(t, "fiber", string(ctx.Response().Header.Peek("X-Before"))) require.NoError(t, writeErr) } func TestCollectHandlers_Nil(t *testing.T) { t.Parallel() require.PanicsWithValue(t, "nil: invalid handler #0 ()\n", func() { collectHandlers("nil", nil) }) } ================================================ FILE: addon/retry/README.md ================================================ # Retry Addon Retry addon for [Fiber](https://github.com/gofiber/fiber) designed to apply retry mechanism for unsuccessful network operations. This addon uses an exponential backoff algorithm with jitter. It calls the function multiple times and tries to make it successful. If all calls are failed, then, it returns an error. It adds a jitter at each retry step because adding a jitter is a way to break synchronization across the client and avoid collision. ## Table of Contents - [Retry Addon](#retry-addon) - [Table of Contents](#table-of-contents) - [Signatures](#signatures) - [Examples](#examples) - [Default Config](#default-config) - [Custom Config](#custom-config) - [Config](#config) - [Default Config Example](#default-config-example) ## Signatures ```go func NewExponentialBackoff(config ...retry.Config) *retry.ExponentialBackoff ``` ## Examples ```go package main import ( "fmt" "github.com/gofiber/fiber/v3/addon/retry" "github.com/gofiber/fiber/v3/client" ) func main() { expBackoff := retry.NewExponentialBackoff(retry.Config{}) // Local variables that will be used inside of Retry var resp *client.Response var err error // Retry a network request and return an error to signify to try again err = expBackoff.Retry(func() error { client := client.New() resp, err = client.Get("https://gofiber.io") if err != nil { return fmt.Errorf("GET gofiber.io failed: %w", err) } if resp.StatusCode() != 200 { return fmt.Errorf("GET gofiber.io did not return OK 200") } return nil }) // If all retries failed, panic if err != nil { panic(err) } fmt.Printf("GET gofiber.io succeeded with status code %d\n", resp.StatusCode()) } ``` ## Default Config ```go retry.NewExponentialBackoff() ``` ## Custom Config ```go retry.NewExponentialBackoff(retry.Config{ InitialInterval: 2 * time.Second, MaxBackoffTime: 64 * time.Second, Multiplier: 2.0, MaxRetryCount: 15, }) ``` ## Config ```go // Config defines the config for addon. type Config struct { // InitialInterval defines the initial time interval for backoff algorithm. // // Optional. Default: 1 * time.Second InitialInterval time.Duration // MaxBackoffTime defines maximum time duration for backoff algorithm. When // the algorithm is reached this time, rest of the retries will be maximum // 32 seconds. // // Optional. Default: 32 * time.Second MaxBackoffTime time.Duration // Multiplier defines multiplier number of the backoff algorithm. // // Optional. Default: 2.0 Multiplier float64 // MaxRetryCount defines maximum retry count for the backoff algorithm. // // Optional. Default: 10 MaxRetryCount int } ``` ## Default Config Example ```go // DefaultConfig is the default config for retry. var DefaultConfig = Config{ InitialInterval: 1 * time.Second, MaxBackoffTime: 32 * time.Second, Multiplier: 2.0, MaxRetryCount: 10, currentInterval: 1 * time.Second, } ``` ================================================ FILE: addon/retry/config.go ================================================ package retry import ( "time" ) // Config defines the config for addon. type Config struct { // InitialInterval defines the initial time interval for backoff algorithm. // // Optional. Default: 1 * time.Second InitialInterval time.Duration // MaxBackoffTime defines maximum time duration for backoff algorithm. When // the algorithm is reached this time, rest of the retries will be maximum // 32 seconds. // // Optional. Default: 32 * time.Second MaxBackoffTime time.Duration // Multiplier defines multiplier number of the backoff algorithm. // // Optional. Default: 2.0 Multiplier float64 // MaxRetryCount defines maximum retry count for the backoff algorithm. // // Optional. Default: 10 MaxRetryCount int // currentInterval tracks the current waiting time. // // Optional. Default: 1 * time.Second currentInterval time.Duration } // DefaultConfig is the default config for retry. var DefaultConfig = Config{ InitialInterval: 1 * time.Second, MaxBackoffTime: 32 * time.Second, Multiplier: 2.0, MaxRetryCount: 10, currentInterval: 1 * time.Second, } // configDefault sets the config values if they are not set. func configDefault(config ...Config) Config { if len(config) == 0 { return DefaultConfig } cfg := config[0] if cfg.InitialInterval == 0 { cfg.InitialInterval = DefaultConfig.InitialInterval } if cfg.MaxBackoffTime == 0 { cfg.MaxBackoffTime = DefaultConfig.MaxBackoffTime } if cfg.Multiplier <= 0 { cfg.Multiplier = DefaultConfig.Multiplier } if cfg.MaxRetryCount <= 0 { cfg.MaxRetryCount = DefaultConfig.MaxRetryCount } if cfg.currentInterval == 0 { cfg.currentInterval = cfg.InitialInterval } return cfg } ================================================ FILE: addon/retry/config_test.go ================================================ package retry import ( "testing" "time" "github.com/stretchr/testify/require" ) func TestConfigDefault_NoConfig(t *testing.T) { t.Parallel() cfg := configDefault() require.Equal(t, DefaultConfig, cfg) } func TestConfigDefault_Custom(t *testing.T) { t.Parallel() custom := Config{ InitialInterval: 2 * time.Second, MaxBackoffTime: 64 * time.Second, Multiplier: 3.0, MaxRetryCount: 5, currentInterval: 2 * time.Second, } cfg := configDefault(custom) require.Equal(t, custom, cfg) } func TestConfigDefault_PartialAndNegative(t *testing.T) { t.Parallel() cfg := configDefault(Config{Multiplier: -1, MaxRetryCount: 0}) require.Equal(t, DefaultConfig, cfg) } func TestConfigDefault_CustomInitialInterval(t *testing.T) { t.Parallel() cfg := configDefault(Config{InitialInterval: 5 * time.Second}) require.Equal(t, 5*time.Second, cfg.currentInterval) require.Equal(t, 5*time.Second, cfg.InitialInterval) } func TestConfigDefault_CustomCurrentInterval(t *testing.T) { t.Parallel() cfg := configDefault(Config{currentInterval: 3 * time.Second}) require.Equal(t, 3*time.Second, cfg.currentInterval) require.Equal(t, DefaultConfig.InitialInterval, cfg.InitialInterval) } func TestConfigDefault_CurrentIntervalAndInitialDiffer(t *testing.T) { t.Parallel() cfg := configDefault(Config{InitialInterval: 5 * time.Second, currentInterval: 3 * time.Second}) require.Equal(t, 5*time.Second, cfg.InitialInterval) require.Equal(t, 3*time.Second, cfg.currentInterval) } func TestNewExponentialBackoff_Config(t *testing.T) { t.Parallel() backoff := NewExponentialBackoff(Config{InitialInterval: 4 * time.Second}) require.Equal(t, 4*time.Second, backoff.InitialInterval) require.Equal(t, 4*time.Second, backoff.currentInterval) } ================================================ FILE: addon/retry/exponential_backoff.go ================================================ package retry import ( "crypto/rand" "math/big" "time" ) // ExponentialBackoff is a retry mechanism for retrying some calls. type ExponentialBackoff struct { // InitialInterval is the initial time interval for backoff algorithm. InitialInterval time.Duration // MaxBackoffTime is the maximum time duration for backoff algorithm. It limits // the maximum sleep time. MaxBackoffTime time.Duration // Multiplier is a multiplier number of the backoff algorithm. Multiplier float64 // MaxRetryCount is the maximum number of retry count. MaxRetryCount int // currentInterval tracks the current sleep time. currentInterval time.Duration } // NewExponentialBackoff creates a ExponentialBackoff with default values. func NewExponentialBackoff(config ...Config) *ExponentialBackoff { cfg := configDefault(config...) return &ExponentialBackoff{ InitialInterval: cfg.InitialInterval, MaxBackoffTime: cfg.MaxBackoffTime, Multiplier: cfg.Multiplier, MaxRetryCount: cfg.MaxRetryCount, currentInterval: cfg.currentInterval, } } // Retry is the core logic of the retry mechanism. If the calling function returns // nil as an error, then the Retry method is terminated with returning nil. Otherwise, // if all function calls are returned error, then the method returns this error. func (e *ExponentialBackoff) Retry(f func() error) error { if e.currentInterval <= 0 { e.currentInterval = e.InitialInterval } var err error for i := 0; i < e.MaxRetryCount; i++ { err = f() if err == nil { return nil } if i < e.MaxRetryCount-1 { next := e.next() time.Sleep(next) } } return err } // next calculates the next sleeping time interval. func (e *ExponentialBackoff) next() time.Duration { // generate a random value between [0, 1000) n, err := rand.Int(rand.Reader, big.NewInt(1000)) if err != nil { return e.MaxBackoffTime } t := e.currentInterval + (time.Duration(n.Int64()) * time.Millisecond) e.currentInterval = time.Duration(float64(e.currentInterval) * e.Multiplier) if t >= e.MaxBackoffTime { e.currentInterval = e.MaxBackoffTime return e.MaxBackoffTime } return t } ================================================ FILE: addon/retry/exponential_backoff_test.go ================================================ package retry import ( "crypto/rand" "errors" "testing" "time" "github.com/stretchr/testify/require" ) func Test_ExponentialBackoff_Retry(t *testing.T) { t.Parallel() tests := []struct { expErr error expBackoff *ExponentialBackoff f func() error name string }{ { name: "With default values - successful", expBackoff: NewExponentialBackoff(), f: func() error { return nil }, }, { name: "Successful function", expBackoff: &ExponentialBackoff{ InitialInterval: 1 * time.Millisecond, MaxBackoffTime: 100 * time.Millisecond, Multiplier: 2.0, MaxRetryCount: 5, }, f: func() error { return nil }, }, { name: "Unsuccessful function", expBackoff: &ExponentialBackoff{ InitialInterval: 2 * time.Millisecond, MaxBackoffTime: 100 * time.Millisecond, Multiplier: 2.0, MaxRetryCount: 5, }, f: func() error { return errors.New("failed function") }, expErr: errors.New("failed function"), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() err := tt.expBackoff.Retry(tt.f) require.Equal(t, tt.expErr, err) }) } } func Test_ExponentialBackoff_Retry_NoSleepAfterLastAttempt(t *testing.T) { t.Parallel() const ( largeInterval = 5 * time.Second // would be used for sleep if bug existed maxAcceptable = 2 * time.Second // Retry must return well before largeInterval ) eb := &ExponentialBackoff{ InitialInterval: largeInterval, MaxBackoffTime: largeInterval * 2, Multiplier: 2.0, MaxRetryCount: 1, } start := time.Now() err := eb.Retry(func() error { return errors.New("only attempt") }) elapsed := time.Since(start) require.Error(t, err) require.Equal(t, "only attempt", err.Error()) require.Less(t, elapsed, maxAcceptable, "Retry must not sleep after the last failed attempt; took %v (expected < %v)", elapsed, maxAcceptable) } func Test_ExponentialBackoff_Next(t *testing.T) { t.Parallel() tests := []struct { name string expBackoff *ExponentialBackoff expNextTimeIntervals []time.Duration }{ { name: "With default values", expBackoff: NewExponentialBackoff(), expNextTimeIntervals: []time.Duration{ 1 * time.Second, 2 * time.Second, 4 * time.Second, 8 * time.Second, 16 * time.Second, 32 * time.Second, 32 * time.Second, 32 * time.Second, 32 * time.Second, 32 * time.Second, }, }, { name: "Custom values", expBackoff: &ExponentialBackoff{ InitialInterval: 2.0 * time.Second, MaxBackoffTime: 64 * time.Second, Multiplier: 3.0, MaxRetryCount: 8, currentInterval: 2.0 * time.Second, }, expNextTimeIntervals: []time.Duration{ 2 * time.Second, 6 * time.Second, 18 * time.Second, 54 * time.Second, 64 * time.Second, 64 * time.Second, 64 * time.Second, 64 * time.Second, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() for i := range tt.expBackoff.MaxRetryCount { next := tt.expBackoff.next() if next < tt.expNextTimeIntervals[i] || next > tt.expNextTimeIntervals[i]+1*time.Second { t.Errorf("wrong next time:\n"+ "actual:%v\n"+ "expected range:%v-%v\n", next, tt.expNextTimeIntervals[i], tt.expNextTimeIntervals[i]+1*time.Second) } } }) } } func Test_ExponentialBackoff_NextRandFailure(t *testing.T) { // Backup original reader and restore at the end original := rand.Reader defer func() { rand.Reader = original }() rand.Reader = failingReader{} expBackoff := &ExponentialBackoff{ InitialInterval: 1 * time.Second, MaxBackoffTime: 10 * time.Second, Multiplier: 2, MaxRetryCount: 3, currentInterval: 1 * time.Second, } next := expBackoff.next() require.Equal(t, expBackoff.MaxBackoffTime, next) // currentInterval should not change when random fails require.Equal(t, 1*time.Second, expBackoff.currentInterval) } type failingReader struct{} func (failingReader) Read(_ []byte) (int, error) { return 0, errors.New("fail") } ================================================ FILE: app.go ================================================ // ⚡️ Fiber is an Express inspired web framework written in Go with ☕️ // 🤖 GitHub Repository: https://github.com/gofiber/fiber // 📌 API Documentation: https://docs.gofiber.io // Package fiber is an Express inspired web framework built on top of Fasthttp, // the fastest HTTP engine for Go. Designed to ease things up for fast // development with zero memory allocation and performance in mind. package fiber import ( "bufio" "context" "encoding/json" "encoding/xml" "errors" "fmt" "io" "net" "net/http" "net/http/httputil" "os" "reflect" "strconv" "strings" "sync" "time" "unsafe" "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" "github.com/gofiber/fiber/v3/binder" "github.com/gofiber/fiber/v3/log" ) // Version of current fiber package const Version = "3.1.0" // Handler defines a function to serve HTTP requests. type Handler = func(Ctx) error // Map is a shortcut for map[string]any, useful for JSON returns type Map map[string]any // ErrorHandler defines a function that will process all errors // returned from any handlers in the stack // // cfg := fiber.Config{} // cfg.ErrorHandler = func(c Ctx, err error) error { // code := StatusInternalServerError // var e *fiber.Error // if errors.As(err, &e) { // code = e.Code // } // c.Set(HeaderContentType, MIMETextPlainCharsetUTF8) // return c.Status(code).SendString(err.Error()) // } // app := fiber.New(cfg) type ErrorHandler = func(Ctx, error) error // Error represents an error that occurred while handling a request. type Error struct { Message string `json:"message"` Code int `json:"code"` } // App denotes the Fiber application. type App struct { // App config config Config // Indicates if the value was explicitly configured configured Config // Ctx pool pool sync.Pool // Fasthttp server server *fasthttp.Server // Converts string to a byte slice toBytes func(s string) (b []byte) // Converts byte slice to a string toString func(b []byte) string // Hooks hooks *Hooks // Latest route & group latestRoute *Route // newCtxFunc newCtxFunc func(app *App) CustomCtx // TLS handler tlsHandler *TLSHandler // Mount fields mountFields *mountFields // state management state *State // Route stack divided by HTTP methods stack [][]*Route // customConstraints is a list of external constraints customConstraints []CustomConstraint // sendfiles stores configurations for handling ctx.SendFile operations sendfiles []*sendFileStore // custom binders customBinders []CustomBinder // Route stack divided by HTTP methods and route prefixes treeStack []map[int][]*Route // sendfilesMutex is a mutex used for sendfile operations sendfilesMutex sync.RWMutex mutex sync.Mutex // Amount of registered handlers handlersCount uint32 // contains the information if the route stack has been changed to build the optimized tree routesRefreshed bool // hasCustomCtx tracks whether app uses a custom context implementation hasCustomCtx bool } // Config is a struct holding the server settings. type Config struct { //nolint:govet // Aligning the struct fields is not necessary. betteralign:ignore // Enables the "Server: value" HTTP header. // // Default: "" ServerHeader string `json:"server_header"` // When set to true, the router treats "/foo" and "/foo/" as different. // By default this is disabled and both "/foo" and "/foo/" will execute the same handler. // // Default: false StrictRouting bool `json:"strict_routing"` // When set to true, enables case-sensitive routing. // E.g. "/FoO" and "/foo" are treated as different routes. // By default this is disabled and both "/FoO" and "/foo" will execute the same handler. // // Default: false CaseSensitive bool `json:"case_sensitive"` // When set to true, disables automatic registration of HEAD routes for // every GET route. // // Default: false DisableHeadAutoRegister bool `json:"disable_head_auto_register"` // When set to true, this relinquishes the 0-allocation promise in certain // cases in order to access the handler values (e.g. request bodies) in an // immutable fashion so that these values are available even if you return // from handler. // // Default: false Immutable bool `json:"immutable"` // When set to true, converts all encoded characters in the route back // before setting the path for the context, so that the routing, // the returning of the current url from the context `ctx.Path()` // and the parameters `ctx.Params(%key%)` with decoded characters will work // // Default: false UnescapePath bool `json:"unescape_path"` // Max body size that the server accepts. // Zero or negative values fall back to the default limit. // // Default: 4 * 1024 * 1024 BodyLimit int `json:"body_limit"` // MaxRanges sets the maximum number of ranges parsed from a Range header. // Zero or negative values fall back to the default limit. // // Default: 16 MaxRanges int `json:"max_ranges"` // Maximum number of concurrent connections. // // Default: 256 * 1024 Concurrency int `json:"concurrency"` // Views is the interface that wraps the Render function. // // Default: nil Views Views `json:"-"` // Views Layout is the global layout for all template render until override on Render function. // // Default: "" ViewsLayout string `json:"views_layout"` // PassLocalsToViews Enables passing of the locals set on a fiber.Ctx to the template engine // // Default: false PassLocalsToViews bool `json:"pass_locals_to_views"` // PassLocalsToContext controls whether StoreInContext also propagates values to // the request context.Context for Fiber-backed contexts. // // ValueFromContext for Fiber-backed contexts always reads from c.Locals(). // // Default: false PassLocalsToContext bool `json:"pass_locals_to_context"` // The amount of time allowed to read the full request including body. // It is reset after the request handler has returned. // The connection's read deadline is reset when the connection opens. // // Default: unlimited ReadTimeout time.Duration `json:"read_timeout"` // The maximum duration before timing out writes of the response. // It is reset after the request handler has returned. // // Default: unlimited WriteTimeout time.Duration `json:"write_timeout"` // The maximum amount of time to wait for the next request when keep-alive is enabled. // If IdleTimeout is zero, the value of ReadTimeout is used. // // Default: unlimited IdleTimeout time.Duration `json:"idle_timeout"` // Per-connection buffer size for requests' reading. // This also limits the maximum header size. // Increase this buffer if your clients send multi-KB RequestURIs // and/or multi-KB headers (for example, BIG cookies). // // Default: 4096 ReadBufferSize int `json:"read_buffer_size"` // Per-connection buffer size for responses' writing. // // Default: 4096 WriteBufferSize int `json:"write_buffer_size"` // CompressedFileSuffixes adds suffix to the original file name and // tries saving the resulting compressed file under the new file name. // // Default: map[string]string{"gzip": ".fiber.gz", "br": ".fiber.br", "zstd": ".fiber.zst"} CompressedFileSuffixes map[string]string `json:"compressed_file_suffixes"` // ProxyHeader will enable c.IP() to return the value of the given header key // By default c.IP() will return the Remote IP from the TCP connection // This property can be useful if you are behind a load balancer: X-Forwarded-* // NOTE: headers are easily spoofed and the detected IP addresses are unreliable. // // Default: "" ProxyHeader string `json:"proxy_header"` // GETOnly rejects all non-GET requests if set to true. // This option is useful as anti-DoS protection for servers // accepting only GET requests. The request size is limited // by ReadBufferSize if GETOnly is set. // // Default: false GETOnly bool `json:"get_only"` // ErrorHandler is executed when an error is returned from fiber.Handler. // // Default: DefaultErrorHandler ErrorHandler ErrorHandler `json:"-"` // When set to true, disables keep-alive connections. // The server will close incoming connections after sending the first response to client. // // Default: false DisableKeepalive bool `json:"disable_keepalive"` // When set to true, causes the default date header to be excluded from the response. // // Default: false DisableDefaultDate bool `json:"disable_default_date"` // When set to true, causes the default Content-Type header to be excluded from the response. // // Default: false DisableDefaultContentType bool `json:"disable_default_content_type"` // When set to true, disables header normalization. // By default all header names are normalized: conteNT-tYPE -> Content-Type. // // Default: false DisableHeaderNormalizing bool `json:"disable_header_normalizing"` // This function allows to setup app name for the app // // Default: nil AppName string `json:"app_name"` // StreamRequestBody enables request body streaming, // and calls the handler sooner when given body is // larger than the current limit. // // Default: false StreamRequestBody bool // Will not pre parse Multipart Form data if set to true. // // This option is useful for servers that desire to treat // multipart form data as a binary blob, or choose when to parse the data. // // Server pre parses multipart form data by default. // // Default: false DisablePreParseMultipartForm bool // Aggressively reduces memory usage at the cost of higher CPU usage // if set to true. // // Try enabling this option only if the server consumes too much memory // serving mostly idle keep-alive connections. This may reduce memory // usage by more than 50%. // // Default: false ReduceMemoryUsage bool `json:"reduce_memory_usage"` // When set by an external client of Fiber it will use the provided implementation of a // JSONMarshal // // Allowing for flexibility in using another json library for encoding // Default: json.Marshal JSONEncoder utils.JSONMarshal `json:"-"` // When set by an external client of Fiber it will use the provided implementation of a // JSONUnmarshal // // Allowing for flexibility in using another json library for decoding // Default: json.Unmarshal JSONDecoder utils.JSONUnmarshal `json:"-"` // When set by an external client of Fiber it will use the provided implementation of a // MsgPackMarshal // // Allowing for flexibility in using another msgpack library for encoding // Default: binder.UnimplementedMsgpackMarshal MsgPackEncoder utils.MsgPackMarshal `json:"-"` // When set by an external client of Fiber it will use the provided implementation of a // MsgPackUnmarshal // // Allowing for flexibility in using another msgpack library for decoding // Default: binder.UnimplementedMsgpackUnmarshal MsgPackDecoder utils.MsgPackUnmarshal `json:"-"` // When set by an external client of Fiber it will use the provided implementation of a // CBORMarshal // // Allowing for flexibility in using another cbor library for encoding // Default: binder.UnimplementedCborMarshal CBOREncoder utils.CBORMarshal `json:"-"` // When set by an external client of Fiber it will use the provided implementation of a // CBORUnmarshal // // Allowing for flexibility in using another cbor library for decoding // Default: binder.UnimplementedCborUnmarshal CBORDecoder utils.CBORUnmarshal `json:"-"` // XMLEncoder set by an external client of Fiber it will use the provided implementation of a // XMLMarshal // // Allowing for flexibility in using another XML library for encoding // Default: xml.Marshal XMLEncoder utils.XMLMarshal `json:"-"` // XMLDecoder set by an external client of Fiber it will use the provided implementation of a // XMLUnmarshal // // Allowing for flexibility in using another XML library for decoding // Default: xml.Unmarshal XMLDecoder utils.XMLUnmarshal `json:"-"` // If you find yourself behind some sort of proxy, like a load balancer, // then certain header information may be sent to you using special X-Forwarded-* headers or the Forwarded header. // For example, the Host HTTP header is usually used to return the requested host. // But when you’re behind a proxy, the actual host may be stored in an X-Forwarded-Host header. // // If you are behind a proxy, you should enable TrustProxy to prevent header spoofing. // If you enable TrustProxy and do not provide a TrustProxyConfig, Fiber will skip // all headers that could be spoofed. // If the request IP is in the TrustProxyConfig.Proxies allowlist, then: // 1. c.Scheme() get value from X-Forwarded-Proto, X-Forwarded-Protocol, X-Forwarded-Ssl or X-Url-Scheme header // 2. c.IP() get value from ProxyHeader header. // 3. c.Host() and c.Hostname() get value from X-Forwarded-Host header // But if the request IP is NOT in the TrustProxyConfig.Proxies allowlist, then: // 1. c.Scheme() WON'T get value from X-Forwarded-Proto, X-Forwarded-Protocol, X-Forwarded-Ssl or X-Url-Scheme header, // will return https when a TLS connection is handled by the app, or http otherwise. // 2. c.IP() WON'T get value from ProxyHeader header, will return RemoteIP() from fasthttp context // 3. c.Host() and c.Hostname() WON'T get value from X-Forwarded-Host header, fasthttp.Request.URI().Host() // will be used to get the hostname. // // To automatically trust all loopback, link-local, or private IP addresses, // without manually adding them to the TrustProxyConfig.Proxies allowlist, // you can set TrustProxyConfig.Loopback, TrustProxyConfig.LinkLocal, or TrustProxyConfig.Private to true. // // Default: false TrustProxy bool `json:"trust_proxy"` // Read TrustProxy doc. // // Default: DefaultTrustProxyConfig TrustProxyConfig TrustProxyConfig `json:"trust_proxy_config"` // If set to true, c.IP() and c.IPs() will validate IP addresses before returning them. // Also, c.IP() will return only the first valid IP rather than just the raw header // WARNING: this has a performance cost associated with it. // // Default: false EnableIPValidation bool `json:"enable_ip_validation"` // You can define custom color scheme. They'll be used for startup message, route list and some middlewares. // // Optional. Default: DefaultColors ColorScheme Colors `json:"color_scheme"` // If you want to validate header/form/query... automatically when to bind, you can define struct validator. // Fiber doesn't have default validator, so it'll skip validator step if you don't use any validator. // // Default: nil StructValidator StructValidator // RequestMethods provides customizability for HTTP methods. You can add/remove methods as you wish. // // Optional. Default: DefaultMethods RequestMethods []string // EnableSplittingOnParsers splits the query/body/header parameters by comma when it's true. // For example, you can use it to parse multiple values from a query parameter like this: // /api?foo=bar,baz == foo[]=bar&foo[]=baz // // Optional. Default: false EnableSplittingOnParsers bool `json:"enable_splitting_on_parsers"` // Services is a list of services that are used by the app (e.g. databases, caches, etc.) // // Optional. Default: a zero value slice Services []Service // ServicesStartupContextProvider is a context provider for the startup of the services. // // Optional. Default: a provider that returns context.Background() ServicesStartupContextProvider func() context.Context // ServicesShutdownContextProvider is a context provider for the shutdown of the services. // // Optional. Default: a provider that returns context.Background() ServicesShutdownContextProvider func() context.Context } // Default TrustProxyConfig var DefaultTrustProxyConfig = TrustProxyConfig{} // TrustProxyConfig is a struct for configuring trusted proxies if Config.TrustProxy is true. type TrustProxyConfig struct { ips map[string]struct{} // Proxies is a list of trusted proxy IP addresses or CIDR ranges. // // Default: []string Proxies []string `json:"proxies"` ranges []*net.IPNet // LinkLocal enables trusting all link-local IP ranges (e.g., 169.254.0.0/16, fe80::/10). // // Default: false LinkLocal bool `json:"link_local"` // Loopback enables trusting all loopback IP ranges (e.g., 127.0.0.0/8, ::1/128). // // Default: false Loopback bool `json:"loopback"` // Private enables trusting all private IP ranges (e.g., 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, fc00::/7). // // Default: false Private bool `json:"private"` // UnixSocket enables trusting Unix domain socket connections. // When enabled, requests from Unix sockets are treated as trusted proxies. // // Default: false UnixSocket bool `json:"unix_socket"` } // RouteMessage is some message need to be print when server starts type RouteMessage struct { name string method string path string handlers string } // Default Config values const ( DefaultBodyLimit = 4 * 1024 * 1024 DefaultMaxRanges = 16 DefaultConcurrency = 256 * 1024 DefaultReadBufferSize = 4096 DefaultWriteBufferSize = 4096 ) const ( methodGet = iota methodHead methodPost methodPut methodDelete methodConnect methodOptions methodTrace methodPatch ) // HTTP methods enabled by default var DefaultMethods = []string{ methodGet: MethodGet, methodHead: MethodHead, methodPost: MethodPost, methodPut: MethodPut, methodDelete: MethodDelete, methodConnect: MethodConnect, methodOptions: MethodOptions, methodTrace: MethodTrace, methodPatch: MethodPatch, } // httpReadResponse - Used for test mocking http.ReadResponse var httpReadResponse = http.ReadResponse // DefaultErrorHandler that process return errors from handlers func DefaultErrorHandler(c Ctx, err error) error { code := StatusInternalServerError var e *Error if errors.As(err, &e) { code = e.Code } c.Set(HeaderContentType, MIMETextPlainCharsetUTF8) return c.Status(code).SendString(err.Error()) } // New creates a new Fiber named instance. // // app := fiber.New() // // You can pass optional configuration options by passing a Config struct: // // app := fiber.New(fiber.Config{ // ServerHeader: "Fiber", // }) func New(config ...Config) *App { // Create a new app app := &App{ // Create config config: Config{}, toBytes: utils.UnsafeBytes, toString: utils.UnsafeString, latestRoute: &Route{}, customBinders: []CustomBinder{}, sendfiles: []*sendFileStore{}, } // Create Ctx pool app.pool = sync.Pool{ New: func() any { if app.newCtxFunc != nil { return app.newCtxFunc(app) } return NewDefaultCtx(app) }, } // Define hooks app.hooks = newHooks(app) // Define mountFields app.mountFields = newMountFields(app) // Define state app.state = newState() // Override config if provided if len(config) > 0 { app.config = config[0] } // Initialize configured before defaults are set app.configured = app.config if err := app.validateConfiguredServices(); err != nil { panic(err) } // Override default values if app.config.BodyLimit <= 0 { app.config.BodyLimit = DefaultBodyLimit } if app.config.MaxRanges <= 0 { app.config.MaxRanges = DefaultMaxRanges } if app.config.Concurrency <= 0 { app.config.Concurrency = DefaultConcurrency } if app.config.ReadBufferSize <= 0 { app.config.ReadBufferSize = DefaultReadBufferSize } if app.config.WriteBufferSize <= 0 { app.config.WriteBufferSize = DefaultWriteBufferSize } if app.config.CompressedFileSuffixes == nil { app.config.CompressedFileSuffixes = map[string]string{ "gzip": ".fiber.gz", "br": ".fiber.br", "zstd": ".fiber.zst", } } if app.config.Immutable { app.toBytes, app.toString = toBytesImmutable, toStringImmutable } if app.config.ErrorHandler == nil { app.config.ErrorHandler = DefaultErrorHandler } if app.config.JSONEncoder == nil { app.config.JSONEncoder = json.Marshal } if app.config.JSONDecoder == nil { app.config.JSONDecoder = json.Unmarshal } if app.config.MsgPackEncoder == nil { app.config.MsgPackEncoder = binder.UnimplementedMsgpackMarshal } if app.config.MsgPackDecoder == nil { app.config.MsgPackDecoder = binder.UnimplementedMsgpackUnmarshal } if app.config.CBOREncoder == nil { app.config.CBOREncoder = binder.UnimplementedCborMarshal } if app.config.CBORDecoder == nil { app.config.CBORDecoder = binder.UnimplementedCborUnmarshal } if app.config.XMLEncoder == nil { app.config.XMLEncoder = xml.Marshal } if app.config.XMLDecoder == nil { app.config.XMLDecoder = xml.Unmarshal } if len(app.config.RequestMethods) == 0 { app.config.RequestMethods = DefaultMethods } app.config.TrustProxyConfig.ips = make(map[string]struct{}, len(app.config.TrustProxyConfig.Proxies)) for _, ipAddress := range app.config.TrustProxyConfig.Proxies { app.handleTrustedProxy(ipAddress) } // Create router stack app.stack = make([][]*Route, len(app.config.RequestMethods)) app.treeStack = make([]map[int][]*Route, len(app.config.RequestMethods)) // Override colors app.config.ColorScheme = defaultColors(&app.config.ColorScheme) // Init app app.init() // Return app return app } // NewWithCustomCtx creates a new Fiber instance and applies the // provided function to generate a custom context type. It mirrors the behavior // of calling `New()` followed by `app.setCtxFunc(fn)`. func NewWithCustomCtx(newCtxFunc func(app *App) CustomCtx, config ...Config) *App { app := New(config...) app.setCtxFunc(newCtxFunc) return app } // GetString returns s unchanged when Immutable is off or s is read-only (rodata). // Otherwise, it returns a detached copy (strings.Clone). func (app *App) GetString(s string) string { if !app.config.Immutable || s == "" { return s } if isReadOnly(unsafe.Pointer(unsafe.StringData(s))) { //nolint:gosec // pointer check avoids unnecessary copy return s // literal / rodata → safe to return as-is } return strings.Clone(s) // heap-backed / aliased → detach } // GetBytes returns b unchanged when Immutable is off or b is read-only (rodata). // Otherwise, it returns a detached copy. func (app *App) GetBytes(b []byte) []byte { if !app.config.Immutable || len(b) == 0 { return b } if isReadOnly(unsafe.Pointer(unsafe.SliceData(b))) { //nolint:gosec // pointer check avoids unnecessary copy return b // rodata → safe to return as-is } return utils.CopyBytes(b) // detach when backed by request/response memory } // Adds an ip address to TrustProxyConfig.ranges or TrustProxyConfig.ips based on whether it is an IP range or not func (app *App) handleTrustedProxy(ipAddress string) { if strings.IndexByte(ipAddress, '/') >= 0 { _, ipNet, err := net.ParseCIDR(ipAddress) if err != nil { log.Warnf("IP range %q could not be parsed: %v", ipAddress, err) } else { app.config.TrustProxyConfig.ranges = append(app.config.TrustProxyConfig.ranges, ipNet) } } else { ip := net.ParseIP(ipAddress) if ip == nil { log.Warnf("IP address %q could not be parsed", ipAddress) } else { app.config.TrustProxyConfig.ips[ipAddress] = struct{}{} } } } // setCtxFunc applies the given context factory to the app. // It is used internally by NewWithCustomCtx. It doesn't allow adding new methods, // only customizing existing ones. func (app *App) setCtxFunc(function func(app *App) CustomCtx) { app.newCtxFunc = function app.hasCustomCtx = function != nil if app.server != nil { app.server.Handler = app.requestHandler } } // RegisterCustomConstraint allows to register custom constraint. func (app *App) RegisterCustomConstraint(constraint CustomConstraint) { app.customConstraints = append(app.customConstraints, constraint) } // RegisterCustomBinder Allows to register custom binders to use as Bind().Custom("name"). // They should be compatible with CustomBinder interface. func (app *App) RegisterCustomBinder(customBinder CustomBinder) { app.customBinders = append(app.customBinders, customBinder) } // ReloadViews reloads the configured view engine by invoking its Load method. // It returns an error if no view engine is configured or if reloading fails. func (app *App) ReloadViews() error { app.mutex.Lock() defer app.mutex.Unlock() apps := map[string]*App{"": app} if app.mountFields != nil { apps = app.mountFields.appList } var reloaded bool for _, targetApp := range apps { if targetApp == nil || targetApp.config.Views == nil { continue } if viewValue := reflect.ValueOf(targetApp.config.Views); viewValue.Kind() == reflect.Pointer && viewValue.IsNil() { continue } if err := targetApp.config.Views.Load(); err != nil { return fmt.Errorf("fiber: failed to reload views: %w", err) } reloaded = true } if !reloaded { return ErrNoViewEngineConfigured } return nil } // SetTLSHandler Can be used to set ClientHelloInfo when using TLS with Listener. func (app *App) SetTLSHandler(tlsHandler *TLSHandler) { // Attach the tlsHandler to the config app.mutex.Lock() app.tlsHandler = tlsHandler app.mutex.Unlock() } // Name Assign name to specific route. func (app *App) Name(name string) Router { app.mutex.Lock() defer app.mutex.Unlock() for _, routes := range app.stack { for _, route := range routes { isMethodValid := route.Method == app.latestRoute.Method || app.latestRoute.use || (app.latestRoute.Method == MethodGet && route.Method == MethodHead) if route.Path == app.latestRoute.Path && isMethodValid { route.Name = name if route.group != nil { route.Name = route.group.name + route.Name } } } } if err := app.hooks.executeOnNameHooks(app.latestRoute); err != nil { panic(err) } return app } // GetRoute Get route by name func (app *App) GetRoute(name string) Route { for _, routes := range app.stack { for _, route := range routes { if route.Name == name { return *route } } } return Route{} } // GetRoutes Get all routes. When filterUseOption equal to true, it will filter the routes registered by the middleware. func (app *App) GetRoutes(filterUseOption ...bool) []Route { var rs []Route var filterUse bool if len(filterUseOption) != 0 { filterUse = filterUseOption[0] } for _, routes := range app.stack { for _, route := range routes { if filterUse && route.use { continue } rs = append(rs, *route) } } return rs } // Use registers a middleware route that will match requests // with the provided prefix (which is optional and defaults to "/"). // Also, you can pass another app instance as a sub-router along a routing path. // It's very useful to split up a large API as many independent routers and // compose them as a single service using Use. The fiber's error handler and // any of the fiber's sub apps are added to the application's error handlers // to be invoked on errors that happen within the prefix route. // // app.Use(func(c fiber.Ctx) error { // return c.Next() // }) // app.Use("/api", func(c fiber.Ctx) error { // return c.Next() // }) // app.Use("/api", handler, func(c fiber.Ctx) error { // return c.Next() // }) // subApp := fiber.New() // app.Use("/mounted-path", subApp) // // This method will match all HTTP verbs: GET, POST, PUT, HEAD etc... func (app *App) Use(args ...any) Router { var prefix string var subApp *App var prefixes []string var handlers []Handler for i := range args { switch arg := args[i].(type) { case string: prefix = arg case *App: subApp = arg case []string: prefixes = arg default: handler, ok := toFiberHandler(arg) if !ok { panic(fmt.Sprintf("use: invalid handler %v\n", reflect.TypeOf(arg))) } handlers = append(handlers, handler) } } if len(prefixes) == 0 { prefixes = append(prefixes, prefix) } for _, prefix := range prefixes { if subApp != nil { return app.mount(prefix, subApp) } app.register([]string{methodUse}, prefix, nil, handlers...) } return app } // Get registers a route for GET methods that requests a representation // of the specified resource. Requests using GET should only retrieve data. func (app *App) Get(path string, handler any, handlers ...any) Router { return app.Add([]string{MethodGet}, path, handler, handlers...) } // Head registers a route for HEAD methods that asks for a response identical // to that of a GET request, but without the response body. func (app *App) Head(path string, handler any, handlers ...any) Router { return app.Add([]string{MethodHead}, path, handler, handlers...) } // Post registers a route for POST methods that is used to submit an entity to the // specified resource, often causing a change in state or side effects on the server. func (app *App) Post(path string, handler any, handlers ...any) Router { return app.Add([]string{MethodPost}, path, handler, handlers...) } // Put registers a route for PUT methods that replaces all current representations // of the target resource with the request payload. func (app *App) Put(path string, handler any, handlers ...any) Router { return app.Add([]string{MethodPut}, path, handler, handlers...) } // Delete registers a route for DELETE methods that deletes the specified resource. func (app *App) Delete(path string, handler any, handlers ...any) Router { return app.Add([]string{MethodDelete}, path, handler, handlers...) } // Connect registers a route for CONNECT methods that establishes a tunnel to the // server identified by the target resource. func (app *App) Connect(path string, handler any, handlers ...any) Router { return app.Add([]string{MethodConnect}, path, handler, handlers...) } // Options registers a route for OPTIONS methods that is used to describe the // communication options for the target resource. func (app *App) Options(path string, handler any, handlers ...any) Router { return app.Add([]string{MethodOptions}, path, handler, handlers...) } // Trace registers a route for TRACE methods that performs a message loop-back // test along the path to the target resource. func (app *App) Trace(path string, handler any, handlers ...any) Router { return app.Add([]string{MethodTrace}, path, handler, handlers...) } // Patch registers a route for PATCH methods that is used to apply partial // modifications to a resource. func (app *App) Patch(path string, handler any, handlers ...any) Router { return app.Add([]string{MethodPatch}, path, handler, handlers...) } // Add allows you to specify multiple HTTP methods to register a route. // The provided handlers are executed in order, starting with `handler` and then the variadic `handlers`. func (app *App) Add(methods []string, path string, handler any, handlers ...any) Router { converted := collectHandlers("add", append([]any{handler}, handlers...)...) app.register(methods, path, nil, converted...) return app } // All will register the handler on all HTTP methods func (app *App) All(path string, handler any, handlers ...any) Router { return app.Add(app.config.RequestMethods, path, handler, handlers...) } // Group is used for Routes with common prefix to define a new sub-router with optional middleware. // // api := app.Group("/api") // api.Get("/users", handler) func (app *App) Group(prefix string, handlers ...any) Router { grp := &Group{Prefix: prefix, app: app} if len(handlers) > 0 { converted := collectHandlers("group", handlers...) app.register([]string{methodUse}, prefix, grp, converted...) } if err := app.hooks.executeOnGroupHooks(*grp); err != nil { panic(err) } return grp } // RouteChain creates a Registering instance that lets you declare a stack of // handlers for the same route. Handlers defined via the returned Register are // scoped to the provided path. func (app *App) RouteChain(path string) Register { // Create new route route := &Registering{app: app, path: path} return route } // Route is used to define routes with a common prefix inside the supplied // function. It mirrors the legacy helper and reuses the Group method to create // a sub-router. func (app *App) Route(prefix string, fn func(router Router), name ...string) Router { if fn == nil { panic("route handler 'fn' cannot be nil") } // Create new group group := app.Group(prefix) if len(name) > 0 { group.Name(name[0]) } // Define routes fn(group) return group } // Error makes it compatible with the `error` interface. func (e *Error) Error() string { return e.Message } // NewError creates a new Error instance with an optional message func NewError(code int, message ...string) *Error { err := &Error{ Code: code, Message: utils.StatusMessage(code), } if len(message) > 0 { err.Message = message[0] } return err } // NewErrorf creates a new Error instance with an optional message. // Additional arguments are formatted using fmt.Sprintf when provided. // If the first argument in the message slice is not a string, the function // falls back to using fmt.Sprint on the first element to generate the message. func NewErrorf(code int, message ...any) *Error { var msg string switch len(message) { case 0: // nothing to override msg = utils.StatusMessage(code) case 1: // One argument → treat it like fmt.Sprint(arg) if s, ok := message[0].(string); ok { msg = s } else { msg = fmt.Sprint(message[0]) } default: // Two or more → first must be a format string. if format, ok := message[0].(string); ok { msg = fmt.Sprintf(format, message[1:]...) } else { // If the first arg isn’t a string, fall back. msg = fmt.Sprint(message[0]) } } return &Error{Code: code, Message: msg} } // Config returns the app config as value ( read-only ). func (app *App) Config() Config { return app.config } // Handler returns the server handler. func (app *App) Handler() fasthttp.RequestHandler { //revive:disable-line:confusing-naming // Having both a Handler() (uppercase) and a handler() (lowercase) is fine. TODO: Use nolint:revive directive instead. See https://github.com/golangci/golangci-lint/issues/3476 // prepare the server for the start app.startupProcess() return app.requestHandler } // Stack returns the raw router stack. func (app *App) Stack() [][]*Route { return app.stack } // HandlersCount returns the amount of registered handlers. func (app *App) HandlersCount() uint32 { return app.handlersCount } // Shutdown gracefully shuts down the server without interrupting any active connections. // Shutdown works by first closing all open listeners and then waiting indefinitely for all connections to return to idle before shutting down. // // Make sure the program doesn't exit and waits instead for Shutdown to return. // // Important: app.Listen() must be called in a separate goroutine; otherwise, shutdown hooks will not work // as Listen() is a blocking operation. Example: // // go app.Listen(":3000") // // ... // app.Shutdown() // // Shutdown does not close keepalive connections so its recommended to set ReadTimeout to something else than 0. func (app *App) Shutdown() error { return app.ShutdownWithContext(context.Background()) } // ShutdownWithTimeout gracefully shuts down the server without interrupting any active connections. However, if the timeout is exceeded, // ShutdownWithTimeout will forcefully close any active connections. // ShutdownWithTimeout works by first closing all open listeners and then waiting for all connections to return to idle before shutting down. // // Make sure the program doesn't exit and waits instead for ShutdownWithTimeout to return. // // ShutdownWithTimeout does not close keepalive connections so its recommended to set ReadTimeout to something else than 0. func (app *App) ShutdownWithTimeout(timeout time.Duration) error { ctx, cancelFunc := context.WithTimeout(context.Background(), timeout) defer cancelFunc() return app.ShutdownWithContext(ctx) } // ShutdownWithContext shuts down the server including by force if the context's deadline is exceeded. // // Make sure the program doesn't exit and waits instead for ShutdownWithTimeout to return. // // ShutdownWithContext does not close keepalive connections so its recommended to set ReadTimeout to something else than 0. func (app *App) ShutdownWithContext(ctx context.Context) error { app.mutex.Lock() defer app.mutex.Unlock() var err error if app.server == nil { return ErrNotRunning } // Execute the Shutdown hook app.hooks.executeOnPreShutdownHooks() defer app.hooks.executeOnPostShutdownHooks(err) err = app.server.ShutdownWithContext(ctx) return err } // Server returns the underlying fasthttp server func (app *App) Server() *fasthttp.Server { return app.server } // Hooks returns the hook struct to register hooks. func (app *App) Hooks() *Hooks { return app.hooks } // State returns the state struct to store global data in order to share it between handlers. func (app *App) State() *State { return app.state } var ErrTestGotEmptyResponse = errors.New("test: got empty response") // TestConfig is a struct holding Test settings type TestConfig struct { // Timeout defines the maximum duration a // test can run before timing out. // Default: time.Second Timeout time.Duration // FailOnTimeout specifies whether the test // should return a timeout error if the HTTP response // exceeds the Timeout duration. // Default: true FailOnTimeout bool } // Test is used for internal debugging by passing a *http.Request. // Config is optional and defaults to a 1s error on timeout, // 0 timeout will disable it completely. func (app *App) Test(req *http.Request, config ...TestConfig) (*http.Response, error) { // Default config cfg := TestConfig{ Timeout: time.Second, FailOnTimeout: true, } // Override config if provided if len(config) > 0 { cfg = config[0] } // Add Content-Length if not provided with body if req.Body != http.NoBody && req.Header.Get(HeaderContentLength) == "" { req.Header.Add(HeaderContentLength, strconv.FormatInt(req.ContentLength, 10)) } // Dump raw http request dump, err := httputil.DumpRequest(req, true) if err != nil { return nil, fmt.Errorf("failed to dump request: %w", err) } // Create test connection conn := new(testConn) // Write raw http request if _, err = conn.r.Write(dump); err != nil { return nil, fmt.Errorf("failed to write: %w", err) } // prepare the server for the start app.startupProcess() // Serve conn to server channel := make(chan error, 1) go func() { var returned bool defer func() { if !returned { channel <- ErrHandlerExited } }() channel <- app.server.ServeConn(conn) returned = true }() // Wait for callback if cfg.Timeout > 0 { // With timeout select { case err = <-channel: case <-time.After(cfg.Timeout): if cfg.FailOnTimeout { conn.Close() //nolint:errcheck // It is fine to ignore the error here return nil, os.ErrDeadlineExceeded } // When FailOnTimeout is false, wait up to 1 additional second for the handler // to complete and write a response. This prevents indefinite blocking while // allowing slow handlers to finish. select { case err = <-channel: case <-time.After(time.Second): // Handler took too long even with extra time conn.Close() //nolint:errcheck // It is fine to ignore the error here } } } else { // Without timeout err = <-channel } // Check for errors if err != nil && !errors.Is(err, fasthttp.ErrGetOnly) && !errors.Is(err, errTestConnClosed) { return nil, err } // Read response(s) buffer := bufio.NewReader(&conn.w) var res *http.Response for { // Convert raw http response to *http.Response res, err = httpReadResponse(buffer, req) if err != nil { if errors.Is(err, io.ErrUnexpectedEOF) { return nil, ErrTestGotEmptyResponse } return nil, fmt.Errorf("failed to read response: %w", err) } // Break if this response is non-1xx or there are no more responses if res.StatusCode >= http.StatusOK || buffer.Buffered() == 0 { break } // Discard interim response body before reading the next one if res.Body != nil { if _, errCopy := io.Copy(io.Discard, res.Body); errCopy != nil { return nil, fmt.Errorf("failed to discard interim response body: %w", errCopy) } if errClose := res.Body.Close(); errClose != nil { return nil, fmt.Errorf("failed to close interim response body: %w", errClose) } } } return res, nil } type disableLogger struct{} // Printf implements the fasthttp Logger interface and discards log output. func (*disableLogger) Printf(string, ...any) { } func (app *App) init() *App { // lock application app.mutex.Lock() // Initialize Services when needed, // panics if there is an error starting them. app.initServices() // Only load templates if a view engine is specified if app.config.Views != nil { if err := app.config.Views.Load(); err != nil { log.Warnf("failed to load views: %v", err) } } // create fasthttp server app.server = &fasthttp.Server{ Logger: &disableLogger{}, LogAllErrors: false, ErrorHandler: app.serverErrorHandler, } // fasthttp server settings app.server.Handler = app.requestHandler app.server.Name = app.config.ServerHeader app.server.Concurrency = app.config.Concurrency app.server.NoDefaultDate = app.config.DisableDefaultDate app.server.NoDefaultContentType = app.config.DisableDefaultContentType app.server.DisableHeaderNamesNormalizing = app.config.DisableHeaderNormalizing app.server.DisableKeepalive = app.config.DisableKeepalive app.server.MaxRequestBodySize = app.config.BodyLimit app.server.NoDefaultServerHeader = app.config.ServerHeader == "" app.server.ReadTimeout = app.config.ReadTimeout app.server.WriteTimeout = app.config.WriteTimeout app.server.IdleTimeout = app.config.IdleTimeout app.server.ReadBufferSize = app.config.ReadBufferSize app.server.WriteBufferSize = app.config.WriteBufferSize app.server.GetOnly = app.config.GETOnly app.server.ReduceMemoryUsage = app.config.ReduceMemoryUsage app.server.StreamRequestBody = app.config.StreamRequestBody app.server.DisablePreParseMultipartForm = app.config.DisablePreParseMultipartForm // unlock application app.mutex.Unlock() // Register the Services shutdown handler once the app is initialized and unlocked. app.Hooks().OnPostShutdown(func(_ error) error { if err := app.shutdownServices(app.servicesShutdownCtx()); err != nil { log.Errorf("failed to shutdown services: %v", err) } return nil }) return app } // ErrorHandler is the application's method in charge of finding the // appropriate handler for the given request. It searches any mounted // sub fibers by their prefixes and if it finds a match, it uses that // error handler. Otherwise, it uses the configured error handler for // the app, which if not set is the DefaultErrorHandler. func (app *App) ErrorHandler(ctx Ctx, err error) error { var ( mountedErrHandler ErrorHandler mountedPrefixParts int ) normalizedPath := utils.AddTrailingSlashString(ctx.Path()) for _, prefix := range app.mountFields.appListKeys { subApp := app.mountFields.appList[prefix] normalizedPrefix := utils.AddTrailingSlashString(prefix) if prefix != "" && strings.HasPrefix(normalizedPath, normalizedPrefix) { // Count slashes instead of splitting - more efficient parts := strings.Count(prefix, "/") + 1 if mountedPrefixParts <= parts { if subApp.configured.ErrorHandler != nil { mountedErrHandler = subApp.config.ErrorHandler } mountedPrefixParts = parts } } } if mountedErrHandler != nil { return mountedErrHandler(ctx, err) } return app.config.ErrorHandler(ctx, err) } // serverErrorHandler is a wrapper around the application's error handler method // user for the fasthttp server configuration. It maps a set of fasthttp errors to fiber // errors before calling the application's error handler method. func (app *App) serverErrorHandler(fctx *fasthttp.RequestCtx, err error) { // Acquire Ctx with fasthttp request from pool c := app.AcquireCtx(fctx) defer app.ReleaseCtx(c) var ( errNetOP *net.OpError netErr net.Error ) switch { case errors.As(err, new(*fasthttp.ErrSmallBuffer)): err = ErrRequestHeaderFieldsTooLarge case errors.As(err, &errNetOP) && errNetOP.Timeout(): err = ErrRequestTimeout case errors.As(err, &netErr): err = ErrBadGateway case errors.Is(err, fasthttp.ErrBodyTooLarge): err = ErrRequestEntityTooLarge case errors.Is(err, fasthttp.ErrGetOnly): err = ErrMethodNotAllowed case strings.Contains(err.Error(), "unsupported http request method"): err = ErrNotImplemented case strings.Contains(err.Error(), "timeout"): err = ErrRequestTimeout default: err = NewError(StatusBadRequest, err.Error()) } if c.getMethodInt() != -1 { c.setSkipNonUseRoutes(true) defer c.setSkipNonUseRoutes(false) var nextErr error if d, isDefault := c.(*DefaultCtx); isDefault { _, nextErr = app.next(d) } else { _, nextErr = app.nextCustom(c) } if nextErr != nil && !errors.Is(nextErr, ErrNotFound) && !errors.Is(nextErr, ErrMethodNotAllowed) { log.Errorf("serverErrorHandler: middleware traversal failed: %v", nextErr) } } if catch := app.ErrorHandler(c, err); catch != nil { log.Errorf("serverErrorHandler: failed to call ErrorHandler: %v", catch) _ = c.SendStatus(StatusInternalServerError) //nolint:errcheck // It is fine to ignore the error here return } } // startupProcess Is the method which executes all the necessary processes just before the start of the server. func (app *App) startupProcess() { app.mutex.Lock() defer app.mutex.Unlock() app.ensureAutoHeadRoutesLocked() for prefix, subApp := range app.mountFields.appList { if prefix == "" { continue } subApp.ensureAutoHeadRoutes() } app.mountStartupProcess() // build route tree stack app.buildTree() } // Run onListen hooks. If they return an error, panic. func (app *App) runOnListenHooks(listenData *ListenData) { if err := app.hooks.executeOnListenHooks(listenData); err != nil { panic(err) } } ================================================ FILE: app_integration_test.go ================================================ package fiber_test import ( "bytes" "errors" "net" "net/http" "net/http/httptest" "strconv" "strings" "testing" "time" "github.com/stretchr/testify/require" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/basicauth" "github.com/gofiber/fiber/v3/middleware/cache" "github.com/gofiber/fiber/v3/middleware/compress" "github.com/gofiber/fiber/v3/middleware/cors" "github.com/gofiber/fiber/v3/middleware/csrf" "github.com/gofiber/fiber/v3/middleware/encryptcookie" "github.com/gofiber/fiber/v3/middleware/envvar" "github.com/gofiber/fiber/v3/middleware/helmet" "github.com/gofiber/fiber/v3/middleware/keyauth" "github.com/gofiber/fiber/v3/middleware/limiter" "github.com/gofiber/fiber/v3/middleware/recover" "github.com/gofiber/fiber/v3/middleware/requestid" "github.com/gofiber/fiber/v3/middleware/session" "github.com/valyala/fasthttp" "github.com/valyala/fasthttp/fasthttputil" ) type integrationCustomCtx struct { *fiber.DefaultCtx } func newIntegrationCustomCtx(app *fiber.App) fiber.CustomCtx { return &integrationCustomCtx{DefaultCtx: fiber.NewDefaultCtx(app)} } func performOversizedRequest(t *testing.T, app *fiber.App, configure func(req *fasthttp.Request)) *fasthttp.Response { t.Helper() ln := fasthttputil.NewInmemoryListener() errCh := make(chan error, 1) go func() { errCh <- app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true}) }() t.Cleanup(func() { require.NoError(t, app.Shutdown()) if err := <-errCh; err != nil && !errors.Is(err, net.ErrClosed) { require.NoError(t, err) } }) require.Eventually(t, func() bool { conn, err := ln.Dial() if err != nil { return false } if err := conn.Close(); err != nil { return false } return true }, time.Second, 10*time.Millisecond) req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() req.SetRequestURI("http://example.com/") req.Header.SetMethod(fiber.MethodPost) req.Header.Set(fiber.HeaderOrigin, "https://example.com") req.SetBody(bytes.Repeat([]byte{'a'}, 32)) if configure != nil { configure(req) } client := fasthttp.Client{ Dial: func(string) (net.Conn, error) { return ln.Dial() }, } require.NoError(t, client.Do(req, resp)) respCopy := fasthttp.AcquireResponse() resp.CopyTo(respCopy) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) t.Cleanup(func() { fasthttp.ReleaseResponse(respCopy) }) return respCopy } var integrationEncryptCookieKey = encryptcookie.GenerateKey(32) // middlewareCombinationTestCase describes a middleware stack that should keep its // headers intact even when the default error handler runs. Keeping it as a named // type (instead of an inline struct) makes the massive table below easier to // scan and extend. // //nolint:govet // field alignment is secondary to readability for this test table type middlewareCombinationTestCase struct { // betteralign:ignore - readability takes priority in tests name string setup func(app *fiber.App) configureRequest func(req *fasthttp.Request) handler func(c fiber.Ctx) error assertions func(t *testing.T, resp *fasthttp.Response) expectedStatus int } func (tc middlewareCombinationTestCase) statusOrDefault() int { if tc.expectedStatus == 0 { return fiber.StatusInternalServerError } return tc.expectedStatus } func (tc middlewareCombinationTestCase) handlerOrDefault() func(fiber.Ctx) error { if tc.handler != nil { return tc.handler } return func(fiber.Ctx) error { return fiber.NewError(fiber.StatusInternalServerError, "middleware combination failure") } } func Test_Integration_RequestID_ContextPropagationFlag(t *testing.T) { t.Parallel() t.Run("disabled by default", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(requestid.New(requestid.Config{Generator: func() string { return "rid-disabled" }})) app.Get("/", func(c fiber.Ctx) error { require.Equal(t, "rid-disabled", requestid.FromContext(c)) require.Empty(t, requestid.FromContext(c.Context())) return c.SendStatus(fiber.StatusOK) }) resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) }) t.Run("enabled", func(t *testing.T) { t.Parallel() app := fiber.New(fiber.Config{PassLocalsToContext: true}) app.Use(requestid.New(requestid.Config{Generator: func() string { return "rid-enabled" }})) app.Get("/", func(c fiber.Ctx) error { require.Equal(t, "rid-enabled", requestid.FromContext(c)) require.Equal(t, "rid-enabled", requestid.FromContext(c.Context())) return c.SendStatus(fiber.StatusOK) }) resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) }) } func Test_Integration_App_ServerErrorHandler_MiddlewareCombinationHeaders(t *testing.T) { t.Parallel() // This integration suite exercises representative middleware stacks to ensure their // response headers survive after Fiber's default error handler emits a failure. const ( // Origins used by the CORS stacks in this suite. corsHelmetOrigin = "https://cors-and-helmet.example" corsRequestIDOrigin = "https://cors-and-requestid.example" corsCSRForigin = "https://cors-and-csrf.example" corsCacheOrigin = "https://cors-and-cache.example" corsSessionOrigin = "https://cors-and-session.example" corsHelmetRequestID = "https://cors-helmet-requestid.example" csrfCookieName = "combo-csrf" generatedRequestID = "generated-combo-request-id" helmetLimiterMax = 7 helmetLimiterReset = 60 requestIDHeader = "combo-request-id" csrfTokenValue = "csrf-token" encryptedCookieName = "combo-encrypted" encryptedCookieVal = "unencrypted" envvarAllowHeader = fiber.MethodGet + ", " + fiber.MethodHead basicRealm = "combo-basic" keyAuthRealm = "combo-key" keyAuthErrorDesc = "missing-key" ) // Each entry wires up a different middleware stack so we can ensure response mutations // survive the hop through the default error handler. testCases := []middlewareCombinationTestCase{ // --- CORS-focused stacks keep cross-origin metadata on error responses. { name: "cors+helmet", setup: func(app *fiber.App) { app.Use(cors.New(cors.Config{AllowOrigins: []string{corsHelmetOrigin}})) app.Use(helmet.New()) }, configureRequest: func(req *fasthttp.Request) { req.Header.Set(fiber.HeaderOrigin, corsHelmetOrigin) }, assertions: func(t *testing.T, resp *fasthttp.Response) { t.Helper() require.Equal(t, corsHelmetOrigin, string(resp.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) require.Equal(t, "nosniff", string(resp.Header.Peek(fiber.HeaderXContentTypeOptions))) require.Equal(t, "same-origin", string(resp.Header.Peek("Cross-Origin-Opener-Policy"))) require.Equal(t, "same-origin", string(resp.Header.Peek("Cross-Origin-Resource-Policy"))) require.Equal(t, "require-corp", string(resp.Header.Peek("Cross-Origin-Embedder-Policy"))) }, }, { name: "cors+requestid", setup: func(app *fiber.App) { app.Use(cors.New(cors.Config{AllowOrigins: []string{corsRequestIDOrigin}})) app.Use(requestid.New()) }, configureRequest: func(req *fasthttp.Request) { req.Header.Set(fiber.HeaderOrigin, corsRequestIDOrigin) req.Header.Set("X-Request-ID", requestIDHeader) }, assertions: func(t *testing.T, resp *fasthttp.Response) { t.Helper() require.Equal(t, corsRequestIDOrigin, string(resp.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) require.Equal(t, requestIDHeader, string(resp.Header.Peek("X-Request-ID"))) }, }, { name: "cors+helmet+requestid", setup: func(app *fiber.App) { app.Use(cors.New(cors.Config{AllowOrigins: []string{corsHelmetRequestID}})) app.Use(helmet.New()) app.Use(requestid.New(requestid.Config{ Generator: func() string { return generatedRequestID }, })) }, configureRequest: func(req *fasthttp.Request) { req.Header.Set(fiber.HeaderOrigin, corsHelmetRequestID) }, assertions: func(t *testing.T, resp *fasthttp.Response) { t.Helper() require.Equal(t, corsHelmetRequestID, string(resp.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) require.Equal(t, generatedRequestID, string(resp.Header.Peek("X-Request-ID"))) require.Equal(t, "nosniff", string(resp.Header.Peek(fiber.HeaderXContentTypeOptions))) }, }, { name: "cors+cache", setup: func(app *fiber.App) { app.Use(cors.New(cors.Config{AllowOrigins: []string{corsCacheOrigin}})) app.Use(cache.New()) // Cache needs the default error handler to execute so it can emit X-Cache on failures. app.Use(func(c fiber.Ctx) error { if err := c.Next(); err != nil { if handlerErr := app.Config().ErrorHandler(c, err); handlerErr != nil { return handlerErr } c.Set(fiber.HeaderCacheControl, "no-store") return nil } c.Set(fiber.HeaderCacheControl, "no-store") return nil }) }, configureRequest: func(req *fasthttp.Request) { req.Header.Set(fiber.HeaderOrigin, corsCacheOrigin) req.Header.SetMethod(fiber.MethodGet) }, assertions: func(t *testing.T, resp *fasthttp.Response) { t.Helper() require.Equal(t, corsCacheOrigin, string(resp.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) require.Equal(t, "unreachable", string(resp.Header.Peek("X-Cache"))) require.Equal(t, "no-store", string(resp.Header.Peek(fiber.HeaderCacheControl))) }, }, { name: "cors+session", setup: func(app *fiber.App) { app.Use(cors.New(cors.Config{ AllowOrigins: []string{corsSessionOrigin}, AllowCredentials: true, })) app.Use(session.New()) app.Use(func(c fiber.Ctx) error { if sm := session.FromContext(c); sm != nil { sm.Set("cors-session", "enabled") } return c.Next() }) }, configureRequest: func(req *fasthttp.Request) { req.Header.Set(fiber.HeaderOrigin, corsSessionOrigin) }, assertions: func(t *testing.T, resp *fasthttp.Response) { t.Helper() require.Equal(t, corsSessionOrigin, string(resp.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) require.Equal(t, "true", string(resp.Header.Peek(fiber.HeaderAccessControlAllowCredentials))) require.Contains(t, string(resp.Header.Peek(fiber.HeaderSetCookie)), "session_id=") }, }, { name: "helmet+encryptcookie", setup: func(app *fiber.App) { app.Use(helmet.New()) app.Use(encryptcookie.New(encryptcookie.Config{Key: integrationEncryptCookieKey})) app.Use(func(c fiber.Ctx) error { c.Cookie(&fiber.Cookie{Name: encryptedCookieName, Value: encryptedCookieVal}) return c.Next() }) }, assertions: func(t *testing.T, resp *fasthttp.Response) { t.Helper() require.Equal(t, "nosniff", string(resp.Header.Peek(fiber.HeaderXContentTypeOptions))) cookieHeader := string(resp.Header.Peek(fiber.HeaderSetCookie)) require.Contains(t, cookieHeader, encryptedCookieName+"=") require.NotContains(t, cookieHeader, encryptedCookieVal) }, }, // --- Helmet anchored stacks validate security headers across other middleware. { name: "helmet+limiter", setup: func(app *fiber.App) { app.Use(helmet.New()) app.Use(limiter.New(limiter.Config{ Max: helmetLimiterMax, Expiration: time.Duration(helmetLimiterReset) * time.Second, KeyGenerator: func(fiber.Ctx) string { return "helmet+limiter" }, })) }, assertions: func(t *testing.T, resp *fasthttp.Response) { t.Helper() require.Equal(t, "nosniff", string(resp.Header.Peek(fiber.HeaderXContentTypeOptions))) require.Equal(t, strconv.Itoa(helmetLimiterMax), string(resp.Header.Peek("X-RateLimit-Limit"))) require.Equal(t, strconv.Itoa(helmetLimiterMax-1), string(resp.Header.Peek("X-RateLimit-Remaining"))) require.Equal(t, strconv.Itoa(helmetLimiterReset), string(resp.Header.Peek("X-RateLimit-Reset"))) }, }, { name: "cors+csrf", setup: func(app *fiber.App) { app.Use(cors.New(cors.Config{ AllowOrigins: []string{corsCSRForigin}, AllowCredentials: true, })) app.Use(csrf.New(csrf.Config{ CookieName: csrfCookieName, KeyGenerator: func() string { return csrfTokenValue }, })) }, configureRequest: func(req *fasthttp.Request) { req.Header.SetMethod(fiber.MethodGet) req.Header.Set(fiber.HeaderOrigin, corsCSRForigin) }, assertions: func(t *testing.T, resp *fasthttp.Response) { t.Helper() require.Equal(t, corsCSRForigin, string(resp.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) require.Equal(t, "true", string(resp.Header.Peek(fiber.HeaderAccessControlAllowCredentials))) require.Contains(t, string(resp.Header.Peek(fiber.HeaderSetCookie)), csrfCookieName+"="+csrfTokenValue) }, }, { name: "helmet+session", setup: func(app *fiber.App) { app.Use(helmet.New()) app.Use(session.New()) app.Use(func(c fiber.Ctx) error { if sm := session.FromContext(c); sm != nil { sm.Set("combo-session", "enabled") } return c.Next() }) }, assertions: func(t *testing.T, resp *fasthttp.Response) { t.Helper() require.Equal(t, "nosniff", string(resp.Header.Peek(fiber.HeaderXContentTypeOptions))) require.Contains(t, string(resp.Header.Peek(fiber.HeaderSetCookie)), "session_id=") }, }, { name: "helmet+csrf", setup: func(app *fiber.App) { app.Use(helmet.New()) app.Use(csrf.New(csrf.Config{ CookieName: csrfCookieName, KeyGenerator: func() string { return csrfTokenValue }, })) }, configureRequest: func(req *fasthttp.Request) { req.Header.SetMethod(fiber.MethodGet) }, assertions: func(t *testing.T, resp *fasthttp.Response) { t.Helper() require.Equal(t, "nosniff", string(resp.Header.Peek(fiber.HeaderXContentTypeOptions))) require.Contains(t, string(resp.Header.Peek(fiber.HeaderSetCookie)), csrfCookieName+"="+csrfTokenValue) }, }, { name: "helmet+envvar", setup: func(app *fiber.App) { app.Use(helmet.New()) app.Use(envvar.New(envvar.Config{ExportVars: map[string]string{"COMBO_ENV": "configured"}})) }, expectedStatus: fiber.StatusMethodNotAllowed, assertions: func(t *testing.T, resp *fasthttp.Response) { t.Helper() require.Equal(t, "nosniff", string(resp.Header.Peek(fiber.HeaderXContentTypeOptions))) require.Equal(t, envvarAllowHeader, string(resp.Header.Peek(fiber.HeaderAllow))) }, }, { name: "helmet+basicauth", setup: func(app *fiber.App) { app.Use(helmet.New()) app.Use(basicauth.New(basicauth.Config{ Realm: basicRealm, Unauthorized: func(c fiber.Ctx) error { c.Set(fiber.HeaderWWWAuthenticate, "Basic realm=\""+basicRealm+"\", charset=\"UTF-8\"") c.Set(fiber.HeaderCacheControl, "no-store") c.Set(fiber.HeaderVary, fiber.HeaderAuthorization) c.Status(fiber.StatusUnauthorized) return fiber.ErrUnauthorized }, })) }, expectedStatus: fiber.StatusUnauthorized, assertions: func(t *testing.T, resp *fasthttp.Response) { t.Helper() require.Equal(t, "nosniff", string(resp.Header.Peek(fiber.HeaderXContentTypeOptions))) require.Equal(t, "Basic realm=\""+basicRealm+"\", charset=\"UTF-8\"", string(resp.Header.Peek(fiber.HeaderWWWAuthenticate))) require.Equal(t, "no-store", string(resp.Header.Peek(fiber.HeaderCacheControl))) require.Equal(t, fiber.HeaderAuthorization, string(resp.Header.Peek(fiber.HeaderVary))) }, }, { name: "helmet+keyauth", setup: func(app *fiber.App) { app.Use(helmet.New()) app.Use(keyauth.New(keyauth.Config{ Realm: keyAuthRealm, Error: keyauth.ErrorInvalidToken, ErrorDescription: keyAuthErrorDesc, Validator: func(fiber.Ctx, string) (bool, error) { return false, nil }, ErrorHandler: func(c fiber.Ctx, _ error) error { c.Status(fiber.StatusUnauthorized) return fiber.ErrUnauthorized }, })) }, expectedStatus: fiber.StatusUnauthorized, assertions: func(t *testing.T, resp *fasthttp.Response) { t.Helper() require.Equal(t, "nosniff", string(resp.Header.Peek(fiber.HeaderXContentTypeOptions))) authenticate := string(resp.Header.Peek(fiber.HeaderWWWAuthenticate)) require.Contains(t, authenticate, "Bearer realm=\""+keyAuthRealm+"\"") require.Contains(t, authenticate, "error=\""+keyauth.ErrorInvalidToken+"\"") require.Contains(t, authenticate, "error_description=\""+keyAuthErrorDesc+"\"") }, }, { name: "helmet+compress", setup: func(app *fiber.App) { app.Use(helmet.New()) app.Use(compress.New()) app.Use(func(c fiber.Ctx) error { if err := c.Next(); err != nil { if handlerErr := app.Config().ErrorHandler(c, err); handlerErr != nil { return handlerErr } // Inflate the error body so the compress middleware has something to work with. if body := c.Response().Body(); len(body) > 0 { c.Response().SetBodyString(strings.Repeat(string(body), 32)) } } return nil }) }, configureRequest: func(req *fasthttp.Request) { req.Header.Set(fiber.HeaderAcceptEncoding, "gzip") }, assertions: func(t *testing.T, resp *fasthttp.Response) { t.Helper() require.Equal(t, "nosniff", string(resp.Header.Peek(fiber.HeaderXContentTypeOptions))) require.Equal(t, "gzip", string(resp.Header.Peek(fiber.HeaderContentEncoding))) require.Equal(t, fiber.HeaderAcceptEncoding, string(resp.Header.Peek(fiber.HeaderVary))) }, }, { name: "helmet+recover", setup: func(app *fiber.App) { app.Use(helmet.New()) app.Use(recover.New()) }, handler: func(fiber.Ctx) error { panic("panic for recover middleware") }, assertions: func(t *testing.T, resp *fasthttp.Response) { t.Helper() require.Equal(t, "nosniff", string(resp.Header.Peek(fiber.HeaderXContentTypeOptions))) // Recover writes a plain-text body; ensure we still return content to clients while // keeping Helmet's security headers intact. require.Contains(t, string(resp.Body()), "panic for recover middleware") }, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() app := fiber.New() tc.setup(app) // Every stack shares the same route that always hits the default error handler so we // can verify which headers survive the error response. A few cases override the // handler to exercise panic recovery or other routes that still flow through the // default error path. app.All("/", tc.handlerOrDefault()) resp := performOversizedRequest(t, app, tc.configureRequest) require.Equal(t, tc.statusOrDefault(), resp.StatusCode()) tc.assertions(t, resp) }) } } func Test_Integration_App_ServerErrorHandler_PreservesCORSHeadersOnBodyLimit(t *testing.T) { app := fiber.New(fiber.Config{BodyLimit: 16}) app.Use(cors.New(cors.Config{ AllowOrigins: []string{"https://example.com"}, AllowCredentials: true, ExposeHeaders: []string{"X-Request-ID"}, })) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) resp := performOversizedRequest(t, app, nil) require.Equal(t, fiber.StatusRequestEntityTooLarge, resp.StatusCode()) require.Equal(t, "https://example.com", string(resp.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) require.Equal(t, "true", string(resp.Header.Peek(fiber.HeaderAccessControlAllowCredentials))) require.Equal(t, "X-Request-ID", string(resp.Header.Peek(fiber.HeaderAccessControlExposeHeaders))) require.Equal(t, "Origin", string(resp.Header.Peek(fiber.HeaderVary))) } func Test_Integration_App_ServerErrorHandler_PreservesHelmetHeadersOnBodyLimit(t *testing.T) { app := fiber.New(fiber.Config{BodyLimit: 16}) app.Use(helmet.New()) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) resp := performOversizedRequest(t, app, nil) require.Equal(t, fiber.StatusRequestEntityTooLarge, resp.StatusCode()) require.Equal(t, "nosniff", string(resp.Header.Peek(fiber.HeaderXContentTypeOptions))) require.Equal(t, "same-origin", string(resp.Header.Peek("Cross-Origin-Opener-Policy"))) require.Equal(t, "same-origin", string(resp.Header.Peek("Cross-Origin-Resource-Policy"))) require.Equal(t, "require-corp", string(resp.Header.Peek("Cross-Origin-Embedder-Policy"))) } func Test_Integration_App_ServerErrorHandler_PreservesRequestID(t *testing.T) { const expectedRequestID = "integration-request-id" app := fiber.New(fiber.Config{BodyLimit: 16}) app.Use(requestid.New()) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) resp := performOversizedRequest(t, app, func(req *fasthttp.Request) { req.Header.Set("X-Request-ID", expectedRequestID) }) require.Equal(t, fiber.StatusRequestEntityTooLarge, resp.StatusCode()) require.Equal(t, expectedRequestID, string(resp.Header.Peek("X-Request-ID"))) } func Test_Integration_App_ServerErrorHandler_GroupMiddlewareChain(t *testing.T) { app := fiber.New(fiber.Config{BodyLimit: 16}) app.Use(helmet.New()) api := app.Group("/api") api.Use(requestid.New()) api.Use(func(c fiber.Ctx) error { c.Set("X-Group-Middleware", "active") return c.Next() }) api.Post("/resource", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) resp := performOversizedRequest(t, app, func(req *fasthttp.Request) { req.SetRequestURI("http://example.com/api/resource") }) require.Equal(t, fiber.StatusRequestEntityTooLarge, resp.StatusCode()) require.Equal(t, "nosniff", string(resp.Header.Peek(fiber.HeaderXContentTypeOptions))) require.NotEmpty(t, resp.Header.Peek("X-Request-ID")) require.Equal(t, "active", string(resp.Header.Peek("X-Group-Middleware"))) } func Test_Integration_App_ServerErrorHandler_RetainsHeadersFromSubsequentMiddleware(t *testing.T) { app := fiber.New(fiber.Config{BodyLimit: 8}) app.Use(func(c fiber.Ctx) error { c.Set("X-Custom-Middleware", "ran") return c.Next() }) app.Use(cors.New()) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) resp := performOversizedRequest(t, app, nil) require.Equal(t, fiber.StatusRequestEntityTooLarge, resp.StatusCode()) require.Equal(t, "ran", string(resp.Header.Peek("X-Custom-Middleware"))) require.Equal(t, "*", string(resp.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) } func Test_Integration_App_ServerErrorHandler_WithCustomCtx(t *testing.T) { app := fiber.NewWithCustomCtx(newIntegrationCustomCtx, fiber.Config{BodyLimit: 16}) app.Use(func(c fiber.Ctx) error { customCtx, ok := c.(*integrationCustomCtx) require.True(t, ok) customCtx.Set("X-Custom-Ctx", "true") return c.Next() }) app.Use(cors.New(cors.Config{AllowOrigins: []string{"https://example.org"}})) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) resp := performOversizedRequest(t, app, func(req *fasthttp.Request) { req.Header.Set(fiber.HeaderOrigin, "https://example.org") }) require.Equal(t, fiber.StatusRequestEntityTooLarge, resp.StatusCode()) require.Equal(t, "true", string(resp.Header.Peek("X-Custom-Ctx"))) require.Equal(t, "https://example.org", string(resp.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) } ================================================ FILE: app_test.go ================================================ // ⚡️ Fiber is an Express inspired web framework written in Go with ☕️ // 🤖 GitHub Repository: https://github.com/gofiber/fiber // 📌 API Documentation: https://docs.gofiber.io package fiber import ( "bufio" "bytes" "context" "crypto/tls" "encoding/json" "errors" "fmt" "io" "mime/multipart" "net" "net/http" "net/http/httptest" "os" "path/filepath" "reflect" "regexp" "runtime" "strconv" "strings" "sync" "testing" "time" "unsafe" "github.com/gofiber/utils/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" "github.com/valyala/fasthttp/fasthttputil" ) type fileView struct { path string content string loads int } func (v *fileView) Load() error { contents, err := os.ReadFile(v.path) if err != nil { return fmt.Errorf("read template: %w", err) } v.content = string(contents) v.loads++ return nil } func (*fileView) Render(io.Writer, string, any, ...string) error { return nil } func testEmptyHandler(_ Ctx) error { return nil } func testStatus200(t *testing.T, app *App, url, method string) { t.Helper() req := httptest.NewRequest(method, url, http.NoBody) resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") } func testErrorResponse(t *testing.T, err error, resp *http.Response, expectedBodyError string) { t.Helper() require.NoError(t, err, "app.Test(req)") require.Equal(t, 500, resp.StatusCode, "Status code") body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, expectedBodyError, string(body), "Response body") } func Test_App_Test_Goroutine_Leak_Compare(t *testing.T) { t.Parallel() testCases := []struct { handler Handler name string timeout time.Duration sleepTime time.Duration expectLeak bool }{ { name: "With timeout (potential leak)", handler: func(c Ctx) error { time.Sleep(300 * time.Millisecond) // Simulate time-consuming operation return c.SendString("ok") }, timeout: 50 * time.Millisecond, // // Short timeout to ensure triggering sleepTime: 500 * time.Millisecond, // Wait time longer than handler execution time expectLeak: true, }, { name: "Without timeout (no leak)", handler: func(c Ctx) error { return c.SendString("ok") // Return immediately }, timeout: 0, // Disable timeout sleepTime: 100 * time.Millisecond, expectLeak: false, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() app := New() // Record initial goroutine count initialGoroutines := runtime.NumGoroutine() t.Logf("[%s] Initial goroutines: %d", tc.name, initialGoroutines) app.Get("/", tc.handler) // Send 10 requests numRequests := 10 for range numRequests { req := httptest.NewRequest(MethodGet, "/", http.NoBody) if tc.timeout > 0 { _, err := app.Test(req, TestConfig{ Timeout: tc.timeout, FailOnTimeout: true, }) require.Error(t, err) require.ErrorIs(t, err, os.ErrDeadlineExceeded) } else if resp, err := app.Test(req); err != nil { t.Errorf("unexpected error: %v", err) } else { require.Equal(t, 200, resp.StatusCode) } } // Wait for normal goroutines to complete time.Sleep(tc.sleepTime) // Check final goroutine count finalGoroutines := runtime.NumGoroutine() leakedGoroutines := finalGoroutines - initialGoroutines if leakedGoroutines < 0 { leakedGoroutines = 0 } t.Logf("[%s] Final goroutines: %d (leaked: %d)", tc.name, finalGoroutines, leakedGoroutines) if tc.expectLeak { // We allow up to 3x the request count to account for noise; zero is tolerated. maxLeak := numRequests * 3 if leakedGoroutines > maxLeak { t.Errorf("[%s] Expected at most %d leaked goroutines, but got %d", tc.name, maxLeak, leakedGoroutines) } return } // No-leak scenario: allow a small buffer for runtime noise. // Increase slack to reduce flakes from background goroutines. if leakedGoroutines > numRequests { t.Errorf("[%s] Expected at most %d leaked goroutines, but got %d", tc.name, numRequests, leakedGoroutines) } }) } } func Test_App_MethodNotAllowed(t *testing.T) { t.Parallel() app := New() app.Use(func(c Ctx) error { return c.Next() }) app.Post("/", testEmptyHandler) app.Options("/", testEmptyHandler) resp, err := app.Test(httptest.NewRequest(MethodPost, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) require.Empty(t, resp.Header.Get(HeaderAllow)) resp, err = app.Test(httptest.NewRequest(MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, 405, resp.StatusCode) require.Equal(t, "POST, OPTIONS", resp.Header.Get(HeaderAllow)) resp, err = app.Test(httptest.NewRequest(MethodPatch, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, 405, resp.StatusCode) require.Equal(t, "POST, OPTIONS", resp.Header.Get(HeaderAllow)) resp, err = app.Test(httptest.NewRequest(MethodPut, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, 405, resp.StatusCode) require.Equal(t, "POST, OPTIONS", resp.Header.Get(HeaderAllow)) app.Get("/", testEmptyHandler) resp, err = app.Test(httptest.NewRequest(MethodTrace, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, 405, resp.StatusCode) require.Equal(t, "GET, HEAD, POST, OPTIONS", resp.Header.Get(HeaderAllow)) resp, err = app.Test(httptest.NewRequest(MethodPatch, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, 405, resp.StatusCode) require.Equal(t, "GET, HEAD, POST, OPTIONS", resp.Header.Get(HeaderAllow)) app.Head("/", testEmptyHandler) resp, err = app.Test(httptest.NewRequest(MethodPut, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, 405, resp.StatusCode) require.Equal(t, "GET, HEAD, POST, OPTIONS", resp.Header.Get(HeaderAllow)) } func Test_App_RegisterNetHTTPHandler(t *testing.T) { t.Parallel() tests := []struct { name string register func(app *App, path string, handler any) method string expectBody bool }{ { name: "Get", register: func(app *App, path string, handler any) { app.Get(path, handler) }, method: http.MethodGet, expectBody: true, }, { name: "Head", register: func(app *App, path string, handler any) { app.Head(path, handler) }, method: http.MethodHead, }, { name: "Post", register: func(app *App, path string, handler any) { app.Post(path, handler) }, method: http.MethodPost, expectBody: true, }, { name: "Put", register: func(app *App, path string, handler any) { app.Put(path, handler) }, method: http.MethodPut, expectBody: true, }, { name: "Delete", register: func(app *App, path string, handler any) { app.Delete(path, handler) }, method: http.MethodDelete, expectBody: true, }, { name: "Connect", register: func(app *App, path string, handler any) { app.Connect(path, handler) }, method: http.MethodConnect, expectBody: true, }, { name: "Options", register: func(app *App, path string, handler any) { app.Options(path, handler) }, method: http.MethodOptions, expectBody: true, }, { name: "Trace", register: func(app *App, path string, handler any) { app.Trace(path, handler) }, method: http.MethodTrace, expectBody: true, }, { name: "Patch", register: func(app *App, path string, handler any) { app.Patch(path, handler) }, method: http.MethodPatch, expectBody: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() app := New() handler := func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Test", r.Method) w.WriteHeader(http.StatusAccepted) if r.Method == http.MethodHead { return } _, err := w.Write([]byte("hello from net/http " + r.Method)) assert.NoError(t, err) } tt.register(app, "/foo", http.HandlerFunc(handler)) req := httptest.NewRequest(tt.method, "/foo", http.NoBody) if tt.method == http.MethodConnect { req.URL.Scheme = "http" req.URL.Host = "example.com" } resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, http.StatusAccepted, resp.StatusCode) require.Equal(t, tt.method, resp.Header.Get("X-Test")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) if tt.expectBody { require.Equal(t, "hello from net/http "+tt.method, string(body)) } else { require.Empty(t, body) } }) } } func Test_App_Custom_Middleware_404_Should_Not_SetMethodNotAllowed(t *testing.T) { t.Parallel() app := New() app.Use(func(c Ctx) error { return c.SendStatus(404) }) app.Post("/", testEmptyHandler) resp, err := app.Test(httptest.NewRequest(MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, 404, resp.StatusCode) require.Equal(t, MIMETextPlainCharsetUTF8, resp.Header.Get(HeaderContentType)) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "Not Found", string(body)) require.Equal(t, strconv.Itoa(len("Not Found")), resp.Header.Get(HeaderContentLength)) g := app.Group("/with-next", func(c Ctx) error { return c.Status(404).Next() }) g.Post("/", testEmptyHandler) resp, err = app.Test(httptest.NewRequest(MethodGet, "/with-next", http.NoBody)) require.NoError(t, err) require.Equal(t, 404, resp.StatusCode) require.Equal(t, MIMETextPlainCharsetUTF8, resp.Header.Get(HeaderContentType)) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "Not Found", string(body)) require.Equal(t, strconv.Itoa(len("Not Found")), resp.Header.Get(HeaderContentLength)) } func Test_App_ServerErrorHandler_SmallReadBuffer(t *testing.T) { t.Parallel() expectedError := regexp.MustCompile( `error when reading request headers: small read buffer\. Increase ReadBufferSize\. Buffer size=4096, contents: "GET / HTTP/1.1\\r\\nHost: example\.com\\r\\nVery-Long-Header: -+`, ) app := New() app.Get("/", func(_ Ctx) error { panic(errors.New("should never called")) }) request := httptest.NewRequest(MethodGet, "/", http.NoBody) logHeaderSlice := make([]string, 5000) request.Header.Set("Very-Long-Header", strings.Join(logHeaderSlice, "-")) _, err := app.Test(request) if err == nil { t.Error("Expect an error at app.Test(request)") } require.Regexp(t, expectedError, err.Error()) } func Test_App_Errors(t *testing.T) { t.Parallel() app := New(Config{ BodyLimit: 4, }) app.Get("/", func(_ Ctx) error { return errors.New("hi, i'm an error") }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 500, resp.StatusCode, "Status code") body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "hi, i'm an error", string(body)) _, err = app.Test(httptest.NewRequest(MethodGet, "/", strings.NewReader("big body"))) if err != nil { require.Equal(t, "body size exceeds the given limit", err.Error(), "app.Test(req)") } } func Test_App_BodyLimit_Negative(t *testing.T) { t.Parallel() limits := []int{-1, -512} for _, limit := range limits { app := New(Config{BodyLimit: limit}) app.Post("/", func(c Ctx) error { return c.SendStatus(StatusOK) }) largeBody := bytes.Repeat([]byte{'a'}, DefaultBodyLimit+1) req := httptest.NewRequest(MethodPost, "/", bytes.NewReader(largeBody)) _, err := app.Test(req) require.ErrorIs(t, err, fasthttp.ErrBodyTooLarge) smallBody := bytes.Repeat([]byte{'a'}, DefaultBodyLimit-1) req = httptest.NewRequest(MethodPost, "/", bytes.NewReader(smallBody)) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, StatusOK, resp.StatusCode) } } func Test_App_BodyLimit_Zero(t *testing.T) { t.Parallel() app := New(Config{BodyLimit: 0}) app.Post("/", func(c Ctx) error { return c.SendStatus(StatusOK) }) largeBody := bytes.Repeat([]byte{'a'}, DefaultBodyLimit+1) req := httptest.NewRequest(MethodPost, "/", bytes.NewReader(largeBody)) _, err := app.Test(req) require.ErrorIs(t, err, fasthttp.ErrBodyTooLarge) smallBody := bytes.Repeat([]byte{'a'}, DefaultBodyLimit-1) req = httptest.NewRequest(MethodPost, "/", bytes.NewReader(smallBody)) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, StatusOK, resp.StatusCode) } func Test_App_BodyLimit_LargerThanDefault(t *testing.T) { t.Parallel() limit := DefaultBodyLimit*2 + 1024 // slightly above double the default app := New(Config{BodyLimit: limit}) app.Post("/", func(c Ctx) error { return c.SendStatus(StatusOK) }) // Body larger than the default but within our custom limit should succeed midBody := bytes.Repeat([]byte{'a'}, DefaultBodyLimit+512) req := httptest.NewRequest(MethodPost, "/", bytes.NewReader(midBody)) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, StatusOK, resp.StatusCode) // Body above the custom limit should fail largeBody := bytes.Repeat([]byte{'a'}, limit+1) req = httptest.NewRequest(MethodPost, "/", bytes.NewReader(largeBody)) _, err = app.Test(req) require.ErrorIs(t, err, fasthttp.ErrBodyTooLarge) } type customConstraint struct{} func (*customConstraint) Name() string { return "test" } func (*customConstraint) Execute(param string, args ...string) bool { if param == "test" && len(args) == 1 && args[0] == "test" { return true } if len(args) == 0 && param == "c" { return true } return false } func Test_App_CustomConstraint(t *testing.T) { t.Parallel() app := New() app.RegisterCustomConstraint(&customConstraint{}) app.Get("/test/:param", func(c Ctx) error { return c.SendString("test") }) app.Get("/test2/:param", func(c Ctx) error { return c.SendString("test") }) app.Get("/test3/:param", func(c Ctx) error { return c.SendString("test") }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/test/test", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") resp, err = app.Test(httptest.NewRequest(MethodGet, "/test/test2", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 404, resp.StatusCode, "Status code") resp, err = app.Test(httptest.NewRequest(MethodGet, "/test2/c", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") resp, err = app.Test(httptest.NewRequest(MethodGet, "/test2/cc", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 404, resp.StatusCode, "Status code") resp, err = app.Test(httptest.NewRequest(MethodGet, "/test3/cc", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 404, resp.StatusCode, "Status code") } func Test_App_ErrorHandler_Custom(t *testing.T) { t.Parallel() app := New(Config{ ErrorHandler: func(c Ctx, _ error) error { return c.Status(200).SendString("hi, i'm a custom error") }, }) app.Get("/", func(_ Ctx) error { return errors.New("hi, i'm an error") }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "hi, i'm a custom error", string(body)) } func Test_App_ErrorHandler_HandlerStack(t *testing.T) { t.Parallel() app := New(Config{ ErrorHandler: func(c Ctx, err error) error { require.Equal(t, "1: USE error", err.Error()) return DefaultErrorHandler(c, err) }, }) app.Use("/", func(c Ctx) error { err := c.Next() // call next USE require.Equal(t, "2: USE error", err.Error()) return errors.New("1: USE error") }, func(c Ctx) error { err := c.Next() // call [0] GET require.Equal(t, "0: GET error", err.Error()) return errors.New("2: USE error") }) app.Get("/", func(_ Ctx) error { return errors.New("0: GET error") }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 500, resp.StatusCode, "Status code") body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "1: USE error", string(body)) } func Test_App_ErrorHandler_RouteStack(t *testing.T) { t.Parallel() app := New(Config{ ErrorHandler: func(c Ctx, err error) error { require.Equal(t, "1: USE error", err.Error()) return DefaultErrorHandler(c, err) }, }) app.Use("/", func(c Ctx) error { err := c.Next() require.Equal(t, "0: GET error", err.Error()) return errors.New("1: USE error") // [2] call ErrorHandler }) app.Get("/test", func(_ Ctx) error { return errors.New("0: GET error") // [1] return to USE }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/test", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 500, resp.StatusCode, "Status code") body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "1: USE error", string(body)) } func Test_App_serverErrorHandler_Internal_Error(t *testing.T) { t.Parallel() app := New() msg := "test err" c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed app.serverErrorHandler(c.fasthttp, errors.New(msg)) require.Equal(t, string(c.fasthttp.Response.Body()), msg) require.Equal(t, StatusBadRequest, c.fasthttp.Response.StatusCode()) } func Test_App_serverErrorHandler_Network_Error(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed app.serverErrorHandler(c.fasthttp, &net.DNSError{ Err: "test error", Name: "test host", IsTimeout: false, }) require.Equal(t, string(c.fasthttp.Response.Body()), utils.StatusMessage(StatusBadGateway)) require.Equal(t, StatusBadGateway, c.fasthttp.Response.StatusCode()) } func Test_App_serverErrorHandler_Unsupported_Method_Error(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed app.serverErrorHandler(c.fasthttp, errors.New("unsupported http request method 'FOO'")) require.Equal(t, utils.StatusMessage(StatusNotImplemented), string(c.fasthttp.Response.Body())) require.Equal(t, StatusNotImplemented, c.fasthttp.Response.StatusCode()) } func Test_App_serverErrorHandler_Unsupported_Method_Request(t *testing.T) { t.Parallel() app := New() app.Get("/bar", func(c Ctx) error { return c.SendString("bar") }) ln := fasthttputil.NewInmemoryListener() serverStarted := make(chan struct{}, 1) serverErr := make(chan error, 1) go func() { serverStarted <- struct{}{} if err := app.Listener(ln); err != nil && !errors.Is(err, http.ErrServerClosed) { serverErr <- err return } serverErr <- nil }() <-serverStarted conn, err := ln.Dial() require.NoError(t, err) require.NoError(t, conn.SetDeadline(time.Now().Add(5*time.Second))) _, err = conn.Write([]byte("FOO /bar HTTP/1.1\r\nHost: example.com\r\n\r\n")) require.NoError(t, err) resp, err := http.ReadResponse(bufio.NewReader(conn), nil) require.NoError(t, err) require.Equal(t, StatusNotImplemented, resp.StatusCode) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, utils.StatusMessage(StatusNotImplemented), string(body)) require.NoError(t, resp.Body.Close()) require.NoError(t, conn.Close()) require.NoError(t, app.Shutdown()) require.NoError(t, <-serverErr) } func Test_App_Nested_Params(t *testing.T) { t.Parallel() app := New() app.Get("/test", func(c Ctx) error { return c.Status(400).Send([]byte("Should move on")) }) app.Get("/test/:param", func(c Ctx) error { return c.Status(400).Send([]byte("Should move on")) }) app.Get("/test/:param/test", func(c Ctx) error { return c.Status(400).Send([]byte("Should move on")) }) app.Get("/test/:param/test/:param2", func(c Ctx) error { return c.Status(200).Send([]byte("Good job")) }) req := httptest.NewRequest(MethodGet, "/test/john/test/doe", http.NoBody) resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") } func Test_App_Use_Params(t *testing.T) { t.Parallel() app := New() app.Use("/prefix/:param", func(c Ctx) error { require.Equal(t, "john", c.Params("param")) return nil }) app.Use("/foo/:bar?", func(c Ctx) error { require.Equal(t, "foobar", c.Params("bar", "foobar")) return nil }) app.Use("/:param/*", func(c Ctx) error { require.Equal(t, "john", c.Params("param")) require.Equal(t, "doe", c.Params("*")) return nil }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/prefix/john", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") resp, err = app.Test(httptest.NewRequest(MethodGet, "/john/doe", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") resp, err = app.Test(httptest.NewRequest(MethodGet, "/foo", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.PanicsWithValue(t, "use: invalid handler func()\n", func() { app.Use("/:param/*", func() { // this should panic }) }) } func Test_App_Use_UnescapedPath(t *testing.T) { t.Parallel() app := New(Config{UnescapePath: true, CaseSensitive: true}) app.Use("/cRéeR/:param", func(c Ctx) error { require.Equal(t, "/cRéeR/اختبار", c.Path()) return c.SendString(c.Params("param")) }) app.Use("/abc", func(c Ctx) error { require.Equal(t, "/AbC", c.Path()) return nil }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/cR%C3%A9eR/%D8%A7%D8%AE%D8%AA%D8%A8%D8%A7%D8%B1", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") body, err := io.ReadAll(resp.Body) require.NoError(t, err, "app.Test(req)") // check the param result require.Equal(t, "اختبار", app.toString(body)) // with lowercase letters resp, err = app.Test(httptest.NewRequest(MethodGet, "/cr%C3%A9er/%D8%A7%D8%AE%D8%AA%D8%A8%D8%A7%D8%B1", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusNotFound, resp.StatusCode, "Status code") } func Test_App_Use_CaseSensitive(t *testing.T) { t.Parallel() app := New(Config{CaseSensitive: true}) app.Use("/abc", func(c Ctx) error { return c.SendString(c.Path()) }) // wrong letters in the requested route -> 404 resp, err := app.Test(httptest.NewRequest(MethodGet, "/AbC", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusNotFound, resp.StatusCode, "Status code") // right letters in the requested route -> 200 resp, err = app.Test(httptest.NewRequest(MethodGet, "/abc", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") // check the detected path when the case-insensitive recognition is activated app.config.CaseSensitive = false // check the case-sensitive feature resp, err = app.Test(httptest.NewRequest(MethodGet, "/AbC", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") body, err := io.ReadAll(resp.Body) require.NoError(t, err, "app.Test(req)") // check the detected path result require.Equal(t, "/AbC", app.toString(body)) } func Test_App_Not_Use_StrictRouting(t *testing.T) { t.Parallel() app := New() app.Use("/abc", func(c Ctx) error { return c.SendString(c.Path()) }) g := app.Group("/foo") g.Use("/", func(c Ctx) error { return c.SendString(c.Path()) }) // wrong path in the requested route -> 404 resp, err := app.Test(httptest.NewRequest(MethodGet, "/abc/", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") // right path in the requested route -> 200 resp, err = app.Test(httptest.NewRequest(MethodGet, "/abc", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") // wrong path with group in the requested route -> 404 resp, err = app.Test(httptest.NewRequest(MethodGet, "/foo", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") // right path with group in the requested route -> 200 resp, err = app.Test(httptest.NewRequest(MethodGet, "/foo/", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") } func Test_App_Use_MultiplePrefix(t *testing.T) { t.Parallel() app := New() app.Use([]string{"/john", "/doe"}, func(c Ctx) error { return c.SendString(c.Path()) }) g := app.Group("/test") g.Use([]string{"/john", "/doe"}, func(c Ctx) error { return c.SendString(c.Path()) }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/john", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "/john", string(body)) resp, err = app.Test(httptest.NewRequest(MethodGet, "/doe", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "/doe", string(body)) resp, err = app.Test(httptest.NewRequest(MethodGet, "/test/john", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "/test/john", string(body)) resp, err = app.Test(httptest.NewRequest(MethodGet, "/test/doe", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "/test/doe", string(body)) } func Test_Group_Use_NoBoundary(t *testing.T) { t.Parallel() app := New() grp := app.Group("/api") grp.Use("/foo", func(c Ctx) error { return c.SendStatus(StatusOK) }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/api/foo/bar", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") resp, err = app.Test(httptest.NewRequest(MethodGet, "/api/foobar", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusNotFound, resp.StatusCode, "Status code") } func Test_App_Use_StrictRouting(t *testing.T) { t.Parallel() app := New(Config{StrictRouting: true}) app.Get("/abc", func(c Ctx) error { return c.SendString(c.Path()) }) g := app.Group("/foo") g.Get("/", func(c Ctx) error { return c.SendString(c.Path()) }) // wrong path in the requested route -> 404 resp, err := app.Test(httptest.NewRequest(MethodGet, "/abc/", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusNotFound, resp.StatusCode, "Status code") // right path in the requested route -> 200 resp, err = app.Test(httptest.NewRequest(MethodGet, "/abc", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") // wrong path with group in the requested route -> 404 resp, err = app.Test(httptest.NewRequest(MethodGet, "/foo", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusNotFound, resp.StatusCode, "Status code") // right path with group in the requested route -> 200 resp, err = app.Test(httptest.NewRequest(MethodGet, "/foo/", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") } func Test_App_Add_Method_Test(t *testing.T) { t.Parallel() methods := append(DefaultMethods, "JOHN") //nolint:gocritic // We want a new slice here app := New(Config{ RequestMethods: methods, }) app.Add([]string{"JOHN"}, "/john", testEmptyHandler) resp, err := app.Test(httptest.NewRequest("JOHN", "/john", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") resp, err = app.Test(httptest.NewRequest(MethodGet, "/john", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusMethodNotAllowed, resp.StatusCode, "Status code") resp, err = app.Test(httptest.NewRequest("UNKNOWN", "/john", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusNotImplemented, resp.StatusCode, "Status code") // Add a new method require.Panics(t, func() { app.Add([]string{"JANE"}, "/jane", testEmptyHandler) }) } func Test_App_All_Method_Test(t *testing.T) { t.Parallel() methods := append(DefaultMethods, "JOHN") //nolint:gocritic // We want a new slice here app := New(Config{ RequestMethods: methods, }) // Add a new method with All app.All("/doe", testEmptyHandler) resp, err := app.Test(httptest.NewRequest("JOHN", "/doe", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") // Add a new method require.Panics(t, func() { app.Add([]string{"JANE"}, "/jane", testEmptyHandler) }) } // go test -run Test_App_GETOnly func Test_App_GETOnly(t *testing.T) { t.Parallel() app := New(Config{ GETOnly: true, }) app.Post("/", func(c Ctx) error { return c.SendString("Hello 👋!") }) req := httptest.NewRequest(MethodPost, "/", http.NoBody) resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusMethodNotAllowed, resp.StatusCode, "Status code") } func Test_App_Use_Params_Group(t *testing.T) { t.Parallel() app := New() group := app.Group("/prefix/:param/*") group.Use("/", func(c Ctx) error { return c.Next() }) group.Get("/test", func(c Ctx) error { require.Equal(t, "john", c.Params("param")) require.Equal(t, "doe", c.Params("*")) return nil }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/prefix/john/doe/test", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") } func Test_App_Chaining(t *testing.T) { t.Parallel() n := func(c Ctx) error { return c.Next() } app := New() app.Use("/john", n, n, n, n, func(c Ctx) error { return c.SendStatus(202) }) // check handler count for registered HEAD route require.Len(t, app.stack[app.methodInt(MethodHead)][0].Handlers, 5, "app.Test(req)") req := httptest.NewRequest(MethodPost, "/john", http.NoBody) resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") require.Equal(t, 202, resp.StatusCode, "Status code") app.Get("/test", n, n, n, n, func(c Ctx) error { return c.SendStatus(203) }) req = httptest.NewRequest(MethodGet, "/test", http.NoBody) resp, err = app.Test(req) require.NoError(t, err, "app.Test(req)") require.Equal(t, 203, resp.StatusCode, "Status code") } func Test_App_Order(t *testing.T) { t.Parallel() app := New() app.Get("/test", func(c Ctx) error { _, err := c.WriteString("1") require.NoError(t, err) return c.Next() }) app.All("/test", func(c Ctx) error { _, err := c.WriteString("2") require.NoError(t, err) return c.Next() }) app.Use(func(c Ctx) error { _, err := c.WriteString("3") require.NoError(t, err) return c.SendStatus(StatusOK) }) req := httptest.NewRequest(MethodGet, "/test", http.NoBody) resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "123", string(body)) } func Test_App_AutoHead_Compliance(t *testing.T) { t.Parallel() app := New() app.Get("/hello", func(c Ctx) error { c.Set("X-Test", "string") return c.SendString("hello") }) app.startupProcess() getReq := httptest.NewRequest(MethodGet, "/hello", http.NoBody) getResp, err := app.Test(getReq) require.NoError(t, err, "app.Test(get)") defer func() { require.NoError(t, getResp.Body.Close()) }() body, err := io.ReadAll(getResp.Body) require.NoError(t, err) require.Equal(t, "hello", string(body)) require.Equal(t, "string", getResp.Header.Get("X-Test")) headReq := httptest.NewRequest(MethodHead, "/hello", http.NoBody) headResp, err := app.Test(headReq) require.NoError(t, err, "app.Test(head)") defer func() { require.NoError(t, headResp.Body.Close()) }() require.Equal(t, getResp.StatusCode, headResp.StatusCode) require.Equal(t, strconv.Itoa(len(body)), headResp.Header.Get(HeaderContentLength)) require.Equal(t, getResp.Header.Get(HeaderContentType), headResp.Header.Get(HeaderContentType)) require.Equal(t, getResp.Header.Get("X-Test"), headResp.Header.Get("X-Test")) headBody, err := io.ReadAll(headResp.Body) require.NoError(t, err) require.Empty(t, headBody) } func Test_App_AutoHead_Compliance_SendFile(t *testing.T) { t.Parallel() if runtime.GOOS == "windows" { t.Skip("SendFile auto-HEAD test is skipped on Windows due to file locking semantics") } tmpDir := t.TempDir() filePath := filepath.Join(tmpDir, "hello.txt") fileContent := []byte("file-body") require.NoError(t, os.WriteFile(filePath, fileContent, 0o644)) //nolint:gosec // permissions match test fixtures app := New() app.Get("/file", func(c Ctx) error { c.Set("X-Test", "file") return c.SendFile(filePath) }) app.startupProcess() getReq := httptest.NewRequest(MethodGet, "/file", http.NoBody) getResp, err := app.Test(getReq) require.NoError(t, err, "app.Test(get)") defer func() { require.NoError(t, getResp.Body.Close()) }() body, err := io.ReadAll(getResp.Body) require.NoError(t, err) require.Equal(t, fileContent, body) require.Equal(t, "file", getResp.Header.Get("X-Test")) headReq := httptest.NewRequest(MethodHead, "/file", http.NoBody) headResp, err := app.Test(headReq) require.NoError(t, err, "app.Test(head)") defer func() { require.NoError(t, headResp.Body.Close()) }() require.Equal(t, getResp.StatusCode, headResp.StatusCode) require.Equal(t, strconv.Itoa(len(fileContent)), headResp.Header.Get(HeaderContentLength)) require.Equal(t, getResp.Header.Get(HeaderContentType), headResp.Header.Get(HeaderContentType)) require.Equal(t, getResp.Header.Get("X-Test"), headResp.Header.Get("X-Test")) headBody, err := io.ReadAll(headResp.Body) require.NoError(t, err) require.Empty(t, headBody) } func Test_App_Methods(t *testing.T) { t.Parallel() dummyHandler := testEmptyHandler app := New() app.Connect("/:john?/:doe?", dummyHandler) testStatus200(t, app, "/john/doe", "CONNECT") app.Put("/:john?/:doe?", dummyHandler) testStatus200(t, app, "/john/doe", MethodPut) app.Post("/:john?/:doe?", dummyHandler) testStatus200(t, app, "/john/doe", MethodPost) app.Delete("/:john?/:doe?", dummyHandler) testStatus200(t, app, "/john/doe", MethodDelete) app.Head("/:john?/:doe?", dummyHandler) testStatus200(t, app, "/john/doe", MethodHead) app.Patch("/:john?/:doe?", dummyHandler) testStatus200(t, app, "/john/doe", MethodPatch) app.Options("/:john?/:doe?", dummyHandler) testStatus200(t, app, "/john/doe", MethodOptions) app.Trace("/:john?/:doe?", dummyHandler) testStatus200(t, app, "/john/doe", MethodTrace) app.Get("/:john?/:doe?", dummyHandler) testStatus200(t, app, "/john/doe", MethodGet) app.All("/:john?/:doe?", dummyHandler) testStatus200(t, app, "/john/doe", MethodPost) app.Use("/:john?/:doe?", dummyHandler) testStatus200(t, app, "/john/doe", MethodGet) } func Test_App_Route_Naming(t *testing.T) { t.Parallel() app := New() handler := func(c Ctx) error { return c.SendStatus(StatusOK) } app.Get("/john", handler).Name("john") app.Delete("/doe", handler) app.Name("doe") jane := app.Group("/jane").Name("jane.") group := app.Group("/group") subGroup := jane.Group("/sub-group").Name("sub.") jane.Get("/test", handler).Name("test") jane.Trace("/trace", handler).Name("trace") group.Get("/test", handler).Name("test") app.Post("/post", handler).Name("post") subGroup.Get("/done", handler).Name("done") require.Equal(t, "post", app.GetRoute("post").Name) require.Equal(t, "john", app.GetRoute("john").Name) require.Equal(t, "jane.test", app.GetRoute("jane.test").Name) require.Equal(t, "jane.trace", app.GetRoute("jane.trace").Name) require.Equal(t, "jane.sub.done", app.GetRoute("jane.sub.done").Name) require.Equal(t, "test", app.GetRoute("test").Name) } func Test_App_New(t *testing.T) { t.Parallel() app := New() app.Get("/", testEmptyHandler) appConfig := New(Config{ Immutable: true, }) appConfig.Get("/", testEmptyHandler) } func Test_App_Config(t *testing.T) { t.Parallel() app := New(Config{ StrictRouting: true, }) require.True(t, app.Config().StrictRouting) } func Test_App_GetString(t *testing.T) { t.Parallel() heap := string([]byte("fiber")) appMutable := New() same := appMutable.GetString(heap) if unsafe.StringData(same) != unsafe.StringData(heap) { //nolint:gosec // compare pointer addresses t.Error("expected original string when immutable is disabled") } appImmutable := New(Config{Immutable: true}) copied := appImmutable.GetString(heap) if unsafe.StringData(copied) == unsafe.StringData(heap) { //nolint:gosec // compare pointer addresses t.Error("expected a copy for heap-backed string when immutable is enabled") } literal := "fiber" sameLit := appImmutable.GetString(literal) if unsafe.StringData(sameLit) != unsafe.StringData(literal) { //nolint:gosec // compare pointer addresses t.Error("expected original literal when immutable is enabled") } } func Test_App_GetBytes(t *testing.T) { t.Parallel() b := []byte("fiber") appMutable := New() same := appMutable.GetBytes(b) if unsafe.SliceData(same) != unsafe.SliceData(b) { //nolint:gosec // compare pointer addresses t.Error("expected original slice when immutable is disabled") } alias := make([]byte, 10) copy(alias, b) sub := alias[:5] appImmutable := New(Config{Immutable: true}) copied := appImmutable.GetBytes(sub) if unsafe.SliceData(copied) == unsafe.SliceData(sub) { //nolint:gosec // compare pointer addresses t.Error("expected a copy for aliased slice when immutable is enabled") } full := make([]byte, 5) copy(full, b) detached := appImmutable.GetBytes(full) if unsafe.SliceData(detached) == unsafe.SliceData(full) { //nolint:gosec // compare pointer addresses t.Error("expected a copy even when cap==len") } } func Test_App_Shutdown(t *testing.T) { t.Parallel() t.Run("success", func(t *testing.T) { t.Parallel() app := New() require.NoError(t, app.Shutdown()) }) t.Run("no server", func(t *testing.T) { t.Parallel() app := &App{} if err := app.Shutdown(); err != nil { require.ErrorContains(t, err, "shutdown: server is not running") } }) } func Test_App_ShutdownWithTimeout(t *testing.T) { t.Parallel() app := New() app.Get("/", func(c Ctx) error { time.Sleep(5 * time.Second) return c.SendString("body") }) ln := fasthttputil.NewInmemoryListener() serverReady := make(chan struct{}) // Signal that the server is ready to start go func() { serverReady <- struct{}{} err := app.Listener(ln) assert.NoError(t, err) }() <-serverReady // Waiting for the server to be ready // Create a connection and send a request connReady := make(chan struct{}) go func() { conn, err := ln.Dial() assert.NoError(t, err) _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")) assert.NoError(t, err) connReady <- struct{}{} // Signal that the request has been sent }() <-connReady // Waiting for the request to be sent shutdownErr := make(chan error) go func() { shutdownErr <- app.ShutdownWithTimeout(1 * time.Second) }() timer := time.NewTimer(time.Second * 5) select { case <-timer.C: t.Fatal("idle connections not closed on shutdown") case err := <-shutdownErr: if err == nil || !errors.Is(err, context.DeadlineExceeded) { t.Fatalf("unexpected err %v. Expecting %v", err, context.DeadlineExceeded) } } } func Test_App_ShutdownWithContext(t *testing.T) { t.Parallel() t.Run("successful shutdown", func(t *testing.T) { t.Parallel() app := New() // Fast request that should complete app.Get("/", func(c Ctx) error { return c.SendString("OK") }) ln := fasthttputil.NewInmemoryListener() serverStarted := make(chan bool, 1) go func() { serverStarted <- true if err := app.Listener(ln); err != nil { t.Errorf("Failed to start listener: %v", err) } }() <-serverStarted // Execute normal request conn, err := ln.Dial() require.NoError(t, err) _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")) require.NoError(t, err) // Shutdown with sufficient timeout ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() err = app.ShutdownWithContext(ctx) require.NoError(t, err, "Expected successful shutdown") }) t.Run("shutdown with hooks", func(t *testing.T) { t.Parallel() app := New() hookOrder := make([]string, 0) var hookMutex sync.Mutex app.Hooks().OnPreShutdown(func() error { hookMutex.Lock() hookOrder = append(hookOrder, "pre") hookMutex.Unlock() return nil }) app.Hooks().OnPostShutdown(func(_ error) error { hookMutex.Lock() hookOrder = append(hookOrder, "post") hookMutex.Unlock() return nil }) ln := fasthttputil.NewInmemoryListener() go func() { if err := app.Listener(ln); err != nil { t.Errorf("Failed to start listener: %v", err) } }() time.Sleep(100 * time.Millisecond) err := app.ShutdownWithContext(context.Background()) require.NoError(t, err) require.Equal(t, []string{"pre", "post"}, hookOrder, "Hooks should execute in order") }) t.Run("timeout with long running request", func(t *testing.T) { t.Parallel() app := New() requestStarted := make(chan struct{}) requestProcessing := make(chan struct{}) app.Get("/", func(c Ctx) error { close(requestStarted) // Wait for signal to continue processing the request <-requestProcessing time.Sleep(2 * time.Second) return c.SendString("OK") }) ln := fasthttputil.NewInmemoryListener() go func() { if err := app.Listener(ln); err != nil { t.Errorf("Failed to start listener: %v", err) } }() // Ensure server is fully started time.Sleep(100 * time.Millisecond) // Start a long-running request go func() { conn, err := ln.Dial() if err != nil { t.Errorf("Failed to dial: %v", err) return } if _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")); err != nil { t.Errorf("Failed to write: %v", err) } }() // Wait for request to start select { case <-requestStarted: // Request has started, signal to continue processing close(requestProcessing) case <-time.After(2 * time.Second): t.Fatal("Request did not start in time") } // Attempt shutdown, should timeout ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) defer cancel() err := app.ShutdownWithContext(ctx) require.ErrorIs(t, err, context.DeadlineExceeded) }) } func Test_App_OptionsAsterisk(t *testing.T) { t.Parallel() app := New() app.Options("/resource", func(c Ctx) error { c.Set(HeaderAllow, "GET") c.Status(StatusNoContent) return nil }) app.Options("*", func(c Ctx) error { c.Set(HeaderAllow, "GET, POST") c.Status(StatusOK) return nil }) ln := fasthttputil.NewInmemoryListener() errCh := make(chan error, 1) serverReady := make(chan struct{}) go func() { serverReady <- struct{}{} errCh <- app.Listener(ln) }() <-serverReady t.Cleanup(func() { require.NoError(t, app.Shutdown()) require.NoError(t, <-errCh) }) writeRequest := func(conn net.Conn, raw string) { t.Helper() _, err := conn.Write([]byte(raw)) require.NoError(t, err) } conn, err := ln.Dial() require.NoError(t, err) writeRequest(conn, "OPTIONS * HTTP/1.1\r\nHost: example.com\r\n\r\n") resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: http.MethodOptions}) require.NoError(t, err) require.Equal(t, StatusOK, resp.StatusCode) require.Equal(t, "GET, POST", resp.Header.Get(HeaderAllow)) require.Zero(t, resp.ContentLength) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Empty(t, body) require.NoError(t, resp.Body.Close()) require.NoError(t, conn.Close()) controlConn, err := ln.Dial() require.NoError(t, err) writeRequest(controlConn, "OPTIONS /resource HTTP/1.1\r\nHost: example.com\r\n\r\n") controlResp, err := http.ReadResponse(bufio.NewReader(controlConn), &http.Request{Method: http.MethodOptions}) require.NoError(t, err) require.Equal(t, StatusNoContent, controlResp.StatusCode) require.Equal(t, "GET", controlResp.Header.Get(HeaderAllow)) require.Zero(t, controlResp.ContentLength) controlBody, err := io.ReadAll(controlResp.Body) require.NoError(t, err) require.Empty(t, controlBody) require.NoError(t, controlResp.Body.Close()) require.NoError(t, controlConn.Close()) } // go test -run Test_App_Mixed_Routes_WithSameLen func Test_App_Mixed_Routes_WithSameLen(t *testing.T) { t.Parallel() app := New() // middleware app.Use(func(c Ctx) error { c.Set("TestHeader", "TestValue") return c.Next() }) // routes with the same length app.Get("/tesbar", func(c Ctx) error { c.Type("html") return c.Send([]byte("TEST_BAR")) }) app.Get("/foobar", func(c Ctx) error { c.Type("html") return c.Send([]byte("FOO_BAR")) }) // match get route req := httptest.NewRequest(MethodGet, "/foobar", http.NoBody) resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.NotEmpty(t, resp.Header.Get(HeaderContentLength)) require.Equal(t, "TestValue", resp.Header.Get("TestHeader")) require.Equal(t, "text/html; charset=utf-8", resp.Header.Get(HeaderContentType)) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "FOO_BAR", string(body)) // match static route req = httptest.NewRequest(MethodGet, "/tesbar", http.NoBody) resp, err = app.Test(req) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.NotEmpty(t, resp.Header.Get(HeaderContentLength)) require.Equal(t, "TestValue", resp.Header.Get("TestHeader")) require.Equal(t, "text/html; charset=utf-8", resp.Header.Get(HeaderContentType)) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Contains(t, string(body), "TEST_BAR") } func Test_App_Group_Invalid(t *testing.T) { t.Parallel() require.PanicsWithValue(t, "use: invalid handler int\n", func() { New().Group("/").Use(1) }) } func Test_App_Group(t *testing.T) { t.Parallel() dummyHandler := testEmptyHandler app := New() grp := app.Group("/test") grp.Get("/", dummyHandler) testStatus200(t, app, "/test", MethodGet) grp.Get("/:demo?", dummyHandler) testStatus200(t, app, "/test/john", MethodGet) grp.Connect("/CONNECT", dummyHandler) testStatus200(t, app, "/test/CONNECT", MethodConnect) grp.Put("/PUT", dummyHandler) testStatus200(t, app, "/test/PUT", MethodPut) grp.Post("/POST", dummyHandler) testStatus200(t, app, "/test/POST", MethodPost) grp.Delete("/DELETE", dummyHandler) testStatus200(t, app, "/test/DELETE", MethodDelete) grp.Head("/HEAD", dummyHandler) testStatus200(t, app, "/test/HEAD", MethodHead) grp.Patch("/PATCH", dummyHandler) testStatus200(t, app, "/test/PATCH", MethodPatch) grp.Options("/OPTIONS", dummyHandler) testStatus200(t, app, "/test/OPTIONS", MethodOptions) grp.Trace("/TRACE", dummyHandler) testStatus200(t, app, "/test/TRACE", MethodTrace) grp.All("/ALL", dummyHandler) testStatus200(t, app, "/test/ALL", MethodPost) grp.Use(dummyHandler) testStatus200(t, app, "/test/oke", MethodGet) grp.Use("/USE", dummyHandler) testStatus200(t, app, "/test/USE/oke", MethodGet) api := grp.Group("/v1") api.Post("/", dummyHandler) resp, err := app.Test(httptest.NewRequest(MethodPost, "/test/v1/", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") api.Get("/users", dummyHandler) resp, err = app.Test(httptest.NewRequest(MethodGet, "/test/v1/UsErS", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") } func Test_App_RouteChain(t *testing.T) { t.Parallel() dummyHandler := testEmptyHandler app := New() register := app.RouteChain("/test"). Get(dummyHandler). Head(dummyHandler). Post(dummyHandler). Put(dummyHandler). Delete(dummyHandler). Connect(dummyHandler). Options(dummyHandler). Trace(dummyHandler). Patch(dummyHandler) testStatus200(t, app, "/test", MethodGet) testStatus200(t, app, "/test", MethodHead) testStatus200(t, app, "/test", MethodPost) testStatus200(t, app, "/test", MethodPut) testStatus200(t, app, "/test", MethodDelete) testStatus200(t, app, "/test", MethodConnect) testStatus200(t, app, "/test", MethodOptions) testStatus200(t, app, "/test", MethodTrace) testStatus200(t, app, "/test", MethodPatch) register.RouteChain("/v1").Get(dummyHandler).Post(dummyHandler) resp, err := app.Test(httptest.NewRequest(MethodPost, "/test/v1", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") resp, err = app.Test(httptest.NewRequest(MethodGet, "/test/v1", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") register.RouteChain("/v1").RouteChain("/v2").RouteChain("/v3").Get(dummyHandler).Trace(dummyHandler) resp, err = app.Test(httptest.NewRequest(MethodTrace, "/test/v1/v2/v3", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") resp, err = app.Test(httptest.NewRequest(MethodGet, "/test/v1/v2/v3", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") } func Test_App_Route(t *testing.T) { t.Parallel() app := New() app.Route("/test", func(api Router) { api.Get("/foo", testEmptyHandler).Name("foo") api.Route("/bar", func(bar Router) { bar.Get("/", testEmptyHandler).Name("index") }, "bar.") }, "test.") testStatus200(t, app, "/test/foo", MethodGet) resp, err := app.Test(httptest.NewRequest(MethodGet, "/test/bar/", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, http.StatusOK, resp.StatusCode, "Status code") require.Equal(t, "/test/foo", app.GetRoute("test.foo").Path) require.Equal(t, "/test/bar/", app.GetRoute("test.bar.index").Path) } func Test_App_Route_nilFuncPanics(t *testing.T) { t.Parallel() app := New() require.PanicsWithValue(t, "route handler 'fn' cannot be nil", func() { app.Route("/panic", nil) }) } func Test_Group_Route_nilFuncPanics(t *testing.T) { t.Parallel() app := New() grp := app.Group("/api") require.PanicsWithValue(t, "route handler 'fn' cannot be nil", func() { grp.Route("/panic", nil) }) } func Test_Group_RouteChain_All(t *testing.T) { t.Parallel() app := New() var calls []string grp := app.Group("/api", func(c Ctx) error { calls = append(calls, "group") return c.Next() }) grp.RouteChain("/users").All(func(c Ctx) error { calls = append(calls, "routechain") return c.SendStatus(http.StatusOK) }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/api/users", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, http.StatusOK, resp.StatusCode, "Status code") require.Equal(t, []string{"group", "routechain"}, calls) } func Test_App_Deep_Group(t *testing.T) { t.Parallel() runThroughCount := 0 dummyHandler := func(c Ctx) error { runThroughCount++ return c.Next() } app := New() gAPI := app.Group("/api", dummyHandler) gV1 := gAPI.Group("/v1", dummyHandler) gUser := gV1.Group("/user", dummyHandler) gUser.Get("/authenticate", func(c Ctx) error { runThroughCount++ return c.SendStatus(200) }) testStatus200(t, app, "/api/v1/user/authenticate", MethodGet) require.Equal(t, 4, runThroughCount, "Loop count") } // go test -run Test_App_Next_Method func Test_App_Next_Method(t *testing.T) { t.Parallel() app := New() app.Use(func(c Ctx) error { require.Equal(t, MethodGet, c.Method()) err := c.Next() require.Equal(t, MethodGet, c.Method()) return err }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 404, resp.StatusCode, "Status code") } // go test -v -run=^$ -bench=Benchmark_NewError -benchmem -count=4 func Benchmark_NewError(b *testing.B) { for b.Loop() { NewError(200, "test") //nolint:errcheck // not needed } } func Benchmark_NewError_Parallel(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { NewError(200, "test") //nolint:errcheck // not needed } }) } // go test -run Test_NewError func Test_NewError(t *testing.T) { t.Parallel() e := NewError(StatusForbidden, "permission denied") require.Equal(t, StatusForbidden, e.Code) require.Equal(t, "permission denied", e.Message) } // go test -run Test_NewError_Format func Test_NewErrorf_Format(t *testing.T) { t.Parallel() type args []any tests := []struct { name string want string in args code int }{ { name: "no-args → default text", code: StatusNotFound, in: nil, want: utils.StatusMessage(StatusNotFound), }, { name: "single-string arg overrides", code: StatusBadRequest, in: args{"custom bad request"}, want: "custom bad request", }, { name: "single non-string arg stringified", code: StatusInternalServerError, in: args{errors.New("db down")}, want: "db down", }, { name: "single nil interface", code: StatusInternalServerError, in: args{any(nil)}, want: "", }, { name: "format string + args", code: StatusBadRequest, in: args{"invalid id %d", 10}, want: "invalid id 10", }, { name: "format string + excess args", code: StatusBadRequest, in: args{"odd %d", 1, 2, 3}, want: "odd 1%!(EXTRA int=2, int=3)", }, { name: "≥2 args but first not string", code: StatusBadRequest, in: args{errors.New("boom"), 42}, want: "boom", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() e := NewErrorf(tt.code, tt.in...) require.Equal(t, tt.code, e.Code) require.Equal(t, tt.want, e.Message) }) } } // go test -run Test_Test_Timeout func Test_Test_Timeout(t *testing.T) { t.Parallel() app := New() app.Get("/", testEmptyHandler) resp, err := app.Test(httptest.NewRequest(MethodGet, "/", http.NoBody), TestConfig{ Timeout: 0, }) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") app.Get("timeout", func(_ Ctx) error { time.Sleep(200 * time.Millisecond) return nil }) _, err = app.Test(httptest.NewRequest(MethodGet, "/timeout", http.NoBody), TestConfig{ Timeout: 20 * time.Millisecond, FailOnTimeout: true, }) require.Error(t, err, "app.Test(req)") } type errorReader int var errErrorReader = errors.New("errorReader") func (errorReader) Read([]byte) (int, error) { return 0, errErrorReader } // go test -run Test_Test_DumpError func Test_Test_DumpError(t *testing.T) { t.Parallel() app := New() app.Get("/", testEmptyHandler) resp, err := app.Test(httptest.NewRequest(MethodGet, "/", errorReader(0))) require.Nil(t, resp) require.ErrorIs(t, err, errErrorReader) } // go test -run Test_App_Handler func Test_App_Handler(t *testing.T) { t.Parallel() h := New().Handler() require.Equal(t, "fasthttp.RequestHandler", reflect.TypeOf(h).String()) } type invalidView struct{} func (invalidView) Load() error { return errors.New("invalid view") } func (invalidView) Render(io.Writer, string, any, ...string) error { panic("implement me") } type countingView struct { loadErr error loads int } func (v *countingView) Load() error { v.loads++ return v.loadErr } func (*countingView) Render(io.Writer, string, any, ...string) error { return nil } func Test_App_ReloadViews_Success(t *testing.T) { t.Parallel() view := &countingView{} app := New(Config{Views: view}) initialLoads := view.loads require.NoError(t, app.ReloadViews()) require.Equal(t, initialLoads+1, view.loads) require.NoError(t, app.ReloadViews()) require.Equal(t, initialLoads+2, view.loads) } func Test_App_ReloadViews_Error(t *testing.T) { t.Parallel() wantedErr := errors.New("boom") view := &countingView{loadErr: wantedErr} app := New(Config{Views: view}) err := app.ReloadViews() require.Error(t, err) require.ErrorIs(t, err, wantedErr) } func Test_App_ReloadViews_NoEngine(t *testing.T) { t.Parallel() app := New() err := app.ReloadViews() require.ErrorIs(t, err, ErrNoViewEngineConfigured) } func Test_App_ReloadViews_InterfaceNilPointer(t *testing.T) { t.Parallel() var view *countingView app := &App{config: Config{Views: view}} err := app.ReloadViews() require.ErrorIs(t, err, ErrNoViewEngineConfigured) } func Test_App_ReloadViews_MountedViews(t *testing.T) { t.Parallel() tempDir := t.TempDir() templatePath := filepath.Join(tempDir, "template.html") require.NoError(t, os.WriteFile(templatePath, []byte("before"), 0o600)) view := &fileView{path: templatePath} subApp := New(Config{Views: view}) app := New() app.Use("/sub", subApp) require.NoError(t, view.Load()) initialLoads := view.loads require.Equal(t, "before", view.content) require.NoError(t, os.WriteFile(templatePath, []byte("after"), 0o600)) require.NoError(t, app.ReloadViews()) require.Equal(t, "after", view.content) require.Greater(t, view.loads, initialLoads) } func Test_App_ReloadViews_MountedViews_Error(t *testing.T) { t.Parallel() expectedErr := errors.New("sub view error") subView := &countingView{loadErr: expectedErr} subApp := New(Config{Views: subView}) app := New() app.Use("/sub", subApp) err := app.ReloadViews() require.ErrorIs(t, err, expectedErr) } func Test_App_ReloadViews_MountedViews_MultipleApps(t *testing.T) { t.Parallel() viewA := &countingView{} viewB := &countingView{} subAppA := New(Config{Views: viewA}) subAppB := New(Config{Views: viewB}) app := New() app.Use("/a", subAppA) app.Use("/b", subAppB) initialLoadsA := viewA.loads initialLoadsB := viewB.loads require.NoError(t, app.ReloadViews()) require.Equal(t, initialLoadsA+1, viewA.loads) require.Equal(t, initialLoadsB+1, viewB.loads) } func Test_App_ReloadViews_MountedViews_WithParentViews(t *testing.T) { t.Parallel() parentView := &countingView{} subView := &countingView{} subApp := New(Config{Views: subView}) app := New(Config{Views: parentView}) app.Use("/sub", subApp) initialParentLoads := parentView.loads initialSubLoads := subView.loads require.NoError(t, app.ReloadViews()) require.Equal(t, initialParentLoads+1, parentView.loads) require.Equal(t, initialSubLoads+1, subView.loads) } // go test -run Test_App_Init_Error_View func Test_App_Init_Error_View(t *testing.T) { t.Parallel() app := New(Config{Views: invalidView{}}) require.PanicsWithValue(t, "implement me", func() { //nolint:errcheck // not needed _ = app.config.Views.Render(nil, "", nil) }) } // go test -run Test_App_Stack func Test_App_Stack(t *testing.T) { t.Parallel() app := New() app.Use("/path0", testEmptyHandler) app.Get("/path1", testEmptyHandler) app.Get("/path2", testEmptyHandler) app.Post("/path3", testEmptyHandler) app.startupProcess() stack := app.Stack() methodList := app.config.RequestMethods require.Len(t, methodList, len(stack)) require.Len(t, stack[app.methodInt(MethodGet)], 3) require.Len(t, stack[app.methodInt(MethodHead)], 3) require.Len(t, stack[app.methodInt(MethodPost)], 2) require.Len(t, stack[app.methodInt(MethodPut)], 1) require.Len(t, stack[app.methodInt(MethodPatch)], 1) require.Len(t, stack[app.methodInt(MethodDelete)], 1) require.Len(t, stack[app.methodInt(MethodConnect)], 1) require.Len(t, stack[app.methodInt(MethodOptions)], 1) require.Len(t, stack[app.methodInt(MethodTrace)], 1) } // go test -run Test_App_HandlersCount func Test_App_HandlersCount(t *testing.T) { t.Parallel() app := New() app.Use("/path0", testEmptyHandler) app.Get("/path2", testEmptyHandler) app.Post("/path3", testEmptyHandler) app.startupProcess() count := app.HandlersCount() require.Equal(t, uint32(4), count) } // go test -run Test_App_ReadTimeout func Test_App_ReadTimeout(t *testing.T) { t.Parallel() app := New(Config{ ReadTimeout: time.Nanosecond, IdleTimeout: time.Minute, DisableKeepalive: true, }) ln, err := net.Listen(NetworkTCP4, "127.0.0.1:0") require.NoError(t, err) addr := ln.Addr().String() app.Get("/read-timeout", func(c Ctx) error { return c.SendString("I should not be sent") }) go func() { time.Sleep(500 * time.Millisecond) conn, err := net.Dial(NetworkTCP4, addr) assert.NoError(t, err) defer func(conn net.Conn) { closeErr := conn.Close() assert.NoError(t, closeErr) }(conn) _, err = conn.Write([]byte("HEAD /read-timeout HTTP/1.1\r\n")) assert.NoError(t, err) buf := make([]byte, 1024) var n int n, err = conn.Read(buf) assert.NoError(t, err) assert.True(t, bytes.Contains(buf[:n], []byte("408 Request Timeout"))) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listener(ln, ListenConfig{DisableStartupMessage: true})) } // go test -run Test_App_BadRequest func Test_App_BadRequest(t *testing.T) { t.Parallel() app := New() app.Get("/bad-request", func(c Ctx) error { return c.SendString("I should not be sent") }) ln, err := net.Listen(NetworkTCP4, "127.0.0.1:0") require.NoError(t, err) addr := ln.Addr().String() go func() { time.Sleep(500 * time.Millisecond) conn, err := net.Dial(NetworkTCP4, addr) assert.NoError(t, err) defer func(conn net.Conn) { closeErr := conn.Close() assert.NoError(t, closeErr) }(conn) _, err = conn.Write([]byte("BadRequest\r\n")) assert.NoError(t, err) buf := make([]byte, 1024) var n int n, err = conn.Read(buf) assert.NoError(t, err) assert.True(t, bytes.Contains(buf[:n], []byte("400 Bad Request"))) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listener(ln, ListenConfig{DisableStartupMessage: true})) } // go test -run Test_App_SmallReadBuffer func Test_App_SmallReadBuffer(t *testing.T) { t.Parallel() app := New(Config{ ReadBufferSize: 1, }) app.Get("/small-read-buffer", func(c Ctx) error { return c.SendString("I should not be sent") }) ln, err := net.Listen(NetworkTCP4, "127.0.0.1:0") require.NoError(t, err) addr := ln.Addr().String() go func() { time.Sleep(500 * time.Millisecond) req, err := http.NewRequestWithContext(context.Background(), MethodGet, fmt.Sprintf("http://%s/small-read-buffer", addr), http.NoBody) assert.NoError(t, err) var client http.Client resp, err := client.Do(req) assert.NoError(t, err) assert.Equal(t, 431, resp.StatusCode) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listener(ln, ListenConfig{DisableStartupMessage: true})) } func Test_App_Server(t *testing.T) { t.Parallel() app := New() require.NotNil(t, app.Server()) } func Test_App_Error_In_Fasthttp_Server(t *testing.T) { app := New() app.config.ErrorHandler = func(_ Ctx, _ error) error { return errors.New("fake error") } app.server.GetOnly = true resp, err := app.Test(httptest.NewRequest(MethodPost, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, 500, resp.StatusCode) } // go test -race -run Test_App_New_Test_Parallel func Test_App_New_Test_Parallel(t *testing.T) { t.Parallel() t.Run("Test_App_New_Test_Parallel_1", func(t *testing.T) { t.Parallel() app := New(Config{Immutable: true}) _, err := app.Test(httptest.NewRequest(MethodGet, "/", http.NoBody)) require.NoError(t, err) }) t.Run("Test_App_New_Test_Parallel_2", func(t *testing.T) { t.Parallel() app := New(Config{Immutable: true}) _, err := app.Test(httptest.NewRequest(MethodGet, "/", http.NoBody)) require.NoError(t, err) }) } func Test_App_ReadBodyStream(t *testing.T) { t.Parallel() app := New(Config{StreamRequestBody: true}) app.Post("/", func(c Ctx) error { // Calling c.Body() automatically reads the entire stream. return c.SendString(fmt.Sprintf("%v %s", c.Request().IsBodyStream(), c.Body())) }) testString := "this is a test" resp, err := app.Test(httptest.NewRequest(MethodPost, "/", bytes.NewBufferString(testString))) require.NoError(t, err, "app.Test(req)") body, err := io.ReadAll(resp.Body) require.NoError(t, err, "io.ReadAll(resp.Body)") require.Equal(t, "true "+testString, string(body)) } func Test_App_DisablePreParseMultipartForm(t *testing.T) { t.Parallel() // Must be used with both otherwise there is no point. testString := "this is a test" app := New(Config{DisablePreParseMultipartForm: true, StreamRequestBody: true}) app.Post("/", func(c Ctx) error { req := c.Request() mpf, err := req.MultipartForm() if err != nil { return err } if !req.IsBodyStream() { return errors.New("not a body stream") } file, err := mpf.File["test"][0].Open() if err != nil { return fmt.Errorf("failed to open: %w", err) } buffer := make([]byte, len(testString)) n, err := file.Read(buffer) if err != nil { return fmt.Errorf("failed to read: %w", err) } if n != len(testString) { return errors.New("bad read length") } return c.Send(buffer) }) b := &bytes.Buffer{} w := multipart.NewWriter(b) writer, err := w.CreateFormFile("test", "test") require.NoError(t, err, "w.CreateFormFile") n, err := writer.Write([]byte(testString)) require.NoError(t, err, "writer.Write") require.Len(t, testString, n, "writer n") require.NoError(t, w.Close(), "w.Close()") req := httptest.NewRequest(MethodPost, "/", b) req.Header.Set("Content-Type", w.FormDataContentType()) resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") body, err := io.ReadAll(resp.Body) require.NoError(t, err, "io.ReadAll(resp.Body)") require.Equal(t, testString, string(body)) } func Test_App_Test_no_timeout_infinitely(t *testing.T) { t.Parallel() var err error c := make(chan int) go func() { defer func() { c <- 0 }() app := New() app.Get("/", func(_ Ctx) error { runtime.Goexit() return nil }) req := httptest.NewRequest(MethodGet, "/", http.NoBody) _, err = app.Test(req, TestConfig{ Timeout: 0, }) }() tk := time.NewTimer(5 * time.Second) defer tk.Stop() select { case <-tk.C: t.Error("hanging test") t.FailNow() case <-c: } if err == nil { t.Error("unexpected success request") t.FailNow() } } func Test_App_Test_timeout(t *testing.T) { t.Parallel() app := New() app.Get("/", func(_ Ctx) error { time.Sleep(1 * time.Second) return nil }) _, err := app.Test(httptest.NewRequest(MethodGet, "/", http.NoBody), TestConfig{ Timeout: 100 * time.Millisecond, FailOnTimeout: true, }) require.ErrorIs(t, err, os.ErrDeadlineExceeded) } func Test_App_Test_timeout_empty_response(t *testing.T) { t.Parallel() app := New() app.Get("/", func(_ Ctx) error { time.Sleep(50 * time.Millisecond) return nil }) // When FailOnTimeout is false, the test should return whatever response is available // at timeout without failing resp, err := app.Test(httptest.NewRequest(MethodGet, "/", http.NoBody), TestConfig{ Timeout: 10 * time.Millisecond, FailOnTimeout: false, }) require.NoError(t, err) require.Equal(t, StatusOK, resp.StatusCode) } func Test_App_Test_drop_empty_response(t *testing.T) { t.Parallel() app := New() app.Get("/", func(c Ctx) error { return c.Drop() }) _, err := app.Test(httptest.NewRequest(MethodGet, "/", http.NoBody), TestConfig{ Timeout: 0, FailOnTimeout: false, }) require.ErrorIs(t, err, ErrTestGotEmptyResponse) } func Test_App_Test_response_error(t *testing.T) { // Note: Test cannot run in parallel due to // overriding the httpReadResponse global variable. // t.Parallel() // Override httpReadResponse temporarily oldHTTPReadResponse := httpReadResponse defer func() { httpReadResponse = oldHTTPReadResponse }() httpReadResponse = func(_ *bufio.Reader, _ *http.Request) (*http.Response, error) { return nil, errErrorReader } app := New() app.Get("/", func(c Ctx) error { return c.SendStatus(StatusOK) }) _, err := app.Test(httptest.NewRequest(MethodGet, "/", http.NoBody), TestConfig{ Timeout: 0, FailOnTimeout: false, }) require.ErrorIs(t, err, errErrorReader) } type errorReadCloser int var errInvalidReadOnBody = errors.New("test: invalid Read on body") func (errorReadCloser) Read(_ []byte) (int, error) { return 0, errInvalidReadOnBody } func (errorReadCloser) Close() error { return nil } func Test_App_Test_ReadFail(t *testing.T) { // Note: Test cannot run in parallel due to // overriding the httpReadResponse global variable. // t.Parallel() // Override httpReadResponse temporarily oldHTTPReadResponse := httpReadResponse defer func() { httpReadResponse = oldHTTPReadResponse }() httpReadResponse = func(r *bufio.Reader, req *http.Request) (*http.Response, error) { resp, err := http.ReadResponse(r, req) require.NoError(t, resp.Body.Close()) resp.Body = errorReadCloser(0) return resp, err //nolint:wrapcheck // unnecessary to wrap it } app := New() hints := []string{"; rel=preload; as=script"} app.Get("/early", func(c Ctx) error { err := c.SendEarlyHints(hints) require.NoError(t, err) return c.SendStatus(StatusOK) }) req := httptest.NewRequest(MethodGet, "/early", http.NoBody) _, err := app.Test(req) require.ErrorIs(t, err, errInvalidReadOnBody) } var errDoubleClose = errors.New("test: double close") type doubleCloseBody struct { isClosed bool } func (b *doubleCloseBody) Read(_ []byte) (int, error) { if b.isClosed { return 0, errInvalidReadOnBody } // Close after reading EOF _ = b.Close() //nolint:errcheck // It is fine to ignore the error here return 0, io.EOF } func (b *doubleCloseBody) Close() error { if b.isClosed { return errDoubleClose } b.isClosed = true return nil } func Test_App_Test_CloseFail(t *testing.T) { // Note: Test cannot run in parallel due to // overriding the httpReadResponse global variable. // t.Parallel() // Override httpReadResponse temporarily oldHTTPReadResponse := httpReadResponse defer func() { httpReadResponse = oldHTTPReadResponse }() httpReadResponse = func(r *bufio.Reader, req *http.Request) (*http.Response, error) { resp, err := http.ReadResponse(r, req) _ = resp.Body.Close() //nolint:errcheck // It is fine to ignore the error here resp.Body = &doubleCloseBody{} return resp, err //nolint:wrapcheck // unnecessary to wrap it } app := New() hints := []string{"; rel=preload; as=script"} app.Get("/early", func(c Ctx) error { err := c.SendEarlyHints(hints) require.NoError(t, err) return c.Status(StatusOK).SendString("done") }) req := httptest.NewRequest(MethodGet, "/early", http.NoBody) _, err := app.Test(req) require.ErrorIs(t, err, errDoubleClose) } func Test_App_SetTLSHandler(t *testing.T) { t.Parallel() tlsHandler := &TLSHandler{clientHelloInfo: &tls.ClientHelloInfo{ ServerName: "example.golang", }} app := New() app.SetTLSHandler(tlsHandler) c := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(c) require.Equal(t, "example.golang", c.ClientHelloInfo().ServerName) } func Test_App_AddCustomRequestMethod(t *testing.T) { t.Parallel() methods := append(DefaultMethods, "TEST") //nolint:gocritic // We want a new slice here app := New(Config{ RequestMethods: methods, }) appMethods := app.config.RequestMethods // method name is always uppercase - https://datatracker.ietf.org/doc/html/rfc7231#section-4.1 require.Len(t, app.stack, len(appMethods)) require.Equal(t, "TEST", appMethods[len(appMethods)-1]) } func Test_App_GetRoutes(t *testing.T) { t.Parallel() app := New() app.Use(func(c Ctx) error { return c.Next() }) handler := func(c Ctx) error { return c.SendStatus(StatusOK) } app.Delete("/delete", handler).Name("delete") app.Post("/post", handler).Name("post") routes := app.GetRoutes(false) require.Len(t, routes, 2+len(app.config.RequestMethods)) methodMap := map[string]string{"/delete": "delete", "/post": "post"} for _, route := range routes { name, ok := methodMap[route.Path] if ok { require.Equal(t, name, route.Name) } } routes = app.GetRoutes(true) require.Len(t, routes, 2) for _, route := range routes { name, ok := methodMap[route.Path] require.True(t, ok) require.Equal(t, name, route.Name) } } func Test_Middleware_Route_Naming_With_Use(t *testing.T) { t.Parallel() named := "named" app := New() app.Get("/unnamed", func(c Ctx) error { return c.Next() }) app.Post("/named", func(c Ctx) error { return c.Next() }).Name(named) app.Use(func(c Ctx) error { return c.Next() }) // no name - logging MW app.Use(func(c Ctx) error { return c.Next() }).Name("corsMW") app.Use(func(c Ctx) error { return c.Next() }).Name("compressMW") app.Use(func(c Ctx) error { return c.Next() }) // no name - cache MW grp := app.Group("/pages").Name("pages.") grp.Use(func(c Ctx) error { return c.Next() }).Name("csrfMW") grp.Get("/home", func(c Ctx) error { return c.Next() }).Name("home") grp.Get("/unnamed", func(c Ctx) error { return c.Next() }) for _, route := range app.GetRoutes() { switch route.Path { case "/": require.Equal(t, "compressMW", route.Name) case "/unnamed", "/pages/unnamed": require.Empty(t, route.Name) case "/named": require.Equal(t, named, route.Name) case "/pages": require.Equal(t, "pages.csrfMW", route.Name) case "/pages/home": require.Equal(t, "pages.home", route.Name) default: t.Errorf("unknown route: %s", route.Path) } } } func Test_Route_Naming_Issue_2671_2685(t *testing.T) { t.Parallel() app := New() app.Get("/", emptyHandler).Name("index") require.Equal(t, "/", app.GetRoute("index").Path) app.Get("/a/:a_id", emptyHandler).Name("a") require.Equal(t, "/a/:a_id", app.GetRoute("a").Path) app.Post("/b/:bId", emptyHandler).Name("b") require.Equal(t, "/b/:bId", app.GetRoute("b").Path) c := app.Group("/c") c.Get("", emptyHandler).Name("c.get") require.Equal(t, "/c", app.GetRoute("c.get").Path) c.Post("", emptyHandler).Name("c.post") require.Equal(t, "/c", app.GetRoute("c.post").Path) c.Get("/d", emptyHandler).Name("c.get.d") require.Equal(t, "/c/d", app.GetRoute("c.get.d").Path) d := app.Group("/d/:d_id") d.Get("", emptyHandler).Name("d.get") require.Equal(t, "/d/:d_id", app.GetRoute("d.get").Path) d.Post("", emptyHandler).Name("d.post") require.Equal(t, "/d/:d_id", app.GetRoute("d.post").Path) e := app.Group("/e/:eId") e.Get("", emptyHandler).Name("e.get") require.Equal(t, "/e/:eId", app.GetRoute("e.get").Path) e.Post("", emptyHandler).Name("e.post") require.Equal(t, "/e/:eId", app.GetRoute("e.post").Path) e.Get("f", emptyHandler).Name("e.get.f") require.Equal(t, "/e/:eId/f", app.GetRoute("e.get.f").Path) postGroup := app.Group("/post/:postId") postGroup.Get("", emptyHandler).Name("post.get") require.Equal(t, "/post/:postId", app.GetRoute("post.get").Path) postGroup.Post("", emptyHandler).Name("post.update") require.Equal(t, "/post/:postId", app.GetRoute("post.update").Path) // Add testcase for routes use the same PATH on different methods app.Get("/users", emptyHandler).Name("get-users") app.Post("/users", emptyHandler).Name("add-user") getUsers := app.GetRoute("get-users") require.Equal(t, "/users", getUsers.Path) addUser := app.GetRoute("add-user") require.Equal(t, "/users", addUser.Path) // Add testcase for routes use the same PATH on different methods (for groups) newGrp := app.Group("/name-test") newGrp.Get("/users", emptyHandler).Name("grp-get-users") newGrp.Post("/users", emptyHandler).Name("grp-add-user") getUsers = app.GetRoute("grp-get-users") require.Equal(t, "/name-test/users", getUsers.Path) addUser = app.GetRoute("grp-add-user") require.Equal(t, "/name-test/users", addUser.Path) // Add testcase for HEAD route naming app.Get("/simple-route", emptyHandler).Name("simple-route") app.Head("/simple-route", emptyHandler).Name("simple-route2") sRoute := app.GetRoute("simple-route") require.Equal(t, "/simple-route", sRoute.Path) sRoute2 := app.GetRoute("simple-route2") require.Equal(t, "/simple-route", sRoute2.Path) } func Test_App_State(t *testing.T) { t.Parallel() app := New() app.State().Set("key", "value") str, ok := app.State().GetString("key") require.True(t, ok) require.Equal(t, "value", str) } // go test -v -run=^$ -bench=Benchmark_Communication_Flow -benchmem -count=4 func Benchmark_Communication_Flow(b *testing.B) { app := New() app.Get("/", func(c Ctx) error { return c.SendString("Hello, World!") }) h := app.Handler() fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(MethodGet) fctx.Request.SetRequestURI("/") b.ReportAllocs() for b.Loop() { h(fctx) } require.Equal(b, 200, fctx.Response.Header.StatusCode()) require.Equal(b, "Hello, World!", string(fctx.Response.Body())) } func Benchmark_Communication_Flow_Parallel(b *testing.B) { app := New() app.Get("/", func(c Ctx) error { return c.SendString("Hello, World!") }) h := app.Handler() b.ReportAllocs() b.RunParallel(func(pb *testing.PB) { fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(MethodGet) fctx.Request.SetRequestURI("/") for pb.Next() { h(fctx) } }) verifyCtx := &fasthttp.RequestCtx{} verifyCtx.Request.Header.SetMethod(MethodGet) verifyCtx.Request.SetRequestURI("/") h(verifyCtx) require.Equal(b, 200, verifyCtx.Response.Header.StatusCode()) require.Equal(b, "Hello, World!", string(verifyCtx.Response.Body())) } // go test -v -run=^$ -bench=Benchmark_Ctx_AcquireReleaseFlow -benchmem -count=4 func Benchmark_Ctx_AcquireReleaseFlow(b *testing.B) { app := New() fctx := &fasthttp.RequestCtx{} b.Run("withoutRequestCtx", func(b *testing.B) { b.ReportAllocs() for b.Loop() { c, _ := app.AcquireCtx(fctx).(*DefaultCtx) //nolint:errcheck // not needed app.ReleaseCtx(c) } }) b.Run("withRequestCtx", func(b *testing.B) { b.ReportAllocs() for b.Loop() { c, _ := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck // not needed app.ReleaseCtx(c) } }) } func acquireDefaultCtxForAppBenchmark(b *testing.B, app *App, fctx *fasthttp.RequestCtx) *DefaultCtx { b.Helper() ctx := app.AcquireCtx(fctx) defaultCtx, ok := ctx.(*DefaultCtx) if !ok { b.Fatal("AcquireCtx did not return *DefaultCtx") } return defaultCtx } func Benchmark_Ctx_AcquireReleaseFlow_Parallel(b *testing.B) { app := New() b.Run("withoutRequestCtx", func(b *testing.B) { b.ReportAllocs() b.RunParallel(func(pb *testing.PB) { fctx := &fasthttp.RequestCtx{} for pb.Next() { c := acquireDefaultCtxForAppBenchmark(b, app, fctx) app.ReleaseCtx(c) } }) }) b.Run("withRequestCtx", func(b *testing.B) { b.ReportAllocs() b.RunParallel(func(pb *testing.PB) { for pb.Next() { c := acquireDefaultCtxForAppBenchmark(b, app, &fasthttp.RequestCtx{}) app.ReleaseCtx(c) } }) }) } func TestErrorHandler_PicksRightOne(t *testing.T) { t.Parallel() // common handler to be used by all routes, // it will always fail by returning an error since // we need to test that the right ErrorHandler is invoked handler := func(_ Ctx) error { return errors.New("random error") } // subapp /api/v1/users [no custom error handler] appAPIV1Users := New() appAPIV1Users.Get("/", handler) // subapp /api/v1/use [with custom error handler] appAPIV1UseEH := func(c Ctx, _ error) error { return c.SendString("/api/v1/use error handler") } appAPIV1Use := New(Config{ErrorHandler: appAPIV1UseEH}) appAPIV1Use.Get("/", handler) // subapp: /api/v1 [with custom error handler] appV1EH := func(c Ctx, _ error) error { return c.SendString("/api/v1 error handler") } appV1 := New(Config{ErrorHandler: appV1EH}) appV1.Get("/", handler) appV1.Use("/users", appAPIV1Users) appV1.Use("/use", appAPIV1Use) // root app [no custom error handler] app := New() app.Get("/", handler) app.Use("/api/v1", appV1) testCases := []struct { path string // the endpoint url to test expected string // the expected error response }{ // /api/v1/users mount doesn't have custom ErrorHandler // so it should use the upper-nearest one (/api/v1) {"/api/v1/users", "/api/v1 error handler"}, // /api/v1/use mount has a custom ErrorHandler {"/api/v1/use", "/api/v1/use error handler"}, // /api/v1 mount has a custom ErrorHandler {"/api/v1", "/api/v1 error handler"}, // / mount doesn't have custom ErrorHandler, since is // the root path i will use Fiber's default Error Handler {"/", "random error"}, } for _, testCase := range testCases { t.Run(testCase.path, func(t *testing.T) { t.Parallel() resp, err := app.Test(httptest.NewRequest(MethodGet, testCase.path, http.NoBody)) if err != nil { t.Fatal(err) } body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } require.Equal(t, testCase.expected, string(body)) }) } } type groupIDResponse struct { GroupID string `json:"group_id"` } // Test for the reported bug where Test method returns "test: got empty response" error // when using a very small timeout value (e.g., 5 microseconds instead of 5 seconds). // With FailOnTimeout: false, the test should return whatever response is available without error. func Test_App_Test_SmallTimeout_WithFailOnTimeoutFalse(t *testing.T) { t.Parallel() app := New() app.Post("/admin/api/groups", func(c Ctx) error { // Add a small delay to ensure timeout is triggered time.Sleep(10 * time.Millisecond) groupID := "g.test123" return c.JSON(groupIDResponse{ GroupID: groupID, }) }) req := httptest.NewRequest(MethodPost, "/admin/api/groups", http.NoBody) // Using 5 microseconds which is too short for the handler to complete // But with FailOnTimeout: false, it should return whatever is available without error resp, err := app.Test(req, TestConfig{ Timeout: 5 * time.Microsecond, FailOnTimeout: false, }) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) var response groupIDResponse body, err := io.ReadAll(resp.Body) require.NoError(t, err) err = json.Unmarshal(body, &response) require.NoError(t, err) require.NotEmpty(t, response.GroupID) require.Equal(t, "g.test123", response.GroupID) } // Test that FailOnTimeout: true still works as expected func Test_App_Test_SmallTimeout_WithFailOnTimeoutTrue(t *testing.T) { t.Parallel() app := New() app.Post("/admin/api/groups", func(c Ctx) error { time.Sleep(100 * time.Millisecond) groupID := "g.test123" return c.JSON(groupIDResponse{ GroupID: groupID, }) }) req := httptest.NewRequest(MethodPost, "/admin/api/groups", http.NoBody) // With FailOnTimeout: true (default), it should fail fast with timeout error _, err := app.Test(req, TestConfig{ Timeout: 10 * time.Millisecond, FailOnTimeout: true, }) require.ErrorIs(t, err, os.ErrDeadlineExceeded) } ================================================ FILE: bind.go ================================================ package fiber import ( "encoding/json" "errors" "fmt" "reflect" "slices" "sync" "github.com/gofiber/fiber/v3/binder" "github.com/gofiber/schema" "github.com/gofiber/utils/v2" utilsbytes "github.com/gofiber/utils/v2/bytes" ) // CustomBinder An interface to register custom binders. type CustomBinder interface { Name() string MIMETypes() []string Parse(c Ctx, out any) error } // StructValidator is an interface to register custom struct validator for binding. type StructValidator interface { Validate(out any) error } var bindPool = sync.Pool{ New: func() any { return &Bind{ dontHandleErrs: true, } }, } // Bind provides helper methods for binding request data to Go values. // By default (manual mode), parsing failures are returned as *BindError; use errors.As to extract source and field details. // With WithAutoHandling(), parsing failures set HTTP 400 and return *Error instead. type Bind struct { ctx Ctx dontHandleErrs bool skipValidation bool } // BindError source constants for BindError.Source. const ( BindSourceURI = "uri" BindSourceQuery = "query" BindSourceHeader = "header" BindSourceCookie = "cookie" BindSourceBody = "body" BindSourceRespHeader = "respHeader" ) // BindError wraps a binding failure with the source and field that failed. // Use errors.As(err, &be) to extract it when you need to branch on source // (e.g. 404 for URI vs 400 for body). type BindError struct { Err error // underlying error; use errors.As to inspect Source string // binding source: uri, query, body, header, cookie, or respHeader (see BindSource* constants) Field string // struct field or tag key that failed (best-effort, may be empty) } func (e *BindError) Error() string { if e.Field != "" { return fmt.Sprintf("bind %q from %s: %v", e.Field, e.Source, e.Err) } return fmt.Sprintf("bind from %s: %v", e.Source, e.Err) } func (e *BindError) Unwrap() error { return e.Err } func extractFieldFromError(err error) string { var convErr schema.ConversionError if errors.As(err, &convErr) { return convErr.Key } var unknownKey schema.UnknownKeyError if errors.As(err, &unknownKey) { return unknownKey.Key } var emptyField schema.EmptyFieldError if errors.As(err, &emptyField) { return emptyField.Key } var multiErr schema.MultiError if errors.As(err, &multiErr) { for k := range multiErr { return k } } var unmarshalErr *json.UnmarshalTypeError if errors.As(err, &unmarshalErr) { return unmarshalErr.Field } return "" } func newBindError(source string, raw error) *BindError { return &BindError{Source: source, Field: extractFieldFromError(raw), Err: raw} } // AcquireBind returns Bind reference from bind pool. func AcquireBind() *Bind { b, ok := bindPool.Get().(*Bind) if !ok { panic(errBindPoolTypeAssertion) } return b } // ReleaseBind returns b acquired via Bind to bind pool. func ReleaseBind(b *Bind) { b.release() bindPool.Put(b) } // releasePooledBinder resets a binder and returns it to its pool. // It should be used with defer to ensure proper cleanup of pooled binders. func releasePooledBinder[T interface{ Reset() }](pool *sync.Pool, bind T) { bind.Reset() binder.PutToThePool(pool, bind) } func (b *Bind) release() { b.ctx = nil b.dontHandleErrs = true b.skipValidation = false } // WithoutAutoHandling If you want to handle binder errors manually, you can use `WithoutAutoHandling`. // It's default behavior of binder. func (b *Bind) WithoutAutoHandling() *Bind { b.dontHandleErrs = true return b } // WithAutoHandling If you want to handle binder errors automatically, you can use `WithAutoHandling`. // If there's an error, it will return the error and set HTTP status to `400 Bad Request`. // You must still return on error explicitly func (b *Bind) WithAutoHandling() *Bind { b.dontHandleErrs = false return b } // SkipValidation enables or disables struct validation for the current bind chain. func (b *Bind) SkipValidation(skip bool) *Bind { b.skipValidation = skip return b } // Check WithAutoHandling/WithoutAutoHandling errors and return it by usage. func (b *Bind) returnErr(err error) error { if err == nil || b.dontHandleErrs { return err } b.ctx.Status(StatusBadRequest) return NewError(StatusBadRequest, "Bad request: "+err.Error()) } // returnBindErr runs returnErr and, if the result is not a *Error, wraps it in *BindError. // Use for binding parse failures; use returnErr directly for Custom and validation errors. func (b *Bind) returnBindErr(err error, source string) error { if retErr := b.returnErr(err); retErr != nil { var fiberErr *Error if errors.As(retErr, &fiberErr) { return fiberErr } return newBindError(source, retErr) } return nil } // Struct validation. func (b *Bind) validateStruct(out any) error { if b.skipValidation { return nil } validator := b.ctx.App().config.StructValidator if validator == nil { return nil } t := reflect.TypeOf(out) if t == nil { return nil } // Unwrap pointers (e.g. *T, **T) to inspect the underlying destination type. for t.Kind() == reflect.Ptr { t = t.Elem() } if t.Kind() != reflect.Struct { return nil } return validator.Validate(out) } // Custom To use custom binders, you have to use this method. // You can register them from RegisterCustomBinder method of Fiber instance. // They're checked by name, if it's not found, it will return an error. // NOTE: WithAutoHandling/WithAutoHandling is still valid for Custom binders. func (b *Bind) Custom(name string, dest any) error { binders := b.ctx.App().customBinders for _, customBinder := range binders { if customBinder.Name() == name { if err := b.returnBindErr(customBinder.Parse(b.ctx, dest), name); err != nil { return err } return b.validateStruct(dest) } } return ErrCustomBinderNotFound } // Header binds the request header strings into the struct, map[string]string and map[string][]string. // Returns *BindError on parse failure (manual mode) or *Error with status 400 (auto-handling mode). func (b *Bind) Header(out any) error { bind := binder.GetFromThePool[*binder.HeaderBinding](&binder.HeaderBinderPool) bind.EnableSplitting = b.ctx.App().config.EnableSplittingOnParsers defer releasePooledBinder(&binder.HeaderBinderPool, bind) if err := b.returnBindErr(bind.Bind(b.ctx.Request(), out), BindSourceHeader); err != nil { return err } return b.validateStruct(out) } // RespHeader binds the response header strings into the struct, map[string]string and map[string][]string. // Returns *BindError on parse failure (manual mode) or *Error with status 400 (auto-handling mode). func (b *Bind) RespHeader(out any) error { bind := binder.GetFromThePool[*binder.RespHeaderBinding](&binder.RespHeaderBinderPool) bind.EnableSplitting = b.ctx.App().config.EnableSplittingOnParsers defer releasePooledBinder(&binder.RespHeaderBinderPool, bind) if err := b.returnBindErr(bind.Bind(b.ctx.Response(), out), BindSourceRespHeader); err != nil { return err } return b.validateStruct(out) } // Cookie binds the request cookie strings into the struct, map[string]string and map[string][]string. // Returns *BindError on parse failure (manual mode) or *Error with status 400 (auto-handling mode). // NOTE: If your cookie is like key=val1,val2; they'll be bound as a slice if your map is map[string][]string. Else, it'll use last element of cookie. func (b *Bind) Cookie(out any) error { bind := binder.GetFromThePool[*binder.CookieBinding](&binder.CookieBinderPool) bind.EnableSplitting = b.ctx.App().config.EnableSplittingOnParsers defer releasePooledBinder(&binder.CookieBinderPool, bind) if err := b.returnBindErr(bind.Bind(&b.ctx.RequestCtx().Request, out), BindSourceCookie); err != nil { return err } return b.validateStruct(out) } // Query binds the query string into the struct, map[string]string and map[string][]string. // Returns *BindError on parse failure (manual mode) or *Error with status 400 (auto-handling mode). func (b *Bind) Query(out any) error { bind := binder.GetFromThePool[*binder.QueryBinding](&binder.QueryBinderPool) bind.EnableSplitting = b.ctx.App().config.EnableSplittingOnParsers defer releasePooledBinder(&binder.QueryBinderPool, bind) if err := b.returnBindErr(bind.Bind(&b.ctx.RequestCtx().Request, out), BindSourceQuery); err != nil { return err } return b.validateStruct(out) } // JSON binds the body string into the struct. // Returns *BindError on parse failure (manual mode) or *Error with status 400 (auto-handling mode). func (b *Bind) JSON(out any) error { bind := binder.GetFromThePool[*binder.JSONBinding](&binder.JSONBinderPool) bind.JSONDecoder = b.ctx.App().Config().JSONDecoder defer releasePooledBinder(&binder.JSONBinderPool, bind) if err := b.returnBindErr(bind.Bind(b.ctx.Body(), out), BindSourceBody); err != nil { return err } return b.validateStruct(out) } // CBOR binds the body string into the struct. // Returns *BindError on parse failure (manual mode) or *Error with status 400 (auto-handling mode). func (b *Bind) CBOR(out any) error { bind := binder.GetFromThePool[*binder.CBORBinding](&binder.CBORBinderPool) bind.CBORDecoder = b.ctx.App().Config().CBORDecoder defer releasePooledBinder(&binder.CBORBinderPool, bind) if err := b.returnBindErr(bind.Bind(b.ctx.Body(), out), BindSourceBody); err != nil { return err } return b.validateStruct(out) } // XML binds the body string into the struct. // Returns *BindError on parse failure (manual mode) or *Error with status 400 (auto-handling mode). func (b *Bind) XML(out any) error { bind := binder.GetFromThePool[*binder.XMLBinding](&binder.XMLBinderPool) bind.XMLDecoder = b.ctx.App().config.XMLDecoder defer releasePooledBinder(&binder.XMLBinderPool, bind) if err := b.returnBindErr(bind.Bind(b.ctx.Body(), out), BindSourceBody); err != nil { return err } return b.validateStruct(out) } // Form binds the form into the struct, map[string]string and map[string][]string. // Returns *BindError on parse failure (manual mode) or *Error with status 400 (auto-handling mode). // If Content-Type is "application/x-www-form-urlencoded" or "multipart/form-data", it will bind the form values. // Multipart file fields are supported using *multipart.FileHeader, []*multipart.FileHeader, or *[]*multipart.FileHeader. func (b *Bind) Form(out any) error { bind := binder.GetFromThePool[*binder.FormBinding](&binder.FormBinderPool) bind.EnableSplitting = b.ctx.App().config.EnableSplittingOnParsers defer releasePooledBinder(&binder.FormBinderPool, bind) if err := b.returnBindErr(bind.Bind(&b.ctx.RequestCtx().Request, out), BindSourceBody); err != nil { return err } return b.validateStruct(out) } // URI binds the route parameters into the struct, map[string]string and map[string][]string. // Returns *BindError on parse failure (manual mode) or *Error with status 400 (auto-handling mode). func (b *Bind) URI(out any) error { bind := binder.GetFromThePool[*binder.URIBinding](&binder.URIBinderPool) defer releasePooledBinder(&binder.URIBinderPool, bind) if err := b.returnBindErr(bind.Bind(b.ctx.Route().Params, b.ctx.Params, out), BindSourceURI); err != nil { return err } return b.validateStruct(out) } // MsgPack binds the body string into the struct. // Returns *BindError on parse failure (manual mode) or *Error with status 400 (auto-handling mode). func (b *Bind) MsgPack(out any) error { bind := binder.GetFromThePool[*binder.MsgPackBinding](&binder.MsgPackBinderPool) bind.MsgPackDecoder = b.ctx.App().Config().MsgPackDecoder defer releasePooledBinder(&binder.MsgPackBinderPool, bind) if err := b.returnBindErr(bind.Bind(b.ctx.Body(), out), BindSourceBody); err != nil { return err } return b.validateStruct(out) } // Body binds the request body into the struct, map[string]string and map[string][]string. // Returns *BindError on parse failure (manual mode) or *Error with status 400 (auto-handling mode). // It supports decoding the following content types based on the Content-Type header: // application/json, application/xml, application/x-www-form-urlencoded, multipart/form-data // If none of the content types above are matched, it'll take a look custom binders by checking the MIMETypes() method of custom binder. // If there is no custom binder for mime type of body, it will return a ErrUnprocessableEntity error. func (b *Bind) Body(out any) error { // Get content-type ctype := utils.UnsafeString(utilsbytes.UnsafeToLower(b.ctx.RequestCtx().Request.Header.ContentType())) ctype = binder.FilterFlags(utils.ParseVendorSpecificContentType(ctype)) // Check custom binders binders := b.ctx.App().customBinders for _, customBinder := range binders { if slices.Contains(customBinder.MIMETypes(), ctype) { if err := b.returnBindErr(customBinder.Parse(b.ctx, out), BindSourceBody); err != nil { return err } return b.validateStruct(out) } } // Parse body accordingly switch ctype { case MIMEApplicationJSON: return b.JSON(out) case MIMEApplicationMsgPack: return b.MsgPack(out) case MIMETextXML, MIMEApplicationXML: return b.XML(out) case MIMEApplicationCBOR: return b.CBOR(out) case MIMEApplicationForm, MIMEMultipartForm: return b.Form(out) } // No suitable content type found return ErrUnprocessableEntity } // All binds values from URI params, the request body, the query string, // headers, and cookies into the provided struct in precedence order. // Returns *BindError on parse failure (manual mode) or *Error with status 400 (auto-handling mode). func (b *Bind) All(out any) error { outVal := reflect.ValueOf(out) if outVal.Kind() != reflect.Ptr || outVal.Elem().Kind() != reflect.Struct { return ErrUnprocessableEntity } outElem := outVal.Elem() // Precedence: URL Params -> Body -> Query -> Headers -> Cookies sources := []func(any) error{b.URI} // Check if both Body and Content-Type are set if len(b.ctx.Request().Body()) > 0 && len(b.ctx.RequestCtx().Request.Header.ContentType()) > 0 { sources = append(sources, b.Body) } sources = append(sources, b.Query, b.Header, b.Cookie) prevSkip := b.skipValidation b.skipValidation = true // TODO: Support custom precedence with an optional binding_source tag // TODO: Create WithOverrideEmptyValues // Bind from each source, but only update unset fields for _, bindFunc := range sources { tempStruct := reflect.New(outElem.Type()).Interface() if err := bindFunc(tempStruct); err != nil { b.skipValidation = prevSkip return err } tempStructVal := reflect.ValueOf(tempStruct).Elem() mergeStruct(outElem, tempStructVal) } b.skipValidation = prevSkip return b.returnErr(b.validateStruct(out)) } func mergeStruct(dst, src reflect.Value) { dstFields := dst.NumField() for i := range dstFields { dstField := dst.Field(i) srcField := src.Field(i) // Skip if the destination field is already set if isZero(dstField.Interface()) { if dstField.CanSet() && srcField.IsValid() { dstField.Set(srcField) } } } } func isZero(value any) bool { v := reflect.ValueOf(value) return v.IsZero() } ================================================ FILE: bind_test.go ================================================ //nolint:wrapcheck,tagliatelle // We must not wrap errors in tests package fiber import ( "bytes" "compress/gzip" "encoding/json" "errors" "fmt" "mime/multipart" "net/http" "net/http/httptest" "reflect" "strings" "testing" "time" "github.com/fxamacker/cbor/v2" "github.com/gofiber/fiber/v3/binder" "github.com/gofiber/schema" "github.com/shamaton/msgpack/v3" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) const helloWorld = "hello world" // go test -run Test_returnErr -v func Test_returnErr(t *testing.T) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) err := c.Bind().WithAutoHandling().returnErr(nil) require.NoError(t, err) } // go test -run Test_AcquireReleaseBind -v func Test_AcquireReleaseBind(t *testing.T) { b := AcquireBind() b.dontHandleErrs = false b.skipValidation = true b.ctx = &DefaultCtx{} ReleaseBind(b) b2 := AcquireBind() require.Nil(t, b2.ctx) require.True(t, b2.dontHandleErrs) require.False(t, b2.skipValidation) ReleaseBind(b2) } // go test -run Test_BindError -v func Test_BindError_Unwrap(t *testing.T) { t.Parallel() inner := errors.New("inner") be := &BindError{Source: BindSourceQuery, Err: inner} require.ErrorIs(t, be, inner) require.Equal(t, inner, errors.Unwrap(be)) var extracted *BindError require.ErrorAs(t, be, &extracted) require.Equal(t, BindSourceQuery, extracted.Source) } func Test_BindError_ErrorFormat(t *testing.T) { t.Parallel() t.Run("with field", func(t *testing.T) { t.Parallel() be := &BindError{Source: BindSourceQuery, Field: "id", Err: errors.New("conversion failed")} require.Contains(t, be.Error(), `"id"`) require.Contains(t, be.Error(), "query") require.Contains(t, be.Error(), "conversion failed") }) t.Run("without field", func(t *testing.T) { t.Parallel() be := &BindError{Source: BindSourceBody, Field: "", Err: errors.New("parse failed")} require.Equal(t, "bind from body: parse failed", be.Error()) }) } func Test_BindError_FieldExtraction(t *testing.T) { t.Parallel() t.Run("QueryConversionError", func(t *testing.T) { t.Parallel() type Q struct { ID int `query:"id"` } app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(c) }) c.Request().URI().SetQueryString("id=notanint") err := c.Bind().Query(new(Q)) require.Error(t, err) var be *BindError require.ErrorAs(t, err, &be) require.Equal(t, BindSourceQuery, be.Source) require.Equal(t, "id", be.Field) require.ErrorAs(t, err, &MultiError{}) }) t.Run("ConversionError", func(t *testing.T) { t.Parallel() convErrBinder := &customBinderReturningError{ err: schema.ConversionError{Key: "count"}, mimeType: "application/x-conversion-error-test", } app := New() app.RegisterCustomBinder(convErrBinder) c := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(c) }) c.Request().SetBody([]byte("{}")) c.Request().Header.SetContentType("application/x-conversion-error-test") type D struct{ Name string } err := c.Bind().Body(new(D)) require.Error(t, err) var be *BindError require.ErrorAs(t, err, &be) require.Equal(t, BindSourceBody, be.Source) require.Equal(t, "count", be.Field) }) t.Run("JSONUnmarshalTypeError", func(t *testing.T) { t.Parallel() type J struct { Count int `json:"count"` } app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(c) }) c.Request().SetBody([]byte(`{"count":"notanint"}`)) c.Request().Header.SetContentType(MIMEApplicationJSON) err := c.Bind().Body(new(J)) require.Error(t, err) var be *BindError require.ErrorAs(t, err, &be) require.Equal(t, BindSourceBody, be.Source) require.Equal(t, "count", be.Field) var unmarshalErr *json.UnmarshalTypeError require.ErrorAs(t, err, &unmarshalErr) }) t.Run("UnknownKeyError", func(t *testing.T) { t.Parallel() unknownKeyBinder := &customBinderReturningError{ err: schema.UnknownKeyError{Key: "extra_field"}, mimeType: "application/x-unknown-key-test", } app := New() app.RegisterCustomBinder(unknownKeyBinder) c := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(c) }) c.Request().SetBody([]byte("{}")) c.Request().Header.SetContentType("application/x-unknown-key-test") type D struct{ Name string } err := c.Bind().Body(new(D)) require.Error(t, err) var be *BindError require.ErrorAs(t, err, &be) require.Equal(t, BindSourceBody, be.Source) require.Equal(t, "extra_field", be.Field) }) t.Run("EmptyFieldError", func(t *testing.T) { t.Parallel() emptyFieldBinder := &customBinderReturningError{ err: schema.EmptyFieldError{Key: "required_field"}, mimeType: "application/x-empty-field-test", } app := New() app.RegisterCustomBinder(emptyFieldBinder) c := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(c) }) c.Request().SetBody([]byte("{}")) c.Request().Header.SetContentType("application/x-empty-field-test") type D struct{ Name string } err := c.Bind().Body(new(D)) require.Error(t, err) var be *BindError require.ErrorAs(t, err, &be) require.Equal(t, BindSourceBody, be.Source) require.Equal(t, "required_field", be.Field) }) t.Run("MultiError", func(t *testing.T) { t.Parallel() type Q struct { A string `query:"a,required"` B string `query:"b,required"` } app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(c) }) c.Request().URI().SetQueryString("") err := c.Bind().Query(new(Q)) require.Error(t, err) var be *BindError require.ErrorAs(t, err, &be) require.Equal(t, BindSourceQuery, be.Source) require.Contains(t, []string{"a", "b"}, be.Field) require.ErrorAs(t, err, &MultiError{}) }) t.Run("NoRecognizedErrorType", func(t *testing.T) { t.Parallel() genericErrBinder := &customBinderReturningError{ err: errors.New("generic parse failure"), mimeType: "application/x-generic-error", } app := New() app.RegisterCustomBinder(genericErrBinder) c := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(c) }) c.Request().SetBody([]byte("{}")) c.Request().Header.SetContentType("application/x-generic-error") type D struct{ Name string } err := c.Bind().Body(new(D)) require.Error(t, err) var be *BindError require.ErrorAs(t, err, &be) require.Equal(t, BindSourceBody, be.Source) require.Empty(t, be.Field) }) } func Test_BindError_Sources(t *testing.T) { t.Parallel() t.Run("URI", func(t *testing.T) { t.Parallel() type U struct { ID int `uri:"id"` } app := New() app.Get("/user/:id", func(ctx Ctx) error { err := ctx.Bind().URI(new(U)) require.Error(t, err) var be *BindError require.ErrorAs(t, err, &be) require.Equal(t, BindSourceURI, be.Source) require.ErrorAs(t, err, &MultiError{}) return nil }) _, err := app.Test(httptest.NewRequest(http.MethodGet, "/user/notanint", http.NoBody)) require.NoError(t, err) }) t.Run("Query", func(t *testing.T) { t.Parallel() type Q struct { ID int `query:"id,required"` } app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(c) }) c.Request().URI().SetQueryString("") err := c.Bind().Query(new(Q)) require.Error(t, err) var be *BindError require.ErrorAs(t, err, &be) require.Equal(t, BindSourceQuery, be.Source) require.ErrorAs(t, err, &MultiError{}) }) t.Run("Header", func(t *testing.T) { t.Parallel() type H struct { ID int `header:"x-id,required"` } app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(c) }) c.Request().Header.Del("X-Id") err := c.Bind().Header(new(H)) require.Error(t, err) var be *BindError require.ErrorAs(t, err, &be) require.Equal(t, BindSourceHeader, be.Source) require.ErrorAs(t, err, &MultiError{}) }) t.Run("Cookie", func(t *testing.T) { t.Parallel() type C struct { ID int `cookie:"id,required"` } app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(c) }) c.Request().Header.DelCookie("id") err := c.Bind().Cookie(new(C)) require.Error(t, err) var be *BindError require.ErrorAs(t, err, &be) require.Equal(t, BindSourceCookie, be.Source) require.ErrorAs(t, err, &MultiError{}) }) t.Run("Body", func(t *testing.T) { t.Parallel() type J struct { X int `json:"x"` } app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(c) }) c.Request().SetBody([]byte(`{"x":"bad"}`)) c.Request().Header.SetContentType(MIMEApplicationJSON) err := c.Bind().Body(new(J)) require.Error(t, err) var be *BindError require.ErrorAs(t, err, &be) require.Equal(t, BindSourceBody, be.Source) var unmarshalErr *json.UnmarshalTypeError require.ErrorAs(t, err, &unmarshalErr) }) } func Test_BindError_All(t *testing.T) { t.Parallel() type Req struct { Name string `json:"name"` ID int `uri:"id" json:"id"` } t.Run("URIFailsFirst", func(t *testing.T) { t.Parallel() app := New() app.Get("/users/:id", func(ctx Ctx) error { err := ctx.Bind().All(new(Req)) require.Error(t, err) var be *BindError require.ErrorAs(t, err, &be) require.Equal(t, BindSourceURI, be.Source) require.ErrorAs(t, err, &MultiError{}) return nil }) req := httptest.NewRequest(http.MethodGet, "/users/notanint", bytes.NewReader([]byte(`{"name":"ok"}`))) req.Header.Set("Content-Type", MIMEApplicationJSON) _, err := app.Test(req) require.NoError(t, err) }) } // go test -run Test_Bind_Query -v func Test_Bind_Query(t *testing.T) { t.Parallel() app := New(Config{ EnableSplittingOnParsers: true, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) type Query struct { Name string Hobby []string ID int } c.Request().SetBody([]byte(``)) c.Request().Header.SetContentType("") c.Request().URI().SetQueryString("id=1&name=tom&hobby=basketball&hobby=football") q := new(Query) require.NoError(t, c.Bind().Query(q)) require.Len(t, q.Hobby, 2) c.Request().URI().SetQueryString("id=1&name=tom&hobby=basketball,football") q = new(Query) require.NoError(t, c.Bind().Query(q)) require.Len(t, q.Hobby, 2) c.Request().URI().SetQueryString("id=1&name=tom&hobby=soccer&hobby=basketball,football") q = new(Query) require.NoError(t, c.Bind().Query(q)) require.Len(t, q.Hobby, 3) empty := new(Query) c.Request().URI().SetQueryString("") require.NoError(t, c.Bind().Query(empty)) require.Empty(t, empty.Hobby) type Query2 struct { Name string Hobby string Default string `query:"default,default:hello"` FavouriteDrinks []string Empty []string Alloc []string Defaults []string `query:"defaults,default:hello|world"` No []int64 ID int Bool bool } c.Request().URI().SetQueryString("id=1&name=tom&hobby=basketball,football&favouriteDrinks=milo,coke,pepsi&alloc=&no=1") q2 := new(Query2) q2.Bool = true q2.Name = helloWorld require.NoError(t, c.Bind().Query(q2)) require.Equal(t, "basketball,football", q2.Hobby) require.True(t, q2.Bool) require.Equal(t, "tom", q2.Name) // check value get overwritten require.Equal(t, []string{"milo", "coke", "pepsi"}, q2.FavouriteDrinks) var nilSlice []string require.Equal(t, nilSlice, q2.Empty) require.Equal(t, []string{""}, q2.Alloc) require.Equal(t, []int64{1}, q2.No) require.Equal(t, "hello", q2.Default) require.Equal(t, []string{"hello", "world"}, q2.Defaults) type RequiredQuery struct { Name string `query:"name,required"` } rq := new(RequiredQuery) c.Request().URI().SetQueryString("") err := c.Bind().Query(rq) require.Error(t, err) require.Equal(t, "bind \"name\" from query: name is empty", err.Error()) require.ErrorAs(t, err, &MultiError{}) type ArrayQuery struct { Data []string } aq := new(ArrayQuery) c.Request().URI().SetQueryString("data[]=john&data[]=doe") require.NoError(t, c.Bind().Query(aq)) require.Len(t, aq.Data, 2) } // go test -run Test_Bind_Query_Map -v func Test_Bind_Query_Map(t *testing.T) { t.Parallel() app := New(Config{ EnableSplittingOnParsers: true, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetBody([]byte(``)) c.Request().Header.SetContentType("") c.Request().URI().SetQueryString("id=1&name=tom&hobby=basketball&hobby=football") q := make(map[string][]string) require.NoError(t, c.Bind().Query(&q)) require.Len(t, q["hobby"], 2) c.Request().URI().SetQueryString("id=1&name=tom&hobby=basketball,football") q = make(map[string][]string) require.NoError(t, c.Bind().Query(&q)) require.Len(t, q["hobby"], 2) c.Request().URI().SetQueryString("id=1&name=tom&hobby=soccer&hobby=basketball,football") q = make(map[string][]string) require.NoError(t, c.Bind().Query(&q)) require.Len(t, q["hobby"], 3) c.Request().URI().SetQueryString("id=1&name=tom&hobby=soccer") qq := make(map[string]string) require.NoError(t, c.Bind().Query(&qq)) require.Equal(t, "1", qq["id"]) empty := make(map[string][]string) c.Request().URI().SetQueryString("") require.NoError(t, c.Bind().Query(&empty)) require.Empty(t, empty["hobby"]) em := make(map[string][]int) c.Request().URI().SetQueryString("") require.ErrorIs(t, c.Bind().Query(&em), binder.ErrMapNotConvertible) } // go test -run Test_Bind_Query_WithSetParserDecoder -v func Test_Bind_Query_WithSetParserDecoder(t *testing.T) { type NonRFCTime time.Time nonRFCConverter := func(value string) reflect.Value { if v, err := time.Parse("2006-01-02", value); err == nil { return reflect.ValueOf(v) } return reflect.Value{} } nonRFCTime := binder.ParserType{ CustomType: NonRFCTime{}, Converter: nonRFCConverter, } binder.SetParserDecoder(binder.ParserConfig{ IgnoreUnknownKeys: true, ParserType: []binder.ParserType{nonRFCTime}, ZeroEmpty: true, SetAliasTag: "query", }) app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) type NonRFCTimeInput struct { Date NonRFCTime `query:"date"` Title string `query:"title"` Body string `query:"body"` } c.Request().SetBody([]byte(``)) c.Request().Header.SetContentType("") q := new(NonRFCTimeInput) c.Request().URI().SetQueryString("date=2021-04-10&title=CustomDateTest&Body=October") require.NoError(t, c.Bind().Query(q)) require.Equal(t, "CustomDateTest", q.Title) date := fmt.Sprintf("%v", q.Date) require.Equal(t, "{0 63753609600 }", date) require.Equal(t, "October", q.Body) c.Request().URI().SetQueryString("date=2021-04-10&title&Body=October") q = &NonRFCTimeInput{ Title: "Existing title", Body: "Existing Body", } require.NoError(t, c.Bind().Query(q)) require.Empty(t, q.Title) } // go test -run Test_Bind_Query_Schema -v func Test_Bind_Query_Schema(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) type Query1 struct { Name string `query:"name,required"` Nested struct { Age int `query:"age"` } `query:"nested,required"` } c.Request().SetBody([]byte(``)) c.Request().Header.SetContentType("") c.Request().URI().SetQueryString("name=tom&nested.age=10") q := new(Query1) require.NoError(t, c.Bind().Query(q)) c.Request().URI().SetQueryString("namex=tom&nested.age=10") q = new(Query1) err := c.Bind().Query(q) require.Error(t, err) require.Equal(t, "bind \"name\" from query: name is empty", err.Error()) require.ErrorAs(t, err, &MultiError{}) c.Request().URI().SetQueryString("name=tom&nested.agex=10") q = new(Query1) require.NoError(t, c.Bind().Query(q)) c.Request().URI().SetQueryString("name=tom&test.age=10") q = new(Query1) err = c.Bind().Query(q) require.Error(t, err) require.Equal(t, "bind \"nested\" from query: nested is empty", err.Error()) require.ErrorAs(t, err, &MultiError{}) type Query2 struct { Name string `query:"name"` Nested struct { Age int `query:"age,required"` } `query:"nested"` } c.Request().URI().SetQueryString("name=tom&nested.age=10") q2 := new(Query2) require.NoError(t, c.Bind().Query(q2)) c.Request().URI().SetQueryString("nested.age=10") q2 = new(Query2) require.NoError(t, c.Bind().Query(q2)) c.Request().URI().SetQueryString("nested.agex=10") q2 = new(Query2) err = c.Bind().Query(q2) require.Error(t, err) require.Equal(t, "bind \"nested.age\" from query: nested.age is empty", err.Error()) require.ErrorAs(t, err, &MultiError{}) c.Request().URI().SetQueryString("nested.agex=10") q2 = new(Query2) err = c.Bind().Query(q2) require.Error(t, err) require.Equal(t, "bind \"nested.age\" from query: nested.age is empty", err.Error()) require.ErrorAs(t, err, &MultiError{}) type Node struct { Next *Node `query:"next,required"` Value int `query:"val,required"` } c.Request().URI().SetQueryString("val=1&next.val=3") n := new(Node) require.NoError(t, c.Bind().Query(n)) require.Equal(t, 1, n.Value) require.Equal(t, 3, n.Next.Value) c.Request().URI().SetQueryString("next.val=2") n = new(Node) err = c.Bind().Query(n) require.Error(t, err) require.Equal(t, "bind \"val\" from query: val is empty", err.Error()) require.ErrorAs(t, err, &MultiError{}) c.Request().URI().SetQueryString("val=3&next.value=2") n = new(Node) n.Next = new(Node) require.NoError(t, c.Bind().Query(n)) require.Equal(t, 3, n.Value) require.Equal(t, 0, n.Next.Value) type Person struct { Name string `query:"name"` Age int `query:"age"` } type CollectionQuery struct { Data []Person `query:"data"` } c.Request().URI().SetQueryString("data[0][name]=john&data[0][age]=10&data[1][name]=doe&data[1][age]=12") cq := new(CollectionQuery) require.NoError(t, c.Bind().Query(cq)) require.Len(t, cq.Data, 2) require.Equal(t, "john", cq.Data[0].Name) require.Equal(t, 10, cq.Data[0].Age) require.Equal(t, "doe", cq.Data[1].Name) require.Equal(t, 12, cq.Data[1].Age) c.Request().URI().SetQueryString("data.0.name=john&data.0.age=10&data.1.name=doe&data.1.age=12") cq = new(CollectionQuery) require.NoError(t, c.Bind().Query(cq)) require.Len(t, cq.Data, 2) require.Equal(t, "john", cq.Data[0].Name) require.Equal(t, 10, cq.Data[0].Age) require.Equal(t, "doe", cq.Data[1].Name) require.Equal(t, 12, cq.Data[1].Age) } // go test -run Test_Bind_Header -v func Test_Bind_Header(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) type Header struct { Name string Hobby []string ID int } c.Request().SetBody([]byte(``)) c.Request().Header.SetContentType("") c.Request().Header.Add("id", "1") c.Request().Header.Add("Name", "John Doe") c.Request().Header.Add("Hobby", "golang,fiber") q := new(Header) require.NoError(t, c.Bind().Header(q)) require.Len(t, q.Hobby, 1) c.Request().Header.Del("hobby") c.Request().Header.Add("Hobby", "golang,fiber,go") q = new(Header) require.NoError(t, c.Bind().Header(q)) require.Len(t, q.Hobby, 1) empty := new(Header) c.Request().Header.Del("hobby") require.NoError(t, c.Bind().Query(empty)) require.Empty(t, empty.Hobby) type Header2 struct { Name string Hobby string FavouriteDrinks []string Empty []string Alloc []string No []int64 ID int Bool bool } c.Request().Header.Add("id", "2") c.Request().Header.Add("Name", "Jane Doe") c.Request().Header.Del("hobby") c.Request().Header.Add("Hobby", "go,fiber") c.Request().Header.Add("favouriteDrinks", "milo,coke,pepsi") c.Request().Header.Add("alloc", "") c.Request().Header.Add("no", "1") h2 := new(Header2) h2.Bool = true h2.Name = helloWorld require.NoError(t, c.Bind().Header(h2)) require.Equal(t, "go,fiber", h2.Hobby) require.True(t, h2.Bool) require.Equal(t, "Jane Doe", h2.Name) // check value get overwritten require.Equal(t, []string{"milo,coke,pepsi"}, h2.FavouriteDrinks) var nilSlice []string require.Equal(t, nilSlice, h2.Empty) require.Equal(t, []string{""}, h2.Alloc) require.Equal(t, []int64{1}, h2.No) type RequiredHeader struct { Name string `header:"name,required"` } rh := new(RequiredHeader) c.Request().Header.Del("name") err := c.Bind().Header(rh) require.Error(t, err) require.Equal(t, "bind \"name\" from header: name is empty", err.Error()) require.ErrorAs(t, err, &MultiError{}) } // go test -run Test_Bind_Header_Map -v func Test_Bind_Header_Map(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetBody([]byte(``)) c.Request().Header.SetContentType("") c.Request().Header.Add("id", "1") c.Request().Header.Add("Name", "John Doe") c.Request().Header.Add("Hobby", "golang,fiber") q := make(map[string][]string, 0) require.NoError(t, c.Bind().Header(&q)) require.Len(t, q["Hobby"], 1) c.Request().Header.Del("hobby") c.Request().Header.Add("Hobby", "golang,fiber,go") q = make(map[string][]string, 0) require.NoError(t, c.Bind().Header(&q)) require.Len(t, q["Hobby"], 1) empty := make(map[string][]string, 0) c.Request().Header.Del("hobby") require.NoError(t, c.Bind().Query(&empty)) require.Empty(t, empty["Hobby"]) } // go test -run Test_Bind_Header_WithSetParserDecoder -v func Test_Bind_Header_WithSetParserDecoder(t *testing.T) { type NonRFCTime time.Time nonRFCConverter := func(value string) reflect.Value { if v, err := time.Parse("2006-01-02", value); err == nil { return reflect.ValueOf(v) } return reflect.Value{} } nonRFCTime := binder.ParserType{ CustomType: NonRFCTime{}, Converter: nonRFCConverter, } binder.SetParserDecoder(binder.ParserConfig{ IgnoreUnknownKeys: true, ParserType: []binder.ParserType{nonRFCTime}, ZeroEmpty: true, SetAliasTag: "req", }) app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) type NonRFCTimeInput struct { Date NonRFCTime `req:"date"` Title string `req:"title"` Body string `req:"body"` } c.Request().SetBody([]byte(``)) c.Request().Header.SetContentType("") r := new(NonRFCTimeInput) c.Request().Header.Add("Date", "2021-04-10") c.Request().Header.Add("Title", "CustomDateTest") c.Request().Header.Add("Body", "October") require.NoError(t, c.Bind().Header(r)) require.Equal(t, "CustomDateTest", r.Title) date := fmt.Sprintf("%v", r.Date) require.Equal(t, "{0 63753609600 }", date) require.Equal(t, "October", r.Body) c.Request().Header.Add("Title", "") r = &NonRFCTimeInput{ Title: "Existing title", Body: "Existing Body", } require.NoError(t, c.Bind().Header(r)) require.Empty(t, r.Title) } // go test -run Test_Bind_Header_Schema -v func Test_Bind_Header_Schema(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) type Header1 struct { Name string `header:"Name,required"` Nested struct { Age int `header:"Age"` } `header:"Nested,required"` } c.Request().SetBody([]byte(``)) c.Request().Header.SetContentType("") c.Request().Header.Add("Name", "tom") c.Request().Header.Add("Nested.Age", "10") q := new(Header1) require.NoError(t, c.Bind().Header(q)) c.Request().Header.Del("Name") q = new(Header1) err := c.Bind().Header(q) require.Error(t, err) require.Equal(t, "bind \"Name\" from header: Name is empty", err.Error()) require.ErrorAs(t, err, &MultiError{}) c.Request().Header.Add("Name", "tom") c.Request().Header.Del("Nested.Age") c.Request().Header.Add("Nested.Agex", "10") q = new(Header1) require.NoError(t, c.Bind().Header(q)) c.Request().Header.Del("Nested.Agex") q = new(Header1) err = c.Bind().Header(q) require.Error(t, err) require.Equal(t, "bind \"Nested\" from header: Nested is empty", err.Error()) require.ErrorAs(t, err, &MultiError{}) c.Request().Header.Del("Nested.Agex") c.Request().Header.Del("Name") type Header2 struct { Name string `header:"Name"` Nested struct { Age int `header:"age,required"` } `header:"Nested"` } c.Request().Header.Add("Name", "tom") c.Request().Header.Add("Nested.Age", "10") h2 := new(Header2) require.NoError(t, c.Bind().Header(h2)) c.Request().Header.Del("Name") h2 = new(Header2) require.NoError(t, c.Bind().Header(h2)) c.Request().Header.Del("Name") c.Request().Header.Del("Nested.Age") c.Request().Header.Add("Nested.Agex", "10") h2 = new(Header2) err = c.Bind().Header(h2) require.Error(t, err) require.Equal(t, "bind \"Nested.age\" from header: Nested.age is empty", err.Error()) require.ErrorAs(t, err, &MultiError{}) type Node struct { Next *Node `header:"Next,required"` Value int `header:"Val,required"` } c.Request().Header.Add("Val", "1") c.Request().Header.Add("Next.Val", "3") n := new(Node) require.NoError(t, c.Bind().Header(n)) require.Equal(t, 1, n.Value) require.Equal(t, 3, n.Next.Value) c.Request().Header.Del("Val") n = new(Node) err = c.Bind().Header(n) require.Error(t, err) require.Equal(t, "bind \"Val\" from header: Val is empty", err.Error()) require.ErrorAs(t, err, &MultiError{}) c.Request().Header.Add("Val", "3") c.Request().Header.Del("Next.Val") c.Request().Header.Add("Next.Value", "2") n = new(Node) n.Next = new(Node) require.NoError(t, c.Bind().Header(n)) require.Equal(t, 3, n.Value) require.Equal(t, 0, n.Next.Value) } // go test -run Test_Bind_Resp_Header -v func Test_Bind_RespHeader(t *testing.T) { t.Parallel() app := New(Config{ EnableSplittingOnParsers: true, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) type Header struct { Name string Hobby []string ID int } c.Request().SetBody([]byte(``)) c.Request().Header.SetContentType("") c.Response().Header.Add("id", "1") c.Response().Header.Add("Name", "John Doe") c.Response().Header.Add("Hobby", "golang,fiber") q := new(Header) require.NoError(t, c.Bind().RespHeader(q)) require.Len(t, q.Hobby, 2) c.Response().Header.Del("hobby") c.Response().Header.Add("Hobby", "golang,fiber,go") q = new(Header) require.NoError(t, c.Bind().RespHeader(q)) require.Len(t, q.Hobby, 3) empty := new(Header) c.Response().Header.Del("hobby") require.NoError(t, c.Bind().Query(empty)) require.Empty(t, empty.Hobby) type Header2 struct { Name string Hobby string FavouriteDrinks []string Empty []string Alloc []string No []int64 ID int Bool bool } c.Response().Header.Add("id", "2") c.Response().Header.Add("Name", "Jane Doe") c.Response().Header.Del("hobby") c.Response().Header.Add("Hobby", "go,fiber") c.Response().Header.Add("favouriteDrinks", "milo,coke,pepsi") c.Response().Header.Add("alloc", "") c.Response().Header.Add("no", "1") h2 := new(Header2) h2.Bool = true h2.Name = helloWorld require.NoError(t, c.Bind().RespHeader(h2)) require.Equal(t, "go,fiber", h2.Hobby) require.True(t, h2.Bool) require.Equal(t, "Jane Doe", h2.Name) // check value get overwritten require.Equal(t, []string{"milo", "coke", "pepsi"}, h2.FavouriteDrinks) var nilSlice []string require.Equal(t, nilSlice, h2.Empty) require.Equal(t, []string{""}, h2.Alloc) require.Equal(t, []int64{1}, h2.No) type RequiredHeader struct { Name string `respHeader:"name,required"` } rh := new(RequiredHeader) c.Response().Header.Del("name") err := c.Bind().RespHeader(rh) require.Error(t, err) require.Equal(t, "bind \"name\" from respHeader: name is empty", err.Error()) require.ErrorAs(t, err, &MultiError{}) } // go test -run Test_Bind_RespHeader_Map -v func Test_Bind_RespHeader_Map(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetBody([]byte(``)) c.Request().Header.SetContentType("") c.Response().Header.Add("id", "1") c.Response().Header.Add("Name", "John Doe") c.Response().Header.Add("Hobby", "golang,fiber") q := make(map[string][]string, 0) require.NoError(t, c.Bind().RespHeader(&q)) require.Len(t, q["Hobby"], 1) c.Response().Header.Del("hobby") c.Response().Header.Add("Hobby", "golang,fiber,go") q = make(map[string][]string, 0) require.NoError(t, c.Bind().RespHeader(&q)) require.Len(t, q["Hobby"], 1) empty := make(map[string][]string, 0) c.Response().Header.Del("hobby") require.NoError(t, c.Bind().Query(&empty)) require.Empty(t, empty["Hobby"]) } // go test -v -run=^$ -bench=Benchmark_Bind_Query -benchmem -count=4 func Benchmark_Bind_Query(b *testing.B) { var err error app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) type Query struct { Name string Hobby []string ID int } c.Request().SetBody([]byte(``)) c.Request().Header.SetContentType("") c.Request().URI().SetQueryString("id=1&name=tom&hobby=basketball&hobby=football") q := new(Query) b.ReportAllocs() for b.Loop() { err = c.Bind().Query(q) } require.NoError(b, err) require.Equal(b, "tom", q.Name) require.Equal(b, 1, q.ID) require.Len(b, q.Hobby, 2) } // go test -v -run=^$ -bench=Benchmark_Bind_Query_Default -benchmem -count=4 func Benchmark_Bind_Query_Default(b *testing.B) { var err error app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) type Query struct { Name string `query:"name,default:tom"` Hobby []string `query:"hobby,default:football|basketball"` ID int `query:"id,default:1"` } c.Request().SetBody([]byte(``)) c.Request().Header.SetContentType("") c.Request().URI().SetQueryString("") q := new(Query) b.ReportAllocs() for b.Loop() { *q = Query{} err = c.Bind().Query(q) } require.NoError(b, err) require.Equal(b, "tom", q.Name) require.Equal(b, 1, q.ID) require.Len(b, q.Hobby, 2) } // go test -v -run=^$ -bench=Benchmark_Bind_Query_Map -benchmem -count=4 func Benchmark_Bind_Query_Map(b *testing.B) { var err error app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetBody([]byte(``)) c.Request().Header.SetContentType("") c.Request().URI().SetQueryString("id=1&name=tom&hobby=basketball&hobby=football") q := make(map[string][]string) b.ReportAllocs() for b.Loop() { err = c.Bind().Query(&q) } require.NoError(b, err) } // go test -v -run=^$ -bench=Benchmark_Bind_Query_WithParseParam -benchmem -count=4 func Benchmark_Bind_Query_WithParseParam(b *testing.B) { var err error app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) type Person struct { Name string `query:"name"` Age int `query:"age"` } type CollectionQuery struct { Data []Person `query:"data"` } c.Request().SetBody([]byte(``)) c.Request().Header.SetContentType("") c.Request().URI().SetQueryString("data[0][name]=john&data[0][age]=10") cq := new(CollectionQuery) b.ReportAllocs() for b.Loop() { err = c.Bind().Query(cq) } require.NoError(b, err) } // go test -v -run=^$ -bench=Benchmark_Bind_Query_Comma -benchmem -count=4 func Benchmark_Bind_Query_Comma(b *testing.B) { var err error app := New(Config{ EnableSplittingOnParsers: true, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) type Query struct { Name string Hobby []string ID int } c.Request().SetBody([]byte(``)) c.Request().Header.SetContentType("") c.Request().URI().SetQueryString("id=1&name=tom&hobby=basketball,football") q := new(Query) b.ReportAllocs() for b.Loop() { err = c.Bind().Query(q) } require.NoError(b, err) } // go test -v -run=^$ -bench=Benchmark_Bind_Header -benchmem -count=4 func Benchmark_Bind_Header(b *testing.B) { var err error app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) type ReqHeader struct { Name string Hobby []string ID int } c.Request().SetBody([]byte(``)) c.Request().Header.SetContentType("") c.Request().Header.Add("id", "1") c.Request().Header.Add("Name", "John Doe") c.Request().Header.Add("Hobby", "golang,fiber") q := new(ReqHeader) b.ReportAllocs() for b.Loop() { err = c.Bind().Header(q) } require.NoError(b, err) } // go test -v -run=^$ -bench=Benchmark_Bind_Header_Map -benchmem -count=4 func Benchmark_Bind_Header_Map(b *testing.B) { var err error app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetBody([]byte(``)) c.Request().Header.SetContentType("") c.Request().Header.Add("id", "1") c.Request().Header.Add("Name", "John Doe") c.Request().Header.Add("Hobby", "golang,fiber") q := make(map[string][]string) b.ReportAllocs() for b.Loop() { err = c.Bind().Header(&q) } require.NoError(b, err) } // go test -v -run=^$ -bench=Benchmark_Bind_RespHeader -benchmem -count=4 func Benchmark_Bind_RespHeader(b *testing.B) { var err error app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) type ReqHeader struct { Name string Hobby []string ID int } c.Request().SetBody([]byte(``)) c.Request().Header.SetContentType("") c.Response().Header.Add("id", "1") c.Response().Header.Add("Name", "John Doe") c.Response().Header.Add("Hobby", "golang,fiber") q := new(ReqHeader) b.ReportAllocs() for b.Loop() { err = c.Bind().RespHeader(q) } require.NoError(b, err) } // go test -v -run=^$ -bench=Benchmark_Bind_RespHeader_Map -benchmem -count=4 func Benchmark_Bind_RespHeader_Map(b *testing.B) { var err error app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetBody([]byte(``)) c.Request().Header.SetContentType("") c.Response().Header.Add("id", "1") c.Response().Header.Add("Name", "John Doe") c.Response().Header.Add("Hobby", "golang,fiber") q := make(map[string][]string) b.ReportAllocs() for b.Loop() { err = c.Bind().RespHeader(&q) } require.NoError(b, err) } // go test -run Test_Bind_Body_Compression func Test_Bind_Body(t *testing.T) { t.Parallel() app := New(Config{ MsgPackEncoder: msgpack.Marshal, MsgPackDecoder: msgpack.Unmarshal, CBOREncoder: cbor.Marshal, CBORDecoder: cbor.Unmarshal, }) reqBody := []byte(`{"name":"john"}`) type Demo struct { Name string `json:"name" xml:"name" form:"name" query:"name" msgpack:"name"` Names []string `json:"names" xml:"names" form:"names" query:"names" msgpack:"names"` } // Helper function to test compressed bodies testCompressedBody := func(t *testing.T, compressedBody []byte, encoding string) { t.Helper() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.SetContentType(MIMEApplicationJSON) c.Request().Header.Set(fasthttp.HeaderContentEncoding, encoding) c.Request().SetBody(compressedBody) c.Request().Header.SetContentLength(len(compressedBody)) d := new(Demo) require.NoError(t, c.Bind().Body(d)) require.Equal(t, "john", d.Name) c.Request().Header.Del(fasthttp.HeaderContentEncoding) } t.Run("Gzip", func(t *testing.T) { t.Parallel() compressedBody := fasthttp.AppendGzipBytes(nil, reqBody) require.NotEqual(t, reqBody, compressedBody) testCompressedBody(t, compressedBody, "gzip") }) t.Run("Deflate", func(t *testing.T) { t.Parallel() compressedBody := fasthttp.AppendDeflateBytes(nil, reqBody) require.NotEqual(t, reqBody, compressedBody) testCompressedBody(t, compressedBody, "deflate") }) t.Run("Brotli", func(t *testing.T) { t.Parallel() compressedBody := fasthttp.AppendBrotliBytes(nil, reqBody) require.NotEqual(t, reqBody, compressedBody) testCompressedBody(t, compressedBody, "br") }) t.Run("Zstd", func(t *testing.T) { t.Parallel() compressedBody := fasthttp.AppendZstdBytes(nil, reqBody) require.NotEqual(t, reqBody, compressedBody) testCompressedBody(t, compressedBody, "zstd") }) testDecodeParser := func(t *testing.T, contentType string, body []byte) { t.Helper() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.SetContentType(contentType) c.Request().SetBody(body) c.Request().Header.SetContentLength(len(body)) d := new(Demo) require.NoError(t, c.Bind().Body(d)) require.Equal(t, "john", d.Name) } testErrorParser := func(t *testing.T, contentType string, body []byte) { t.Helper() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.SetContentType(contentType) c.Request().SetBody(body) c.Request().Header.SetContentLength(len(body)) d := new(Demo) err := c.Bind().Body(d) require.Error(t, err) } t.Run("JSON", func(t *testing.T) { testDecodeParser(t, MIMEApplicationJSON, []byte(`{"name":"john"}`)) }) t.Run("MsgPack", func(t *testing.T) { testDecodeParser(t, MIMEApplicationMsgPack, []byte{0x81, 0xa4, 0x6e, 0x61, 0x6d, 0x65, 0xa4, 0x6a, 0x6f, 0x68, 0x6e}) testErrorParser(t, MIMEApplicationMsgPack, []byte{0xFF, 0xFF}) }) t.Run("CBOR", func(t *testing.T) { enc, err := cbor.Marshal(&Demo{Name: "john"}) if err != nil { t.Error(err) } testDecodeParser(t, MIMEApplicationCBOR, enc) // Test invalid CBOR data t.Run("Invalid", func(t *testing.T) { invalidData := []byte{0xFF, 0xFF} // Invalid CBOR data c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.SetContentType(MIMEApplicationCBOR) c.Request().SetBody(invalidData) d := new(Demo) require.Error(t, c.Bind().Body(d)) }) }) t.Run("XML", func(t *testing.T) { testDecodeParser(t, MIMEApplicationXML, []byte(`john`)) }) t.Run("Form", func(t *testing.T) { testDecodeParser(t, MIMEApplicationForm, []byte("name=john")) }) t.Run("MultipartForm", func(t *testing.T) { testDecodeParser(t, MIMEMultipartForm+`;boundary="b"`, []byte("--b\r\nContent-Disposition: form-data; name=\"name\"\r\n\r\njohn\r\n--b--")) }) testDecodeParserError := func(t *testing.T, contentType, body string) { t.Helper() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.SetContentType(contentType) c.Request().SetBody([]byte(body)) c.Request().Header.SetContentLength(len(body)) require.Error(t, c.Bind().Body(nil)) } t.Run("ErrorInvalidContentType", func(t *testing.T) { testDecodeParserError(t, "invalid-content-type", "") }) t.Run("ErrorUnknownContentTypeReturnsUnprocessableEntity", func(t *testing.T) { t.Parallel() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.SetContentType("application/unknown-type") c.Request().SetBody([]byte("some body")) c.Request().Header.SetContentLength(9) d := new(Demo) err := c.Bind().Body(d) require.Error(t, err) require.ErrorIs(t, err, ErrUnprocessableEntity) }) t.Run("ErrorMalformedMultipart", func(t *testing.T) { testDecodeParserError(t, MIMEMultipartForm+`;boundary="b"`, "--b") }) type CollectionQuery struct { Data []Demo `query:"data"` } t.Run("MultipartCollectionQueryDotNotation", func(t *testing.T) { c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Reset() buf := &bytes.Buffer{} writer := multipart.NewWriter(buf) require.NoError(t, writer.WriteField("data.0.name", "john")) require.NoError(t, writer.WriteField("data.1.name", "doe")) require.NoError(t, writer.Close()) c.Request().Header.SetContentType(writer.FormDataContentType()) c.Request().SetBody(buf.Bytes()) c.Request().Header.SetContentLength(len(c.Body())) cq := new(CollectionQuery) require.NoError(t, c.Bind().Body(cq)) require.Len(t, cq.Data, 2) require.Equal(t, "john", cq.Data[0].Name) require.Equal(t, "doe", cq.Data[1].Name) }) t.Run("MultipartCollectionQuerySquareBrackets", func(t *testing.T) { c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Reset() buf := &bytes.Buffer{} writer := multipart.NewWriter(buf) require.NoError(t, writer.WriteField("data[0][name]", "john")) require.NoError(t, writer.WriteField("data[1][name]", "doe")) require.NoError(t, writer.Close()) c.Request().Header.SetContentType(writer.FormDataContentType()) c.Request().SetBody(buf.Bytes()) c.Request().Header.SetContentLength(len(c.Body())) cq := new(CollectionQuery) require.NoError(t, c.Bind().Body(cq)) require.Len(t, cq.Data, 2) require.Equal(t, "john", cq.Data[0].Name) require.Equal(t, "doe", cq.Data[1].Name) }) t.Run("CollectionQuerySquareBrackets", func(t *testing.T) { c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Reset() c.Request().Header.SetContentType(MIMEApplicationForm) c.Request().SetBody([]byte("data[0][name]=john&data[1][name]=doe")) c.Request().Header.SetContentLength(len(c.Body())) cq := new(CollectionQuery) require.NoError(t, c.Bind().Body(cq)) require.Len(t, cq.Data, 2) require.Equal(t, "john", cq.Data[0].Name) require.Equal(t, "doe", cq.Data[1].Name) }) t.Run("CollectionQueryDotNotation", func(t *testing.T) { c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Reset() c.Request().Header.SetContentType(MIMEApplicationForm) c.Request().SetBody([]byte("data.0.name=john&data.1.name=doe")) c.Request().Header.SetContentLength(len(c.Body())) cq := new(CollectionQuery) require.NoError(t, c.Bind().Body(cq)) require.Len(t, cq.Data, 2) require.Equal(t, "john", cq.Data[0].Name) require.Equal(t, "doe", cq.Data[1].Name) }) } // go test -run Test_Bind_Body_WithSetParserDecoder func Test_Bind_Body_WithSetParserDecoder(t *testing.T) { type CustomTime time.Time timeConverter := func(value string) reflect.Value { if v, err := time.Parse("2006-01-02", value); err == nil { return reflect.ValueOf(v) } return reflect.Value{} } customTime := binder.ParserType{ CustomType: CustomTime{}, Converter: timeConverter, } binder.SetParserDecoder(binder.ParserConfig{ IgnoreUnknownKeys: true, ParserType: []binder.ParserType{customTime}, ZeroEmpty: true, SetAliasTag: "form", }) app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) type Demo struct { Date CustomTime `form:"date"` Title string `form:"title"` Body string `form:"body"` } testDecodeParser := func(contentType, body string) { c.Request().Header.SetContentType(contentType) c.Request().SetBody([]byte(body)) c.Request().Header.SetContentLength(len(body)) d := Demo{ Title: "Existing title", Body: "Existing Body", } require.NoError(t, c.Bind().Body(&d)) date := fmt.Sprintf("%v", d.Date) require.Equal(t, "{0 63743587200 }", date) require.Empty(t, d.Title) require.Equal(t, "New Body", d.Body) } testDecodeParser(MIMEApplicationForm, "date=2020-12-15&title=&body=New Body") testDecodeParser(MIMEMultipartForm+`; boundary="b"`, "--b\r\nContent-Disposition: form-data; name=\"date\"\r\n\r\n2020-12-15\r\n--b\r\nContent-Disposition: form-data; name=\"title\"\r\n\r\n\r\n--b\r\nContent-Disposition: form-data; name=\"body\"\r\n\r\nNew Body\r\n--b--") } // go test -v -run=^$ -bench=Benchmark_Bind_Body_JSON -benchmem -count=4 func Benchmark_Bind_Body_JSON(b *testing.B) { var err error app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) type Demo struct { Name string `json:"name"` } body, err := json.Marshal(&Demo{Name: "john"}) if err != nil { b.Error(err) } c.Request().SetBody(body) c.Request().Header.SetContentType(MIMEApplicationJSON) c.Request().Header.SetContentLength(len(body)) d := new(Demo) b.ReportAllocs() for b.Loop() { err = c.Bind().Body(d) } require.NoError(b, err) require.Equal(b, "john", d.Name) } // go test -v -run=^$ -bench=Benchmark_Bind_Body_MsgPack -benchmem -count=4 func Benchmark_Bind_Body_MsgPack(b *testing.B) { var err error app := New( Config{ MsgPackEncoder: msgpack.Marshal, MsgPackDecoder: msgpack.Unmarshal, CBOREncoder: cbor.Marshal, CBORDecoder: cbor.Unmarshal, }, ) c := app.AcquireCtx(&fasthttp.RequestCtx{}) type Demo struct { Name string `msgpack:"name"` } body := []byte{0x81, 0xa4, 0x6e, 0x61, 0x6d, 0x65, 0xa4, 0x6a, 0x6f, 0x68, 0x6e} // {"name":"john"} c.Request().SetBody(body) c.Request().Header.SetContentType(MIMEApplicationMsgPack) c.Request().Header.SetContentLength(len(body)) d := new(Demo) b.ReportAllocs() for b.Loop() { err = c.Bind().Body(d) } require.NoError(b, err) require.Equal(b, "john", d.Name) } // go test -v -run=^$ -bench=Benchmark_Bind_Body_XML -benchmem -count=4 func Benchmark_Bind_Body_XML(b *testing.B) { var err error app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) type Demo struct { Name string `xml:"name"` } body := []byte("john") c.Request().SetBody(body) c.Request().Header.SetContentType(MIMEApplicationXML) c.Request().Header.SetContentLength(len(body)) d := new(Demo) b.ReportAllocs() for b.Loop() { err = c.Bind().Body(d) } require.NoError(b, err) require.Equal(b, "john", d.Name) } // go test -v -run=^$ -bench=Benchmark_Bind_Body_CBOR -benchmem -count=4 func Benchmark_Bind_Body_CBOR(b *testing.B) { var err error app := New(Config{ CBOREncoder: cbor.Marshal, CBORDecoder: cbor.Unmarshal, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) type Demo struct { Name string `json:"name"` } body, err := cbor.Marshal(&Demo{Name: "john"}) if err != nil { b.Error(err) } c.Request().SetBody(body) c.Request().Header.SetContentType(MIMEApplicationCBOR) c.Request().Header.SetContentLength(len(body)) d := new(Demo) b.ReportAllocs() for b.Loop() { err = c.Bind().Body(d) } require.NoError(b, err) require.Equal(b, "john", d.Name) } // go test -v -run=^$ -bench=Benchmark_Bind_Body_Form -benchmem -count=4 func Benchmark_Bind_Body_Form(b *testing.B) { var err error app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) type Demo struct { Name string `form:"name"` } body := []byte("name=john") c.Request().SetBody(body) c.Request().Header.SetContentType(MIMEApplicationForm) c.Request().Header.SetContentLength(len(body)) d := new(Demo) b.ReportAllocs() for b.Loop() { err = c.Bind().Body(d) } require.NoError(b, err) require.Equal(b, "john", d.Name) } // go test -v -run=^$ -bench=Benchmark_Bind_Body_MultipartForm -benchmem -count=4 func Benchmark_Bind_Body_MultipartForm(b *testing.B) { var err error app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) type Demo struct { Name string `form:"name"` } buf := &bytes.Buffer{} writer := multipart.NewWriter(buf) require.NoError(b, writer.WriteField("name", "john")) require.NoError(b, writer.Close()) body := buf.Bytes() c.Request().SetBody(body) c.Request().Header.SetContentType(MIMEMultipartForm + `;boundary=` + writer.Boundary()) c.Request().Header.SetContentLength(len(body)) d := new(Demo) b.ReportAllocs() for b.Loop() { err = c.Bind().Body(d) } require.NoError(b, err) require.Equal(b, "john", d.Name) } // go test -v -run=^$ -bench=Benchmark_Bind_Body_MultipartForm_Nested -benchmem -count=4 func Benchmark_Bind_Body_MultipartForm_Nested(b *testing.B) { var err error app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) type Person struct { Name string `form:"name"` Age int `form:"age"` } type Demo struct { Name string `form:"name"` Persons []Person `form:"persons"` } buf := &bytes.Buffer{} writer := multipart.NewWriter(buf) require.NoError(b, writer.WriteField("name", "john")) require.NoError(b, writer.WriteField("persons.0.name", "john")) require.NoError(b, writer.WriteField("persons[0][age]", "10")) require.NoError(b, writer.WriteField("persons[1][name]", "doe")) require.NoError(b, writer.WriteField("persons.1.age", "20")) require.NoError(b, writer.Close()) body := buf.Bytes() c.Request().SetBody(body) c.Request().Header.SetContentType(MIMEMultipartForm + `;boundary=` + writer.Boundary()) c.Request().Header.SetContentLength(len(body)) d := new(Demo) b.ReportAllocs() for b.Loop() { err = c.Bind().Body(d) } require.NoError(b, err) require.Equal(b, "john", d.Name) require.Equal(b, "john", d.Persons[0].Name) require.Equal(b, 10, d.Persons[0].Age) require.Equal(b, "doe", d.Persons[1].Name) require.Equal(b, 20, d.Persons[1].Age) } // go test -v -run=^$ -bench=Benchmark_Bind_Body_Form_Map -benchmem -count=4 func Benchmark_Bind_Body_Form_Map(b *testing.B) { var err error app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) body := []byte("name=john") c.Request().SetBody(body) c.Request().Header.SetContentType(MIMEApplicationForm) c.Request().Header.SetContentLength(len(body)) d := make(map[string]string) b.ReportAllocs() for b.Loop() { err = c.Bind().Body(&d) } require.NoError(b, err) require.Equal(b, "john", d["name"]) } // go test -run Test_Bind_URI func Test_Bind_URI(t *testing.T) { t.Parallel() app := New() app.Get("/test1/userId/role/:roleId", func(c Ctx) error { type Demo struct { UserID uint `uri:"userId"` RoleID uint `uri:"roleId"` } d := new(Demo) if err := c.Bind().URI(d); err != nil { t.Fatal(err) } require.Equal(t, uint(111), d.UserID) require.Equal(t, uint(222), d.RoleID) return nil }) _, err := app.Test(httptest.NewRequest(MethodGet, "/test1/111/role/222", http.NoBody)) require.NoError(t, err) _, err = app.Test(httptest.NewRequest(MethodGet, "/test2/111/role/222", http.NoBody)) require.NoError(t, err) } // go test -run Test_Bind_URI_Map func Test_Bind_URI_Map(t *testing.T) { t.Parallel() app := New() app.Get("/test1/userId/role/:roleId", func(c Ctx) error { d := make(map[string]string) if err := c.Bind().URI(&d); err != nil { t.Fatal(err) } require.Equal(t, uint(111), d["userId"]) require.Equal(t, uint(222), d["roleId"]) return nil }) _, err := app.Test(httptest.NewRequest(MethodGet, "/test1/111/role/222", http.NoBody)) require.NoError(t, err) _, err = app.Test(httptest.NewRequest(MethodGet, "/test2/111/role/222", http.NoBody)) require.NoError(t, err) } // go test -v -run=^$ -bench=Benchmark_Bind_URI -benchmem -count=4 func Benchmark_Bind_URI(b *testing.B) { var err error app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed c.route = &Route{ Params: []string{ "param1", "param2", "param3", "param4", }, } c.values = [maxParams]string{ "john", "doe", "is", "awesome", } var res struct { Param1 string `uri:"param1"` Param2 string `uri:"param2"` Param3 string `uri:"param3"` Param4 string `uri:"param4"` } b.ReportAllocs() for b.Loop() { err = c.Bind().URI(&res) } require.NoError(b, err) require.Equal(b, "john", res.Param1) require.Equal(b, "doe", res.Param2) require.Equal(b, "is", res.Param3) require.Equal(b, "awesome", res.Param4) } // go test -v -run=^$ -bench=Benchmark_Bind_URI_Map -benchmem -count=4 func Benchmark_Bind_URI_Map(b *testing.B) { var err error app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed c.route = &Route{ Params: []string{ "param1", "param2", "param3", "param4", }, } c.values = [maxParams]string{ "john", "doe", "is", "awesome", } res := make(map[string]string) b.ReportAllocs() for b.Loop() { err = c.Bind().URI(&res) } require.NoError(b, err) require.Equal(b, "john", res["param1"]) require.Equal(b, "doe", res["param2"]) require.Equal(b, "is", res["param3"]) require.Equal(b, "awesome", res["param4"]) } // go test -run Test_Bind_Cookie -v func Test_Bind_Cookie(t *testing.T) { t.Parallel() app := New(Config{ EnableSplittingOnParsers: true, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) type Cookie struct { Name string Hobby []string ID int } c.Request().SetBody([]byte(``)) c.Request().Header.SetContentType("") c.Request().Header.SetCookie("id", "1") c.Request().Header.SetCookie("Name", "John Doe") c.Request().Header.SetCookie("Hobby", "golang,fiber") q := new(Cookie) require.NoError(t, c.Bind().Cookie(q)) require.Len(t, q.Hobby, 2) c.Request().Header.DelCookie("hobby") c.Request().Header.SetCookie("Hobby", "golang,fiber,go") q = new(Cookie) require.NoError(t, c.Bind().Cookie(q)) require.Len(t, q.Hobby, 3) empty := new(Cookie) c.Request().Header.DelCookie("hobby") require.NoError(t, c.Bind().Query(empty)) require.Empty(t, empty.Hobby) type Cookie2 struct { Name string Hobby string FavouriteDrinks []string Empty []string Alloc []string No []int64 ID int Bool bool } c.Request().Header.SetCookie("id", "2") c.Request().Header.SetCookie("Name", "Jane Doe") c.Request().Header.DelCookie("hobby") c.Request().Header.SetCookie("Hobby", "go,fiber") c.Request().Header.SetCookie("favouriteDrinks", "milo,coke,pepsi") c.Request().Header.SetCookie("alloc", "") c.Request().Header.SetCookie("no", "1") h2 := new(Cookie2) h2.Bool = true h2.Name = helloWorld require.NoError(t, c.Bind().Cookie(h2)) require.Equal(t, "go,fiber", h2.Hobby) require.True(t, h2.Bool) require.Equal(t, "Jane Doe", h2.Name) // check value get overwritten require.Equal(t, []string{"milo", "coke", "pepsi"}, h2.FavouriteDrinks) var nilSlice []string require.Equal(t, nilSlice, h2.Empty) require.Equal(t, []string{""}, h2.Alloc) require.Equal(t, []int64{1}, h2.No) type RequiredCookie struct { Name string `cookie:"name,required"` } rh := new(RequiredCookie) c.Request().Header.DelCookie("name") err := c.Bind().Cookie(rh) require.Error(t, err) require.Equal(t, "bind \"name\" from cookie: name is empty", err.Error()) require.ErrorAs(t, err, &MultiError{}) } // go test -run Test_Bind_Cookie_Map -v func Test_Bind_Cookie_Map(t *testing.T) { t.Parallel() app := New(Config{ EnableSplittingOnParsers: true, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetBody([]byte(``)) c.Request().Header.SetContentType("") c.Request().Header.SetCookie("id", "1") c.Request().Header.SetCookie("Name", "John Doe") c.Request().Header.SetCookie("Hobby", "golang,fiber") q := make(map[string][]string) require.NoError(t, c.Bind().Cookie(&q)) require.Len(t, q["Hobby"], 2) c.Request().Header.DelCookie("hobby") c.Request().Header.SetCookie("Hobby", "golang,fiber,go") q = make(map[string][]string) require.NoError(t, c.Bind().Cookie(&q)) require.Len(t, q["Hobby"], 3) empty := make(map[string][]string) c.Request().Header.DelCookie("hobby") require.NoError(t, c.Bind().Query(&empty)) require.Empty(t, empty["Hobby"]) } // go test -run Test_Bind_Cookie_WithSetParserDecoder -v func Test_Bind_Cookie_WithSetParserDecoder(t *testing.T) { type NonRFCTime time.Time nonRFCConverter := func(value string) reflect.Value { if v, err := time.Parse("2006-01-02", value); err == nil { return reflect.ValueOf(v) } return reflect.Value{} } nonRFCTime := binder.ParserType{ CustomType: NonRFCTime{}, Converter: nonRFCConverter, } binder.SetParserDecoder(binder.ParserConfig{ IgnoreUnknownKeys: true, ParserType: []binder.ParserType{nonRFCTime}, ZeroEmpty: true, SetAliasTag: "cerez", }) app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) type NonRFCTimeInput struct { Date NonRFCTime `cerez:"date"` Title string `cerez:"title"` Body string `cerez:"body"` } c.Request().SetBody([]byte(``)) c.Request().Header.SetContentType("") r := new(NonRFCTimeInput) c.Request().Header.SetCookie("Date", "2021-04-10") c.Request().Header.SetCookie("Title", "CustomDateTest") c.Request().Header.SetCookie("Body", "October") require.NoError(t, c.Bind().Cookie(r)) require.Equal(t, "CustomDateTest", r.Title) date := fmt.Sprintf("%v", r.Date) require.Equal(t, "{0 63753609600 }", date) require.Equal(t, "October", r.Body) c.Request().Header.SetCookie("Title", "") r = &NonRFCTimeInput{ Title: "Existing title", Body: "Existing Body", } require.NoError(t, c.Bind().Cookie(r)) require.Empty(t, r.Title) } // go test -run Test_Bind_Cookie_Schema -v func Test_Bind_Cookie_Schema(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) type Cookie1 struct { Name string `cookie:"Name,required"` Nested struct { Age int `cookie:"Age"` } `cookie:"Nested,required"` } c.Request().SetBody([]byte(``)) c.Request().Header.SetContentType("") c.Request().Header.SetCookie("Name", "tom") c.Request().Header.SetCookie("Nested.Age", "10") q := new(Cookie1) require.NoError(t, c.Bind().Cookie(q)) c.Request().Header.DelCookie("Name") q = new(Cookie1) err := c.Bind().Cookie(q) require.Error(t, err) require.Equal(t, "bind \"Name\" from cookie: Name is empty", err.Error()) require.ErrorAs(t, err, &MultiError{}) c.Request().Header.SetCookie("Name", "tom") c.Request().Header.DelCookie("Nested.Age") c.Request().Header.SetCookie("Nested.Agex", "10") q = new(Cookie1) require.NoError(t, c.Bind().Cookie(q)) c.Request().Header.DelCookie("Nested.Agex") q = new(Cookie1) err = c.Bind().Cookie(q) require.Error(t, err) require.Equal(t, "bind \"Nested\" from cookie: Nested is empty", err.Error()) require.ErrorAs(t, err, &MultiError{}) c.Request().Header.DelCookie("Nested.Agex") c.Request().Header.DelCookie("Name") type Cookie2 struct { Name string `cookie:"Name"` Nested struct { Age int `cookie:"Age,required"` } `cookie:"Nested"` } c.Request().Header.SetCookie("Name", "tom") c.Request().Header.SetCookie("Nested.Age", "10") h2 := new(Cookie2) require.NoError(t, c.Bind().Cookie(h2)) c.Request().Header.DelCookie("Name") h2 = new(Cookie2) require.NoError(t, c.Bind().Cookie(h2)) c.Request().Header.DelCookie("Name") c.Request().Header.DelCookie("Nested.Age") c.Request().Header.SetCookie("Nested.Agex", "10") h2 = new(Cookie2) err = c.Bind().Cookie(h2) require.Error(t, err) require.Equal(t, "bind \"Nested.Age\" from cookie: Nested.Age is empty", err.Error()) require.ErrorAs(t, err, &MultiError{}) type Node struct { Next *Node `cookie:"Next,required"` Value int `cookie:"Val,required"` } c.Request().Header.SetCookie("Val", "1") c.Request().Header.SetCookie("Next.Val", "3") n := new(Node) require.NoError(t, c.Bind().Cookie(n)) require.Equal(t, 1, n.Value) require.Equal(t, 3, n.Next.Value) c.Request().Header.DelCookie("Val") n = new(Node) err = c.Bind().Cookie(n) require.Error(t, err) require.Equal(t, "bind \"Val\" from cookie: Val is empty", err.Error()) require.ErrorAs(t, err, &MultiError{}) c.Request().Header.SetCookie("Val", "3") c.Request().Header.DelCookie("Next.Val") c.Request().Header.SetCookie("Next.Value", "2") n = new(Node) n.Next = new(Node) require.NoError(t, c.Bind().Cookie(n)) require.Equal(t, 3, n.Value) require.Equal(t, 0, n.Next.Value) } // go test -v -run=^$ -bench=Benchmark_Bind_Cookie -benchmem -count=4 func Benchmark_Bind_Cookie(b *testing.B) { var err error app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) type Cookie struct { Name string Hobby []string ID int } c.Request().SetBody([]byte(``)) c.Request().Header.SetContentType("") c.Request().Header.SetCookie("id", "1") c.Request().Header.SetCookie("Name", "John Doe") c.Request().Header.SetCookie("Hobby", "golang,fiber") q := new(Cookie) b.ReportAllocs() for b.Loop() { err = c.Bind().Cookie(q) } require.NoError(b, err) } // go test -v -run=^$ -bench=Benchmark_Bind_Cookie_Map -benchmem -count=4 func Benchmark_Bind_Cookie_Map(b *testing.B) { var err error app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetBody([]byte(``)) c.Request().Header.SetContentType("") c.Request().Header.SetCookie("id", "1") c.Request().Header.SetCookie("Name", "John Doe") c.Request().Header.SetCookie("Hobby", "golang,fiber") q := make(map[string][]string) b.ReportAllocs() for b.Loop() { err = c.Bind().Cookie(&q) } require.NoError(b, err) } // custom binder for testing type customBinder struct{} func (*customBinder) Name() string { return "custom" } func (*customBinder) MIMETypes() []string { return []string{"test", "test2"} } func (*customBinder) Parse(c Ctx, out any) error { return json.Unmarshal(c.Body(), out) } // customBinderReturningError returns a fixed error for testing extractFieldFromError branches. type customBinderReturningError struct { err error mimeType string } func (*customBinderReturningError) Name() string { return "error-binder" } func (b *customBinderReturningError) MIMETypes() []string { if b.mimeType != "" { return []string{b.mimeType} } return []string{"application/x-unknown-key-test", "application/x-empty-field-test"} } func (b *customBinderReturningError) Parse(_ Ctx, _ any) error { return b.err } // go test -run Test_Bind_CustomBinder func Test_Bind_CustomBinder(t *testing.T) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) // Register binder customBinder := &customBinder{} app.RegisterCustomBinder(customBinder) type Demo struct { Name string `json:"name"` } body := []byte(`{"name":"john"}`) c.Request().SetBody(body) c.Request().Header.SetContentType("test") c.Request().Header.SetContentLength(len(body)) d := new(Demo) require.NoError(t, c.Bind().Body(d)) require.NoError(t, c.Bind().Custom("custom", d)) require.Equal(t, ErrCustomBinderNotFound, c.Bind().Custom("not_custom", d)) require.Equal(t, "john", d.Name) } // go test -run Test_Bind_CustomBinder_Source func Test_Bind_CustomBinder_Source(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(c) }) app.RegisterCustomBinder(&customBinder{}) type Demo struct { Name string `json:"name"` } c.Request().SetBody([]byte(`{invalid json`)) c.Request().Header.SetContentLength(14) err := c.Bind().Custom("custom", new(Demo)) require.Error(t, err) var be *BindError require.ErrorAs(t, err, &be) require.Equal(t, "custom", be.Source) } // go test -run Test_Bind_CustomBinder_Validation func Test_Bind_CustomBinder_Validation(t *testing.T) { t.Parallel() app := New(Config{StructValidator: &structValidator{}}) app.RegisterCustomBinder(&customBinder{}) t.Run("Body_custom_binder_validation_pass", func(t *testing.T) { t.Parallel() c := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(c) }) body := []byte(`{"name":"john"}`) c.Request().SetBody(body) c.Request().Header.SetContentType("test") c.Request().Header.SetContentLength(len(body)) out := new(simpleQuery) require.NoError(t, c.Bind().Body(out)) require.Equal(t, "john", out.Name) }) t.Run("Body_custom_binder_validation_fail", func(t *testing.T) { t.Parallel() c := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(c) }) body := []byte(`{"name":"invalid"}`) c.Request().SetBody(body) c.Request().Header.SetContentType("test") c.Request().Header.SetContentLength(len(body)) out := new(simpleQuery) require.EqualError(t, c.Bind().Body(out), "you should have entered right name") }) t.Run("Custom_binder_validation_pass", func(t *testing.T) { t.Parallel() c := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(c) }) body := []byte(`{"name":"john"}`) c.Request().SetBody(body) c.Request().Header.SetContentLength(len(body)) out := new(simpleQuery) require.NoError(t, c.Bind().Custom("custom", out)) require.Equal(t, "john", out.Name) }) t.Run("Custom_binder_validation_fail", func(t *testing.T) { t.Parallel() c := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(c) }) body := []byte(`{"name":"invalid"}`) c.Request().SetBody(body) c.Request().Header.SetContentLength(len(body)) out := new(simpleQuery) require.EqualError(t, c.Bind().Custom("custom", out), "you should have entered right name") }) } // go test -run Test_Bind_WithAutoHandling func Test_Bind_WithAutoHandling(t *testing.T) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) type RequiredQuery struct { Name string `query:"name,required"` } rq := new(RequiredQuery) c.Request().URI().SetQueryString("") err := c.Bind().WithAutoHandling().Query(rq) require.Equal(t, StatusBadRequest, c.Response().StatusCode()) require.Equal(t, "Bad request: name is empty", err.Error()) } // simple struct validator for testing type structValidator struct{} func (*structValidator) Validate(out any) error { out = reflect.ValueOf(out).Elem().Interface() sq, ok := out.(simpleQuery) if !ok { return errors.New("failed to type-assert to simpleQuery") } if sq.Name != "john" { return errors.New("you should have entered right name") } return nil } type simpleQuery struct { Name string `query:"name" json:"name"` } type countingStructValidator struct { calls int } func (v *countingStructValidator) Validate(_ any) error { v.calls++ return nil } // go test -run Test_Bind_Form_Map_SkipsStructValidator func Test_Bind_Form_Map_SkipsStructValidator(t *testing.T) { t.Parallel() makeRequest := func(c Ctx) { body := []byte("name=john") c.Request().SetBody(body) c.Request().Header.SetContentType(MIMEApplicationForm) c.Request().Header.SetContentLength(len(body)) } t.Run("without validator", func(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(c) }) makeRequest(c) req := make(map[string]string) require.NoError(t, c.Bind().Form(&req)) require.Equal(t, "john", req["name"]) }) t.Run("with struct validator configured", func(t *testing.T) { t.Parallel() validator := &countingStructValidator{} app := New(Config{StructValidator: validator}) c := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(c) }) makeRequest(c) req := make(map[string]string) require.NoError(t, c.Bind().Form(&req)) require.Equal(t, "john", req["name"]) require.Equal(t, 0, validator.calls) }) } // go test -run Test_Bind_SkipValidation func Test_Bind_SkipValidation(t *testing.T) { t.Parallel() app := New(Config{StructValidator: &structValidator{}}) t.Run("validation enabled", func(t *testing.T) { t.Parallel() c := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(c) }) rq := new(simpleQuery) c.Request().URI().SetQueryString("name=efe") require.Equal(t, "you should have entered right name", c.Bind().SkipValidation(false).Query(rq).Error()) }) t.Run("validation skipped", func(t *testing.T) { t.Parallel() c := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(c) }) rq := new(simpleQuery) c.Request().URI().SetQueryString("name=efe") require.NoError(t, c.Bind().SkipValidation(true).Query(rq)) require.Equal(t, "efe", rq.Name) }) } // go test -run Test_Bind_SkipValidation_Usage func Test_Bind_SkipValidation_Usage(t *testing.T) { t.Parallel() app := New(Config{StructValidator: &structValidator{}}) t.Run("skip validation for json", func(t *testing.T) { t.Parallel() c := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(c) }) body := []byte(`{"name":"efe"}`) c.Request().SetBody(body) c.Request().Header.SetContentType(MIMEApplicationJSON) c.Request().Header.SetContentLength(len(body)) req := new(simpleQuery) require.NoError(t, c.Bind().SkipValidation(true).JSON(req)) require.Equal(t, "efe", req.Name) }) t.Run("re-enable validation explicitly", func(t *testing.T) { t.Parallel() c := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(c) }) body := []byte(`{"name":"efe"}`) c.Request().SetBody(body) c.Request().Header.SetContentType(MIMEApplicationJSON) c.Request().Header.SetContentLength(len(body)) req := new(simpleQuery) require.EqualError(t, c.Bind().SkipValidation(false).JSON(req), "you should have entered right name") }) t.Run("toggle on same bind instance", func(t *testing.T) { c := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(c) }) body := []byte(`{"name":"efe"}`) c.Request().SetBody(body) c.Request().Header.SetContentType(MIMEApplicationJSON) c.Request().Header.SetContentLength(len(body)) req := new(simpleQuery) require.NoError(t, c.Bind().SkipValidation(true).JSON(req)) c.Request().SetBody(body) c.Request().Header.SetContentType(MIMEApplicationJSON) c.Request().Header.SetContentLength(len(body)) req = new(simpleQuery) require.EqualError(t, c.Bind().SkipValidation(false).JSON(req), "you should have entered right name") }) } // go test -run Test_Bind_ValidateStruct_NilTarget func Test_Bind_ValidateStruct_NilTarget(t *testing.T) { t.Parallel() validator := &countingStructValidator{} app := New(Config{StructValidator: validator}) c := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(c) }) require.NoError(t, c.Bind().validateStruct(nil)) require.Equal(t, 0, validator.calls) } // go test -run Test_Bind_StructValidator func Test_Bind_StructValidator(t *testing.T) { app := New(Config{StructValidator: &structValidator{}}) c := app.AcquireCtx(&fasthttp.RequestCtx{}) rq := new(simpleQuery) c.Request().URI().SetQueryString("name=efe") require.Equal(t, "you should have entered right name", c.Bind().Query(rq).Error()) rq = new(simpleQuery) c.Request().URI().SetQueryString("name=john") require.NoError(t, c.Bind().Query(rq)) } // go test -run Test_Bind_RepeatParserWithSameStruct -v func Test_Bind_RepeatParserWithSameStruct(t *testing.T) { t.Parallel() app := New(Config{ CBOREncoder: cbor.Marshal, CBORDecoder: cbor.Unmarshal, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(c) type Request struct { QueryParam string `query:"query_param"` HeaderParam string `header:"header_param"` BodyParam string `json:"body_param" xml:"body_param" form:"body_param"` } r := new(Request) c.Request().URI().SetQueryString("query_param=query_param") require.NoError(t, c.Bind().Query(r)) require.Equal(t, "query_param", r.QueryParam) c.Request().Header.Add("header_param", "header_param") require.NoError(t, c.Bind().Header(r)) require.Equal(t, "header_param", r.HeaderParam) var gzipJSON bytes.Buffer w := gzip.NewWriter(&gzipJSON) _, err := w.Write([]byte(`{"body_param":"body_param"}`)) require.NoError(t, err) err = w.Close() require.NoError(t, err) c.Request().Header.SetContentType(MIMEApplicationJSON) c.Request().Header.Set(HeaderContentEncoding, "gzip") c.Request().SetBody(gzipJSON.Bytes()) c.Request().Header.SetContentLength(len(gzipJSON.Bytes())) require.NoError(t, c.Bind().Body(r)) require.Equal(t, "body_param", r.BodyParam) c.Request().Header.Del(HeaderContentEncoding) testDecodeParser := func(contentType, body string) { c.Request().Header.SetContentType(contentType) c.Request().SetBody([]byte(body)) c.Request().Header.SetContentLength(len(body)) require.NoError(t, c.Bind().Body(r)) require.Equal(t, "body_param", r.BodyParam) } cb, err := cbor.Marshal(&Request{BodyParam: "body_param"}) require.NoError(t, err, "Failed to marshal CBOR data") testDecodeParser(MIMEApplicationJSON, `{"body_param":"body_param"}`) testDecodeParser(MIMEApplicationXML, `body_param`) testDecodeParser(MIMEApplicationCBOR, string(cb)) testDecodeParser(MIMEApplicationForm, "body_param=body_param") testDecodeParser(MIMEMultipartForm+`;boundary="b"`, "--b\r\nContent-Disposition: form-data; name=\"body_param\"\r\n\r\nbody_param\r\n--b--") } type RequestConfig struct { Headers map[string]string Cookies map[string]string ContentType string Query string Body []byte } func (rc *RequestConfig) ApplyTo(ctx Ctx) { if rc.Body != nil { ctx.Request().SetBody(rc.Body) ctx.Request().Header.SetContentLength(len(rc.Body)) } if rc.ContentType != "" { ctx.Request().Header.SetContentType(rc.ContentType) } for k, v := range rc.Headers { ctx.Request().Header.Set(k, v) } for k, v := range rc.Cookies { ctx.Request().Header.SetCookie(k, v) } if rc.Query != "" { ctx.Request().URI().SetQueryString(rc.Query) } } // go test -run Test_Bind_All func Test_Bind_All(t *testing.T) { t.Parallel() type User struct { Avatar *multipart.FileHeader `form:"avatar"` Name string `query:"name" json:"name" form:"name"` Email string `json:"email" form:"email"` Role string `header:"X-User-Role"` SessionID string `json:"session_id" cookie:"session_id"` ID int `uri:"id" query:"id" json:"id" form:"id"` } newBind := func(app *App) *Bind { return &Bind{ ctx: app.AcquireCtx(&fasthttp.RequestCtx{}), } } defaultConfig := func() *RequestConfig { return &RequestConfig{ ContentType: MIMEApplicationJSON, Body: []byte(`{"name":"john", "email": "john@doe.com", "session_id": "abc1234", "id": 1}`), Headers: map[string]string{ "X-User-Role": "admin", }, Cookies: map[string]string{ "session_id": "abc123", }, Query: "id=1&name=john", } } tests := []struct { out any expected *User config *RequestConfig name string wantErr bool }{ { name: "Invalid output type", out: 123, wantErr: true, }, { name: "Successful binding", out: new(User), config: defaultConfig(), expected: &User{ ID: 1, Name: "john", Email: "john@doe.com", Role: "admin", SessionID: "abc1234", }, }, { name: "Missing fields (partial JSON only)", out: new(User), config: &RequestConfig{ ContentType: MIMEApplicationJSON, Body: []byte(`{"name":"partial"}`), }, expected: &User{ Name: "partial", }, }, { name: "Override query with JSON", out: new(User), config: &RequestConfig{ ContentType: MIMEApplicationJSON, Body: []byte(`{"name":"fromjson", "id": 99}`), Query: "id=1&name=queryname", }, expected: &User{ Name: "fromjson", ID: 99, }, }, { name: "Form binding", out: new(User), config: &RequestConfig{ ContentType: MIMEApplicationForm, Body: []byte("id=2&name=formname&email=form@doe.com"), }, expected: &User{ ID: 2, Name: "formname", Email: "form@doe.com", }, }, { name: "Skip body when content-type missing", out: new(User), config: &RequestConfig{ Body: []byte(`{"name":"bodyname"}`), Query: "name=queryname", }, expected: &User{ Name: "queryname", }, }, { name: "Skip empty body despite content-type", out: new(User), config: &RequestConfig{ ContentType: MIMEApplicationJSON, Body: []byte{}, Query: "name=queryname", }, expected: &User{ Name: "queryname", }, }, } app := New() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() bind := newBind(app) if tt.config != nil { tt.config.ApplyTo(bind.ctx) } err := bind.All(tt.out) if tt.wantErr { require.Error(t, err) return } require.NoError(t, err) if tt.expected != nil { actual, ok := tt.out.(*User) require.True(t, ok) require.Equal(t, tt.expected.ID, actual.ID) require.Equal(t, tt.expected.Name, actual.Name) require.Equal(t, tt.expected.Email, actual.Email) require.Equal(t, tt.expected.Role, actual.Role) require.Equal(t, tt.expected.SessionID, actual.SessionID) } }) } } // go test -run Test_Bind_All_Uri_Precedence func Test_Bind_All_Uri_Precedence(t *testing.T) { t.Parallel() type User struct { Name string `json:"name"` Email string `json:"email"` ID int `uri:"id" json:"id" query:"id" form:"id"` } app := New() app.Post("/test1/:id", func(c Ctx) error { d := new(User) if err := c.Bind().All(d); err != nil { t.Fatal(err) } require.Equal(t, 111, d.ID) require.Equal(t, "john", d.Name) require.Equal(t, "john@doe.com", d.Email) return nil }) body := strings.NewReader(`{"id": 999, "name": "john", "email": "john@doe.com"}`) req := httptest.NewRequest(MethodPost, "/test1/111?id=888", body) req.Header.Set("Content-Type", "application/json") res, err := app.Test(req) require.NoError(t, err) require.Equal(t, 200, res.StatusCode) } // go test -run Test_Bind_All_Query_Precedence func Test_Bind_All_Query_Precedence(t *testing.T) { t.Parallel() type Data struct { ID int `query:"id" header:"id" cookie:"id"` } app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().URI().SetQueryString("id=5") c.Request().Header.Set("id", "3") c.Request().Header.SetCookie("id", "2") d := new(Data) require.NoError(t, (&Bind{ctx: c}).All(d)) require.Equal(t, 5, d.ID) } // go test -run Test_Bind_All_StructValidator func Test_Bind_All_StructValidator(t *testing.T) { t.Parallel() app := New(Config{StructValidator: &structValidator{}}) // Success case: name comes from body only ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) ctx.Request().Header.SetContentType(MIMEApplicationJSON) ctx.Request().SetBody([]byte(`{"name":"john"}`)) sq := new(simpleQuery) require.NoError(t, (&Bind{ctx: ctx}).All(sq)) require.Equal(t, "john", sq.Name) // Failure: missing name everywhere ctxFail := app.AcquireCtx(&fasthttp.RequestCtx{}) ctxFail.Request().Header.SetContentType(MIMEApplicationJSON) ctxFail.Request().SetBody([]byte(`{}`)) sqFail := new(simpleQuery) err := (&Bind{ctx: ctxFail}).WithoutAutoHandling().All(sqFail) require.EqualError(t, err, "you should have entered right name") } // go test -v -run=^$ -bench=Benchmark_Bind_All -benchmem -count=4 func BenchmarkBind_All(b *testing.B) { type User struct { SessionID string `json:"session_id" cookie:"session_id"` Name string `query:"name" json:"name" form:"name"` Email string `json:"email" form:"email"` Role string `header:"X-User-Role"` ID int `uri:"id" query:"id" json:"id" form:"id"` } app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) config := &RequestConfig{ ContentType: MIMEApplicationJSON, Body: []byte(`{"name":"john", "email": "john@doe.com", "session_id": "abc1234", "id": 1}`), Headers: map[string]string{ "X-User-Role": "admin", }, Cookies: map[string]string{ "session_id": "abc123", }, Query: "id=1&name=john", } bind := &Bind{ ctx: c, } for b.Loop() { user := &User{} config.ApplyTo(c) if err := bind.All(user); err != nil { b.Fatalf("unexpected error: %v", err) } } } ================================================ FILE: binder/README.md ================================================ # Fiber Binders **Binder** is a new request/response binding feature for Fiber introduced in Fiber v3. It replaces the old Fiber parsers and offers enhanced capabilities such as custom binder registration, struct validation, support for `map[string]string`, `map[string][]string`, and more. Binder replaces the following components: - `BodyParser` - `ParamsParser` - `GetReqHeaders` - `GetRespHeaders` - `AllParams` - `QueryParser` - `ReqHeaderParser` ## Default Binders Fiber provides several default binders out of the box: - [Form](form.go) - [Query](query.go) - [URI](uri.go) - [Header](header.go) - [Response Header](resp_header.go) - [Cookie](cookie.go) - [JSON](json.go) - [XML](xml.go) - [CBOR](cbor.go) ## Guides ### Binding into a Struct Fiber supports binding request data directly into a struct using [gofiber/schema](https://github.com/gofiber/schema). Here's an example: ```go // Field names must start with an uppercase letter type Person struct { Name string `json:"name" xml:"name" form:"name"` Pass string `json:"pass" xml:"pass" form:"pass"` } app.Post("/", func(c fiber.Ctx) error { p := new(Person) if err := c.Bind().Body(p); err != nil { return err } log.Println(p.Name) // Output: john log.Println(p.Pass) // Output: doe // Additional logic... }) // Run tests with the following curl commands: // JSON curl -X POST -H "Content-Type: application/json" --data "{\"name\":\"john\",\"pass\":\"doe\"}" localhost:3000 // XML curl -X POST -H "Content-Type: application/xml" --data "johndoe" localhost:3000 // URL-Encoded Form curl -X POST -H "Content-Type: application/x-www-form-urlencoded" --data "name=john&pass=doe" localhost:3000 // Multipart Form curl -X POST -F name=john -F pass=doe http://localhost:3000 // Query Parameters curl -X POST "http://localhost:3000/?name=john&pass=doe" ``` ### Binding into a Map Fiber allows binding request data into a `map[string]string` or `map[string][]string`. Here's an example: ```go app.Get("/", func(c fiber.Ctx) error { params := make(map[string][]string) if err := c.Bind().Query(params); err != nil { return err } log.Println(params["name"]) // Output: [john] log.Println(params["pass"]) // Output: [doe] log.Println(params["products"]) // Output: [shoe hat] // Additional logic... return nil }) // Run tests with the following curl command: curl "http://localhost:3000/?name=john&pass=doe&products=shoe&products=hat" ``` ### Automatic Error Handling with `WithAutoHandling` By default, Fiber returns binder errors directly. To handle errors automatically and return a `400 Bad Request` status, use the `WithAutoHandling()` method. **Example:** ```go // Field names must start with an uppercase letter type Person struct { Name string `json:"name,required"` Pass string `json:"pass"` } app.Get("/", func(c fiber.Ctx) error { p := new(Person) if err := c.Bind().WithAutoHandling().JSON(p); err != nil { return err // Automatically returns status code 400 // Response: Bad request: name is empty } // Additional logic... return nil }) // Run tests with the following curl command: curl -X GET -H "Content-Type: application/json" --data "{\"pass\":\"doe\"}" localhost:3000 ``` ### Defining a Custom Binder Fiber maintains a minimal codebase by not including every possible binder. If you need to use a custom binder, you can easily register and utilize it. Here's an example of creating a `toml` binder. ```go type Person struct { Name string `toml:"name"` Pass string `toml:"pass"` } type tomlBinding struct{} func (b *tomlBinding) Name() string { return "toml" } func (b *tomlBinding) MIMETypes() []string { return []string{"application/toml"} } func (b *tomlBinding) Parse(c fiber.Ctx, out any) error { return toml.Unmarshal(c.Body(), out) } func main() { app := fiber.New() app.RegisterCustomBinder(&tomlBinding{}) app.Get("/", func(c fiber.Ctx) error { out := new(Person) if err := c.Bind().Body(out); err != nil { return err } // Alternatively, specify the custom binder: // if err := c.Bind().Custom("toml", out); err != nil { // return err // } return c.SendString(out.Pass) // Output: test }) app.Listen(":3000") } // Run tests with the following curl command: curl -X GET -H "Content-Type: application/toml" --data "name = 'bar' pass = 'test'" localhost:3000 ``` ### Defining a Custom Validator All Fiber binders support struct validation if a validator is defined in the configuration. You can create your own validator or use existing ones like [go-playground/validator](https://github.com/go-playground/validator) or [go-ozzo/ozzo-validation](https://github.com/go-ozzo/ozzo-validation). Here's an example of a simple custom validator: ```go type Query struct { Name string `query:"name"` } type structValidator struct{} func (v *structValidator) Engine() any { return nil // Implement if using an external validation engine } func (v *structValidator) ValidateStruct(out any) error { data := reflect.ValueOf(out).Elem().Interface() query := data.(Query) if query.Name != "john" { return errors.New("you should have entered the correct name!") } return nil } func main() { app := fiber.New(fiber.Config{ StructValidator: &structValidator{}, }) app.Get("/", func(c fiber.Ctx) error { out := new(Query) if err := c.Bind().Query(out); err != nil { return err // Returns: you should have entered the correct name! } return c.SendString(out.Name) }) app.Listen(":3000") } // Run tests with the following curl command: curl "http://localhost:3000/?name=efe" ``` ================================================ FILE: binder/binder.go ================================================ package binder import ( "errors" "sync" ) // Binder errors var ( ErrSuitableContentNotFound = errors.New("binder: suitable content not found to parse body") ErrMapNotConvertible = errors.New("binder: map is not convertible to map[string]string or map[string][]string") ErrMapNilDestination = errors.New("binder: map destination is nil and cannot be initialized") ErrInvalidDestinationValue = errors.New("binder: invalid destination value") ErrUnmatchedBrackets = errors.New("unmatched brackets") ) var errPoolTypeAssertion = errors.New("failed to type-assert to T") var HeaderBinderPool = sync.Pool{ New: func() any { return &HeaderBinding{} }, } var RespHeaderBinderPool = sync.Pool{ New: func() any { return &RespHeaderBinding{} }, } var CookieBinderPool = sync.Pool{ New: func() any { return &CookieBinding{} }, } var QueryBinderPool = sync.Pool{ New: func() any { return &QueryBinding{} }, } var FormBinderPool = sync.Pool{ New: func() any { return &FormBinding{} }, } var URIBinderPool = sync.Pool{ New: func() any { return &URIBinding{} }, } var XMLBinderPool = sync.Pool{ New: func() any { return &XMLBinding{} }, } var JSONBinderPool = sync.Pool{ New: func() any { return &JSONBinding{} }, } var CBORBinderPool = sync.Pool{ New: func() any { return &CBORBinding{} }, } var MsgPackBinderPool = sync.Pool{ New: func() any { return &MsgPackBinding{} }, } // GetFromThePool retrieves a binder from the provided sync.Pool and panics if // the stored value cannot be cast to the requested type. func GetFromThePool[T any](pool *sync.Pool) T { binder, ok := pool.Get().(T) if !ok { panic(errPoolTypeAssertion) } return binder } // PutToThePool returns the binder to the provided sync.Pool. func PutToThePool[T any](pool *sync.Pool, binder T) { pool.Put(binder) } ================================================ FILE: binder/binder_test.go ================================================ package binder import ( "mime/multipart" "reflect" "strconv" "testing" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) func Test_GetAndPutToThePool(t *testing.T) { t.Parallel() // Panics in case we get from another pool require.Panics(t, func() { _ = GetFromThePool[*HeaderBinding](&CookieBinderPool) }) // We get from the pool binder := GetFromThePool[*HeaderBinding](&HeaderBinderPool) PutToThePool(&HeaderBinderPool, binder) _ = GetFromThePool[*RespHeaderBinding](&RespHeaderBinderPool) _ = GetFromThePool[*QueryBinding](&QueryBinderPool) _ = GetFromThePool[*FormBinding](&FormBinderPool) _ = GetFromThePool[*URIBinding](&URIBinderPool) _ = GetFromThePool[*XMLBinding](&XMLBinderPool) _ = GetFromThePool[*JSONBinding](&JSONBinderPool) _ = GetFromThePool[*CBORBinding](&CBORBinderPool) _ = GetFromThePool[*MsgPackBinding](&MsgPackBinderPool) } func Test_Binders_ErrorPaths(t *testing.T) { t.Run("query binder invalid key", func(t *testing.T) { b := &QueryBinding{} req := fasthttp.AcquireRequest() req.URI().SetQueryString("invalid[%3Dval&name=john") defer fasthttp.ReleaseRequest(req) err := b.Bind(req, &struct{}{}) require.Error(t, err) require.Contains(t, err.Error(), "unmatched brackets") }) t.Run("form binder invalid key", func(t *testing.T) { b := &FormBinding{} req := fasthttp.AcquireRequest() req.SetBodyString("invalid[=val") req.Header.SetContentType("application/x-www-form-urlencoded") defer fasthttp.ReleaseRequest(req) err := b.Bind(req, &struct{}{}) require.Error(t, err) require.Contains(t, err.Error(), "unmatched brackets") }) t.Run("form binder bad multipart", func(t *testing.T) { b := &FormBinding{} req := fasthttp.AcquireRequest() req.Header.SetContentType(MIMEMultipartForm) defer fasthttp.ReleaseRequest(req) err := b.Bind(req, &struct{}{}) require.Error(t, err) }) } func Test_GetFieldCache_Panic(t *testing.T) { t.Parallel() require.Panics(t, func() { getFieldCache("unknown") }) } func Test_parseToMap_defaultCase(t *testing.T) { t.Parallel() m := map[string]int{} err := parseToMap(reflect.ValueOf(m), map[string][]string{"a": {"1"}}) require.NoError(t, err) require.Empty(t, m) m2 := map[string]string{} err = parseToMap(reflect.ValueOf(m2), map[string][]string{"empty": {}}) require.NoError(t, err) require.Empty(t, m2["empty"]) var zeroStringMap map[string]string err = parseToMap(reflect.ValueOf(&zeroStringMap).Elem(), map[string][]string{"name": {"john"}}) require.NoError(t, err) require.Equal(t, "john", zeroStringMap["name"]) } func Test_parse_function_maps(t *testing.T) { t.Parallel() m := map[string][]string{} err := parse("query", &m, map[string][]string{"a": {"b"}}) require.NoError(t, err) require.Equal(t, []string{"b"}, m["a"]) m2 := map[string]string{} err = parse("query", &m2, map[string][]string{"a": {"b"}}) require.NoError(t, err) require.Equal(t, "b", m2["a"]) var zeroStringMap map[string]string err = parse("query", &zeroStringMap, map[string][]string{"foo": {"bar", "baz"}}) require.NoError(t, err) require.Equal(t, "baz", zeroStringMap["foo"]) var zeroSliceMap map[string][]string err = parse("query", &zeroSliceMap, map[string][]string{"foo": {"bar", "baz"}}) require.NoError(t, err) require.Equal(t, []string{"bar", "baz"}, zeroSliceMap["foo"]) } func Test_SetParserDecoder_UnknownKeys(t *testing.T) { SetParserDecoder(ParserConfig{IgnoreUnknownKeys: false}) defer SetParserDecoder(ParserConfig{IgnoreUnknownKeys: true, ZeroEmpty: true}) type user struct { Name string `query:"name"` } data := map[string][]string{"name": {"john"}, "foo": {"bar"}} err := parseToStruct("query", &user{}, data) require.Error(t, err) SetParserDecoder(ParserConfig{IgnoreUnknownKeys: true, ZeroEmpty: true}) } func Test_SetParserDecoder_CustomConverter(t *testing.T) { type myInt int conv := func(s string) reflect.Value { v, _ := strconv.Atoi(s) //nolint:errcheck // not needed mi := myInt(v) return reflect.ValueOf(mi) } SetParserDecoder(ParserConfig{ParserType: []ParserType{{CustomType: myInt(0), Converter: conv}}}) defer SetParserDecoder(ParserConfig{IgnoreUnknownKeys: true, ZeroEmpty: true}) type data struct { V myInt `query:"v"` } d := new(data) err := parse("query", d, map[string][]string{"v": {"5"}}) require.NoError(t, err) require.Equal(t, myInt(5), d.V) } func Test_formatBindData_typeMismatch(t *testing.T) { t.Parallel() out := struct{}{} files := map[string][]*multipart.FileHeader{} err := formatBindData("query", out, files, "file", 123, false, false) require.Error(t, err) require.Equal(t, "unsupported value type: int", err.Error()) } ================================================ FILE: binder/cbor.go ================================================ package binder import ( "github.com/gofiber/utils/v2" ) // CBORBinding is the CBOR binder for CBOR request body. type CBORBinding struct { CBORDecoder utils.CBORUnmarshal } // Name returns the binding name. func (*CBORBinding) Name() string { return "cbor" } // Bind parses the request body as CBOR and returns the result. func (b *CBORBinding) Bind(body []byte, out any) error { return b.CBORDecoder(body, out) } // Reset resets the CBORBinding binder. func (b *CBORBinding) Reset() { b.CBORDecoder = nil } // UnimplementedCborMarshal panics to signal that a CBOR marshaler must be // configured before CBOR support can be used. func UnimplementedCborMarshal(_ any) ([]byte, error) { panic("Must explicitly setup CBOR, please check docs: https://docs.gofiber.io/next/guide/advance-format#cbor") } // UnimplementedCborUnmarshal panics to signal that a CBOR unmarshaler must be // configured before CBOR support can be used. func UnimplementedCborUnmarshal(_ []byte, _ any) error { panic("Must explicitly setup CBOR, please check docs: https://docs.gofiber.io/next/guide/advance-format#cbor") } ================================================ FILE: binder/cbor_test.go ================================================ package binder import ( "testing" "github.com/fxamacker/cbor/v2" "github.com/stretchr/testify/require" ) func Test_CBORBinder_Bind(t *testing.T) { t.Parallel() b := &CBORBinding{ CBORDecoder: cbor.Unmarshal, } require.Equal(t, "cbor", b.Name()) type Post struct { Title string `cbor:"title"` } type User struct { Name string `cbor:"name"` Posts []Post `cbor:"posts"` Names []string `cbor:"names"` Age int `cbor:"age"` } var user User wantedUser := User{ Name: "john", Names: []string{ "john", "doe", }, Age: 42, Posts: []Post{ {Title: "post1"}, {Title: "post2"}, {Title: "post3"}, }, } body, err := cbor.Marshal(wantedUser) require.NoError(t, err) err = b.Bind(body, &user) require.NoError(t, err) require.Equal(t, "john", user.Name) require.Equal(t, 42, user.Age) require.Len(t, user.Posts, 3) require.Equal(t, "post1", user.Posts[0].Title) require.Equal(t, "post2", user.Posts[1].Title) require.Equal(t, "post3", user.Posts[2].Title) require.Contains(t, user.Names, "john") require.Contains(t, user.Names, "doe") b.Reset() require.Nil(t, b.CBORDecoder) } func Benchmark_CBORBinder_Bind(b *testing.B) { b.ReportAllocs() binder := &CBORBinding{ CBORDecoder: cbor.Unmarshal, } type User struct { Name string `cbor:"name"` Age int `cbor:"age"` } var user User wantedUser := User{ Name: "john", Age: 42, } body, err := cbor.Marshal(wantedUser) require.NoError(b, err) for b.Loop() { err = binder.Bind(body, &user) } require.NoError(b, err) require.Equal(b, "john", user.Name) require.Equal(b, 42, user.Age) } func Test_UnimplementedCborMarshal_Panics(t *testing.T) { t.Parallel() require.Panics(t, func() { _, _ = UnimplementedCborMarshal(struct{ Name string }{Name: "test"}) //nolint:errcheck // this is just a test to trigger the panic }) } func Test_UnimplementedCborUnmarshal_Panics(t *testing.T) { t.Parallel() require.Panics(t, func() { var out any _ = UnimplementedCborUnmarshal([]byte{0xa0}, &out) //nolint:errcheck // this is just a test to trigger the panic }) } func Test_UnimplementedCborMarshal_PanicMessage(t *testing.T) { t.Parallel() defer func() { if r := recover(); r != nil { require.Contains(t, r, "Must explicitly setup CBOR") } }() _, _ = UnimplementedCborMarshal(struct{ Name string }{Name: "test"}) //nolint:errcheck // this is just a test to trigger the panic } func Test_UnimplementedCborUnmarshal_PanicMessage(t *testing.T) { t.Parallel() defer func() { if r := recover(); r != nil { require.Contains(t, r, "Must explicitly setup CBOR") } }() var out any _ = UnimplementedCborUnmarshal([]byte{0xa0}, &out) //nolint:errcheck // this is just a test to trigger the panic } ================================================ FILE: binder/cookie.go ================================================ package binder import ( "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) // CookieBinding is the cookie binder for cookie request body. type CookieBinding struct { EnableSplitting bool } // Name returns the binding name. func (*CookieBinding) Name() string { return "cookie" } // Bind parses the request cookie and returns the result. func (b *CookieBinding) Bind(req *fasthttp.Request, out any) error { data := make(map[string][]string) for key, val := range req.Header.Cookies() { k := utils.UnsafeString(key) v := utils.UnsafeString(val) if err := formatBindData(b.Name(), out, data, k, v, b.EnableSplitting, false); err != nil { return err } } return parse(b.Name(), out, data) } // Reset resets the CookieBinding binder. func (b *CookieBinding) Reset() { b.EnableSplitting = false } ================================================ FILE: binder/cookie_test.go ================================================ package binder import ( "testing" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) func Test_CookieBinder_Bind(t *testing.T) { t.Parallel() b := &CookieBinding{ EnableSplitting: true, } require.Equal(t, "cookie", b.Name()) type Post struct { Title string `form:"title"` } type User struct { Name string `form:"name"` Names []string `form:"names"` Posts []Post `form:"posts"` Age int `form:"age"` } var user User req := fasthttp.AcquireRequest() req.Header.SetCookie("name", "john") req.Header.SetCookie("names", "john,doe") req.Header.SetCookie("age", "42") t.Cleanup(func() { fasthttp.ReleaseRequest(req) }) err := b.Bind(req, &user) require.NoError(t, err) require.Equal(t, "john", user.Name) require.Equal(t, 42, user.Age) require.Contains(t, user.Names, "john") require.Contains(t, user.Names, "doe") b.Reset() require.False(t, b.EnableSplitting) } func Benchmark_CookieBinder_Bind(b *testing.B) { b.ReportAllocs() binder := &CookieBinding{ EnableSplitting: true, } type User struct { Name string `query:"name"` Posts []string `query:"posts"` Age int `query:"age"` } var user User req := fasthttp.AcquireRequest() b.Cleanup(func() { fasthttp.ReleaseRequest(req) }) req.Header.SetCookie("name", "john") req.Header.SetCookie("age", "42") req.Header.SetCookie("posts", "post1,post2,post3") var err error for b.Loop() { err = binder.Bind(req, &user) } require.NoError(b, err) require.Equal(b, "john", user.Name) require.Equal(b, 42, user.Age) require.Len(b, user.Posts, 3) require.Contains(b, user.Posts, "post1") require.Contains(b, user.Posts, "post2") require.Contains(b, user.Posts, "post3") } func Test_CookieBinder_Bind_ParseError(t *testing.T) { b := &CookieBinding{} type User struct { Age int `cookie:"age"` } var user User req := fasthttp.AcquireRequest() req.Header.SetCookie("age", "invalid") t.Cleanup(func() { fasthttp.ReleaseRequest(req) }) err := b.Bind(req, &user) require.Error(t, err) } ================================================ FILE: binder/form.go ================================================ package binder import ( "mime/multipart" "sync" "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) const MIMEMultipartForm string = "multipart/form-data" var ( formMapPool = sync.Pool{ New: func() any { return make(map[string][]string) }, } formFileMapPool = sync.Pool{ New: func() any { return make(map[string][]*multipart.FileHeader) }, } ) // FormBinding is the form binder for form request body. type FormBinding struct { EnableSplitting bool } // Name returns the binding name. func (*FormBinding) Name() string { return "form" } // Bind parses the request body and returns the result. func (b *FormBinding) Bind(req *fasthttp.Request, out any) error { // Handle multipart form if FilterFlags(utils.UnsafeString(req.Header.ContentType())) == MIMEMultipartForm { return b.bindMultipart(req, out) } data := acquireFormMap() defer releaseFormMap(data) for key, val := range req.PostArgs().All() { k := utils.UnsafeString(key) v := utils.UnsafeString(val) if err := formatBindData(b.Name(), out, data, k, v, b.EnableSplitting, true); err != nil { return err } } return parse(b.Name(), out, data) } // bindMultipart parses the request body and returns the result. func (b *FormBinding) bindMultipart(req *fasthttp.Request, out any) error { multipartForm, err := req.MultipartForm() if err != nil { return err } data := acquireFormMap() defer releaseFormMap(data) for key, values := range multipartForm.Value { err = formatBindData(b.Name(), out, data, key, values, b.EnableSplitting, true) if err != nil { return err } } files := acquireFileHeaderMap() defer releaseFileHeaderMap(files) for key, values := range multipartForm.File { err = formatBindData(b.Name(), out, files, key, values, b.EnableSplitting, true) if err != nil { return err } } return parse(b.Name(), out, data, files) } // Reset resets the FormBinding binder. func (b *FormBinding) Reset() { b.EnableSplitting = false } func acquireFormMap() map[string][]string { m, ok := formMapPool.Get().(map[string][]string) if !ok { m = make(map[string][]string) } return m } func releaseFormMap(m map[string][]string) { clearFormMap(m) formMapPool.Put(m) } func acquireFileHeaderMap() map[string][]*multipart.FileHeader { m, ok := formFileMapPool.Get().(map[string][]*multipart.FileHeader) if !ok { m = make(map[string][]*multipart.FileHeader) } return m } func releaseFileHeaderMap(m map[string][]*multipart.FileHeader) { clearFileHeaderMap(m) formFileMapPool.Put(m) } func clearFormMap(m map[string][]string) { clear(m) } func clearFileHeaderMap(m map[string][]*multipart.FileHeader) { clear(m) } ================================================ FILE: binder/form_test.go ================================================ package binder import ( "bytes" "io" "mime/multipart" "testing" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) func Test_FormBinder_Bind(t *testing.T) { t.Parallel() b := &FormBinding{ EnableSplitting: true, } require.Equal(t, "form", b.Name()) type Post struct { Title string `form:"title"` } type User struct { Name string `form:"name"` Names []string `form:"names"` Posts []Post `form:"posts"` Age int `form:"age"` } var user User req := fasthttp.AcquireRequest() req.SetBodyString("name=john&names=john,doe&age=42&posts[0][title]=post1&posts[1][title]=post2&posts[2][title]=post3") req.Header.SetContentType("application/x-www-form-urlencoded") t.Cleanup(func() { fasthttp.ReleaseRequest(req) }) err := b.Bind(req, &user) require.NoError(t, err) require.Equal(t, "john", user.Name) require.Equal(t, 42, user.Age) require.Len(t, user.Posts, 3) require.Equal(t, "post1", user.Posts[0].Title) require.Equal(t, "post2", user.Posts[1].Title) require.Equal(t, "post3", user.Posts[2].Title) require.Contains(t, user.Names, "john") require.Contains(t, user.Names, "doe") b.Reset() require.False(t, b.EnableSplitting) } func Test_FormBinder_Bind_ParseError(t *testing.T) { b := &FormBinding{} type User struct { Age int `form:"age"` } var user User req := fasthttp.AcquireRequest() req.SetBodyString("age=invalid") req.Header.SetContentType("application/x-www-form-urlencoded") t.Cleanup(func() { fasthttp.ReleaseRequest(req) }) err := b.Bind(req, &user) require.Error(t, err) } func Benchmark_FormBinder_Bind(b *testing.B) { b.ReportAllocs() binder := &FormBinding{ EnableSplitting: true, } type User struct { Name string `form:"name"` Posts []string `form:"posts"` Age int `form:"age"` } var user User req := fasthttp.AcquireRequest() req.SetBodyString("name=john&age=42&posts=post1,post2,post3") req.Header.SetContentType("application/x-www-form-urlencoded") var err error for b.Loop() { err = binder.Bind(req, &user) } require.NoError(b, err) require.Equal(b, "john", user.Name) require.Equal(b, 42, user.Age) require.Len(b, user.Posts, 3) } func Test_FormBinder_BindMultipart(t *testing.T) { t.Parallel() b := &FormBinding{ EnableSplitting: true, } require.Equal(t, "form", b.Name()) type Post struct { Title string `form:"title"` } type User struct { Avatar *multipart.FileHeader `form:"avatar"` Name string `form:"name"` Names []string `form:"names"` Posts []Post `form:"posts"` Avatars []*multipart.FileHeader `form:"avatars"` Age int `form:"age"` } var user User req := fasthttp.AcquireRequest() buf := &bytes.Buffer{} mw := multipart.NewWriter(buf) require.NoError(t, mw.WriteField("name", "john")) require.NoError(t, mw.WriteField("names", "john,eric")) require.NoError(t, mw.WriteField("names", "doe")) require.NoError(t, mw.WriteField("age", "42")) require.NoError(t, mw.WriteField("posts[0][title]", "post1")) require.NoError(t, mw.WriteField("posts[1][title]", "post2")) require.NoError(t, mw.WriteField("posts[2][title]", "post3")) writer, err := mw.CreateFormFile("avatar", "avatar.txt") require.NoError(t, err) _, err = writer.Write([]byte("avatar")) require.NoError(t, err) writer, err = mw.CreateFormFile("avatars", "avatar1.txt") require.NoError(t, err) _, err = writer.Write([]byte("avatar1")) require.NoError(t, err) writer, err = mw.CreateFormFile("avatars", "avatar2.txt") require.NoError(t, err) _, err = writer.Write([]byte("avatar2")) require.NoError(t, err) require.NoError(t, mw.Close()) req.Header.SetContentType(mw.FormDataContentType()) req.SetBody(buf.Bytes()) t.Cleanup(func() { fasthttp.ReleaseRequest(req) }) err = b.Bind(req, &user) require.NoError(t, err) require.Equal(t, "john", user.Name) require.Equal(t, 42, user.Age) require.Contains(t, user.Names, "john") require.Contains(t, user.Names, "doe") require.Contains(t, user.Names, "eric") require.Len(t, user.Posts, 3) require.Equal(t, "post1", user.Posts[0].Title) require.Equal(t, "post2", user.Posts[1].Title) require.Equal(t, "post3", user.Posts[2].Title) require.NotNil(t, user.Avatar) require.Equal(t, "avatar.txt", user.Avatar.Filename) require.Equal(t, "application/octet-stream", user.Avatar.Header.Get("Content-Type")) file, err := user.Avatar.Open() require.NoError(t, err) content, err := io.ReadAll(file) require.NoError(t, err) require.Equal(t, "avatar", string(content)) require.Len(t, user.Avatars, 2) require.Equal(t, "avatar1.txt", user.Avatars[0].Filename) require.Equal(t, "application/octet-stream", user.Avatars[0].Header.Get("Content-Type")) file, err = user.Avatars[0].Open() require.NoError(t, err) content, err = io.ReadAll(file) require.NoError(t, err) require.Equal(t, "avatar1", string(content)) require.Equal(t, "avatar2.txt", user.Avatars[1].Filename) require.Equal(t, "application/octet-stream", user.Avatars[1].Header.Get("Content-Type")) file, err = user.Avatars[1].Open() require.NoError(t, err) content, err = io.ReadAll(file) require.NoError(t, err) require.Equal(t, "avatar2", string(content)) } func Test_FormBinder_BindMultipart_ValueError(t *testing.T) { b := &FormBinding{} req := fasthttp.AcquireRequest() buf := &bytes.Buffer{} mw := multipart.NewWriter(buf) require.NoError(t, mw.WriteField("invalid[", "val")) require.NoError(t, mw.Close()) req.Header.SetContentType(mw.FormDataContentType()) req.SetBody(buf.Bytes()) t.Cleanup(func() { fasthttp.ReleaseRequest(req) }) err := b.Bind(req, &struct{}{}) require.Error(t, err) require.Contains(t, err.Error(), "unmatched brackets") } func Test_FormBinder_BindMultipart_FileError(t *testing.T) { b := &FormBinding{} req := fasthttp.AcquireRequest() buf := &bytes.Buffer{} mw := multipart.NewWriter(buf) writer, err := mw.CreateFormFile("invalid[", "file.txt") require.NoError(t, err) _, err = writer.Write([]byte("content")) require.NoError(t, err) require.NoError(t, mw.Close()) req.Header.SetContentType(mw.FormDataContentType()) req.SetBody(buf.Bytes()) t.Cleanup(func() { fasthttp.ReleaseRequest(req) }) err = b.Bind(req, &struct{}{}) require.Error(t, err) require.Contains(t, err.Error(), "unmatched brackets") } func Test_FormBinder_Bind_MapClearedBetweenRequests(t *testing.T) { t.Parallel() b := &FormBinding{} type payload struct { Name string `form:"name"` Age int `form:"age"` } firstReq := fasthttp.AcquireRequest() firstReq.SetBodyString("name=john&age=21") firstReq.Header.SetContentType("application/x-www-form-urlencoded") t.Cleanup(func() { fasthttp.ReleaseRequest(firstReq) }) var first payload require.NoError(t, b.Bind(firstReq, &first)) require.Equal(t, "john", first.Name) require.Equal(t, 21, first.Age) secondReq := fasthttp.AcquireRequest() secondReq.SetBodyString("age=42") secondReq.Header.SetContentType("application/x-www-form-urlencoded") t.Cleanup(func() { fasthttp.ReleaseRequest(secondReq) }) var second payload require.NoError(t, b.Bind(secondReq, &second)) require.Empty(t, second.Name) require.Equal(t, 42, second.Age) } func Test_FormBinder_BindMultipart_MapsClearedBetweenRequests(t *testing.T) { t.Parallel() b := &FormBinding{} type payload struct { // betteralign:ignore - test payload prioritizes readability over alignment Avatar *multipart.FileHeader `form:"avatar"` Name string `form:"name"` Age int `form:"age"` } firstReq := fasthttp.AcquireRequest() firstBuffer := &bytes.Buffer{} firstWriter := multipart.NewWriter(firstBuffer) require.NoError(t, firstWriter.WriteField("name", "john")) require.NoError(t, firstWriter.WriteField("age", "21")) firstFile, err := firstWriter.CreateFormFile("avatar", "avatar.txt") require.NoError(t, err) _, err = firstFile.Write([]byte("avatar-content")) require.NoError(t, err) require.NoError(t, firstWriter.Close()) firstReq.Header.SetContentType(firstWriter.FormDataContentType()) firstReq.SetBody(firstBuffer.Bytes()) t.Cleanup(func() { fasthttp.ReleaseRequest(firstReq) }) var first payload require.NoError(t, b.Bind(firstReq, &first)) require.Equal(t, "john", first.Name) require.Equal(t, 21, first.Age) require.NotNil(t, first.Avatar) require.Equal(t, "avatar.txt", first.Avatar.Filename) secondReq := fasthttp.AcquireRequest() secondBuffer := &bytes.Buffer{} secondWriter := multipart.NewWriter(secondBuffer) require.NoError(t, secondWriter.WriteField("age", "42")) require.NoError(t, secondWriter.Close()) secondReq.Header.SetContentType(secondWriter.FormDataContentType()) secondReq.SetBody(secondBuffer.Bytes()) t.Cleanup(func() { fasthttp.ReleaseRequest(secondReq) }) var second payload require.NoError(t, b.Bind(secondReq, &second)) require.Empty(t, second.Name) require.Equal(t, 42, second.Age) require.Nil(t, second.Avatar) } func Benchmark_FormBinder_BindMultipart(b *testing.B) { b.ReportAllocs() binder := &FormBinding{ EnableSplitting: true, } type User struct { Name string `form:"name"` Posts []string `form:"posts"` Age int `form:"age"` } var user User req := fasthttp.AcquireRequest() b.Cleanup(func() { fasthttp.ReleaseRequest(req) }) buf := &bytes.Buffer{} mw := multipart.NewWriter(buf) require.NoError(b, mw.WriteField("name", "john")) require.NoError(b, mw.WriteField("age", "42")) require.NoError(b, mw.WriteField("posts", "post1")) require.NoError(b, mw.WriteField("posts", "post2")) require.NoError(b, mw.WriteField("posts", "post3")) require.NoError(b, mw.Close()) req.Header.SetContentType(mw.FormDataContentType()) req.SetBody(buf.Bytes()) var err error for b.Loop() { err = binder.Bind(req, &user) } require.NoError(b, err) require.Equal(b, "john", user.Name) require.Equal(b, 42, user.Age) require.Len(b, user.Posts, 3) } ================================================ FILE: binder/header.go ================================================ package binder import ( "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) // HeaderBinding is the binder implementation used to populate values from HTTP headers. type HeaderBinding struct { EnableSplitting bool } // Name returns the binding name. func (*HeaderBinding) Name() string { return "header" } // Bind parses the request header and returns the result. func (b *HeaderBinding) Bind(req *fasthttp.Request, out any) error { data := make(map[string][]string) for key, val := range req.Header.All() { k := utils.UnsafeString(key) v := utils.UnsafeString(val) if err := formatBindData(b.Name(), out, data, k, v, b.EnableSplitting, false); err != nil { return err } } return parse(b.Name(), out, data) } // Reset resets the HeaderBinding binder. func (b *HeaderBinding) Reset() { b.EnableSplitting = false } ================================================ FILE: binder/header_test.go ================================================ package binder import ( "testing" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) func Test_HeaderBinder_Bind(t *testing.T) { t.Parallel() b := &HeaderBinding{ EnableSplitting: true, } require.Equal(t, "header", b.Name()) type User struct { Name string `header:"Name"` Names []string `header:"Names"` Posts []string `header:"Posts"` Age int `header:"Age"` } var user User req := fasthttp.AcquireRequest() req.Header.Set("name", "john") req.Header.Set("names", "john,doe") req.Header.Set("age", "42") req.Header.Set("posts", "post1,post2,post3") t.Cleanup(func() { fasthttp.ReleaseRequest(req) }) err := b.Bind(req, &user) require.NoError(t, err) require.Equal(t, "john", user.Name) require.Equal(t, 42, user.Age) require.Len(t, user.Posts, 3) require.Equal(t, "post1", user.Posts[0]) require.Equal(t, "post2", user.Posts[1]) require.Equal(t, "post3", user.Posts[2]) require.Contains(t, user.Names, "john") require.Contains(t, user.Names, "doe") b.Reset() require.False(t, b.EnableSplitting) } func Benchmark_HeaderBinder_Bind(b *testing.B) { b.ReportAllocs() binder := &HeaderBinding{ EnableSplitting: true, } type User struct { Name string `header:"Name"` Posts []string `header:"Posts"` Age int `header:"Age"` } var user User req := fasthttp.AcquireRequest() b.Cleanup(func() { fasthttp.ReleaseRequest(req) }) req.Header.Set("name", "john") req.Header.Set("age", "42") req.Header.Set("posts", "post1,post2,post3") var err error for b.Loop() { err = binder.Bind(req, &user) } require.NoError(b, err) require.Equal(b, "john", user.Name) require.Equal(b, 42, user.Age) require.Len(b, user.Posts, 3) require.Contains(b, user.Posts, "post1") require.Contains(b, user.Posts, "post2") require.Contains(b, user.Posts, "post3") } func Test_HeaderBinder_Bind_ParseError(t *testing.T) { b := &HeaderBinding{} type User struct { Age int `header:"Age"` } var user User req := fasthttp.AcquireRequest() req.Header.Set("age", "invalid") t.Cleanup(func() { fasthttp.ReleaseRequest(req) }) err := b.Bind(req, &user) require.Error(t, err) } ================================================ FILE: binder/json.go ================================================ package binder import ( "github.com/gofiber/utils/v2" ) // JSONBinding is the JSON binder for JSON request body. type JSONBinding struct { JSONDecoder utils.JSONUnmarshal } // Name returns the binding name. func (*JSONBinding) Name() string { return "json" } // Bind parses the request body as JSON and returns the result. func (b *JSONBinding) Bind(body []byte, out any) error { return b.JSONDecoder(body, out) } // Reset resets the JSONBinding binder. func (b *JSONBinding) Reset() { b.JSONDecoder = nil } ================================================ FILE: binder/json_test.go ================================================ package binder import ( "encoding/json" "testing" "github.com/stretchr/testify/require" ) func Test_JSON_Binding_Bind(t *testing.T) { t.Parallel() b := &JSONBinding{ JSONDecoder: json.Unmarshal, } require.Equal(t, "json", b.Name()) type Post struct { Title string `json:"title"` } type User struct { Name string `json:"name"` Posts []Post `json:"posts"` Age int `json:"age"` } var user User err := b.Bind([]byte(`{"name":"john","age":42,"posts":[{"title":"post1"},{"title":"post2"},{"title":"post3"}]}`), &user) require.NoError(t, err) require.Equal(t, "john", user.Name) require.Equal(t, 42, user.Age) require.Len(t, user.Posts, 3) require.Equal(t, "post1", user.Posts[0].Title) require.Equal(t, "post2", user.Posts[1].Title) require.Equal(t, "post3", user.Posts[2].Title) b.Reset() require.Nil(t, b.JSONDecoder) } func Benchmark_JSON_Binding_Bind(b *testing.B) { b.ReportAllocs() binder := &JSONBinding{ JSONDecoder: json.Unmarshal, } type User struct { Name string `json:"name"` Posts []string `json:"posts"` Age int `json:"age"` } var user User var err error for b.Loop() { err = binder.Bind([]byte(`{"name":"john","age":42,"posts":["post1","post2","post3"]}`), &user) } require.NoError(b, err) require.Equal(b, "john", user.Name) require.Equal(b, 42, user.Age) require.Len(b, user.Posts, 3) require.Equal(b, "post1", user.Posts[0]) require.Equal(b, "post2", user.Posts[1]) require.Equal(b, "post3", user.Posts[2]) } ================================================ FILE: binder/mapping.go ================================================ package binder import ( "fmt" "maps" "mime/multipart" "reflect" "strings" "sync" utilsstrings "github.com/gofiber/utils/v2/strings" "github.com/valyala/bytebufferpool" "github.com/gofiber/schema" ) // ParserConfig form decoder config for SetParserDecoder type ParserConfig struct { SetAliasTag string ParserType []ParserType IgnoreUnknownKeys bool ZeroEmpty bool } // ParserType require two element, type and converter for register. // Use ParserType with BodyParser for parsing custom type in form data. type ParserType struct { CustomType any Converter func(string) reflect.Value } var ( decoderPoolMu sync.RWMutex // decoderPoolMap helps to improve binders decoderPoolMap = map[string]*sync.Pool{} // tags is used to classify parser's pool tags = []string{"header", "respHeader", "cookie", "query", "form", "uri"} ) func getDecoderPool(tag string) *sync.Pool { decoderPoolMu.RLock() pool := decoderPoolMap[tag] if pool == nil { decoderPoolMu.RUnlock() panic(fmt.Sprintf("decoder pool not initialized for tag %q", tag)) } decoderPoolMu.RUnlock() return pool } // SetParserDecoder allow globally change the option of form decoder, update decoderPool func SetParserDecoder(parserConfig ParserConfig) { decoderPoolMu.Lock() defer decoderPoolMu.Unlock() for _, tag := range tags { decoderPoolMap[tag] = &sync.Pool{New: func() any { return decoderBuilder(parserConfig) }} } } func decoderBuilder(parserConfig ParserConfig) any { decoder := schema.NewDecoder() decoder.IgnoreUnknownKeys(parserConfig.IgnoreUnknownKeys) if parserConfig.SetAliasTag != "" { decoder.SetAliasTag(parserConfig.SetAliasTag) } for _, v := range parserConfig.ParserType { decoder.RegisterConverter(reflect.ValueOf(v.CustomType).Interface(), v.Converter) } decoder.ZeroEmpty(parserConfig.ZeroEmpty) return decoder } func init() { decoderPoolMu.Lock() defer decoderPoolMu.Unlock() for _, tag := range tags { decoderPoolMap[tag] = &sync.Pool{New: func() any { return decoderBuilder(ParserConfig{ IgnoreUnknownKeys: true, ZeroEmpty: true, }) }} } } // parse data into the map or struct func parse(aliasTag string, out any, data map[string][]string, files ...map[string][]*multipart.FileHeader) error { ptrVal := reflect.ValueOf(out) // Get pointer value if ptrVal.Kind() == reflect.Ptr { ptrVal = ptrVal.Elem() } // Parse into the map if ptrVal.Kind() == reflect.Map && ptrVal.Type().Key().Kind() == reflect.String { return parseToMap(ptrVal, data) } // Parse into the struct return parseToStruct(aliasTag, out, data, files...) } // Parse data into the struct with gofiber/schema func parseToStruct(aliasTag string, out any, data map[string][]string, files ...map[string][]*multipart.FileHeader) error { // Get decoder from pool pool := getDecoderPool(aliasTag) schemaDecoder := pool.Get().(*schema.Decoder) //nolint:errcheck,forcetypeassert // not needed defer pool.Put(schemaDecoder) // Set alias tag schemaDecoder.SetAliasTag(aliasTag) if err := schemaDecoder.Decode(out, data, files...); err != nil { return fmt.Errorf("%w", err) } return nil } // Parse data into the map // thanks to https://github.com/gin-gonic/gin/blob/master/binding/binding.go func parseToMap(target reflect.Value, data map[string][]string) error { if !target.IsValid() { return ErrInvalidDestinationValue } if target.Kind() == reflect.Interface && !target.IsNil() { target = target.Elem() } if target.Kind() != reflect.Map || target.Type().Key().Kind() != reflect.String { return nil // nothing to do for non-map destinations } if target.IsNil() { if !target.CanSet() { return ErrMapNilDestination } target.Set(reflect.MakeMap(target.Type())) } switch target.Type().Elem().Kind() { case reflect.Slice: newMap, ok := target.Interface().(map[string][]string) if !ok { return ErrMapNotConvertible } maps.Copy(newMap, data) case reflect.String: newMap, ok := target.Interface().(map[string]string) if !ok { return ErrMapNotConvertible } for k, v := range data { if len(v) == 0 { newMap[k] = "" continue } newMap[k] = v[len(v)-1] } default: // Interface element maps (e.g. map[string]any) are left untouched because // the binder cannot safely infer element conversions without mutating // caller-provided values. These destinations therefore see a successful // no-op parse. return nil // it's not necessary to check all types } return nil } func parseParamSquareBrackets(k string) (string, error) { bb := bytebufferpool.Get() defer bytebufferpool.Put(bb) kbytes := []byte(k) openBracketsCount := 0 for i, b := range kbytes { if b == '[' { openBracketsCount++ if i+1 < len(kbytes) && kbytes[i+1] != ']' { if err := bb.WriteByte('.'); err != nil { return "", err //nolint:wrapcheck // unnecessary to wrap it } } continue } if b == ']' { openBracketsCount-- if openBracketsCount < 0 { return "", ErrUnmatchedBrackets } continue } if err := bb.WriteByte(b); err != nil { return "", err //nolint:wrapcheck // unnecessary to wrap it } } if openBracketsCount > 0 { return "", ErrUnmatchedBrackets } return bb.String(), nil } func isStringKeyMap(t reflect.Type) bool { return t.Kind() == reflect.Map && t.Key().Kind() == reflect.String } func isExported(f *reflect.StructField) bool { if f == nil { return false } return f.PkgPath == "" } func fieldName(f *reflect.StructField, aliasTag string) string { if f == nil { return "" } name := f.Tag.Get(aliasTag) if name == "" { name = f.Name } else if first, _, found := strings.Cut(name, ","); found { name = first } return utilsstrings.ToLower(name) } type fieldInfo struct { names map[string]reflect.Kind nestedKinds map[reflect.Kind]struct{} } func unwrapType(t reflect.Type) reflect.Type { for t.Kind() == reflect.Ptr { t = t.Elem() } return t } var ( headerFieldCache sync.Map respHeaderFieldCache sync.Map cookieFieldCache sync.Map queryFieldCache sync.Map formFieldCache sync.Map uriFieldCache sync.Map ) func getFieldCache(aliasTag string) *sync.Map { switch aliasTag { case "header": return &headerFieldCache case "respHeader": return &respHeaderFieldCache case "cookie": return &cookieFieldCache case "form": return &formFieldCache case "uri": return &uriFieldCache case "query": return &queryFieldCache } panic("unknown alias tag: " + aliasTag) } func buildFieldInfo(t reflect.Type, aliasTag string) fieldInfo { info := fieldInfo{ names: make(map[string]reflect.Kind), nestedKinds: make(map[reflect.Kind]struct{}), } for i := 0; i < t.NumField(); i++ { f := t.Field(i) if !isExported(&f) { continue } fieldType := unwrapType(f.Type) info.names[fieldName(&f, aliasTag)] = fieldType.Kind() if fieldType.Kind() == reflect.Struct { for j := 0; j < fieldType.NumField(); j++ { sf := fieldType.Field(j) if !isExported(&sf) { continue } nestedType := unwrapType(sf.Type) info.nestedKinds[nestedType.Kind()] = struct{}{} } } } return info } func equalFieldType(out any, kind reflect.Kind, key, aliasTag string) bool { typ := reflect.TypeOf(out).Elem() key = utilsstrings.ToLower(key) if isStringKeyMap(typ) { return true } if typ.Kind() != reflect.Struct { return false } cache := getFieldCache(aliasTag) val, ok := cache.Load(typ) if !ok { info := buildFieldInfo(typ, aliasTag) val, _ = cache.LoadOrStore(typ, info) } info, ok := val.(fieldInfo) if !ok { return false } if k, ok := info.names[key]; ok && k == kind { return true } if _, ok := info.nestedKinds[kind]; ok { return true } return false } // FilterFlags returns the media type value by trimming any parameters from a Content-Type header. func FilterFlags(content string) string { if i := strings.IndexAny(content, " ;"); i >= 0 { return content[:i] } return content } func formatBindData[T, K any](aliasTag string, out any, data map[string][]T, key string, value K, enableSplitting, supportBracketNotation bool) error { //nolint:revive // it's okay var err error if supportBracketNotation && strings.IndexByte(key, '[') >= 0 { key, err = parseParamSquareBrackets(key) if err != nil { return err } } switch v := any(value).(type) { case string: dataMap, ok := any(data).(map[string][]string) if !ok { return fmt.Errorf("unsupported value type: %T", value) } assignBindData(aliasTag, out, dataMap, key, v, enableSplitting) case []string: dataMap, ok := any(data).(map[string][]string) if !ok { return fmt.Errorf("unsupported value type: %T", value) } for _, val := range v { assignBindData(aliasTag, out, dataMap, key, val, enableSplitting) } case []*multipart.FileHeader: for _, val := range v { valT, ok := any(val).(T) if !ok { return fmt.Errorf("unsupported value type: %T", value) } data[key] = append(data[key], valT) } default: return fmt.Errorf("unsupported value type: %T", value) } return err } func assignBindData(aliasTag string, out any, data map[string][]string, key, value string, enableSplitting bool) { //nolint:revive // it's okay if enableSplitting && strings.IndexByte(value, ',') >= 0 && equalFieldType(out, reflect.Slice, key, aliasTag) { for v := range strings.SplitSeq(value, ",") { data[key] = append(data[key], v) } } else { data[key] = append(data[key], value) } } ================================================ FILE: binder/mapping_test.go ================================================ package binder import ( "fmt" "mime/multipart" "reflect" "strconv" "sync" "testing" "github.com/gofiber/schema" "github.com/stretchr/testify/require" ) func Test_EqualFieldType(t *testing.T) { t.Parallel() var out int require.False(t, equalFieldType(&out, reflect.Int, "key", "query")) var dummy struct{ f string } require.False(t, equalFieldType(&dummy, reflect.String, "key", "query")) var dummy2 struct{ f string } require.False(t, equalFieldType(&dummy2, reflect.String, "f", "query")) var user struct { Name string Address string `query:"address"` Age int `query:"AGE"` } require.True(t, equalFieldType(&user, reflect.String, "name", "query")) require.True(t, equalFieldType(&user, reflect.String, "Name", "query")) require.True(t, equalFieldType(&user, reflect.String, "address", "query")) require.True(t, equalFieldType(&user, reflect.String, "Address", "query")) require.True(t, equalFieldType(&user, reflect.Int, "AGE", "query")) require.True(t, equalFieldType(&user, reflect.Int, "age", "query")) var user2 struct { User struct { Name string Address string `query:"address"` Age int `query:"AGE"` } `query:"user"` } require.True(t, equalFieldType(&user2, reflect.String, "user.name", "query")) require.True(t, equalFieldType(&user2, reflect.String, "user.Name", "query")) require.True(t, equalFieldType(&user2, reflect.String, "user.address", "query")) require.True(t, equalFieldType(&user2, reflect.String, "user.Address", "query")) require.True(t, equalFieldType(&user2, reflect.Int, "user.AGE", "query")) require.True(t, equalFieldType(&user2, reflect.Int, "user.age", "query")) var pointerUser struct { Tags *[]string `query:"tags"` } require.True(t, equalFieldType(&pointerUser, reflect.Slice, "tags", "query")) type nested struct { Values []string `query:"values"` } var nestedWrapper struct { Nested *nested `query:"nested"` } require.True(t, equalFieldType(&nestedWrapper, reflect.Slice, "nested.values", "query")) type nestedPointerSlice struct { Values *[]string `query:"values"` } var nestedPointerWrapper struct { Nested *nestedPointerSlice `query:"nested"` } require.True(t, equalFieldType(&nestedPointerWrapper, reflect.Slice, "nested.values", "query")) } func Test_ParseParamSquareBrackets(t *testing.T) { t.Parallel() tests := []struct { err error input string expected string }{ { err: nil, input: "foo[bar]", expected: "foo.bar", }, { err: nil, input: "foo[bar][baz]", expected: "foo.bar.baz", }, { err: ErrUnmatchedBrackets, input: "foo[bar", expected: "", }, { err: ErrUnmatchedBrackets, input: "foo[bar][baz", expected: "", }, { err: ErrUnmatchedBrackets, input: "foo]bar[", expected: "", }, { err: nil, input: "foo[bar[baz]]", expected: "foo.bar.baz", }, { err: nil, input: "", expected: "", }, { err: nil, input: "[]", expected: "", }, { err: nil, input: "foo[]", expected: "foo", }, } for _, tt := range tests { t.Run(tt.input, func(t *testing.T) { t.Parallel() result, err := parseParamSquareBrackets(tt.input) if tt.err != nil { require.ErrorIs(t, err, tt.err) } else { require.NoError(t, err) require.Equal(t, tt.expected, result) } }) } } func Test_parseToMap(t *testing.T) { t.Parallel() inputMap := map[string][]string{ "key1": {"value1", "value2"}, "key2": {"value3"}, "key3": {"value4"}, } // Test map[string]string m := make(map[string]string) err := parseToMap(reflect.ValueOf(m), inputMap) require.NoError(t, err) require.Equal(t, "value2", m["key1"]) require.Equal(t, "value3", m["key2"]) require.Equal(t, "value4", m["key3"]) // Test map[string][]string m2 := make(map[string][]string) err = parseToMap(reflect.ValueOf(m2), inputMap) require.NoError(t, err) require.Len(t, m2["key1"], 2) require.Contains(t, m2["key1"], "value1") require.Contains(t, m2["key1"], "value2") require.Len(t, m2["key2"], 1) require.Len(t, m2["key3"], 1) // Test map[string]any m3 := make(map[string]any) err = parseToMap(reflect.ValueOf(m3), inputMap) require.NoError(t, err) require.Empty(t, m3) var zeroStringMap map[string]string err = parseToMap(reflect.ValueOf(&zeroStringMap).Elem(), inputMap) require.NoError(t, err) require.Equal(t, "value2", zeroStringMap["key1"]) var zeroSliceMap map[string][]string err = parseToMap(reflect.ValueOf(&zeroSliceMap).Elem(), inputMap) require.NoError(t, err) require.Len(t, zeroSliceMap["key1"], 2) err = parseToMap(reflect.ValueOf(map[string]string(nil)), inputMap) require.ErrorIs(t, err, ErrMapNilDestination) } func Test_FilterFlags(t *testing.T) { t.Parallel() tests := []struct { input string expected string }{ { input: "text/javascript; charset=utf-8", expected: "text/javascript", }, { input: "text/javascript", expected: "text/javascript", }, { input: "text/javascript; charset=utf-8; foo=bar", expected: "text/javascript", }, { input: "text/javascript charset=utf-8", expected: "text/javascript", }, } for _, tt := range tests { t.Run(tt.input, func(t *testing.T) { t.Parallel() result := FilterFlags(tt.input) require.Equal(t, tt.expected, result) }) } } func Benchmark_FilterFlags(b *testing.B) { b.ReportAllocs() cases := []string{ "text/javascript; charset=utf-8", "application/json", "text/plain; charset=utf-8; foo=bar", "text/javascript charset=utf-8", } for i := 0; i < b.N; i++ { _ = FilterFlags(cases[i&3]) } } func TestFormatBindData(t *testing.T) { t.Parallel() t.Run("string value with valid key", func(t *testing.T) { t.Parallel() out := struct{}{} data := make(map[string][]string) err := formatBindData("query", out, data, "name", "John", false, false) if err != nil { t.Fatalf("unexpected error: %v", err) } if len(data["name"]) != 1 || data["name"][0] != "John" { t.Fatalf("expected data[\"name\"] = [John], got %v", data["name"]) } }) t.Run("unsupported value type", func(t *testing.T) { t.Parallel() out := struct{}{} data := make(map[string][]string) err := formatBindData("query", out, data, "age", 30, false, false) // int is unsupported if err == nil { t.Fatal("expected an error, got nil") } }) t.Run("bracket notation parsing error", func(t *testing.T) { t.Parallel() out := struct{}{} data := make(map[string][]string) err := formatBindData("query", out, data, "invalid[", "value", false, true) // malformed bracket notation if err == nil { t.Fatal("expected an error, got nil") } }) t.Run("handling multipart file headers", func(t *testing.T) { t.Parallel() out := struct{}{} data := make(map[string][]*multipart.FileHeader) files := []*multipart.FileHeader{ {Filename: "file1.txt"}, {Filename: "file2.txt"}, } err := formatBindData("query", out, data, "files", files, false, false) if err != nil { t.Fatalf("unexpected error: %v", err) } if len(data["files"]) != 2 { t.Fatalf("expected 2 files, got %d", len(data["files"])) } }) t.Run("type casting error", func(t *testing.T) { t.Parallel() out := struct{}{} data := map[string][]int{} // Incorrect type to force a casting error err := formatBindData("query", out, data, "key", "value", false, false) require.Equal(t, "unsupported value type: string", err.Error()) }) } func TestAssignBindData(t *testing.T) { t.Parallel() t.Run("splitting enabled with comma", func(t *testing.T) { t.Parallel() out := struct { Colors []string `query:"colors"` }{} data := make(map[string][]string) assignBindData("query", &out, data, "colors", "red,blue,green", true) require.Len(t, data["colors"], 3) }) t.Run("splitting disabled", func(t *testing.T) { t.Parallel() var out []string data := make(map[string][]string) assignBindData("query", out, data, "color", "red,blue", false) require.Len(t, data["color"], 1) }) } func Test_parseToStruct_MismatchedData(t *testing.T) { t.Parallel() type User struct { Name string `query:"name"` Age int `query:"age"` } data := map[string][]string{ "name": {"John"}, "age": {"invalidAge"}, } err := parseToStruct("query", &User{}, data) require.Error(t, err) require.EqualError(t, err, "schema: error converting value for \"age\"") } func Test_formatBindData_ErrorCases(t *testing.T) { t.Parallel() t.Run("unsupported value type int", func(t *testing.T) { t.Parallel() out := struct{}{} data := make(map[string][]string) err := formatBindData("query", out, data, "age", 30, false, false) // int is unsupported require.Error(t, err) require.EqualError(t, err, "unsupported value type: int") }) t.Run("unsupported value type map", func(t *testing.T) { t.Parallel() out := struct{}{} data := make(map[string][]string) err := formatBindData("query", out, data, "map", map[string]string{"key": "value"}, false, false) // map is unsupported require.Error(t, err) require.EqualError(t, err, "unsupported value type: map[string]string") }) t.Run("bracket notation parsing error", func(t *testing.T) { t.Parallel() out := struct{}{} data := make(map[string][]string) err := formatBindData("query", out, data, "invalid[", "value", false, true) // malformed bracket notation require.Error(t, err) require.EqualError(t, err, "unmatched brackets") }) t.Run("type casting error for []string", func(t *testing.T) { t.Parallel() out := struct{}{} data := make(map[string][]string) err := formatBindData("query", out, data, "names", 123, false, false) // invalid type for []string require.Error(t, err) require.EqualError(t, err, "unsupported value type: int") }) } func Test_decoderBuilder(t *testing.T) { t.Parallel() type customInt int conv := func(s string) reflect.Value { i, err := strconv.Atoi(s) if err != nil { panic(err) } return reflect.ValueOf(customInt(i)) } parserConfig := ParserConfig{ SetAliasTag: "custom", ParserType: []ParserType{{ CustomType: customInt(0), Converter: conv, }}, IgnoreUnknownKeys: false, ZeroEmpty: false, } decAny := decoderBuilder(parserConfig) dec, ok := decAny.(*schema.Decoder) require.True(t, ok) var out struct { X customInt `custom:"x"` } err := dec.Decode(&out, map[string][]string{"x": {"7"}}) require.NoError(t, err) require.Equal(t, customInt(7), out.X) } func Test_parseToMap_Extended(t *testing.T) { t.Parallel() data := map[string][]string{ "empty": {}, "key1": {"value1"}, } m := make(map[string]string) err := parseToMap(reflect.ValueOf(m), data) require.NoError(t, err) require.Empty(t, m["empty"]) m2 := make(map[string][]int) err = parseToMap(reflect.ValueOf(m2), data) require.ErrorIs(t, err, ErrMapNotConvertible) m3 := make(map[string]int) err = parseToMap(reflect.ValueOf(m3), data) require.NoError(t, err) } func Test_decoderPoolMapInit(t *testing.T) { t.Parallel() for _, tag := range tags { decAny := getDecoderPool(tag).Get() dec, ok := decAny.(*schema.Decoder) require.True(t, ok) require.NotNil(t, dec) getDecoderPool(tag).Put(decAny) } } func TestSetParserDecoderConcurrentAccess(t *testing.T) { t.Parallel() t.Cleanup(func() { SetParserDecoder(ParserConfig{ IgnoreUnknownKeys: true, ZeroEmpty: true, }) }) type queryUser struct { Name string `query:"name"` } data := map[string][]string{ "name": {"fiber"}, } parserConfig := ParserConfig{ IgnoreUnknownKeys: true, ZeroEmpty: true, } start := make(chan struct{}) const workers = 25 errCh := make(chan error, workers*2) var wg sync.WaitGroup runWorker := func(fn func() error) { wg.Go(func() { <-start defer func() { if r := recover(); r != nil { errCh <- fmt.Errorf("panic: %v", r) } }() if err := fn(); err != nil { errCh <- err } }) } for i := 0; i < workers; i++ { runWorker(func() error { SetParserDecoder(parserConfig) return nil }) runWorker(func() error { var out queryUser if err := parseToStruct("query", &out, data); err != nil { return err } if out.Name != "fiber" { return fmt.Errorf("unexpected name %q", out.Name) } return nil }) } close(start) wg.Wait() close(errCh) for err := range errCh { require.NoError(t, err) } } func Test_getFieldCache(t *testing.T) { t.Parallel() require.NotNil(t, getFieldCache("header")) require.NotNil(t, getFieldCache("respHeader")) require.NotNil(t, getFieldCache("cookie")) require.NotNil(t, getFieldCache("form")) require.NotNil(t, getFieldCache("uri")) require.NotNil(t, getFieldCache("query")) require.Panics(t, func() { getFieldCache("unknown") }) } func Test_EqualFieldType_Map(t *testing.T) { t.Parallel() m := map[string]int{} require.True(t, equalFieldType(&m, reflect.Int, "any", "query")) } func Test_equalFieldType_CacheTypeMismatch(t *testing.T) { type Sample struct { Field string `query:"field"` } cache := getFieldCache("query") typ := reflect.TypeOf(Sample{}) cache.Store(typ, 1) defer cache.Delete(typ) var s Sample require.False(t, equalFieldType(&s, reflect.String, "field", "query")) } func Test_buildFieldInfo_Unexported(t *testing.T) { t.Parallel() type nested struct { export int Exported int } _ = nested{export: 0} type outer struct { Name string Nested nested } info := buildFieldInfo(reflect.TypeOf(outer{}), "query") require.Contains(t, info.names, "name") _, ok := info.nestedKinds[reflect.Int] require.True(t, ok) } func Test_formatBindData_BracketNotationSuccess(t *testing.T) { t.Parallel() out := struct{}{} data := make(map[string][]string) err := formatBindData("query", out, data, "user[name]", "john", false, true) require.NoError(t, err) require.Equal(t, "john", data["user.name"][0]) } func Test_formatBindData_FileHeaderTypeMismatch(t *testing.T) { t.Parallel() out := struct{}{} data := map[string][]int{} files := []*multipart.FileHeader{{Filename: "file1.txt"}} err := formatBindData("query", out, data, "file", files, false, false) require.EqualError(t, err, "unsupported value type: []*multipart.FileHeader") } func Benchmark_equalFieldType(b *testing.B) { type Nested struct { Name string `query:"name"` } type User struct { Name string `query:"name"` Nested Nested `query:"user"` Age int `query:"age"` } var user User b.ReportAllocs() for b.Loop() { equalFieldType(&user, reflect.String, "name", "query") equalFieldType(&user, reflect.Int, "age", "query") equalFieldType(&user, reflect.String, "user.name", "query") } } ================================================ FILE: binder/msgpack.go ================================================ package binder import ( "github.com/gofiber/utils/v2" ) // MsgPackBinding is the MsgPack binder for MsgPack request body. type MsgPackBinding struct { MsgPackDecoder utils.MsgPackUnmarshal } // Name returns the binding name. func (*MsgPackBinding) Name() string { return "msgpack" } // Bind parses the request body as MsgPack and returns the result. func (b *MsgPackBinding) Bind(body []byte, out any) error { return b.MsgPackDecoder(body, out) } // Reset resets the MsgPackBinding binder. func (b *MsgPackBinding) Reset() { b.MsgPackDecoder = nil } // UnimplementedMsgpackMarshal panics to signal that a Msgpack marshaler must // be configured before MsgPack support can be used. func UnimplementedMsgpackMarshal(_ any) ([]byte, error) { panic("Must explicit setup Msgpack, please check docs: https://docs.gofiber.io/next/guide/advance-format#msgpack") } // UnimplementedMsgpackUnmarshal panics to signal that a Msgpack unmarshaler // must be configured before MsgPack support can be used. func UnimplementedMsgpackUnmarshal(_ []byte, _ any) error { panic("Must explicit setup Msgpack, please check docs: https://docs.gofiber.io/next/guide/advance-format#msgpack") } ================================================ FILE: binder/msgpack_test.go ================================================ package binder import ( "testing" "github.com/shamaton/msgpack/v3" "github.com/stretchr/testify/require" ) func Test_Msgpack_Binding_Bind(t *testing.T) { t.Parallel() b := &MsgPackBinding{ MsgPackDecoder: msgpack.Unmarshal, } require.Equal(t, "msgpack", b.Name()) type Post struct { Title string `msgpack:"title"` } type User struct { Name string `msgpack:"name"` Posts []Post `msgpack:"posts"` Age int `msgpack:"age"` } var user User // Prepare msgpack data input := map[string]any{ "name": "john", "age": 42, "posts": []map[string]any{ {"title": "post1"}, {"title": "post2"}, {"title": "post3"}, }, } data, err := msgpack.Marshal(input) require.NoError(t, err) err = b.Bind(data, &user) require.NoError(t, err) require.Equal(t, "john", user.Name) require.Equal(t, 42, user.Age) require.Len(t, user.Posts, 3) require.Equal(t, "post1", user.Posts[0].Title) require.Equal(t, "post2", user.Posts[1].Title) require.Equal(t, "post3", user.Posts[2].Title) b.Reset() require.Nil(t, b.MsgPackDecoder) } func Benchmark_Msgpack_Binding_Bind(b *testing.B) { b.ReportAllocs() binder := &MsgPackBinding{ MsgPackDecoder: msgpack.Unmarshal, } type User struct { Name string `msgpack:"name"` Posts []string `msgpack:"posts"` Age int `msgpack:"age"` } var user User var err error for b.Loop() { // {"name":"john","age":42,"posts":[{"title":"post1"},{"title":"post2"},{"title":"post3"}]} err = binder.Bind([]byte{ 0x83, 0xa4, 0x6e, 0x61, 0x6d, 0x65, 0xa4, 0x6a, 0x6f, 0x68, 0x6e, 0xa3, 0x61, 0x67, 0x65, 0x2a, 0xa5, 0x70, 0x6f, 0x73, 0x74, 0x73, 0x93, 0xa5, 0x70, 0x6f, 0x73, 0x74, 0x31, 0xa5, 0x70, 0x6f, 0x73, 0x74, 0x32, 0xa5, 0x70, 0x6f, 0x73, 0x74, 0x33, }, &user) } require.NoError(b, err) require.Equal(b, "john", user.Name) require.Equal(b, 42, user.Age) require.Len(b, user.Posts, 3) require.Equal(b, "post1", user.Posts[0]) require.Equal(b, "post2", user.Posts[1]) require.Equal(b, "post3", user.Posts[2]) } func Test_UnimplementedMsgpackMarshal_Panics(t *testing.T) { t.Parallel() require.Panics(t, func() { _, err := UnimplementedMsgpackMarshal(struct{ Name string }{Name: "test"}) require.NoError(t, err) }) } func Test_UnimplementedMsgpackUnmarshal_Panics(t *testing.T) { t.Parallel() require.Panics(t, func() { var out any err := UnimplementedMsgpackUnmarshal([]byte{0x80}, &out) require.NoError(t, err) }) } func Test_UnimplementedMsgpackMarshal_PanicMessage(t *testing.T) { t.Parallel() defer func() { if r := recover(); r != nil { require.Contains(t, r, "Must explicit setup Msgpack") } }() _, err := UnimplementedMsgpackMarshal(struct{ Name string }{Name: "test"}) require.NoError(t, err) } func Test_UnimplementedMsgpackUnmarshal_PanicMessage(t *testing.T) { t.Parallel() defer func() { if r := recover(); r != nil { require.Contains(t, r, "Must explicit setup Msgpack") } }() var out any err := UnimplementedMsgpackUnmarshal([]byte{0x80}, &out) require.NoError(t, err) } ================================================ FILE: binder/query.go ================================================ package binder import ( "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) // QueryBinding is the query binder for query request body. type QueryBinding struct { EnableSplitting bool } // Name returns the binding name. func (*QueryBinding) Name() string { return "query" } // Bind parses the request query and returns the result. func (b *QueryBinding) Bind(reqCtx *fasthttp.Request, out any) error { data := make(map[string][]string) for key, val := range reqCtx.URI().QueryArgs().All() { k := utils.UnsafeString(key) v := utils.UnsafeString(val) if err := formatBindData(b.Name(), out, data, k, v, b.EnableSplitting, true); err != nil { return err } } return parse(b.Name(), out, data) } // Reset resets the QueryBinding binder. func (b *QueryBinding) Reset() { b.EnableSplitting = false } ================================================ FILE: binder/query_test.go ================================================ package binder import ( "testing" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) func Test_QueryBinder_Bind(t *testing.T) { t.Parallel() b := &QueryBinding{ EnableSplitting: true, } require.Equal(t, "query", b.Name()) type Post struct { Title string `query:"title"` } type User struct { Name string `query:"name"` Names []string `query:"names"` Posts []Post `query:"posts"` Age int `query:"age"` } var user User req := fasthttp.AcquireRequest() req.URI().SetQueryString("name=john&names=john,doe&age=42&posts[0][title]=post1&posts[1][title]=post2&posts[2][title]=post3") t.Cleanup(func() { fasthttp.ReleaseRequest(req) }) err := b.Bind(req, &user) require.NoError(t, err) require.Equal(t, "john", user.Name) require.Equal(t, 42, user.Age) require.Len(t, user.Posts, 3) require.Equal(t, "post1", user.Posts[0].Title) require.Equal(t, "post2", user.Posts[1].Title) require.Equal(t, "post3", user.Posts[2].Title) require.Contains(t, user.Names, "john") require.Contains(t, user.Names, "doe") b.Reset() require.False(t, b.EnableSplitting) } func Benchmark_QueryBinder_Bind(b *testing.B) { b.ReportAllocs() binder := &QueryBinding{ EnableSplitting: true, } type User struct { Name string `query:"name"` Posts []string `query:"posts"` Age int `query:"age"` } var user User req := fasthttp.AcquireRequest() b.Cleanup(func() { fasthttp.ReleaseRequest(req) }) req.URI().SetQueryString("name=john&age=42&posts=post1,post2,post3") var err error for b.Loop() { err = binder.Bind(req, &user) } require.NoError(b, err) require.Equal(b, "john", user.Name) require.Equal(b, 42, user.Age) require.Len(b, user.Posts, 3) require.Contains(b, user.Posts, "post1") require.Contains(b, user.Posts, "post2") require.Contains(b, user.Posts, "post3") } func Test_QueryBinder_Bind_PointerSlices(t *testing.T) { t.Parallel() binder := &QueryBinding{ EnableSplitting: true, } type Preferences struct { Tags *[]string `query:"tags"` } type Profile struct { Emails *[]string `query:"emails"` Prefs *Preferences `query:"preferences"` } var profile Profile req := fasthttp.AcquireRequest() req.URI().SetQueryString("emails=work,personal&preferences[tags]=golang,api") t.Cleanup(func() { fasthttp.ReleaseRequest(req) }) err := binder.Bind(req, &profile) require.NoError(t, err) require.NotNil(t, profile.Emails) require.ElementsMatch(t, []string{"work", "personal"}, *profile.Emails) require.NotNil(t, profile.Prefs) require.NotNil(t, profile.Prefs.Tags) require.ElementsMatch(t, []string{"golang", "api"}, *profile.Prefs.Tags) } ================================================ FILE: binder/resp_header.go ================================================ package binder import ( "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) // RespHeaderBinding is the respHeader binder for response header. type RespHeaderBinding struct { EnableSplitting bool } // Name returns the binding name. func (*RespHeaderBinding) Name() string { return "respHeader" } // Bind parses the response header and returns the result. func (b *RespHeaderBinding) Bind(resp *fasthttp.Response, out any) error { data := make(map[string][]string) for key, val := range resp.Header.All() { k := utils.UnsafeString(key) v := utils.UnsafeString(val) if err := formatBindData(b.Name(), out, data, k, v, b.EnableSplitting, false); err != nil { return err } } return parse(b.Name(), out, data) } // Reset resets the RespHeaderBinding binder. func (b *RespHeaderBinding) Reset() { b.EnableSplitting = false } ================================================ FILE: binder/resp_header_test.go ================================================ package binder import ( "testing" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) func Test_RespHeaderBinder_Bind(t *testing.T) { t.Parallel() b := &RespHeaderBinding{ EnableSplitting: true, } require.Equal(t, "respHeader", b.Name()) type User struct { Name string `respHeader:"name"` Posts []string `respHeader:"posts"` Age int `respHeader:"age"` } var user User resp := fasthttp.AcquireResponse() resp.Header.Set("name", "john") resp.Header.Set("age", "42") resp.Header.Set("posts", "post1,post2,post3") t.Cleanup(func() { fasthttp.ReleaseResponse(resp) }) err := b.Bind(resp, &user) require.NoError(t, err) require.Equal(t, "john", user.Name) require.Equal(t, 42, user.Age) require.Equal(t, []string{"post1", "post2", "post3"}, user.Posts) b.Reset() require.False(t, b.EnableSplitting) } func Benchmark_RespHeaderBinder_Bind(b *testing.B) { b.ReportAllocs() binder := &RespHeaderBinding{ EnableSplitting: true, } type User struct { Name string `respHeader:"name"` Posts []string `respHeader:"posts"` Age int `respHeader:"age"` } var user User resp := fasthttp.AcquireResponse() resp.Header.Set("name", "john") resp.Header.Set("age", "42") resp.Header.Set("posts", "post1,post2,post3") b.Cleanup(func() { fasthttp.ReleaseResponse(resp) }) var err error for b.Loop() { err = binder.Bind(resp, &user) } require.NoError(b, err) require.Equal(b, "john", user.Name) require.Equal(b, 42, user.Age) require.Equal(b, []string{"post1", "post2", "post3"}, user.Posts) } func Test_RespHeaderBinder_Bind_ParseError(t *testing.T) { b := &RespHeaderBinding{} type User struct { Age int `respHeader:"age"` } var user User resp := fasthttp.AcquireResponse() resp.Header.Set("age", "invalid") t.Cleanup(func() { fasthttp.ReleaseResponse(resp) }) err := b.Bind(resp, &user) require.Error(t, err) } ================================================ FILE: binder/uri.go ================================================ package binder // URIBinding is the binder implementation for populating values from route parameters. type URIBinding struct{} // Name returns the binding name. func (*URIBinding) Name() string { return "uri" } // Bind parses the URI parameters and returns the result. func (b *URIBinding) Bind(params []string, paramsFunc func(key string, defaultValue ...string) string, out any) error { data := make(map[string][]string, len(params)) for _, param := range params { data[param] = append(data[param], paramsFunc(param)) } return parse(b.Name(), out, data) } // Reset resets URIBinding binder. func (*URIBinding) Reset() { // Nothing to reset } ================================================ FILE: binder/uri_test.go ================================================ package binder import ( "testing" "github.com/stretchr/testify/require" ) func Test_URIBinding_Bind(t *testing.T) { t.Parallel() b := &URIBinding{} require.Equal(t, "uri", b.Name()) type User struct { Name string `uri:"name"` Posts []string `uri:"posts"` Age int `uri:"age"` } var user User paramsKey := []string{"name", "age", "posts"} paramsVals := []string{"john", "42", "post1,post2,post3"} paramsFunc := func(key string, _ ...string) string { for i, k := range paramsKey { if k == key { return paramsVals[i] } } return "" } err := b.Bind(paramsKey, paramsFunc, &user) require.NoError(t, err) require.Equal(t, "john", user.Name) require.Equal(t, 42, user.Age) require.Equal(t, []string{"post1,post2,post3"}, user.Posts) b.Reset() } func Benchmark_URIBinding_Bind(b *testing.B) { b.ReportAllocs() binder := &URIBinding{} type User struct { Name string `uri:"name"` Posts []string `uri:"posts"` Age int `uri:"age"` } var user User paramsKey := []string{"name", "age", "posts"} paramsVals := []string{"john", "42", "post1,post2,post3"} paramsFunc := func(key string, _ ...string) string { for i, k := range paramsKey { if k == key { return paramsVals[i] } } return "" } var err error for b.Loop() { err = binder.Bind(paramsKey, paramsFunc, &user) } require.NoError(b, err) require.Equal(b, "john", user.Name) require.Equal(b, 42, user.Age) require.Equal(b, []string{"post1,post2,post3"}, user.Posts) } ================================================ FILE: binder/xml.go ================================================ package binder import ( "fmt" "github.com/gofiber/utils/v2" ) // XMLBinding is the XML binder for XML request body. type XMLBinding struct { XMLDecoder utils.XMLUnmarshal } // Name returns the binding name. func (*XMLBinding) Name() string { return "xml" } // Bind parses the request body as XML and returns the result. func (b *XMLBinding) Bind(body []byte, out any) error { if err := b.XMLDecoder(body, out); err != nil { return fmt.Errorf("failed to unmarshal xml: %w", err) } return nil } // Reset resets the XMLBinding binder. func (b *XMLBinding) Reset() { b.XMLDecoder = nil } ================================================ FILE: binder/xml_test.go ================================================ package binder import ( "encoding/xml" "testing" "github.com/stretchr/testify/require" ) func Test_XMLBinding_Bind(t *testing.T) { t.Parallel() b := &XMLBinding{ XMLDecoder: xml.Unmarshal, } require.Equal(t, "xml", b.Name()) type Posts struct { XMLName xml.Name `xml:"post"` Title string `xml:"title"` } type User struct { Name string `xml:"name"` Ignore string `xml:"-"` Posts []Posts `xml:"posts>post"` Age int `xml:"age"` } user := new(User) err := b.Bind([]byte(` john 42 ignore post1 post2 `), user) require.NoError(t, err) require.Equal(t, "john", user.Name) require.Equal(t, 42, user.Age) require.Empty(t, user.Ignore) require.Len(t, user.Posts, 2) require.Equal(t, "post1", user.Posts[0].Title) require.Equal(t, "post2", user.Posts[1].Title) b.Reset() require.Nil(t, b.XMLDecoder) } func Test_XMLBinding_Bind_error(t *testing.T) { t.Parallel() b := &XMLBinding{ XMLDecoder: xml.Unmarshal, } type User struct { Name string `xml:"name"` Age int `xml:"age"` } user := new(User) err := b.Bind([]byte(` john 42 unknown post"` Age int `xml:"age"` } user := new(User) data := []byte(` john 42 ignore post1 post2 `) var err error for b.Loop() { err = binder.Bind(data, user) } require.NoError(b, err) user = new(User) err = binder.Bind(data, user) require.NoError(b, err) require.Equal(b, "john", user.Name) require.Equal(b, 42, user.Age) require.Len(b, user.Posts, 2) require.Equal(b, "post1", user.Posts[0].Title) require.Equal(b, "post2", user.Posts[1].Title) } ================================================ FILE: client/README.md ================================================

Fiber Client

Easy-to-use HTTP client based on fasthttp (inspired by resty and axios)

Features section describes in detail about Resty capabilities

## Features > The characteristics have not yet been written. - GET, POST, PUT, DELETE, HEAD, PATCH, OPTIONS, etc. - Simple and chainable methods for settings and request - Request Body can be `string`, `[]byte`, `map`, `slice` - Auto detects `Content-Type` - Buffer processing for `files` - Native `*fasthttp.Request` instance can be accessed during middleware and request execution via `Request.RawRequest` - Request Body can be read multiple time via `Request.RawRequest.GetBody()` - Response object gives you more possibility - Access as `[]byte` by `response.Body()` or access as `string` by `response.String()` - Automatic marshal and unmarshal for JSON and XML content type - Default is JSON, if you supply struct/map without header Content-Type - For auto-unmarshal, refer to - - Success scenario Request.SetResult() and Response.Result(). - Error scenario Request.SetError() and Response.Error(). - Supports RFC7807 - application/problem+json & application/problem+xml - Provide an option to override JSON Marshal/Unmarshal and XML Marshal/Unmarshal ## Usage The following samples will assist you to become as comfortable as possible with `Fiber Client` library. ```go // Import Fiber Client into your code and refer it as `client`. import "github.com/gofiber/fiber/client" ``` ### Simple GET ================================================ FILE: client/client.go ================================================ // Package client exposes Fiber's HTTP client built on top of fasthttp. // // It allows constructing new clients or wrapping existing fasthttp transports // so applications can share pools, dialers, and TLS settings between Fiber and // lower-level fasthttp integrations. package client import ( "context" "crypto/tls" "crypto/x509" "encoding/json" "encoding/xml" "errors" "io" "os" "path/filepath" "sync" "time" "github.com/fxamacker/cbor/v2" "github.com/gofiber/fiber/v3/log" "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" "github.com/valyala/fasthttp/fasthttpproxy" ) var ErrFailedToAppendCert = errors.New("failed to append certificate") // Client provides Fiber's high-level HTTP API while delegating transport work // to fasthttp.Client, fasthttp.HostClient, or fasthttp.LBClient implementations. // // Settings configured on the client are shared across every request and may be // overridden per request when needed. type Client struct { logger log.CommonLogger transport httpClientTransport header *Header params *QueryParam cookies *Cookie path *PathParam jsonMarshal utils.JSONMarshal jsonUnmarshal utils.JSONUnmarshal xmlMarshal utils.XMLMarshal xmlUnmarshal utils.XMLUnmarshal cborMarshal utils.CBORMarshal cborUnmarshal utils.CBORUnmarshal cookieJar *CookieJar retryConfig *RetryConfig baseURL string userAgent string referer string userRequestHooks []RequestHook builtinRequestHooks []RequestHook userResponseHooks []ResponseHook builtinResponseHooks []ResponseHook timeout time.Duration mu sync.RWMutex debug bool disablePathNormalizing bool } // Do executes the request using the underlying fasthttp transport. // // It mirrors [fasthttp.Client.Do], [fasthttp.HostClient.Do], or // [fasthttp.LBClient.Do] depending on how the Fiber client was constructed. func (c *Client) Do(req *fasthttp.Request, resp *fasthttp.Response) error { return c.transport.Do(req, resp) } // DoTimeout executes the request and waits for a response up to the provided timeout. // It mirrors the behavior of the respective fasthttp client's DoTimeout implementation. func (c *Client) DoTimeout(req *fasthttp.Request, resp *fasthttp.Response, timeout time.Duration) error { return c.transport.DoTimeout(req, resp, timeout) } // DoDeadline executes the request and waits for a response until the provided deadline. // It mirrors the behavior of the respective fasthttp client's DoDeadline implementation. func (c *Client) DoDeadline(req *fasthttp.Request, resp *fasthttp.Response, deadline time.Time) error { return c.transport.DoDeadline(req, resp, deadline) } // DoRedirects executes the request following redirects up to maxRedirects. func (c *Client) DoRedirects(req *fasthttp.Request, resp *fasthttp.Response, maxRedirects int) error { return c.transport.DoRedirects(req, resp, maxRedirects) } // CloseIdleConnections closes idle connections on the underlying fasthttp transport when supported. func (c *Client) CloseIdleConnections() { c.transport.CloseIdleConnections() } func (c *Client) currentTLSConfig() *tls.Config { return c.transport.TLSConfig() } func (c *Client) applyTLSConfig(config *tls.Config) { c.transport.SetTLSConfig(config) } func (c *Client) applyDial(dial fasthttp.DialFunc) { c.transport.SetDial(dial) } // FasthttpClient returns the underlying *fasthttp.Client if the client was created with one. func (c *Client) FasthttpClient() *fasthttp.Client { if client, ok := c.transport.(*standardClientTransport); ok { return client.client } return nil } // HostClient returns the underlying fasthttp.HostClient if the client was created with one. func (c *Client) HostClient() *fasthttp.HostClient { if client, ok := c.transport.(*hostClientTransport); ok { return client.client } return nil } // LBClient returns the underlying fasthttp.LBClient if the client was created with one. func (c *Client) LBClient() *fasthttp.LBClient { if client, ok := c.transport.(*lbClientTransport); ok { return client.client } return nil } // R creates a new Request associated with the client. func (c *Client) R() *Request { return AcquireRequest().SetClient(c) } // RequestHook returns the user-defined request hooks. func (c *Client) RequestHook() []RequestHook { return c.userRequestHooks } // AddRequestHook adds user-defined request hooks. func (c *Client) AddRequestHook(h ...RequestHook) *Client { c.mu.Lock() defer c.mu.Unlock() c.userRequestHooks = append(c.userRequestHooks, h...) return c } // ResponseHook returns the user-defined response hooks. func (c *Client) ResponseHook() []ResponseHook { return c.userResponseHooks } // AddResponseHook adds user-defined response hooks. func (c *Client) AddResponseHook(h ...ResponseHook) *Client { c.mu.Lock() defer c.mu.Unlock() c.userResponseHooks = append(c.userResponseHooks, h...) return c } // JSONMarshal returns the JSON marshal function used by the client. func (c *Client) JSONMarshal() utils.JSONMarshal { return c.jsonMarshal } // SetJSONMarshal sets the JSON marshal function to use. func (c *Client) SetJSONMarshal(f utils.JSONMarshal) *Client { c.jsonMarshal = f return c } // JSONUnmarshal returns the JSON unmarshal function used by the client. func (c *Client) JSONUnmarshal() utils.JSONUnmarshal { return c.jsonUnmarshal } // SetJSONUnmarshal sets the JSON unmarshal function to use. func (c *Client) SetJSONUnmarshal(f utils.JSONUnmarshal) *Client { c.jsonUnmarshal = f return c } // XMLMarshal returns the XML marshal function used by the client. func (c *Client) XMLMarshal() utils.XMLMarshal { return c.xmlMarshal } // SetXMLMarshal sets the XML marshal function to use. func (c *Client) SetXMLMarshal(f utils.XMLMarshal) *Client { c.xmlMarshal = f return c } // XMLUnmarshal returns the XML unmarshal function used by the client. func (c *Client) XMLUnmarshal() utils.XMLUnmarshal { return c.xmlUnmarshal } // SetXMLUnmarshal sets the XML unmarshal function to use. func (c *Client) SetXMLUnmarshal(f utils.XMLUnmarshal) *Client { c.xmlUnmarshal = f return c } // CBORMarshal returns the CBOR marshal function used by the client. func (c *Client) CBORMarshal() utils.CBORMarshal { return c.cborMarshal } // SetCBORMarshal sets the CBOR marshal function to use. func (c *Client) SetCBORMarshal(f utils.CBORMarshal) *Client { c.cborMarshal = f return c } // CBORUnmarshal returns the CBOR unmarshal function used by the client. func (c *Client) CBORUnmarshal() utils.CBORUnmarshal { return c.cborUnmarshal } // SetCBORUnmarshal sets the CBOR unmarshal function to use. func (c *Client) SetCBORUnmarshal(f utils.CBORUnmarshal) *Client { c.cborUnmarshal = f return c } // TLSConfig returns the client's TLS configuration. // If none is set, it initializes a new one. func (c *Client) TLSConfig() *tls.Config { c.mu.Lock() defer c.mu.Unlock() if cfg := c.currentTLSConfig(); cfg != nil { return cfg } cfg := &tls.Config{MinVersion: tls.VersionTLS12} c.applyTLSConfig(cfg) return cfg } // SetTLSConfig sets the TLS configuration for the client. func (c *Client) SetTLSConfig(config *tls.Config) *Client { c.mu.Lock() defer c.mu.Unlock() c.applyTLSConfig(config) return c } // SetCertificates adds certificates to the client's TLS configuration. func (c *Client) SetCertificates(certs ...tls.Certificate) *Client { config := c.TLSConfig() config.Certificates = append(config.Certificates, certs...) return c } // SetRootCertificate adds one or more root certificates to the client's TLS configuration. func (c *Client) SetRootCertificate(path string) *Client { cleanPath := filepath.Clean(path) file, err := os.Open(cleanPath) if err != nil { c.logger.Panicf("client: %v", err) } defer func() { if closeErr := file.Close(); closeErr != nil { c.logger.Panicf("client: failed to close file: %v", closeErr) } }() pem, err := io.ReadAll(file) if err != nil { c.logger.Panicf("client: %v", err) } config := c.TLSConfig() if config.RootCAs == nil { config.RootCAs = x509.NewCertPool() } if !config.RootCAs.AppendCertsFromPEM(pem) { c.logger.Panicf("client: %v", ErrFailedToAppendCert) } return c } // SetRootCertificateFromString adds one or more root certificates from a string to the client's TLS configuration. func (c *Client) SetRootCertificateFromString(pem string) *Client { config := c.TLSConfig() if config.RootCAs == nil { config.RootCAs = x509.NewCertPool() } if !config.RootCAs.AppendCertsFromPEM([]byte(pem)) { c.logger.Panicf("client: %v", ErrFailedToAppendCert) } return c } // SetProxyURL sets the proxy URL for the client. This affects all subsequent requests. func (c *Client) SetProxyURL(proxyURL string) error { c.mu.Lock() defer c.mu.Unlock() c.applyDial(fasthttpproxy.FasthttpHTTPDialer(proxyURL)) return nil } // RetryConfig returns the current retry configuration. func (c *Client) RetryConfig() *RetryConfig { return c.retryConfig } // SetRetryConfig sets the retry configuration for the client. func (c *Client) SetRetryConfig(config *RetryConfig) *Client { c.mu.Lock() defer c.mu.Unlock() c.retryConfig = config return c } // BaseURL returns the client's base URL. func (c *Client) BaseURL() string { return c.baseURL } // SetBaseURL sets the base URL prefix for all requests made by the client. func (c *Client) SetBaseURL(url string) *Client { c.baseURL = url return c } // Header returns all header values associated with the provided key. func (c *Client) Header(key string) []string { return c.header.PeekMultiple(key) } // AddHeader adds a single header field and its value to the client. These headers apply to all requests. func (c *Client) AddHeader(key, val string) *Client { c.header.Add(key, val) return c } // SetHeader sets a single header field and its value in the client. func (c *Client) SetHeader(key, val string) *Client { c.header.Set(key, val) return c } // AddHeaders adds multiple header fields and their values to the client. func (c *Client) AddHeaders(h map[string][]string) *Client { c.header.AddHeaders(h) return c } // SetHeaders method sets multiple headers field and its values at one go in the client instance. // These headers will be applied to all requests created from this client instance. Also it can be // overridden at request level headers options. func (c *Client) SetHeaders(h map[string]string) *Client { c.header.SetHeaders(h) return c } // Param returns all values of the specified query parameter. func (c *Client) Param(key string) []string { tmp := c.params.PeekMulti(key) res := make([]string, 0, len(tmp)) for _, v := range tmp { res = append(res, utils.UnsafeString(v)) } return res } // AddParam adds a single query parameter and its value to the client. // These params will be applied to all requests created from this client instance. func (c *Client) AddParam(key, val string) *Client { c.params.Add(key, val) return c } // SetParam sets a single query parameter and its value in the client. func (c *Client) SetParam(key, val string) *Client { c.params.Set(key, val) return c } // AddParams adds multiple query parameters and their values to the client. func (c *Client) AddParams(m map[string][]string) *Client { c.params.AddParams(m) return c } // SetParams sets multiple query parameters and their values in the client. func (c *Client) SetParams(m map[string]string) *Client { c.params.SetParams(m) return c } // SetParamsWithStruct sets multiple query parameters and their values using a struct. func (c *Client) SetParamsWithStruct(v any) *Client { c.params.SetParamsWithStruct(v) return c } // DelParams deletes one or more query parameters and their values from the client. func (c *Client) DelParams(key ...string) *Client { for _, v := range key { c.params.Del(v) } return c } // SetUserAgent sets the User-Agent header for the client. func (c *Client) SetUserAgent(ua string) *Client { c.userAgent = ua return c } // SetReferer sets the Referer header for the client. func (c *Client) SetReferer(r string) *Client { c.referer = r return c } // DisablePathNormalizing reports whether path normalizing is disabled for the client. func (c *Client) DisablePathNormalizing() bool { return c.disablePathNormalizing } // SetDisablePathNormalizing configures the client to disable or enable path normalizing. func (c *Client) SetDisablePathNormalizing(disable bool) *Client { c.disablePathNormalizing = disable return c } // PathParam returns the value of the specified path parameter. Returns an empty string if it does not exist. func (c *Client) PathParam(key string) string { if val, ok := (*c.path)[key]; ok { return val } return "" } // SetPathParam sets a single path parameter and its value in the client. func (c *Client) SetPathParam(key, val string) *Client { c.path.SetParam(key, val) return c } // SetPathParams sets multiple path parameters and their values in the client. func (c *Client) SetPathParams(m map[string]string) *Client { c.path.SetParams(m) return c } // SetPathParamsWithStruct sets multiple path parameters and their values using a struct. func (c *Client) SetPathParamsWithStruct(v any) *Client { c.path.SetParamsWithStruct(v) return c } // DelPathParams deletes one or more path parameters and their values from the client. func (c *Client) DelPathParams(key ...string) *Client { c.path.DelParams(key...) return c } // Cookie returns the value of the specified cookie. Returns an empty string if it does not exist. func (c *Client) Cookie(key string) string { if val, ok := (*c.cookies)[key]; ok { return val } return "" } // SetCookie sets a single cookie and its value in the client. func (c *Client) SetCookie(key, val string) *Client { c.cookies.SetCookie(key, val) return c } // SetCookies sets multiple cookies and their values in the client. func (c *Client) SetCookies(m map[string]string) *Client { c.cookies.SetCookies(m) return c } // SetCookiesWithStruct sets multiple cookies and their values using a struct. func (c *Client) SetCookiesWithStruct(v any) *Client { c.cookies.SetCookiesWithStruct(v) return c } // DelCookies deletes one or more cookies and their values from the client. func (c *Client) DelCookies(key ...string) *Client { c.cookies.DelCookies(key...) return c } // SetTimeout sets the timeout value for the client. This applies to all requests unless overridden at the request level. func (c *Client) SetTimeout(t time.Duration) *Client { c.timeout = t return c } // Debug enables debug-level logging output. func (c *Client) Debug() *Client { c.debug = true return c } // DisableDebug disables debug-level logging output. func (c *Client) DisableDebug() *Client { c.debug = false return c } // StreamResponseBody returns the current StreamResponseBody setting. func (c *Client) StreamResponseBody() bool { return c.transport.StreamResponseBody() } // SetStreamResponseBody enables or disables response body streaming. // When enabled, the response body can be read as a stream using BodyStream() // instead of being fully loaded into memory. This is useful for large responses // or server-sent events. func (c *Client) SetStreamResponseBody(enable bool) *Client { c.transport.SetStreamResponseBody(enable) return c } // SetCookieJar sets the cookie jar for the client. func (c *Client) SetCookieJar(cookieJar *CookieJar) *Client { c.cookieJar = cookieJar return c } // Get sends a GET request to the specified URL, similar to axios. func (c *Client) Get(url string, cfg ...Config) (*Response, error) { req := AcquireRequest().SetClient(c) setConfigToRequest(req, cfg...) return req.Get(url) } // Post sends a POST request to the specified URL, similar to axios. func (c *Client) Post(url string, cfg ...Config) (*Response, error) { req := AcquireRequest().SetClient(c) setConfigToRequest(req, cfg...) return req.Post(url) } // Head sends a HEAD request to the specified URL, similar to axios. func (c *Client) Head(url string, cfg ...Config) (*Response, error) { req := AcquireRequest().SetClient(c) setConfigToRequest(req, cfg...) return req.Head(url) } // Put sends a PUT request to the specified URL, similar to axios. func (c *Client) Put(url string, cfg ...Config) (*Response, error) { req := AcquireRequest().SetClient(c) setConfigToRequest(req, cfg...) return req.Put(url) } // Delete sends a DELETE request to the specified URL, similar to axios. func (c *Client) Delete(url string, cfg ...Config) (*Response, error) { req := AcquireRequest().SetClient(c) setConfigToRequest(req, cfg...) return req.Delete(url) } // Options sends an OPTIONS request to the specified URL, similar to axios. func (c *Client) Options(url string, cfg ...Config) (*Response, error) { req := AcquireRequest().SetClient(c) setConfigToRequest(req, cfg...) return req.Options(url) } // Patch sends a PATCH request to the specified URL, similar to axios. func (c *Client) Patch(url string, cfg ...Config) (*Response, error) { req := AcquireRequest().SetClient(c) setConfigToRequest(req, cfg...) return req.Patch(url) } // Custom sends a request with a custom method to the specified URL, similar to axios. func (c *Client) Custom(url, method string, cfg ...Config) (*Response, error) { req := AcquireRequest().SetClient(c) setConfigToRequest(req, cfg...) return req.Custom(url, method) } // SetDial sets the custom dial function for the client. func (c *Client) SetDial(dial fasthttp.DialFunc) *Client { c.mu.Lock() defer c.mu.Unlock() c.applyDial(dial) return c } // SetLogger sets the logger instance used by the client. func (c *Client) SetLogger(logger log.CommonLogger) *Client { c.mu.Lock() defer c.mu.Unlock() c.logger = logger return c } // Logger returns the logger instance used by the client. func (c *Client) Logger() log.CommonLogger { return c.logger } // Reset resets the client to its default state, clearing most configurations // and replacing the underlying transport with a new fasthttp.Client so future // requests resume with Fiber's standard transport settings. func (c *Client) Reset() { c.transport = newStandardClientTransport(&fasthttp.Client{}) c.baseURL = "" c.timeout = 0 c.userAgent = "" c.referer = "" c.retryConfig = nil c.debug = false c.disablePathNormalizing = false if c.cookieJar != nil { c.cookieJar.Release() c.cookieJar = nil } c.path.Reset() c.cookies.Reset() c.header.Reset() c.params.Reset() } // Config is used to easily set request parameters. Note that when setting a request body, // JSON is used as the default serialization mechanism. The priority is: // Body > FormData > File. type Config struct { Ctx context.Context //nolint:containedctx // It's needed to be stored in the config. Body any Header map[string]string Param map[string]string Cookie map[string]string PathParam map[string]string FormData map[string]string UserAgent string Referer string File []*File Timeout time.Duration MaxRedirects int DisablePathNormalizing bool } // setConfigToRequest sets the parameters passed via Config to the Request. func setConfigToRequest(req *Request, config ...Config) { if len(config) == 0 { return } cfg := config[0] if cfg.Ctx != nil { req.SetContext(cfg.Ctx) } if cfg.UserAgent != "" { req.SetUserAgent(cfg.UserAgent) } if cfg.Referer != "" { req.SetReferer(cfg.Referer) } if cfg.Header != nil { req.SetHeaders(cfg.Header) } if cfg.Param != nil { req.SetParams(cfg.Param) } if cfg.Cookie != nil { req.SetCookies(cfg.Cookie) } if cfg.PathParam != nil { req.SetPathParams(cfg.PathParam) } if cfg.Timeout != 0 { req.SetTimeout(cfg.Timeout) } if cfg.MaxRedirects != 0 { req.SetMaxRedirects(cfg.MaxRedirects) } if cfg.DisablePathNormalizing { req.SetDisablePathNormalizing(true) } if cfg.Body != nil { switch v := cfg.Body.(type) { case []byte: req.SetRawBody(v) case string: req.SetRawBody([]byte(v)) default: req.SetJSON(cfg.Body) } return } if cfg.FormData != nil { req.SetFormDataWithMap(cfg.FormData) return } if len(cfg.File) != 0 { req.AddFiles(cfg.File...) return } } var ( defaultClient *Client replaceMu = sync.Mutex{} defaultUserAgent = "fiber" ) func init() { defaultClient = New() } // New creates and returns a new Client object. func New() *Client { // Follow-up performance optimizations: // Try to use a pool to reduce the memory allocation cost for the Fiber client and the fasthttp client. // If possible, also consider pooling other structs (e.g., request headers, cookies, query parameters, path parameters). return NewWithClient(&fasthttp.Client{}) } // NewWithClient creates and returns a new Client object from an existing fasthttp.Client. func NewWithClient(c *fasthttp.Client) *Client { if c == nil { panic("fasthttp.Client must not be nil") } return newClient(newStandardClientTransport(c)) } // NewWithHostClient creates and returns a new Client object from an existing fasthttp.HostClient. func NewWithHostClient(c *fasthttp.HostClient) *Client { if c == nil { panic("fasthttp.HostClient must not be nil") } return newClient(newHostClientTransport(c)) } // NewWithLBClient creates and returns a new Client object from an existing fasthttp.LBClient. func NewWithLBClient(c *fasthttp.LBClient) *Client { if c == nil { panic("fasthttp.LBClient must not be nil") } return newClient(newLBClientTransport(c)) } func newClient(transport httpClientTransport) *Client { return &Client{ transport: transport, header: &Header{ RequestHeader: &fasthttp.RequestHeader{}, }, params: &QueryParam{ Args: fasthttp.AcquireArgs(), }, cookies: &Cookie{}, path: &PathParam{}, userRequestHooks: []RequestHook{}, builtinRequestHooks: []RequestHook{parserRequestURL, parserRequestHeader, parserRequestBody}, userResponseHooks: []ResponseHook{}, builtinResponseHooks: []ResponseHook{parserResponseCookie, logger}, jsonMarshal: json.Marshal, jsonUnmarshal: json.Unmarshal, xmlMarshal: xml.Marshal, cborMarshal: cbor.Marshal, cborUnmarshal: cbor.Unmarshal, xmlUnmarshal: xml.Unmarshal, logger: log.DefaultLogger[*log.Logger](), } } // C returns the default client. func C() *Client { return defaultClient } // Replace replaces the defaultClient with a new one, returning a function to restore the old client. func Replace(c *Client) func() { replaceMu.Lock() defer replaceMu.Unlock() oldClient := defaultClient defaultClient = c return func() { replaceMu.Lock() defer replaceMu.Unlock() defaultClient = oldClient } } // Get sends a GET request using the default client. func Get(url string, cfg ...Config) (*Response, error) { return C().Get(url, cfg...) } // Post sends a POST request using the default client. func Post(url string, cfg ...Config) (*Response, error) { return C().Post(url, cfg...) } // Head sends a HEAD request using the default client. func Head(url string, cfg ...Config) (*Response, error) { return C().Head(url, cfg...) } // Put sends a PUT request using the default client. func Put(url string, cfg ...Config) (*Response, error) { return C().Put(url, cfg...) } // Delete sends a DELETE request using the default client. func Delete(url string, cfg ...Config) (*Response, error) { return C().Delete(url, cfg...) } // Options sends an OPTIONS request using the default client. func Options(url string, cfg ...Config) (*Response, error) { return C().Options(url, cfg...) } // Patch sends a PATCH request using the default client. func Patch(url string, cfg ...Config) (*Response, error) { return C().Patch(url, cfg...) } ================================================ FILE: client/client_test.go ================================================ package client import ( "context" "crypto/tls" "encoding/hex" "errors" "io" "net" "os" "path/filepath" "reflect" "sync" "sync/atomic" "testing" "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/addon/retry" "github.com/gofiber/fiber/v3/internal/tlstest" "github.com/gofiber/fiber/v3/log" "github.com/gofiber/utils/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/valyala/bytebufferpool" "github.com/valyala/fasthttp" ) func startTestServerWithPort(t *testing.T, beforeStarting func(app *fiber.App)) (app *fiber.App, addr string) { //nolint:nonamedreturns // gocritic unnamedResult requires explicit result names for clarity when returning app and address t.Helper() app = fiber.New() if beforeStarting != nil { beforeStarting(app) } addrChan := make(chan string) errChan := make(chan error, 1) go func(server *fiber.App) { err := server.Listen(":0", fiber.ListenConfig{ DisableStartupMessage: true, ListenerAddrFunc: func(addr net.Addr) { addrChan <- addr.String() }, }) if err != nil { errChan <- err } }(app) select { case addr = <-addrChan: return app, addr case err := <-errChan: t.Fatalf("Failed to start test server: %v", err) } return nil, "" } func Test_New_With_Client(t *testing.T) { t.Parallel() t.Run("with valid client", func(t *testing.T) { t.Parallel() c := &fasthttp.Client{ MaxConnsPerHost: 5, } client := NewWithClient(c) require.NotNil(t, client) }) t.Run("with nil client", func(t *testing.T) { t.Parallel() require.PanicsWithValue(t, "fasthttp.Client must not be nil", func() { NewWithClient(nil) }) }) } func Test_New_With_HostClient(t *testing.T) { t.Parallel() t.Run("with valid host client", func(t *testing.T) { t.Parallel() hc := &fasthttp.HostClient{Addr: "example.com:80"} client := NewWithHostClient(hc) require.NotNil(t, client) require.Equal(t, hc, client.HostClient()) require.Nil(t, client.FasthttpClient()) }) t.Run("with nil host client", func(t *testing.T) { t.Parallel() require.PanicsWithValue(t, "fasthttp.HostClient must not be nil", func() { NewWithHostClient(nil) }) }) } func Test_New_With_LBClient(t *testing.T) { t.Parallel() t.Run("with valid lb client", func(t *testing.T) { t.Parallel() lb := &fasthttp.LBClient{ Clients: []fasthttp.BalancingClient{ &fasthttp.HostClient{Addr: "example.com:80"}, }, } client := NewWithLBClient(lb) require.NotNil(t, client) require.Equal(t, lb, client.LBClient()) require.Nil(t, client.FasthttpClient()) require.Nil(t, client.HostClient()) }) t.Run("with nil lb client", func(t *testing.T) { t.Parallel() require.PanicsWithValue(t, "fasthttp.LBClient must not be nil", func() { NewWithLBClient(nil) }) }) } func TestClientUnderlyingTransports(t *testing.T) { t.Parallel() std := New() require.NotNil(t, std.FasthttpClient()) require.Nil(t, std.HostClient()) require.Nil(t, std.LBClient()) hostTransport := &fasthttp.HostClient{Addr: "example.com:80"} host := NewWithHostClient(hostTransport) require.Nil(t, host.FasthttpClient()) require.Same(t, hostTransport, host.HostClient()) require.Nil(t, host.LBClient()) lbClient := &fasthttp.LBClient{Clients: []fasthttp.BalancingClient{hostTransport}} lb := NewWithLBClient(lbClient) require.Nil(t, lb.FasthttpClient()) require.Nil(t, lb.HostClient()) require.Same(t, lbClient, lb.LBClient()) } func TestClientCBORUnmarshalOverride(t *testing.T) { t.Parallel() client := New() initial := client.CBORUnmarshal() require.NotNil(t, initial) called := false custom := func(b []byte, v any) error { _ = b _ = v called = true return nil } client.SetCBORUnmarshal(custom) fn := client.CBORUnmarshal() require.NotNil(t, fn) var payload []byte require.NoError(t, fn(payload, nil)) require.True(t, called) } func TestClientSetRootCertificateErrors(t *testing.T) { t.Parallel() client := New() require.Panics(t, func() { client.SetRootCertificate("does-not-exist.pem") }) tmpDir := t.TempDir() badPath := filepath.Join(tmpDir, "invalid.pem") require.NoError(t, os.WriteFile(badPath, []byte("not a pem"), 0o600)) require.Panics(t, func() { client.SetRootCertificate(badPath) }) } func TestClientSetRootCertificateFromStringError(t *testing.T) { t.Parallel() client := New() require.Panics(t, func() { client.SetRootCertificateFromString("invalid pem data") }) } func TestClientLoggerAccessors(t *testing.T) { t.Parallel() client := New() _ = client.Logger() contextual := log.WithContext(context.Background()) client.SetLogger(contextual) require.Equal(t, contextual, client.Logger()) } func TestClientResetClearsState(t *testing.T) { t.Parallel() client := New() jar := AcquireCookieJar() jar.hostCookies = map[string][]*fasthttp.Cookie{"example.com": {}} client.SetCookieJar(jar) client.SetBaseURL("http://example.com") client.SetTimeout(2 * time.Second) client.SetUserAgent("reset-agent") client.SetReferer("reset-ref") client.SetRetryConfig(&RetryConfig{MaxRetryCount: 3}) client.Debug() client.SetDisablePathNormalizing(true) client.SetHeaders(map[string]string{"X-Test": "value"}) client.SetParams(map[string]string{"p": "1"}) client.SetCookies(map[string]string{"cookie": "value"}) client.SetPathParams(map[string]string{"id": "123"}) client.Reset() require.NotNil(t, client.FasthttpClient()) require.Nil(t, client.HostClient()) require.Nil(t, client.LBClient()) require.Empty(t, client.BaseURL()) require.Zero(t, client.timeout) require.Empty(t, client.userAgent) require.Empty(t, client.referer) require.Nil(t, client.retryConfig) require.False(t, client.debug) require.False(t, client.disablePathNormalizing) require.Nil(t, client.cookieJar) require.Nil(t, jar.hostCookies) require.Empty(t, *client.path) require.Empty(t, *client.cookies) require.Equal(t, 0, client.header.Len()) require.Equal(t, 0, client.params.Len()) } func Test_Client_Add_Hook(t *testing.T) { t.Parallel() t.Run("add request hooks", func(t *testing.T) { t.Parallel() buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) client := New().AddRequestHook(func(_ *Client, _ *Request) error { buf.WriteString("hook1") return nil }) require.Len(t, client.RequestHook(), 1) client.AddRequestHook(func(_ *Client, _ *Request) error { buf.WriteString("hook2") return nil }, func(_ *Client, _ *Request) error { buf.WriteString("hook3") return nil }) require.Len(t, client.RequestHook(), 3) }) t.Run("add response hooks", func(t *testing.T) { t.Parallel() client := New().AddResponseHook(func(_ *Client, _ *Response, _ *Request) error { return nil }) require.Len(t, client.ResponseHook(), 1) hook1 := func(_ *Client, _ *Response, _ *Request) error { return nil } hook2 := func(_ *Client, _ *Response, _ *Request) error { return nil } client.AddResponseHook(hook1, hook2) require.Len(t, client.ResponseHook(), 3) }) } func Test_Client_HostClient_Behavior(t *testing.T) { t.Parallel() t.Run("do and redirects", func(t *testing.T) { t.Parallel() app, addr := startTestServerWithPort(t, func(app *fiber.App) { app.Get("/ok", func(c fiber.Ctx) error { return c.SendString("ok") }) app.Get("/redirect", func(c fiber.Ctx) error { return c.Redirect().Status(fiber.StatusFound).To("/ok") }) }) t.Cleanup(func() { require.NoError(t, app.Shutdown()) }) client := NewWithHostClient(&fasthttp.HostClient{Addr: addr}) resp, err := client.Get("http://" + addr + "/ok") require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "ok", resp.String()) resp, err = client.Get("http://"+addr+"/redirect", Config{MaxRedirects: 1}) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "ok", resp.String()) }) t.Run("retries respect dial overrides", func(t *testing.T) { t.Parallel() app, addr := startTestServerWithPort(t, func(app *fiber.App) { app.Get("/", func(c fiber.Ctx) error { return c.SendString("retry") }) }) t.Cleanup(func() { require.NoError(t, app.Shutdown()) }) client := NewWithHostClient(&fasthttp.HostClient{Addr: addr}) client.SetRetryConfig(&RetryConfig{ InitialInterval: time.Millisecond, MaxRetryCount: 2, }) var attempts int32 client.SetDial(func(address string) (net.Conn, error) { if atomic.AddInt32(&attempts, 1) == 1 { return nil, errors.New("dial failure") } return fasthttp.Dial(address) }) resp, err := client.Get("http://" + addr) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "retry", resp.String()) require.EqualValues(t, 2, atomic.LoadInt32(&attempts)) }) t.Run("tls configuration propagates", func(t *testing.T) { t.Parallel() serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() require.NoError(t, err) ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") require.NoError(t, err) ln = tls.NewListener(ln, serverTLSConf) app := fiber.New() app.Get("/", func(c fiber.Ctx) error { return c.SendString("tls-host") }) go func() { assert.NoError(t, app.Listener(ln, fiber.ListenConfig{ DisableStartupMessage: true, })) }() time.Sleep(1 * time.Second) client := NewWithHostClient(&fasthttp.HostClient{Addr: ln.Addr().String(), IsTLS: true}) resp, err := client.SetTLSConfig(clientTLSConf).Get("https://" + ln.Addr().String()) require.NoError(t, err) cfg := client.TLSConfig() require.Same(t, clientTLSConf, cfg) require.NotNil(t, cfg.RootCAs) require.Equal(t, clientTLSConf, client.HostClient().TLSConfig) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "tls-host", resp.String()) require.NoError(t, app.Shutdown()) }) t.Run("proxy overrides and reset", func(t *testing.T) { t.Parallel() client := NewWithHostClient(&fasthttp.HostClient{Addr: "example.com:80"}) require.NoError(t, client.SetProxyURL("http://127.0.0.1:8080")) require.NotNil(t, client.HostClient().Dial) var called int32 customDial := func(addr string) (net.Conn, error) { _ = addr atomic.AddInt32(&called, 1) return nil, errors.New("dial") } client.SetDial(customDial) _, err := client.HostClient().Dial("example.com:80") require.Error(t, err) require.EqualValues(t, 1, atomic.LoadInt32(&called)) client.Reset() require.NotNil(t, client.FasthttpClient()) require.Nil(t, client.HostClient()) require.Nil(t, client.LBClient()) }) t.Run("timeouts and close idle", func(t *testing.T) { t.Parallel() client := NewWithHostClient(&fasthttp.HostClient{Addr: "example.com:80"}) var dialCalls int32 dialErr := errors.New("dial failed") client.HostClient().Dial = func(addr string) (net.Conn, error) { _ = addr atomic.AddInt32(&dialCalls, 1) return nil, dialErr } req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() req.SetRequestURI("http://example.com/") err := client.DoTimeout(req, resp, 5*time.Millisecond) require.ErrorIs(t, err, dialErr) req.SetRequestURI("http://example.com/") err = client.DoDeadline(req, resp, time.Now().Add(5*time.Millisecond)) require.ErrorIs(t, err, dialErr) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) require.EqualValues(t, 2, atomic.LoadInt32(&dialCalls)) require.NotPanics(t, client.CloseIdleConnections) }) } func Test_Client_LBClient_Behavior(t *testing.T) { t.Parallel() newLBClient := func(addr string) *fasthttp.LBClient { return &fasthttp.LBClient{ Clients: []fasthttp.BalancingClient{ &fasthttp.HostClient{Addr: addr}, }, } } t.Run("do and redirects", func(t *testing.T) { t.Parallel() app, addr := startTestServerWithPort(t, func(app *fiber.App) { app.Get("/ok", func(c fiber.Ctx) error { return c.SendString("ok") }) app.Get("/redirect", func(c fiber.Ctx) error { return c.Redirect().Status(fiber.StatusFound).To("/ok") }) }) t.Cleanup(func() { require.NoError(t, app.Shutdown()) }) client := NewWithLBClient(newLBClient(addr)) resp, err := client.Get("http://" + addr + "/ok") require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "ok", resp.String()) resp, err = client.Get("http://"+addr+"/redirect", Config{MaxRedirects: 1}) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "ok", resp.String()) }) t.Run("retries respect dial overrides", func(t *testing.T) { t.Parallel() app, addr := startTestServerWithPort(t, func(app *fiber.App) { app.Get("/", func(c fiber.Ctx) error { return c.SendString("retry") }) }) t.Cleanup(func() { require.NoError(t, app.Shutdown()) }) client := NewWithLBClient(newLBClient(addr)) client.SetRetryConfig(&RetryConfig{ InitialInterval: time.Millisecond, MaxRetryCount: 2, }) var attempts int32 client.SetDial(func(address string) (net.Conn, error) { if atomic.AddInt32(&attempts, 1) == 1 { return nil, errors.New("dial failure") } return fasthttp.Dial(address) }) resp, err := client.Get("http://" + addr) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "retry", resp.String()) require.EqualValues(t, 2, atomic.LoadInt32(&attempts)) }) t.Run("tls configuration propagates", func(t *testing.T) { t.Parallel() serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() require.NoError(t, err) ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") require.NoError(t, err) ln = tls.NewListener(ln, serverTLSConf) app := fiber.New() app.Get("/", func(c fiber.Ctx) error { return c.SendString("tls-lb") }) go func() { assert.NoError(t, app.Listener(ln, fiber.ListenConfig{ DisableStartupMessage: true, })) }() time.Sleep(1 * time.Second) lb := newLBClient(ln.Addr().String()) host, ok := lb.Clients[0].(*fasthttp.HostClient) require.True(t, ok) host.IsTLS = true client := NewWithLBClient(lb) resp, err := client.SetTLSConfig(clientTLSConf).Get("https://" + ln.Addr().String()) require.NoError(t, err) cfg := client.TLSConfig() require.Same(t, clientTLSConf, cfg) require.NotNil(t, cfg.RootCAs) hc, ok := client.LBClient().Clients[0].(*fasthttp.HostClient) require.True(t, ok) require.Equal(t, clientTLSConf, hc.TLSConfig) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "tls-lb", resp.String()) require.NoError(t, app.Shutdown()) }) t.Run("proxy overrides and reset", func(t *testing.T) { t.Parallel() client := NewWithLBClient(&fasthttp.LBClient{ Clients: []fasthttp.BalancingClient{ &fasthttp.HostClient{Addr: "example.com:80"}, }, }) require.NoError(t, client.SetProxyURL("http://127.0.0.1:8080")) hc, ok := client.LBClient().Clients[0].(*fasthttp.HostClient) require.True(t, ok) require.NotNil(t, hc.Dial) var called int32 customDial := func(addr string) (net.Conn, error) { _ = addr atomic.AddInt32(&called, 1) return nil, errors.New("dial") } client.SetDial(customDial) _, err := hc.Dial("example.com:80") require.Error(t, err) require.EqualValues(t, 1, atomic.LoadInt32(&called)) client.Reset() require.NotNil(t, client.FasthttpClient()) require.Nil(t, client.LBClient()) require.Nil(t, client.HostClient()) }) t.Run("timeouts and close idle", func(t *testing.T) { t.Parallel() client := NewWithLBClient(&fasthttp.LBClient{ Clients: []fasthttp.BalancingClient{ &fasthttp.HostClient{Addr: "example.com:80"}, &fasthttp.HostClient{Addr: "example.org:80"}, }, }) var dialCalls int32 dialErr := errors.New("dial failed") for _, bc := range client.LBClient().Clients { if hc, ok := bc.(*fasthttp.HostClient); ok { hc.Dial = func(addr string) (net.Conn, error) { _ = addr atomic.AddInt32(&dialCalls, 1) return nil, dialErr } } } req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() req.SetRequestURI("http://example.com/") err := client.DoTimeout(req, resp, 5*time.Millisecond) require.ErrorIs(t, err, dialErr) req.SetRequestURI("http://example.com/") err = client.DoDeadline(req, resp, time.Now().Add(5*time.Millisecond)) require.ErrorIs(t, err, dialErr) fasthttp.ReleaseRequest(req) fasthttp.ReleaseResponse(resp) require.GreaterOrEqual(t, atomic.LoadInt32(&dialCalls), int32(2)) require.NotPanics(t, client.CloseIdleConnections) }) } func Test_Client_Add_Hook_CheckOrder(t *testing.T) { t.Parallel() buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) client := New(). AddRequestHook(func(_ *Client, _ *Request) error { buf.WriteString("hook1") return nil }). AddRequestHook(func(_ *Client, _ *Request) error { buf.WriteString("hook2") return nil }). AddRequestHook(func(_ *Client, _ *Request) error { buf.WriteString("hook3") return nil }) for _, hook := range client.RequestHook() { require.NoError(t, hook(client, &Request{})) } require.Equal(t, "hook1hook2hook3", buf.String()) } func Test_Client_Marshal(t *testing.T) { t.Parallel() t.Run("set json marshal", func(t *testing.T) { t.Parallel() client := New(). SetJSONMarshal(func(_ any) ([]byte, error) { return []byte("hello"), nil }) val, err := client.JSONMarshal()(nil) require.NoError(t, err) require.Equal(t, []byte("hello"), val) }) t.Run("set json marshal error", func(t *testing.T) { t.Parallel() emptyErr := errors.New("empty json") client := New(). SetJSONMarshal(func(_ any) ([]byte, error) { return nil, emptyErr }) val, err := client.JSONMarshal()(nil) require.Nil(t, val) require.ErrorIs(t, err, emptyErr) }) t.Run("set json unmarshal", func(t *testing.T) { t.Parallel() client := New(). SetJSONUnmarshal(func(_ []byte, _ any) error { return errors.New("empty json") }) err := client.JSONUnmarshal()(nil, nil) require.Equal(t, errors.New("empty json"), err) }) t.Run("set json unmarshal error", func(t *testing.T) { t.Parallel() client := New(). SetJSONUnmarshal(func(_ []byte, _ any) error { return errors.New("empty json") }) err := client.JSONUnmarshal()(nil, nil) require.Equal(t, errors.New("empty json"), err) }) t.Run("set xml marshal", func(t *testing.T) { t.Parallel() client := New(). SetXMLMarshal(func(_ any) ([]byte, error) { return []byte("hello"), nil }) val, err := client.XMLMarshal()(nil) require.NoError(t, err) require.Equal(t, []byte("hello"), val) }) t.Run("set xml marshal error", func(t *testing.T) { t.Parallel() client := New(). SetXMLMarshal(func(_ any) ([]byte, error) { return nil, errors.New("empty xml") }) val, err := client.XMLMarshal()(nil) require.Nil(t, val) require.Equal(t, errors.New("empty xml"), err) }) t.Run("set cbor marshal", func(t *testing.T) { t.Parallel() bs, err := hex.DecodeString("f6") if err != nil { t.Error(err) } client := New(). SetCBORMarshal(func(_ any) ([]byte, error) { return bs, nil }) val, err := client.CBORMarshal()(nil) require.NoError(t, err) require.Equal(t, bs, val) }) t.Run("set cbor marshal error", func(t *testing.T) { t.Parallel() client := New().SetCBORMarshal(func(_ any) ([]byte, error) { return nil, errors.New("invalid struct") }) val, err := client.CBORMarshal()(nil) require.Nil(t, val) require.Equal(t, errors.New("invalid struct"), err) }) t.Run("set xml unmarshal", func(t *testing.T) { t.Parallel() client := New(). SetXMLUnmarshal(func(_ []byte, _ any) error { return errors.New("empty xml") }) err := client.XMLUnmarshal()(nil, nil) require.Equal(t, errors.New("empty xml"), err) }) t.Run("set xml unmarshal error", func(t *testing.T) { t.Parallel() client := New(). SetXMLUnmarshal(func(_ []byte, _ any) error { return errors.New("empty xml") }) err := client.XMLUnmarshal()(nil, nil) require.Equal(t, errors.New("empty xml"), err) }) } func Test_Client_SetBaseURL(t *testing.T) { t.Parallel() client := New().SetBaseURL("http://example.com") require.Equal(t, "http://example.com", client.BaseURL()) } func Test_Client_Invalid_URL(t *testing.T) { t.Parallel() app, dial, start := createHelperServer(t) app.Get("/", func(c fiber.Ctx) error { return c.SendString(c.Hostname()) }) go start() _, err := New().SetDial(dial). R(). Get("http//example") require.ErrorIs(t, err, ErrURLFormat) } func Test_Client_Unsupported_Protocol(t *testing.T) { t.Parallel() _, err := New(). R(). Get("ftp://example.com") require.ErrorIs(t, err, ErrURLFormat) } func Test_Client_ConcurrencyRequests(t *testing.T) { t.Parallel() app, dial, start := createHelperServer(t) app.All("/", func(c fiber.Ctx) error { return c.SendString(c.Hostname() + " " + c.Method()) }) go start() client := New().SetDial(dial) wg := sync.WaitGroup{} for range 5 { for _, method := range []string{"GET", "POST", "PUT", "DELETE", "PATCH"} { m := method wg.Go(func() { resp, err := client.Custom("http://example.com", m) require.NoError(t, err) assert.Equal(t, "example.com "+m, utils.UnsafeString(resp.RawResponse.Body())) }) } } wg.Wait() } func Test_Get(t *testing.T) { t.Parallel() setupApp := func() (*fiber.App, string) { app, addr := startTestServerWithPort(t, func(app *fiber.App) { app.Get("/", func(c fiber.Ctx) error { return c.SendString(c.Hostname()) }) }) return app, addr } t.Run("global get function", func(t *testing.T) { t.Parallel() app, addr := setupApp() defer func() { require.NoError(t, app.Shutdown()) }() resp, err := Get("http://" + addr) require.NoError(t, err) require.Equal(t, "0.0.0.0", utils.UnsafeString(resp.RawResponse.Body())) }) t.Run("client get", func(t *testing.T) { t.Parallel() app, addr := setupApp() defer func() { require.NoError(t, app.Shutdown()) }() resp, err := New().Get("http://" + addr) require.NoError(t, err) require.Equal(t, "0.0.0.0", utils.UnsafeString(resp.RawResponse.Body())) }) } func Test_Head(t *testing.T) { t.Parallel() setupApp := func() (*fiber.App, string) { app, addr := startTestServerWithPort(t, func(app *fiber.App) { app.Head("/", func(c fiber.Ctx) error { return c.SendString(c.Hostname()) }) }) return app, addr } t.Run("global head function", func(t *testing.T) { t.Parallel() app, addr := setupApp() defer func() { require.NoError(t, app.Shutdown()) }() resp, err := Head("http://" + addr) require.NoError(t, err) require.Equal(t, "7", resp.Header(fiber.HeaderContentLength)) require.Empty(t, utils.UnsafeString(resp.RawResponse.Body())) }) t.Run("client head", func(t *testing.T) { t.Parallel() app, addr := setupApp() defer func() { require.NoError(t, app.Shutdown()) }() resp, err := New().Head("http://" + addr) require.NoError(t, err) require.Equal(t, "7", resp.Header(fiber.HeaderContentLength)) require.Empty(t, utils.UnsafeString(resp.RawResponse.Body())) }) } func Test_Post(t *testing.T) { t.Parallel() setupApp := func() (*fiber.App, string) { app, addr := startTestServerWithPort(t, func(app *fiber.App) { app.Post("/", func(c fiber.Ctx) error { return c.Status(fiber.StatusCreated). SendString(c.FormValue("foo")) }) }) return app, addr } t.Run("global post function", func(t *testing.T) { t.Parallel() app, addr := setupApp() defer func() { require.NoError(t, app.Shutdown()) }() for range 5 { resp, err := Post("http://"+addr, Config{ FormData: map[string]string{ "foo": "bar", }, }) require.NoError(t, err) require.Equal(t, fiber.StatusCreated, resp.StatusCode()) require.Equal(t, "bar", resp.String()) } }) t.Run("client post", func(t *testing.T) { t.Parallel() app, addr := setupApp() defer func() { require.NoError(t, app.Shutdown()) }() for range 5 { resp, err := New().Post("http://"+addr, Config{ FormData: map[string]string{ "foo": "bar", }, }) require.NoError(t, err) require.Equal(t, fiber.StatusCreated, resp.StatusCode()) require.Equal(t, "bar", resp.String()) } }) } func Test_Put(t *testing.T) { t.Parallel() setupApp := func() (*fiber.App, string) { app, addr := startTestServerWithPort(t, func(app *fiber.App) { app.Put("/", func(c fiber.Ctx) error { return c.SendString(c.FormValue("foo")) }) }) return app, addr } t.Run("global put function", func(t *testing.T) { t.Parallel() app, addr := setupApp() defer func() { require.NoError(t, app.Shutdown()) }() for range 5 { resp, err := Put("http://"+addr, Config{ FormData: map[string]string{ "foo": "bar", }, }) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "bar", resp.String()) } }) t.Run("client put", func(t *testing.T) { t.Parallel() app, addr := setupApp() defer func() { require.NoError(t, app.Shutdown()) }() for range 5 { resp, err := New().Put("http://"+addr, Config{ FormData: map[string]string{ "foo": "bar", }, }) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "bar", resp.String()) } }) } func Test_Delete(t *testing.T) { t.Parallel() setupApp := func() (*fiber.App, string) { app, addr := startTestServerWithPort(t, func(app *fiber.App) { app.Delete("/", func(c fiber.Ctx) error { return c.Status(fiber.StatusNoContent). SendString("deleted") }) }) return app, addr } t.Run("global delete function", func(t *testing.T) { t.Parallel() app, addr := setupApp() defer func() { require.NoError(t, app.Shutdown()) }() time.Sleep(1 * time.Second) for range 5 { resp, err := Delete("http://"+addr, Config{ FormData: map[string]string{ "foo": "bar", }, }) require.NoError(t, err) require.Equal(t, fiber.StatusNoContent, resp.StatusCode()) require.Empty(t, resp.String()) } }) t.Run("client delete", func(t *testing.T) { t.Parallel() app, addr := setupApp() defer func() { require.NoError(t, app.Shutdown()) }() for range 5 { resp, err := New().Delete("http://"+addr, Config{ FormData: map[string]string{ "foo": "bar", }, }) require.NoError(t, err) require.Equal(t, fiber.StatusNoContent, resp.StatusCode()) require.Empty(t, resp.String()) } }) } func Test_Options(t *testing.T) { t.Parallel() setupApp := func() (*fiber.App, string) { app, addr := startTestServerWithPort(t, func(app *fiber.App) { app.Options("/", func(c fiber.Ctx) error { c.Set(fiber.HeaderAllow, "GET, POST, PUT, DELETE, PATCH") return c.Status(fiber.StatusNoContent).SendString("") }) }) return app, addr } t.Run("global options function", func(t *testing.T) { t.Parallel() app, addr := setupApp() defer func() { require.NoError(t, app.Shutdown()) }() for range 5 { resp, err := Options("http://" + addr) require.NoError(t, err) require.Equal(t, "GET, POST, PUT, DELETE, PATCH", resp.Header(fiber.HeaderAllow)) require.Equal(t, fiber.StatusNoContent, resp.StatusCode()) require.Empty(t, resp.String()) } }) t.Run("client options", func(t *testing.T) { t.Parallel() app, addr := setupApp() defer func() { require.NoError(t, app.Shutdown()) }() for range 5 { resp, err := New().Options("http://" + addr) require.NoError(t, err) require.Equal(t, "GET, POST, PUT, DELETE, PATCH", resp.Header(fiber.HeaderAllow)) require.Equal(t, fiber.StatusNoContent, resp.StatusCode()) require.Empty(t, resp.String()) } }) } func Test_Patch(t *testing.T) { t.Parallel() setupApp := func() (*fiber.App, string) { app, addr := startTestServerWithPort(t, func(app *fiber.App) { app.Patch("/", func(c fiber.Ctx) error { return c.SendString(c.FormValue("foo")) }) }) return app, addr } t.Run("global patch function", func(t *testing.T) { t.Parallel() app, addr := setupApp() defer func() { require.NoError(t, app.Shutdown()) }() time.Sleep(1 * time.Second) for range 5 { resp, err := Patch("http://"+addr, Config{ FormData: map[string]string{ "foo": "bar", }, }) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "bar", resp.String()) } }) t.Run("client patch", func(t *testing.T) { t.Parallel() app, addr := setupApp() defer func() { require.NoError(t, app.Shutdown()) }() for range 5 { resp, err := New().Patch("http://"+addr, Config{ FormData: map[string]string{ "foo": "bar", }, }) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "bar", resp.String()) } }) } func Test_Client_UserAgent(t *testing.T) { t.Parallel() setupApp := func() (*fiber.App, string) { app, addr := startTestServerWithPort(t, func(app *fiber.App) { app.Get("/", func(c fiber.Ctx) error { return c.Send(c.Request().Header.UserAgent()) }) }) return app, addr } t.Run("default", func(t *testing.T) { t.Parallel() app, addr := setupApp() defer func() { require.NoError(t, app.Shutdown()) }() for range 5 { resp, err := Get("http://" + addr) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, defaultUserAgent, resp.String()) } }) t.Run("custom", func(t *testing.T) { t.Parallel() app, addr := setupApp() defer func() { require.NoError(t, app.Shutdown()) }() for range 5 { c := New(). SetUserAgent("ua") resp, err := c.Get("http://" + addr) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "ua", resp.String()) } }) } func Test_Client_Header(t *testing.T) { t.Parallel() t.Run("add header", func(t *testing.T) { t.Parallel() req := New() req.AddHeader("foo", "bar").AddHeader("foo", "fiber") res := req.Header("foo") require.Len(t, res, 2) require.Equal(t, "bar", res[0]) require.Equal(t, "fiber", res[1]) }) t.Run("set header", func(t *testing.T) { t.Parallel() req := New() req.AddHeader("foo", "bar").SetHeader("foo", "fiber") res := req.Header("foo") require.Len(t, res, 1) require.Equal(t, "fiber", res[0]) }) t.Run("add headers", func(t *testing.T) { t.Parallel() req := New() req.SetHeader("foo", "bar"). AddHeaders(map[string][]string{ "foo": {"fiber", "buaa"}, "bar": {"foo"}, }) res := req.Header("foo") require.Len(t, res, 3) require.Equal(t, "bar", res[0]) require.Equal(t, "fiber", res[1]) require.Equal(t, "buaa", res[2]) res = req.Header("bar") require.Len(t, res, 1) require.Equal(t, "foo", res[0]) }) t.Run("set headers", func(t *testing.T) { t.Parallel() req := New() req.SetHeader("foo", "bar"). SetHeaders(map[string]string{ "foo": "fiber", "bar": "foo", }) res := req.Header("foo") require.Len(t, res, 1) require.Equal(t, "fiber", res[0]) res = req.Header("bar") require.Len(t, res, 1) require.Equal(t, "foo", res[0]) }) t.Run("set header case-insensitive", func(t *testing.T) { t.Parallel() req := New() req.SetHeader("foo", "bar"). AddHeader("FOO", "fiber") res := req.Header("foo") require.Len(t, res, 2) require.Equal(t, "bar", res[0]) require.Equal(t, "fiber", res[1]) }) } func Test_Client_Header_With_Server(t *testing.T) { handler := func(c fiber.Ctx) error { for key, value := range c.Request().Header.All() { if k := string(key); k == "K1" || k == "K2" { _, _ = c.Write(key) //nolint:errcheck // It is fine to ignore the error here _, _ = c.Write(value) //nolint:errcheck // It is fine to ignore the error here } } return nil } wrapAgent := func(c *Client) { c.SetHeader("k1", "v1"). AddHeader("k1", "v11"). AddHeaders(map[string][]string{ "k1": {"v22", "v33"}, }). SetHeaders(map[string]string{ "k2": "v2", }). AddHeader("k2", "v22") } testClient(t, handler, wrapAgent, "K1v1K1v11K1v22K1v33K2v2K2v22") } func Test_Client_Cookie(t *testing.T) { t.Parallel() t.Run("set cookie", func(t *testing.T) { t.Parallel() req := New(). SetCookie("foo", "bar") require.Equal(t, "bar", req.Cookie("foo")) req.SetCookie("foo", "bar1") require.Equal(t, "bar1", req.Cookie("foo")) }) t.Run("set cookies", func(t *testing.T) { t.Parallel() req := New(). SetCookies(map[string]string{ "foo": "bar", "bar": "foo", }) require.Equal(t, "bar", req.Cookie("foo")) require.Equal(t, "foo", req.Cookie("bar")) req.SetCookies(map[string]string{ "foo": "bar1", }) require.Equal(t, "bar1", req.Cookie("foo")) require.Equal(t, "foo", req.Cookie("bar")) }) t.Run("set cookies with struct", func(t *testing.T) { t.Parallel() type args struct { CookieString string `cookie:"string"` CookieInt int `cookie:"int"` } req := New().SetCookiesWithStruct(&args{ CookieInt: 5, CookieString: "foo", }) require.Equal(t, "5", req.Cookie("int")) require.Equal(t, "foo", req.Cookie("string")) }) t.Run("del cookies", func(t *testing.T) { t.Parallel() req := New(). SetCookies(map[string]string{ "foo": "bar", "bar": "foo", }) require.Equal(t, "bar", req.Cookie("foo")) require.Equal(t, "foo", req.Cookie("bar")) req.DelCookies("foo") require.Empty(t, req.Cookie("foo")) require.Equal(t, "foo", req.Cookie("bar")) }) } func Test_Client_Cookie_With_Server(t *testing.T) { t.Parallel() handler := func(c fiber.Ctx) error { return c.SendString( c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3") + c.Cookies("k4")) } wrapAgent := func(c *Client) { c.SetCookie("k1", "v1"). SetCookies(map[string]string{ "k2": "v2", "k3": "v3", "k4": "v4", }).DelCookies("k4") } testClient(t, handler, wrapAgent, "v1v2v3") } func Test_Client_CookieJar(t *testing.T) { handler := func(c fiber.Ctx) error { return c.SendString( c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3")) } jar := AcquireCookieJar() defer ReleaseCookieJar(jar) jar.SetKeyValue("example.com", "k1", "v1") jar.SetKeyValue("example.com", "k2", "v2") jar.SetKeyValue("example", "k3", "v3") wrapAgent := func(c *Client) { c.SetCookieJar(jar) } testClient(t, handler, wrapAgent, "v1v2") } func Test_Client_CookieJar_Response(t *testing.T) { t.Parallel() t.Run("without expiration", func(t *testing.T) { t.Parallel() handler := func(c fiber.Ctx) error { c.Cookie(&fiber.Cookie{ Name: "k4", Value: "v4", }) return c.SendString( c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3")) } jar := AcquireCookieJar() defer ReleaseCookieJar(jar) jar.SetKeyValue("example.com", "k1", "v1") jar.SetKeyValue("example.com", "k2", "v2") jar.SetKeyValue("example", "k3", "v3") wrapAgent := func(c *Client) { c.SetCookieJar(jar) } testClient(t, handler, wrapAgent, "v1v2") require.Len(t, jar.getCookiesByHost("example.com"), 3) }) t.Run("with expiration", func(t *testing.T) { t.Parallel() handler := func(c fiber.Ctx) error { c.Cookie(&fiber.Cookie{ Name: "k4", Value: "v4", Expires: time.Now().Add(1 * time.Nanosecond), }) return c.SendString( c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3")) } jar := AcquireCookieJar() defer ReleaseCookieJar(jar) jar.SetKeyValue("example.com", "k1", "v1") jar.SetKeyValue("example.com", "k2", "v2") jar.SetKeyValue("example", "k3", "v3") wrapAgent := func(c *Client) { c.SetCookieJar(jar) } testClient(t, handler, wrapAgent, "v1v2") require.Len(t, jar.getCookiesByHost("example.com"), 2) }) t.Run("override cookie value", func(t *testing.T) { t.Parallel() handler := func(c fiber.Ctx) error { c.Cookie(&fiber.Cookie{ Name: "k1", Value: "v2", }) return c.SendString( c.Cookies("k1") + c.Cookies("k2")) } jar := AcquireCookieJar() defer ReleaseCookieJar(jar) jar.SetKeyValue("example.com", "k1", "v1") jar.SetKeyValue("example.com", "k2", "v2") wrapAgent := func(c *Client) { c.SetCookieJar(jar) } testClient(t, handler, wrapAgent, "v1v2") for _, cookie := range jar.getCookiesByHost("example.com") { if string(cookie.Key()) == "k1" { require.Equal(t, "v2", string(cookie.Value())) } } }) t.Run("different domain", func(t *testing.T) { t.Parallel() handler := func(c fiber.Ctx) error { return c.SendString(c.Cookies("k1")) } jar := AcquireCookieJar() defer ReleaseCookieJar(jar) jar.SetKeyValue("example.com", "k1", "v1") wrapAgent := func(c *Client) { c.SetCookieJar(jar) } testClient(t, handler, wrapAgent, "v1") require.Len(t, jar.getCookiesByHost("example.com"), 1) require.Empty(t, jar.getCookiesByHost("example")) }) } func Test_Client_Referer(t *testing.T) { handler := func(c fiber.Ctx) error { return c.Send(c.Request().Header.Referer()) } wrapAgent := func(c *Client) { c.SetReferer("http://referer.com") } testClient(t, handler, wrapAgent, "http://referer.com") } func Test_Client_QueryParam(t *testing.T) { t.Parallel() t.Run("add param", func(t *testing.T) { t.Parallel() req := New() req.AddParam("foo", "bar").AddParam("foo", "fiber") res := req.Param("foo") require.Len(t, res, 2) require.Equal(t, "bar", res[0]) require.Equal(t, "fiber", res[1]) }) t.Run("set param", func(t *testing.T) { t.Parallel() req := New() req.AddParam("foo", "bar").SetParam("foo", "fiber") res := req.Param("foo") require.Len(t, res, 1) require.Equal(t, "fiber", res[0]) }) t.Run("add params", func(t *testing.T) { t.Parallel() req := New() req.SetParam("foo", "bar"). AddParams(map[string][]string{ "foo": {"fiber", "buaa"}, "bar": {"foo"}, }) res := req.Param("foo") require.Len(t, res, 3) require.Equal(t, "bar", res[0]) require.Equal(t, "fiber", res[1]) require.Equal(t, "buaa", res[2]) res = req.Param("bar") require.Len(t, res, 1) require.Equal(t, "foo", res[0]) }) t.Run("set headers", func(t *testing.T) { t.Parallel() req := New() req.SetParam("foo", "bar"). SetParams(map[string]string{ "foo": "fiber", "bar": "foo", }) res := req.Param("foo") require.Len(t, res, 1) require.Equal(t, "fiber", res[0]) res = req.Param("bar") require.Len(t, res, 1) require.Equal(t, "foo", res[0]) }) t.Run("set params with struct", func(t *testing.T) { t.Parallel() type args struct { TString string TSlice []string TIntSlice []int `param:"int_slice"` TInt int TFloat float64 TBool bool } p := New() p.SetParamsWithStruct(&args{ TInt: 5, TString: "string", TFloat: 3.1, TBool: true, TSlice: []string{"foo", "bar"}, TIntSlice: []int{1, 2}, }) require.Empty(t, p.Param("unexport")) require.Len(t, p.Param("TInt"), 1) require.Equal(t, "5", p.Param("TInt")[0]) require.Len(t, p.Param("TString"), 1) require.Equal(t, "string", p.Param("TString")[0]) require.Len(t, p.Param("TFloat"), 1) require.Equal(t, "3.1", p.Param("TFloat")[0]) require.Len(t, p.Param("TBool"), 1) tslice := p.Param("TSlice") require.Len(t, tslice, 2) require.Equal(t, "foo", tslice[0]) require.Equal(t, "bar", tslice[1]) tint := p.Param("TSlice") require.Len(t, tint, 2) require.Equal(t, "foo", tint[0]) require.Equal(t, "bar", tint[1]) }) t.Run("del params", func(t *testing.T) { t.Parallel() req := New() req.SetParam("foo", "bar"). SetParams(map[string]string{ "foo": "fiber", "bar": "foo", }).DelParams("foo", "bar") res := req.Param("foo") require.Empty(t, res) res = req.Param("bar") require.Empty(t, res) }) } func Test_Client_QueryParam_With_Server(t *testing.T) { handler := func(c fiber.Ctx) error { _, _ = c.WriteString(c.Query("k1")) //nolint:errcheck // It is fine to ignore the error here _, _ = c.WriteString(c.Query("k2")) //nolint:errcheck // It is fine to ignore the error here return nil } wrapAgent := func(c *Client) { c.SetParam("k1", "v1"). AddParam("k2", "v2") } testClient(t, handler, wrapAgent, "v1v2") } func Test_Client_PathParam(t *testing.T) { t.Parallel() t.Run("set path param", func(t *testing.T) { t.Parallel() req := New(). SetPathParam("foo", "bar") require.Equal(t, "bar", req.PathParam("foo")) req.SetPathParam("foo", "bar1") require.Equal(t, "bar1", req.PathParam("foo")) }) t.Run("set path params", func(t *testing.T) { t.Parallel() req := New(). SetPathParams(map[string]string{ "foo": "bar", "bar": "foo", }) require.Equal(t, "bar", req.PathParam("foo")) require.Equal(t, "foo", req.PathParam("bar")) req.SetPathParams(map[string]string{ "foo": "bar1", }) require.Equal(t, "bar1", req.PathParam("foo")) require.Equal(t, "foo", req.PathParam("bar")) }) t.Run("set path params with struct", func(t *testing.T) { t.Parallel() type args struct { CookieString string `path:"string"` CookieInt int `path:"int"` } req := New().SetPathParamsWithStruct(&args{ CookieInt: 5, CookieString: "foo", }) require.Equal(t, "5", req.PathParam("int")) require.Equal(t, "foo", req.PathParam("string")) }) t.Run("del path params", func(t *testing.T) { t.Parallel() req := New(). SetPathParams(map[string]string{ "foo": "bar", "bar": "foo", }) require.Equal(t, "bar", req.PathParam("foo")) require.Equal(t, "foo", req.PathParam("bar")) req.DelPathParams("foo") require.Empty(t, req.PathParam("foo")) require.Equal(t, "foo", req.PathParam("bar")) }) } func Test_Client_PathParam_With_Server(t *testing.T) { app, dial, start := createHelperServer(t) app.Get("/:test", func(c fiber.Ctx) error { return c.SendString(c.Params("test")) }) go start() resp, err := New().SetDial(dial). SetPathParam("path", "test"). Get("http://example.com/:path") require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "test", resp.String()) } func Test_Client_TLS(t *testing.T) { t.Parallel() serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() require.NoError(t, err) ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") require.NoError(t, err) ln = tls.NewListener(ln, serverTLSConf) app := fiber.New() app.Get("/", func(c fiber.Ctx) error { return c.SendString("tls") }) go func() { assert.NoError(t, app.Listener(ln, fiber.ListenConfig{ DisableStartupMessage: true, })) }() time.Sleep(1 * time.Second) client := New() resp, err := client.SetTLSConfig(clientTLSConf).Get("https://" + ln.Addr().String()) require.NoError(t, err) cfg := client.TLSConfig() require.Same(t, clientTLSConf, cfg) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "tls", resp.String()) } func Test_Client_TLS_Error(t *testing.T) { t.Parallel() serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() clientTLSConf.MaxVersion = tls.VersionTLS12 serverTLSConf.MinVersion = tls.VersionTLS13 require.NoError(t, err) ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") require.NoError(t, err) ln = tls.NewListener(ln, serverTLSConf) app := fiber.New() app.Get("/", func(c fiber.Ctx) error { return c.SendString("tls") }) go func() { assert.NoError(t, app.Listener(ln, fiber.ListenConfig{ DisableStartupMessage: true, })) }() time.Sleep(1 * time.Second) client := New() resp, err := client.SetTLSConfig(clientTLSConf).Get("https://" + ln.Addr().String()) require.Error(t, err) cfg := client.TLSConfig() require.Same(t, clientTLSConf, cfg) require.Nil(t, resp) } func Test_Client_TLS_Empty_TLSConfig(t *testing.T) { t.Parallel() serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() require.NoError(t, err) ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") require.NoError(t, err) ln = tls.NewListener(ln, serverTLSConf) app := fiber.New() app.Get("/", func(c fiber.Ctx) error { return c.SendString("tls") }) go func() { assert.NoError(t, app.Listener(ln, fiber.ListenConfig{ DisableStartupMessage: true, })) }() time.Sleep(1 * time.Second) client := New() resp, err := client.Get("https://" + ln.Addr().String()) require.Error(t, err) require.NotEqual(t, clientTLSConf, client.TLSConfig()) require.Nil(t, resp) } func Test_Client_SetCertificates(t *testing.T) { t.Parallel() serverTLSConf, _, err := tlstest.GetTLSConfigs() require.NoError(t, err) client := New().SetCertificates(serverTLSConf.Certificates...) require.Len(t, client.TLSConfig().Certificates, 1) } func Test_Client_SetRootCertificate(t *testing.T) { t.Parallel() client := New().SetRootCertificate("../.github/testdata/ssl.pem") require.NotNil(t, client.TLSConfig().RootCAs) } func Test_Client_SetRootCertificateFromString(t *testing.T) { t.Parallel() file, err := os.Open("../.github/testdata/ssl.pem") defer func() { require.NoError(t, file.Close()) }() require.NoError(t, err) pem, err := io.ReadAll(file) require.NoError(t, err) client := New().SetRootCertificateFromString(string(pem)) require.NotNil(t, client.TLSConfig().RootCAs) } func Test_Client_R(t *testing.T) { t.Parallel() client := New() req := client.R() require.Equal(t, "Request", reflect.TypeOf(req).Elem().Name()) require.Equal(t, client, req.Client()) } func Test_Replace(t *testing.T) { app, dial, start := createHelperServer(t) app.Get("/", func(c fiber.Ctx) error { return c.SendString(string(c.Request().Header.Peek("k1"))) }) go start() C().SetDial(dial) resp, err := Get("http://example.com") require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Empty(t, resp.String()) r := New().SetDial(dial).SetHeader("k1", "v1") clean := Replace(r) resp, err = Get("http://example.com") require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "v1", resp.String()) clean() C().SetDial(dial) resp, err = Get("http://example.com") require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Empty(t, resp.String()) C().SetDial(nil) } func Test_Set_Config_To_Request(t *testing.T) { t.Parallel() t.Run("set ctx", func(t *testing.T) { t.Parallel() type ctxKey struct{} var key ctxKey = struct{}{} ctx := context.Background() ctx = context.WithValue(ctx, key, "v1") req := AcquireRequest() setConfigToRequest(req, Config{Ctx: ctx}) require.Equal(t, "v1", req.Context().Value(key)) }) t.Run("set useragent", func(t *testing.T) { t.Parallel() req := AcquireRequest() setConfigToRequest(req, Config{UserAgent: "agent"}) require.Equal(t, "agent", req.UserAgent()) }) t.Run("set referer", func(t *testing.T) { t.Parallel() req := AcquireRequest() setConfigToRequest(req, Config{Referer: "referer"}) require.Equal(t, "referer", req.Referer()) }) t.Run("set header", func(t *testing.T) { req := AcquireRequest() setConfigToRequest(req, Config{Header: map[string]string{ "k1": "v1", }}) require.Equal(t, "v1", req.Header("k1")[0]) }) t.Run("set params", func(t *testing.T) { t.Parallel() req := AcquireRequest() setConfigToRequest(req, Config{Param: map[string]string{ "k1": "v1", }}) require.Equal(t, "v1", req.Param("k1")[0]) }) t.Run("set cookies", func(t *testing.T) { t.Parallel() req := AcquireRequest() setConfigToRequest(req, Config{Cookie: map[string]string{ "k1": "v1", }}) require.Equal(t, "v1", req.Cookie("k1")) }) t.Run("set pathparam", func(t *testing.T) { t.Parallel() req := AcquireRequest() setConfigToRequest(req, Config{PathParam: map[string]string{ "k1": "v1", }}) require.Equal(t, "v1", req.PathParam("k1")) }) t.Run("set timeout", func(t *testing.T) { t.Parallel() req := AcquireRequest() setConfigToRequest(req, Config{Timeout: 1 * time.Second}) require.Equal(t, 1*time.Second, req.Timeout()) }) t.Run("set maxredirects", func(t *testing.T) { t.Parallel() req := AcquireRequest() setConfigToRequest(req, Config{MaxRedirects: 1}) require.Equal(t, 1, req.MaxRedirects()) }) t.Run("set body string", func(t *testing.T) { t.Parallel() req := AcquireRequest() setConfigToRequest(req, Config{Body: "test"}) body, ok := req.body.([]byte) require.True(t, ok) require.Equal(t, "test", string(body)) }) t.Run("set body byte", func(t *testing.T) { t.Parallel() req := AcquireRequest() setConfigToRequest(req, Config{Body: []byte("test")}) require.Equal(t, []byte("test"), req.body) }) t.Run("set body json", func(t *testing.T) { t.Parallel() req := AcquireRequest() type payload struct { Foo string `json:"foo"` } setConfigToRequest(req, Config{Body: payload{Foo: "bar"}}) payloadBody, ok := req.body.(payload) require.True(t, ok) require.Equal(t, payload{Foo: "bar"}, payloadBody) }) t.Run("set body map", func(t *testing.T) { t.Parallel() req := AcquireRequest() setConfigToRequest(req, Config{Body: map[string]string{ "foo": "bar", }}) require.Equal(t, map[string]string{ "foo": "bar", }, req.body) }) t.Run("set file", func(t *testing.T) { t.Parallel() req := AcquireRequest() setConfigToRequest(req, Config{File: []*File{ { name: "test", path: "path", }, }}) require.Equal(t, "path", req.File("test").path) }) } func Test_Client_SetProxyURL(t *testing.T) { t.Parallel() app, dial, start := createHelperServer(t) app.Get("/", func(c fiber.Ctx) error { return c.SendString(c.Get("isProxy")) }) go start() fasthttpClient := &fasthttp.Client{ Dial: dial, NoDefaultUserAgentHeader: true, DisablePathNormalizing: true, } // Create a simple proxy server proxyServer := fiber.New() proxyServer.Use("*", func(c fiber.Ctx) error { req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() req.SetRequestURI(c.BaseURL()) req.Header.SetMethod(fasthttp.MethodGet) for key, value := range c.Request().Header.All() { req.Header.AddBytesKV(key, value) } req.Header.Set("isProxy", "true") if err := fasthttpClient.Do(req, resp); err != nil { return err } c.Status(resp.StatusCode()) c.RequestCtx().SetBody(resp.Body()) return nil }) addrChan := make(chan string) go func() { assert.NoError(t, proxyServer.Listen(":0", fiber.ListenConfig{ DisableStartupMessage: true, ListenerAddrFunc: func(addr net.Addr) { addrChan <- addr.String() }, })) }() t.Cleanup(func() { require.NoError(t, app.Shutdown()) }) time.Sleep(1 * time.Second) t.Run("success", func(t *testing.T) { t.Parallel() client := New() err := client.SetProxyURL(<-addrChan) require.NoError(t, err) resp, err := client.Get("http://localhost:3000") require.NoError(t, err) require.Equal(t, 200, resp.StatusCode()) require.Equal(t, "true", string(resp.Body())) }) t.Run("error", func(t *testing.T) { t.Parallel() client := New() err := client.SetProxyURL(":this is not a proxy") require.NoError(t, err) _, err = client.Get("http://localhost:3000") require.Error(t, err) }) } func Test_Client_SetRetryConfig(t *testing.T) { t.Parallel() retryConfig := &retry.Config{ InitialInterval: 1 * time.Second, MaxRetryCount: 3, } core, client, req := newCore(), New(), AcquireRequest() req.SetURL("http://exampleretry.com") client.SetRetryConfig(retryConfig) _, err := core.execute(context.Background(), client, req) require.Error(t, err) require.Equal(t, retryConfig.InitialInterval, client.RetryConfig().InitialInterval) require.Equal(t, retryConfig.MaxRetryCount, client.RetryConfig().MaxRetryCount) } func Benchmark_Client_Request(b *testing.B) { app, dial, start := createHelperServer(b) app.Get("/", func(c fiber.Ctx) error { return c.SendString("hello world") }) go start() client := New().SetDial(dial) b.ReportAllocs() var err error var resp *Response for b.Loop() { resp, err = client.Get("http://example.com") resp.Close() } require.NoError(b, err) } func Benchmark_Client_Request_Parallel(b *testing.B) { app, dial, start := createHelperServer(b) app.Get("/", func(c fiber.Ctx) error { return c.SendString("hello world") }) go start() client := New().SetDial(dial) b.ResetTimer() b.ReportAllocs() b.RunParallel(func(pb *testing.PB) { var err error var resp *Response for pb.Next() { resp, err = client.Get("http://example.com") resp.Close() } require.NoError(b, err) }) } func Benchmark_Client_Request_Send_ContextCancel(b *testing.B) { app, ln, start := createHelperServer(b) startedCh := make(chan struct{}) errCh := make(chan error) respCh := make(chan *Response) app.Post("/", func(c fiber.Ctx) error { startedCh <- struct{}{} time.Sleep(time.Millisecond) // let cancel be called return c.Status(fiber.StatusOK).SendString("post") }) go start() client := New().SetDial(ln) b.ReportAllocs() b.ResetTimer() for b.Loop() { ctx, cancel := context.WithCancel(context.Background()) req := AcquireRequest(). SetClient(client). SetURL("http://example.com"). SetMethod(fiber.MethodPost). SetContext(ctx) go func(r *Request) { defer ReleaseRequest(r) resp, err := r.Send() respCh <- resp errCh <- err }(req) <-startedCh // request is made, we can cancel the context now cancel() require.Nil(b, <-respCh) require.ErrorIs(b, <-errCh, ErrTimeoutOrCancel) } } func Test_Client_StreamResponseBody(t *testing.T) { t.Parallel() t.Run("default value", func(t *testing.T) { t.Parallel() client := New() require.False(t, client.StreamResponseBody()) }) t.Run("enable streaming", func(t *testing.T) { t.Parallel() client := New() result := client.SetStreamResponseBody(true) require.True(t, client.StreamResponseBody()) require.Equal(t, client, result) }) t.Run("disable streaming", func(t *testing.T) { t.Parallel() client := New() client.SetStreamResponseBody(true) require.True(t, client.StreamResponseBody()) client.SetStreamResponseBody(false) require.False(t, client.StreamResponseBody()) }) t.Run("with host client", func(t *testing.T) { t.Parallel() hostClient := &fasthttp.HostClient{} client := NewWithHostClient(hostClient) client.SetStreamResponseBody(true) require.True(t, client.StreamResponseBody()) require.True(t, hostClient.StreamResponseBody) }) t.Run("with lb client", func(t *testing.T) { t.Parallel() hostClient := &fasthttp.HostClient{Addr: "example.com:80"} lbClient := &fasthttp.LBClient{ Clients: []fasthttp.BalancingClient{ hostClient, }, } client := NewWithLBClient(lbClient) client.SetStreamResponseBody(true) require.True(t, client.StreamResponseBody()) require.True(t, hostClient.StreamResponseBody) }) } ================================================ FILE: client/cookiejar.go ================================================ // The code was originally taken from https://github.com/valyala/fasthttp/pull/526. package client import ( "bytes" "net" "strings" "sync" "time" "github.com/gofiber/utils/v2" utilsbytes "github.com/gofiber/utils/v2/bytes" utilsstrings "github.com/gofiber/utils/v2/strings" "github.com/valyala/fasthttp" ) var cookieJarPool = sync.Pool{ New: func() any { return &CookieJar{} }, } // AcquireCookieJar returns an empty CookieJar object from the pool. func AcquireCookieJar() *CookieJar { jar, ok := cookieJarPool.Get().(*CookieJar) if !ok { panic(errCookieJarTypeAssertion) } return jar } // ReleaseCookieJar returns a CookieJar object to the pool. func ReleaseCookieJar(c *CookieJar) { c.Release() cookieJarPool.Put(c) } // CookieJar manages cookie storage for the client. It stores cookies keyed by host. type CookieJar struct { hostCookies map[string][]*fasthttp.Cookie mu sync.Mutex } // Get returns all cookies stored for a given URI. If there are no cookies for the // provided host, the returned slice will be nil. // // The CookieJar keeps its own copies of cookies, so it is safe to release the returned // cookies after use. func (cj *CookieJar) Get(uri *fasthttp.URI) []*fasthttp.Cookie { if uri == nil { return nil } secure := bytes.Equal(uri.Scheme(), httpsScheme) return cj.getByHostAndPath(uri.Host(), uri.Path(), secure) } // getByHostAndPath returns cookies stored for a specific host and path. func (cj *CookieJar) getByHostAndPath(host, path []byte, secure bool) []*fasthttp.Cookie { if cj.hostCookies == nil { return nil } var ( err error hostStr = utils.UnsafeString(host) ) // port must not be included. hostStr, _, err = net.SplitHostPort(hostStr) if err != nil { hostStr = utils.UnsafeString(host) } return cj.cookiesForRequest(hostStr, path, secure) } // getCookiesByHost returns cookies stored for a specific host, removing any that have expired. func (cj *CookieJar) getCookiesByHost(host string) []*fasthttp.Cookie { cj.mu.Lock() defer cj.mu.Unlock() now := time.Now() cookies := cj.hostCookies[host] kept := cookies[:0] for _, c := range cookies { // Remove expired cookies. if !c.Expire().Equal(fasthttp.CookieExpireUnlimited) && c.Expire().Before(now) { fasthttp.ReleaseCookie(c) continue } kept = append(kept, c) } cj.hostCookies[host] = kept return kept } // cookiesForRequest returns cookies that match the given host, path and security settings. // //nolint:revive // secure is required to filter Secure cookies based on scheme func (cj *CookieJar) cookiesForRequest(host string, path []byte, secure bool) []*fasthttp.Cookie { cj.mu.Lock() defer cj.mu.Unlock() now := time.Now() var matched []*fasthttp.Cookie for domain, cookies := range cj.hostCookies { if !domainMatch(host, domain) { continue } kept := cookies[:0] for _, c := range cookies { if !c.Expire().Equal(fasthttp.CookieExpireUnlimited) && c.Expire().Before(now) { fasthttp.ReleaseCookie(c) continue } kept = append(kept, c) if !pathMatch(path, c.Path()) { continue } if c.Secure() && !secure { continue } nc := fasthttp.AcquireCookie() nc.CopyTo(c) matched = append(matched, nc) } cj.hostCookies[domain] = kept } return matched } // Set stores the given cookies for the specified URI host. If a cookie key already exists, // it will be replaced by the new cookie value. // // CookieJar stores copies of the provided cookies, so they may be safely released after use. func (cj *CookieJar) Set(uri *fasthttp.URI, cookies ...*fasthttp.Cookie) { if uri == nil { return } cj.SetByHost(uri.Host(), cookies...) } // SetByHost stores the given cookies for the specified host. If a cookie key already exists, // it will be replaced by the new cookie value. // // CookieJar stores copies of the provided cookies, so they may be safely released after use. func (cj *CookieJar) SetByHost(host []byte, cookies ...*fasthttp.Cookie) { hostStr := utils.UnsafeString(host) if h, _, err := net.SplitHostPort(hostStr); err == nil { hostStr = h } hostStr = utilsstrings.ToLower(hostStr) hostKey := utils.CopyString(hostStr) cj.mu.Lock() defer cj.mu.Unlock() if cj.hostCookies == nil { cj.hostCookies = make(map[string][]*fasthttp.Cookie) } for _, cookie := range cookies { domain := utils.TrimLeft(cookie.Domain(), '.') utilsbytes.UnsafeToLower(domain) key := hostKey if len(domain) == 0 { cookie.SetDomain(hostStr) } else { key = utils.CopyString(utils.UnsafeString(domain)) cookie.SetDomainBytes(domain) } hostCookies := cj.hostCookies[key] existing := searchCookieByKeyAndPath(cookie.Key(), cookie.Path(), hostCookies) if existing == nil { existing = fasthttp.AcquireCookie() hostCookies = append(hostCookies, existing) } existing.CopyTo(cookie) cj.hostCookies[key] = hostCookies } } // SetKeyValue sets a cookie for the specified host with the given key and value. // // This function helps prevent extra allocations by avoiding duplication of repeated cookies. func (cj *CookieJar) SetKeyValue(host, key, value string) { c := fasthttp.AcquireCookie() c.SetKey(key) c.SetValue(value) cj.SetByHost(utils.UnsafeBytes(host), c) } // SetKeyValueBytes sets a cookie for the specified host using byte slices for the key and value. // // This function helps prevent extra allocations by avoiding duplication of repeated cookies. func (cj *CookieJar) SetKeyValueBytes(host string, key, value []byte) { c := fasthttp.AcquireCookie() c.SetKeyBytes(key) c.SetValueBytes(value) cj.SetByHost(utils.UnsafeBytes(host), c) } // dumpCookiesToReq writes the stored cookies to the given request. func (cj *CookieJar) dumpCookiesToReq(req *fasthttp.Request) { uri := req.URI() secure := bytes.Equal(uri.Scheme(), httpsScheme) cookies := cj.getByHostAndPath(uri.Host(), uri.Path(), secure) for _, cookie := range cookies { req.Header.SetCookieBytesKV(cookie.Key(), cookie.Value()) fasthttp.ReleaseCookie(cookie) } } // parseCookiesFromResp parses the cookies from the response and stores them for the specified host and path. func (cj *CookieJar) parseCookiesFromResp(host, _ []byte, resp *fasthttp.Response) { hostStr := utils.UnsafeString(host) if h, _, err := net.SplitHostPort(hostStr); err == nil { hostStr = h } hostStr = utilsstrings.ToLower(hostStr) hostKey := utils.CopyString(hostStr) cj.mu.Lock() defer cj.mu.Unlock() if cj.hostCookies == nil { cj.hostCookies = make(map[string][]*fasthttp.Cookie) } now := time.Now() for _, value := range resp.Header.Cookies() { tmp := fasthttp.AcquireCookie() _ = tmp.ParseBytes(value) //nolint:errcheck // ignore error domainBytes := utils.TrimLeft(tmp.Domain(), '.') utilsbytes.UnsafeToLower(domainBytes) key := hostKey if len(domainBytes) == 0 { tmp.SetDomain(hostStr) } else { key = utils.CopyString(utils.UnsafeString(domainBytes)) tmp.SetDomainBytes(domainBytes) } cookies := cj.hostCookies[key] c := searchCookieByKeyAndPath(tmp.Key(), tmp.Path(), cookies) if c == nil { c = fasthttp.AcquireCookie() cookies = append(cookies, c) } c.CopyTo(tmp) if c.Expire().Equal(fasthttp.CookieExpireUnlimited) || c.Expire().After(now) { cj.hostCookies[key] = cookies } else { kept := cookies[:0] for _, v := range cookies { if v != c { kept = append(kept, v) } } cj.hostCookies[key] = kept fasthttp.ReleaseCookie(c) } fasthttp.ReleaseCookie(tmp) } } // Release releases all stored cookies. After this, the CookieJar is empty. func (cj *CookieJar) Release() { // FOLLOW-UP performance optimization: // Currently, a race condition is found because the reset method modifies a value // that is not a copy but a reference. A solution would be to make a copy. // for _, v := range cj.hostCookies { // for _, c := range v { // fasthttp.ReleaseCookie(c) // } // } cj.hostCookies = nil } // searchCookieByKeyAndPath looks up a cookie by its key and path from the provided slice of cookies. func searchCookieByKeyAndPath(key, path []byte, cookies []*fasthttp.Cookie) *fasthttp.Cookie { for _, c := range cookies { if bytes.Equal(key, c.Key()) { if pathMatch(path, c.Path()) { return c } } } return nil } // pathMatch determines whether the request path matches the cookie path // according to RFC 6265 section 5.1.4. func pathMatch(reqPath, cookiePath []byte) bool { if len(reqPath) == 0 { reqPath = []byte("/") } if len(cookiePath) == 0 { cookiePath = []byte("/") } if bytes.Equal(reqPath, cookiePath) { return true } if !bytes.HasPrefix(reqPath, cookiePath) { return false } if cookiePath[len(cookiePath)-1] == '/' { return true } return len(reqPath) > len(cookiePath) && reqPath[len(cookiePath)] == '/' } // domainMatch reports whether host domain-matches the given cookie domain. func domainMatch(host, domain string) bool { host = utilsstrings.UnsafeToLower(host) if host == domain { return true } return strings.HasSuffix(host, "."+domain) } ================================================ FILE: client/cookiejar_test.go ================================================ package client import ( "bytes" "testing" "time" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) func checkKeyValue(t *testing.T, cj *CookieJar, cookie *fasthttp.Cookie, uri *fasthttp.URI, n int) { t.Helper() cs := cj.Get(uri) require.GreaterOrEqual(t, len(cs), n) c := cs[n-1] require.NotNil(t, c) require.Equal(t, string(c.Key()), string(cookie.Key())) require.Equal(t, string(c.Value()), string(cookie.Value())) } func Test_CookieJarGet(t *testing.T) { t.Parallel() url := []byte("http://fasthttp.com/") url1 := []byte("http://fasthttp.com/make/") url11 := []byte("http://fasthttp.com/hola") url2 := []byte("http://fasthttp.com/make/fasthttp") url3 := []byte("http://fasthttp.com/make/fasthttp/great") cj := &CookieJar{} c1 := &fasthttp.Cookie{} c1.SetKey("k") c1.SetValue("v") c1.SetPath("/make/") c2 := &fasthttp.Cookie{} c2.SetKey("kk") c2.SetValue("vv") c2.SetPath("/make/fasthttp") c3 := &fasthttp.Cookie{} c3.SetKey("kkk") c3.SetValue("vvv") c3.SetPath("/make/fasthttp/great") uri := fasthttp.AcquireURI() require.NoError(t, uri.Parse(nil, url)) uri1 := fasthttp.AcquireURI() require.NoError(t, uri1.Parse(nil, url1)) uri11 := fasthttp.AcquireURI() require.NoError(t, uri11.Parse(nil, url11)) uri2 := fasthttp.AcquireURI() require.NoError(t, uri2.Parse(nil, url2)) uri3 := fasthttp.AcquireURI() require.NoError(t, uri3.Parse(nil, url3)) cj.Set(uri1, c1, c2, c3) cookies := cj.Get(uri1) require.Len(t, cookies, 1) for _, cookie := range cookies { require.True(t, bytes.HasPrefix(uri1.Path(), cookie.Path())) } cookies = cj.Get(uri11) require.Empty(t, cookies) cookies = cj.Get(uri2) require.Len(t, cookies, 2) for _, cookie := range cookies { require.True(t, bytes.HasPrefix(uri2.Path(), cookie.Path())) } cookies = cj.Get(uri3) require.Len(t, cookies, 3) for _, cookie := range cookies { require.True(t, bytes.HasPrefix(uri3.Path(), cookie.Path())) } cookies = cj.Get(uri) require.Empty(t, cookies) } func Test_CookieJarGetExpired(t *testing.T) { t.Parallel() url1 := []byte("http://fasthttp.com/make/") uri1 := fasthttp.AcquireURI() require.NoError(t, uri1.Parse(nil, url1)) c1 := &fasthttp.Cookie{} c1.SetKey("k") c1.SetValue("v") c1.SetExpire(time.Now().Add(-time.Hour)) cj := &CookieJar{} cj.Set(uri1, c1) cookies := cj.Get(uri1) require.Empty(t, cookies) } func Test_CookieJarSet(t *testing.T) { t.Parallel() url := []byte("http://fasthttp.com/hello/world") cj := &CookieJar{} cookie := &fasthttp.Cookie{} cookie.SetKey("k") cookie.SetValue("v") uri := fasthttp.AcquireURI() require.NoError(t, uri.Parse(nil, url)) cj.Set(uri, cookie) checkKeyValue(t, cj, cookie, uri, 1) } func Test_CookieJarSetRepeatedCookieKeys(t *testing.T) { t.Parallel() host := "fast.http" cj := &CookieJar{} uri := fasthttp.AcquireURI() uri.SetHost(host) cookie := &fasthttp.Cookie{} cookie.SetKey("k") cookie.SetValue("v") cookie2 := &fasthttp.Cookie{} cookie2.SetKey("k") cookie2.SetValue("v2") cookie3 := &fasthttp.Cookie{} cookie3.SetKey("key") cookie3.SetValue("value") cj.Set(uri, cookie, cookie2, cookie3) cookies := cj.Get(uri) require.Len(t, cookies, 2) require.Equal(t, cookies[0].String(), cookie2.String()) require.True(t, bytes.Equal(cookies[0].Value(), cookie2.Value())) } func Test_CookieJarSetKeyValue(t *testing.T) { t.Parallel() host := "fast.http" cj := &CookieJar{} uri := fasthttp.AcquireURI() uri.SetHost(host) cj.SetKeyValue(host, "k", "v") cj.SetKeyValue(host, "key", "value") cj.SetKeyValue(host, "k", "vv") cj.SetKeyValue(host, "key", "value2") cookies := cj.Get(uri) require.Len(t, cookies, 2) } func Test_CookieJarGetFromResponse(t *testing.T) { t.Parallel() res := fasthttp.AcquireResponse() host := []byte("fast.http") uri := fasthttp.AcquireURI() uri.SetHostBytes(host) c := &fasthttp.Cookie{} c.SetKey("key") c.SetValue("val") c2 := &fasthttp.Cookie{} c2.SetKey("k") c2.SetValue("v") c3 := &fasthttp.Cookie{} c3.SetKey("kk") c3.SetValue("vv") res.Header.SetStatusCode(200) res.Header.SetCookie(c) res.Header.SetCookie(c2) res.Header.SetCookie(c3) cj := &CookieJar{} cj.parseCookiesFromResp(host, nil, res) cookies := cj.Get(uri) require.Len(t, cookies, 3) values := map[string]string{"key": "val", "k": "v", "kk": "vv"} for _, c := range cookies { k := string(c.Key()) v, ok := values[k] require.True(t, ok) require.Equal(t, v, string(c.Value())) delete(values, k) } require.Empty(t, values) } func Test_CookieJar_HostPort(t *testing.T) { t.Parallel() jar := &CookieJar{} uriSet := fasthttp.AcquireURI() require.NoError(t, uriSet.Parse(nil, []byte("http://fasthttp.com:80/path"))) c := &fasthttp.Cookie{} c.SetKey("k") c.SetValue("v") jar.Set(uriSet, c) // retrieve using a different port to ensure port is ignored uriGet := fasthttp.AcquireURI() require.NoError(t, uriGet.Parse(nil, []byte("http://fasthttp.com:8080/path"))) cookies := jar.Get(uriGet) require.Len(t, cookies, 1) require.Equal(t, "k", string(cookies[0].Key())) require.Equal(t, "v", string(cookies[0].Value())) require.Equal(t, "fasthttp.com", string(cookies[0].Domain())) } func Test_CookieJar_Domain(t *testing.T) { t.Parallel() jar := &CookieJar{} uri := fasthttp.AcquireURI() require.NoError(t, uri.Parse(nil, []byte("http://sub.example.com/"))) c := &fasthttp.Cookie{} c.SetKey("k") c.SetValue("v") c.SetDomain("example.com") jar.Set(uri, c) uri2 := fasthttp.AcquireURI() require.NoError(t, uri2.Parse(nil, []byte("http://other.example.com/"))) cookies := jar.Get(uri2) require.Len(t, cookies, 1) require.Equal(t, "k", string(cookies[0].Key())) require.Equal(t, "v", string(cookies[0].Value())) } func Test_CookieJar_Secure(t *testing.T) { t.Parallel() jar := &CookieJar{} uriHTTP := fasthttp.AcquireURI() require.NoError(t, uriHTTP.Parse(nil, []byte("http://example.com/"))) c := &fasthttp.Cookie{} c.SetKey("k") c.SetValue("v") c.SetSecure(true) jar.Set(uriHTTP, c) cookies := jar.Get(uriHTTP) require.Empty(t, cookies) uriHTTPS := fasthttp.AcquireURI() require.NoError(t, uriHTTPS.Parse(nil, []byte("https://example.com/"))) cookies = jar.Get(uriHTTPS) require.Len(t, cookies, 1) require.Equal(t, "k", string(cookies[0].Key())) require.Equal(t, "v", string(cookies[0].Value())) } func Test_CookieJar_PathMatch(t *testing.T) { t.Parallel() jar := &CookieJar{} setURI := fasthttp.AcquireURI() require.NoError(t, setURI.Parse(nil, []byte("http://example.com/api"))) c := &fasthttp.Cookie{} c.SetKey("k") c.SetValue("v") c.SetPath("/api") jar.Set(setURI, c) uriExact := fasthttp.AcquireURI() require.NoError(t, uriExact.Parse(nil, []byte("http://example.com/api"))) require.Len(t, jar.Get(uriExact), 1) uriChild := fasthttp.AcquireURI() require.NoError(t, uriChild.Parse(nil, []byte("http://example.com/api/v1"))) require.Len(t, jar.Get(uriChild), 1) uriNoMatch := fasthttp.AcquireURI() require.NoError(t, uriNoMatch.Parse(nil, []byte("http://example.com/apiv1"))) require.Empty(t, jar.Get(uriNoMatch)) } ================================================ FILE: client/core.go ================================================ // Core pipeline scaffolds request execution for Fiber's HTTP client, including // hook invocation, retry orchestration, and timeout management around fasthttp // transports. package client import ( "context" "errors" "net" "strconv" "strings" "sync" "github.com/valyala/fasthttp" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/addon/retry" ) const boundary = "FiberFormBoundary" // RequestHook is a function invoked before the request is sent. // It receives a Client and a Request, allowing you to modify the Request or Client data. type RequestHook func(*Client, *Request) error // ResponseHook is a function invoked after a response is received. // It receives a Client, Response, and Request, allowing you to modify the Response data // or perform actions based on the response. type ResponseHook func(*Client, *Response, *Request) error // RetryConfig is an alias for the `retry.Config` type from the `addon/retry` package. type RetryConfig = retry.Config // addMissingPort appends the appropriate port number to the given address if it doesn't have one. // If isTLS is true, it uses port 443; otherwise, it uses port 80. func addMissingPort(addr string, isTLS bool) string { //revive:disable-line:flag-parameter if strings.IndexByte(addr, ':') != -1 { return addr } port := 80 if isTLS { port = 443 } return net.JoinHostPort(addr, strconv.Itoa(port)) } // core stores middleware and plugin definitions and defines the request execution process. type core struct { client *Client req *Request ctx context.Context //nolint:containedctx // Context is needed here. } // getRetryConfig returns a copy of the client's retry configuration. func (c *core) getRetryConfig() *RetryConfig { c.client.mu.RLock() defer c.client.mu.RUnlock() cfg := c.client.RetryConfig() if cfg == nil { return nil } return &RetryConfig{ InitialInterval: cfg.InitialInterval, MaxBackoffTime: cfg.MaxBackoffTime, Multiplier: cfg.Multiplier, MaxRetryCount: cfg.MaxRetryCount, } } // execFunc is the core logic to send the request and receive the response. // It leverages the fasthttp client, optionally with retries or redirects. func (c *core) execFunc() (*Response, error) { // do not close, these will be returned to the pool errChan := acquireErrChan() respChan := acquireResponseChan() cfg := c.getRetryConfig() go func() { // retain both channels until they are drained defer releaseErrChan(errChan) defer releaseResponseChan(respChan) reqv := fasthttp.AcquireRequest() defer fasthttp.ReleaseRequest(reqv) respv := fasthttp.AcquireResponse() defer func() { if respv != nil { fasthttp.ReleaseResponse(respv) } }() c.req.RawRequest.CopyTo(reqv) if bodyStream := c.req.RawRequest.BodyStream(); bodyStream != nil { reqv.SetBodyStream(bodyStream, c.req.RawRequest.Header.ContentLength()) } var err error if cfg != nil { // Use an exponential backoff retry strategy. err = retry.NewExponentialBackoff(*cfg).Retry(func() error { if c.req.maxRedirects > 0 && (string(reqv.Header.Method()) == fiber.MethodGet || string(reqv.Header.Method()) == fiber.MethodHead) { return c.client.DoRedirects(reqv, respv, c.req.maxRedirects) } return c.client.Do(reqv, respv) }) } else { if c.req.maxRedirects > 0 && (string(reqv.Header.Method()) == fiber.MethodGet || string(reqv.Header.Method()) == fiber.MethodHead) { err = c.client.DoRedirects(reqv, respv, c.req.maxRedirects) } else { err = c.client.Do(reqv, respv) } } if err != nil { errChan <- err return } resp := AcquireResponse() resp.setClient(c.client) resp.setRequest(c.req) // Swap the fasthttp response with the Fiber response's RawResponse field. // This is required, as (*fasthttp.Response).CopyTo() explicitly does not // copy body streams. // // See: https://github.com/valyala/fasthttp/blob/v1.69.0/http.go#L909-L923 // // The defer statement above ensures that the original RawResponse // (now stored in respv) will be properly released. resp.RawResponse, respv = respv, resp.RawResponse respChan <- resp }() select { case err := <-errChan: return nil, err case resp := <-respChan: return resp, nil case <-c.ctx.Done(): go func() { // drain the channels and release the response select { case resp := <-respChan: ReleaseResponse(resp) case <-errChan: } }() return nil, ErrTimeoutOrCancel } } // preHooks runs all request hooks before sending the request. func (c *core) preHooks() error { c.client.mu.Lock() defer c.client.mu.Unlock() for _, f := range c.client.userRequestHooks { if err := f(c.client, c.req); err != nil { return err } } for _, f := range c.client.builtinRequestHooks { if err := f(c.client, c.req); err != nil { return err } } return nil } // afterHooks runs all response hooks after receiving the response. func (c *core) afterHooks(resp *Response) error { c.client.mu.Lock() defer c.client.mu.Unlock() for _, f := range c.client.builtinResponseHooks { if err := f(c.client, resp, c.req); err != nil { return err } } for _, f := range c.client.userResponseHooks { if err := f(c.client, resp, c.req); err != nil { return err } } return nil } // timeout applies the configured timeout to the request, if any. func (c *core) timeout() context.CancelFunc { var cancel context.CancelFunc if c.req.timeout > 0 { c.ctx, cancel = context.WithTimeout(c.ctx, c.req.timeout) } else if c.client.timeout > 0 { c.ctx, cancel = context.WithTimeout(c.ctx, c.client.timeout) } return cancel } // execute runs all hooks, applies timeouts, sends the request, and runs response hooks. func (c *core) execute(ctx context.Context, client *Client, req *Request) (*Response, error) { // Store references locally. c.ctx = ctx c.client = client c.req = req // Execute pre request hooks (user-defined and built-in). if err := c.preHooks(); err != nil { return nil, err } // Apply timeout if specified. cancel := c.timeout() if cancel != nil { defer cancel() } // Perform the actual HTTP request. resp, err := c.execFunc() if err != nil { return nil, err } // Execute after response hooks (built-in and then user-defined). if err := c.afterHooks(resp); err != nil { resp.Close() return nil, err } return resp, nil } var responseChanPool = &sync.Pool{ New: func() any { return make(chan *Response) }, } // acquireResponseChan returns an empty, non-closed *Response channel from the pool. // The returned channel may be returned to the pool with releaseResponseChan func acquireResponseChan() chan *Response { ch, ok := responseChanPool.Get().(chan *Response) if !ok { panic(errResponseChanTypeAssertion) } return ch } // releaseResponseChan returns the *Response channel to the pool. // It's the caller's responsibility to ensure that: // - the channel is not closed // - the channel is drained before returning it // - the channel is not reused after returning it func releaseResponseChan(ch chan *Response) { responseChanPool.Put(ch) } var errChanPool = &sync.Pool{ New: func() any { return make(chan error) }, } // acquireErrChan returns an empty, non-closed error channel from the pool. // The returned channel may be returned to the pool with releaseErrChan func acquireErrChan() chan error { ch, ok := errChanPool.Get().(chan error) if !ok { panic(errChanErrorTypeAssertion) } return ch } // releaseErrChan returns the error channel to the pool. // It's caller's responsibility to ensure that: // - the channel is not closed // - the channel is drained before returning it // - the channel is not reused after returning it func releaseErrChan(ch chan error) { errChanPool.Put(ch) } // newCore returns a new core object. func newCore() *core { return &core{} } var ( ErrTimeoutOrCancel = errors.New("timeout or cancel") ErrURLFormat = errors.New("the URL is incorrect") ErrNotSupportSchema = errors.New("protocol not supported; only http or https are allowed") ErrFileNoName = errors.New("the file should have a name") ErrBodyType = errors.New("the body type should be []byte") ErrNotSupportSaveMethod = errors.New("only file paths and io.Writer are supported") ErrBodyTypeNotSupported = errors.New("the body type is not supported") ) ================================================ FILE: client/core_test.go ================================================ package client import ( "bytes" "context" "crypto/tls" "errors" "net" "sync" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" "github.com/valyala/fasthttp/fasthttputil" "github.com/gofiber/fiber/v3" ) func Test_AddMissing_Port(t *testing.T) { t.Parallel() type args struct { addr string isTLS bool } tests := []struct { name string want string args args }{ { name: "do anything", args: args{ addr: "example.com:1234", }, want: "example.com:1234", }, { name: "add 80 port", args: args{ addr: "example.com", }, want: "example.com:80", }, { name: "add 443 port", args: args{ addr: "example.com", isTLS: true, }, want: "example.com:443", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() require.Equal(t, tt.want, addMissingPort(tt.args.addr, tt.args.isTLS)) }) } } func Test_Exec_Func(t *testing.T) { t.Parallel() ln := fasthttputil.NewInmemoryListener() app := fiber.New() app.Get("/normal", func(c fiber.Ctx) error { return c.SendString(c.Hostname()) }) app.Get("/return-error", func(_ fiber.Ctx) error { return errors.New("the request is error") }) app.Get("/redirect", func(c fiber.Ctx) error { return c.Redirect().Status(fiber.StatusFound).To("/normal") }) app.Get("/hang-up", func(c fiber.Ctx) error { time.Sleep(time.Second) return c.SendString(c.Hostname() + " hang up") }) go func() { assert.NoError(t, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) }() time.Sleep(300 * time.Millisecond) t.Run("normal request", func(t *testing.T) { t.Parallel() core, client, req := newCore(), New(), AcquireRequest() core.ctx = context.Background() core.client = client core.req = req client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) req.RawRequest.SetRequestURI("http://example.com/normal") resp, err := core.execFunc() require.NoError(t, err) require.Equal(t, 200, resp.RawResponse.StatusCode()) require.Equal(t, "example.com", string(resp.RawResponse.Body())) }) t.Run("follow redirect with retry config", func(t *testing.T) { t.Parallel() core, client, req := newCore(), New(), AcquireRequest() core.ctx = context.Background() core.client = client core.req = req client.SetRetryConfig(&RetryConfig{MaxRetryCount: 1}) client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) req.SetMaxRedirects(1) req.RawRequest.Header.SetMethod(fiber.MethodGet) req.RawRequest.SetRequestURI("http://example.com/redirect") resp, err := core.execFunc() require.NoError(t, err) require.Equal(t, 200, resp.RawResponse.StatusCode()) require.Equal(t, "example.com", string(resp.RawResponse.Body())) }) t.Run("the request return an error", func(t *testing.T) { t.Parallel() core, client, req := newCore(), New(), AcquireRequest() core.ctx = context.Background() core.client = client core.req = req client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) req.RawRequest.SetRequestURI("http://example.com/return-error") resp, err := core.execFunc() require.NoError(t, err) require.Equal(t, 500, resp.RawResponse.StatusCode()) require.Equal(t, "the request is error", string(resp.RawResponse.Body())) }) t.Run("the request timeout", func(t *testing.T) { t.Parallel() core, client, req := newCore(), New(), AcquireRequest() ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() core.ctx = ctx core.client = client core.req = req client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) req.RawRequest.SetRequestURI("http://example.com/hang-up") _, err := core.execFunc() require.Equal(t, ErrTimeoutOrCancel, err) }) t.Run("cancel drains errChan", func(t *testing.T) { core, client, req := newCore(), New(), AcquireRequest() ctx, cancel := context.WithCancel(context.Background()) defer cancel() core.ctx = ctx core.client = client core.req = req req.RawRequest.SetRequestURI("http://example.com/drain-err") blockingTransport := newBlockingErrTransport(errors.New("upstream failure")) client.transport = blockingTransport defer blockingTransport.release() type execResult struct { resp *Response err error } resultCh := make(chan execResult, 1) go func() { resp, err := core.execFunc() resultCh <- execResult{resp: resp, err: err} }() select { case <-blockingTransport.called: case <-time.After(time.Second): t.Fatal("transport Do was not invoked") } cancel() var result execResult select { case result = <-resultCh: case <-time.After(time.Second): t.Fatal("execFunc did not return") } require.Nil(t, result.resp) require.ErrorIs(t, result.err, ErrTimeoutOrCancel) blockingTransport.release() select { case <-blockingTransport.finished: case <-time.After(time.Second): t.Fatal("transport Do did not finish") } }) } func Test_Execute(t *testing.T) { t.Parallel() ln := fasthttputil.NewInmemoryListener() app := fiber.New() app.Get("/normal", func(c fiber.Ctx) error { return c.SendString(c.Hostname()) }) app.Get("/return-error", func(_ fiber.Ctx) error { return errors.New("the request is error") }) app.Get("/hang-up", func(c fiber.Ctx) error { time.Sleep(time.Second) return c.SendString(c.Hostname() + " hang up") }) go func() { assert.NoError(t, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) }() t.Run("add user request hooks", func(t *testing.T) { t.Parallel() core, client, req := newCore(), New(), AcquireRequest() client.AddRequestHook(func(_ *Client, _ *Request) error { require.Equal(t, "http://example.com", req.URL()) return nil }) client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) req.SetURL("http://example.com") resp, err := core.execute(context.Background(), client, req) require.NoError(t, err) require.Equal(t, "Not Found", string(resp.RawResponse.Body())) }) t.Run("add user response hooks", func(t *testing.T) { t.Parallel() core, client, req := newCore(), New(), AcquireRequest() client.AddResponseHook(func(_ *Client, _ *Response, req *Request) error { require.Equal(t, "http://example.com", req.URL()) return nil }) client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) req.SetURL("http://example.com") resp, err := core.execute(context.Background(), client, req) require.NoError(t, err) require.Equal(t, "Not Found", string(resp.RawResponse.Body())) }) t.Run("no timeout", func(t *testing.T) { t.Parallel() core, client, req := newCore(), New(), AcquireRequest() client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) req.SetURL("http://example.com/hang-up") resp, err := core.execute(context.Background(), client, req) require.NoError(t, err) require.Equal(t, "example.com hang up", string(resp.RawResponse.Body())) }) t.Run("client timeout", func(t *testing.T) { t.Parallel() core, client, req := newCore(), New(), AcquireRequest() client.SetTimeout(500 * time.Millisecond) client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) req.SetURL("http://example.com/hang-up") _, err := core.execute(context.Background(), client, req) require.Equal(t, ErrTimeoutOrCancel, err) }) t.Run("request timeout", func(t *testing.T) { t.Parallel() core, client, req := newCore(), New(), AcquireRequest() client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) req.SetURL("http://example.com/hang-up"). SetTimeout(300 * time.Millisecond) _, err := core.execute(context.Background(), client, req) require.Equal(t, ErrTimeoutOrCancel, err) }) t.Run("request timeout has higher level", func(t *testing.T) { t.Parallel() core, client, req := newCore(), New(), AcquireRequest() client.SetTimeout(30 * time.Millisecond) client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) req.SetURL("http://example.com/hang-up"). SetTimeout(3000 * time.Millisecond) resp, err := core.execute(context.Background(), client, req) require.NoError(t, err) require.Equal(t, "example.com hang up", string(resp.RawResponse.Body())) }) } type blockingErrTransport struct { err error called chan struct{} unblock chan struct{} finished chan struct{} calledOnce sync.Once releaseOnce sync.Once finishedOnce sync.Once } func newBlockingErrTransport(err error) *blockingErrTransport { return &blockingErrTransport{ err: err, called: make(chan struct{}), unblock: make(chan struct{}), finished: make(chan struct{}), } } func (b *blockingErrTransport) Do(_ *fasthttp.Request, _ *fasthttp.Response) error { b.calledOnce.Do(func() { close(b.called) }) <-b.unblock b.finishedOnce.Do(func() { close(b.finished) }) return b.err } func (b *blockingErrTransport) DoTimeout(req *fasthttp.Request, resp *fasthttp.Response, _ time.Duration) error { return b.Do(req, resp) } func (b *blockingErrTransport) DoDeadline(req *fasthttp.Request, resp *fasthttp.Response, _ time.Time) error { return b.Do(req, resp) } func (b *blockingErrTransport) DoRedirects(req *fasthttp.Request, resp *fasthttp.Response, _ int) error { return b.Do(req, resp) } func (*blockingErrTransport) CloseIdleConnections() { } func (*blockingErrTransport) TLSConfig() *tls.Config { return nil } func (*blockingErrTransport) SetTLSConfig(_ *tls.Config) { } func (*blockingErrTransport) SetDial(_ fasthttp.DialFunc) { } func (*blockingErrTransport) Client() any { return nil } func (*blockingErrTransport) StreamResponseBody() bool { return false } func (*blockingErrTransport) SetStreamResponseBody(_ bool) { } func (b *blockingErrTransport) release() { b.releaseOnce.Do(func() { close(b.unblock) }) } func Test_Core_RequestBodyStream(t *testing.T) { t.Parallel() t.Run("request with body stream is properly copied", func(t *testing.T) { t.Parallel() app := fiber.New() app.Post("/echo", func(c fiber.Ctx) error { body := c.Body() return c.Send(body) }) ln := fasthttputil.NewInmemoryListener() go func() { err := app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true}) if err != nil { panic(err) } }() t.Cleanup(func() { require.NoError(t, app.Shutdown()) }) client := New().SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) // Create a request with a body stream using SetRawBody which properly sets the body streamContent := "this is streamed body content" req := AcquireRequest().SetClient(client) req.SetURL("http://example.com/echo") req.SetMethod(fiber.MethodPost) req.SetRawBody([]byte(streamContent)) resp, err := req.Send() require.NoError(t, err) defer resp.Close() require.Equal(t, streamContent, string(resp.Body())) }) t.Run("request body stream with content length", func(t *testing.T) { t.Parallel() resultCh := make(chan struct { body string length int }, 1) app := fiber.New() app.Post("/check-length", func(c fiber.Ctx) error { resultCh <- struct { body string length int }{ body: string(c.Body()), length: c.Request().Header.ContentLength(), } return c.SendString("ok") }) ln := fasthttputil.NewInmemoryListener() go func() { err := app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true}) if err != nil { panic(err) } }() t.Cleanup(func() { require.NoError(t, app.Shutdown()) }) client := New().SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) streamContent := "body with known length" req := AcquireRequest().SetClient(client) req.SetURL("http://example.com/check-length") req.SetMethod(fiber.MethodPost) req.SetRawBody([]byte(streamContent)) resp, err := req.Send() require.NoError(t, err) defer resp.Close() result := <-resultCh require.Equal(t, streamContent, result.body) require.Equal(t, len(streamContent), result.length) }) t.Run("raw body stream survives CopyTo", func(t *testing.T) { t.Parallel() const streamContent = "streaming raw request body" resultCh := make(chan struct { body string length int }, 1) app := fiber.New() app.Post("/copy-to-body-stream", func(c fiber.Ctx) error { body := string(c.Body()) resultCh <- struct { body string length int }{ body: body, length: c.Request().Header.ContentLength(), } return c.SendString(body) }) ln := fasthttputil.NewInmemoryListener() go func() { err := app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true}) if err != nil { panic(err) } }() t.Cleanup(func() { require.NoError(t, app.Shutdown()) }) core, client, req := newCore(), New(), AcquireRequest() core.ctx = context.Background() core.client = client core.req = req client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) req.RawRequest.SetRequestURI("http://example.com/copy-to-body-stream") req.RawRequest.Header.SetMethod(fiber.MethodPost) req.RawRequest.SetBodyStream(bytes.NewBufferString(streamContent), len(streamContent)) resp, err := core.execFunc() require.NoError(t, err) defer resp.Close() result := <-resultCh require.Equal(t, streamContent, string(resp.Body())) require.Equal(t, streamContent, result.body) require.Equal(t, len(streamContent), result.length) }) } ================================================ FILE: client/errors.go ================================================ package client import ( "errors" ) var ( errResponseChanTypeAssertion = errors.New("failed to type-assert to *Response") errChanErrorTypeAssertion = errors.New("failed to type-assert to chan error") errRequestTypeAssertion = errors.New("failed to type-assert to *Request") errFileTypeAssertion = errors.New("failed to type-assert to *File") errCookieJarTypeAssertion = errors.New("failed to type-assert to *CookieJar") errSyncPoolBuffer = errors.New("failed to retrieve buffer from a sync.Pool") ) ================================================ FILE: client/helper_test.go ================================================ package client import ( "net" "testing" "time" "github.com/fxamacker/cbor/v2" "github.com/gofiber/fiber/v3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp/fasthttputil" ) type testServer struct { app *fiber.App ch chan struct{} ln *fasthttputil.InmemoryListener tb testing.TB } func startTestServer(tb testing.TB, beforeStarting func(app *fiber.App)) *testServer { tb.Helper() ln := fasthttputil.NewInmemoryListener() app := fiber.New(fiber.Config{ CBOREncoder: cbor.Marshal, CBORDecoder: cbor.Unmarshal, }) if beforeStarting != nil { beforeStarting(app) } ch := make(chan struct{}) go func() { err := app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true}) assert.NoError(tb, err) close(ch) }() return &testServer{ app: app, ch: ch, ln: ln, tb: tb, } } func (ts *testServer) stop() { ts.tb.Helper() if err := ts.app.Shutdown(); err != nil { ts.tb.Fatal(err) } select { case <-ts.ch: case <-time.After(time.Second): ts.tb.Fatalf("timeout when waiting for server close") } } func (ts *testServer) dial() func(addr string) (net.Conn, error) { ts.tb.Helper() return func(_ string) (net.Conn, error) { return ts.ln.Dial() } } func createHelperServer(tb testing.TB) (app *fiber.App, dial func(addr string) (net.Conn, error), start func()) { //nolint:nonamedreturns // gocritic unnamedResult requires explicit result identifiers for helper components tb.Helper() ln := fasthttputil.NewInmemoryListener() app = fiber.New() dial = func(_ string) (net.Conn, error) { return ln.Dial() } start = func() { require.NoError(tb, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) } return app, dial, start } func testRequest(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Request), excepted string, count ...int) { t.Helper() app, ln, start := createHelperServer(t) app.Get("/", handler) go start() c := 1 if len(count) > 0 { c = count[0] } client := New().SetDial(ln) for i := 0; i < c; i++ { req := AcquireRequest().SetClient(client) wrapAgent(req) resp, err := req.Get("http://example.com") require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, excepted, resp.String()) resp.Close() } } func testRequestFail(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Request), excepted error, count ...int) { t.Helper() app, ln, start := createHelperServer(t) app.Get("/", handler) go start() c := 1 if len(count) > 0 { c = count[0] } client := New().SetDial(ln) for i := 0; i < c; i++ { req := AcquireRequest().SetClient(client) wrapAgent(req) _, err := req.Get("http://example.com") require.Equal(t, excepted.Error(), err.Error()) } } func testClient(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Client), excepted string, count ...int) { //nolint:unparam // maybe needed t.Helper() app, ln, start := createHelperServer(t) app.Get("/", handler) go start() c := 1 if len(count) > 0 { c = count[0] } for i := 0; i < c; i++ { client := New().SetDial(ln) wrapAgent(client) resp, err := client.Get("http://example.com") require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, excepted, resp.String()) resp.Close() } } ================================================ FILE: client/hooks.go ================================================ package client import ( "crypto/rand" "fmt" "io" "mime/multipart" "os" "path/filepath" "regexp" "strconv" "strings" "sync" "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) var protocolCheck = regexp.MustCompile(`^https?://.*$`) var fileBufPool = sync.Pool{ New: func() any { b := make([]byte, 1<<20) // 1MB buffer return &b }, } const ( headerAccept = "Accept" applicationJSON = "application/json" applicationCBOR = "application/cbor" applicationXML = "application/xml" applicationForm = "application/x-www-form-urlencoded" multipartFormData = "multipart/form-data" letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" ) // unsafeRandString returns a random string of length n. // An error is returned if the random source fails. func unsafeRandString(n int) (string, error) { inputLength := byte(len(letterBytes)) // Compute the largest multiple of inputLength ≤ 256 to avoid modulo bias. // Any byte ≥ max will be rejected and re‑read. maxLength := byte(256 - (256 % int(inputLength))) out := make([]byte, n) buf := make([]byte, n) // Read n raw bytes in one shot if _, err := rand.Read(buf); err != nil { return "", fmt.Errorf("rand.Read failed: %w", err) } for i, b := range buf { // Reject values ≥ maxLength for b >= maxLength { if _, err := rand.Read(buf[i : i+1]); err != nil { return "", fmt.Errorf("rand.Read failed: %w", err) } b = buf[i] } out[i] = letterBytes[b%inputLength] } return utils.UnsafeString(out), nil } // parserRequestURL sets options for the hostclient and normalizes the URL. // It merges the baseURL with the request URI if needed and applies query and path parameters. func parserRequestURL(c *Client, req *Request) error { // Split URL into path and query parts using Cut (avoids allocation) uri, queryPart, _ := strings.Cut(req.url, "?") // If the URL doesn't start with http/https, prepend the baseURL. if !protocolCheck.MatchString(uri) { uri = c.baseURL + uri if !protocolCheck.MatchString(uri) { return ErrURLFormat } } // Set path parameters from the request and client. for key, val := range req.path.All() { uri = strings.ReplaceAll(uri, ":"+key, val) } for key, val := range c.path.All() { uri = strings.ReplaceAll(uri, ":"+key, val) } // Set the URI in the raw request. disablePathNormalizing := c.DisablePathNormalizing() || req.DisablePathNormalizing() req.RawRequest.SetRequestURI(uri) req.RawRequest.URI().DisablePathNormalizing = disablePathNormalizing if disablePathNormalizing { req.RawRequest.URI().SetPathBytes(req.RawRequest.URI().PathOriginal()) } // Merge query parameters (split query from fragment using Cut). queryOnly, hashPart, _ := strings.Cut(queryPart, "#") args := fasthttp.AcquireArgs() defer fasthttp.ReleaseArgs(args) args.Parse(queryOnly) for key, value := range c.params.All() { args.AddBytesKV(key, value) } for key, value := range req.params.All() { args.AddBytesKV(key, value) } req.RawRequest.URI().SetQueryStringBytes(utils.CopyBytes(args.QueryString())) req.RawRequest.URI().SetHash(hashPart) return nil } // parserRequestHeader merges client and request headers, and sets headers automatically based on the request data. // It also sets the User-Agent and Referer headers, and applies any cookies from the cookie jar. func parserRequestHeader(c *Client, req *Request) error { // Set HTTP method. req.RawRequest.Header.SetMethod(req.Method()) // Merge headers from the client. for key, value := range c.header.All() { req.RawRequest.Header.AddBytesKV(key, value) } // Merge headers from the request. for key, value := range req.header.All() { req.RawRequest.Header.AddBytesKV(key, value) } // Set Content-Type and Accept headers based on the request body type. switch req.bodyType { case jsonBody: req.RawRequest.Header.SetContentType(applicationJSON) req.RawRequest.Header.Set(headerAccept, applicationJSON) case xmlBody: req.RawRequest.Header.SetContentType(applicationXML) case cborBody: req.RawRequest.Header.SetContentType(applicationCBOR) case formBody: req.RawRequest.Header.SetContentType(applicationForm) case filesBody: req.RawRequest.Header.SetContentType(multipartFormData) // If boundary is default, append a random string to it. if req.boundary == boundary { randStr, err := unsafeRandString(16) if err != nil { return fmt.Errorf("boundary generation: %w", err) } req.boundary += randStr } req.RawRequest.Header.SetMultipartFormBoundary(req.boundary) default: // noBody or rawBody do not require special handling here. } // Set User-Agent header. req.RawRequest.Header.SetUserAgent(defaultUserAgent) if c.userAgent != "" { req.RawRequest.Header.SetUserAgent(c.userAgent) } if req.userAgent != "" { req.RawRequest.Header.SetUserAgent(req.userAgent) } // Set Referer header. req.RawRequest.Header.SetReferer(c.referer) if req.referer != "" { req.RawRequest.Header.SetReferer(req.referer) } // Set cookies from the cookie jar if available. if c.cookieJar != nil { c.cookieJar.dumpCookiesToReq(req.RawRequest) } // Set cookies from the client. for key, val := range c.cookies.All() { req.RawRequest.Header.SetCookie(key, val) } // Set cookies from the request. for key, val := range req.cookies.All() { req.RawRequest.Header.SetCookie(key, val) } return nil } // parserRequestBody serializes the request body based on its type and sets it into the RawRequest. func parserRequestBody(c *Client, req *Request) error { switch req.bodyType { case jsonBody: body, err := c.jsonMarshal(req.body) if err != nil { return err } req.RawRequest.SetBody(body) case xmlBody: body, err := c.xmlMarshal(req.body) if err != nil { return err } req.RawRequest.SetBody(body) case cborBody: body, err := c.cborMarshal(req.body) if err != nil { return err } req.RawRequest.SetBody(body) case formBody: req.RawRequest.SetBody(req.formData.QueryString()) case filesBody: return parserRequestBodyFile(req) case rawBody: if body, ok := req.body.([]byte); ok { //nolint:revive // ignore simplicity req.RawRequest.SetBody(body) } else { return ErrBodyType } case noBody: // No body to set. return nil default: return ErrBodyTypeNotSupported } return nil } // parserRequestBodyFile handles the case where the request contains files to be uploaded. func parserRequestBodyFile(req *Request) error { mw := multipart.NewWriter(req.RawRequest.BodyWriter()) err := mw.SetBoundary(req.boundary) if err != nil { return fmt.Errorf("set boundary error: %w", err) } defer func() { e := mw.Close() if e != nil { // Close errors are typically ignored. return } }() // Add form data. for key, value := range req.formData.All() { err = mw.WriteField(utils.UnsafeString(key), utils.UnsafeString(value)) if err != nil { break } } if err != nil { return fmt.Errorf("write formdata error: %w", err) } // Add files. fileBuf, ok := fileBufPool.Get().(*[]byte) if !ok { return errSyncPoolBuffer } defer fileBufPool.Put(fileBuf) for i, f := range req.files { if f.name == "" && f.path == "" { return ErrFileNoName } // Set the file name if not provided. if f.name == "" && f.path != "" { f.path = filepath.Clean(f.path) f.name = filepath.Base(f.path) } // Set the field name if not provided. if f.fieldName == "" { f.fieldName = "file" + strconv.Itoa(i+1) } if err := addFormFile(mw, f, fileBuf); err != nil { return err } } return nil } func addFormFile(mw *multipart.Writer, f *File, fileBuf *[]byte) error { // If reader is not set, open the file. if f.reader == nil { var err error f.reader, err = os.Open(f.path) if err != nil { return fmt.Errorf("open file error: %w", err) } } // Ensure the file reader is always closed after copying. defer f.reader.Close() //nolint:errcheck // not needed // Create form file and copy the content. w, err := mw.CreateFormFile(f.fieldName, f.name) if err != nil { return fmt.Errorf("create file error: %w", err) } if _, err := io.CopyBuffer(w, f.reader, *fileBuf); err != nil { return fmt.Errorf("failed to copy file data: %w", err) } return nil } // parserResponseCookie parses the Set-Cookie headers from the response and stores them. func parserResponseCookie(c *Client, resp *Response, req *Request) error { var err error for key, value := range resp.RawResponse.Header.Cookies() { cookie := fasthttp.AcquireCookie() if err = cookie.ParseBytes(value); err != nil { fasthttp.ReleaseCookie(cookie) break } cookie.SetKeyBytes(key) resp.cookie = append(resp.cookie, cookie) } if err != nil { return err } // Store cookies in the cookie jar if available. if c.cookieJar != nil { c.cookieJar.parseCookiesFromResp(req.RawRequest.URI().Host(), req.RawRequest.URI().Path(), resp.RawResponse) } return nil } // logger is a response hook that logs request and response data if debug mode is enabled. func logger(c *Client, resp *Response, req *Request) error { if !c.debug { return nil } c.logger.Debugf("%s\n", req.RawRequest.String()) c.logger.Debugf("%s\n", resp.RawResponse.String()) return nil } ================================================ FILE: client/hooks_test.go ================================================ package client import ( "bytes" "encoding/xml" "fmt" "io" "net" "net/url" "path/filepath" "strings" "testing" "github.com/fxamacker/cbor/v2" "github.com/gofiber/fiber/v3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) func Test_Rand_String(t *testing.T) { t.Parallel() tests := []struct { name string args int }{ { name: "test generate", args: 16, }, { name: "test generate smaller string", args: 8, }, { name: "test generate larger string", args: 32, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() got, err := unsafeRandString(tt.args) require.NoError(t, err) require.Len(t, got, tt.args) }) } t.Run("valid characters", func(t *testing.T) { t.Parallel() got, err := unsafeRandString(32) require.NoError(t, err) for i := 0; i < len(got); i++ { require.Contains(t, letterBytes, string(got[i])) } }) } func Test_Parser_Request_URL(t *testing.T) { t.Parallel() t.Run("client baseurl should be set", func(t *testing.T) { t.Parallel() client := New().SetBaseURL("http://example.com/api") req := AcquireRequest().SetURL("") err := parserRequestURL(client, req) require.NoError(t, err) require.Equal(t, "http://example.com/api", req.RawRequest.URI().String()) }) t.Run("request url should be set", func(t *testing.T) { t.Parallel() client := New() req := AcquireRequest().SetURL("http://example.com/api") err := parserRequestURL(client, req) require.NoError(t, err) require.Equal(t, "http://example.com/api", req.RawRequest.URI().String()) }) t.Run("the request url will override baseurl with protocol", func(t *testing.T) { t.Parallel() client := New().SetBaseURL("http://example.com/api") req := AcquireRequest().SetURL("http://example.com/api/v1") err := parserRequestURL(client, req) require.NoError(t, err) require.Equal(t, "http://example.com/api/v1", req.RawRequest.URI().String()) }) t.Run("the request url should be append after baseurl without protocol", func(t *testing.T) { t.Parallel() client := New().SetBaseURL("http://example.com/api") req := AcquireRequest().SetURL("/v1") err := parserRequestURL(client, req) require.NoError(t, err) require.Equal(t, "http://example.com/api/v1", req.RawRequest.URI().String()) }) t.Run("the url is error", func(t *testing.T) { t.Parallel() client := New().SetBaseURL("example.com/api") req := AcquireRequest().SetURL("/v1") err := parserRequestURL(client, req) require.Equal(t, ErrURLFormat, err) }) t.Run("the path param from client", func(t *testing.T) { t.Parallel() client := New(). SetBaseURL("http://example.com/api/:id"). SetPathParam("id", "5") req := AcquireRequest() err := parserRequestURL(client, req) require.NoError(t, err) require.Equal(t, "http://example.com/api/5", req.RawRequest.URI().String()) }) t.Run("the path param from request", func(t *testing.T) { t.Parallel() client := New(). SetBaseURL("http://example.com/api/:id/:name"). SetPathParam("id", "5") req := AcquireRequest(). SetURL("/{key}"). SetPathParams(map[string]string{ "name": "fiber", "key": "val", }). DelPathParams("key") err := parserRequestURL(client, req) require.NoError(t, err) require.Equal(t, "http://example.com/api/5/fiber/%7Bkey%7D", req.RawRequest.URI().String()) }) t.Run("the path param from request and client", func(t *testing.T) { t.Parallel() client := New(). SetBaseURL("http://example.com/api/:id/:name"). SetPathParam("id", "5") req := AcquireRequest(). SetURL("/:key"). SetPathParams(map[string]string{ "name": "fiber", "key": "val", "id": "12", }) err := parserRequestURL(client, req) require.NoError(t, err) require.Equal(t, "http://example.com/api/12/fiber/val", req.RawRequest.URI().String()) }) t.Run("query params from client should be set", func(t *testing.T) { t.Parallel() client := New(). SetParam("foo", "bar") req := AcquireRequest().SetURL("http://example.com/api/v1") err := parserRequestURL(client, req) require.NoError(t, err) require.Equal(t, []byte("foo=bar"), req.RawRequest.URI().QueryString()) }) t.Run("query params from request should be set", func(t *testing.T) { t.Parallel() client := New() req := AcquireRequest(). SetURL("http://example.com/api/v1"). SetParam("bar", "foo") err := parserRequestURL(client, req) require.NoError(t, err) require.Equal(t, []byte("bar=foo"), req.RawRequest.URI().QueryString()) }) t.Run("query params should be merged", func(t *testing.T) { t.Parallel() client := New(). SetParam("bar", "foo1") req := AcquireRequest(). SetURL("http://example.com/api/v1?bar=foo2"). SetParam("bar", "foo") err := parserRequestURL(client, req) require.NoError(t, err) values, err := url.ParseQuery(string(req.RawRequest.URI().QueryString())) require.NoError(t, err) flag1, flag2, flag3 := false, false, false for _, v := range values["bar"] { switch v { case "foo1": flag1 = true case "foo2": flag2 = true case "foo": flag3 = true default: t.Fatalf("unexpected query param value: %s", v) } } require.True(t, flag1) require.True(t, flag2) require.True(t, flag3) }) t.Run("request disable path normalizing should be respected", func(t *testing.T) { t.Parallel() client := New() req := AcquireRequest(). SetURL("https://example.my.host/other.host%2Fpath%2Fto%2Fdata%23123"). SetDisablePathNormalizing(true) t.Cleanup(func() { ReleaseRequest(req) }) err := parserRequestURL(client, req) require.NoError(t, err) require.Equal(t, "https://example.my.host/other.host%2Fpath%2Fto%2Fdata%23123", req.RawRequest.URI().String()) }) t.Run("client disable path normalizing should be respected", func(t *testing.T) { t.Parallel() client := New().SetDisablePathNormalizing(true) req := AcquireRequest(). SetURL("https://example.my.host/other.host%2Fpath%2Fto%2Fdata%23123") t.Cleanup(func() { ReleaseRequest(req) }) err := parserRequestURL(client, req) require.NoError(t, err) require.Equal(t, "https://example.my.host/other.host%2Fpath%2Fto%2Fdata%23123", req.RawRequest.URI().String()) }) } func Test_Parser_Request_Header(t *testing.T) { t.Parallel() t.Run("client header should be set", func(t *testing.T) { t.Parallel() client := New(). SetHeaders(map[string]string{ fiber.HeaderContentType: "application/json", }) req := AcquireRequest() err := parserRequestHeader(client, req) require.NoError(t, err) require.Equal(t, []byte("application/json"), req.RawRequest.Header.ContentType()) }) t.Run("request header should be set", func(t *testing.T) { t.Parallel() client := New() req := AcquireRequest(). SetHeaders(map[string]string{ fiber.HeaderContentType: "application/json, utf-8", }) err := parserRequestHeader(client, req) require.NoError(t, err) require.Equal(t, []byte("application/json, utf-8"), req.RawRequest.Header.ContentType()) }) t.Run("request header should override client header", func(t *testing.T) { t.Parallel() client := New(). SetHeader(fiber.HeaderContentType, "application/xml") req := AcquireRequest(). SetHeader(fiber.HeaderContentType, "application/json, utf-8") err := parserRequestHeader(client, req) require.NoError(t, err) require.Equal(t, []byte("application/json, utf-8"), req.RawRequest.Header.ContentType()) }) t.Run("auto set json header", func(t *testing.T) { t.Parallel() type jsonData struct { Name string `json:"name"` } client := New() req := AcquireRequest(). SetJSON(jsonData{ Name: "foo", }) err := parserRequestHeader(client, req) require.NoError(t, err) require.Equal(t, []byte(applicationJSON), req.RawRequest.Header.ContentType()) //nolint:testifylint // test }) t.Run("auto set xml header", func(t *testing.T) { t.Parallel() type xmlData struct { XMLName xml.Name `xml:"body"` Name string `xml:"name"` } client := New() req := AcquireRequest(). SetXML(xmlData{ Name: "foo", }) err := parserRequestHeader(client, req) require.NoError(t, err) require.Equal(t, []byte(applicationXML), req.RawRequest.Header.ContentType()) }) t.Run("auto set form data header", func(t *testing.T) { t.Parallel() client := New() req := AcquireRequest(). SetFormDataWithMap(map[string]string{ "foo": "bar", "ball": "circle and square", }) err := parserRequestHeader(client, req) require.NoError(t, err) require.Equal(t, applicationForm, string(req.RawRequest.Header.ContentType())) }) t.Run("auto set file header", func(t *testing.T) { t.Parallel() client := New() req := AcquireRequest(). AddFileWithReader("hello", io.NopCloser(strings.NewReader("world"))). SetFormData("foo", "bar") err := parserRequestHeader(client, req) require.NoError(t, err) require.Contains(t, string(req.RawRequest.Header.MultipartFormBoundary()), "FiberFormBoundary") require.Contains(t, string(req.RawRequest.Header.ContentType()), multipartFormData) }) t.Run("ua should have default value", func(t *testing.T) { t.Parallel() client := New() req := AcquireRequest() err := parserRequestHeader(client, req) require.NoError(t, err) require.Equal(t, []byte("fiber"), req.RawRequest.Header.UserAgent()) }) t.Run("ua in client should be set", func(t *testing.T) { t.Parallel() client := New().SetUserAgent("foo") req := AcquireRequest() err := parserRequestHeader(client, req) require.NoError(t, err) require.Equal(t, []byte("foo"), req.RawRequest.Header.UserAgent()) }) t.Run("ua in request should have higher level", func(t *testing.T) { t.Parallel() client := New().SetUserAgent("foo") req := AcquireRequest().SetUserAgent("bar") err := parserRequestHeader(client, req) require.NoError(t, err) require.Equal(t, []byte("bar"), req.RawRequest.Header.UserAgent()) }) t.Run("referer in client should be set", func(t *testing.T) { t.Parallel() client := New().SetReferer("https://example.com") req := AcquireRequest() err := parserRequestHeader(client, req) require.NoError(t, err) require.Equal(t, []byte("https://example.com"), req.RawRequest.Header.Referer()) }) t.Run("referer in request should have higher level", func(t *testing.T) { t.Parallel() client := New().SetReferer("http://example.com") req := AcquireRequest().SetReferer("https://example.com") err := parserRequestHeader(client, req) require.NoError(t, err) require.Equal(t, []byte("https://example.com"), req.RawRequest.Header.Referer()) }) t.Run("client cookie should be set", func(t *testing.T) { t.Parallel() client := New(). SetCookie("foo", "bar"). SetCookies(map[string]string{ "bar": "foo", "bar1": "foo1", }). DelCookies("bar1") req := AcquireRequest() err := parserRequestHeader(client, req) require.NoError(t, err) require.Equal(t, "bar", string(req.RawRequest.Header.Cookie("foo"))) require.Equal(t, "foo", string(req.RawRequest.Header.Cookie("bar"))) require.Empty(t, string(req.RawRequest.Header.Cookie("bar1"))) }) t.Run("request cookie should be set", func(t *testing.T) { t.Parallel() type cookies struct { Foo string `cookie:"foo"` Bar int `cookie:"bar"` } client := New() req := AcquireRequest(). SetCookiesWithStruct(&cookies{ Foo: "bar", Bar: 67, }) err := parserRequestHeader(client, req) require.NoError(t, err) require.Equal(t, "bar", string(req.RawRequest.Header.Cookie("foo"))) require.Equal(t, "67", string(req.RawRequest.Header.Cookie("bar"))) require.Empty(t, string(req.RawRequest.Header.Cookie("bar1"))) }) t.Run("request cookie will override client cookie", func(t *testing.T) { t.Parallel() type cookies struct { Foo string `cookie:"foo"` Bar int `cookie:"bar"` } client := New(). SetCookie("foo", "bar"). SetCookies(map[string]string{ "bar": "foo", "bar1": "foo1", }) req := AcquireRequest(). SetCookiesWithStruct(&cookies{ Foo: "bar", Bar: 67, }) err := parserRequestHeader(client, req) require.NoError(t, err) require.Equal(t, "bar", string(req.RawRequest.Header.Cookie("foo"))) require.Equal(t, "67", string(req.RawRequest.Header.Cookie("bar"))) require.Equal(t, "foo1", string(req.RawRequest.Header.Cookie("bar1"))) }) } func Test_Parser_Request_Body(t *testing.T) { t.Parallel() t.Run("json body", func(t *testing.T) { t.Parallel() type jsonData struct { Name string `json:"name"` } client := New() req := AcquireRequest(). SetJSON(jsonData{ Name: "foo", }) err := parserRequestBody(client, req) require.NoError(t, err) require.Equal(t, []byte("{\"name\":\"foo\"}"), req.RawRequest.Body()) //nolint:testifylint // test }) t.Run("xml body", func(t *testing.T) { t.Parallel() type xmlData struct { XMLName xml.Name `xml:"body"` Name string `xml:"name"` } client := New() req := AcquireRequest(). SetXML(xmlData{ Name: "foo", }) err := parserRequestBody(client, req) require.NoError(t, err) require.Equal(t, []byte("foo"), req.RawRequest.Body()) }) t.Run("CBOR body", func(t *testing.T) { t.Parallel() type cborData struct { Name string `cbor:"name"` Age int `cbor:"age"` } data := cborData{ Name: "foo", Age: 12, } client := New() req := AcquireRequest(). SetCBOR(data) err := parserRequestBody(client, req) require.NoError(t, err) encoded, err := cbor.Marshal(data) require.NoError(t, err) require.Equal(t, encoded, req.RawRequest.Body()) }) t.Run("form data body", func(t *testing.T) { t.Parallel() client := New() req := AcquireRequest(). SetFormDataWithMap(map[string]string{ "ball": "circle and square", }) err := parserRequestBody(client, req) require.NoError(t, err) require.Equal(t, "ball=circle+and+square", string(req.RawRequest.Body())) }) t.Run("form data body error", func(t *testing.T) { t.Parallel() client := New() req := AcquireRequest(). SetFormDataWithMap(map[string]string{ "": "", }) err := parserRequestBody(client, req) require.NoError(t, err) }) t.Run("file body", func(t *testing.T) { t.Parallel() client := New() req := AcquireRequest(). AddFileWithReader("hello", io.NopCloser(strings.NewReader("world"))) err := parserRequestBody(client, req) require.NoError(t, err) require.Contains(t, string(req.RawRequest.Body()), "--FiberFormBoundary") require.Contains(t, string(req.RawRequest.Body()), "world") }) t.Run("file body open error", func(t *testing.T) { t.Parallel() client := New() missingPath := filepath.Join(t.TempDir(), "missing.txt") req := AcquireRequest().AddFile(missingPath) err := parserRequestBody(client, req) require.ErrorContains(t, err, "open file error") }) t.Run("file body missing path and name", func(t *testing.T) { t.Parallel() client := New() file := AcquireFile(SetFileReader(io.NopCloser(strings.NewReader("world")))) req := AcquireRequest().AddFiles(file) err := parserRequestBody(client, req) require.ErrorIs(t, err, ErrFileNoName) }) t.Run("file and form data", func(t *testing.T) { t.Parallel() client := New() req := AcquireRequest(). AddFileWithReader("hello", io.NopCloser(strings.NewReader("world"))). SetFormData("foo", "bar") err := parserRequestBody(client, req) require.NoError(t, err) require.Contains(t, string(req.RawRequest.Body()), "--FiberFormBoundary") require.Contains(t, string(req.RawRequest.Body()), "world") require.Contains(t, string(req.RawRequest.Body()), "bar") }) t.Run("raw body", func(t *testing.T) { t.Parallel() client := New() req := AcquireRequest(). SetRawBody([]byte("hello world")) err := parserRequestBody(client, req) require.NoError(t, err) require.Equal(t, []byte("hello world"), req.RawRequest.Body()) }) t.Run("raw body error", func(t *testing.T) { t.Parallel() client := New() req := AcquireRequest(). SetRawBody([]byte("hello world")) req.body = nil err := parserRequestBody(client, req) require.ErrorIs(t, err, ErrBodyType) }) t.Run("unsupported body type", func(t *testing.T) { t.Parallel() client := New() req := AcquireRequest() req.bodyType = 999 // some invalid type err := parserRequestBody(client, req) require.ErrorIs(t, err, ErrBodyTypeNotSupported) }) } type dummyLogger struct { buf *bytes.Buffer } func (*dummyLogger) Trace(_ ...any) {} func (*dummyLogger) Debug(_ ...any) {} func (*dummyLogger) Info(_ ...any) {} func (*dummyLogger) Warn(_ ...any) {} func (*dummyLogger) Error(_ ...any) {} func (*dummyLogger) Fatal(_ ...any) {} func (*dummyLogger) Panic(_ ...any) {} func (*dummyLogger) Tracef(_ string, _ ...any) {} func (l *dummyLogger) Debugf(format string, v ...any) { fmt.Fprintf(l.buf, format, v...) } func (*dummyLogger) Infof(_ string, _ ...any) {} func (*dummyLogger) Warnf(_ string, _ ...any) {} func (*dummyLogger) Errorf(_ string, _ ...any) {} func (*dummyLogger) Fatalf(_ string, _ ...any) {} func (*dummyLogger) Panicf(_ string, _ ...any) {} func (*dummyLogger) Tracew(_ string, _ ...any) {} func (*dummyLogger) Debugw(_ string, _ ...any) {} func (*dummyLogger) Infow(_ string, _ ...any) {} func (*dummyLogger) Warnw(_ string, _ ...any) {} func (*dummyLogger) Errorw(_ string, _ ...any) {} func (*dummyLogger) Fatalw(_ string, _ ...any) {} func (*dummyLogger) Panicw(_ string, _ ...any) {} func Test_Client_Logger_Debug(t *testing.T) { t.Parallel() app := fiber.New() app.Get("/", func(c fiber.Ctx) error { return c.SendString("response") }) addrChan := make(chan string) go func() { assert.NoError(t, app.Listen(":0", fiber.ListenConfig{ DisableStartupMessage: true, ListenerAddrFunc: func(addr net.Addr) { addrChan <- addr.String() }, })) }() defer func(app *fiber.App) { require.NoError(t, app.Shutdown()) }(app) var buf bytes.Buffer logger := &dummyLogger{buf: &buf} client := New() client.Debug().SetLogger(logger) addr := <-addrChan resp, err := client.Get("http://" + addr) require.NoError(t, err) defer resp.Close() require.NoError(t, err) require.Contains(t, buf.String(), "Host: "+addr) require.Contains(t, buf.String(), "Content-Length: 8") } func Test_Client_Logger_DisableDebug(t *testing.T) { t.Parallel() app := fiber.New() app.Get("/", func(c fiber.Ctx) error { return c.SendString("response") }) addrChan := make(chan string) go func() { assert.NoError(t, app.Listen(":0", fiber.ListenConfig{ DisableStartupMessage: true, ListenerAddrFunc: func(addr net.Addr) { addrChan <- addr.String() }, })) }() defer func(app *fiber.App) { require.NoError(t, app.Shutdown()) }(app) var buf bytes.Buffer logger := &dummyLogger{buf: &buf} client := New() client.DisableDebug().SetLogger(logger) addr := <-addrChan resp, err := client.Get("http://" + addr) require.NoError(t, err) defer resp.Close() require.NoError(t, err) require.Empty(t, buf.String()) } func Benchmark_Parser_Request_Body_File(b *testing.B) { b.Helper() const ( fileCount = 3 fileSize = 32 << 10 // 32KB payload per file ) formValues := map[string]string{ "username": "fiber", "api_key": "d5942ef5", } fileContents := make([][]byte, fileCount) for i := range fileContents { fileContents[i] = bytes.Repeat([]byte{byte('a' + i)}, fileSize) } b.ReportAllocs() for i := 0; i < b.N; i++ { var totalBytes int64 for _, c := range fileContents { totalBytes += int64(len(c)) } b.SetBytes(totalBytes) req := newBenchmarkRequest(formValues, fileContents) if err := parserRequestBodyFile(req); err != nil { b.Fatalf("parserRequestBodyFile: %v", err) } releaseBenchmarkRequest(req) } } func newBenchmarkRequest(formValues map[string]string, fileContents [][]byte) *Request { req := &Request{ boundary: "FiberBenchmarkBoundary", formData: FormData{Args: fasthttp.AcquireArgs()}, RawRequest: fasthttp.AcquireRequest(), files: make([]*File, len(fileContents)), } req.RawRequest.Header.SetContentType("multipart/form-data; boundary=" + req.boundary) for key, value := range formValues { req.formData.Set(key, value) } for i, content := range fileContents { req.files[i] = AcquireFile( SetFileName(fmt.Sprintf("file-%d.bin", i)), SetFileFieldName(fmt.Sprintf("file%d", i)), SetFileReader(io.NopCloser(bytes.NewReader(content))), ) } return req } func releaseBenchmarkRequest(req *Request) { fasthttp.ReleaseRequest(req.RawRequest) fasthttp.ReleaseArgs(req.formData.Args) for _, f := range req.files { ReleaseFile(f) } } ================================================ FILE: client/request.go ================================================ package client import ( "bytes" "context" "errors" "io" "iter" "maps" "path/filepath" "reflect" "slices" "sort" "strconv" "sync" "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) // WithStruct is implemented by types that allow data to be stored from a struct via reflection. type WithStruct interface { Add(name, obj string) Del(name string) } // bodyType defines the type of request body. type bodyType int // Enumeration of request body types. const ( noBody bodyType = iota jsonBody xmlBody formBody filesBody rawBody cborBody ) var ErrClientNil = errors.New("client cannot be nil") // Request contains all data related to an HTTP request. type Request struct { ctx context.Context //nolint:containedctx // Context is needed to be stored in the request. body any header Header params QueryParam cookies Cookie path PathParam client *Client formData FormData RawRequest *fasthttp.Request url string method string userAgent string boundary string referer string files []*File timeout time.Duration maxRedirects int bodyType bodyType disablePathNormalizing bool } // Method returns the HTTP method set in the Request. func (r *Request) Method() string { return r.method } // SetMethod sets the HTTP method for the Request. // It is recommended to use the specialized methods (e.g., Get, Post) instead. func (r *Request) SetMethod(method string) *Request { r.method = method return r } // URL returns the URL set in the Request. func (r *Request) URL() string { return r.url } // SetURL sets the URL for the Request. func (r *Request) SetURL(url string) *Request { r.url = url return r } // Client returns the Client instance associated with this Request. func (r *Request) Client() *Client { return r.client } // SetClient sets the Client instance for the Request. func (r *Request) SetClient(c *Client) *Request { if c == nil { panic(ErrClientNil) } r.client = c return r } // Context returns the context associated with the Request. // If not set, a background context is returned. func (r *Request) Context() context.Context { if r.ctx == nil { return context.Background() } return r.ctx } // SetContext sets the context for the Request, allowing request cancellation if ctx is done. // See https://blog.golang.org/context article and the "context" package documentation. func (r *Request) SetContext(ctx context.Context) *Request { r.ctx = ctx return r } // Header returns all values associated with the given header key. func (r *Request) Header(key string) []string { return r.header.PeekMultiple(key) } type pair struct { k []string v []string } // Len implements sort.Interface and reports the number of tracked keys. func (p *pair) Len() int { return len(p.k) } // Swap implements sort.Interface and swaps the entries at the provided indices. func (p *pair) Swap(i, j int) { p.k[i], p.k[j] = p.k[j], p.k[i] p.v[i], p.v[j] = p.v[j], p.v[i] } // Less implements sort.Interface and orders entries lexicographically by key. func (p *pair) Less(i, j int) bool { return p.k[i] < p.k[j] } // Headers returns an iterator over all headers in the Request. // Use maps.Collect() to gather them into a map if needed. // // The returned values are only valid until the request object is released. // Do not store references to returned values; make copies instead. func (r *Request) Headers() iter.Seq2[string, []string] { return func(yield func(string, []string) bool) { peekKeys := r.header.PeekKeys() // Copy keys to immutable strings to decouple from fasthttp's internal buffers. keys := make([]string, len(peekKeys)) for i, key := range peekKeys { keys[i] = utils.UnsafeString(key) } for _, key := range keys { vals := r.header.PeekAll(key) valsStr := make([]string, len(vals)) for i, v := range vals { valsStr[i] = utils.UnsafeString(v) } if !yield(key, valsStr) { return } } } } // AddHeader adds a single header field and value to the Request. func (r *Request) AddHeader(key, val string) *Request { r.header.Add(key, val) return r } // SetHeader sets a single header field and value in the Request, overriding any previously set value. func (r *Request) SetHeader(key, val string) *Request { r.header.Del(key) r.header.Set(key, val) return r } // AddHeaders adds multiple header fields and values at once. func (r *Request) AddHeaders(h map[string][]string) *Request { r.header.AddHeaders(h) return r } // SetHeaders sets multiple header fields and values at once, overriding previously set values. func (r *Request) SetHeaders(h map[string]string) *Request { r.header.SetHeaders(h) return r } // Param returns all values associated with the given query parameter. func (r *Request) Param(key string) []string { tmp := r.params.PeekMulti(key) res := make([]string, 0, len(tmp)) for _, v := range tmp { res = append(res, utils.UnsafeString(v)) } return res } // Params returns an iterator over all query parameters in the Request. // Use maps.Collect() to gather them into a map if needed. // // The returned values are only valid until the request object is released. // Do not store references to returned values; make copies instead. func (r *Request) Params() iter.Seq2[string, []string] { return func(yield func(string, []string) bool) { vals := r.params.Len() if vals == 0 { return } prealloc := make([]string, 2*vals) p := pair{ k: prealloc[:0:vals], v: prealloc[vals : vals : 2*vals], } for k, v := range r.params.All() { p.k = append(p.k, utils.UnsafeString(k)) p.v = append(p.v, utils.UnsafeString(v)) } sort.Sort(&p) j := 0 for i := range vals { if i == vals-1 || p.k[i] != p.k[i+1] { if !yield(p.k[i], p.v[j:i+1]) { break } j = i + 1 } } } } // AddParam adds a single query parameter and value to the Request. func (r *Request) AddParam(key, val string) *Request { r.params.Add(key, val) return r } // SetParam sets a single query parameter and value in the Request, overriding any previously set value. func (r *Request) SetParam(key, val string) *Request { r.params.Set(key, val) return r } // AddParams adds multiple query parameters and their values at once. func (r *Request) AddParams(m map[string][]string) *Request { r.params.AddParams(m) return r } // SetParams sets multiple query parameters and their values at once, overriding previously set values. func (r *Request) SetParams(m map[string]string) *Request { r.params.SetParams(m) return r } // SetParamsWithStruct sets multiple query parameters from a struct, overriding previously set values. func (r *Request) SetParamsWithStruct(v any) *Request { r.params.SetParamsWithStruct(v) return r } // DelParams deletes one or more query parameters. func (r *Request) DelParams(key ...string) *Request { for _, v := range key { r.params.Del(v) } return r } // UserAgent returns the User-Agent header set in the Request. func (r *Request) UserAgent() string { return r.userAgent } // SetUserAgent sets the User-Agent header, overriding any previously set value. func (r *Request) SetUserAgent(ua string) *Request { r.userAgent = ua return r } // Boundary returns the multipart boundary used by the Request. func (r *Request) Boundary() string { return r.boundary } // SetBoundary sets the multipart boundary. func (r *Request) SetBoundary(b string) *Request { r.boundary = b return r } // Referer returns the Referer header set in the Request. func (r *Request) Referer() string { return r.referer } // SetReferer sets the Referer header, overriding any previously set value. func (r *Request) SetReferer(referer string) *Request { r.referer = referer return r } // Cookie returns the value of a named cookie. // If the cookie does not exist, an empty string is returned. func (r *Request) Cookie(key string) string { if val, ok := r.cookies[key]; ok { return val } return "" } // Cookies returns an iterator over all cookies. // Use maps.Collect() to gather them into a map if needed. func (r *Request) Cookies() iter.Seq2[string, string] { return r.cookies.All() } // SetCookie sets a single cookie, overriding any previously set value. func (r *Request) SetCookie(key, val string) *Request { r.cookies.SetCookie(key, val) return r } // SetCookies sets multiple cookies at once, overriding previously set values. func (r *Request) SetCookies(m map[string]string) *Request { r.cookies.SetCookies(m) return r } // SetCookiesWithStruct sets multiple cookies from a struct, overriding previously set values. func (r *Request) SetCookiesWithStruct(v any) *Request { r.cookies.SetCookiesWithStruct(v) return r } // DelCookies deletes one or more cookies. func (r *Request) DelCookies(key ...string) *Request { r.cookies.DelCookies(key...) return r } // PathParam returns the value of a named path parameter. // If the parameter does not exist, an empty string is returned. func (r *Request) PathParam(key string) string { if val, ok := r.path[key]; ok { return val } return "" } // PathParams returns an iterator over all path parameters. // Use maps.Collect() to gather them into a map if needed. func (r *Request) PathParams() iter.Seq2[string, string] { return r.path.All() } // SetPathParam sets a single path parameter and value, overriding any previously set value. func (r *Request) SetPathParam(key, val string) *Request { r.path.SetParam(key, val) return r } // SetPathParams sets multiple path parameters and values at once, overriding previously set values. func (r *Request) SetPathParams(m map[string]string) *Request { r.path.SetParams(m) return r } // SetPathParamsWithStruct sets multiple path parameters from a struct, overriding previously set values. func (r *Request) SetPathParamsWithStruct(v any) *Request { r.path.SetParamsWithStruct(v) return r } // DelPathParams deletes one or more path parameters. func (r *Request) DelPathParams(key ...string) *Request { r.path.DelParams(key...) return r } // ResetPathParams deletes all path parameters. func (r *Request) ResetPathParams() *Request { r.path.Reset() return r } // SetJSON sets the request body to a JSON-encoded value. func (r *Request) SetJSON(v any) *Request { r.body = v r.bodyType = jsonBody return r } // SetXML sets the request body to an XML-encoded value. func (r *Request) SetXML(v any) *Request { r.body = v r.bodyType = xmlBody return r } // SetCBOR sets the request body to a CBOR-encoded value. func (r *Request) SetCBOR(v any) *Request { r.body = v r.bodyType = cborBody return r } // SetRawBody sets the request body to raw bytes. func (r *Request) SetRawBody(v []byte) *Request { r.body = v r.bodyType = rawBody return r } // resetBody clears the existing body. If the current body type is filesBody and // the new type is formBody, the formBody setting is ignored to preserve files. func (r *Request) resetBody(t bodyType) { r.body = nil // If bodyType is filesBody and we attempt to set formBody, ignore the change. if r.bodyType == filesBody && t == formBody { return } r.bodyType = t } // FormData returns all values associated with a form field. func (r *Request) FormData(key string) []string { tmp := r.formData.PeekMulti(key) res := make([]string, 0, len(tmp)) for _, v := range tmp { res = append(res, utils.UnsafeString(v)) } return res } // AllFormData returns an iterator over all form fields. // Use maps.Collect() to gather them into a map if needed. // // The returned values are only valid until the request object is released. // Do not store references to returned values; make copies instead. func (r *Request) AllFormData() iter.Seq2[string, []string] { return func(yield func(string, []string) bool) { vals := r.formData.Len() if vals == 0 { return } prealloc := make([]string, 2*vals) p := pair{ k: prealloc[:0:vals], v: prealloc[vals : vals : 2*vals], } for k, v := range r.formData.All() { p.k = append(p.k, utils.UnsafeString(k)) p.v = append(p.v, utils.UnsafeString(v)) } sort.Sort(&p) j := 0 for i := range vals { if i == vals-1 || p.k[i] != p.k[i+1] { if !yield(p.k[i], p.v[j:i+1]) { break } j = i + 1 } } } } // AddFormData adds a single form field and value to the Request. func (r *Request) AddFormData(key, val string) *Request { r.formData.Add(key, val) r.resetBody(formBody) return r } // SetFormData sets a single form field and value, overriding any previously set value. func (r *Request) SetFormData(key, val string) *Request { r.formData.Set(key, val) r.resetBody(formBody) return r } // AddFormDataWithMap adds multiple form fields and values to the Request. func (r *Request) AddFormDataWithMap(m map[string][]string) *Request { r.formData.AddWithMap(m) r.resetBody(formBody) return r } // SetFormDataWithMap sets multiple form fields and values at once, overriding previously set values. func (r *Request) SetFormDataWithMap(m map[string]string) *Request { r.formData.SetWithMap(m) r.resetBody(formBody) return r } // SetFormDataWithStruct sets multiple form fields from a struct, overriding previously set values. func (r *Request) SetFormDataWithStruct(v any) *Request { r.formData.SetWithStruct(v) r.resetBody(formBody) return r } // DelFormData deletes one or more form fields. func (r *Request) DelFormData(key ...string) *Request { r.formData.DelData(key...) r.resetBody(formBody) return r } // File returns the file associated with the given name. // If no name was provided during addition, it attempts to match by the file's base name. func (r *Request) File(name string) *File { for _, v := range r.files { switch v.name { case "": if filepath.Base(v.path) == name { return v } case name: return v default: continue } } return nil } // Files returns all files added to the Request. // // The returned values are only valid until the request object is released. // Do not store references to returned values; make copies instead. func (r *Request) Files() []*File { return r.files } // FileByPath returns the file associated with the given file path. func (r *Request) FileByPath(path string) *File { for _, v := range r.files { if v.path == path { return v } } return nil } // AddFile adds a single file by its path. func (r *Request) AddFile(path string) *Request { r.files = append(r.files, AcquireFile(SetFilePath(path))) r.resetBody(filesBody) return r } // AddFileWithReader adds a file using an io.ReadCloser. func (r *Request) AddFileWithReader(name string, reader io.ReadCloser) *Request { r.files = append(r.files, AcquireFile(SetFileName(name), SetFileReader(reader))) r.resetBody(filesBody) return r } // AddFiles adds multiple files at once. func (r *Request) AddFiles(files ...*File) *Request { r.files = append(r.files, files...) r.resetBody(filesBody) return r } // Timeout returns the timeout duration set in the Request. func (r *Request) Timeout() time.Duration { return r.timeout } // SetTimeout sets the timeout for the Request, overriding any previously set value. func (r *Request) SetTimeout(t time.Duration) *Request { r.timeout = t return r } // MaxRedirects returns the maximum number of redirects configured for the Request. func (r *Request) MaxRedirects() int { return r.maxRedirects } // SetMaxRedirects sets the maximum number of redirects, overriding any previously set value. func (r *Request) SetMaxRedirects(count int) *Request { r.maxRedirects = count return r } // DisablePathNormalizing reports whether path normalizing is disabled for the Request. func (r *Request) DisablePathNormalizing() bool { return r.disablePathNormalizing } // SetDisablePathNormalizing configures the Request to disable or enable path normalizing. func (r *Request) SetDisablePathNormalizing(disable bool) *Request { r.disablePathNormalizing = disable r.RawRequest.URI().DisablePathNormalizing = disable return r } // checkClient ensures that a Client is set. If none is set, it defaults to the global defaultClient. func (r *Request) checkClient() { if r.client == nil { r.SetClient(defaultClient) } } // Get sends a GET request to the given URL. func (r *Request) Get(url string) (*Response, error) { return r.SetURL(url).SetMethod(fiber.MethodGet).Send() } // Post sends a POST request to the given URL. func (r *Request) Post(url string) (*Response, error) { return r.SetURL(url).SetMethod(fiber.MethodPost).Send() } // Head sends a HEAD request to the given URL. func (r *Request) Head(url string) (*Response, error) { return r.SetURL(url).SetMethod(fiber.MethodHead).Send() } // Put sends a PUT request to the given URL. func (r *Request) Put(url string) (*Response, error) { return r.SetURL(url).SetMethod(fiber.MethodPut).Send() } // Delete sends a DELETE request to the given URL. func (r *Request) Delete(url string) (*Response, error) { return r.SetURL(url).SetMethod(fiber.MethodDelete).Send() } // Options sends an OPTIONS request to the given URL. func (r *Request) Options(url string) (*Response, error) { return r.SetURL(url).SetMethod(fiber.MethodOptions).Send() } // Patch sends a PATCH request to the given URL. func (r *Request) Patch(url string) (*Response, error) { return r.SetURL(url).SetMethod(fiber.MethodPatch).Send() } // Custom sends a request with a custom HTTP method to the given URL. func (r *Request) Custom(url, method string) (*Response, error) { return r.SetURL(url).SetMethod(method).Send() } // Send executes the Request. func (r *Request) Send() (*Response, error) { r.checkClient() return newCore().execute(r.Context(), r.Client(), r) } // Reset clears the Request object, returning it to its default state. // Used by ReleaseRequest to recycle the object. func (r *Request) Reset() { r.url = "" r.method = fiber.MethodGet r.userAgent = "" r.referer = "" r.ctx = nil r.body = nil r.timeout = 0 r.maxRedirects = 0 r.bodyType = noBody r.boundary = boundary r.disablePathNormalizing = false for len(r.files) != 0 { t := r.files[0] r.files = r.files[1:] ReleaseFile(t) } r.formData.Reset() r.path.Reset() r.cookies.Reset() r.header.Reset() r.params.Reset() r.RawRequest.Reset() } // Header wraps fasthttp.RequestHeader, storing headers for both client and request. type Header struct { *fasthttp.RequestHeader } // PeekMultiple returns multiple values of a header field with the same key. func (h *Header) PeekMultiple(key string) []string { var res []string byteKey := []byte(key) for k, value := range h.All() { if bytes.EqualFold(k, byteKey) { res = append(res, utils.UnsafeString(value)) } } return res } // AddHeaders adds multiple headers from a map. func (h *Header) AddHeaders(r map[string][]string) { for k, v := range r { for _, vv := range v { h.Add(k, vv) } } } // SetHeaders sets multiple headers from a map, overriding previously set values. func (h *Header) SetHeaders(r map[string]string) { for k, v := range r { h.Del(k) h.Set(k, v) } } // QueryParam wraps fasthttp.Args for query parameters. type QueryParam struct { *fasthttp.Args } // Keys returns all keys from the query parameters. func (p *QueryParam) Keys() []string { keys := make([]string, 0, p.Len()) for key := range p.All() { keys = append(keys, utils.UnsafeString(key)) } return slices.Compact(keys) } // AddParams adds multiple parameters from a map. func (p *QueryParam) AddParams(r map[string][]string) { for k, v := range r { for _, vv := range v { p.Add(k, vv) } } } // SetParams sets multiple parameters from a map, overriding previously set values. func (p *QueryParam) SetParams(r map[string]string) { for k, v := range r { p.Set(k, v) } } // SetParamsWithStruct sets multiple parameters from a struct. // Nested structs are not currently supported. func (p *QueryParam) SetParamsWithStruct(v any) { SetValWithStruct(p, "param", v) } // Cookie is a map used to store cookies. type Cookie map[string]string // Add adds a cookie key-value pair. func (c Cookie) Add(key, val string) { c[key] = val } // Del deletes a cookie by key. func (c Cookie) Del(key string) { delete(c, key) } // SetCookie sets a single cookie value. func (c Cookie) SetCookie(key, val string) { c[key] = val } // SetCookies sets multiple cookies from a map. func (c Cookie) SetCookies(m map[string]string) { maps.Copy(c, m) } // SetCookiesWithStruct sets cookies from a struct. // Nested structs are not currently supported. func (c Cookie) SetCookiesWithStruct(v any) { SetValWithStruct(c, "cookie", v) } // DelCookies deletes multiple cookies by keys. func (c Cookie) DelCookies(key ...string) { for _, v := range key { c.Del(v) } } // All returns an iterator over cookie key-value pairs. // // The returned key and value should not be retained after the iteration loop. func (c Cookie) All() iter.Seq2[string, string] { return maps.All(c) } // Reset clears the Cookie map. func (c Cookie) Reset() { clear(c) } // PathParam is a map used to store path parameters. type PathParam map[string]string // Add adds a path parameter key-value pair. func (p PathParam) Add(key, val string) { p[key] = val } // Del deletes a path parameter by key. func (p PathParam) Del(key string) { delete(p, key) } // SetParam sets a single path parameter. func (p PathParam) SetParam(key, val string) { p[key] = val } // SetParams sets multiple path parameters from a map. func (p PathParam) SetParams(m map[string]string) { maps.Copy(p, m) } // SetParamsWithStruct sets multiple path parameters from a struct. // Nested structs are not currently supported. func (p PathParam) SetParamsWithStruct(v any) { SetValWithStruct(p, "path", v) } // DelParams deletes multiple path parameters. func (p PathParam) DelParams(key ...string) { for _, v := range key { p.Del(v) } } // All returns an iterator over path parameter key-value pairs. // // The returned key and value should not be retained after the iteration loop. func (p PathParam) All() iter.Seq2[string, string] { return maps.All(p) } // Reset clears the PathParam map. func (p PathParam) Reset() { clear(p) } // FormData wraps fasthttp.Args for URL-encoded bodies and form data. type FormData struct { *fasthttp.Args } // Keys returns all keys from the form data. func (f *FormData) Keys() []string { keys := make([]string, 0, f.Len()) for key := range f.All() { keys = append(keys, utils.UnsafeString(key)) } return slices.Compact(keys) } // Add adds a single form field. func (f *FormData) Add(key, val string) { f.Args.Add(key, val) } // Set sets a single form field, overriding previously set values. func (f *FormData) Set(key, val string) { f.Args.Set(key, val) } // AddWithMap adds multiple form fields from a map. func (f *FormData) AddWithMap(m map[string][]string) { for k, v := range m { for _, vv := range v { f.Add(k, vv) } } } // SetWithMap sets multiple form fields from a map, overriding previously set values. func (f *FormData) SetWithMap(m map[string]string) { for k, v := range m { f.Set(k, v) } } // SetWithStruct sets multiple form fields from a struct. // Nested structs are not currently supported. func (f *FormData) SetWithStruct(v any) { SetValWithStruct(f, "form", v) } // DelData deletes multiple form fields. func (f *FormData) DelData(key ...string) { for _, v := range key { f.Del(v) } } // Reset clears the FormData object. func (f *FormData) Reset() { f.Args.Reset() } // File represents a file to be sent with the request. type File struct { reader io.ReadCloser name string fieldName string path string } // SetName sets the file's name. func (f *File) SetName(n string) { f.name = n } // SetFieldName sets the key associated with the file in the body. func (f *File) SetFieldName(n string) { f.fieldName = n } // SetPath sets the file's path. func (f *File) SetPath(p string) { f.path = p } // SetReader sets the file's reader, which will be closed in the parserBody hook. func (f *File) SetReader(r io.ReadCloser) { f.reader = r } // Reset clears the File object. func (f *File) Reset() { f.name = "" f.fieldName = "" f.path = "" f.reader = nil } var requestPool = &sync.Pool{ New: func() any { return &Request{ header: Header{RequestHeader: &fasthttp.RequestHeader{}}, params: QueryParam{Args: fasthttp.AcquireArgs()}, cookies: Cookie{}, path: PathParam{}, boundary: boundary, formData: FormData{Args: fasthttp.AcquireArgs()}, files: make([]*File, 0), RawRequest: fasthttp.AcquireRequest(), } }, } // AcquireRequest returns a new (pooled) Request object. func AcquireRequest() *Request { req, ok := requestPool.Get().(*Request) if !ok { panic(errRequestTypeAssertion) } return req } // ReleaseRequest returns the Request object to the pool. // Do not use the released Request afterward to avoid data races. func ReleaseRequest(req *Request) { req.Reset() requestPool.Put(req) } var filePool sync.Pool // SetFileFunc defines a function that modifies a File object. type SetFileFunc func(f *File) // SetFileName sets the file name. func SetFileName(n string) SetFileFunc { return func(f *File) { f.SetName(n) } } // SetFileFieldName sets the file's field name. func SetFileFieldName(p string) SetFileFunc { return func(f *File) { f.SetFieldName(p) } } // SetFilePath sets the file path. func SetFilePath(p string) SetFileFunc { return func(f *File) { f.SetPath(p) } } // SetFileReader sets the file's reader. func SetFileReader(r io.ReadCloser) SetFileFunc { return func(f *File) { f.SetReader(r) } } // AcquireFile returns a (pooled) File object and applies the provided SetFileFunc functions to it. func AcquireFile(setter ...SetFileFunc) *File { fv := filePool.Get() if fv != nil { f, ok := fv.(*File) if !ok { panic(errFileTypeAssertion) } for _, v := range setter { v(f) } return f } f := &File{} for _, v := range setter { v(f) } return f } // ReleaseFile returns the File object to the pool. // Do not use the released File afterward to avoid data races. func ReleaseFile(f *File) { f.Reset() filePool.Put(f) } // SetValWithStruct sets values using a struct. The struct's fields are examined via reflection. // `p` is a type that implements WithStruct. `tagName` defines the struct tag to look for. // `v` is the struct containing data. // // Fields in `v` should be string, int, int8, int16, int32, int64, uint, // uint8, uint16, uint32, uint64, float32, float64, complex64, // complex128 or bool. Arrays or slices are inserted sequentially with the // same key. Other types are ignored. func SetValWithStruct(p WithStruct, tagName string, v any) { valueOfV := reflect.ValueOf(v) typeOfV := reflect.TypeOf(v) // The value should be a struct or a pointer to a struct. if typeOfV.Kind() == reflect.Pointer && typeOfV.Elem().Kind() == reflect.Struct { valueOfV = valueOfV.Elem() typeOfV = typeOfV.Elem() } else if typeOfV.Kind() != reflect.Struct { return } // A helper function to set values. var setVal func(name string, val reflect.Value) setVal = func(name string, val reflect.Value) { switch val.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: p.Add(name, strconv.Itoa(int(val.Int()))) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: p.Add(name, strconv.FormatUint(val.Uint(), 10)) case reflect.Float32, reflect.Float64: p.Add(name, strconv.FormatFloat(val.Float(), 'f', -1, 64)) case reflect.Complex64, reflect.Complex128: p.Add(name, strconv.FormatComplex(val.Complex(), 'f', -1, 128)) case reflect.Bool: if val.Bool() { p.Add(name, "true") } else { p.Add(name, "false") } case reflect.String: p.Add(name, val.String()) case reflect.Slice, reflect.Array: for i := 0; i < val.Len(); i++ { setVal(name, val.Index(i)) } default: return } } for i := 0; i < typeOfV.NumField(); i++ { field := typeOfV.Field(i) if !field.IsExported() { continue } name := field.Tag.Get(tagName) if name == "" { name = field.Name } val := valueOfV.Field(i) // To cover slice and array, we delete the val then add it. p.Del(name) setVal(name, val) } } ================================================ FILE: client/request_bench_test.go ================================================ package client import ( "runtime" "runtime/metrics" "strconv" "testing" ) // BenchmarkRequestHeapScan measures how much heap memory the GC needs to scan // when a batch of requests is created and released. func BenchmarkRequestHeapScan(b *testing.B) { samples := []metrics.Sample{ {Name: "/gc/scan/heap:bytes"}, {Name: "/gc/scan/total:bytes"}, } b.ReportAllocs() b.StopTimer() b.ResetTimer() const batchSize = 512 var totalScanHeap, totalScanAll uint64 for i := 0; i < b.N; i++ { reqs := make([]*Request, batchSize) // revive:disable-next-line:call-to-gc // Ensure consistent heap state before measuring scan metrics runtime.GC() metrics.Read(samples) startScanHeap := samples[0].Value.Uint64() startScanAll := samples[1].Value.Uint64() b.StartTimer() for j := range reqs { req := AcquireRequest() req.SetHeader("X-Benchmark", "value") req.SetCookie("session", strconv.Itoa(j)) req.SetPathParam("id", strconv.Itoa(j)) req.SetParam("page", strconv.Itoa(j)) reqs[j] = req } b.StopTimer() // revive:disable-next-line:call-to-gc // Force GC to capture post-benchmark scan metrics runtime.GC() metrics.Read(samples) totalScanHeap += samples[0].Value.Uint64() - startScanHeap totalScanAll += samples[1].Value.Uint64() - startScanAll for _, req := range reqs { ReleaseRequest(req) } } if b.N > 0 { b.ReportMetric(float64(totalScanHeap)/float64(b.N), "scan-bytes-heap/op") b.ReportMetric(float64(totalScanAll)/float64(b.N), "scan-bytes-total/op") } } ================================================ FILE: client/request_test.go ================================================ package client import ( "bytes" "context" "errors" "io" "maps" "mime/multipart" "net" "os" "path/filepath" "regexp" "strings" "testing" "time" "github.com/gofiber/fiber/v3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" "github.com/valyala/fasthttp/fasthttputil" ) func Test_Request_Method(t *testing.T) { t.Parallel() req := AcquireRequest() req.SetMethod("GET") require.Equal(t, "GET", req.Method()) req.SetMethod("POST") require.Equal(t, "POST", req.Method()) req.SetMethod("PUT") require.Equal(t, "PUT", req.Method()) req.SetMethod("DELETE") require.Equal(t, "DELETE", req.Method()) req.SetMethod("PATCH") require.Equal(t, "PATCH", req.Method()) req.SetMethod("OPTIONS") require.Equal(t, "OPTIONS", req.Method()) req.SetMethod("HEAD") require.Equal(t, "HEAD", req.Method()) req.SetMethod("TRACE") require.Equal(t, "TRACE", req.Method()) req.SetMethod("CUSTOM") require.Equal(t, "CUSTOM", req.Method()) } func Test_Request_URL(t *testing.T) { t.Parallel() req := AcquireRequest() req.SetURL("http://example.com/normal") require.Equal(t, "http://example.com/normal", req.URL()) req.SetURL("https://example.com/normal") require.Equal(t, "https://example.com/normal", req.URL()) } func Test_Request_Client(t *testing.T) { t.Parallel() client := New() req := AcquireRequest() req.SetClient(client) require.Equal(t, client, req.Client()) } func Test_Request_Context(t *testing.T) { t.Parallel() req := AcquireRequest() ctx := req.Context() type ctxKey struct{} var key ctxKey = struct{}{} require.Nil(t, ctx.Value(key)) ctx = context.WithValue(ctx, key, "string") req.SetContext(ctx) ctx = req.Context() v, ok := ctx.Value(key).(string) require.True(t, ok) require.Equal(t, "string", v) } func Test_Request_Header(t *testing.T) { t.Parallel() t.Run("add header", func(t *testing.T) { t.Parallel() req := AcquireRequest() req.AddHeader("foo", "bar").AddHeader("foo", "fiber") res := req.Header("foo") require.Len(t, res, 2) require.Equal(t, "bar", res[0]) require.Equal(t, "fiber", res[1]) }) t.Run("set header", func(t *testing.T) { t.Parallel() req := AcquireRequest() req.AddHeader("foo", "bar").SetHeader("foo", "fiber") res := req.Header("foo") require.Len(t, res, 1) require.Equal(t, "fiber", res[0]) }) t.Run("add headers", func(t *testing.T) { t.Parallel() req := AcquireRequest() req.SetHeader("foo", "bar"). AddHeaders(map[string][]string{ "foo": {"fiber", "buaa"}, "bar": {"foo"}, }) res := req.Header("foo") require.Len(t, res, 3) require.Equal(t, "bar", res[0]) require.Equal(t, "fiber", res[1]) require.Equal(t, "buaa", res[2]) res = req.Header("bar") require.Len(t, res, 1) require.Equal(t, "foo", res[0]) }) t.Run("set headers", func(t *testing.T) { t.Parallel() req := AcquireRequest() req.SetHeader("foo", "bar"). SetHeaders(map[string]string{ "foo": "fiber", "bar": "foo", }) res := req.Header("foo") require.Len(t, res, 1) require.Equal(t, "fiber", res[0]) res = req.Header("bar") require.Len(t, res, 1) require.Equal(t, "foo", res[0]) }) } func Test_Request_Headers(t *testing.T) { t.Parallel() req := AcquireRequest() req.AddHeaders(map[string][]string{ "foo": {"bar", "fiber"}, "bar": {"foo"}, }) headers := maps.Collect(req.Headers()) require.Contains(t, headers["Foo"], "fiber") require.Contains(t, headers["Foo"], "bar") require.Contains(t, headers["Bar"], "foo") require.Len(t, headers, 2) } func Benchmark_Request_Headers(b *testing.B) { req := AcquireRequest() req.AddHeaders(map[string][]string{ "foo": {"bar", "fiber"}, "bar": {"foo"}, }) b.ReportAllocs() for b.Loop() { for k, v := range req.Headers() { _ = k _ = v } } } func Test_Request_QueryParam(t *testing.T) { t.Parallel() t.Run("add param", func(t *testing.T) { t.Parallel() req := AcquireRequest() req.AddParam("foo", "bar").AddParam("foo", "fiber") res := req.Param("foo") require.Len(t, res, 2) require.Equal(t, "bar", res[0]) require.Equal(t, "fiber", res[1]) }) t.Run("set param", func(t *testing.T) { t.Parallel() req := AcquireRequest() req.AddParam("foo", "bar").SetParam("foo", "fiber") res := req.Param("foo") require.Len(t, res, 1) require.Equal(t, "fiber", res[0]) }) t.Run("add params", func(t *testing.T) { t.Parallel() req := AcquireRequest() req.SetParam("foo", "bar"). AddParams(map[string][]string{ "foo": {"fiber", "buaa"}, "bar": {"foo"}, }) res := req.Param("foo") require.Len(t, res, 3) require.Equal(t, "bar", res[0]) require.Equal(t, "fiber", res[1]) require.Equal(t, "buaa", res[2]) res = req.Param("bar") require.Len(t, res, 1) require.Equal(t, "foo", res[0]) }) t.Run("set headers", func(t *testing.T) { t.Parallel() req := AcquireRequest() req.SetParam("foo", "bar"). SetParams(map[string]string{ "foo": "fiber", "bar": "foo", }) res := req.Param("foo") require.Len(t, res, 1) require.Equal(t, "fiber", res[0]) res = req.Param("bar") require.Len(t, res, 1) require.Equal(t, "foo", res[0]) }) t.Run("set params with struct", func(t *testing.T) { t.Parallel() type args struct { TString string TSlice []string TIntSlice []int `param:"int_slice"` TInt int TFloat float64 TBool bool } p := AcquireRequest() p.SetParamsWithStruct(&args{ TInt: 5, TString: "string", TFloat: 3.1, TBool: true, TSlice: []string{"foo", "bar"}, TIntSlice: []int{1, 2}, }) require.Empty(t, p.Param("unexport")) require.Len(t, p.Param("TInt"), 1) require.Equal(t, "5", p.Param("TInt")[0]) require.Len(t, p.Param("TString"), 1) require.Equal(t, "string", p.Param("TString")[0]) require.Len(t, p.Param("TFloat"), 1) require.Equal(t, "3.1", p.Param("TFloat")[0]) require.Len(t, p.Param("TBool"), 1) tslice := p.Param("TSlice") require.Len(t, tslice, 2) require.Equal(t, "foo", tslice[0]) require.Equal(t, "bar", tslice[1]) tint := p.Param("TSlice") require.Len(t, tint, 2) require.Equal(t, "foo", tint[0]) require.Equal(t, "bar", tint[1]) }) t.Run("del params", func(t *testing.T) { t.Parallel() req := AcquireRequest() req.SetParam("foo", "bar"). SetParams(map[string]string{ "foo": "fiber", "bar": "foo", }).DelParams("foo", "bar") res := req.Param("foo") require.Empty(t, res) res = req.Param("bar") require.Empty(t, res) }) } func Test_Request_Params(t *testing.T) { t.Parallel() t.Run("empty iterator", func(t *testing.T) { t.Parallel() req := AcquireRequest() t.Cleanup(func() { ReleaseRequest(req) }) called := false req.Params()(func(_ string, _ []string) bool { called = true return true }) require.False(t, called) }) t.Run("populated iterator", func(t *testing.T) { t.Parallel() req := AcquireRequest() t.Cleanup(func() { ReleaseRequest(req) }) req.AddParams(map[string][]string{ "foo": {"bar", "fiber"}, "bar": {"foo"}, }) pathParams := maps.Collect(req.Params()) require.Contains(t, pathParams["foo"], "bar") require.Contains(t, pathParams["foo"], "fiber") require.Contains(t, pathParams["bar"], "foo") require.Len(t, pathParams, 2) }) } func Benchmark_Request_Params(b *testing.B) { req := AcquireRequest() req.AddParams(map[string][]string{ "foo": {"bar", "fiber"}, "bar": {"foo"}, }) b.ReportAllocs() for b.Loop() { for k, v := range req.Params() { _ = k _ = v } } } func Test_Request_UA(t *testing.T) { t.Parallel() req := AcquireRequest().SetUserAgent("fiber") require.Equal(t, "fiber", req.UserAgent()) req.SetUserAgent("foo") require.Equal(t, "foo", req.UserAgent()) } func Test_Request_Referer(t *testing.T) { t.Parallel() req := AcquireRequest().SetReferer("http://example.com") require.Equal(t, "http://example.com", req.Referer()) req.SetReferer("https://example.com") require.Equal(t, "https://example.com", req.Referer()) } func Test_Request_Cookie(t *testing.T) { t.Parallel() t.Run("set cookie", func(t *testing.T) { t.Parallel() req := AcquireRequest(). SetCookie("foo", "bar") require.Equal(t, "bar", req.Cookie("foo")) req.SetCookie("foo", "bar1") require.Equal(t, "bar1", req.Cookie("foo")) }) t.Run("set cookies", func(t *testing.T) { t.Parallel() req := AcquireRequest(). SetCookies(map[string]string{ "foo": "bar", "bar": "foo", }) require.Equal(t, "bar", req.Cookie("foo")) require.Equal(t, "foo", req.Cookie("bar")) req.SetCookies(map[string]string{ "foo": "bar1", }) require.Equal(t, "bar1", req.Cookie("foo")) require.Equal(t, "foo", req.Cookie("bar")) }) t.Run("set cookies with struct", func(t *testing.T) { t.Parallel() type args struct { CookieString string `cookie:"string"` CookieInt int `cookie:"int"` } req := AcquireRequest().SetCookiesWithStruct(&args{ CookieInt: 5, CookieString: "foo", }) require.Equal(t, "5", req.Cookie("int")) require.Equal(t, "foo", req.Cookie("string")) }) t.Run("del cookies", func(t *testing.T) { t.Parallel() req := AcquireRequest(). SetCookies(map[string]string{ "foo": "bar", "bar": "foo", }) require.Equal(t, "bar", req.Cookie("foo")) require.Equal(t, "foo", req.Cookie("bar")) req.DelCookies("foo") require.Empty(t, req.Cookie("foo")) require.Equal(t, "foo", req.Cookie("bar")) }) } func Test_Request_Cookies(t *testing.T) { t.Parallel() req := AcquireRequest() req.SetCookies(map[string]string{ "foo": "bar", "bar": "foo", }) cookies := maps.Collect(req.Cookies()) require.Equal(t, "bar", cookies["foo"]) require.Equal(t, "foo", cookies["bar"]) require.NotPanics(t, func() { for _, v := range req.Cookies() { if v == "bar" { break } } }) require.Len(t, cookies, 2) } func Benchmark_Request_Cookies(b *testing.B) { req := AcquireRequest() req.SetCookies(map[string]string{ "foo": "bar", "bar": "foo", }) b.ReportAllocs() for b.Loop() { for k, v := range req.Cookies() { _ = k _ = v } } } func Test_Request_PathParam(t *testing.T) { t.Parallel() t.Run("set path param", func(t *testing.T) { t.Parallel() req := AcquireRequest(). SetPathParam("foo", "bar") require.Equal(t, "bar", req.PathParam("foo")) req.SetPathParam("foo", "bar1") require.Equal(t, "bar1", req.PathParam("foo")) }) t.Run("set path params", func(t *testing.T) { t.Parallel() req := AcquireRequest(). SetPathParams(map[string]string{ "foo": "bar", "bar": "foo", }) require.Equal(t, "bar", req.PathParam("foo")) require.Equal(t, "foo", req.PathParam("bar")) req.SetPathParams(map[string]string{ "foo": "bar1", }) require.Equal(t, "bar1", req.PathParam("foo")) require.Equal(t, "foo", req.PathParam("bar")) }) t.Run("set path params with struct", func(t *testing.T) { t.Parallel() type args struct { CookieString string `path:"string"` CookieInt int `path:"int"` } req := AcquireRequest().SetPathParamsWithStruct(&args{ CookieInt: 5, CookieString: "foo", }) require.Equal(t, "5", req.PathParam("int")) require.Equal(t, "foo", req.PathParam("string")) }) t.Run("del path params", func(t *testing.T) { t.Parallel() req := AcquireRequest(). SetPathParams(map[string]string{ "foo": "bar", "bar": "foo", }) require.Equal(t, "bar", req.PathParam("foo")) require.Equal(t, "foo", req.PathParam("bar")) req.DelPathParams("foo") require.Empty(t, req.PathParam("foo")) require.Equal(t, "foo", req.PathParam("bar")) }) t.Run("clear path params", func(t *testing.T) { t.Parallel() req := AcquireRequest(). SetPathParams(map[string]string{ "foo": "bar", "bar": "foo", }) require.Equal(t, "bar", req.PathParam("foo")) require.Equal(t, "foo", req.PathParam("bar")) req.ResetPathParams() require.Empty(t, req.PathParam("foo")) require.Empty(t, req.PathParam("bar")) }) } func Test_Request_PathParams(t *testing.T) { t.Parallel() req := AcquireRequest() req.SetPathParams(map[string]string{ "foo": "bar", "bar": "foo", }) pathParams := maps.Collect(req.PathParams()) require.Equal(t, "bar", pathParams["foo"]) require.Equal(t, "foo", pathParams["bar"]) require.Len(t, pathParams, 2) require.NotPanics(t, func() { for _, v := range req.PathParams() { if v == "bar" { break } } }) } func Benchmark_Request_PathParams(b *testing.B) { req := AcquireRequest() req.SetPathParams(map[string]string{ "foo": "bar", "bar": "foo", }) b.ReportAllocs() for b.Loop() { for k, v := range req.PathParams() { _ = k _ = v } } } func Test_Request_FormData(t *testing.T) { t.Parallel() t.Run("add form data", func(t *testing.T) { t.Parallel() req := AcquireRequest() defer ReleaseRequest(req) req.AddFormData("foo", "bar").AddFormData("foo", "fiber") res := req.FormData("foo") require.Len(t, res, 2) require.Equal(t, "bar", res[0]) require.Equal(t, "fiber", res[1]) }) t.Run("set param", func(t *testing.T) { t.Parallel() req := AcquireRequest() defer ReleaseRequest(req) req.AddFormData("foo", "bar").SetFormData("foo", "fiber") res := req.FormData("foo") require.Len(t, res, 1) require.Equal(t, "fiber", res[0]) }) t.Run("add params", func(t *testing.T) { t.Parallel() req := AcquireRequest() defer ReleaseRequest(req) req.SetFormData("foo", "bar"). AddFormDataWithMap(map[string][]string{ "foo": {"fiber", "buaa"}, "bar": {"foo"}, }) res := req.FormData("foo") require.Len(t, res, 3) require.Contains(t, res, "bar") require.Contains(t, res, "buaa") require.Contains(t, res, "fiber") res = req.FormData("bar") require.Len(t, res, 1) require.Equal(t, "foo", res[0]) }) t.Run("set headers", func(t *testing.T) { t.Parallel() req := AcquireRequest() defer ReleaseRequest(req) req.SetFormData("foo", "bar"). SetFormDataWithMap(map[string]string{ "foo": "fiber", "bar": "foo", }) res := req.FormData("foo") require.Len(t, res, 1) require.Equal(t, "fiber", res[0]) res = req.FormData("bar") require.Len(t, res, 1) require.Equal(t, "foo", res[0]) }) t.Run("set params with struct", func(t *testing.T) { t.Parallel() type args struct { TString string TSlice []string TIntSlice []int `form:"int_slice"` TInt int TFloat float64 TBool bool } p := AcquireRequest() defer ReleaseRequest(p) p.SetFormDataWithStruct(&args{ TInt: 5, TString: "string", TFloat: 3.1, TBool: true, TSlice: []string{"foo", "bar"}, TIntSlice: []int{1, 2}, }) require.Empty(t, p.FormData("unexport")) require.Len(t, p.FormData("TInt"), 1) require.Equal(t, "5", p.FormData("TInt")[0]) require.Len(t, p.FormData("TString"), 1) require.Equal(t, "string", p.FormData("TString")[0]) require.Len(t, p.FormData("TFloat"), 1) require.Equal(t, "3.1", p.FormData("TFloat")[0]) require.Len(t, p.FormData("TBool"), 1) tslice := p.FormData("TSlice") require.Len(t, tslice, 2) require.Contains(t, tslice, "bar") require.Contains(t, tslice, "foo") tint := p.FormData("TSlice") require.Len(t, tint, 2) require.Contains(t, tint, "bar") require.Contains(t, tint, "foo") }) t.Run("del params", func(t *testing.T) { t.Parallel() req := AcquireRequest() defer ReleaseRequest(req) req.SetFormData("foo", "bar"). SetFormDataWithMap(map[string]string{ "foo": "fiber", "bar": "foo", }).DelFormData("foo", "bar") res := req.FormData("foo") require.Empty(t, res) res = req.FormData("bar") require.Empty(t, res) }) } func Test_Request_File(t *testing.T) { t.Parallel() t.Run("add file", func(t *testing.T) { t.Parallel() req := AcquireRequest(). AddFile("../.github/index.html"). AddFiles(AcquireFile(SetFileName("tmp.txt"))) require.Equal(t, "../.github/index.html", req.File("index.html").path) require.Equal(t, "../.github/index.html", req.FileByPath("../.github/index.html").path) require.Equal(t, "tmp.txt", req.File("tmp.txt").name) require.Nil(t, req.File("tmp2.txt")) require.Nil(t, req.FileByPath("tmp2.txt")) }) t.Run("add file by reader", func(t *testing.T) { t.Parallel() req := AcquireRequest(). AddFileWithReader("tmp.txt", io.NopCloser(strings.NewReader("world"))) require.Equal(t, "tmp.txt", req.File("tmp.txt").name) content, err := io.ReadAll(req.File("tmp.txt").reader) require.NoError(t, err) require.Equal(t, "world", string(content)) }) t.Run("add files", func(t *testing.T) { t.Parallel() req := AcquireRequest(). AddFiles(AcquireFile(SetFileName("tmp.txt")), AcquireFile(SetFileName("foo.txt"))) require.Equal(t, "tmp.txt", req.File("tmp.txt").name) require.Equal(t, "foo.txt", req.File("foo.txt").name) }) } func Test_Request_Files(t *testing.T) { t.Parallel() req := AcquireRequest() req.AddFile("../.github/index.html") req.AddFiles(AcquireFile(SetFileName("tmp.txt"))) files := req.Files() require.Equal(t, "../.github/index.html", files[0].path) require.Nil(t, files[0].reader) require.Equal(t, "tmp.txt", files[1].name) require.Nil(t, files[1].reader) require.Len(t, files, 2) } func Benchmark_Request_Files(b *testing.B) { req := AcquireRequest() req.AddFile("../.github/index.html") req.AddFiles(AcquireFile(SetFileName("tmp.txt"))) b.ReportAllocs() for b.Loop() { for k, v := range req.Files() { _ = k _ = v } } } func Test_Request_Timeout(t *testing.T) { t.Parallel() req := AcquireRequest().SetTimeout(5 * time.Second) require.Equal(t, 5*time.Second, req.Timeout()) } func Test_Request_Invalid_URL(t *testing.T) { t.Parallel() resp, err := AcquireRequest(). Get("http://example.com\r\n\r\nGET /\r\n\r\n") require.Equal(t, ErrURLFormat, err) require.Equal(t, (*Response)(nil), resp) } func Test_Request_Unsupported_Protocol(t *testing.T) { t.Parallel() resp, err := AcquireRequest(). Get("ftp://example.com") require.Equal(t, ErrURLFormat, err) require.Equal(t, (*Response)(nil), resp) } func Test_Request_Get(t *testing.T) { t.Parallel() app, ln, start := createHelperServer(t) app.Get("/", func(c fiber.Ctx) error { return c.SendString(c.Hostname()) }) go start() time.Sleep(100 * time.Millisecond) client := New().SetDial(ln) for range 5 { req := AcquireRequest().SetClient(client) resp, err := req.Get("http://example.com") require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "example.com", resp.String()) resp.Close() } } func Test_Request_Post(t *testing.T) { t.Parallel() app, ln, start := createHelperServer(t) app.Post("/", func(c fiber.Ctx) error { return c.Status(fiber.StatusCreated). SendString(c.FormValue("foo")) }) go start() time.Sleep(100 * time.Millisecond) client := New().SetDial(ln) for range 5 { resp, err := AcquireRequest(). SetClient(client). SetFormData("foo", "bar"). Post("http://example.com") require.NoError(t, err) require.Equal(t, fiber.StatusCreated, resp.StatusCode()) require.Equal(t, "bar", resp.String()) resp.Close() } } func Test_Request_Head(t *testing.T) { t.Parallel() app, ln, start := createHelperServer(t) app.Head("/", func(c fiber.Ctx) error { return c.SendString(c.Hostname()) }) go start() time.Sleep(100 * time.Millisecond) client := New().SetDial(ln) for range 5 { resp, err := AcquireRequest(). SetClient(client). Head("http://example.com") require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Empty(t, resp.String()) resp.Close() } } func Test_Request_Put(t *testing.T) { t.Parallel() app, ln, start := createHelperServer(t) app.Put("/", func(c fiber.Ctx) error { return c.SendString(c.FormValue("foo")) }) go start() time.Sleep(100 * time.Millisecond) client := New().SetDial(ln) for range 5 { resp, err := AcquireRequest(). SetClient(client). SetFormData("foo", "bar"). Put("http://example.com") require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "bar", resp.String()) resp.Close() } } func Test_Request_Delete(t *testing.T) { t.Parallel() app, ln, start := createHelperServer(t) app.Delete("/", func(c fiber.Ctx) error { return c.Status(fiber.StatusNoContent). SendString("deleted") }) go start() time.Sleep(100 * time.Millisecond) client := New().SetDial(ln) for range 5 { resp, err := AcquireRequest(). SetClient(client). Delete("http://example.com") require.NoError(t, err) require.Equal(t, fiber.StatusNoContent, resp.StatusCode()) require.Empty(t, resp.String()) resp.Close() } } func Test_Request_Options(t *testing.T) { t.Parallel() app, ln, start := createHelperServer(t) app.Options("/", func(c fiber.Ctx) error { return c.Status(fiber.StatusOK). SendString("options") }) go start() time.Sleep(100 * time.Millisecond) client := New().SetDial(ln) for range 5 { resp, err := AcquireRequest(). SetClient(client). Options("http://example.com") require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "options", resp.String()) resp.Close() } } func Test_Request_Send(t *testing.T) { t.Parallel() app, ln, start := createHelperServer(t) app.Post("/", func(c fiber.Ctx) error { return c.Status(fiber.StatusOK). SendString("post") }) go start() time.Sleep(100 * time.Millisecond) client := New().SetDial(ln) for range 5 { resp, err := AcquireRequest(). SetClient(client). SetURL("http://example.com"). SetMethod(fiber.MethodPost). Send() require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "post", resp.String()) resp.Close() } } func Test_Request_Patch(t *testing.T) { t.Parallel() app, ln, start := createHelperServer(t) app.Patch("/", func(c fiber.Ctx) error { return c.SendString(c.FormValue("foo")) }) go start() time.Sleep(100 * time.Millisecond) client := New().SetDial(ln) for range 5 { resp, err := AcquireRequest(). SetClient(client). SetFormData("foo", "bar"). Patch("http://example.com") require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "bar", resp.String()) resp.Close() } } func Test_Request_Header_With_Server(t *testing.T) { t.Parallel() handler := func(c fiber.Ctx) error { for key, value := range c.Request().Header.All() { if k := string(key); k == "K1" || k == "K2" { _, err := c.Write(key) require.NoError(t, err) _, err = c.Write(value) require.NoError(t, err) } } return nil } wrapAgent := func(r *Request) { r.SetHeader("k1", "v1"). AddHeader("k1", "v11"). AddHeaders(map[string][]string{ "k1": {"v22", "v33"}, }). SetHeaders(map[string]string{ "k2": "v2", }). AddHeader("k2", "v22") } testRequest(t, handler, wrapAgent, "K1v1K1v11K1v22K1v33K2v2K2v22") } func Test_Request_UserAgent_With_Server(t *testing.T) { t.Parallel() handler := func(c fiber.Ctx) error { return c.Send(c.Request().Header.UserAgent()) } t.Run("default", func(t *testing.T) { t.Parallel() testRequest(t, handler, func(_ *Request) {}, defaultUserAgent, 5) }) t.Run("custom", func(t *testing.T) { t.Parallel() testRequest(t, handler, func(agent *Request) { agent.SetUserAgent("ua") }, "ua", 5) }) } func Test_Request_Cookie_With_Server(t *testing.T) { t.Parallel() handler := func(c fiber.Ctx) error { return c.SendString( c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3") + c.Cookies("k4")) } wrapAgent := func(req *Request) { req.SetCookie("k1", "v1"). SetCookies(map[string]string{ "k2": "v2", "k3": "v3", "k4": "v4", }).DelCookies("k4") } testRequest(t, handler, wrapAgent, "v1v2v3") } func Test_Request_Referer_With_Server(t *testing.T) { t.Parallel() handler := func(c fiber.Ctx) error { return c.Send(c.Request().Header.Referer()) } wrapAgent := func(req *Request) { req.SetReferer("http://referer.com") } testRequest(t, handler, wrapAgent, "http://referer.com") } func Test_Request_QueryString_With_Server(t *testing.T) { t.Parallel() handler := func(c fiber.Ctx) error { return c.Send(c.Request().URI().QueryString()) } wrapAgent := func(req *Request) { req.SetParam("foo", "bar"). SetParams(map[string]string{ "bar": "baz", }) } testRequest(t, handler, wrapAgent, "foo=bar&bar=baz") } func checkFormFile(t *testing.T, fh *multipart.FileHeader, filename string) { t.Helper() basename := filepath.Base(filename) require.Equal(t, fh.Filename, basename) b1, err := os.ReadFile(filepath.Clean(filename)) require.NoError(t, err) b2 := make([]byte, fh.Size) f, err := fh.Open() require.NoError(t, err) defer func() { require.NoError(t, f.Close()) }() _, err = f.Read(b2) require.NoError(t, err) require.Equal(t, b1, b2) } func Test_Request_Body_With_Server(t *testing.T) { t.Parallel() t.Run("json body", func(t *testing.T) { t.Parallel() testRequest(t, func(c fiber.Ctx) error { require.Equal(t, "application/json", string(c.Request().Header.ContentType())) return c.SendString(string(c.Request().Body())) }, func(agent *Request) { agent.SetJSON(map[string]string{ "success": "hello", }) }, "{\"success\":\"hello\"}", ) }) t.Run("xml body", func(t *testing.T) { t.Parallel() testRequest(t, func(c fiber.Ctx) error { require.Equal(t, "application/xml", string(c.Request().Header.ContentType())) return c.SendString(string(c.Request().Body())) }, func(agent *Request) { type args struct { Content string `xml:"content"` } agent.SetXML(args{ Content: "hello", }) }, "hello", ) }) t.Run("cbor body", func(t *testing.T) { t.Parallel() testRequest(t, func(c fiber.Ctx) error { require.Equal(t, "application/cbor", string(c.Request().Header.ContentType())) return c.SendString(string(c.Request().Body())) }, func(agent *Request) { type args struct { Content string `cbor:"content"` } agent.SetCBOR(args{ Content: "hello", }) }, "\xa1gcontentehello", ) }) t.Run("formdata", func(t *testing.T) { t.Parallel() testRequest(t, func(c fiber.Ctx) error { require.Equal(t, fiber.MIMEApplicationForm, string(c.Request().Header.ContentType())) return c.Send([]byte("foo=" + c.FormValue("foo") + "&bar=" + c.FormValue("bar") + "&fiber=" + c.FormValue("fiber"))) }, func(agent *Request) { agent.SetFormData("foo", "bar"). SetFormDataWithMap(map[string]string{ "bar": "baz", "fiber": "fast", }) }, "foo=bar&bar=baz&fiber=fast") }) t.Run("multipart form", func(t *testing.T) { t.Parallel() app, ln, start := createHelperServer(t) app.Post("/", func(c fiber.Ctx) error { require.Equal(t, "multipart/form-data; boundary=myBoundary", c.Get(fiber.HeaderContentType)) mf, err := c.MultipartForm() require.NoError(t, err) require.Equal(t, "bar", mf.Value["foo"][0]) return c.Send(c.Request().Body()) }) go start() client := New().SetDial(ln) req := AcquireRequest(). SetClient(client). SetBoundary("myBoundary"). SetFormData("foo", "bar"). AddFiles(AcquireFile( SetFileName("hello.txt"), SetFileFieldName("foo"), SetFileReader(io.NopCloser(strings.NewReader("world"))), )) require.Equal(t, "myBoundary", req.Boundary()) resp, err := req.Post("http://example.com") require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) form, err := multipart.NewReader(bytes.NewReader(resp.Body()), "myBoundary").ReadForm(1024 * 1024) require.NoError(t, err) require.Equal(t, "bar", form.Value["foo"][0]) resp.Close() }) t.Run("multipart form send file", func(t *testing.T) { t.Parallel() app, ln, start := createHelperServer(t) app.Post("/", func(c fiber.Ctx) error { require.Equal(t, "multipart/form-data; boundary=myBoundary", c.Get(fiber.HeaderContentType)) fh1, err := c.FormFile("field1") require.NoError(t, err) require.Equal(t, "name", fh1.Filename) buf := make([]byte, fh1.Size) f, err := fh1.Open() require.NoError(t, err) defer func() { require.NoError(t, f.Close()) }() _, err = f.Read(buf) require.NoError(t, err) require.Equal(t, "form file", string(buf)) fh2, err := c.FormFile("file2") require.NoError(t, err) checkFormFile(t, fh2, "../.github/testdata/index.html") fh3, err := c.FormFile("file3") require.NoError(t, err) checkFormFile(t, fh3, "../.github/testdata/index.tmpl") return c.SendString("multipart form files") }) go start() client := New().SetDial(ln) for range 5 { req := AcquireRequest(). SetClient(client). AddFiles( AcquireFile( SetFileFieldName("field1"), SetFileName("name"), SetFileReader(io.NopCloser(bytes.NewReader([]byte("form file")))), ), ). AddFile("../.github/testdata/index.html"). AddFile("../.github/testdata/index.tmpl"). SetBoundary("myBoundary") resp, err := req.Post("http://example.com") require.NoError(t, err) require.Equal(t, "multipart form files", resp.String()) resp.Close() } }) t.Run("multipart random boundary", func(t *testing.T) { t.Parallel() app, ln, start := createHelperServer(t) app.Post("/", func(c fiber.Ctx) error { reg := regexp.MustCompile(`multipart/form-data; boundary=[\-\w]{33}`) require.True(t, reg.MatchString(c.Get(fiber.HeaderContentType))) return c.Send(c.Request().Body()) }) go start() client := New().SetDial(ln) req := AcquireRequest(). SetClient(client). SetFormData("foo", "bar"). AddFiles(AcquireFile( SetFileName("hello.txt"), SetFileFieldName("foo"), SetFileReader(io.NopCloser(strings.NewReader("world"))), )) resp, err := req.Post("http://example.com") require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) }) t.Run("raw body", func(t *testing.T) { t.Parallel() testRequest(t, func(c fiber.Ctx) error { return c.SendString(string(c.Request().Body())) }, func(agent *Request) { agent.SetRawBody([]byte("hello")) }, "hello", ) }) } func Test_Request_AllFormData(t *testing.T) { t.Parallel() t.Run("empty iterator", func(t *testing.T) { t.Parallel() req := AcquireRequest() t.Cleanup(func() { ReleaseRequest(req) }) called := false req.AllFormData()(func(_ string, _ []string) bool { called = true return true }) require.False(t, called) }) t.Run("populated iterator", func(t *testing.T) { t.Parallel() req := AcquireRequest() t.Cleanup(func() { ReleaseRequest(req) }) req.AddFormDataWithMap(map[string][]string{ "foo": {"bar", "fiber"}, "bar": {"foo"}, }) pathParams := maps.Collect(req.AllFormData()) require.Contains(t, pathParams["foo"], "bar") require.Contains(t, pathParams["foo"], "fiber") require.Contains(t, pathParams["bar"], "foo") require.Len(t, pathParams, 2) }) } func Benchmark_Request_AllFormData(b *testing.B) { req := AcquireRequest() req.AddFormDataWithMap(map[string][]string{ "foo": {"bar", "fiber"}, "bar": {"foo"}, }) b.ReportAllocs() for b.Loop() { for k, v := range req.AllFormData() { _ = k _ = v } } } func Test_Request_Error_Body_With_Server(t *testing.T) { t.Parallel() t.Run("json error", func(t *testing.T) { t.Parallel() testRequestFail(t, func(c fiber.Ctx) error { return c.SendString("") }, func(agent *Request) { agent.SetJSON(complex(1, 1)) }, errors.New("json: unsupported type: complex128"), ) }) t.Run("xml error", func(t *testing.T) { t.Parallel() testRequestFail(t, func(c fiber.Ctx) error { return c.SendString("") }, func(agent *Request) { agent.SetXML(complex(1, 1)) }, errors.New("xml: unsupported type: complex128"), ) }) t.Run("form body with invalid boundary", func(t *testing.T) { t.Parallel() _, err := AcquireRequest(). SetBoundary("*"). AddFileWithReader("t.txt", io.NopCloser(strings.NewReader("world"))). Get("http://example.com") require.Equal(t, "set boundary error: mime: invalid boundary character", err.Error()) }) t.Run("open non exist file", func(t *testing.T) { t.Parallel() _, err := AcquireRequest(). AddFile("non-exist-file!"). Get("http://example.com") require.Contains(t, err.Error(), "open non-exist-file!") }) } func Test_Request_Timeout_With_Server(t *testing.T) { t.Parallel() app, ln, start := createHelperServer(t) app.Get("/", func(c fiber.Ctx) error { time.Sleep(time.Millisecond * 200) return c.SendString("timeout") }) go start() client := New().SetDial(ln) _, err := AcquireRequest(). SetClient(client). SetTimeout(50 * time.Millisecond). Get("http://example.com") require.Equal(t, ErrTimeoutOrCancel, err) } func Test_Request_MaxRedirects(t *testing.T) { t.Parallel() ln := fasthttputil.NewInmemoryListener() app := fiber.New() app.Get("/", func(c fiber.Ctx) error { if c.Request().URI().QueryArgs().Has("foo") { return c.Redirect().To("/foo") } return c.Redirect().To("/") }) app.Get("/foo", func(c fiber.Ctx) error { return c.SendString("redirect") }) go func() { assert.NoError(t, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) }() t.Run("success", func(t *testing.T) { t.Parallel() client := New().SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) resp, err := AcquireRequest(). SetClient(client). SetMaxRedirects(1). Get("http://example.com?foo") body := resp.String() code := resp.StatusCode() require.Equal(t, 200, code) require.Equal(t, "redirect", body) require.NoError(t, err) resp.Close() }) t.Run("error", func(t *testing.T) { t.Parallel() client := New().SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) resp, err := AcquireRequest(). SetClient(client). SetMaxRedirects(1). Get("http://example.com") require.Nil(t, resp) require.Equal(t, "too many redirects detected when doing the request", err.Error()) }) t.Run("MaxRedirects", func(t *testing.T) { t.Parallel() client := New().SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) req := AcquireRequest(). SetClient(client). SetMaxRedirects(3) require.Equal(t, 3, req.MaxRedirects()) }) } func Test_SetValWithStruct(t *testing.T) { t.Parallel() // test SetValWithStruct via QueryParam struct. type args struct { TString string TSlice []string TIntSlice []int `param:"int_slice"` unexport int TInt int TUint uint TFloat float64 TComplex complex128 TBool bool } t.Run("the struct should be applied", func(t *testing.T) { t.Parallel() p := &QueryParam{ Args: fasthttp.AcquireArgs(), } SetValWithStruct(p, "param", args{ unexport: 5, TInt: 5, TUint: 5, TString: "string", TFloat: 3.1, TComplex: 3 + 4i, TBool: false, TSlice: []string{"foo", "bar"}, TIntSlice: []int{0, 1, 2}, }) require.Empty(t, string(p.Peek("unexport"))) require.Equal(t, []byte("5"), p.Peek("TInt")) require.Equal(t, []byte("5"), p.Peek("TUint")) require.Equal(t, []byte("string"), p.Peek("TString")) require.Equal(t, []byte("3.1"), p.Peek("TFloat")) require.Equal(t, []byte("(3+4i)"), p.Peek("TComplex")) require.Equal(t, []byte("false"), p.Peek("TBool")) require.True(t, func() bool { for _, v := range p.PeekMulti("TSlice") { if string(v) == "foo" { return true } } return false }()) require.True(t, func() bool { for _, v := range p.PeekMulti("TSlice") { if string(v) == "bar" { return true } } return false }()) require.True(t, func() bool { for _, v := range p.PeekMulti("int_slice") { if string(v) == "0" { return true } } return false }()) require.True(t, func() bool { for _, v := range p.PeekMulti("int_slice") { if string(v) == "1" { return true } } return false }()) require.True(t, func() bool { for _, v := range p.PeekMulti("int_slice") { if string(v) == "2" { return true } } return false }()) }) t.Run("the pointer of a struct should be applied", func(t *testing.T) { t.Parallel() p := &QueryParam{ Args: fasthttp.AcquireArgs(), } SetValWithStruct(p, "param", &args{ TInt: 5, TString: "string", TFloat: 3.1, TBool: true, TSlice: []string{"foo", "bar"}, TIntSlice: []int{1, 2}, }) require.Equal(t, []byte("5"), p.Peek("TInt")) require.Equal(t, []byte("string"), p.Peek("TString")) require.Equal(t, []byte("3.1"), p.Peek("TFloat")) require.Equal(t, "true", string(p.Peek("TBool"))) require.True(t, func() bool { for _, v := range p.PeekMulti("TSlice") { if string(v) == "foo" { return true } } return false }()) require.True(t, func() bool { for _, v := range p.PeekMulti("TSlice") { if string(v) == "bar" { return true } } return false }()) require.True(t, func() bool { for _, v := range p.PeekMulti("int_slice") { if string(v) == "1" { return true } } return false }()) require.True(t, func() bool { for _, v := range p.PeekMulti("int_slice") { if string(v) == "2" { return true } } return false }()) }) t.Run("error type should ignore", func(t *testing.T) { t.Parallel() p := &QueryParam{ Args: fasthttp.AcquireArgs(), } SetValWithStruct(p, "param", 5) require.Equal(t, 0, p.Len()) }) } func Benchmark_SetValWithStruct(b *testing.B) { // test SetValWithStruct via QueryParam struct. type args struct { TString string TSlice []string TIntSlice []int `param:"int_slice"` unexport int TInt int TUint uint TFloat float64 TComplex complex128 TBool bool } b.Run("the struct should be applied", func(b *testing.B) { p := &QueryParam{ Args: fasthttp.AcquireArgs(), } b.ReportAllocs() for b.Loop() { SetValWithStruct(p, "param", args{ unexport: 5, TInt: 5, TUint: 5, TString: "string", TFloat: 3.1, TComplex: 3 + 4i, TBool: false, TSlice: []string{"foo", "bar"}, TIntSlice: []int{0, 1, 2}, }) } require.Empty(b, string(p.Peek("unexport"))) require.Equal(b, []byte("5"), p.Peek("TInt")) require.Equal(b, []byte("5"), p.Peek("TUint")) require.Equal(b, []byte("string"), p.Peek("TString")) require.Equal(b, []byte("3.1"), p.Peek("TFloat")) require.Equal(b, []byte("(3+4i)"), p.Peek("TComplex")) require.Equal(b, []byte("false"), p.Peek("TBool")) require.True(b, func() bool { for _, v := range p.PeekMulti("TSlice") { if string(v) == "foo" { return true } } return false }()) require.True(b, func() bool { for _, v := range p.PeekMulti("TSlice") { if string(v) == "bar" { return true } } return false }()) require.True(b, func() bool { for _, v := range p.PeekMulti("int_slice") { if string(v) == "0" { return true } } return false }()) require.True(b, func() bool { for _, v := range p.PeekMulti("int_slice") { if string(v) == "1" { return true } } return false }()) require.True(b, func() bool { for _, v := range p.PeekMulti("int_slice") { if string(v) == "2" { return true } } return false }()) }) b.Run("the pointer of a struct should be applied", func(b *testing.B) { p := &QueryParam{ Args: fasthttp.AcquireArgs(), } b.ReportAllocs() for b.Loop() { SetValWithStruct(p, "param", &args{ TInt: 5, TString: "string", TFloat: 3.1, TBool: true, TSlice: []string{"foo", "bar"}, TIntSlice: []int{1, 2}, }) } require.Equal(b, []byte("5"), p.Peek("TInt")) require.Equal(b, []byte("string"), p.Peek("TString")) require.Equal(b, []byte("3.1"), p.Peek("TFloat")) require.Equal(b, "true", string(p.Peek("TBool"))) require.True(b, func() bool { for _, v := range p.PeekMulti("TSlice") { if string(v) == "foo" { return true } } return false }()) require.True(b, func() bool { for _, v := range p.PeekMulti("TSlice") { if string(v) == "bar" { return true } } return false }()) require.True(b, func() bool { for _, v := range p.PeekMulti("int_slice") { if string(v) == "1" { return true } } return false }()) require.True(b, func() bool { for _, v := range p.PeekMulti("int_slice") { if string(v) == "2" { return true } } return false }()) }) b.Run("error type should ignore", func(b *testing.B) { p := &QueryParam{ Args: fasthttp.AcquireArgs(), } b.ReportAllocs() for b.Loop() { SetValWithStruct(p, "param", 5) } require.Equal(b, 0, p.Len()) }) } ================================================ FILE: client/response.go ================================================ package client import ( "bytes" "errors" "fmt" "io" "io/fs" "iter" "os" "path/filepath" "sync" "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) // Response represents the result of a request. It provides access to the response data. type Response struct { client *Client request *Request RawResponse *fasthttp.Response cookie []*fasthttp.Cookie } // setClient sets the client instance in the response. The client object is used by core functionalities. func (r *Response) setClient(c *Client) { r.client = c } // setRequest sets the request object in the response. The request is released when Response.Close is called. func (r *Response) setRequest(req *Request) { r.request = req } // Status returns the HTTP status message of the executed request. func (r *Response) Status() string { return string(r.RawResponse.Header.StatusMessage()) } // StatusCode returns the HTTP status code of the executed request. func (r *Response) StatusCode() int { return r.RawResponse.StatusCode() } // Protocol returns the HTTP protocol used for the request. func (r *Response) Protocol() string { return string(r.RawResponse.Header.Protocol()) } // Header returns the value of the specified response header field. func (r *Response) Header(key string) string { return utils.UnsafeString(r.RawResponse.Header.Peek(key)) } // Headers returns all headers in the response using an iterator. // Use maps.Collect() to gather them into a map if needed. // // The returned values are valid only until the response object is released. // Do not store references to returned values; make copies instead. func (r *Response) Headers() iter.Seq2[string, []string] { return func(yield func(string, []string) bool) { keys := r.RawResponse.Header.PeekKeys() for _, key := range keys { vals := r.RawResponse.Header.PeekAll(utils.UnsafeString(key)) valsStr := make([]string, len(vals)) for i, v := range vals { valsStr[i] = utils.UnsafeString(v) } if !yield(utils.UnsafeString(key), valsStr) { return } } } } // Cookies returns all cookies set by the response. // // The returned values are valid only until the response object is released. // Do not store references to returned values; make copies instead. func (r *Response) Cookies() []*fasthttp.Cookie { return r.cookie } // Body returns the HTTP response body as a byte slice. func (r *Response) Body() []byte { return r.RawResponse.Body() } // BodyStream returns the response body as a stream reader. // Note: When using BodyStream(), the response body is not copied to memory, // so calling Body() afterwards may return an empty slice. func (r *Response) BodyStream() io.Reader { if stream := r.RawResponse.BodyStream(); stream != nil { return stream } // If streaming is not enabled, return a bytes.Reader from the regular body return bytes.NewReader(r.RawResponse.Body()) } // IsStreaming returns true if the response body is being streamed. func (r *Response) IsStreaming() bool { return r.RawResponse.BodyStream() != nil } // String returns the response body as a trimmed string. func (r *Response) String() string { return utils.TrimSpace(string(r.Body())) } // JSON unmarshal the response body into the given interface{} using JSON. func (r *Response) JSON(v any) error { if r.client == nil { return ErrClientNil } return r.client.jsonUnmarshal(r.Body(), v) } // CBOR unmarshal the response body into the given interface{} using CBOR. func (r *Response) CBOR(v any) error { if r.client == nil { return ErrClientNil } return r.client.cborUnmarshal(r.Body(), v) } // XML unmarshal the response body into the given interface{} using XML. func (r *Response) XML(v any) error { if r.client == nil { return ErrClientNil } return r.client.xmlUnmarshal(r.Body(), v) } // Save writes the response body to a file or io.Writer. // If a string path is provided, it creates directories if needed, then writes to a file. // If an io.Writer is provided, it writes directly to it. // When streaming is enabled, the body is read directly from the stream. func (r *Response) Save(v any) error { switch p := v.(type) { case string: file := filepath.Clean(p) dir := filepath.Dir(file) // Create directory if it doesn't exist if _, err := os.Stat(dir); err != nil { if !errors.Is(err, fs.ErrNotExist) { return fmt.Errorf("failed to check directory: %w", err) } if err = os.MkdirAll(dir, 0o750); err != nil { return fmt.Errorf("failed to create directory: %w", err) } } // Create and write to file outFile, err := os.Create(file) if err != nil { return fmt.Errorf("failed to create file: %w", err) } defer func() { _ = outFile.Close() }() //nolint:errcheck // not needed // Use BodyStream() which handles both streaming and non-streaming cases if _, err = io.Copy(outFile, r.BodyStream()); err != nil { return fmt.Errorf("failed to write response body to file: %w", err) } return nil case io.Writer: // Use BodyStream() which handles both streaming and non-streaming cases if _, err := io.Copy(p, r.BodyStream()); err != nil { return fmt.Errorf("failed to write response body to writer: %w", err) } // Close the writer if it implements io.WriteCloser if pc, ok := p.(io.WriteCloser); ok { _ = pc.Close() //nolint:errcheck // not needed } return nil default: return ErrNotSupportSaveMethod } } // Reset clears the Response object, making it ready for reuse. func (r *Response) Reset() { r.client = nil r.request = nil for len(r.cookie) != 0 { t := r.cookie[0] r.cookie = r.cookie[1:] fasthttp.ReleaseCookie(t) } r.RawResponse.Reset() } // Close releases both the Request and Response objects back to their pools. // After calling Close, do not use these objects. func (r *Response) Close() { if r.request != nil { tmp := r.request r.request = nil ReleaseRequest(tmp) } ReleaseResponse(r) } var responsePool = &sync.Pool{ New: func() any { return &Response{ cookie: []*fasthttp.Cookie{}, RawResponse: fasthttp.AcquireResponse(), } }, } // AcquireResponse returns a new (pooled) Response object. // When done, release it with ReleaseResponse to reduce GC load. func AcquireResponse() *Response { resp, ok := responsePool.Get().(*Response) if !ok { panic("unexpected type from responsePool.Get()") } return resp } // ReleaseResponse returns the Response object to the pool. // Do not use the released Response afterward to avoid data races. func ReleaseResponse(resp *Response) { resp.Reset() responsePool.Put(resp) } ================================================ FILE: client/response_test.go ================================================ package client import ( "bytes" "crypto/tls" "encoding/xml" "errors" "io" "net" "os" "path/filepath" "testing" "github.com/gofiber/fiber/v3/internal/tlstest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/gofiber/fiber/v3" ) func Test_Response_Status(t *testing.T) { t.Parallel() setupApp := func() *testServer { server := startTestServer(t, func(app *fiber.App) { app.Get("/", func(c fiber.Ctx) error { return c.SendString("foo") }) app.Get("/fail", func(c fiber.Ctx) error { return c.SendStatus(407) }) }) return server } t.Run("success", func(t *testing.T) { t.Parallel() server := setupApp() defer server.stop() client := New().SetDial(server.dial()) resp, err := AcquireRequest(). SetClient(client). Get("http://example") require.NoError(t, err) require.Equal(t, "OK", resp.Status()) resp.Close() }) t.Run("fail", func(t *testing.T) { t.Parallel() server := setupApp() defer server.stop() client := New().SetDial(server.dial()) resp, err := AcquireRequest(). SetClient(client). Get("http://example/fail") require.NoError(t, err) require.Equal(t, "Proxy Authentication Required", resp.Status()) resp.Close() }) } func Test_Response_Status_Code(t *testing.T) { t.Parallel() setupApp := func() *testServer { server := startTestServer(t, func(app *fiber.App) { app.Get("/", func(c fiber.Ctx) error { return c.SendString("foo") }) app.Get("/fail", func(c fiber.Ctx) error { return c.SendStatus(407) }) }) return server } t.Run("success", func(t *testing.T) { t.Parallel() server := setupApp() defer server.stop() client := New().SetDial(server.dial()) resp, err := AcquireRequest(). SetClient(client). Get("http://example") require.NoError(t, err) require.Equal(t, 200, resp.StatusCode()) resp.Close() }) t.Run("fail", func(t *testing.T) { t.Parallel() server := setupApp() defer server.stop() client := New().SetDial(server.dial()) resp, err := AcquireRequest(). SetClient(client). Get("http://example/fail") require.NoError(t, err) require.Equal(t, 407, resp.StatusCode()) resp.Close() }) } func Test_Response_Protocol(t *testing.T) { t.Parallel() t.Run("http", func(t *testing.T) { t.Parallel() server := startTestServer(t, func(app *fiber.App) { app.Get("/", func(c fiber.Ctx) error { return c.SendString("foo") }) }) defer server.stop() client := New().SetDial(server.dial()) resp, err := AcquireRequest(). SetClient(client). Get("http://example") require.NoError(t, err) require.Equal(t, "HTTP/1.1", resp.Protocol()) resp.Close() }) t.Run("https", func(t *testing.T) { t.Parallel() serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() require.NoError(t, err) ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") require.NoError(t, err) ln = tls.NewListener(ln, serverTLSConf) app := fiber.New() app.Get("/", func(c fiber.Ctx) error { return c.SendString(c.Scheme()) }) go func() { assert.NoError(t, app.Listener(ln, fiber.ListenConfig{ DisableStartupMessage: true, })) }() client := New() resp, err := client.SetTLSConfig(clientTLSConf).Get("https://" + ln.Addr().String()) require.NoError(t, err) require.Equal(t, clientTLSConf, client.TLSConfig()) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "https", resp.String()) require.Equal(t, "HTTP/1.1", resp.Protocol()) resp.Close() }) } func Test_Response_Header(t *testing.T) { t.Parallel() server := startTestServer(t, func(app *fiber.App) { app.Get("/", func(c fiber.Ctx) error { c.Response().Header.Add("foo", "bar") return c.SendString("helo world") }) }) defer server.stop() client := New().SetDial(server.dial()) resp, err := AcquireRequest(). SetClient(client). Get("http://example.com") require.NoError(t, err) require.Equal(t, "bar", resp.Header("foo")) resp.Close() } func Test_Response_Headers(t *testing.T) { t.Parallel() server := startTestServer(t, func(app *fiber.App) { app.Get("/", func(c fiber.Ctx) error { c.Response().Header.Add("foo", "bar") c.Response().Header.Add("foo", "bar2") c.Response().Header.Add("foo2", "bar") return c.SendString("hello world") }) }) defer server.stop() client := New().SetDial(server.dial()) resp, err := AcquireRequest(). SetClient(client). Get("http://example.com") require.NoError(t, err) headers := make(map[string][]string) for k, v := range resp.Headers() { headers[k] = append(headers[k], v...) } require.Equal(t, "hello world", resp.String()) require.Contains(t, headers["Foo"], "bar") require.Contains(t, headers["Foo"], "bar2") require.Contains(t, headers["Foo2"], "bar") require.Len(t, headers, 5) // Foo + Foo2 + Date + Content-Length + Content-Type resp.Close() } func Benchmark_Headers(b *testing.B) { server := startTestServer( b, func(app *fiber.App) { app.Get("/", func(c fiber.Ctx) error { c.Response().Header.Add("foo", "bar") c.Response().Header.Add("foo", "bar2") c.Response().Header.Add("foo", "bar3") c.Response().Header.Add("foo2", "bar") c.Response().Header.Add("foo2", "bar2") c.Response().Header.Add("foo2", "bar3") return c.SendString("helo world") }) }, ) client := New().SetDial(server.dial()) resp, err := AcquireRequest(). SetClient(client). Get("http://example.com") require.NoError(b, err) b.Cleanup(func() { resp.Close() server.stop() }) b.ReportAllocs() for b.Loop() { for k, v := range resp.Headers() { _ = k _ = v } } } func Test_Response_Cookie(t *testing.T) { t.Parallel() server := startTestServer(t, func(app *fiber.App) { app.Get("/", func(c fiber.Ctx) error { c.Cookie(&fiber.Cookie{ Name: "foo", Value: "bar", }) return c.SendString("helo world") }) }) defer server.stop() client := New().SetDial(server.dial()) resp, err := AcquireRequest(). SetClient(client). Get("http://example.com") require.NoError(t, err) require.Equal(t, "bar", string(resp.Cookies()[0].Value())) resp.Close() } func Test_Response_Body(t *testing.T) { t.Parallel() setupApp := func() *testServer { server := startTestServer(t, func(app *fiber.App) { app.Get("/", func(c fiber.Ctx) error { return c.SendString("hello world") }) app.Get("/json", func(c fiber.Ctx) error { return c.SendString("{\"status\":\"success\"}") }) app.Get("/xml", func(c fiber.Ctx) error { return c.SendString("success") }) app.Get("/cbor", func(c fiber.Ctx) error { type cborData struct { Name string `cbor:"name"` Age int `cbor:"age"` } return c.CBOR(cborData{ Name: "foo", Age: 12, }) }) }) return server } t.Run("raw body", func(t *testing.T) { t.Parallel() server := setupApp() defer server.stop() client := New().SetDial(server.dial()) resp, err := AcquireRequest(). SetClient(client). Get("http://example.com") require.NoError(t, err) require.Equal(t, []byte("hello world"), resp.Body()) resp.Close() }) t.Run("string body", func(t *testing.T) { t.Parallel() server := setupApp() defer server.stop() client := New().SetDial(server.dial()) resp, err := AcquireRequest(). SetClient(client). Get("http://example.com") require.NoError(t, err) require.Equal(t, "hello world", resp.String()) resp.Close() }) t.Run("json body", func(t *testing.T) { t.Parallel() type body struct { Status string `json:"status"` } server := setupApp() defer server.stop() client := New().SetDial(server.dial()) resp, err := AcquireRequest(). SetClient(client). Get("http://example.com/json") require.NoError(t, err) tmp := &body{} err = resp.JSON(tmp) require.NoError(t, err) require.Equal(t, "success", tmp.Status) resp.Close() }) t.Run("xml body", func(t *testing.T) { t.Parallel() type body struct { Name xml.Name `xml:"status"` Status string `xml:"name"` } server := setupApp() defer server.stop() client := New().SetDial(server.dial()) resp, err := AcquireRequest(). SetClient(client). Get("http://example.com/xml") require.NoError(t, err) tmp := &body{} err = resp.XML(tmp) require.NoError(t, err) require.Equal(t, "success", tmp.Status) resp.Close() }) t.Run("cbor body", func(t *testing.T) { t.Parallel() type cborData struct { Name string `cbor:"name"` Age int `cbor:"age"` } data := cborData{ Name: "foo", Age: 12, } server := setupApp() defer server.stop() client := New().SetDial(server.dial()) resp, err := AcquireRequest(). SetClient(client). Get("http://example.com/cbor") require.NoError(t, err) tmp := &cborData{} err = resp.CBOR(tmp) require.NoError(t, err) require.Equal(t, data, *tmp) resp.Close() }) } func Test_Response_DecodeHelpers_ClientNilSafety(t *testing.T) { t.Parallel() t.Run("client nil returns exported error without panic", func(t *testing.T) { t.Parallel() type payload struct { Status string `json:"status" xml:"status" cbor:"status"` } t.Run("json", func(t *testing.T) { t.Parallel() resp := AcquireResponse() t.Cleanup(func() { ReleaseResponse(resp) }) resp.RawResponse.SetBodyString(`{"status":"success"}`) decoded := payload{} require.NotPanics(t, func() { err := resp.JSON(&decoded) require.ErrorIs(t, err, ErrClientNil) }) }) t.Run("xml", func(t *testing.T) { t.Parallel() resp := AcquireResponse() t.Cleanup(func() { ReleaseResponse(resp) }) resp.RawResponse.SetBodyString(`success`) decoded := payload{} require.NotPanics(t, func() { err := resp.XML(&decoded) require.ErrorIs(t, err, ErrClientNil) }) }) t.Run("cbor", func(t *testing.T) { t.Parallel() resp := AcquireResponse() t.Cleanup(func() { ReleaseResponse(resp) }) resp.RawResponse.SetBodyString("not-cbor") decoded := payload{} require.NotPanics(t, func() { err := resp.CBOR(&decoded) require.ErrorIs(t, err, ErrClientNil) }) }) }) t.Run("decode helpers still work with client", func(t *testing.T) { t.Parallel() type payload struct { Status string `json:"status" xml:"status" cbor:"status"` } t.Run("json", func(t *testing.T) { t.Parallel() resp := AcquireResponse() t.Cleanup(func() { ReleaseResponse(resp) }) resp.setClient(New()) resp.RawResponse.SetBodyString(`{"status":"success"}`) decoded := payload{} err := resp.JSON(&decoded) require.NoError(t, err) require.Equal(t, "success", decoded.Status) }) t.Run("xml", func(t *testing.T) { t.Parallel() resp := AcquireResponse() t.Cleanup(func() { ReleaseResponse(resp) }) resp.setClient(New()) resp.RawResponse.SetBodyString(`success`) decoded := payload{} err := resp.XML(&decoded) require.NoError(t, err) require.Equal(t, "success", decoded.Status) }) t.Run("cbor", func(t *testing.T) { t.Parallel() client := New() resp := AcquireResponse() t.Cleanup(func() { ReleaseResponse(resp) }) resp.setClient(client) body, err := client.cborMarshal(payload{Status: "success"}) require.NoError(t, err) resp.RawResponse.SetBody(body) decoded := payload{} err = resp.CBOR(&decoded) require.NoError(t, err) require.Equal(t, "success", decoded.Status) }) }) } func Test_Response_Save(t *testing.T) { t.Parallel() setupApp := func() *testServer { server := startTestServer(t, func(app *fiber.App) { app.Get("/json", func(c fiber.Ctx) error { return c.SendString("{\"status\":\"success\"}") }) }) return server } t.Run("file path", func(t *testing.T) { t.Parallel() server := setupApp() defer server.stop() client := New().SetDial(server.dial()) resp, err := AcquireRequest(). SetClient(client). Get("http://example.com/json") require.NoError(t, err) err = resp.Save("./test/tmp.json") require.NoError(t, err) defer func() { _, statErr := os.Stat("./test/tmp.json") require.NoError(t, statErr) statErr = os.RemoveAll("./test") require.NoError(t, statErr) }() file, err := os.Open("./test/tmp.json") require.NoError(t, err) defer func(file *os.File) { closeErr := file.Close() require.NoError(t, closeErr) }(file) data, err := io.ReadAll(file) require.NoError(t, err) require.JSONEq(t, "{\"status\":\"success\"}", string(data)) }) t.Run("io.Writer", func(t *testing.T) { t.Parallel() server := setupApp() defer server.stop() client := New().SetDial(server.dial()) resp, err := AcquireRequest(). SetClient(client). Get("http://example.com/json") require.NoError(t, err) buf := &bytes.Buffer{} err = resp.Save(buf) require.NoError(t, err) require.JSONEq(t, "{\"status\":\"success\"}", buf.String()) }) t.Run("io.Copy error when saving to file is surfaced", func(t *testing.T) { t.Parallel() resp := AcquireResponse() defer ReleaseResponse(resp) resp.RawResponse.SetBodyStream(&errorReader{err: errors.New("copy failure")}, -1) target := filepath.Join(t.TempDir(), "out.txt") err := resp.Save(target) require.ErrorContains(t, err, "failed to write response body to file: copy failure") }) t.Run("io.Copy error when saving to writer is surfaced", func(t *testing.T) { t.Parallel() resp := AcquireResponse() defer ReleaseResponse(resp) resp.RawResponse.SetBodyStream(bytes.NewBufferString("data"), len("data")) err := resp.Save(&errorWriter{err: errors.New("sink closed")}) require.ErrorContains(t, err, "failed to write response body to writer: sink closed") }) t.Run("error type", func(t *testing.T) { t.Parallel() server := setupApp() defer server.stop() client := New().SetDial(server.dial()) resp, err := AcquireRequest(). SetClient(client). Get("http://example.com/json") require.NoError(t, err) err = resp.Save(nil) require.Error(t, err) }) } func Test_Response_BodyStream(t *testing.T) { t.Parallel() t.Run("basic streaming", func(t *testing.T) { t.Parallel() server := startTestServer(t, func(app *fiber.App) { app.Get("/stream", func(c fiber.Ctx) error { return c.SendStream(bytes.NewReader([]byte("streaming data"))) }) }) defer server.stop() client := New().SetDial(server.dial()).SetStreamResponseBody(true) resp, err := client.Get("http://example.com/stream") require.NoError(t, err) defer resp.Close() bodyStream := resp.BodyStream() require.NotNil(t, bodyStream) data, err := io.ReadAll(bodyStream) require.NoError(t, err) require.Equal(t, "streaming data", string(data)) }) t.Run("large response streaming", func(t *testing.T) { t.Parallel() server := startTestServer(t, func(app *fiber.App) { app.Get("/large", func(c fiber.Ctx) error { data := make([]byte, 1024) for i := range data { data[i] = byte('A' + i%26) } return c.SendStream(bytes.NewReader(data)) }) }) defer server.stop() client := New().SetDial(server.dial()).SetStreamResponseBody(true) resp, err := client.Get("http://example.com/large") require.NoError(t, err) defer resp.Close() bodyStream := resp.BodyStream() require.NotNil(t, bodyStream) buffer := make([]byte, 256) var totalRead []byte for { n, err := bodyStream.Read(buffer) if n > 0 { totalRead = append(totalRead, buffer[:n]...) } if err == io.EOF { break } require.NoError(t, err) } require.Len(t, totalRead, 1024) }) } func Test_Response_BodyStream_Fallback(t *testing.T) { t.Parallel() t.Run("non-streaming response fallback to bytes.Reader", func(t *testing.T) { t.Parallel() server := startTestServer(t, func(app *fiber.App) { app.Get("/regular", func(c fiber.Ctx) error { return c.SendString("regular response body") }) }) defer server.stop() client := New().SetDial(server.dial()) resp, err := client.Get("http://example.com/regular") require.NoError(t, err) defer resp.Close() require.False(t, resp.IsStreaming()) bodyStream := resp.BodyStream() require.NotNil(t, bodyStream) data, err := io.ReadAll(bodyStream) require.NoError(t, err) require.Equal(t, "regular response body", string(data)) }) } func Test_Response_IsStreaming(t *testing.T) { t.Parallel() t.Run("streaming disabled", func(t *testing.T) { t.Parallel() server := startTestServer(t, func(app *fiber.App) { app.Get("/regular", func(c fiber.Ctx) error { return c.SendString("regular content") }) }) defer server.stop() client := New().SetDial(server.dial()) resp, err := client.Get("http://example.com/regular") require.NoError(t, err) defer resp.Close() require.False(t, resp.IsStreaming()) }) t.Run("bodystream always works regardless of streaming state", func(t *testing.T) { t.Parallel() server := startTestServer(t, func(app *fiber.App) { app.Get("/test", func(c fiber.Ctx) error { return c.SendString("test content") }) }) defer server.stop() // Test with streaming enabled client1 := New().SetDial(server.dial()).SetStreamResponseBody(true) resp1, err := client1.Get("http://example.com/test") require.NoError(t, err) defer resp1.Close() bodyStream1 := resp1.BodyStream() require.NotNil(t, bodyStream1) data1, err := io.ReadAll(bodyStream1) require.NoError(t, err) require.Equal(t, "test content", string(data1)) // Test with streaming disabled client2 := New().SetDial(server.dial()).SetStreamResponseBody(false) resp2, err := client2.Get("http://example.com/test") require.NoError(t, err) defer resp2.Close() require.False(t, resp2.IsStreaming()) bodyStream2 := resp2.BodyStream() require.NotNil(t, bodyStream2) data2, err := io.ReadAll(bodyStream2) require.NoError(t, err) require.Equal(t, "test content", string(data2)) }) } func Test_Response_Save_Streaming(t *testing.T) { t.Parallel() t.Run("save streaming response to file", func(t *testing.T) { t.Parallel() server := startTestServer(t, func(app *fiber.App) { app.Get("/stream", func(c fiber.Ctx) error { return c.SendStream(bytes.NewReader([]byte("streaming file content"))) }) }) defer server.stop() client := New().SetDial(server.dial()).SetStreamResponseBody(true) resp, err := client.Get("http://example.com/stream") require.NoError(t, err) defer resp.Close() testFile := filepath.Join(t.TempDir(), "stream_test.txt") err = resp.Save(testFile) require.NoError(t, err) data, err := os.ReadFile(testFile) //nolint:gosec // test file is created in a temp directory require.NoError(t, err) require.Equal(t, "streaming file content", string(data)) }) t.Run("save streaming response to io.Writer", func(t *testing.T) { t.Parallel() server := startTestServer(t, func(app *fiber.App) { app.Get("/stream", func(c fiber.Ctx) error { return c.SendStream(bytes.NewReader([]byte("streaming writer content"))) }) }) defer server.stop() client := New().SetDial(server.dial()).SetStreamResponseBody(true) resp, err := client.Get("http://example.com/stream") require.NoError(t, err) defer resp.Close() var buf bytes.Buffer err = resp.Save(&buf) require.NoError(t, err) require.Equal(t, "streaming writer content", buf.String()) }) t.Run("save non-streaming response to file using BodyStream", func(t *testing.T) { t.Parallel() server := startTestServer(t, func(app *fiber.App) { app.Get("/regular", func(c fiber.Ctx) error { return c.SendString("regular file content") }) }) defer server.stop() client := New().SetDial(server.dial()) resp, err := client.Get("http://example.com/regular") require.NoError(t, err) defer resp.Close() testFile := filepath.Join(t.TempDir(), "regular_test.txt") err = resp.Save(testFile) require.NoError(t, err) data, err := os.ReadFile(testFile) //nolint:gosec // test file is created in a temp directory require.NoError(t, err) require.Equal(t, "regular file content", string(data)) }) t.Run("save to io.WriteCloser closes writer", func(t *testing.T) { t.Parallel() server := startTestServer(t, func(app *fiber.App) { app.Get("/test", func(c fiber.Ctx) error { return c.SendString("test content") }) }) defer server.stop() client := New().SetDial(server.dial()) resp, err := client.Get("http://example.com/test") require.NoError(t, err) defer resp.Close() // Create a mock WriteCloser to verify Close is called mockWriter := &mockWriteCloser{} err = resp.Save(mockWriter) require.NoError(t, err) require.True(t, mockWriter.closed, "Save() should close io.WriteCloser") require.Equal(t, "test content", mockWriter.buf.String()) }) } // mockWriteCloser is a helper to verify that Save() closes io.WriteCloser type mockWriteCloser struct { buf bytes.Buffer closed bool } type errorReader struct { err error } func (m *errorReader) Read(_ []byte) (int, error) { return 0, m.err } type errorWriter struct { err error } func (m *errorWriter) Write(_ []byte) (int, error) { return 0, m.err } func (m *mockWriteCloser) Write(p []byte) (int, error) { return m.buf.Write(p) //nolint:wrapcheck // propagate buffer write error directly for test helper } func (m *mockWriteCloser) Close() error { m.closed = true return nil } ================================================ FILE: client/transport.go ================================================ // Transport adapters unify fasthttp clients behind a shared interface so the // Fiber client can coordinate behavior like redirects, TLS overrides, and // dial customizations regardless of the underlying transport type. package client import ( "bytes" "crypto/tls" "time" "github.com/valyala/fasthttp" ) // defaultRedirectLimit mirrors fasthttp's default when callers supply a negative redirect cap. const defaultRedirectLimit = 16 var ( // Pre-allocated byte slice for http/https scheme comparison httpScheme = []byte("http") httpsScheme = []byte("https") ) // httpClientTransport unifies the operations exposed by the Fiber client across // the fasthttp.Client, fasthttp.HostClient, and fasthttp.LBClient adapters so // helper logic can treat the concrete transports uniformly. type httpClientTransport interface { Do(req *fasthttp.Request, resp *fasthttp.Response) error DoTimeout(req *fasthttp.Request, resp *fasthttp.Response, timeout time.Duration) error DoDeadline(req *fasthttp.Request, resp *fasthttp.Response, deadline time.Time) error DoRedirects(req *fasthttp.Request, resp *fasthttp.Response, maxRedirects int) error CloseIdleConnections() TLSConfig() *tls.Config SetTLSConfig(config *tls.Config) SetDial(dial fasthttp.DialFunc) Client() any StreamResponseBody() bool SetStreamResponseBody(enable bool) } // standardClientTransport adapts fasthttp.Client to the httpClientTransport // interface used by Fiber's client helpers. type standardClientTransport struct { client *fasthttp.Client } func newStandardClientTransport(client *fasthttp.Client) *standardClientTransport { return &standardClientTransport{client: client} } func (s *standardClientTransport) Do(req *fasthttp.Request, resp *fasthttp.Response) error { return s.client.Do(req, resp) } func (s *standardClientTransport) DoTimeout(req *fasthttp.Request, resp *fasthttp.Response, timeout time.Duration) error { return s.client.DoTimeout(req, resp, timeout) } func (s *standardClientTransport) DoDeadline(req *fasthttp.Request, resp *fasthttp.Response, deadline time.Time) error { return s.client.DoDeadline(req, resp, deadline) } func (s *standardClientTransport) DoRedirects(req *fasthttp.Request, resp *fasthttp.Response, maxRedirects int) error { return s.client.DoRedirects(req, resp, maxRedirects) } func (s *standardClientTransport) CloseIdleConnections() { s.client.CloseIdleConnections() } func (s *standardClientTransport) TLSConfig() *tls.Config { return s.client.TLSConfig } func (s *standardClientTransport) SetTLSConfig(config *tls.Config) { s.client.TLSConfig = config } func (s *standardClientTransport) SetDial(dial fasthttp.DialFunc) { s.client.Dial = dial } func (s *standardClientTransport) Client() any { return s.client } func (s *standardClientTransport) StreamResponseBody() bool { return s.client.StreamResponseBody } func (s *standardClientTransport) SetStreamResponseBody(enable bool) { s.client.StreamResponseBody = enable } // hostClientTransport adapts fasthttp.HostClient to the httpClientTransport // interface used by Fiber's client helpers. type hostClientTransport struct { client *fasthttp.HostClient } func newHostClientTransport(client *fasthttp.HostClient) *hostClientTransport { return &hostClientTransport{client: client} } func (h *hostClientTransport) Do(req *fasthttp.Request, resp *fasthttp.Response) error { return h.client.Do(req, resp) } func (h *hostClientTransport) DoTimeout(req *fasthttp.Request, resp *fasthttp.Response, timeout time.Duration) error { return h.client.DoTimeout(req, resp, timeout) } func (h *hostClientTransport) DoDeadline(req *fasthttp.Request, resp *fasthttp.Response, deadline time.Time) error { return h.client.DoDeadline(req, resp, deadline) } func (h *hostClientTransport) DoRedirects(req *fasthttp.Request, resp *fasthttp.Response, maxRedirects int) error { return h.client.DoRedirects(req, resp, maxRedirects) } func (h *hostClientTransport) CloseIdleConnections() { h.client.CloseIdleConnections() } func (h *hostClientTransport) TLSConfig() *tls.Config { return h.client.TLSConfig } func (h *hostClientTransport) SetTLSConfig(config *tls.Config) { h.client.TLSConfig = config } func (h *hostClientTransport) SetDial(dial fasthttp.DialFunc) { h.client.Dial = dial } func (h *hostClientTransport) Client() any { return h.client } func (h *hostClientTransport) StreamResponseBody() bool { return h.client.StreamResponseBody } func (h *hostClientTransport) SetStreamResponseBody(enable bool) { h.client.StreamResponseBody = enable } // lbClientTransport adapts fasthttp.LBClient to the httpClientTransport // interface used by Fiber's client helpers. type lbClientTransport struct { client *fasthttp.LBClient } func newLBClientTransport(client *fasthttp.LBClient) *lbClientTransport { return &lbClientTransport{client: client} } func (l *lbClientTransport) Do(req *fasthttp.Request, resp *fasthttp.Response) error { return l.client.Do(req, resp) } func (l *lbClientTransport) DoTimeout(req *fasthttp.Request, resp *fasthttp.Response, timeout time.Duration) error { return l.client.DoTimeout(req, resp, timeout) } func (l *lbClientTransport) DoDeadline(req *fasthttp.Request, resp *fasthttp.Response, deadline time.Time) error { return l.client.DoDeadline(req, resp, deadline) } // DoRedirects proxies redirect handling through doRedirectsWithClient so the // load-balanced transport mirrors fasthttp.Client semantics despite // fasthttp.LBClient not exposing DoRedirects directly. func (l *lbClientTransport) DoRedirects(req *fasthttp.Request, resp *fasthttp.Response, maxRedirects int) error { return doRedirectsWithClient(req, resp, maxRedirects, l.client) } func (l *lbClientTransport) CloseIdleConnections() { forEachHostClient(l.client, func(hc *fasthttp.HostClient) { hc.CloseIdleConnections() }) } func (l *lbClientTransport) TLSConfig() *tls.Config { if len(l.client.Clients) == 0 { return nil } return extractTLSConfig(l.client.Clients) } func (l *lbClientTransport) SetTLSConfig(config *tls.Config) { forEachHostClient(l.client, func(hc *fasthttp.HostClient) { hc.TLSConfig = config }) } func (l *lbClientTransport) SetDial(dial fasthttp.DialFunc) { forEachHostClient(l.client, func(hc *fasthttp.HostClient) { hc.Dial = dial }) } func (l *lbClientTransport) Client() any { return l.client } func (l *lbClientTransport) StreamResponseBody() bool { if len(l.client.Clients) == 0 { return false } // Return the StreamResponseBody setting from the first HostClient var streamEnabled bool for _, c := range l.client.Clients { if walkBalancingClientWithBreak(c, func(hc *fasthttp.HostClient) bool { streamEnabled = hc.StreamResponseBody return true }) { break } } return streamEnabled } func (l *lbClientTransport) SetStreamResponseBody(enable bool) { forEachHostClient(l.client, func(hc *fasthttp.HostClient) { hc.StreamResponseBody = enable }) } // forEachHostClient applies fn to every host client reachable from the provided // load balancer by recursively following nested balancers and wrapper types. func forEachHostClient(lb *fasthttp.LBClient, fn func(*fasthttp.HostClient)) { for _, c := range lb.Clients { walkBalancingClient(c, fn) } } // walkBalancingClient traverses balancing clients recursively, invoking fn for // every host client discovered beneath the current node. func walkBalancingClient(client any, fn func(*fasthttp.HostClient)) { walkBalancingClientWithBreak(client, func(hc *fasthttp.HostClient) bool { fn(hc) return false }) } // extractTLSConfig returns the first TLS configuration discovered while walking // the provided balancing clients so cached settings flow through nested load // balancers without redundant traversal. func extractTLSConfig(clients []fasthttp.BalancingClient) *tls.Config { var cfg *tls.Config for _, c := range clients { if walkBalancingClientWithBreak(c, func(hc *fasthttp.HostClient) bool { if hc.TLSConfig != nil { cfg = hc.TLSConfig return true } return false }) { break } } return cfg } // walkBalancingClientWithBreak traverses balancing clients recursively and // invokes fn for each host client until fn signals success, enabling early // termination once a match is found. func walkBalancingClientWithBreak(client any, fn func(*fasthttp.HostClient) bool) bool { switch c := client.(type) { case *fasthttp.HostClient: return fn(c) case *fasthttp.LBClient: for _, nestedClient := range c.Clients { if walkBalancingClientWithBreak(nestedClient, fn) { return true } } case interface{ LBClient() *fasthttp.LBClient }: if nested := c.LBClient(); nested != nil { if walkBalancingClientWithBreak(nested, fn) { return true } } } return false } // redirectClient describes the minimal Do-capable surface needed by // doRedirectsWithClient so transports that do not expose DoRedirects (such as // fasthttp.LBClient) can participate in redirect handling. type redirectClient interface { Do(req *fasthttp.Request, resp *fasthttp.Response) error } // doRedirectsWithClient mirrors fasthttp's redirect loop for transports that do // not expose DoRedirects (e.g. fasthttp.LBClient). The helper always issues the // initial request, respects zero redirect limits, falls back to the default cap // for negative values, and validates redirect targets before following them. func doRedirectsWithClient(req *fasthttp.Request, resp *fasthttp.Response, maxRedirects int, client redirectClient) error { currentURL := req.URI().String() redirects := 0 singleRequestOnly := maxRedirects <= 0 if maxRedirects < 0 { maxRedirects = defaultRedirectLimit singleRequestOnly = false } for { req.SetRequestURI(currentURL) if err := client.Do(req, resp); err != nil { return err } statusCode := resp.Header.StatusCode() if !fasthttp.StatusCodeIsRedirect(statusCode) { return nil } if singleRequestOnly { return nil } redirects++ if redirects > maxRedirects { return fasthttp.ErrTooManyRedirects } location := resp.Header.Peek("Location") if len(location) == 0 { return fasthttp.ErrMissingLocation } nextURL, err := composeRedirectURL(currentURL, location, req.DisableRedirectPathNormalizing) if err != nil { return err } currentURL = nextURL if req.Header.IsPost() && (statusCode == fasthttp.StatusMovedPermanently || statusCode == fasthttp.StatusFound || statusCode == fasthttp.StatusSeeOther) { req.Header.SetMethod(fasthttp.MethodGet) req.SetBody(nil) req.Header.Del(fasthttp.HeaderContentType) req.Header.Del(fasthttp.HeaderContentLength) } } } // composeRedirectURL resolves a redirect target relative to the current request // URL while rejecting suspicious payloads (e.g. control characters) and // restricting schemes to HTTP/S so caller-provided Location headers cannot // trigger arbitrary transports. func composeRedirectURL(base string, location []byte, disablePathNormalizing bool) (string, error) { for _, b := range location { if b < 0x20 || b == 0x7f { return "", fasthttp.ErrorInvalidURI } } uri := fasthttp.AcquireURI() defer fasthttp.ReleaseURI(uri) uri.Update(base) uri.UpdateBytes(location) uri.DisablePathNormalizing = disablePathNormalizing if scheme := uri.Scheme(); len(scheme) > 0 && !bytes.EqualFold(scheme, httpScheme) && !bytes.EqualFold(scheme, httpsScheme) { return "", fasthttp.ErrorInvalidURI } if len(uri.Scheme()) > 0 && len(uri.Host()) == 0 { return "", fasthttp.ErrorInvalidURI } return uri.String(), nil } ================================================ FILE: client/transport_test.go ================================================ package client import ( "crypto/tls" "errors" "net" "sync/atomic" "testing" "time" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) type stubBalancingClient struct{} func (stubBalancingClient) DoDeadline(*fasthttp.Request, *fasthttp.Response, time.Time) error { return nil } func (stubBalancingClient) PendingRequests() int { return 0 } type lbBalancingClient struct { client *fasthttp.LBClient } func (l *lbBalancingClient) DoDeadline(req *fasthttp.Request, resp *fasthttp.Response, deadline time.Time) error { if l.client == nil { return nil } return l.client.DoDeadline(req, resp, deadline) } func (*lbBalancingClient) PendingRequests() int { return 0 } func (l *lbBalancingClient) LBClient() *fasthttp.LBClient { return l.client } type stubRedirectCall struct { err error status *int location *string } func ptrInt(v int) *int { return &v } func ptrString(v string) *string { return &v } type stubRedirectClient struct { calls []stubRedirectCall callCount int } func (s *stubRedirectClient) Do(req *fasthttp.Request, resp *fasthttp.Response) error { _ = req s.callCount++ if len(s.calls) == 0 { resp.Reset() resp.Header.SetStatusCode(fasthttp.StatusOK) return nil } call := s.calls[0] s.calls = s.calls[1:] resp.Reset() if call.status != nil { resp.Header.SetStatusCode(*call.status) } if call.location != nil { resp.Header.Set("Location", *call.location) } return call.err } func (s *stubRedirectClient) CallCount() int { return s.callCount } func TestStandardClientTransportCoverage(t *testing.T) { t.Parallel() var dialCount atomic.Int32 client := &fasthttp.Client{} client.Dial = func(addr string) (net.Conn, error) { _ = addr dialCount.Add(1) return nil, errors.New("dial error") } transport := newStandardClientTransport(client) req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() defer fasthttp.ReleaseRequest(req) defer fasthttp.ReleaseResponse(resp) req.SetRequestURI("http://example.com/") require.Error(t, transport.Do(req, resp)) req.SetRequestURI("http://example.com/") require.Error(t, transport.DoTimeout(req, resp, time.Millisecond)) req.SetRequestURI("http://example.com/") require.Error(t, transport.DoDeadline(req, resp, time.Now().Add(time.Second))) transport.CloseIdleConnections() underlying, ok := transport.Client().(*fasthttp.Client) require.True(t, ok) require.Same(t, client, underlying) require.Equal(t, int32(3), dialCount.Load()) clientTLS := &tls.Config{ServerName: "standard", MinVersion: tls.VersionTLS12} client.TLSConfig = clientTLS cfg := transport.TLSConfig() require.Same(t, clientTLS, cfg) override := &tls.Config{ServerName: "override", MinVersion: tls.VersionTLS13} transport.SetTLSConfig(override) require.Equal(t, override, client.TLSConfig) } func TestHostClientTransportClientAccessor(t *testing.T) { t.Parallel() host := &fasthttp.HostClient{Addr: "example.com:80"} transport := newHostClientTransport(host) current, ok := transport.Client().(*fasthttp.HostClient) require.True(t, ok) require.Same(t, host, current) hostTLS := &tls.Config{ServerName: "host", MinVersion: tls.VersionTLS12} host.TLSConfig = hostTLS cfg := transport.TLSConfig() require.Same(t, hostTLS, cfg) override := &tls.Config{ServerName: "host-override", MinVersion: tls.VersionTLS13} transport.SetTLSConfig(override) require.Equal(t, override, host.TLSConfig) } func TestLBClientTransportAccessorsAndOverrides(t *testing.T) { t.Parallel() hostWithoutOverrides := &fasthttp.HostClient{Addr: "example.com:80"} nestedDialHost := &fasthttp.HostClient{Addr: "example.org:80"} nestedTLSHost := &fasthttp.HostClient{Addr: "example.net:80", TLSConfig: &tls.Config{ServerName: "example", MinVersion: tls.VersionTLS12}} multiLevelHost := &fasthttp.HostClient{Addr: "example.edu:80"} nestedDialHost.Dial = func(addr string) (net.Conn, error) { _ = addr return nil, errors.New("original dial") } multiLevelHost.Dial = func(addr string) (net.Conn, error) { _ = addr return nil, errors.New("multi-level dial") } nestedDialLB := &lbBalancingClient{client: &fasthttp.LBClient{Clients: []fasthttp.BalancingClient{nestedDialHost}}} nestedTLSLB := &lbBalancingClient{client: &fasthttp.LBClient{Clients: []fasthttp.BalancingClient{nestedTLSHost}}} multiLevelLeaf := &lbBalancingClient{client: &fasthttp.LBClient{Clients: []fasthttp.BalancingClient{multiLevelHost}}} multiLevelWrapper := &lbBalancingClient{client: &fasthttp.LBClient{Clients: []fasthttp.BalancingClient{multiLevelLeaf}}} lb := &fasthttp.LBClient{Clients: []fasthttp.BalancingClient{ stubBalancingClient{}, hostWithoutOverrides, nestedDialLB, nestedTLSLB, multiLevelWrapper, }} transport := newLBClientTransport(lb) require.Same(t, lb, transport.Client()) cfg := transport.TLSConfig() require.Same(t, nestedTLSHost.TLSConfig, cfg) overrideTLS := &tls.Config{ServerName: "override", MinVersion: tls.VersionTLS12} transport.SetTLSConfig(overrideTLS) require.Equal(t, overrideTLS, hostWithoutOverrides.TLSConfig) require.Equal(t, overrideTLS, nestedDialHost.TLSConfig) require.Equal(t, overrideTLS, nestedTLSHost.TLSConfig) require.Equal(t, overrideTLS, multiLevelHost.TLSConfig) cfg = transport.TLSConfig() require.Same(t, overrideTLS, cfg) cfg.ServerName = "mutated" require.Equal(t, "mutated", transport.TLSConfig().ServerName) overrideDialCalled := atomic.Bool{} overrideDial := func(addr string) (net.Conn, error) { _ = addr overrideDialCalled.Store(true) return nil, errors.New("override dial") } transport.SetDial(overrideDial) overrideDialCalled.Store(false) _, err := hostWithoutOverrides.Dial("example.com:80") require.Error(t, err) require.True(t, overrideDialCalled.Load()) overrideDialCalled.Store(false) _, err = nestedDialHost.Dial("example.org:80") require.Error(t, err) require.True(t, overrideDialCalled.Load()) overrideDialCalled.Store(false) _, err = multiLevelHost.Dial("example.edu:80") require.Error(t, err) require.True(t, overrideDialCalled.Load()) } func TestExtractTLSConfigVariations(t *testing.T) { t.Parallel() require.Nil(t, extractTLSConfig(nil)) require.Nil(t, extractTLSConfig([]fasthttp.BalancingClient{stubBalancingClient{}})) host := &fasthttp.HostClient{TLSConfig: &tls.Config{ServerName: "configured", MinVersion: tls.VersionTLS12}} require.Equal(t, host.TLSConfig, extractTLSConfig([]fasthttp.BalancingClient{host})) nested := &fasthttp.HostClient{TLSConfig: &tls.Config{ServerName: "nested", MinVersion: tls.VersionTLS12}} nestedLB := &lbBalancingClient{client: &fasthttp.LBClient{Clients: []fasthttp.BalancingClient{nested}}} require.Equal(t, nested.TLSConfig, extractTLSConfig([]fasthttp.BalancingClient{nestedLB})) multiLayerLB := &lbBalancingClient{client: &fasthttp.LBClient{Clients: []fasthttp.BalancingClient{nestedLB}}} require.Equal(t, nested.TLSConfig, extractTLSConfig([]fasthttp.BalancingClient{multiLayerLB})) } func TestWalkBalancingClientWithBreak(t *testing.T) { t.Parallel() host := &fasthttp.HostClient{} require.True(t, walkBalancingClientWithBreak(host, func(*fasthttp.HostClient) bool { return true })) require.False(t, walkBalancingClientWithBreak(stubBalancingClient{}, func(*fasthttp.HostClient) bool { t.Fatal("unexpected call") return false })) nested := &fasthttp.HostClient{} nestedLB := &lbBalancingClient{client: &fasthttp.LBClient{Clients: []fasthttp.BalancingClient{nested}}} require.True(t, walkBalancingClientWithBreak(nestedLB, func(*fasthttp.HostClient) bool { return true })) directNestedHost := &fasthttp.HostClient{} directNestedLB := &fasthttp.LBClient{Clients: []fasthttp.BalancingClient{directNestedHost}} require.True(t, walkBalancingClientWithBreak(directNestedLB, func(hc *fasthttp.HostClient) bool { require.Same(t, directNestedHost, hc) return true })) } func TestDoRedirectsWithClientBranches(t *testing.T) { t.Parallel() req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() defer fasthttp.ReleaseRequest(req) defer fasthttp.ReleaseResponse(resp) req.SetRequestURI("http://example.com/start") req.Header.SetMethod(fasthttp.MethodPost) req.Header.SetContentType("application/json") req.SetBodyString("payload") client := &stubRedirectClient{calls: []stubRedirectCall{{status: ptrInt(fasthttp.StatusMovedPermanently), location: ptrString("/redirect")}, {status: ptrInt(fasthttp.StatusOK)}}} require.NoError(t, doRedirectsWithClient(req, resp, -1, client)) require.Equal(t, fasthttp.MethodGet, string(req.Header.Method())) require.Equal(t, "http://example.com/redirect", req.URI().String()) require.Empty(t, req.Body()) require.Empty(t, req.Header.ContentType()) resp.Reset() req.Header.SetMethod(fasthttp.MethodPost) req.Header.SetContentType("application/json") req.SetBodyString("payload") seeOtherClient := &stubRedirectClient{calls: []stubRedirectCall{{status: ptrInt(fasthttp.StatusSeeOther), location: ptrString("/see-other")}, {status: ptrInt(fasthttp.StatusOK)}}} require.NoError(t, doRedirectsWithClient(req, resp, -1, seeOtherClient)) require.Equal(t, fasthttp.MethodGet, string(req.Header.Method())) require.Equal(t, "http://example.com/see-other", req.URI().String()) require.Empty(t, req.Body()) require.Empty(t, req.Header.ContentType()) resp.Reset() req.Header.SetMethod(fasthttp.MethodPost) req.SetRequestURI("http://example.com/again") req.SetBodyString("payload") singleCall := &stubRedirectClient{calls: []stubRedirectCall{{status: ptrInt(fasthttp.StatusFound), location: ptrString("/ignored")}}} require.NoError(t, doRedirectsWithClient(req, resp, 0, singleCall)) require.Equal(t, fasthttp.StatusFound, resp.StatusCode()) require.Equal(t, fasthttp.MethodPost, string(req.Header.Method())) require.Equal(t, "http://example.com/again", req.URI().String()) require.Equal(t, "payload", string(req.Body())) require.Equal(t, 1, singleCall.CallCount()) require.Equal(t, fasthttp.StatusFound, resp.Header.StatusCode()) resp.Reset() req.Header.SetMethod(fasthttp.MethodPost) req.SetRequestURI("http://example.com/start") client = &stubRedirectClient{calls: []stubRedirectCall{{status: ptrInt(fasthttp.StatusFound)}}} require.ErrorIs(t, doRedirectsWithClient(req, resp, 1, client), fasthttp.ErrMissingLocation) resp.Reset() req.Header.SetMethod(fasthttp.MethodPost) req.SetRequestURI("http://example.com/start") client = &stubRedirectClient{calls: []stubRedirectCall{{status: ptrInt(fasthttp.StatusMovedPermanently), location: ptrString("ftp://example.com")}}} require.ErrorIs(t, doRedirectsWithClient(req, resp, 1, client), fasthttp.ErrorInvalidURI) resp.Reset() req.Header.SetMethod(fasthttp.MethodPost) req.SetRequestURI("http://example.com/start") client = &stubRedirectClient{calls: []stubRedirectCall{{status: ptrInt(fasthttp.StatusFound), location: ptrString("/bad\x00path")}}} require.ErrorIs(t, doRedirectsWithClient(req, resp, 1, client), fasthttp.ErrorInvalidURI) resp.Reset() req.Header.SetMethod(fasthttp.MethodPost) req.SetRequestURI("http://example.com/start") client = &stubRedirectClient{calls: []stubRedirectCall{{status: ptrInt(fasthttp.StatusMovedPermanently), location: ptrString("/loop")}, {status: ptrInt(fasthttp.StatusFound), location: ptrString("/final")}, {status: ptrInt(fasthttp.StatusOK)}}} require.ErrorIs(t, doRedirectsWithClient(req, resp, 1, client), fasthttp.ErrTooManyRedirects) } func TestDoRedirectsWithClientDefaultLimit(t *testing.T) { t.Parallel() req := fasthttp.AcquireRequest() resp := fasthttp.AcquireResponse() defer fasthttp.ReleaseRequest(req) defer fasthttp.ReleaseResponse(resp) req.SetRequestURI("http://example.com/start") req.Header.SetMethod(fasthttp.MethodPost) calls := make([]stubRedirectCall, 0, defaultRedirectLimit+1) for i := 0; i < defaultRedirectLimit+1; i++ { calls = append(calls, stubRedirectCall{status: ptrInt(fasthttp.StatusFound), location: ptrString("/loop")}) } client := &stubRedirectClient{calls: calls} err := doRedirectsWithClient(req, resp, -1, client) require.ErrorIs(t, err, fasthttp.ErrTooManyRedirects) require.Equal(t, defaultRedirectLimit+1, client.CallCount()) } func Test_StandardClientTransport_StreamResponseBody(t *testing.T) { t.Parallel() t.Run("default value", func(t *testing.T) { t.Parallel() transport := newStandardClientTransport(&fasthttp.Client{}) require.False(t, transport.StreamResponseBody()) }) t.Run("enable streaming", func(t *testing.T) { t.Parallel() client := &fasthttp.Client{} transport := newStandardClientTransport(client) transport.SetStreamResponseBody(true) require.True(t, transport.StreamResponseBody()) require.True(t, client.StreamResponseBody) }) t.Run("disable streaming", func(t *testing.T) { t.Parallel() client := &fasthttp.Client{} transport := newStandardClientTransport(client) transport.SetStreamResponseBody(true) require.True(t, transport.StreamResponseBody()) transport.SetStreamResponseBody(false) require.False(t, transport.StreamResponseBody()) require.False(t, client.StreamResponseBody) }) } func Test_HostClientTransport_StreamResponseBody(t *testing.T) { t.Parallel() t.Run("default value", func(t *testing.T) { t.Parallel() hostClient := &fasthttp.HostClient{} transport := newHostClientTransport(hostClient) require.False(t, transport.StreamResponseBody()) }) t.Run("enable streaming", func(t *testing.T) { t.Parallel() hostClient := &fasthttp.HostClient{} transport := newHostClientTransport(hostClient) transport.SetStreamResponseBody(true) require.True(t, transport.StreamResponseBody()) require.True(t, hostClient.StreamResponseBody) }) t.Run("disable streaming", func(t *testing.T) { t.Parallel() hostClient := &fasthttp.HostClient{} transport := newHostClientTransport(hostClient) transport.SetStreamResponseBody(true) require.True(t, transport.StreamResponseBody()) transport.SetStreamResponseBody(false) require.False(t, transport.StreamResponseBody()) require.False(t, hostClient.StreamResponseBody) }) } func Test_LBClientTransport_StreamResponseBody(t *testing.T) { t.Parallel() t.Run("empty clients", func(t *testing.T) { t.Parallel() lbClient := &fasthttp.LBClient{ Clients: []fasthttp.BalancingClient{}, } transport := newLBClientTransport(lbClient) require.False(t, transport.StreamResponseBody()) }) t.Run("single host client", func(t *testing.T) { t.Parallel() hostClient := &fasthttp.HostClient{Addr: "example.com:80"} lbClient := &fasthttp.LBClient{ Clients: []fasthttp.BalancingClient{hostClient}, } transport := newLBClientTransport(lbClient) // Test default require.False(t, transport.StreamResponseBody()) // Enable streaming transport.SetStreamResponseBody(true) require.True(t, transport.StreamResponseBody()) require.True(t, hostClient.StreamResponseBody) // Disable streaming transport.SetStreamResponseBody(false) require.False(t, transport.StreamResponseBody()) require.False(t, hostClient.StreamResponseBody) }) t.Run("multiple host clients", func(t *testing.T) { t.Parallel() hostClient1 := &fasthttp.HostClient{Addr: "example1.com:80"} hostClient2 := &fasthttp.HostClient{Addr: "example2.com:80"} lbClient := &fasthttp.LBClient{ Clients: []fasthttp.BalancingClient{hostClient1, hostClient2}, } transport := newLBClientTransport(lbClient) // Enable streaming on all clients transport.SetStreamResponseBody(true) require.True(t, transport.StreamResponseBody()) require.True(t, hostClient1.StreamResponseBody) require.True(t, hostClient2.StreamResponseBody) // Disable streaming on all clients transport.SetStreamResponseBody(false) require.False(t, transport.StreamResponseBody()) require.False(t, hostClient1.StreamResponseBody) require.False(t, hostClient2.StreamResponseBody) }) } func Test_httpClientTransport_Interface(t *testing.T) { t.Parallel() transports := []struct { transport httpClientTransport name string }{ { name: "standardClientTransport", transport: newStandardClientTransport(&fasthttp.Client{}), }, { name: "hostClientTransport", transport: newHostClientTransport(&fasthttp.HostClient{}), }, { name: "lbClientTransport", transport: newLBClientTransport(&fasthttp.LBClient{ Clients: []fasthttp.BalancingClient{ &fasthttp.HostClient{Addr: "example.com:80"}, }, }), }, } for _, tt := range transports { t.Run(tt.name, func(t *testing.T) { t.Parallel() transport := tt.transport require.NotNil(t, transport.Client()) initialStream := transport.StreamResponseBody() transport.SetStreamResponseBody(!initialStream) require.Equal(t, !initialStream, transport.StreamResponseBody()) transport.SetStreamResponseBody(initialStream) require.Equal(t, initialStream, transport.StreamResponseBody()) }) } } ================================================ FILE: color.go ================================================ // ⚡️ Fiber is an Express inspired web framework written in Go with ☕️ // 🤖 GitHub Repository: https://github.com/gofiber/fiber // 📌 API Documentation: https://docs.gofiber.io package fiber // Colors is a struct to define custom colors for Fiber app and middlewares. type Colors struct { // Black color. // // Optional. Default: "\u001b[90m" Black string // Red color. // // Optional. Default: "\u001b[91m" Red string // Green color. // // Optional. Default: "\u001b[92m" Green string // Yellow color. // // Optional. Default: "\u001b[93m" Yellow string // Blue color. // // Optional. Default: "\u001b[94m" Blue string // Magenta color. // // Optional. Default: "\u001b[95m" Magenta string // Cyan color. // // Optional. Default: "\u001b[96m" Cyan string // White color. // // Optional. Default: "\u001b[97m" White string // Reset color. // // Optional. Default: "\u001b[0m" Reset string } // DefaultColors Default color codes var DefaultColors = Colors{ Black: "\u001b[90m", Red: "\u001b[91m", Green: "\u001b[92m", Yellow: "\u001b[93m", Blue: "\u001b[94m", Magenta: "\u001b[95m", Cyan: "\u001b[96m", White: "\u001b[97m", Reset: "\u001b[0m", } // defaultColors is a function to override default colors to config func defaultColors(colors *Colors) Colors { if colors == nil { return DefaultColors } cfg := *colors if cfg.Black == "" { cfg.Black = DefaultColors.Black } if cfg.Red == "" { cfg.Red = DefaultColors.Red } if cfg.Green == "" { cfg.Green = DefaultColors.Green } if cfg.Yellow == "" { cfg.Yellow = DefaultColors.Yellow } if cfg.Blue == "" { cfg.Blue = DefaultColors.Blue } if cfg.Magenta == "" { cfg.Magenta = DefaultColors.Magenta } if cfg.Cyan == "" { cfg.Cyan = DefaultColors.Cyan } if cfg.White == "" { cfg.White = DefaultColors.White } if cfg.Reset == "" { cfg.Reset = DefaultColors.Reset } return cfg } ================================================ FILE: constants.go ================================================ package fiber // HTTP methods were copied from net/http. const ( MethodGet = "GET" // RFC 7231, 4.3.1 MethodHead = "HEAD" // RFC 7231, 4.3.2 MethodPost = "POST" // RFC 7231, 4.3.3 MethodPut = "PUT" // RFC 7231, 4.3.4 MethodPatch = "PATCH" // RFC 5789 MethodDelete = "DELETE" // RFC 7231, 4.3.5 MethodConnect = "CONNECT" // RFC 7231, 4.3.6 MethodOptions = "OPTIONS" // RFC 7231, 4.3.7 MethodTrace = "TRACE" // RFC 7231, 4.3.8 methodUse = "USE" ) // MIME types that are commonly used const ( MIMETextXML = "text/xml" MIMETextHTML = "text/html" MIMETextPlain = "text/plain" MIMETextJavaScript = "text/javascript" MIMETextCSS = "text/css" MIMEApplicationXML = "application/xml" MIMEApplicationJSON = "application/json" MIMEApplicationJavaScript = "application/javascript" MIMEApplicationCBOR = "application/cbor" MIMEApplicationForm = "application/x-www-form-urlencoded" MIMEOctetStream = "application/octet-stream" MIMEMultipartForm = "multipart/form-data" MIMEApplicationMsgPack = "application/vnd.msgpack" MIMETextXMLCharsetUTF8 = "text/xml; charset=utf-8" MIMETextHTMLCharsetUTF8 = "text/html; charset=utf-8" MIMETextPlainCharsetUTF8 = "text/plain; charset=utf-8" MIMETextJavaScriptCharsetUTF8 = "text/javascript; charset=utf-8" MIMETextCSSCharsetUTF8 = "text/css; charset=utf-8" MIMEApplicationXMLCharsetUTF8 = "application/xml; charset=utf-8" MIMEApplicationJSONCharsetUTF8 = "application/json; charset=utf-8" ) // HTTP status codes were copied from net/http with the following updates: // - Rename StatusNonAuthoritativeInfo to StatusNonAuthoritativeInformation // - Add StatusSwitchProxy (306) // NOTE: Keep this list in sync with statusMessage const ( StatusContinue = 100 // RFC 9110, 15.2.1 StatusSwitchingProtocols = 101 // RFC 9110, 15.2.2 StatusProcessing = 102 // RFC 2518, 10.1 StatusEarlyHints = 103 // RFC 8297 StatusOK = 200 // RFC 9110, 15.3.1 StatusCreated = 201 // RFC 9110, 15.3.2 StatusAccepted = 202 // RFC 9110, 15.3.3 StatusNonAuthoritativeInformation = 203 // RFC 9110, 15.3.4 StatusNoContent = 204 // RFC 9110, 15.3.5 StatusResetContent = 205 // RFC 9110, 15.3.6 StatusPartialContent = 206 // RFC 9110, 15.3.7 StatusMultiStatus = 207 // RFC 4918, 11.1 StatusAlreadyReported = 208 // RFC 5842, 7.1 StatusIMUsed = 226 // RFC 3229, 10.4.1 StatusMultipleChoices = 300 // RFC 9110, 15.4.1 StatusMovedPermanently = 301 // RFC 9110, 15.4.2 StatusFound = 302 // RFC 9110, 15.4.3 StatusSeeOther = 303 // RFC 9110, 15.4.4 StatusNotModified = 304 // RFC 9110, 15.4.5 StatusUseProxy = 305 // RFC 9110, 15.4.6 StatusSwitchProxy = 306 // RFC 9110, 15.4.7 (Unused) StatusTemporaryRedirect = 307 // RFC 9110, 15.4.8 StatusPermanentRedirect = 308 // RFC 9110, 15.4.9 StatusBadRequest = 400 // RFC 9110, 15.5.1 StatusUnauthorized = 401 // RFC 9110, 15.5.2 StatusPaymentRequired = 402 // RFC 9110, 15.5.3 StatusForbidden = 403 // RFC 9110, 15.5.4 StatusNotFound = 404 // RFC 9110, 15.5.5 StatusMethodNotAllowed = 405 // RFC 9110, 15.5.6 StatusNotAcceptable = 406 // RFC 9110, 15.5.7 StatusProxyAuthRequired = 407 // RFC 9110, 15.5.8 StatusRequestTimeout = 408 // RFC 9110, 15.5.9 StatusConflict = 409 // RFC 9110, 15.5.10 StatusGone = 410 // RFC 9110, 15.5.11 StatusLengthRequired = 411 // RFC 9110, 15.5.12 StatusPreconditionFailed = 412 // RFC 9110, 15.5.13 StatusRequestEntityTooLarge = 413 // RFC 9110, 15.5.14 StatusRequestURITooLong = 414 // RFC 9110, 15.5.15 StatusUnsupportedMediaType = 415 // RFC 9110, 15.5.16 StatusRequestedRangeNotSatisfiable = 416 // RFC 9110, 15.5.17 StatusExpectationFailed = 417 // RFC 9110, 15.5.18 StatusTeapot = 418 // RFC 9110, 15.5.19 (Unused) StatusMisdirectedRequest = 421 // RFC 9110, 15.5.20 StatusUnprocessableEntity = 422 // RFC 9110, 15.5.21 StatusLocked = 423 // RFC 4918, 11.3 StatusFailedDependency = 424 // RFC 4918, 11.4 StatusTooEarly = 425 // RFC 8470, 5.2. StatusUpgradeRequired = 426 // RFC 9110, 15.5.22 StatusPreconditionRequired = 428 // RFC 6585, 3 StatusTooManyRequests = 429 // RFC 6585, 4 StatusRequestHeaderFieldsTooLarge = 431 // RFC 6585, 5 StatusUnavailableForLegalReasons = 451 // RFC 7725, 3 StatusInternalServerError = 500 // RFC 9110, 15.6.1 StatusNotImplemented = 501 // RFC 9110, 15.6.2 StatusBadGateway = 502 // RFC 9110, 15.6.3 StatusServiceUnavailable = 503 // RFC 9110, 15.6.4 StatusGatewayTimeout = 504 // RFC 9110, 15.6.5 StatusHTTPVersionNotSupported = 505 // RFC 9110, 15.6.6 StatusVariantAlsoNegotiates = 506 // RFC 2295, 8.1 StatusInsufficientStorage = 507 // RFC 4918, 11.5 StatusLoopDetected = 508 // RFC 5842, 7.2 StatusNotExtended = 510 // RFC 2774, 7 StatusNetworkAuthenticationRequired = 511 // RFC 6585, 6 ) // Errors var ( ErrBadRequest = NewError(StatusBadRequest) // 400 ErrUnauthorized = NewError(StatusUnauthorized) // 401 ErrPaymentRequired = NewError(StatusPaymentRequired) // 402 ErrForbidden = NewError(StatusForbidden) // 403 ErrNotFound = NewError(StatusNotFound) // 404 ErrMethodNotAllowed = NewError(StatusMethodNotAllowed) // 405 ErrNotAcceptable = NewError(StatusNotAcceptable) // 406 ErrProxyAuthRequired = NewError(StatusProxyAuthRequired) // 407 ErrRequestTimeout = NewError(StatusRequestTimeout) // 408 ErrConflict = NewError(StatusConflict) // 409 ErrGone = NewError(StatusGone) // 410 ErrLengthRequired = NewError(StatusLengthRequired) // 411 ErrPreconditionFailed = NewError(StatusPreconditionFailed) // 412 ErrRequestEntityTooLarge = NewError(StatusRequestEntityTooLarge) // 413 ErrRequestURITooLong = NewError(StatusRequestURITooLong) // 414 ErrUnsupportedMediaType = NewError(StatusUnsupportedMediaType) // 415 ErrRequestedRangeNotSatisfiable = NewError(StatusRequestedRangeNotSatisfiable) // 416 ErrExpectationFailed = NewError(StatusExpectationFailed) // 417 ErrTeapot = NewError(StatusTeapot) // 418 ErrMisdirectedRequest = NewError(StatusMisdirectedRequest) // 421 ErrUnprocessableEntity = NewError(StatusUnprocessableEntity) // 422 ErrLocked = NewError(StatusLocked) // 423 ErrFailedDependency = NewError(StatusFailedDependency) // 424 ErrTooEarly = NewError(StatusTooEarly) // 425 ErrUpgradeRequired = NewError(StatusUpgradeRequired) // 426 ErrPreconditionRequired = NewError(StatusPreconditionRequired) // 428 ErrTooManyRequests = NewError(StatusTooManyRequests) // 429 ErrRequestHeaderFieldsTooLarge = NewError(StatusRequestHeaderFieldsTooLarge) // 431 ErrUnavailableForLegalReasons = NewError(StatusUnavailableForLegalReasons) // 451 ErrInternalServerError = NewError(StatusInternalServerError) // 500 ErrNotImplemented = NewError(StatusNotImplemented) // 501 ErrBadGateway = NewError(StatusBadGateway) // 502 ErrServiceUnavailable = NewError(StatusServiceUnavailable) // 503 ErrGatewayTimeout = NewError(StatusGatewayTimeout) // 504 ErrHTTPVersionNotSupported = NewError(StatusHTTPVersionNotSupported) // 505 ErrVariantAlsoNegotiates = NewError(StatusVariantAlsoNegotiates) // 506 ErrInsufficientStorage = NewError(StatusInsufficientStorage) // 507 ErrLoopDetected = NewError(StatusLoopDetected) // 508 ErrNotExtended = NewError(StatusNotExtended) // 510 ErrNetworkAuthenticationRequired = NewError(StatusNetworkAuthenticationRequired) // 511 ) // HTTP Headers were copied from net/http. const ( HeaderAuthorization = "Authorization" HeaderProxyAuthenticate = "Proxy-Authenticate" HeaderProxyAuthorization = "Proxy-Authorization" HeaderWWWAuthenticate = "WWW-Authenticate" HeaderAge = "Age" HeaderCacheControl = "Cache-Control" HeaderClearSiteData = "Clear-Site-Data" HeaderExpires = "Expires" HeaderPragma = "Pragma" HeaderWarning = "Warning" HeaderAcceptCH = "Accept-CH" HeaderAcceptCHLifetime = "Accept-CH-Lifetime" HeaderContentDPR = "Content-DPR" HeaderDPR = "DPR" HeaderEarlyData = "Early-Data" HeaderSaveData = "Save-Data" HeaderViewportWidth = "Viewport-Width" HeaderWidth = "Width" HeaderETag = "ETag" HeaderIfMatch = "If-Match" HeaderIfModifiedSince = "If-Modified-Since" HeaderIfNoneMatch = "If-None-Match" HeaderIfUnmodifiedSince = "If-Unmodified-Since" HeaderLastModified = "Last-Modified" HeaderVary = "Vary" HeaderConnection = "Connection" HeaderKeepAlive = "Keep-Alive" HeaderAccept = "Accept" HeaderAcceptCharset = "Accept-Charset" HeaderAcceptEncoding = "Accept-Encoding" HeaderAcceptLanguage = "Accept-Language" HeaderCookie = "Cookie" HeaderExpect = "Expect" HeaderMaxForwards = "Max-Forwards" HeaderSetCookie = "Set-Cookie" HeaderAccessControlAllowCredentials = "Access-Control-Allow-Credentials" HeaderAccessControlAllowHeaders = "Access-Control-Allow-Headers" HeaderAccessControlAllowMethods = "Access-Control-Allow-Methods" HeaderAccessControlAllowOrigin = "Access-Control-Allow-Origin" HeaderAccessControlExposeHeaders = "Access-Control-Expose-Headers" HeaderAccessControlMaxAge = "Access-Control-Max-Age" HeaderAccessControlRequestHeaders = "Access-Control-Request-Headers" HeaderAccessControlRequestMethod = "Access-Control-Request-Method" HeaderOrigin = "Origin" HeaderTimingAllowOrigin = "Timing-Allow-Origin" HeaderXPermittedCrossDomainPolicies = "X-Permitted-Cross-Domain-Policies" HeaderDNT = "DNT" HeaderTk = "Tk" HeaderContentDisposition = "Content-Disposition" HeaderContentEncoding = "Content-Encoding" HeaderContentLanguage = "Content-Language" HeaderContentLength = "Content-Length" HeaderContentLocation = "Content-Location" HeaderContentType = "Content-Type" HeaderForwarded = "Forwarded" HeaderVia = "Via" HeaderXForwardedFor = "X-Forwarded-For" HeaderXForwardedHost = "X-Forwarded-Host" HeaderXForwardedProto = "X-Forwarded-Proto" HeaderXForwardedProtocol = "X-Forwarded-Protocol" HeaderXForwardedSsl = "X-Forwarded-Ssl" HeaderXUrlScheme = "X-Url-Scheme" HeaderLocation = "Location" HeaderFrom = "From" HeaderHost = "Host" HeaderReferer = "Referer" HeaderReferrerPolicy = "Referrer-Policy" HeaderUserAgent = "User-Agent" HeaderAllow = "Allow" HeaderServer = "Server" HeaderAcceptRanges = "Accept-Ranges" HeaderContentRange = "Content-Range" HeaderIfRange = "If-Range" HeaderRange = "Range" HeaderContentSecurityPolicy = "Content-Security-Policy" HeaderContentSecurityPolicyReportOnly = "Content-Security-Policy-Report-Only" HeaderCrossOriginResourcePolicy = "Cross-Origin-Resource-Policy" HeaderExpectCT = "Expect-CT" HeaderPermissionsPolicy = "Permissions-Policy" HeaderPublicKeyPins = "Public-Key-Pins" HeaderPublicKeyPinsReportOnly = "Public-Key-Pins-Report-Only" HeaderStrictTransportSecurity = "Strict-Transport-Security" HeaderUpgradeInsecureRequests = "Upgrade-Insecure-Requests" HeaderXContentTypeOptions = "X-Content-Type-Options" HeaderXDownloadOptions = "X-Download-Options" HeaderXFrameOptions = "X-Frame-Options" HeaderXPoweredBy = "X-Powered-By" HeaderXXSSProtection = "X-XSS-Protection" HeaderLastEventID = "Last-Event-ID" HeaderNEL = "NEL" HeaderPingFrom = "Ping-From" HeaderPingTo = "Ping-To" HeaderReportTo = "Report-To" HeaderTE = "TE" HeaderTrailer = "Trailer" HeaderTransferEncoding = "Transfer-Encoding" HeaderSecFetchSite = "Sec-Fetch-Site" HeaderSecWebSocketAccept = "Sec-WebSocket-Accept" HeaderSecWebSocketExtensions = "Sec-WebSocket-Extensions" HeaderSecWebSocketKey = "Sec-WebSocket-Key" HeaderSecWebSocketProtocol = "Sec-WebSocket-Protocol" HeaderSecWebSocketVersion = "Sec-WebSocket-Version" HeaderAcceptPatch = "Accept-Patch" HeaderAcceptPushPolicy = "Accept-Push-Policy" HeaderAcceptSignature = "Accept-Signature" HeaderAltSvc = "Alt-Svc" HeaderDate = "Date" HeaderIndex = "Index" HeaderLargeAllocation = "Large-Allocation" HeaderLink = "Link" HeaderPushPolicy = "Push-Policy" HeaderRetryAfter = "Retry-After" HeaderServerTiming = "Server-Timing" HeaderSignature = "Signature" HeaderSignedHeaders = "Signed-Headers" HeaderSourceMap = "SourceMap" HeaderUpgrade = "Upgrade" HeaderXDNSPrefetchControl = "X-DNS-Prefetch-Control" HeaderXPingback = "X-Pingback" HeaderXRequestID = "X-Request-ID" HeaderXRequestedWith = "X-Requested-With" HeaderXResponseTime = "X-Response-Time" HeaderXRobotsTag = "X-Robots-Tag" HeaderXUACompatible = "X-UA-Compatible" HeaderAccessControlAllowPrivateNetwork = "Access-Control-Allow-Private-Network" HeaderAccessControlRequestPrivateNetwork = "Access-Control-Request-Private-Network" ) // Network types that are commonly used const ( NetworkTCP = "tcp" NetworkTCP4 = "tcp4" NetworkTCP6 = "tcp6" NetworkUnix = "unix" ) // Compression types const ( StrGzip = "gzip" StrCompress = "compress" StrIdentity = "identity" StrBr = "br" StrDeflate = "deflate" StrBrotli = "brotli" StrZstd = "zstd" ) // Cookie SameSite // https://datatracker.ietf.org/doc/html/draft-ietf-httpbis-rfc6265bis-03#section-4.1.2.7 const ( CookieSameSiteDisabled = "disabled" // not in RFC, just control "SameSite" attribute will not be set. CookieSameSiteLaxMode = "Lax" CookieSameSiteStrictMode = "Strict" CookieSameSiteNoneMode = "None" ) // Route Constraints const ( ConstraintInt = "int" ConstraintBool = "bool" ConstraintFloat = "float" ConstraintAlpha = "alpha" ConstraintGUID = "guid" ConstraintMinLen = "minLen" ConstraintMaxLen = "maxLen" ConstraintLen = "len" ConstraintBetweenLen = "betweenLen" ConstraintMinLenLower = "minlen" ConstraintMaxLenLower = "maxlen" ConstraintBetweenLenLower = "betweenlen" ConstraintMin = "min" ConstraintMax = "max" ConstraintRange = "range" ConstraintDatetime = "datetime" ConstraintRegex = "regex" ) ================================================ FILE: ctx.go ================================================ // ⚡️ Fiber is an Express inspired web framework written in Go with ☕️ // 🤖 GitHub Repository: https://github.com/gofiber/fiber // 📌 API Documentation: https://docs.gofiber.io package fiber import ( "context" "crypto/tls" "fmt" "io" "maps" "mime/multipart" "strconv" "strings" "sync/atomic" "time" "github.com/gofiber/utils/v2" utilsbytes "github.com/gofiber/utils/v2/bytes" "github.com/valyala/bytebufferpool" "github.com/valyala/fasthttp" ) const ( schemeHTTP = "http" schemeHTTPS = "https" ) const ( // maxParams defines the maximum number of parameters per route. maxParams = 30 maxDetectionPaths = 3 ) var ( _ io.Writer = (*DefaultCtx)(nil) // Compile-time check _ context.Context = (*DefaultCtx)(nil) // Compile-time check ) // The contextKey type is unexported to prevent collisions with context keys defined in // other packages. type contextKey int // userContextKey define the key name for storing context.Context in *fasthttp.RequestCtx const ( userContextKey contextKey = iota // __local_user_context__ ) // DefaultCtx is the default implementation of the Ctx interface // generation tool `go install github.com/vburenin/ifacemaker@f30b6f9bdbed4b5c4804ec9ba4a04a999525c202` // https://github.com/vburenin/ifacemaker/blob/f30b6f9bdbed4b5c4804ec9ba4a04a999525c202/ifacemaker.go#L14-L31 // //go:generate ifacemaker --file ctx.go --file req.go --file res.go --struct DefaultCtx --iface Ctx --pkg fiber --promoted --output ctx_interface_gen.go --not-exported true --iface-comment "Ctx represents the Context which hold the HTTP request and response.\nIt has methods for the request query string, parameters, body, HTTP headers and so on." type DefaultCtx struct { handlerCtx CustomCtx // Active custom context implementation, if any DefaultReq // Default request api DefaultRes // Default response api app *App // Reference to *App route *Route // Reference to *Route fasthttp *fasthttp.RequestCtx // Reference to *fasthttp.RequestCtx bind *Bind // Default bind reference redirect *Redirect // Default redirect reference viewBindMap Map // Default view map to bind template engine values [maxParams]string // Route parameter values baseURI string // HTTP base uri pathOriginal string // Original HTTP path flashMessages redirectionMsgs // Flash messages path []byte // HTTP path with the modifications by the configuration detectionPath []byte // Route detection path treePathHash int // Hash of the path for the search in the tree indexRoute int // Index of the current route indexHandler int // Index of the current handler methodInt int // HTTP method INT equivalent abandoned atomic.Bool // If true, ctx won't be pooled until ForceRelease is called matched bool // Non use route matched skipNonUseRoutes bool // Skip non-use routes while iterating middleware } // TLSHandler hosts the callback hooks Fiber invokes while negotiating TLS // connections, including optional client certificate lookups. type TLSHandler struct { clientHelloInfo *tls.ClientHelloInfo } // GetClientInfo Callback function to set ClientHelloInfo // Must comply with the method structure of https://cs.opensource.google/go/go/+/refs/tags/go1.20:src/crypto/tls/common.go;l=554-563 // Since we overlay the method of the TLS config in the listener method func (t *TLSHandler) GetClientInfo(info *tls.ClientHelloInfo) (*tls.Certificate, error) { t.clientHelloInfo = info return nil, nil //nolint:nilnil // Not returning anything useful here is probably fine } // Views is the interface that wraps the Render function. type Views interface { Load() error Render(out io.Writer, name string, binding any, layout ...string) error } // App returns the *App reference to the instance of the Fiber application func (c *DefaultCtx) App() *App { return c.app } // BaseURL returns (protocol + host + base path). func (c *DefaultCtx) BaseURL() string { // TODO: Could be improved: 53.8 ns/op 32 B/op 1 allocs/op // Should work like https://codeigniter.com/user_guide/helpers/url_helper.html if c.baseURI != "" { return c.baseURI } c.baseURI = c.Scheme() + "://" + c.Host() return c.baseURI } // RequestCtx returns *fasthttp.RequestCtx that carries a deadline // a cancellation signal, and other values across API boundaries. func (c *DefaultCtx) RequestCtx() *fasthttp.RequestCtx { return c.fasthttp } // Context returns a context implementation that was set by // user earlier or returns a non-nil, empty context, if it was not set earlier. func (c *DefaultCtx) Context() context.Context { if c.fasthttp == nil { return context.Background() } if ctx, ok := c.fasthttp.UserValue(userContextKey).(context.Context); ok && ctx != nil { return ctx } ctx := context.Background() c.SetContext(ctx) return ctx } // SetContext sets a context implementation by user. func (c *DefaultCtx) SetContext(ctx context.Context) { if c.fasthttp == nil { return } c.fasthttp.SetUserValue(userContextKey, ctx) } // Deadline returns the time when work done on behalf of this context // should be canceled. Deadline returns ok==false when no deadline is // set. Successive calls to Deadline return the same results. // // Due to current limitations in how fasthttp works, Deadline operates as a nop. // See: https://github.com/valyala/fasthttp/issues/965#issuecomment-777268945 func (*DefaultCtx) Deadline() (time.Time, bool) { return time.Time{}, false } // Done returns a channel that's closed when work done on behalf of this // context should be canceled. Done may return nil if this context can // never be canceled. Successive calls to Done return the same value. // The close of the Done channel may happen asynchronously, // after the cancel function returns. // // Due to current limitations in how fasthttp works, Done operates as a nop. // See: https://github.com/valyala/fasthttp/issues/965#issuecomment-777268945 func (*DefaultCtx) Done() <-chan struct{} { return nil } // Err mirrors context.Err, returning nil until cancellation and then the terminal error value. // // Due to current limitations in how fasthttp works, Err operates as a nop. // See: https://github.com/valyala/fasthttp/issues/965#issuecomment-777268945 func (*DefaultCtx) Err() error { return nil } // Request return the *fasthttp.Request object // This allows you to use all fasthttp request methods // https://godoc.org/github.com/valyala/fasthttp#Request // Returns nil if the context has been released. func (c *DefaultCtx) Request() *fasthttp.Request { if c.fasthttp == nil { return nil } return &c.fasthttp.Request } // Response return the *fasthttp.Response object // This allows you to use all fasthttp response methods // https://godoc.org/github.com/valyala/fasthttp#Response // Returns nil if the context has been released. func (c *DefaultCtx) Response() *fasthttp.Response { if c.fasthttp == nil { return nil } return &c.fasthttp.Response } // Get returns the HTTP request header specified by field. // Field names are case-insensitive // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. func (c *DefaultCtx) Get(key string, defaultValue ...string) string { return c.DefaultReq.Get(key, defaultValue...) } // GetHeaders returns the HTTP request headers. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. func (c *DefaultCtx) GetHeaders() map[string][]string { return c.DefaultReq.GetHeaders() } // GetReqHeaders returns the HTTP request headers. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. func (c *DefaultCtx) GetReqHeaders() map[string][]string { return c.DefaultReq.GetHeaders() } // GetRespHeader returns the HTTP response header specified by field. // Field names are case-insensitive // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. func (c *DefaultCtx) GetRespHeader(key string, defaultValue ...string) string { return c.DefaultRes.Get(key, defaultValue...) } // GetRespHeaders returns the HTTP response headers. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. func (c *DefaultCtx) GetRespHeaders() map[string][]string { return c.DefaultRes.GetHeaders() } // ClientHelloInfo return CHI from context func (c *DefaultCtx) ClientHelloInfo() *tls.ClientHelloInfo { if c.app.tlsHandler != nil { return c.app.tlsHandler.clientHelloInfo } return nil } // Next executes the next method in the stack that matches the current route. func (c *DefaultCtx) Next() error { // Increment handler index c.indexHandler++ // Did we execute all route handlers? if c.indexHandler < len(c.route.Handlers) { if c.handlerCtx != nil { return c.route.Handlers[c.indexHandler](c.handlerCtx) } return c.route.Handlers[c.indexHandler](c) } if c.handlerCtx != nil { _, err := c.app.nextCustom(c.handlerCtx) return err } _, err := c.app.next(c) return err } // RestartRouting instead of going to the next handler. This may be useful after // changing the request path. Note that handlers might be executed again. func (c *DefaultCtx) RestartRouting() error { c.indexRoute = -1 if c.handlerCtx != nil { _, err := c.app.nextCustom(c.handlerCtx) return err } _, err := c.app.next(c) return err } func (c *DefaultCtx) setHandlerCtx(ctx CustomCtx) { if ctx == nil { c.handlerCtx = nil return } if defaultCtx, ok := ctx.(*DefaultCtx); ok && defaultCtx == c { c.handlerCtx = nil return } c.handlerCtx = ctx } // OriginalURL contains the original request URL. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting to use the value outside the Handler. func (c *DefaultCtx) OriginalURL() string { return c.app.toString(c.fasthttp.Request.Header.RequestURI()) } // Path returns the path part of the request URL. // Optionally, you could override the path. // Make copies or use the Immutable setting to use the value outside the Handler. func (c *DefaultCtx) Path(override ...string) string { if len(override) != 0 && string(c.path) != override[0] { // Set new path to context c.pathOriginal = override[0] // Set new path to request context c.fasthttp.Request.URI().SetPath(c.pathOriginal) // Prettify path c.configDependentPaths() } return c.app.toString(c.path) } // RequestID returns the request identifier from the response header or request header. func (c *DefaultCtx) RequestID() string { if requestID := c.GetRespHeader(HeaderXRequestID); requestID != "" { return requestID } return c.Get(HeaderXRequestID) } // Req returns a convenience type whose API is limited to operations // on the incoming request. func (c *DefaultCtx) Req() Req { return &c.DefaultReq } // Res returns a convenience type whose API is limited to operations // on the outgoing response. func (c *DefaultCtx) Res() Res { return &c.DefaultRes } // Redirect returns the Redirect reference. // Use Redirect().Status() to set custom redirection status code. // If status is not specified, status defaults to 303 See Other. // You can use Redirect().To(), Redirect().Route() and Redirect().Back() for redirection. func (c *DefaultCtx) Redirect() *Redirect { if c.redirect == nil { c.redirect = AcquireRedirect() c.redirect.c = c } return c.redirect } // ViewBind Add vars to default view var map binding to template engine. // Variables are read by the Render method and may be overwritten. func (c *DefaultCtx) ViewBind(vars Map) error { // init viewBindMap - lazy map if c.viewBindMap == nil { c.viewBindMap = make(Map, len(vars)) } maps.Copy(c.viewBindMap, vars) return nil } // Route returns the matched Route struct. func (c *DefaultCtx) Route() *Route { if c.route == nil { // Fallback for fasthttp error handler return &Route{ path: c.pathOriginal, Path: c.pathOriginal, Method: c.Method(), Handlers: make([]Handler, 0), Params: make([]string, 0), } } return c.route } // FullPath returns the matched route path, including any group prefixes. func (c *DefaultCtx) FullPath() string { return c.Route().Path } // Matched returns true if the current request path was matched by the router. func (c *DefaultCtx) Matched() bool { return c.getMatched() } // IsMiddleware returns true if the current request handler was registered as middleware. func (c *DefaultCtx) IsMiddleware() bool { if c.route == nil { return false } if c.route.use { return true } // For route-level middleware, there will be a next handler in the chain return c.indexHandler+1 < len(c.route.Handlers) } // HasBody returns true if the request declares a body via Content-Length, Transfer-Encoding, or already buffered payload data. func (c *DefaultCtx) HasBody() bool { hdr := &c.fasthttp.Request.Header //nolint:revive // switch is exhaustive for all ContentLength() cases switch cl := hdr.ContentLength(); { case cl > 0: return true case cl == -1: // fasthttp reports -1 for Transfer-Encoding: chunked bodies. return true case cl == 0: if hasTransferEncodingBody(hdr) { return true } } return len(c.fasthttp.Request.Body()) > 0 } // OverrideParam overwrites a route parameter value by name. // If the parameter name does not exist in the route, this method does nothing. func (c *DefaultCtx) OverrideParam(name, value string) { // If no route is matched, there are no parameters to update if !c.Matched() { return } // Normalize wildcard (*) and plus (+) tokens to their internal // representations (*1, +1) used by the router. if name == "*" || name == "+" { name += "1" } if c.app.config.CaseSensitive { for i, param := range c.route.Params { if param == name { c.values[i] = value return } } return } nameBytes := utils.UnsafeBytes(name) for i, param := range c.route.Params { if utils.EqualFold(utils.UnsafeBytes(param), nameBytes) { c.values[i] = value return } } } func hasTransferEncodingBody(hdr *fasthttp.RequestHeader) bool { teBytes := hdr.Peek(HeaderTransferEncoding) var te string if len(teBytes) > 0 { te = utils.UnsafeString(teBytes) } else { for key, value := range hdr.All() { if !utils.EqualFold(utils.UnsafeString(key), HeaderTransferEncoding) { continue } te = utils.UnsafeString(value) break } } if te == "" { return false } hasEncoding := false for raw := range strings.SplitSeq(te, ",") { token := utils.TrimSpace(raw) if token == "" { continue } if idx := strings.IndexByte(token, ';'); idx >= 0 { token = utils.TrimSpace(token[:idx]) } if token == "" { continue } if utils.EqualFold(token, "identity") { continue } hasEncoding = true } return hasEncoding } // IsWebSocket returns true if the request includes a WebSocket upgrade handshake. func (c *DefaultCtx) IsWebSocket() bool { conn := c.fasthttp.Request.Header.Peek(HeaderConnection) var isUpgrade bool for v := range strings.SplitSeq(utils.UnsafeString(conn), ",") { if utils.EqualFold(utils.TrimSpace(v), "upgrade") { isUpgrade = true break } } if !isUpgrade { return false } return utils.EqualFold(c.fasthttp.Request.Header.Peek(HeaderUpgrade), websocketBytes) } // IsPreflight returns true if the request is a CORS preflight. func (c *DefaultCtx) IsPreflight() bool { if c.Method() != MethodOptions { return false } hdr := &c.fasthttp.Request.Header if len(hdr.Peek(HeaderAccessControlRequestMethod)) == 0 { return false } return len(hdr.Peek(HeaderOrigin)) > 0 } // SaveFile saves any multipart file to disk. func (*DefaultCtx) SaveFile(fileheader *multipart.FileHeader, path string) error { return fasthttp.SaveMultipartFile(fileheader, path) } // SaveFileToStorage saves any multipart file to an external storage system. func (c *DefaultCtx) SaveFileToStorage(fileheader *multipart.FileHeader, path string, storage Storage) error { file, err := fileheader.Open() if err != nil { return fmt.Errorf("failed to open: %w", err) } defer file.Close() //nolint:errcheck // not needed maxUploadSize := c.app.config.BodyLimit if maxUploadSize <= 0 { maxUploadSize = DefaultBodyLimit } if fileheader.Size > 0 && fileheader.Size > int64(maxUploadSize) { return fmt.Errorf("failed to read: %w", fasthttp.ErrBodyTooLarge) } buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) limitedReader := io.LimitReader(file, int64(maxUploadSize)+1) if _, err = buf.ReadFrom(limitedReader); err != nil { return fmt.Errorf("failed to read: %w", err) } if buf.Len() > maxUploadSize { return fmt.Errorf("failed to read: %w", fasthttp.ErrBodyTooLarge) } data := append([]byte(nil), buf.Bytes()...) if err := storage.SetWithContext(c.Context(), path, data, 0); err != nil { return fmt.Errorf("failed to store: %w", err) } return nil } // Secure returns whether a secure connection was established. func (c *DefaultCtx) Secure() bool { return c.Protocol() == schemeHTTPS } // Status sets the HTTP status for the response. // This method is chainable. func (c *DefaultCtx) Status(status int) Ctx { c.fasthttp.Response.SetStatusCode(status) return c } // String returns unique string representation of the ctx. // // The returned value may be useful for logging. func (c *DefaultCtx) String() string { // Get buffer from pool buf := bytebufferpool.Get() // Start with the ID, converting it to a hex string without fmt.Sprintf buf.WriteByte('#') // Convert ID to hexadecimal id := strconv.FormatUint(c.fasthttp.ID(), 16) // Pad with leading zeros to ensure 16 characters for i := 0; i < (16 - len(id)); i++ { buf.WriteByte('0') } buf.WriteString(id) buf.WriteString(" - ") // Add local and remote addresses directly buf.WriteString(c.fasthttp.LocalAddr().String()) buf.WriteString(" <-> ") buf.WriteString(c.fasthttp.RemoteAddr().String()) buf.WriteString(" - ") // Add method and URI buf.Write(c.fasthttp.Request.Header.Method()) buf.WriteByte(' ') buf.Write(c.fasthttp.URI().FullURI()) // Allocate string str := buf.String() // Reset buffer buf.Reset() bytebufferpool.Put(buf) return str } // Value makes it possible to retrieve values (Locals) under keys scoped to the request // and therefore available to all following routes that match the request. If the context // has been released and c.fasthttp is nil (for example, after ReleaseCtx), Value returns nil. func (c *DefaultCtx) Value(key any) any { if c.fasthttp == nil { return nil } return c.fasthttp.UserValue(key) } var ( // xmlHTTPRequestBytes is precomputed for XHR detection xmlHTTPRequestBytes = []byte("xmlhttprequest") // websocketBytes is precomputed for WebSocket upgrade detection websocketBytes = []byte("websocket") ) // XHR returns a Boolean property, that is true, if the request's X-Requested-With header field is XMLHttpRequest, // indicating that the request was issued by a client library (such as jQuery). func (c *DefaultCtx) XHR() bool { return utils.EqualFold(c.fasthttp.Request.Header.Peek(HeaderXRequestedWith), xmlHTTPRequestBytes) } // configDependentPaths set paths for route recognition and prepared paths for the user, // here the features for caseSensitive, decoded paths, strict paths are evaluated func (c *DefaultCtx) configDependentPaths() { c.path = append(c.path[:0], c.pathOriginal...) // If UnescapePath enabled, we decode the path and save it for the framework user if c.app.config.UnescapePath { c.path = fasthttp.AppendUnquotedArg(c.path[:0], c.path) } // another path is specified which is for routing recognition only // use the path that was changed by the previous configuration flags c.detectionPath = append(c.detectionPath[:0], c.path...) // If CaseSensitive is disabled, we lowercase the original path if !c.app.config.CaseSensitive { c.detectionPath = utilsbytes.UnsafeToLower(c.detectionPath) } // If StrictRouting is disabled, we strip all trailing slashes if !c.app.config.StrictRouting && len(c.detectionPath) > 1 && c.detectionPath[len(c.detectionPath)-1] == '/' { c.detectionPath = utils.TrimRight(c.detectionPath, '/') } // Define the path for dividing routes into areas for fast tree detection, so that fewer routes need to be traversed, // since the first three characters area select a list of routes c.treePathHash = 0 if len(c.detectionPath) >= maxDetectionPaths { c.treePathHash = int(c.detectionPath[0])<<16 | int(c.detectionPath[1])<<8 | int(c.detectionPath[2]) } } // Reset is a method to reset context fields by given request when to use server handlers. func (c *DefaultCtx) Reset(fctx *fasthttp.RequestCtx) { // Reset route and handler index c.indexRoute = -1 c.indexHandler = 0 // Reset matched flag c.matched = false c.skipNonUseRoutes = false // Set paths c.pathOriginal = c.app.toString(fctx.URI().PathOriginal()) // Set method c.methodInt = c.app.methodInt(utils.UnsafeString(fctx.Request.Header.Method())) // Attach *fasthttp.RequestCtx to ctx c.fasthttp = fctx // reset base uri c.baseURI = "" // Prettify path c.configDependentPaths() c.DefaultReq.c = c c.DefaultRes.c = c c.fasthttp.SetUserValue(userContextKey, nil) } // release is a method to reset context fields when to use ReleaseCtx() func (c *DefaultCtx) release() { c.route = nil c.fasthttp = nil if c.bind != nil { ReleaseBind(c.bind) c.bind = nil } c.flashMessages = c.flashMessages[:0] // Clear viewBindMap by deleting all keys (reuse underlying map if possible) clear(c.viewBindMap) if c.redirect != nil { ReleaseRedirect(c.redirect) c.redirect = nil } c.skipNonUseRoutes = false // performance: no need for using c.abandoned.Store(false) here, as it is always set to false when it was true in ForceRelease c.handlerCtx = nil c.DefaultReq.release() c.DefaultRes.release() } // Abandon marks this context as abandoned. An abandoned context will not be // returned to the pool when ReleaseCtx is called. // // This is used by the timeout middleware to return immediately while the // handler goroutine continues using the context safely. // // Only call ForceRelease after Abandon if you can guarantee no other goroutine // (including Fiber's requestHandler and ErrorHandler) will touch the context. // The timeout middleware intentionally does NOT call ForceRelease to avoid // races, which means timed-out requests leak their contexts until a safe // reclamation strategy exists. func (c *DefaultCtx) Abandon() { c.abandoned.Store(true) } // IsAbandoned returns true if Abandon() was called on this context. func (c *DefaultCtx) IsAbandoned() bool { return c.abandoned.Load() } // ForceRelease releases an abandoned context back to the pool. // This MUST only be called after all goroutines (including requestHandler and // ErrorHandler) have completely finished using this context. Calling it while // any goroutine is still running causes races. func (c *DefaultCtx) ForceRelease() { c.abandoned.Store(false) c.app.ReleaseCtx(c) } func (c *DefaultCtx) renderExtensions(bind any) { if bindMap, ok := bind.(Map); ok { // Bind view map for key, value := range c.viewBindMap { if _, ok := bindMap[key]; !ok { bindMap[key] = value } } // Check if the PassLocalsToViews option is enabled (by default it is disabled) if c.app.config.PassLocalsToViews { // Loop through each local and set it in the map c.fasthttp.VisitUserValues(func(key []byte, val any) { // check if bindMap doesn't contain the key if _, ok := bindMap[c.app.toString(key)]; !ok { // Set the key and value in the bindMap bindMap[c.app.toString(key)] = val } }) } } if len(c.app.mountFields.appListKeys) == 0 { c.app.generateAppListKeys() } } // Bind You can bind body, cookie, headers etc. into the map, map slice, struct easily by using Binding method. // It gives custom binding support, detailed binding options and more. // Replacement of: BodyParser, ParamsParser, GetReqHeaders, GetRespHeaders, AllParams, QueryParser, ReqHeaderParser func (c *DefaultCtx) Bind() *Bind { if c.bind == nil { c.bind = AcquireBind() } c.bind.ctx = c return c.bind } // Methods to use with next stack. func (c *DefaultCtx) getMethodInt() int { return c.methodInt } func (c *DefaultCtx) getIndexRoute() int { return c.indexRoute } func (c *DefaultCtx) getTreePathHash() int { return c.treePathHash } func (c *DefaultCtx) getDetectionPath() string { return c.app.toString(c.detectionPath) } func (c *DefaultCtx) getValues() *[maxParams]string { return &c.values } func (c *DefaultCtx) getMatched() bool { return c.matched } func (c *DefaultCtx) getSkipNonUseRoutes() bool { return c.skipNonUseRoutes } func (c *DefaultCtx) setIndexHandler(handler int) { c.indexHandler = handler } func (c *DefaultCtx) setIndexRoute(route int) { c.indexRoute = route } func (c *DefaultCtx) setMatched(matched bool) { c.matched = matched } func (c *DefaultCtx) setSkipNonUseRoutes(skip bool) { c.skipNonUseRoutes = skip } func (c *DefaultCtx) setRoute(route *Route) { c.route = route } func (c *DefaultCtx) getPathOriginal() string { return c.pathOriginal } ================================================ FILE: ctx_interface.go ================================================ // ⚡️ Fiber is an Express inspired web framework written in Go with ☕️ // 🤖 GitHub Repository: https://github.com/gofiber/fiber // 📌 API Documentation: https://docs.gofiber.io package fiber import ( "github.com/valyala/fasthttp" ) // CustomCtx extends Ctx with the additional methods required by Fiber's // internals and middleware helpers. type CustomCtx interface { Ctx // Reset is a method to reset context fields by given request when to use server handlers. Reset(fctx *fasthttp.RequestCtx) // release is called before returning the context to the pool. release() // Abandon marks the context as abandoned. An abandoned context will not be // returned to the pool when ReleaseCtx is called. This is used by the timeout // middleware to return immediately while the handler goroutine continues. // The cleanup goroutine must call ForceRelease when the handler finishes. Abandon() // IsAbandoned returns true if the context has been abandoned. IsAbandoned() bool // ForceRelease releases an abandoned context back to the pool. // Must only be called after the handler goroutine has completely finished. ForceRelease() // Methods to use with next stack. getMethodInt() int getIndexRoute() int getTreePathHash() int getDetectionPath() string getPathOriginal() string getValues() *[maxParams]string getMatched() bool getSkipNonUseRoutes() bool setIndexHandler(handler int) setIndexRoute(route int) setMatched(matched bool) setSkipNonUseRoutes(skip bool) setRoute(route *Route) } // NewDefaultCtx constructs the default context implementation bound to the // provided application. func NewDefaultCtx(app *App) *DefaultCtx { // return ctx ctx := &DefaultCtx{ // Set app reference app: app, } ctx.DefaultReq.c = ctx ctx.DefaultRes.c = ctx return ctx } // AcquireCtx retrieves a new Ctx from the pool. func (app *App) AcquireCtx(fctx *fasthttp.RequestCtx) CustomCtx { ctx, ok := app.pool.Get().(CustomCtx) if !ok { panic(errCustomCtxTypeAssertion) } if app.hasCustomCtx { if setter, ok := ctx.(interface{ setHandlerCtx(CustomCtx) }); ok { setter.setHandlerCtx(ctx) } } ctx.Reset(fctx) return ctx } // ReleaseCtx releases the ctx back into the pool. // If the context was abandoned (e.g., by timeout middleware), this is a no-op. // Call ForceRelease only when you can guarantee no goroutines (including the // requestHandler and ErrorHandler) still touch the context; the timeout // middleware intentionally leaves abandoned contexts unreleased to avoid races. func (app *App) ReleaseCtx(c CustomCtx) { if c.IsAbandoned() { return } c.release() app.pool.Put(c) } ================================================ FILE: ctx_interface_gen.go ================================================ // Code generated by ifacemaker; DO NOT EDIT. package fiber import ( "bufio" "context" "crypto/tls" "io" "mime/multipart" "time" "github.com/valyala/fasthttp" ) // Ctx represents the Context which hold the HTTP request and response. // It has methods for the request query string, parameters, body, HTTP headers and so on. type Ctx interface { // App returns the *App reference to the instance of the Fiber application App() *App // BaseURL returns (protocol + host + base path). BaseURL() string // RequestCtx returns *fasthttp.RequestCtx that carries a deadline // a cancellation signal, and other values across API boundaries. RequestCtx() *fasthttp.RequestCtx // Context returns a context implementation that was set by // user earlier or returns a non-nil, empty context, if it was not set earlier. Context() context.Context // SetContext sets a context implementation by user. SetContext(ctx context.Context) // Deadline returns the time when work done on behalf of this context // should be canceled. Deadline returns ok==false when no deadline is // set. Successive calls to Deadline return the same results. // // Due to current limitations in how fasthttp works, Deadline operates as a nop. // See: https://github.com/valyala/fasthttp/issues/965#issuecomment-777268945 Deadline() (time.Time, bool) // Done returns a channel that's closed when work done on behalf of this // context should be canceled. Done may return nil if this context can // never be canceled. Successive calls to Done return the same value. // The close of the Done channel may happen asynchronously, // after the cancel function returns. // // Due to current limitations in how fasthttp works, Done operates as a nop. // See: https://github.com/valyala/fasthttp/issues/965#issuecomment-777268945 Done() <-chan struct{} // Err mirrors context.Err, returning nil until cancellation and then the terminal error value. // // Due to current limitations in how fasthttp works, Err operates as a nop. // See: https://github.com/valyala/fasthttp/issues/965#issuecomment-777268945 Err() error // Request return the *fasthttp.Request object // This allows you to use all fasthttp request methods // https://godoc.org/github.com/valyala/fasthttp#Request // Returns nil if the context has been released. Request() *fasthttp.Request // Response return the *fasthttp.Response object // This allows you to use all fasthttp response methods // https://godoc.org/github.com/valyala/fasthttp#Response // Returns nil if the context has been released. Response() *fasthttp.Response // Get returns the HTTP request header specified by field. // Field names are case-insensitive // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. Get(key string, defaultValue ...string) string // GetHeaders returns the HTTP request headers. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. GetHeaders() map[string][]string // GetReqHeaders returns the HTTP request headers. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. GetReqHeaders() map[string][]string // GetRespHeader returns the HTTP response header specified by field. // Field names are case-insensitive // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. GetRespHeader(key string, defaultValue ...string) string // GetRespHeaders returns the HTTP response headers. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. GetRespHeaders() map[string][]string // ClientHelloInfo return CHI from context ClientHelloInfo() *tls.ClientHelloInfo // Next executes the next method in the stack that matches the current route. Next() error // RestartRouting instead of going to the next handler. This may be useful after // changing the request path. Note that handlers might be executed again. RestartRouting() error setHandlerCtx(ctx CustomCtx) // OriginalURL contains the original request URL. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting to use the value outside the Handler. OriginalURL() string // Path returns the path part of the request URL. // Optionally, you could override the path. // Make copies or use the Immutable setting to use the value outside the Handler. Path(override ...string) string // RequestID returns the request identifier from the response header or request header. RequestID() string // Req returns a convenience type whose API is limited to operations // on the incoming request. Req() Req // Res returns a convenience type whose API is limited to operations // on the outgoing response. Res() Res // Redirect returns the Redirect reference. // Use Redirect().Status() to set custom redirection status code. // If status is not specified, status defaults to 303 See Other. // You can use Redirect().To(), Redirect().Route() and Redirect().Back() for redirection. Redirect() *Redirect // ViewBind Add vars to default view var map binding to template engine. // Variables are read by the Render method and may be overwritten. ViewBind(vars Map) error // Route returns the matched Route struct. Route() *Route // FullPath returns the matched route path, including any group prefixes. FullPath() string // Matched returns true if the current request path was matched by the router. Matched() bool // IsMiddleware returns true if the current request handler was registered as middleware. IsMiddleware() bool // HasBody returns true if the request declares a body via Content-Length, Transfer-Encoding, or already buffered payload data. HasBody() bool // OverrideParam overwrites a route parameter value by name. // If the parameter name does not exist in the route, this method does nothing. OverrideParam(name, value string) // IsWebSocket returns true if the request includes a WebSocket upgrade handshake. IsWebSocket() bool // IsPreflight returns true if the request is a CORS preflight. IsPreflight() bool // SaveFile saves any multipart file to disk. SaveFile(fileheader *multipart.FileHeader, path string) error // SaveFileToStorage saves any multipart file to an external storage system. SaveFileToStorage(fileheader *multipart.FileHeader, path string, storage Storage) error // Secure returns whether a secure connection was established. Secure() bool // Status sets the HTTP status for the response. // This method is chainable. Status(status int) Ctx // String returns unique string representation of the ctx. // // The returned value may be useful for logging. String() string // Value makes it possible to retrieve values (Locals) under keys scoped to the request // and therefore available to all following routes that match the request. If the context // has been released and c.fasthttp is nil (for example, after ReleaseCtx), Value returns nil. Value(key any) any // XHR returns a Boolean property, that is true, if the request's X-Requested-With header field is XMLHttpRequest, // indicating that the request was issued by a client library (such as jQuery). XHR() bool // configDependentPaths set paths for route recognition and prepared paths for the user, // here the features for caseSensitive, decoded paths, strict paths are evaluated configDependentPaths() // Reset is a method to reset context fields by given request when to use server handlers. Reset(fctx *fasthttp.RequestCtx) // release is a method to reset context fields when to use ReleaseCtx() release() // Abandon marks this context as abandoned. An abandoned context will not be // returned to the pool when ReleaseCtx is called. // // This is used by the timeout middleware to return immediately while the // handler goroutine continues using the context safely. // // Only call ForceRelease after Abandon if you can guarantee no other goroutine // (including Fiber's requestHandler and ErrorHandler) will touch the context. // The timeout middleware intentionally does NOT call ForceRelease to avoid // races, which means timed-out requests leak their contexts until a safe // reclamation strategy exists. Abandon() // IsAbandoned returns true if Abandon() was called on this context. IsAbandoned() bool // ForceRelease releases an abandoned context back to the pool. // This MUST only be called after all goroutines (including requestHandler and // ErrorHandler) have completely finished using this context. Calling it while // any goroutine is still running causes races. ForceRelease() renderExtensions(bind any) // Bind You can bind body, cookie, headers etc. into the map, map slice, struct easily by using Binding method. // It gives custom binding support, detailed binding options and more. // Replacement of: BodyParser, ParamsParser, GetReqHeaders, GetRespHeaders, AllParams, QueryParser, ReqHeaderParser Bind() *Bind // Methods to use with next stack. getMethodInt() int getIndexRoute() int getTreePathHash() int getDetectionPath() string getValues() *[maxParams]string getMatched() bool getSkipNonUseRoutes() bool setIndexHandler(handler int) setIndexRoute(route int) setMatched(matched bool) setSkipNonUseRoutes(skip bool) setRoute(route *Route) getPathOriginal() string // FullURL returns the full request URL (protocol + host + original URL). FullURL() string // UserAgent returns the User-Agent request header. UserAgent() string // Referer returns the Referer request header. Referer() string // AcceptLanguage returns the Accept-Language request header. AcceptLanguage() string // AcceptEncoding returns the Accept-Encoding request header. AcceptEncoding() string // HasHeader reports whether the request includes a header with the given key. HasHeader(key string) bool // MediaType returns the MIME type from the Content-Type header without parameters. MediaType() string // Charset returns the charset parameter from the Content-Type header. Charset() string // IsJSON reports whether the Content-Type header is JSON. IsJSON() bool // IsForm reports whether the Content-Type header is form-encoded. IsForm() bool // IsMultipart reports whether the Content-Type header is multipart form data. IsMultipart() bool // AcceptsJSON reports whether the Accept header allows JSON. AcceptsJSON() bool // AcceptsHTML reports whether the Accept header allows HTML. AcceptsHTML() bool // AcceptsXML reports whether the Accept header allows XML. AcceptsXML() bool // AcceptsEventStream reports whether the Accept header allows text/event-stream. AcceptsEventStream() bool // Accepts checks if the specified extensions or content types are acceptable. Accepts(offers ...string) string // AcceptsCharsets checks if the specified charset is acceptable. AcceptsCharsets(offers ...string) string // AcceptsEncodings checks if the specified encoding is acceptable. AcceptsEncodings(offers ...string) string // AcceptsLanguages checks if the specified language is acceptable using // RFC 4647 Basic Filtering. AcceptsLanguages(offers ...string) string // AcceptsLanguagesExtended checks if the specified language is acceptable using // RFC 4647 Extended Filtering. AcceptsLanguagesExtended(offers ...string) string // BodyRaw contains the raw body submitted in a POST request. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. BodyRaw() []byte //nolint:nonamedreturns // gocritic unnamedResult prefers naming decoded body, decode count, and error tryDecodeBodyInOrder(originalBody *[]byte, encodings []string) (body []byte, decodesRealized uint8, err error) // Body contains the raw body submitted in a POST request. // This method will decompress the body if the 'Content-Encoding' header is provided. // It returns the original (or decompressed) body data which is valid only within the handler. // Don't store direct references to the returned data. // If you need to keep the body's data later, make a copy or use the Immutable option. Body() []byte // Cookies are used for getting a cookie value by key. // Defaults to the empty string "" if the cookie doesn't exist. // If a default value is given, it will return that value if the cookie doesn't exist. // The returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting to use the value outside the Handler. Cookies(key string, defaultValue ...string) string // FormFile returns the first file by key from a MultipartForm. FormFile(key string) (*multipart.FileHeader, error) // FormValue returns the first value by key from a MultipartForm. // Search is performed in QueryArgs, PostArgs, MultipartForm and FormFile in this particular order. // Defaults to the empty string "" if the form value doesn't exist. // If a default value is given, it will return that value if the form value does not exist. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. FormValue(key string, defaultValue ...string) string // Fresh returns true when the response is still “fresh” in the client's cache, // otherwise false is returned to indicate that the client cache is now stale // and the full response should be sent. // When a client sends the Cache-Control: no-cache request header to indicate an end-to-end // reload request, this module will return false to make handling these requests transparent. // https://github.com/jshttp/fresh/blob/master/index.js#L33 Fresh() bool // Host contains the host derived from the X-Forwarded-Host or Host HTTP header. // Returned value is only valid within the handler. Do not store any references. // In a network context, `Host` refers to the combination of a hostname and potentially a port number used for connecting, // while `Hostname` refers specifically to the name assigned to a device on a network, excluding any port information. // Example: URL: https://example.com:8080 -> Host: example.com:8080 // Make copies or use the Immutable setting instead. // Please use Config.TrustProxy to prevent header spoofing if your app is behind a proxy. Host() string // Hostname contains the hostname derived from the X-Forwarded-Host or Host HTTP header using the c.Host() method. // Returned value is only valid within the handler. Do not store any references. // Example: URL: https://example.com:8080 -> Hostname: example.com // Make copies or use the Immutable setting instead. // Please use Config.TrustProxy to prevent header spoofing if your app is behind a proxy. Hostname() string // Port returns the remote port of the request. Port() string // IP returns the remote IP address of the request. // If ProxyHeader and IP Validation is configured, it will parse that header and return the first valid IP address. // Please use Config.TrustProxy to prevent header spoofing if your app is behind a proxy. IP() string // extractIPsFromHeader will return a slice of IPs it found given a header name in the order they appear. // When IP validation is enabled, any invalid IPs will be omitted. extractIPsFromHeader(header string) []string // extractIPFromHeader will attempt to pull the real client IP from the given header when IP validation is enabled. // currently, it will return the first valid IP address in header. // when IP validation is disabled, it will simply return the value of the header without any inspection. // Implementation is almost the same as in extractIPsFromHeader, but without allocation of []string. extractIPFromHeader(header string) string // IPs returns a string slice of IP addresses specified in the X-Forwarded-For request header. // When IP validation is enabled, only valid IPs are returned. IPs() []string // Is returns the matching content type, // if the incoming request's Content-Type HTTP header field matches the MIME type specified by the type parameter Is(extension string) bool // Locals makes it possible to pass any values under keys scoped to the request // and therefore available to all following routes that match the request. // // All the values are removed from ctx after returning from the top // RequestHandler. Additionally, Close method is called on each value // implementing io.Closer before removing the value from ctx. Locals(key any, value ...any) any // Method returns the HTTP request method for the context, optionally overridden by the provided argument. // If no override is given or if the provided override is not a valid HTTP method, it returns the current method from the context. // Otherwise, it updates the context's method and returns the overridden method as a string. Method(override ...string) string // MultipartForm parse form entries from binary. // This returns a map[string][]string, so given a key, the value will be a string slice. MultipartForm() (*multipart.Form, error) // Params is used to get the route parameters. // Defaults to empty string "" if the param doesn't exist. // If a default value is given, it will return that value if the param doesn't exist. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting to use the value outside the Handler. Params(key string, defaultValue ...string) string // Scheme contains the request protocol string: http or https for TLS requests. // Please use Config.TrustProxy to prevent header spoofing if your app is behind a proxy. Scheme() string // Protocol returns the HTTP protocol of request: HTTP/1.1 and HTTP/2. Protocol() string // Query returns the query string parameter in the url. // Defaults to empty string "" if the query doesn't exist. // If a default value is given, it will return that value if the query doesn't exist. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting to use the value outside the Handler. Query(key string, defaultValue ...string) string // Queries returns a map of query parameters and their values. // // GET /?name=alex&wanna_cake=2&id= // Queries()["name"] == "alex" // Queries()["wanna_cake"] == "2" // Queries()["id"] == "" // // GET /?field1=value1&field1=value2&field2=value3 // Queries()["field1"] == "value2" // Queries()["field2"] == "value3" // // GET /?list_a=1&list_a=2&list_a=3&list_b[]=1&list_b[]=2&list_b[]=3&list_c=1,2,3 // Queries()["list_a"] == "3" // Queries()["list_b[]"] == "3" // Queries()["list_c"] == "1,2,3" // // GET /api/search?filters.author.name=John&filters.category.name=Technology&filters[customer][name]=Alice&filters[status]=pending // Queries()["filters.author.name"] == "John" // Queries()["filters.category.name"] == "Technology" // Queries()["filters[customer][name]"] == "Alice" // Queries()["filters[status]"] == "pending" Queries() map[string]string // Range returns a struct containing the type and a slice of ranges. Range(size int64) (Range, error) // Subdomains returns a slice of subdomains from the host, excluding the last `offset` components. // If the offset is negative or exceeds the number of subdomains, an empty slice is returned. // If the offset is zero every label (no trimming) is returned. Subdomains(offset ...int) []string // Stale returns the inverse of Fresh, indicating if the client's cached response is considered stale. Stale() bool // IsProxyTrusted checks trustworthiness of remote ip. // If Config.TrustProxy false, it returns false. // IsProxyTrusted can check remote ip by proxy ranges and ip map. IsProxyTrusted() bool // IsFromLocal will return true if request came from local. IsFromLocal() bool getBody() []byte // Append the specified value to the HTTP response header field. // If the header is not already set, it creates the header with the specified value. Append(field string, values ...string) // Attachment sets the HTTP response Content-Disposition header field to attachment. Attachment(filename ...string) // ClearCookie expires a specific cookie by key on the client side. // If no key is provided it expires all cookies that came with the request. ClearCookie(key ...string) // Cookie sets a cookie by passing a cookie struct. Cookie(cookie *Cookie) // Download transfers the file from path as an attachment. // Typically, browsers will prompt the user for download. // By default, the Content-Disposition header filename= parameter is the filepath (this typically appears in the browser dialog). // Override this default with the filename parameter. Download(file string, filename ...string) error // Format performs content-negotiation on the Accept HTTP header. // It uses Accepts to select a proper format and calls the matching // user-provided handler function. // If no accepted format is found, and a format with MediaType "default" is given, // that default handler is called. If no format is found and no default is given, // StatusNotAcceptable is sent. Format(handlers ...ResFmt) error // AutoFormat performs content-negotiation on the Accept HTTP header. // It uses Accepts to select a proper format. // The supported content types are text/html, text/plain, application/json, application/xml, application/vnd.msgpack, and application/cbor. // For more flexible content negotiation, use Format. // If the header is not specified or there is no proper format, text/plain is used. AutoFormat(body any) error // JSON converts any interface or string to JSON. // Array and slice values encode as JSON arrays, // except that []byte encodes as a base64-encoded string, // and a nil slice encodes as the null JSON value. // If the ctype parameter is given, this method will set the // Content-Type header equal to ctype. If ctype is not given, // The Content-Type header will be set to application/json; charset=utf-8. JSON(data any, ctype ...string) error // MsgPack converts any interface or string to MessagePack encoded bytes. // If the ctype parameter is given, this method will set the // Content-Type header equal to ctype. If ctype is not given, // The Content-Type header will be set to application/vnd.msgpack. MsgPack(data any, ctype ...string) error // CBOR converts any interface or string to CBOR encoded bytes. // If the ctype parameter is given, this method will set the // Content-Type header equal to ctype. If ctype is not given, // The Content-Type header will be set to application/cbor. CBOR(data any, ctype ...string) error // JSONP sends a JSON response with JSONP support. // This method is identical to JSON, except that it opts-in to JSONP callback support. // By default, the callback name is simply callback. JSONP(data any, callback ...string) error // XML converts any interface or string to XML. // This method also sets the content header to application/xml; charset=utf-8. XML(data any) error // Links joins the links followed by the property to populate the response's Link HTTP header field. Links(link ...string) // Location sets the response Location HTTP header to the specified path parameter. Location(path string) // getLocationFromRoute get URL location from route using parameters getLocationFromRoute(route *Route, params Map) (string, error) // GetRouteURL generates URLs to named routes, with parameters. URLs are relative, for example: "/user/1831" GetRouteURL(routeName string, params Map) (string, error) // Render a template with data and sends a text/html response. // We support the following engines: https://github.com/gofiber/template Render(name string, bind any, layouts ...string) error // Send sets the HTTP response body without copying it. // From this point onward the body argument must not be changed. Send(body []byte) error // SendEarlyHints allows the server to hint to the browser what resources a page would need // so the browser can preload them while waiting for the server's full response. Only Link // headers already written to the response will be transmitted as Early Hints. // // This is a HTTP/2+ feature but all browsers will either understand it or safely ignore it. // // NOTE: Older HTTP/1.1 non-browser clients may face compatibility issues. // // See: https://developer.chrome.com/docs/web-platform/early-hints and // https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Link#syntax SendEarlyHints(hints []string) error // SendFile transfers the file from the specified path. // By default, the file is not compressed. To enable compression, set SendFile.Compress to true. // The Content-Type response HTTP header field is set based on the file's extension. // If the file extension is missing or invalid, the Content-Type is detected from the file's format. SendFile(file string, config ...SendFile) error // SendStatus sets the HTTP status code and if the response body is empty, // it sets the correct status message in the body. SendStatus(status int) error // SendString sets the HTTP response body for string types. // This means no type assertion, recommended for faster performance SendString(body string) error // SendStream sets response body stream and optional body size. SendStream(stream io.Reader, size ...int) error // SendStreamWriter sets response body stream writer SendStreamWriter(streamWriter func(*bufio.Writer)) error // Set sets the response's HTTP header field to the specified key, value. Set(key, val string) setCanonical(key, val string) // Type sets the Content-Type HTTP header to the MIME type specified by the file extension. Type(extension string, charset ...string) Ctx // Vary adds the given header field to the Vary response header. // This will append the header, if not already listed; otherwise, leaves it listed in the current location. Vary(fields ...string) // Write appends p into response body. Write(p []byte) (int, error) // Writef appends f & a into response body writer. Writef(f string, a ...any) (int, error) // WriteString appends s to response body. WriteString(s string) (int, error) // Drop closes the underlying connection without sending any response headers or body. // This can be useful for silently terminating client connections, such as in DDoS mitigation // or when blocking access to sensitive endpoints. Drop() error // End immediately flushes the current response and closes the underlying connection. End() error } ================================================ FILE: ctx_test.go ================================================ // ⚡️ Fiber is an Express inspired web framework written in Go with ☕️ // 🤖 GitHub Repository: https://github.com/gofiber/fiber // 📌 API Documentation: https://docs.gofiber.io package fiber import ( "bufio" "bytes" "compress/gzip" "compress/zlib" "context" "crypto/tls" "embed" "encoding/hex" "encoding/xml" "errors" "fmt" "io" "math" "mime/multipart" "net" "net/http" "net/http/httptest" "os" "path/filepath" "runtime" "strconv" "strings" "sync/atomic" "testing" "text/template" "time" "github.com/fxamacker/cbor/v2" "github.com/gofiber/utils/v2" "github.com/shamaton/msgpack/v3" "github.com/stretchr/testify/require" "github.com/valyala/bytebufferpool" "github.com/valyala/fasthttp" "github.com/gofiber/fiber/v3/internal/storage/memory" ) const epsilon = 0.001 type testContextKey struct{} type testNetAddr struct { network string address string } func (t testNetAddr) Network() string { return t.network } func (t testNetAddr) String() string { return t.address } // go test -run Test_Ctx_Accepts func Test_Ctx_Accepts(t *testing.T) { t.Parallel() app := New(Config{ CBOREncoder: cbor.Marshal, CBORDecoder: cbor.Unmarshal, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderAccept, "text/html,application/xhtml+xml,application/xml;q=0.9") require.Empty(t, c.Accepts("")) require.Empty(t, c.Req().Accepts()) require.Equal(t, ".xml", c.Accepts(".xml")) require.Empty(t, c.Accepts(".john")) require.Equal(t, "application/xhtml+xml", c.Accepts("application/xml", "application/xml+rss", "application/yaml", "application/xhtml+xml"), "must use client-preferred mime type") c.Request().Header.Set(HeaderAccept, "application/json, text/plain, */*;q=0") require.Empty(t, c.Accepts("html"), "must treat */*;q=0 as not acceptable") c.Request().Header.Set(HeaderAccept, "text/*, application/json") require.Equal(t, "html", c.Accepts("html")) require.Equal(t, "text/html", c.Accepts("text/html")) require.Equal(t, "json", c.Req().Accepts("json", "text")) require.Equal(t, "application/json", c.Accepts("application/json")) require.Empty(t, c.Accepts("image/png")) require.Empty(t, c.Accepts("png")) c.Request().Header.Set(HeaderAccept, "text/html, application/json") require.Equal(t, "text/*", c.Req().Accepts("text/*")) c.Request().Header.Set(HeaderAccept, "*/*") require.Equal(t, "html", c.Accepts("html")) c.Request().Header.Del(HeaderAccept) require.Equal(t, "json", c.Accepts("json", "html")) require.Equal(t, "application/json", c.Accepts("application/json", "text/html")) } // go test -run Test_Ctx_AcceptsHelpers func Test_Ctx_AcceptsHelpers(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderAccept, "text/html,application/json;q=0.9") require.True(t, c.AcceptsHTML()) require.True(t, c.AcceptsJSON()) require.False(t, c.AcceptsXML()) require.False(t, c.AcceptsEventStream()) c.Request().Header.Set(HeaderAccept, "application/xml") require.True(t, c.AcceptsXML()) require.False(t, c.AcceptsJSON()) c.Request().Header.Set(HeaderAccept, "text/event-stream") require.True(t, c.AcceptsEventStream()) } // go test -run Test_Ctx_ContentTypeHelpers func Test_Ctx_ContentTypeHelpers(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) require.Empty(t, c.MediaType()) require.Empty(t, c.Charset()) require.False(t, c.IsJSON()) c.Request().Header.Set(HeaderContentType, "application/json; charset=utf-8") //nolint:testifylint // This is a MIME type string, not JSON payload. require.Equal(t, MIMEApplicationJSON, c.MediaType()) require.Equal(t, "utf-8", c.Charset()) require.True(t, c.IsJSON()) require.False(t, c.IsForm()) require.False(t, c.IsMultipart()) c.Request().Header.Set(HeaderContentType, "text/html ; charset=\"UTF-8\"") require.Equal(t, MIMETextHTML, c.MediaType()) require.Equal(t, "UTF-8", c.Charset()) c.Request().Header.Set(HeaderContentType, MIMEApplicationForm) require.Equal(t, MIMEApplicationForm, c.MediaType()) require.Empty(t, c.Charset()) require.True(t, c.IsForm()) require.False(t, c.IsMultipart()) c.Request().Header.Set(HeaderContentType, MIMEMultipartForm+"; boundary=abc123") require.Equal(t, MIMEMultipartForm, c.MediaType()) require.Empty(t, c.Charset()) require.True(t, c.IsMultipart()) require.False(t, c.IsForm()) } // go test -run Test_Ctx_Charset func Test_Ctx_Charset(t *testing.T) { t.Parallel() testCases := []struct { name string contentType string expected string }{ { name: "no_parameters", contentType: "text/plain", expected: "", }, { name: "trailing_semicolon", contentType: "text/plain;", expected: "", }, { name: "empty_param_before_charset", contentType: "text/plain; ; charset=utf-8", expected: "utf-8", }, { name: "charset_with_spaces", contentType: "text/plain; charset = utf-8", expected: "utf-8", }, { name: "charset_in_middle", contentType: "text/plain; foo=bar; charset=iso-8859-1; baz=qux", expected: "iso-8859-1", }, { name: "charset_quoted", contentType: "text/plain; charset=\"utf-8\"; foo=bar", expected: "utf-8", }, { name: "non_charset_only", contentType: "text/plain; foo=bar", expected: "", }, { name: "missing_equals", contentType: "text/plain; charset", expected: "", }, { name: "empty_charset_value", contentType: "text/plain; charset=", expected: "", }, { name: "case_insensitive_charset", contentType: "text/plain; chArSet=Shift_JIS", expected: "Shift_JIS", }, } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderContentType, testCase.contentType) require.Equal(t, testCase.expected, c.Charset()) }) } } // go test -run Test_Ctx_HeaderHelpers func Test_Ctx_HeaderHelpers(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.SetHost("example.com") c.Request().SetRequestURI("/search?q=fiber") require.Equal(t, "http://example.com/search?q=fiber", c.FullURL()) c.Request().Header.Set(HeaderUserAgent, "fiber-agent") c.Request().Header.Set(HeaderReferer, "https://example.com") c.Request().Header.Set(HeaderAcceptLanguage, "en-US,en;q=0.9") c.Request().Header.Set(HeaderAcceptEncoding, "gzip, br") c.Request().Header.Set("X-Trace-Id", "trace") require.True(t, c.HasHeader("X-Trace-Id")) require.True(t, c.HasHeader("x-trace-id")) require.Equal(t, "fiber-agent", c.UserAgent()) require.Equal(t, "https://example.com", c.Referer()) require.Equal(t, "en-US,en;q=0.9", c.AcceptLanguage()) require.Equal(t, "gzip, br", c.AcceptEncoding()) c.Request().Header.Set(HeaderXRequestID, "request-id") c.Response().Header.Set(HeaderXRequestID, "response-id") require.Equal(t, "response-id", c.RequestID()) c.Response().Header.Del(HeaderXRequestID) require.Equal(t, "request-id", c.RequestID()) c.Request().Header.Del("X-Trace-Id") require.False(t, c.HasHeader("X-Trace-Id")) } // go test -run Test_Ctx_TypedParsingDefaults func Test_Ctx_TypedParsingDefaults(t *testing.T) { t.Parallel() app := New() app.Get("/:id", func(c Ctx) error { require.Equal(t, 5, Query[int](c, "count", 1)) require.Equal(t, 9, Query[int](c, "missing", 9)) require.Equal(t, 3, Query[int](c, "bad", 3)) require.Equal(t, 42, Params[int](c, "id", 7)) require.Equal(t, 7, Params[int](c, "missing", 7)) require.Equal(t, 11, GetReqHeader[int](c, "X-Limit", 4)) require.Equal(t, 4, GetReqHeader[int](c, "X-Bad", 4)) return c.SendStatus(StatusOK) }) req := httptest.NewRequest(MethodGet, "/42?count=5&bad=oops", http.NoBody) req.Header.Set("X-Limit", "11") req.Header.Set("X-Bad", "oops") resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") defer func() { require.NoError(t, resp.Body.Close()) }() require.Equal(t, StatusOK, resp.StatusCode) } // go test -v -run=^$ -bench=Benchmark_Ctx_Accepts -benchmem -count=4 func Benchmark_Ctx_Accepts(b *testing.B) { app := New(Config{ CBOREncoder: cbor.Marshal, CBORDecoder: cbor.Unmarshal, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) acceptHeader := "text/html,application/xhtml+xml,application/xml;q=0.9" c.Request().Header.Set("Accept", acceptHeader) acceptValues := [][]string{ {".xml"}, {"json", "xml"}, {"application/json", "application/xml"}, } expectedResults := []string{".xml", "xml", "application/xml"} for i := range acceptValues { b.Run(fmt.Sprintf("run-%#v", acceptValues[i]), func(bb *testing.B) { var res string bb.ReportAllocs() for bb.Loop() { res = c.Accepts(acceptValues[i]...) } require.Equal(bb, expectedResults[i], res) }) } } type customCtx struct { DefaultCtx } func (c *customCtx) Params(key string, defaultValue ...string) string { //revive:disable-line:unused-parameter // We need defaultValue for some cases return "prefix_" + c.DefaultCtx.Params(key) } // go test -run Test_Ctx_CustomCtx func Test_Ctx_CustomCtx(t *testing.T) { t.Parallel() app := NewWithCustomCtx(func(app *App) CustomCtx { return &customCtx{ DefaultCtx: *NewDefaultCtx(app), } }) app.Get("/:id", func(c Ctx) error { return c.SendString(c.Params("id")) }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/v3", &bytes.Buffer{})) require.NoError(t, err, "app.Test(req)") defer func() { require.NoError(t, resp.Body.Close()) }() body, err := io.ReadAll(resp.Body) require.NoError(t, err, "io.ReadAll(resp.Body)") require.Len(t, body, len("prefix_v3")) require.Equal(t, "prefix_v3", string(body)) require.Equal(t, MIMETextPlainCharsetUTF8, resp.Header.Get(HeaderContentType)) require.Equal(t, int64(len(body)), resp.ContentLength) } func Test_Ctx_CustomCtx_WithMiddleware(t *testing.T) { t.Parallel() app := NewWithCustomCtx(func(app *App) CustomCtx { return &customCtx{ DefaultCtx: *NewDefaultCtx(app), } }) app.Use(func(c Ctx) error { _, ok := c.(*customCtx) require.True(t, ok) return c.Next() }) app.Get("/", func(c Ctx) error { custom, ok := c.(*customCtx) require.True(t, ok) return c.SendString(custom.Params("")) }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/", http.NoBody)) require.NoError(t, err, "app.Test(req)") defer func() { require.NoError(t, resp.Body.Close()) }() body, err := io.ReadAll(resp.Body) require.NoError(t, err, "io.ReadAll(resp.Body)") require.Equal(t, "prefix_", string(body)) } // go test -run Test_Ctx_CustomCtx func Test_Ctx_CustomCtx_and_Method(t *testing.T) { t.Parallel() // Create app with custom request methods methods := append(DefaultMethods, "JOHN") //nolint:gocritic // We want a new slice here app := NewWithCustomCtx(func(app *App) CustomCtx { return &customCtx{ DefaultCtx: *NewDefaultCtx(app), } }, Config{ RequestMethods: methods, }) // Add route with custom method app.Add([]string{"JOHN"}, "/doe", testEmptyHandler) resp, err := app.Test(httptest.NewRequest("JOHN", "/doe", http.NoBody)) require.NoError(t, err, "app.Test(req)") defer func() { require.NoError(t, resp.Body.Close()) }() require.Equal(t, StatusOK, resp.StatusCode, "Status code") body, err := io.ReadAll(resp.Body) require.NoError(t, err, "io.ReadAll(resp.Body)") require.Empty(t, body) require.Empty(t, resp.Header.Get(HeaderContentType)) require.Equal(t, int64(0), resp.ContentLength) // Add a new method require.Panics(t, func() { app.Add([]string{"JANE"}, "/jane", testEmptyHandler) }) } // go test -run Test_Ctx_Accepts_EmptyAccept func Test_Ctx_Accepts_EmptyAccept(t *testing.T) { t.Parallel() app := New(Config{ CBOREncoder: cbor.Marshal, CBORDecoder: cbor.Unmarshal, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) require.Equal(t, ".forwarded", c.Accepts(".forwarded")) } // go test -run Test_Ctx_Accepts_Wildcard func Test_Ctx_Accepts_Wildcard(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderAccept, "*/*;q=0.9") require.Equal(t, "html", c.Accepts("html")) require.Equal(t, "foo", c.Accepts("foo")) require.Equal(t, ".bar", c.Accepts(".bar")) c.Request().Header.Set(HeaderAccept, "text/html,application/*;q=0.9") require.Equal(t, "xml", c.Accepts("xml")) } // go test -run Test_Ctx_Accepts_MultiHeader func Test_Ctx_Accepts_MultiHeader(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Add(HeaderAccept, "text/plain;q=0.5") c.Request().Header.Add(HeaderAccept, "application/json") require.Equal(t, "application/json", c.Accepts("text/plain", "application/json")) } // go test -run Test_Ctx_AcceptsCharsets func Test_Ctx_AcceptsCharsets(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderAcceptCharset, "utf-8, iso-8859-1;q=0.5") require.Equal(t, "utf-8", c.AcceptsCharsets("utf-8")) } // go test -run Test_Ctx_AcceptsCharsets_MultiHeader func Test_Ctx_AcceptsCharsets_MultiHeader(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Add(HeaderAcceptCharset, "utf-8;q=0.1") c.Request().Header.Add(HeaderAcceptCharset, "iso-8859-1") require.Equal(t, "iso-8859-1", c.AcceptsCharsets("utf-8", "iso-8859-1")) } // go test -v -run=^$ -bench=Benchmark_Ctx_AcceptsCharsets -benchmem -count=4 func Benchmark_Ctx_AcceptsCharsets(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed c.Request().Header.Set("Accept-Charset", "utf-8, iso-8859-1;q=0.5") var res string b.ReportAllocs() for b.Loop() { res = c.AcceptsCharsets("utf-8") } require.Equal(b, "utf-8", res) } // go test -run Test_Ctx_AcceptsEncodings func Test_Ctx_AcceptsEncodings(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderAcceptEncoding, "deflate, gzip;q=1.0, *;q=0.5") require.Equal(t, "gzip", c.AcceptsEncodings("gzip")) require.Equal(t, "abc", c.AcceptsEncodings("abc")) } // go test -run Test_Ctx_AcceptsEncodings_MultiHeader func Test_Ctx_AcceptsEncodings_MultiHeader(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Add(HeaderAcceptEncoding, "deflate;q=0.3") c.Request().Header.Add(HeaderAcceptEncoding, "gzip") require.Equal(t, "gzip", c.AcceptsEncodings("deflate", "gzip")) } // go test -v -run=^$ -bench=Benchmark_Ctx_AcceptsEncodings -benchmem -count=4 func Benchmark_Ctx_AcceptsEncodings(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed c.Request().Header.Set(HeaderAcceptEncoding, "deflate, gzip;q=1.0, *;q=0.5") var res string b.ReportAllocs() for b.Loop() { res = c.AcceptsEncodings("gzip") } require.Equal(b, "gzip", res) } // go test -run Test_Ctx_AcceptsLanguages func Test_Ctx_AcceptsLanguages(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderAcceptLanguage, "fr-CH, fr;q=0.9, en;q=0.8, de;q=0.7, *;q=0.5") require.Equal(t, "fr", c.AcceptsLanguages("fr")) } // go test -run Test_Ctx_AcceptsLanguages_MultiHeader func Test_Ctx_AcceptsLanguages_MultiHeader(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Add(HeaderAcceptLanguage, "de;q=0.4") c.Request().Header.Add(HeaderAcceptLanguage, "en") require.Equal(t, "en", c.AcceptsLanguages("de", "en")) } // go test -run Test_Ctx_AcceptsLanguages_BasicFiltering func Test_Ctx_AcceptsLanguages_BasicFiltering(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderAcceptLanguage, "en-US") require.Equal(t, "en-US", c.AcceptsLanguages("en", "en-US")) require.Empty(t, c.AcceptsLanguages("en")) c.Request().Header.Set(HeaderAcceptLanguage, "en-US, fr") require.Equal(t, "en-US", c.AcceptsLanguages("de", "en-US", "fr")) c.Request().Header.Set(HeaderAcceptLanguage, "en") require.Equal(t, "en-US", c.AcceptsLanguages("en-US")) c.Request().Header.Set(HeaderAcceptLanguage, "*") require.Equal(t, "en", c.AcceptsLanguages("en", "fr")) c.Request().Header.Set(HeaderAcceptLanguage, "en_US") require.Empty(t, c.AcceptsLanguages("en-US")) c.Request().Header.Set(HeaderAcceptLanguage, "en-*") require.Empty(t, c.AcceptsLanguages("en-US")) } // go test -run Test_Ctx_AcceptsLanguages_CaseInsensitive func Test_Ctx_AcceptsLanguages_CaseInsensitive(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderAcceptLanguage, "EN-us") require.Equal(t, "en-US", c.AcceptsLanguages("en-US")) } // go test -run Test_Ctx_AcceptsLanguagesExtended func Test_Ctx_AcceptsLanguagesExtended(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderAcceptLanguage, "en-*") require.Equal(t, "en-US", c.AcceptsLanguagesExtended("en-US")) c.Request().Header.Set(HeaderAcceptLanguage, "*-US") require.Equal(t, "en-US", c.AcceptsLanguagesExtended("en-US", "fr-CA")) c.Request().Header.Set(HeaderAcceptLanguage, "en-US-*") require.Equal(t, "en-US", c.AcceptsLanguagesExtended("en-US")) c.Request().Header.Set(HeaderAcceptLanguage, "en") require.Equal(t, "en-US", c.AcceptsLanguagesExtended("en-US")) c.Request().Header.Set(HeaderAcceptLanguage, "*") require.Equal(t, "en-US", c.AcceptsLanguagesExtended("en-US", "fr-CA")) c.Request().Header.Set(HeaderAcceptLanguage, "en-US") require.Equal(t, "en-US-CA", c.AcceptsLanguagesExtended("en-US-CA")) c.Request().Header.Set(HeaderAcceptLanguage, "en-*") require.Equal(t, "en-US-CA", c.AcceptsLanguagesExtended("en-US-CA")) } // go test -v -run=^$ -bench=Benchmark_Ctx_AcceptsLanguages -benchmem -count=4 func Benchmark_Ctx_AcceptsLanguages(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed c.Request().Header.Set(HeaderAcceptLanguage, "fr-CH, fr;q=0.9, en;q=0.8, de;q=0.7, *;q=0.5") var res string b.ReportAllocs() for b.Loop() { res = c.AcceptsLanguages("fr") } require.Equal(b, "fr", res) } // go test -run Test_Ctx_App func Test_Ctx_App(t *testing.T) { t.Parallel() app := New() app.config.BodyLimit = 1000 c := app.AcquireCtx(&fasthttp.RequestCtx{}) require.Equal(t, 1000, c.App().config.BodyLimit) } // go test -run Test_Ctx_Append func Test_Ctx_Append(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Append("X-Test", "Hello") c.Append("X-Test", "World") c.Append("X-Test", "Hello", "World") // similar value in the middle c.Append("X2-Test", "World") c.Append("X2-Test", "XHello") c.Append("X2-Test", "Hello", "World") // similar value at the start c.Append("X3-Test", "XHello") c.Append("X3-Test", "World") c.Append("X3-Test", "Hello", "World") // try it with multiple similar values c.Append("X4-Test", "XHello") c.Append("X4-Test", "Hello") c.Append("X4-Test", "HelloZ") c.Append("X4-Test", "YHello") c.Append("X4-Test", "Hello") c.Append("X4-Test", "YHello") c.Append("X4-Test", "HelloZ") c.Append("X4-Test", "XHello") // without append value c.Append("X-Custom-Header") require.Equal(t, "Hello, World", string(c.Response().Header.Peek("X-Test"))) require.Equal(t, "World, XHello, Hello", string(c.Response().Header.Peek("X2-Test"))) require.Equal(t, "XHello, World, Hello", string(c.Response().Header.Peek("X3-Test"))) require.Equal(t, "XHello, Hello, HelloZ, YHello", string(c.Response().Header.Peek("X4-Test"))) require.Empty(t, string(c.Response().Header.Peek("x-custom-header"))) } // go test -v -run=^$ -bench=Benchmark_Ctx_Append -benchmem -count=4 func Benchmark_Ctx_Append(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed b.ReportAllocs() for b.Loop() { c.Append("X-Custom-Header", "Hello") c.Append("X-Custom-Header", "World") c.Append("X-Custom-Header", "Hello") } require.Equal(b, "Hello, World", app.toString(c.Response().Header.Peek("X-Custom-Header"))) } // go test -run Test_Ctx_Attachment func Test_Ctx_Attachment(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) // empty c.Attachment() require.Equal(t, `attachment`, string(c.Response().Header.Peek(HeaderContentDisposition))) // real filename c.Attachment("./static/img/logo.png") require.Equal(t, `attachment; filename="logo.png"`, string(c.Response().Header.Peek(HeaderContentDisposition))) require.Equal(t, "image/png", string(c.Response().Header.Peek(HeaderContentType))) // filename with spaces c.Attachment("report 2024.txt") require.Equal(t, `attachment; filename="report+2024.txt"`, string(c.Response().Header.Peek(HeaderContentDisposition))) // filename with nested path c.Attachment("../docs/archive.tar.gz") require.Equal(t, `attachment; filename="archive.tar.gz"`, string(c.Response().Header.Peek(HeaderContentDisposition))) // check quoting c.Attachment("another document.pdf\"\r\nBla: \"fasel") require.Equal(t, `attachment; filename="another+document.pdf%22Bla%3A+%22fasel"`, string(c.Response().Header.Peek(HeaderContentDisposition))) c.Attachment("файл.txt") header := string(c.Response().Header.Peek(HeaderContentDisposition)) require.Contains(t, header, `filename="файл.txt"`) require.Contains(t, header, `filename*=UTF-8''%D1%84%D0%B0%D0%B9%D0%BB.txt`) } // go test -run Test_Ctx_Attachment_SanitizesFilenameControls func Test_Ctx_Attachment_SanitizesFilenameControls(t *testing.T) { t.Parallel() app := New() testCases := []struct { name string filename string expected string }{ { name: "base name only", filename: "../docs/archive.tar.gz", expected: `attachment; filename="archive.tar.gz"`, }, { name: "controls stripped", filename: "down\r\nload\t\x00.txt", expected: `attachment; filename="download.txt"`, }, { name: "controls stripped without extension", filename: "report\r\n\t\x00", expected: `attachment; filename="report"`, }, { name: "empty after sanitize", filename: "\r\n\t\x00", expected: `attachment; filename="download"`, }, { name: "controls stripped in middle", filename: "file\rname\n\t\x00.bin", expected: `attachment; filename="filename.bin"`, }, { name: "dot fallback", filename: ".", expected: `attachment; filename="download"`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Attachment(tc.filename) header := string(c.Response().Header.Peek(HeaderContentDisposition)) require.Equal(t, tc.expected, header) require.NotContains(t, header, "\r") require.NotContains(t, header, "\n") require.NotContains(t, header, "\t") require.NotContains(t, header, "\x00") }) } } // go test -v -run=^$ -bench=Benchmark_Ctx_Attachment -benchmem -count=4 func Benchmark_Ctx_Attachment(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed b.ReportAllocs() for b.Loop() { // example with quote params c.Attachment("another document.pdf\"\r\nBla: \"fasel") } require.Equal(b, `attachment; filename="another+document.pdf%22Bla%3A+%22fasel"`, string(c.Response().Header.Peek(HeaderContentDisposition))) } // go test -run Test_Ctx_BaseURL func Test_Ctx_BaseURL(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com/test") require.Equal(t, "http://google.com", c.BaseURL()) // Check cache require.Equal(t, "http://google.com", c.BaseURL()) } // go test -v -run=^$ -bench=Benchmark_Ctx_BaseURL -benchmem func Benchmark_Ctx_BaseURL(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed c.Request().SetHost("google.com:1337") c.Request().URI().SetPath("/haha/oke/lol") var res string b.ReportAllocs() for b.Loop() { res = c.BaseURL() } require.Equal(b, "http://google.com:1337", res) } // go test -run Test_Ctx_Body func Test_Ctx_Body(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed c.Request().SetBody([]byte("john=doe")) require.Equal(t, []byte("john=doe"), c.Body()) } // go test -run Test_Ctx_BodyRaw func Test_Ctx_BodyRaw(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed c.Request().SetBodyRaw([]byte("john=doe")) require.Equal(t, []byte("john=doe"), c.BodyRaw()) } // go test -run Test_Ctx_BodyRaw_Immutable func Test_Ctx_BodyRaw_Immutable(t *testing.T) { t.Parallel() app := New(Config{Immutable: true}) c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed c.Request().SetBodyRaw([]byte("john=doe")) require.Equal(t, []byte("john=doe"), c.BodyRaw()) } // go test -v -run=^$ -bench=Benchmark_Ctx_Body -benchmem -count=4 func Benchmark_Ctx_Body(b *testing.B) { const input = "john=doe" app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed c.Request().SetBody([]byte(input)) b.ReportAllocs() for b.Loop() { _ = c.Body() } require.Equal(b, []byte(input), c.Body()) } // go test -v -run=^$ -bench=Benchmark_Ctx_BodyRaw -benchmem -count=4 func Benchmark_Ctx_BodyRaw(b *testing.B) { const input = "john=doe" app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed c.Request().SetBodyRaw([]byte(input)) b.ReportAllocs() for b.Loop() { _ = c.BodyRaw() } require.Equal(b, []byte(input), c.BodyRaw()) } // go test -v -run=^$ -bench=Benchmark_Ctx_BodyRaw_Immutable -benchmem -count=4 func Benchmark_Ctx_BodyRaw_Immutable(b *testing.B) { const input = "john=doe" app := New(Config{Immutable: true}) c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed c.Request().SetBodyRaw([]byte(input)) b.ReportAllocs() for b.Loop() { _ = c.BodyRaw() } require.Equal(b, []byte(input), c.BodyRaw()) } // go test -run Test_Ctx_Body_Immutable func Test_Ctx_Body_Immutable(t *testing.T) { t.Parallel() app := New() app.config.Immutable = true c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed c.Request().SetBody([]byte("john=doe")) require.Equal(t, []byte("john=doe"), c.Body()) } // go test -v -run=^$ -bench=Benchmark_Ctx_Body_Immutable -benchmem -count=4 func Benchmark_Ctx_Body_Immutable(b *testing.B) { const input = "john=doe" app := New() app.config.Immutable = true c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed c.Request().SetBody([]byte(input)) b.ReportAllocs() for b.Loop() { _ = c.Body() } require.Equal(b, []byte(input), c.Body()) } // go test -run Test_Ctx_Body_With_Compression func Test_Ctx_Body_With_Compression(t *testing.T) { t.Parallel() tests := []struct { name string contentEncoding string body []byte expectedBody []byte }{ { name: "gzip", contentEncoding: "gzip", body: []byte("john=doe"), expectedBody: []byte("john=doe"), }, { name: "gzip twice", contentEncoding: "gzip, gzip", body: []byte("double"), expectedBody: []byte("double"), }, { name: "unsupported_encoding", contentEncoding: "undefined", body: []byte("keeps_ORIGINAL"), expectedBody: []byte("Unsupported Media Type"), }, { name: "compress_not_implemented", contentEncoding: "compress", body: []byte("foo"), expectedBody: []byte("Not Implemented"), }, { name: "gzip then unsupported", contentEncoding: "gzip, undefined", body: []byte("Go, be gzipped"), expectedBody: []byte("Unsupported Media Type"), }, { name: "invalid_deflate", contentEncoding: "gzip,deflate", body: []byte("I'm not correctly compressed"), expectedBody: []byte(zlib.ErrHeader.Error()), }, { name: "identity", contentEncoding: "identity", body: []byte("bar"), expectedBody: []byte("bar"), }, } for _, testObject := range tests { tCase := testObject // Duplicate object to ensure it will be unique across all runs t.Run(tCase.name, func(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed c.Request().Header.Set("Content-Encoding", tCase.contentEncoding) encs := strings.SplitSeq(tCase.contentEncoding, ",") for enc := range encs { enc = utils.TrimSpace(enc) if strings.Contains(tCase.name, "invalid_deflate") && enc == StrDeflate { continue } switch enc { case "gzip": var b bytes.Buffer gz := gzip.NewWriter(&b) _, err := gz.Write(tCase.body) require.NoError(t, err) require.NoError(t, gz.Flush()) require.NoError(t, gz.Close()) tCase.body = b.Bytes() case StrDeflate: var b bytes.Buffer fl := zlib.NewWriter(&b) _, err := fl.Write(tCase.body) require.NoError(t, err) require.NoError(t, fl.Flush()) require.NoError(t, fl.Close()) tCase.body = b.Bytes() default: // we do nothing and expect the original body to be returned } } c.Request().SetBody(tCase.body) body := c.Body() require.Equal(t, tCase.expectedBody, body) switch { case strings.Contains(tCase.name, "unsupported"): require.Equal(t, StatusUnsupportedMediaType, c.Response().StatusCode()) case strings.Contains(tCase.name, "compress_not_implemented"): require.Equal(t, StatusNotImplemented, c.Response().StatusCode()) default: require.Equal(t, StatusOK, c.Response().StatusCode()) } // Check if body raw is the same as previous before decompression require.Equal( t, tCase.body, c.Request().Body(), "Body raw must be the same as set before", ) }) } } // go test -v -run=^$ -bench=Benchmark_Ctx_Body_With_Compression -benchmem -count=4 func Benchmark_Ctx_Body_With_Compression(b *testing.B) { encodingErr := errors.New("failed to encoding data") var ( compressGzip = func(data []byte) ([]byte, error) { var buf bytes.Buffer writer := gzip.NewWriter(&buf) if _, err := writer.Write(data); err != nil { return nil, encodingErr } if err := writer.Flush(); err != nil { return nil, encodingErr } if err := writer.Close(); err != nil { return nil, encodingErr } return buf.Bytes(), nil } compressDeflate = func(data []byte) ([]byte, error) { var buf bytes.Buffer writer := zlib.NewWriter(&buf) if _, err := writer.Write(data); err != nil { return nil, encodingErr } if err := writer.Flush(); err != nil { return nil, encodingErr } if err := writer.Close(); err != nil { return nil, encodingErr } return buf.Bytes(), nil } ) const input = "john=doe" compressionTests := []struct { compressWriter func([]byte) ([]byte, error) contentEncoding string expectedBody []byte }{ { contentEncoding: "gzip", compressWriter: compressGzip, expectedBody: []byte(input), }, { contentEncoding: "gzip,invalid", compressWriter: compressGzip, expectedBody: []byte(ErrUnsupportedMediaType.Error()), }, { contentEncoding: StrDeflate, compressWriter: compressDeflate, expectedBody: []byte(input), }, { contentEncoding: "gzip,deflate", compressWriter: func(data []byte) ([]byte, error) { var ( buf bytes.Buffer writer interface { io.WriteCloser Flush() error } err error ) // deflate writer = zlib.NewWriter(&buf) if _, err = writer.Write(data); err != nil { return nil, encodingErr } if err = writer.Flush(); err != nil { return nil, encodingErr } if err = writer.Close(); err != nil { return nil, encodingErr } data = make([]byte, buf.Len()) copy(data, buf.Bytes()) buf.Reset() // gzip writer = gzip.NewWriter(&buf) if _, err = writer.Write(data); err != nil { return nil, encodingErr } if err = writer.Flush(); err != nil { return nil, encodingErr } if err = writer.Close(); err != nil { return nil, encodingErr } return buf.Bytes(), nil }, expectedBody: []byte(zlib.ErrHeader.Error()), }, } b.ReportAllocs() for _, ct := range compressionTests { b.Run(ct.contentEncoding, func(b *testing.B) { app := New() const input = "john=doe" c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set("Content-Encoding", ct.contentEncoding) compressedBody, err := ct.compressWriter([]byte(input)) require.NoError(b, err) c.Request().SetBody(compressedBody) for b.Loop() { _ = c.Body() } require.Equal(b, ct.expectedBody, c.Body()) }) } } // go test -run Test_Ctx_Body_With_Compression_Immutable func Test_Ctx_Body_With_Compression_Immutable(t *testing.T) { t.Parallel() tests := []struct { name string contentEncoding string body []byte expectedBody []byte }{ { name: "gzip", contentEncoding: "gzip", body: []byte("john=doe"), expectedBody: []byte("john=doe"), }, { name: "gzip twice", contentEncoding: "gzip, gzip", body: []byte("double"), expectedBody: []byte("double"), }, { name: "unsupported_encoding", contentEncoding: "undefined", body: []byte("keeps_ORIGINAL"), expectedBody: []byte("Unsupported Media Type"), }, { name: "compress_not_implemented", contentEncoding: "compress", body: []byte("foo"), expectedBody: []byte("Not Implemented"), }, { name: "gzip then unsupported", contentEncoding: "gzip, undefined", body: []byte("Go, be gzipped"), expectedBody: []byte("Unsupported Media Type"), }, { name: "invalid_deflate", contentEncoding: "gzip,deflate", body: []byte("I'm not correctly compressed"), expectedBody: []byte(zlib.ErrHeader.Error()), }, { name: "identity", contentEncoding: "identity", body: []byte("bar"), expectedBody: []byte("bar"), }, } for _, testObject := range tests { tCase := testObject // Duplicate object to ensure it will be unique across all runs t.Run(tCase.name, func(t *testing.T) { t.Parallel() app := New() app.config.Immutable = true c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed c.Request().Header.Set("Content-Encoding", tCase.contentEncoding) encs := strings.SplitSeq(tCase.contentEncoding, ",") for enc := range encs { enc = utils.TrimSpace(enc) if strings.Contains(tCase.name, "invalid_deflate") && enc == StrDeflate { continue } switch enc { case "gzip": var b bytes.Buffer gz := gzip.NewWriter(&b) _, err := gz.Write(tCase.body) require.NoError(t, err) require.NoError(t, gz.Flush()) require.NoError(t, gz.Close()) tCase.body = b.Bytes() case StrDeflate: var b bytes.Buffer fl := zlib.NewWriter(&b) _, err := fl.Write(tCase.body) require.NoError(t, err) require.NoError(t, fl.Flush()) require.NoError(t, fl.Close()) tCase.body = b.Bytes() default: // we do nothing and expect the original body to be returned } } c.Request().SetBody(tCase.body) body := c.Body() require.Equal(t, tCase.expectedBody, body) switch { case strings.Contains(tCase.name, "unsupported"): require.Equal(t, StatusUnsupportedMediaType, c.Response().StatusCode()) case strings.Contains(tCase.name, "compress_not_implemented"): require.Equal(t, StatusNotImplemented, c.Response().StatusCode()) default: require.Equal(t, StatusOK, c.Response().StatusCode()) } // Check if body raw is the same as previous before decompression require.Equal( t, tCase.body, c.Request().Body(), "Body raw must be the same as set before", ) }) } } // go test -v -run=^$ -bench=Benchmark_Ctx_Body_With_Compression_Immutable -benchmem -count=4 func Benchmark_Ctx_Body_With_Compression_Immutable(b *testing.B) { encodingErr := errors.New("failed to encoding data") var ( compressGzip = func(data []byte) ([]byte, error) { var buf bytes.Buffer writer := gzip.NewWriter(&buf) if _, err := writer.Write(data); err != nil { return nil, encodingErr } if err := writer.Flush(); err != nil { return nil, encodingErr } if err := writer.Close(); err != nil { return nil, encodingErr } return buf.Bytes(), nil } compressDeflate = func(data []byte) ([]byte, error) { var buf bytes.Buffer writer := zlib.NewWriter(&buf) if _, err := writer.Write(data); err != nil { return nil, encodingErr } if err := writer.Flush(); err != nil { return nil, encodingErr } if err := writer.Close(); err != nil { return nil, encodingErr } return buf.Bytes(), nil } ) const input = "john=doe" compressionTests := []struct { compressWriter func([]byte) ([]byte, error) contentEncoding string expectedBody []byte }{ { contentEncoding: "gzip", compressWriter: compressGzip, expectedBody: []byte(input), }, { contentEncoding: "gzip,invalid", compressWriter: compressGzip, expectedBody: []byte(ErrUnsupportedMediaType.Error()), }, { contentEncoding: StrDeflate, compressWriter: compressDeflate, expectedBody: []byte(input), }, { contentEncoding: "gzip,deflate", compressWriter: func(data []byte) ([]byte, error) { var ( buf bytes.Buffer writer interface { io.WriteCloser Flush() error } err error ) // deflate writer = zlib.NewWriter(&buf) if _, err = writer.Write(data); err != nil { return nil, encodingErr } if err = writer.Flush(); err != nil { return nil, encodingErr } if err = writer.Close(); err != nil { return nil, encodingErr } data = make([]byte, buf.Len()) copy(data, buf.Bytes()) buf.Reset() // gzip writer = gzip.NewWriter(&buf) if _, err = writer.Write(data); err != nil { return nil, encodingErr } if err = writer.Flush(); err != nil { return nil, encodingErr } if err = writer.Close(); err != nil { return nil, encodingErr } return buf.Bytes(), nil }, expectedBody: []byte(zlib.ErrHeader.Error()), }, } b.ReportAllocs() for _, ct := range compressionTests { b.Run(ct.contentEncoding, func(b *testing.B) { app := New() app.config.Immutable = true const input = "john=doe" c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set("Content-Encoding", ct.contentEncoding) compressedBody, err := ct.compressWriter([]byte(input)) require.NoError(b, err) c.Request().SetBody(compressedBody) for b.Loop() { _ = c.Body() } require.Equal(b, ct.expectedBody, c.Body()) }) } } // go test -run Test_Ctx_RequestCtx func Test_Ctx_RequestCtx(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) require.Equal(t, "*fasthttp.RequestCtx", fmt.Sprintf("%T", c.RequestCtx())) } // go test -run Test_Ctx_Cookie func Test_Ctx_Cookie(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) expire := time.Now().Add(24 * time.Hour) var dst []byte dst = expire.In(time.UTC).AppendFormat(dst, time.RFC1123) httpdate := strings.ReplaceAll(string(dst), "UTC", "GMT") cookie := &Cookie{ Name: "username", Value: "john", Expires: expire, // SameSite: CookieSameSiteStrictMode, // default is "lax" } c.Res().Cookie(cookie) expect := "username=john; expires=" + httpdate + "; path=/; SameSite=Lax" require.Equal(t, expect, c.Res().Get(HeaderSetCookie)) expect = "username=john; expires=" + httpdate + "; path=/" cookie.SameSite = CookieSameSiteDisabled c.Res().Cookie(cookie) require.Equal(t, expect, c.Res().Get(HeaderSetCookie)) expect = "username=john; expires=" + httpdate + "; path=/; SameSite=Strict" cookie.SameSite = CookieSameSiteStrictMode c.Res().Cookie(cookie) require.Equal(t, expect, c.Res().Get(HeaderSetCookie)) expect = "username=john; expires=" + httpdate + "; path=/; secure; SameSite=None" cookie.Secure = true cookie.SameSite = CookieSameSiteNoneMode c.Res().Cookie(cookie) require.Equal(t, expect, c.Res().Get(HeaderSetCookie)) expect = "username=john; path=/; secure; SameSite=None" // should remove expires and max-age headers cookie.SessionOnly = true cookie.Expires = expire cookie.MaxAge = 10000 c.Res().Cookie(cookie) require.Equal(t, expect, c.Res().Get(HeaderSetCookie)) expect = "username=john; path=/; secure; SameSite=None" // should remove expires and max-age headers when no expire and no MaxAge (default time) cookie.SessionOnly = false cookie.Expires = time.Time{} cookie.MaxAge = 0 c.Res().Cookie(cookie) require.Equal(t, expect, c.Res().Get(HeaderSetCookie)) expect = "username=john; path=/; secure; SameSite=None; Partitioned" cookie.Partitioned = true c.Res().Cookie(cookie) require.Equal(t, expect, c.Res().Get(HeaderSetCookie)) } // go test -run Test_Ctx_Cookie_PartitionedSecure func Test_Ctx_Cookie_PartitionedSecure(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) ck := &Cookie{ Name: "ps", Value: "v", Secure: true, SameSite: CookieSameSiteNoneMode, Partitioned: true, } c.Res().Cookie(ck) require.Equal(t, "ps=v; path=/; secure; SameSite=None; Partitioned", c.Res().Get(HeaderSetCookie)) } // go test -run Test_Ctx_Cookie_Invalid func Test_Ctx_Cookie_Invalid(t *testing.T) { t.Parallel() app := New() cases := []*Cookie{ {Name: "", Value: "a"}, // empty name {Name: "foo bar", Value: "a"}, // invalid char in name {Name: "n", Value: "bad\nval"}, // invalid value byte {Name: "d", Value: "b", Domain: "in valid"}, // invalid domain spaces {Name: "d", Value: "b", Domain: "example..com"}, // invalid domain dots {Name: "i", Value: "b", Domain: "2001:db8::1"}, // ipv6 not allowed {Name: "p", Value: "b", Path: "\x00"}, // invalid path byte {Name: "e", Value: "b", Expires: time.Date(1500, 1, 1, 0, 0, 0, 0, time.UTC)}, // invalid expires // Note: Partitioned without Secure is auto-fixed (Secure=true set automatically per CHIPS spec) } for _, invalid := range cases { c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Res().Cookie(invalid) require.Empty(t, c.Res().Get(HeaderSetCookie)) c.Response().Header.Reset() app.ReleaseCtx(c) } } // go test -run Test_Ctx_Cookie_DefaultPath func Test_Ctx_Cookie_DefaultPath(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) ck := &Cookie{ Name: "p", Value: "v", // Path intentionally empty to verify defaulting } c.Res().Cookie(ck) require.Equal(t, "p=v; path=/; SameSite=Lax", c.Res().Get(HeaderSetCookie), ) } // go test -run Test_Ctx_Cookie_MaxAgeOnly func Test_Ctx_Cookie_MaxAgeOnly(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) ck := &Cookie{ Name: "ttl", Value: "v", MaxAge: 3600, } c.Res().Cookie(ck) require.Equal(t, "ttl=v; max-age=3600; path=/; SameSite=Lax", c.Res().Get(HeaderSetCookie), ) } // go test -run Test_Ctx_Cookie_StrictPartitioned func Test_Ctx_Cookie_StrictPartitioned(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) ck := &Cookie{ Name: "sp", Value: "v", Secure: true, SameSite: CookieSameSiteStrictMode, Partitioned: true, } c.Res().Cookie(ck) require.Equal(t, "sp=v; path=/; secure; SameSite=Strict; Partitioned", c.Res().Get(HeaderSetCookie), ) } // go test -run Test_Ctx_Cookie_SameSite_CaseInsensitive func Test_Ctx_Cookie_SameSite_CaseInsensitive(t *testing.T) { t.Parallel() app := New() tests := []struct { name string input string expected string }{ // Test case-insensitive Strict {"Strict lowercase", "strict", "SameSite=Strict"}, {"Strict uppercase", "STRICT", "SameSite=Strict"}, {"Strict mixed case", "StRiCt", "SameSite=Strict"}, {"Strict proper case", "Strict", "SameSite=Strict"}, // Test case-insensitive Lax {"Lax lowercase", "lax", "SameSite=Lax"}, {"Lax uppercase", "LAX", "SameSite=Lax"}, {"Lax mixed case", "LaX", "SameSite=Lax"}, {"Lax proper case", "Lax", "SameSite=Lax"}, // Test case-insensitive None {"None lowercase", "none", "SameSite=None"}, {"None uppercase", "NONE", "SameSite=None"}, {"None mixed case", "NoNe", "SameSite=None"}, {"None proper case", "None", "SameSite=None"}, // Test case-insensitive disabled {"Disabled lowercase", "disabled", ""}, {"Disabled uppercase", "DISABLED", ""}, {"Disabled mixed case", "DiSaBlEd", ""}, {"Disabled proper case", "disabled", ""}, // Test invalid values default to Lax {"Invalid value", "invalid", "SameSite=Lax"}, {"Empty value", "", "SameSite=Lax"}, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { t.Parallel() c := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(c) // Reset response c.Response().Reset() cookie := &Cookie{ Name: "test", Value: "value", SameSite: tc.input, } c.Res().Cookie(cookie) setCookieHeader := c.Res().Get(HeaderSetCookie) if tc.expected == "" { // For disabled, SameSite should not appear in the header require.NotContains(t, setCookieHeader, "SameSite") } else { // For all other cases, the expected SameSite should appear require.Contains(t, setCookieHeader, tc.expected) } }) } } // go test -run Test_Ctx_Cookie_SameSite_None_Secure func Test_Ctx_Cookie_SameSite_None_Secure(t *testing.T) { t.Parallel() testCases := []struct { name string cookie *Cookie expectedInHeader string shouldBeSecure bool }{ { name: "Empty value", cookie: &Cookie{ Name: "test", Value: "value", SameSite: "", }, expectedInHeader: "SameSite=Lax", shouldBeSecure: false, }, { name: "None uppercase", cookie: &Cookie{ Name: "test", Value: "value", SameSite: "None", }, expectedInHeader: "SameSite=None", shouldBeSecure: true, }, { name: "None lowercase", cookie: &Cookie{ Name: "test", Value: "value", SameSite: "none", }, expectedInHeader: "SameSite=None", shouldBeSecure: true, }, { name: "Lax proper case", cookie: &Cookie{ Name: "test", Value: "value", SameSite: "Lax", }, expectedInHeader: "SameSite=Lax", shouldBeSecure: false, }, { name: "Strict uppercase", cookie: &Cookie{ Name: "test", Value: "value", SameSite: "STRICT", }, expectedInHeader: "SameSite=Strict", shouldBeSecure: false, }, { name: "Disabled Secure", cookie: &Cookie{ Name: "test", Value: "value", SameSite: "none", Secure: false, }, expectedInHeader: "SameSite=None", shouldBeSecure: true, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() app := New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) ctx.Cookie(tc.cookie) cookie := string(ctx.Response().Header.PeekCookie(tc.cookie.Name)) require.Contains(t, cookie, tc.expectedInHeader) if tc.shouldBeSecure { require.Contains(t, cookie, "secure") } else { require.NotContains(t, cookie, "secure") } }) } } // go test -v -run=^$ -bench=Benchmark_Ctx_Cookie -benchmem -count=4 func Benchmark_Ctx_Cookie(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed b.ReportAllocs() for b.Loop() { c.Cookie(&Cookie{ Name: "John", Value: "Doe", }) } require.Equal(b, "John=Doe; path=/; SameSite=Lax", app.toString(c.Response().Header.Peek("Set-Cookie"))) } // go test -run Test_Ctx_Cookies func Test_Ctx_Cookies(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set("Cookie", "john=doe") require.Equal(t, "doe", c.Req().Cookies("john")) require.Equal(t, "default", c.Req().Cookies("unknown", "default")) } // go test -run Test_Ctx_Format func Test_Ctx_Format(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) // set `accepted` to whatever media type was chosen by Format var accepted string formatHandlers := func(types ...string) []ResFmt { fmts := []ResFmt{} for _, t := range types { typ := utils.CopyString(t) fmts = append(fmts, ResFmt{MediaType: typ, Handler: func(_ Ctx) error { accepted = typ return nil }}) } return fmts } c.Request().Header.Set(HeaderAccept, `text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7`) err := c.Res().Format(formatHandlers("application/xhtml+xml", "application/xml", "foo/bar")...) require.Equal(t, "application/xhtml+xml", accepted) require.Equal(t, "application/xhtml+xml", c.GetRespHeader(HeaderContentType)) require.Equal(t, "application/xhtml+xml", c.Res().Get(HeaderContentType)) require.NoError(t, err) require.NotEqual(t, StatusNotAcceptable, c.Response().StatusCode()) err = c.Res().Format(formatHandlers("foo/bar;a=b")...) require.Equal(t, "foo/bar;a=b", accepted) require.Equal(t, "foo/bar;a=b", c.GetRespHeader(HeaderContentType)) require.Equal(t, "foo/bar;a=b", c.Res().Get(HeaderContentType)) require.NoError(t, err) require.NotEqual(t, StatusNotAcceptable, c.Response().StatusCode()) myError := errors.New("this is an error") err = c.Format(ResFmt{MediaType: "text/html", Handler: func(_ Ctx) error { return myError }}) require.ErrorIs(t, err, myError) c.Request().Header.Set(HeaderAccept, "application/json") err = c.Format(ResFmt{MediaType: "text/html", Handler: func(c Ctx) error { return c.SendStatus(StatusOK) }}) require.Equal(t, StatusNotAcceptable, c.Response().StatusCode()) require.NoError(t, err) c.Request().Header.Set(HeaderAccept, MIMEApplicationMsgPack) err = c.Format(ResFmt{MediaType: "text/html", Handler: func(c Ctx) error { return c.SendStatus(StatusOK) }}) require.Equal(t, StatusNotAcceptable, c.Response().StatusCode()) require.NoError(t, err) err = c.Format(formatHandlers("text/html", "default")...) require.Equal(t, "default", accepted) require.Equal(t, "text/html", c.GetRespHeader(HeaderContentType)) require.Equal(t, "text/html", c.Res().Get(HeaderContentType)) require.NoError(t, err) err = c.Format() require.ErrorIs(t, err, ErrNoHandlers) } func Test_Ctx_Format_NilHandler(t *testing.T) { t.Parallel() t.Run("nil handler in first entry", func(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) var err error require.NotPanics(t, func() { err = c.Format( ResFmt{MediaType: "text/html", Handler: nil}, ResFmt{MediaType: "application/json", Handler: func(_ Ctx) error { return nil }}, ) }) require.Error(t, err) require.Contains(t, err.Error(), `media type "text/html"`) require.Contains(t, err.Error(), "index 0") }) t.Run("nil handler in matched media type", func(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderAccept, "application/json") var err error require.NotPanics(t, func() { err = c.Format( ResFmt{MediaType: "text/html", Handler: func(_ Ctx) error { return nil }}, ResFmt{MediaType: "application/json", Handler: nil}, ) }) require.Error(t, err) require.Contains(t, err.Error(), `media type "application/json"`) require.Contains(t, err.Error(), "index 1") }) t.Run("nil default handler", func(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderAccept, "application/json") var err error require.NotPanics(t, func() { err = c.Format( ResFmt{MediaType: "text/html", Handler: func(_ Ctx) error { return nil }}, ResFmt{MediaType: "default", Handler: nil}, ) }) require.Error(t, err) require.Contains(t, err.Error(), `media type "default"`) require.Contains(t, err.Error(), "index 1") }) } func Benchmark_Ctx_Format(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderAccept, "application/json,text/plain; format=flowed; q=0.9") fail := func(_ Ctx) error { require.FailNow(b, "Wrong type chosen") return errors.New("Wrong type chosen") } ok := func(_ Ctx) error { return nil } var err error b.Run("with arg allocation", func(b *testing.B) { for b.Loop() { err = c.Format( ResFmt{MediaType: "application/xml", Handler: fail}, ResFmt{MediaType: "text/html", Handler: fail}, ResFmt{MediaType: "text/plain;format=fixed", Handler: fail}, ResFmt{MediaType: "text/plain;format=flowed", Handler: ok}, ) } require.NoError(b, err) }) b.Run("pre-allocated args", func(b *testing.B) { offers := []ResFmt{ {MediaType: "application/xml", Handler: fail}, {MediaType: "text/html", Handler: fail}, {MediaType: "text/plain;format=fixed", Handler: fail}, {MediaType: "text/plain;format=flowed", Handler: ok}, } for b.Loop() { err = c.Format(offers...) } require.NoError(b, err) }) c.Request().Header.Set("Accept", "text/plain") b.Run("text/plain", func(b *testing.B) { offers := []ResFmt{ {MediaType: "application/xml", Handler: fail}, {MediaType: "text/plain", Handler: ok}, } for b.Loop() { err = c.Format(offers...) } require.NoError(b, err) }) c.Request().Header.Set("Accept", "json") b.Run("json", func(b *testing.B) { offers := []ResFmt{ {MediaType: "xml", Handler: fail}, {MediaType: "html", Handler: fail}, {MediaType: "json", Handler: ok}, } for b.Loop() { err = c.Format(offers...) } require.NoError(b, err) }) c.Request().Header.Set("Accept", MIMEApplicationMsgPack) b.Run("msgpack", func(b *testing.B) { offers := []ResFmt{ {MediaType: "xml", Handler: fail}, {MediaType: "html", Handler: fail}, {MediaType: MIMEApplicationMsgPack, Handler: ok}, } for b.Loop() { err = c.Format(offers...) } require.NoError(b, err) }) } // go test -run Test_Ctx_AutoFormat func Test_Ctx_AutoFormat(t *testing.T) { t.Parallel() app := New(Config{ MsgPackEncoder: msgpack.Marshal, MsgPackDecoder: msgpack.Unmarshal, CBOREncoder: cbor.Marshal, CBORDecoder: cbor.Unmarshal, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderAccept, MIMETextPlain) err := c.AutoFormat([]byte("Hello, World!")) require.NoError(t, err) require.Equal(t, MIMETextPlainCharsetUTF8, c.GetRespHeader(HeaderContentType)) require.Equal(t, "Hello, World!", string(c.Response().Body())) c.Request().Header.Set(HeaderAccept, MIMETextHTML) err = c.Res().AutoFormat("Hello, World!") require.NoError(t, err) require.Equal(t, MIMETextHTMLCharsetUTF8, c.GetRespHeader(HeaderContentType)) require.Equal(t, "

Hello, World!

", string(c.Response().Body())) c.Request().Header.Set(HeaderAccept, MIMEApplicationJSON) err = c.AutoFormat("Hello, World!") require.NoError(t, err) require.Equal(t, MIMEApplicationJSONCharsetUTF8, c.GetRespHeader(HeaderContentType)) //nolint:testifylint // this is comparing content-type headers, not JSON content require.Equal(t, `"Hello, World!"`, string(c.Response().Body())) c.Request().Header.Set(HeaderAccept, MIMEApplicationMsgPack) err = c.AutoFormat("Hello, World!") require.NoError(t, err) require.Equal(t, MIMEApplicationMsgPack, c.GetRespHeader(HeaderContentType)) require.Equal(t, []byte{ 0xad, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20, 0x57, 0x6f, 0x72, 0x6c, 0x64, 0x21, }, c.Response().Body()) c.Request().Header.Set(HeaderAccept, MIMEApplicationCBOR) err = c.AutoFormat("Hello, World!") require.NoError(t, err) require.Equal(t, MIMEApplicationCBOR, c.GetRespHeader(HeaderContentType)) require.Equal(t, `6d48656c6c6f2c20576f726c6421`, hex.EncodeToString(c.Response().Body())) c.Request().Header.Set(HeaderAccept, MIMETextPlain) err = c.Res().AutoFormat(complex(1, 1)) require.NoError(t, err) require.Equal(t, MIMETextPlainCharsetUTF8, c.GetRespHeader(HeaderContentType)) require.Equal(t, "(1+1i)", string(c.Response().Body())) c.Request().Header.Set(HeaderAccept, MIMEApplicationXML) err = c.AutoFormat("Hello, World!") require.NoError(t, err) require.Equal(t, MIMEApplicationXMLCharsetUTF8, c.GetRespHeader(HeaderContentType)) require.Equal(t, `Hello, World!`, string(c.Response().Body())) err = c.AutoFormat(complex(1, 1)) require.Error(t, err) c.Request().Header.Set(HeaderAccept, MIMETextPlain) err = c.AutoFormat(Map{}) require.NoError(t, err) require.Equal(t, MIMETextPlainCharsetUTF8, c.GetRespHeader(HeaderContentType)) require.Equal(t, "map[]", string(c.Response().Body())) type broken string c.Request().Header.Set(HeaderAccept, "broken/accept") require.NoError(t, err) err = c.AutoFormat(broken("Hello, World!")) require.NoError(t, err) require.Equal(t, MIMETextPlainCharsetUTF8, c.GetRespHeader(HeaderContentType)) require.Equal(t, `Hello, World!`, string(c.Response().Body())) } func Test_Ctx_AutoFormat_Struct(t *testing.T) { t.Parallel() app := New(Config{ MsgPackEncoder: msgpack.Marshal, MsgPackDecoder: msgpack.Unmarshal, CBOREncoder: cbor.Marshal, CBORDecoder: cbor.Unmarshal, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) type Message struct { Sender string `xml:"sender,attr"` Recipients []string Urgency int `xml:"urgency,attr"` } data := Message{ Recipients: []string{"Alice", "Bob"}, Sender: "Carol", Urgency: 3, } c.Request().Header.Set(HeaderAccept, MIMEApplicationJSON) err := c.AutoFormat(data) require.NoError(t, err) require.Equal(t, MIMEApplicationJSONCharsetUTF8, c.GetRespHeader(HeaderContentType)) //nolint:testifylint // this is comparing content-type headers, not JSON content require.JSONEq(t, `{"Sender":"Carol","Recipients":["Alice","Bob"],"Urgency":3}`, string(c.Response().Body()), ) c.Request().Header.Set(HeaderAccept, MIMEApplicationMsgPack) err = c.AutoFormat(data) require.NoError(t, err) require.Equal(t, MIMEApplicationMsgPack, c.GetRespHeader(HeaderContentType)) require.Equal(t, []byte{ // {"Sender":"Carol","Recipients":["Alice","Bob"],"Urgency":3} 0x83, 0xa6, 0x53, 0x65, 0x6e, 0x64, 0x65, 0x72, 0xa5, 0x43, 0x61, 0x72, 0x6f, 0x6c, 0xaa, 0x52, 0x65, 0x63, 0x69, 0x70, 0x69, 0x65, 0x6e, 0x74, 0x73, 0x92, 0xa5, 0x41, 0x6c, 0x69, 0x63, 0x65, 0xa3, 0x42, 0x6f, 0x62, 0xa7, 0x55, 0x72, 0x67, 0x65, 0x6e, 0x63, 0x79, 0x3, }, c.Response().Body()) c.Request().Header.Set(HeaderAccept, MIMEApplicationCBOR) err = c.AutoFormat(data) require.NoError(t, err) require.Equal(t, MIMEApplicationCBOR, c.GetRespHeader(HeaderContentType)) require.Equal(t, "a36653656e646572654361726f6c6a526563697069656e74738265416c69636563426f6267557267656e637903", hex.EncodeToString(c.Response().Body())) c.Request().Header.Set(HeaderAccept, MIMEApplicationXML) err = c.AutoFormat(data) require.NoError(t, err) require.Equal(t, MIMEApplicationXMLCharsetUTF8, c.GetRespHeader(HeaderContentType)) require.Equal(t, `AliceBob`, string(c.Response().Body()), ) } // go test -v -run=^$ -bench=Benchmark_Ctx_AutoFormat -benchmem -count=4 func Benchmark_Ctx_AutoFormat(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set("Accept", "text/plain") b.ReportAllocs() var err error for b.Loop() { err = c.AutoFormat("Hello, World!") } require.NoError(b, err) require.Equal(b, `Hello, World!`, string(c.Response().Body())) } // go test -v -run=^$ -bench=Benchmark_Ctx_AutoFormat_HTML -benchmem -count=4 func Benchmark_Ctx_AutoFormat_HTML(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set("Accept", "text/html") b.ReportAllocs() var err error for b.Loop() { err = c.AutoFormat("Hello, World!") } require.NoError(b, err) require.Equal(b, "

Hello, World!

", string(c.Response().Body())) } // go test -v -run=^$ -bench=Benchmark_Ctx_AutoFormat_JSON -benchmem -count=4 func Benchmark_Ctx_AutoFormat_JSON(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set("Accept", "application/json") b.ReportAllocs() var err error for b.Loop() { err = c.AutoFormat("Hello, World!") } require.NoError(b, err) require.Equal(b, `"Hello, World!"`, string(c.Response().Body())) } // go test -v -run=^$ -bench=Benchmark_Ctx_AutoFormat_MsgPack -benchmem -count=4 func Benchmark_Ctx_AutoFormat_MsgPack(b *testing.B) { app := New( Config{ MsgPackEncoder: msgpack.Marshal, MsgPackDecoder: msgpack.Unmarshal, }, ) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set("Accept", MIMEApplicationMsgPack) b.ReportAllocs() var err error for b.Loop() { err = c.AutoFormat("Hello, World!") } require.NoError(b, err) require.Equal(b, "\xadHello, World!", string(c.Response().Body())) } // go test -v -run=^$ -bench=Benchmark_Ctx_AutoFormat_XML -benchmem -count=4 func Benchmark_Ctx_AutoFormat_XML(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set("Accept", "application/xml") b.ReportAllocs() var err error for b.Loop() { err = c.AutoFormat("Hello, World!") } require.NoError(b, err) require.Equal(b, `Hello, World!`, string(c.Response().Body())) } // go test -run Test_Ctx_FormFile func Test_Ctx_FormFile(t *testing.T) { // TODO: We should clean this up t.Parallel() app := New() app.Post("/test", func(c Ctx) error { fh, err := c.FormFile("file") require.NoError(t, err) require.Equal(t, "test", fh.Filename) f, err := fh.Open() require.NoError(t, err) defer func() { require.NoError(t, f.Close()) }() b := new(bytes.Buffer) _, err = io.Copy(b, f) require.NoError(t, err) require.Equal(t, "hello world", b.String()) return nil }) body := &bytes.Buffer{} writer := multipart.NewWriter(body) ioWriter, err := writer.CreateFormFile("file", "test") require.NoError(t, err) _, err = ioWriter.Write([]byte("hello world")) require.NoError(t, err) require.NoError(t, writer.Close()) req := httptest.NewRequest(MethodPost, "/test", body) req.Header.Set(HeaderContentType, writer.FormDataContentType()) req.Header.Set(HeaderContentLength, strconv.Itoa(len(body.Bytes()))) resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") defer func() { require.NoError(t, resp.Body.Close()) }() require.Equal(t, StatusOK, resp.StatusCode, "Status code") respBody, err := io.ReadAll(resp.Body) require.NoError(t, err, "io.ReadAll(resp.Body)") require.Empty(t, respBody) require.Empty(t, resp.Header.Get(HeaderContentType)) require.Equal(t, int64(0), resp.ContentLength) } // go test -run Test_Ctx_FormValue func Test_Ctx_FormValue(t *testing.T) { t.Parallel() app := New() app.Post("/test", func(c Ctx) error { require.Equal(t, "john", c.FormValue("name")) return nil }) body := &bytes.Buffer{} writer := multipart.NewWriter(body) require.NoError(t, writer.WriteField("name", "john")) require.NoError(t, writer.Close()) req := httptest.NewRequest(MethodPost, "/test", body) req.Header.Set("Content-Type", "multipart/form-data; boundary="+writer.Boundary()) req.Header.Set("Content-Length", strconv.Itoa(len(body.Bytes()))) resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") defer func() { require.NoError(t, resp.Body.Close()) }() require.Equal(t, StatusOK, resp.StatusCode, "Status code") respBody, err := io.ReadAll(resp.Body) require.NoError(t, err, "io.ReadAll(resp.Body)") require.Empty(t, respBody) require.Empty(t, resp.Header.Get(HeaderContentType)) require.Equal(t, int64(0), resp.ContentLength) } // go test -v -run=^$ -bench=Benchmark_Ctx_Fresh_StaleEtag -benchmem -count=4 func Benchmark_Ctx_Fresh_StaleEtag(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) for b.Loop() { c.Request().Header.Set(HeaderIfNoneMatch, `"a", "b", "c", "d"`) c.Request().Header.Set(HeaderCacheControl, "c") c.Fresh() c.Request().Header.Set(HeaderIfNoneMatch, `"a", "b", "c", "d"`) c.Request().Header.Set(HeaderCacheControl, "e") c.Fresh() } } // go test -run Test_Ctx_Fresh func Test_Ctx_Fresh(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) require.False(t, c.Fresh()) c.Request().Header.Set(HeaderIfNoneMatch, "*") c.Request().Header.Set(HeaderCacheControl, "no-cache") require.False(t, c.Fresh()) c.Request().Header.Set(HeaderIfNoneMatch, "*") c.Request().Header.Set(HeaderCacheControl, ",no-cache,") require.False(t, c.Fresh()) c.Request().Header.Set(HeaderIfNoneMatch, "*") c.Request().Header.Set(HeaderCacheControl, "aa,no-cache,") require.False(t, c.Fresh()) c.Request().Header.Set(HeaderIfNoneMatch, "*") c.Request().Header.Set(HeaderCacheControl, ",no-cache,bb") require.False(t, c.Fresh()) c.Request().Header.Set(HeaderIfNoneMatch, `"675af34563dc-tr34"`) c.Request().Header.Set(HeaderCacheControl, "public") require.False(t, c.Fresh()) c.Request().Header.Set(HeaderIfNoneMatch, `"a", "b"`) c.Response().Header.Set(HeaderETag, `"c"`) require.False(t, c.Fresh()) c.Response().Header.Set(HeaderETag, `"a"`) require.True(t, c.Fresh()) c.Request().Header.Set(HeaderIfModifiedSince, "xxWed, 21 Oct 2015 07:28:00 GMT") c.Response().Header.Set(HeaderLastModified, "xxWed, 21 Oct 2015 07:28:00 GMT") require.False(t, c.Fresh()) c.Response().Header.Set(HeaderLastModified, "Wed, 21 Oct 2015 07:28:00 GMT") require.False(t, c.Fresh()) c.Request().Header.Set(HeaderIfModifiedSince, "Wed, 21 Oct 2015 07:28:00 GMT") require.True(t, c.Fresh()) c.Request().Header.Set(HeaderIfModifiedSince, "Wed, 21 Oct 2015 07:27:59 GMT") c.Response().Header.Set(HeaderLastModified, "Wed, 21 Oct 2015 07:28:00 GMT") require.False(t, c.Fresh()) } // go test -v -run=^$ -bench=Benchmark_Ctx_Fresh_WithNoCache -benchmem -count=4 func Benchmark_Ctx_Fresh_WithNoCache(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderIfNoneMatch, "*") c.Request().Header.Set(HeaderCacheControl, "no-cache") for b.Loop() { c.Fresh() } } // go test -v -run=^$ -bench=Benchmark_Ctx_Fresh_LastModified -benchmem -count=4 func Benchmark_Ctx_Fresh_LastModified(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Response().Header.Set(HeaderLastModified, "Wed, 21 Oct 2015 07:28:00 GMT") c.Request().Header.Set(HeaderIfModifiedSince, "Wed, 21 Oct 2015 07:28:00 GMT") for b.Loop() { c.Fresh() } } // go test -run Test_Ctx_Binders -v func Test_Ctx_Binders(t *testing.T) { t.Parallel() // setup app := New(Config{ EnableSplittingOnParsers: true, }) type TestEmbeddedStruct struct { Names []string `query:"names"` } type TestStruct struct { Name string NameWithDefault string `json:"name2" xml:"Name2" form:"name2" cookie:"name2" query:"name2" uri:"name2" header:"Name2"` TestEmbeddedStruct Class int ClassWithDefault int `json:"class2" xml:"Class2" form:"class2" cookie:"class2" query:"class2" uri:"class2" header:"Class2"` } withValues := func(t *testing.T, actionFn func(c Ctx, testStruct *TestStruct) error) { t.Helper() c := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(c) testStruct := new(TestStruct) require.NoError(t, actionFn(c, testStruct)) require.Equal(t, "foo", testStruct.Name) require.Equal(t, 111, testStruct.Class) require.Equal(t, "bar", testStruct.NameWithDefault) require.Equal(t, 222, testStruct.ClassWithDefault) require.Equal(t, []string{"foo", "bar", "test"}, testStruct.Names) } t.Run("Body:xml", func(t *testing.T) { t.Parallel() withValues(t, func(c Ctx, testStruct *TestStruct) error { c.Request().Header.SetContentType(MIMEApplicationXML) c.Request().SetBody([]byte(`foo111bar222foobartest`)) return c.Bind().Body(testStruct) }) }) t.Run("Body:form", func(t *testing.T) { t.Parallel() withValues(t, func(c Ctx, testStruct *TestStruct) error { c.Request().Header.SetContentType(MIMEApplicationForm) c.Request().SetBody([]byte(`name=foo&class=111&name2=bar&class2=222&names=foo,bar,test`)) return c.Bind().Body(testStruct) }) }) t.Run("BodyParser:json", func(t *testing.T) { t.Parallel() withValues(t, func(c Ctx, testStruct *TestStruct) error { c.Request().Header.SetContentType(MIMEApplicationJSON) c.Request().SetBody([]byte(`{"name":"foo","class":111,"name2":"bar","class2":222,"names":["foo","bar","test"]}`)) return c.Bind().Body(testStruct) }) }) t.Run("Body:multiform", func(t *testing.T) { t.Parallel() withValues(t, func(c Ctx, testStruct *TestStruct) error { body := []byte("--b\r\nContent-Disposition: form-data; name=\"name\"\r\n\r\nfoo\r\n--b\r\nContent-Disposition: form-data; name=\"class\"\r\n\r\n111\r\n--b\r\nContent-Disposition: form-data; name=\"name2\"\r\n\r\nbar\r\n--b\r\nContent-Disposition: form-data; name=\"class2\"\r\n\r\n222\r\n--b\r\nContent-Disposition: form-data; name=\"names\"\r\n\r\nfoo\r\n--b\r\nContent-Disposition: form-data; name=\"names\"\r\n\r\nbar\r\n--b\r\nContent-Disposition: form-data; name=\"names\"\r\n\r\ntest\r\n--b--") c.Request().SetBody(body) c.Request().Header.SetContentType(MIMEMultipartForm + `;boundary="b"`) c.Request().Header.SetContentLength(len(body)) return c.Bind().Body(testStruct) }) }) t.Run("Cookie", func(t *testing.T) { t.Parallel() withValues(t, func(c Ctx, testStruct *TestStruct) error { c.Request().Header.Set("Cookie", "name=foo;name2=bar;class=111;class2=222;names=foo,bar,test") return c.Bind().Cookie(testStruct) }) }) t.Run("Query", func(t *testing.T) { t.Parallel() withValues(t, func(c Ctx, testStruct *TestStruct) error { c.Request().URI().SetQueryString("name=foo&name2=bar&class=111&class2=222&names=foo,bar,test") return c.Bind().Query(testStruct) }) }) t.Run("URI", func(t *testing.T) { t.Parallel() c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed defer app.ReleaseCtx(c) c.route = &Route{Params: []string{"name", "name2", "class", "class2"}} c.values = [maxParams]string{"foo", "bar", "111", "222"} testStruct := new(TestStruct) require.NoError(t, c.Bind().URI(testStruct)) require.Equal(t, "foo", testStruct.Name) require.Equal(t, 111, testStruct.Class) require.Equal(t, "bar", testStruct.NameWithDefault) require.Equal(t, 222, testStruct.ClassWithDefault) require.Nil(t, testStruct.Names) }) t.Run("ReqHeader", func(t *testing.T) { t.Parallel() withValues(t, func(c Ctx, testStruct *TestStruct) error { c.Request().Header.Add("name", "foo") c.Request().Header.Add("name2", "bar") c.Request().Header.Add("class", "111") c.Request().Header.Add("class2", "222") c.Request().Header.Add("names", "foo,bar,test") return c.Bind().Header(testStruct) }) }) } // go test -run Test_Ctx_Get func Test_Ctx_Get(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderAcceptCharset, "utf-8, iso-8859-1;q=0.5") c.Request().Header.Set(HeaderReferer, "Monster") require.Equal(t, "utf-8, iso-8859-1;q=0.5", c.Get(HeaderAcceptCharset)) require.Equal(t, "Monster", c.Get(HeaderReferer)) require.Equal(t, "default", c.Get("unknown", "default")) } // go test -run Test_Ctx_GetReqHeader func Test_Ctx_GetReqHeader(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set("foo", "bar") c.Request().Header.Set("id", "123") require.Equal(t, 123, GetReqHeader[int](c, "id")) require.Equal(t, "bar", GetReqHeader[string](c, "foo")) } // go test -run Test_Ctx_Host func Test_Ctx_Host(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com/test") require.Equal(t, "google.com", c.Host()) } // go test -run Test_Ctx_Host_UntrustedProxy func Test_Ctx_Host_UntrustedProxy(t *testing.T) { t.Parallel() // Don't trust any proxy { app := New(Config{TrustProxy: true, TrustProxyConfig: TrustProxyConfig{Proxies: []string{}}}) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com/test") c.Request().Header.Set(HeaderXForwardedHost, "google1.com") require.Equal(t, "google.com", c.Host()) app.ReleaseCtx(c) } // Trust to specific proxy list { app := New(Config{TrustProxy: true, TrustProxyConfig: TrustProxyConfig{Proxies: []string{"0.8.0.0", "0.8.0.1"}}}) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com/test") c.Request().Header.Set(HeaderXForwardedHost, "google1.com") require.Equal(t, "google.com", c.Host()) app.ReleaseCtx(c) } } // go test -run Test_Ctx_Host_TrustedProxy func Test_Ctx_Host_TrustedProxy(t *testing.T) { t.Parallel() { app := New(Config{TrustProxy: true, TrustProxyConfig: TrustProxyConfig{Proxies: []string{"0.0.0.0", "0.8.0.1"}}}) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com/test") c.Request().Header.Set(HeaderXForwardedHost, "google1.com") require.Equal(t, "google1.com", c.Host()) app.ReleaseCtx(c) } t.Run("TrimWhitespaceFromForwardedHost", func(t *testing.T) { t.Parallel() testCases := []struct { name string forwardedHost string expectedHost string }{ { name: "leading whitespace with comma", forwardedHost: " example.com, proxy1", expectedHost: "example.com", }, { name: "trailing whitespace with comma", forwardedHost: "example.com , proxy1", expectedHost: "example.com", }, { name: "leading and trailing whitespace with comma", forwardedHost: " example.com , proxy1", expectedHost: "example.com", }, { name: "no whitespace with comma", forwardedHost: "example.com, proxy1", expectedHost: "example.com", }, { name: "single value with whitespace", forwardedHost: " example.com ", expectedHost: "example.com", }, { name: "leading comma", forwardedHost: ",example.com", expectedHost: "", }, } app := New(Config{TrustProxy: true, TrustProxyConfig: TrustProxyConfig{Proxies: []string{"0.0.0.0", "0.8.0.1"}}}) for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() c := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(c) c.Request().SetRequestURI("http://google.com/test") c.Request().Header.Set(HeaderXForwardedHost, tc.forwardedHost) require.Equal(t, tc.expectedHost, c.Host()) }) } }) } // go test -run Test_Ctx_Host_TrustedProxyRange func Test_Ctx_Host_TrustedProxyRange(t *testing.T) { t.Parallel() app := New(Config{TrustProxy: true, TrustProxyConfig: TrustProxyConfig{Proxies: []string{"0.0.0.0/30"}}}) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com/test") c.Request().Header.Set(HeaderXForwardedHost, "google1.com") require.Equal(t, "google1.com", c.Host()) app.ReleaseCtx(c) } // go test -run Test_Ctx_Host_UntrustedProxyRange func Test_Ctx_Host_UntrustedProxyRange(t *testing.T) { t.Parallel() app := New(Config{TrustProxy: true, TrustProxyConfig: TrustProxyConfig{Proxies: []string{"1.0.0.0/30"}}}) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com/test") c.Request().Header.Set(HeaderXForwardedHost, "google1.com") require.Equal(t, "google.com", c.Host()) app.ReleaseCtx(c) } // go test -v -run=^$ -bench=Benchmark_Ctx_Host -benchmem -count=4 func Benchmark_Ctx_Host(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com/test") var host string b.ReportAllocs() for b.Loop() { host = c.Host() } require.Equal(b, "google.com", host) } // go test -run Test_Ctx_IsProxyTrusted func Test_Ctx_IsProxyTrusted(t *testing.T) { t.Parallel() { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(c) require.False(t, c.IsProxyTrusted()) } { app := New(Config{ TrustProxy: false, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) require.False(t, c.IsProxyTrusted()) } { app := New(Config{ TrustProxy: true, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) require.False(t, c.IsProxyTrusted()) } { app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{ Proxies: []string{}, }, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) require.False(t, c.IsProxyTrusted()) } { app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{ Proxies: []string{"127.0.0.1"}, }, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) require.False(t, c.IsProxyTrusted()) } { app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{ Proxies: []string{"127.0.0.1/8"}, }, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) require.False(t, c.IsProxyTrusted()) } { app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{ Proxies: []string{"0.0.0.0"}, }, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) require.True(t, c.IsProxyTrusted()) } { app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{ Proxies: []string{"0.0.0.1/31"}, }, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) require.True(t, c.IsProxyTrusted()) } { app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{ Proxies: []string{"0.0.0.1/31junk"}, }, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) require.False(t, c.IsProxyTrusted()) } { app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{ Private: true, }, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) require.False(t, c.IsProxyTrusted()) } { app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{ Loopback: true, }, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) require.False(t, c.IsProxyTrusted()) } { app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{ LinkLocal: true, }, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) require.False(t, c.IsProxyTrusted()) } } // go test -run Test_Ctx_Hostname func Test_Ctx_Hostname(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com/test") require.Equal(t, "google.com", c.Hostname()) c.Request().SetRequestURI("http://google.com:8080/test") require.Equal(t, "google.com", c.Hostname()) } // go test -v -run=^$ -bench=Benchmark_Ctx_Hostname -benchmem -count=4 func Benchmark_Ctx_Hostname(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com:8080/test") var hostname string b.ReportAllocs() for b.Loop() { hostname = c.Hostname() } // Trust to specific proxy list { app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{Proxies: []string{"0.8.0.0", "0.8.0.1"}}, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com/test") c.Request().Header.Set(HeaderXForwardedHost, "google1.com") require.Equal(b, "google.com", hostname) app.ReleaseCtx(c) } } // go test -run Test_Ctx_Hostname_Trusted func Test_Ctx_Hostname_TrustedProxy(t *testing.T) { t.Parallel() { app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{Proxies: []string{"0.0.0.0", "0.8.0.1"}}, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com/test") c.Request().Header.Set(HeaderXForwardedHost, "google1.com") require.Equal(t, "google1.com", c.Hostname()) app.ReleaseCtx(c) } } // go test -run Test_Ctx_Hostname_Trusted_Multiple func Test_Ctx_Hostname_TrustedProxy_Multiple(t *testing.T) { t.Parallel() { app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{Proxies: []string{"0.0.0.0", "0.8.0.1"}}, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com/test") c.Request().Header.Set(HeaderXForwardedHost, "google1.com, google2.com") require.Equal(t, "google1.com", c.Hostname()) app.ReleaseCtx(c) } } // go test -run Test_Ctx_Hostname_UntrustedProxyRange func Test_Ctx_Hostname_TrustedProxyRange(t *testing.T) { t.Parallel() app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{Proxies: []string{"0.0.0.0/30"}}, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com/test") c.Request().Header.Set(HeaderXForwardedHost, "google1.com") require.Equal(t, "google1.com", c.Hostname()) app.ReleaseCtx(c) } // go test -run Test_Ctx_Hostname_UntrustedProxyRange func Test_Ctx_Hostname_UntrustedProxyRange(t *testing.T) { t.Parallel() app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{Proxies: []string{"1.0.0.0/30"}}, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com/test") c.Request().Header.Set(HeaderXForwardedHost, "google1.com") require.Equal(t, "google.com", c.Hostname()) app.ReleaseCtx(c) } // go test -run Test_Ctx_Port func Test_Ctx_Port(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) require.Equal(t, "0", c.Port()) } // go test -run Test_Ctx_Port_RemoteAddrVariants func Test_Ctx_Port_RemoteAddrVariants(t *testing.T) { t.Parallel() tests := []struct { name string remote net.Addr want string }{ { name: "tcp", remote: &net.TCPAddr{ IP: net.IPv4(127, 0, 0, 1), Port: 8080, }, want: "8080", }, { name: "unix", remote: &net.UnixAddr{Name: "/tmp/fiber.sock", Net: "unix"}, want: "", }, { name: "default-remote", remote: nil, want: "0", }, { name: "string-host-port", remote: testNetAddr{network: "tcp", address: "192.0.2.1:443"}, want: "443", }, { name: "string-missing-port", remote: testNetAddr{network: "tcp", address: "192.0.2.1"}, want: "", }, { name: "string-ipv6-port", remote: testNetAddr{network: "tcp", address: "[2001:db8::1]:8443"}, want: "8443", }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { t.Parallel() app := New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) defaultCtx, ok := ctx.(*DefaultCtx) require.True(t, ok) defaultCtx.fasthttp.SetRemoteAddr(test.remote) require.Equal(t, test.want, ctx.Port()) app.ReleaseCtx(ctx) }) } } // go test -run Test_Ctx_PortInHandler func Test_Ctx_PortInHandler(t *testing.T) { t.Parallel() app := New() app.Get("/port", func(c Ctx) error { return c.SendString(c.Port()) }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/port", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "0", string(body)) } // go test -run Test_Ctx_IP func Test_Ctx_IP(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) // default behavior will return the remote IP from the stack require.Equal(t, "0.0.0.0", c.IP()) // X-Forwarded-For is set, but it is ignored because proxyHeader is not set c.Request().Header.Set(HeaderXForwardedFor, "0.0.0.1") require.Equal(t, "0.0.0.0", c.IP()) } // go test -run Test_Ctx_IP_ProxyHeader func Test_Ctx_IP_ProxyHeader(t *testing.T) { t.Parallel() // make sure that the same behavior exists for different proxy header names proxyHeaderNames := []string{"Real-Ip", HeaderXForwardedFor} for _, proxyHeaderName := range proxyHeaderNames { app := New(Config{ ProxyHeader: proxyHeaderName, TrustProxy: true, TrustProxyConfig: TrustProxyConfig{ Proxies: []string{"0.0.0.0"}, }, }) fastCtx := &fasthttp.RequestCtx{} fastCtx.SetRemoteAddr(net.Addr(&net.TCPAddr{IP: net.ParseIP("0.0.0.0")})) c := app.AcquireCtx(fastCtx) c.Request().Header.Set(proxyHeaderName, "0.0.0.1") require.Equal(t, "0.0.0.1", c.IP()) // without IP validation we return the full string c.Request().Header.Set(proxyHeaderName, "0.0.0.1, 0.0.0.2") require.Equal(t, "0.0.0.1, 0.0.0.2", c.IP()) // without IP validation we return invalid IPs c.Request().Header.Set(proxyHeaderName, "invalid, 0.0.0.2, 0.0.0.3") require.Equal(t, "invalid, 0.0.0.2, 0.0.0.3", c.IP()) // when proxy header is enabled but the value is empty, without IP validation we return an empty string c.Request().Header.Set(proxyHeaderName, "") require.Empty(t, c.IP()) // without IP validation we return an invalid IP c.Request().Header.Set(proxyHeaderName, "not-valid-ip") require.Equal(t, "not-valid-ip", c.IP()) } } // go test -run Test_Ctx_IP_ProxyHeader func Test_Ctx_IP_ProxyHeader_With_IP_Validation(t *testing.T) { t.Parallel() // make sure that the same behavior exists for different proxy header names proxyHeaderNames := []string{"Real-Ip", HeaderXForwardedFor} for _, proxyHeaderName := range proxyHeaderNames { app := New(Config{ EnableIPValidation: true, ProxyHeader: proxyHeaderName, TrustProxy: true, TrustProxyConfig: TrustProxyConfig{ Proxies: []string{"0.0.0.0"}, }, }) fastCtx := &fasthttp.RequestCtx{} fastCtx.SetRemoteAddr(net.Addr(&net.TCPAddr{IP: net.ParseIP("0.0.0.0")})) c := app.AcquireCtx(fastCtx) // when proxy header & validation is enabled and the value is a valid IP, we return it c.Request().Header.Set(proxyHeaderName, "0.0.0.1") require.Equal(t, "0.0.0.1", c.IP()) // when proxy header & validation is enabled and the value is a list of IPs, we return the first valid IP c.Request().Header.Set(proxyHeaderName, "0.0.0.1, 0.0.0.2") require.Equal(t, "0.0.0.1", c.IP()) c.Request().Header.Set(proxyHeaderName, "invalid, 0.0.0.2, 0.0.0.3") require.Equal(t, "0.0.0.2", c.IP()) // when proxy header & validation is enabled but the value is empty, we will ignore the header c.Request().Header.Set(proxyHeaderName, "") require.Equal(t, "0.0.0.0", c.IP()) // when proxy header & validation is enabled but the value is not an IP, we will ignore the header // and return the IP of the caller c.Request().Header.Set(proxyHeaderName, "not-valid-ip") require.Equal(t, "0.0.0.0", c.IP()) } } // go test -run Test_Ctx_IP_UntrustedProxy func Test_Ctx_IP_UntrustedProxy(t *testing.T) { t.Parallel() app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{Proxies: []string{"0.8.0.1"}}, ProxyHeader: HeaderXForwardedFor, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderXForwardedFor, "0.0.0.1") require.Equal(t, "0.0.0.0", c.IP()) } // go test -run Test_Ctx_IP_TrustedProxy func Test_Ctx_IP_TrustedProxy(t *testing.T) { t.Parallel() app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{Proxies: []string{"0.0.0.0"}}, ProxyHeader: HeaderXForwardedFor, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderXForwardedFor, "0.0.0.1") require.Equal(t, "0.0.0.1", c.IP()) } func Test_Ctx_ProxyTrust_UnixRemoteAddr(t *testing.T) { t.Parallel() if runtime.GOOS == "windows" { t.Skip("unix sockets are not supported on windows in this test") } t.Run("unix_socket_enabled", func(t *testing.T) { t.Parallel() parts := strings.SplitN(runCtxProxyTrustUnixRemoteAddrCase(t, true), "|", 2) require.Len(t, parts, 2) require.Equal(t, "true", parts[0]) require.Equal(t, "1.1.1.1", parts[1]) }) t.Run("unix_socket_disabled", func(t *testing.T) { t.Parallel() parts := strings.SplitN(runCtxProxyTrustUnixRemoteAddrCase(t, false), "|", 2) require.Len(t, parts, 2) require.Equal(t, "false", parts[0]) require.Equal(t, "0.0.0.0", parts[1]) }) } func runCtxProxyTrustUnixRemoteAddrCase(t *testing.T, unixSocket bool) string { t.Helper() app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{ UnixSocket: unixSocket, }, ProxyHeader: HeaderXForwardedFor, }) app.Get("/ip", func(c Ctx) error { return c.SendString(fmt.Sprintf("%t|%s", c.IsProxyTrusted(), c.IP())) }) tmp, err := os.MkdirTemp(os.TempDir(), "fiber-ctx-unix") require.NoError(t, err) t.Cleanup(func() { require.NoError(t, os.RemoveAll(tmp)) }) sock := filepath.Join(tmp, "fiber.sock") result := make(chan string, 1) errCh := make(chan error, 1) go func() { time.Sleep(1000 * time.Millisecond) client := &fasthttp.HostClient{ Addr: sock, Dial: func(addr string) (net.Conn, error) { return net.Dial(NetworkUnix, addr) }, } req := &fasthttp.Request{} resp := &fasthttp.Response{} req.SetRequestURI("http://fiber/ip") req.Header.Set(HeaderXForwardedFor, "1.1.1.1") if err = client.Do(req, resp); err != nil { result <- "" // Ensure result channel always receives a value errCh <- errors.Join(err, app.Shutdown()) return } result <- string(resp.Body()) errCh <- app.Shutdown() }() require.NoError(t, app.Listen(sock, ListenConfig{ DisableStartupMessage: true, ListenerNetwork: NetworkUnix, UnixSocketFileMode: 0o660, })) require.NoError(t, <-errCh) return <-result } // go test -run Test_Ctx_IPs -parallel func Test_Ctx_IPs(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) // normal happy path test case c.Request().Header.Set(HeaderXForwardedFor, "127.0.0.1, 127.0.0.2, 127.0.0.3") require.Equal(t, []string{"127.0.0.1", "127.0.0.2", "127.0.0.3"}, c.IPs()) // inconsistent space formatting c.Request().Header.Set(HeaderXForwardedFor, "127.0.0.1,127.0.0.2 ,127.0.0.3") require.Equal(t, []string{"127.0.0.1", "127.0.0.2", "127.0.0.3"}, c.IPs()) // invalid IPs are allowed to be returned c.Request().Header.Set(HeaderXForwardedFor, "invalid, 127.0.0.1, 127.0.0.2") require.Equal(t, []string{"invalid", "127.0.0.1", "127.0.0.2"}, c.IPs()) c.Request().Header.Set(HeaderXForwardedFor, "127.0.0.1, invalid, 127.0.0.2") require.Equal(t, []string{"127.0.0.1", "invalid", "127.0.0.2"}, c.IPs()) // ensure that the ordering of IPs in the header is maintained c.Request().Header.Set(HeaderXForwardedFor, "127.0.0.3, 127.0.0.1, 127.0.0.2") require.Equal(t, []string{"127.0.0.3", "127.0.0.1", "127.0.0.2"}, c.IPs()) // ensure for IPv6 c.Request().Header.Set(HeaderXForwardedFor, "9396:9549:b4f7:8ed0:4791:1330:8c06:e62d, invalid, 2345:0425:2CA1::0567:5673:23b5") require.Equal(t, []string{"9396:9549:b4f7:8ed0:4791:1330:8c06:e62d", "invalid", "2345:0425:2CA1::0567:5673:23b5"}, c.IPs()) // empty header c.Request().Header.Set(HeaderXForwardedFor, "") require.Empty(t, c.IPs()) // missing header c.Request() require.Empty(t, c.IPs()) } func Test_Ctx_IPs_With_IP_Validation(t *testing.T) { t.Parallel() app := New(Config{EnableIPValidation: true}) c := app.AcquireCtx(&fasthttp.RequestCtx{}) // normal happy path test case c.Request().Header.Set(HeaderXForwardedFor, "127.0.0.1, 127.0.0.2, 127.0.0.3") require.Equal(t, []string{"127.0.0.1", "127.0.0.2", "127.0.0.3"}, c.IPs()) // inconsistent space formatting c.Request().Header.Set(HeaderXForwardedFor, "127.0.0.1,127.0.0.2 ,127.0.0.3") require.Equal(t, []string{"127.0.0.1", "127.0.0.2", "127.0.0.3"}, c.IPs()) // invalid IPs are in the header c.Request().Header.Set(HeaderXForwardedFor, "invalid, 127.0.0.1, 127.0.0.2") require.Equal(t, []string{"127.0.0.1", "127.0.0.2"}, c.IPs()) c.Request().Header.Set(HeaderXForwardedFor, "127.0.0.1, invalid, 127.0.0.2") require.Equal(t, []string{"127.0.0.1", "127.0.0.2"}, c.IPs()) // ensure that the ordering of IPs in the header is maintained c.Request().Header.Set(HeaderXForwardedFor, "127.0.0.3, 127.0.0.1, 127.0.0.2") require.Equal(t, []string{"127.0.0.3", "127.0.0.1", "127.0.0.2"}, c.IPs()) // ensure for IPv6 c.Request().Header.Set(HeaderXForwardedFor, "f037:825e:eadb:1b7b:1667:6f0a:5356:f604, invalid, 9396:9549:b4f7:8ed0:4791:1330:8c06:e62d") require.Equal(t, []string{"f037:825e:eadb:1b7b:1667:6f0a:5356:f604", "9396:9549:b4f7:8ed0:4791:1330:8c06:e62d"}, c.IPs()) // empty header c.Request().Header.Set(HeaderXForwardedFor, "") require.Empty(t, c.IPs()) // missing header c.Request() require.Empty(t, c.IPs()) } // go test -v -run=^$ -bench=Benchmark_Ctx_IPs -benchmem -count=4 func Benchmark_Ctx_IPs(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderXForwardedFor, "127.0.0.1, invalid, 127.0.0.1") var res []string b.ReportAllocs() for b.Loop() { res = c.IPs() } require.Equal(b, []string{"127.0.0.1", "invalid", "127.0.0.1"}, res) } func Benchmark_Ctx_IPs_v6(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(c) c.Request().Header.Set(HeaderXForwardedFor, "f037:825e:eadb:1b7b:1667:6f0a:5356:f604, invalid, 2345:0425:2CA1::0567:5673:23b5") var res []string b.ReportAllocs() for b.Loop() { res = c.IPs() } require.Equal(b, []string{"f037:825e:eadb:1b7b:1667:6f0a:5356:f604", "invalid", "2345:0425:2CA1::0567:5673:23b5"}, res) } func Benchmark_Ctx_IPs_With_IP_Validation(b *testing.B) { app := New(Config{EnableIPValidation: true}) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderXForwardedFor, "127.0.0.1, invalid, 127.0.0.1") var res []string b.ReportAllocs() for b.Loop() { res = c.IPs() } require.Equal(b, []string{"127.0.0.1", "127.0.0.1"}, res) } func Benchmark_Ctx_IPs_v6_With_IP_Validation(b *testing.B) { app := New(Config{EnableIPValidation: true}) c := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(c) c.Request().Header.Set(HeaderXForwardedFor, "2345:0425:2CA1:0000:0000:0567:5673:23b5, invalid, 2345:0425:2CA1::0567:5673:23b5") var res []string b.ReportAllocs() for b.Loop() { res = c.IPs() } require.Equal(b, []string{"2345:0425:2CA1:0000:0000:0567:5673:23b5", "2345:0425:2CA1::0567:5673:23b5"}, res) } func Benchmark_Ctx_IP_With_ProxyHeader(b *testing.B) { app := New(Config{ ProxyHeader: HeaderXForwardedFor, TrustProxy: true, TrustProxyConfig: TrustProxyConfig{ Loopback: true, }, }) fastCtx := &fasthttp.RequestCtx{} fastCtx.SetRemoteAddr(net.Addr(&net.TCPAddr{IP: net.ParseIP("127.0.0.1")})) c := app.AcquireCtx(fastCtx) c.Request().Header.Set(HeaderXForwardedFor, "127.0.0.1") var res string b.ReportAllocs() for b.Loop() { res = c.IP() } require.Equal(b, "127.0.0.1", res) } func Benchmark_Ctx_IP_With_ProxyHeader_and_IP_Validation(b *testing.B) { app := New(Config{ ProxyHeader: HeaderXForwardedFor, TrustProxy: true, TrustProxyConfig: TrustProxyConfig{ Loopback: true, }, EnableIPValidation: true, }) fastCtx := &fasthttp.RequestCtx{} fastCtx.SetRemoteAddr(net.Addr(&net.TCPAddr{IP: net.ParseIP("127.0.0.1")})) c := app.AcquireCtx(fastCtx) c.Request().Header.Set(HeaderXForwardedFor, "127.0.0.1") var res string b.ReportAllocs() for b.Loop() { res = c.IP() } require.Equal(b, "127.0.0.1", res) } func Benchmark_Ctx_IP(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request() var res string b.ReportAllocs() for b.Loop() { res = c.IP() } require.Equal(b, "0.0.0.0", res) } // go test -run Test_Ctx_Is func Test_Ctx_Is(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderContentType, MIMETextHTML+"; boundary=something") require.True(t, c.Is(".html")) require.True(t, c.Is("html")) require.False(t, c.Is("json")) require.False(t, c.Is(".json")) require.False(t, c.Is("")) require.False(t, c.Is(".foooo")) c.Request().Header.Set(HeaderContentType, MIMEApplicationJSONCharsetUTF8) require.False(t, c.Is("html")) require.True(t, c.Is("json")) require.True(t, c.Is(".json")) c.Request().Header.Set(HeaderContentType, " application/json;charset=UTF-8") require.False(t, c.Is("html")) require.True(t, c.Is("json")) require.True(t, c.Is(".json")) c.Request().Header.Set(HeaderContentType, MIMEApplicationXMLCharsetUTF8) require.False(t, c.Is("html")) require.True(t, c.Is("xml")) require.True(t, c.Is(".xml")) c.Request().Header.Set(HeaderContentType, MIMETextPlain) require.False(t, c.Is("html")) require.True(t, c.Is("txt")) require.True(t, c.Is(".txt")) // case-insensitive and trimmed c.Request().Header.Set(HeaderContentType, "APPLICATION/JSON; charset=utf-8") require.True(t, c.Is("json")) require.True(t, c.Is(".json")) // mismatched subtype should not match c.Request().Header.Set(HeaderContentType, "application/json+xml") require.False(t, c.Is("json")) require.False(t, c.Is(".json")) } // go test -v -run=^$ -bench=Benchmark_Ctx_Is -benchmem -count=4 func Benchmark_Ctx_Is(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderContentType, MIMEApplicationJSON) var res bool b.ReportAllocs() for b.Loop() { _ = c.Is(".json") res = c.Is("json") } require.True(b, res) } // go test -run Test_Ctx_Locals func Test_Ctx_Locals(t *testing.T) { t.Parallel() app := New() app.Use(func(c Ctx) error { c.Locals("john", "doe") return c.Next() }) app.Get("/test", func(c Ctx) error { require.Equal(t, "doe", c.Locals("john")) return nil }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/test", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") } // go test -run Test_Ctx_Deadline func Test_Ctx_Deadline(t *testing.T) { t.Parallel() app := New() app.Use(func(c Ctx) error { return c.Next() }) app.Get("/test", func(c Ctx) error { deadline, ok := c.Deadline() require.Equal(t, time.Time{}, deadline) require.False(t, ok) return nil }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/test", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") } // go test -run Test_Ctx_Done func Test_Ctx_Done(t *testing.T) { t.Parallel() app := New() app.Use(func(c Ctx) error { return c.Next() }) app.Get("/test", func(c Ctx) error { var nilChan <-chan struct{} require.Equal(t, nilChan, c.Done()) return nil }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/test", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") } // go test -run Test_Ctx_Err func Test_Ctx_Err(t *testing.T) { t.Parallel() app := New() app.Use(func(c Ctx) error { return c.Next() }) app.Get("/test", func(c Ctx) error { require.NoError(t, c.Err()) return nil }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/test", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") } // go test -run Test_Ctx_Value func Test_Ctx_Value(t *testing.T) { t.Parallel() app := New() app.Use(func(c Ctx) error { c.Locals("john", "doe") return c.Next() }) app.Get("/test", func(c Ctx) error { require.Equal(t, "doe", c.Value("john")) return nil }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/test", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") } // go test -run Test_Ctx_Value_AfterRelease func Test_Ctx_Value_AfterRelease(t *testing.T) { t.Parallel() app := New() var ctx Ctx app.Get("/test", func(c Ctx) error { ctx = c c.Locals("test", "value") return nil }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/test", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") // After the handler completes, the context is released and fasthttp is nil // Value should return nil instead of panicking require.NotPanics(t, func() { val := ctx.Value("test") require.Nil(t, val) }) } // go test -run Test_Ctx_Value_InGoroutine func Test_Ctx_Value_InGoroutine(t *testing.T) { t.Parallel() app := New() done := make(chan bool, 1) // Buffered to prevent goroutine leak errCh := make(chan error, 1) // Channel to communicate errors from goroutine // Use a synchronization point to avoid race detector complaints // while still testing the defensive nil behavior start := make(chan struct{}) app.Get("/test", func(c Ctx) error { c.Locals("test", "value") // Simulate a goroutine that uses the context (like minio.GetObject) go func() { // Wait for handler to complete and context to be released <-start defer func() { if r := recover(); r != nil { errCh <- fmt.Errorf("panic in goroutine: %v", r) return } done <- true }() // This simulates what happens when minio or other libraries // use the fiber.Ctx as a context.Context in a goroutine // The Value method should not panic even if fasthttp is nil val := c.Value("test") // The value might be nil if the context was released _ = val }() return nil }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/test", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") // Signal goroutine to proceed - context has been released after app.Test returns // since the handler (and its deferred ReleaseCtx) has completed close(start) // Wait for goroutine to complete with timeout select { case <-done: // Success - goroutine completed without panic case err := <-errCh: t.Fatalf("error from goroutine: %v", err) case <-time.After(1 * time.Second): t.Fatal("test timed out waiting for goroutine") } } // go test -run Test_Ctx_Context func Test_Ctx_Context(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Run("Nil_Context", func(t *testing.T) { t.Parallel() ctx := c.Context() require.Equal(t, ctx, context.Background()) }) t.Run("ValueContext", func(t *testing.T) { t.Parallel() var testKey testContextKey testValue := "Test Value" ctx := context.WithValue(context.Background(), testKey, testValue) require.Equal(t, testValue, ctx.Value(testKey)) }) } func Test_Ctx_AccessAfterHandlerPanics(t *testing.T) { t.Parallel() app := New() var ctx Ctx app.Get("/test", func(c Ctx) error { ctx = c return nil }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/test", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") require.Panics(t, func() { ctx.Locals("foo") }) } func Test_Ctx_Context_AfterHandlerPanics(t *testing.T) { t.Parallel() app := New() var ctx Ctx app.Get("/test", func(c Ctx) error { ctx = c return nil }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/test", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") // After the fix, Context() returns context.Background() instead of panicking require.NotPanics(t, func() { c := ctx.Context() require.NotNil(t, c) require.Equal(t, context.Background(), c) }) } // go test -run Test_Ctx_Request_Response_AfterRelease func Test_Ctx_Request_Response_AfterRelease(t *testing.T) { t.Parallel() app := New() var ctx Ctx app.Get("/test", func(c Ctx) error { ctx = c return nil }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/test", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") // After the handler completes and context is released, // Request() and Response() should return nil instead of panicking require.NotPanics(t, func() { req := ctx.Request() require.Nil(t, req) res := ctx.Response() require.Nil(t, res) }) } // go test -run Test_Ctx_SetContext func Test_Ctx_SetContext(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) var testKey testContextKey testValue := "Test Value" ctx := context.WithValue(context.Background(), testKey, testValue) c.SetContext(ctx) require.Equal(t, testValue, c.Context().Value(testKey)) } type contextHelperTestKey struct{} func Test_Ctx_StoreInContext_Config(t *testing.T) { t.Parallel() t.Run("disabled", func(t *testing.T) { t.Parallel() app := New() raw := &fasthttp.RequestCtx{} c := app.AcquireCtx(raw) defer app.ReleaseCtx(c) StoreInContext(c, contextHelperTestKey{}, "locals-only") value, ok := c.Locals(contextHelperTestKey{}).(string) require.True(t, ok) require.Equal(t, "locals-only", value) require.Nil(t, c.Context().Value(contextHelperTestKey{})) }) t.Run("enabled", func(t *testing.T) { t.Parallel() app := New(Config{PassLocalsToContext: true}) raw := &fasthttp.RequestCtx{} c := app.AcquireCtx(raw) defer app.ReleaseCtx(c) StoreInContext(c, contextHelperTestKey{}, "both") value, ok := c.Locals(contextHelperTestKey{}).(string) require.True(t, ok) require.Equal(t, "both", value) contextValue, ok := c.Context().Value(contextHelperTestKey{}).(string) require.True(t, ok) require.Equal(t, "both", contextValue) }) } func Test_Ctx_ValueFromContext_Config(t *testing.T) { t.Parallel() t.Run("fiber ctx disabled reads locals", func(t *testing.T) { t.Parallel() app := New() raw := &fasthttp.RequestCtx{} c := app.AcquireCtx(raw) defer app.ReleaseCtx(c) c.Locals(contextHelperTestKey{}, "locals") c.SetContext(context.WithValue(context.Background(), contextHelperTestKey{}, "context")) value, ok := ValueFromContext[string](c, contextHelperTestKey{}) require.True(t, ok) require.Equal(t, "locals", value) }) t.Run("fiber ctx enabled still reads locals", func(t *testing.T) { t.Parallel() app := New(Config{PassLocalsToContext: true}) raw := &fasthttp.RequestCtx{} c := app.AcquireCtx(raw) defer app.ReleaseCtx(c) c.Locals(contextHelperTestKey{}, "locals") c.SetContext(context.WithValue(context.Background(), contextHelperTestKey{}, "context")) value, ok := ValueFromContext[string](c, contextHelperTestKey{}) require.True(t, ok) require.Equal(t, "locals", value) }) t.Run("fiber custom ctx enabled still reads locals", func(t *testing.T) { t.Parallel() app := NewWithCustomCtx(func(app *App) CustomCtx { return &customCtx{DefaultCtx: *NewDefaultCtx(app)} }, Config{PassLocalsToContext: true}) raw := &fasthttp.RequestCtx{} c := app.AcquireCtx(raw) defer app.ReleaseCtx(c) c.Locals(contextHelperTestKey{}, "locals") c.SetContext(context.WithValue(context.Background(), contextHelperTestKey{}, "context")) value, ok := ValueFromContext[string](c, contextHelperTestKey{}) require.True(t, ok) require.Equal(t, "locals", value) }) t.Run("fasthttp request ctx", func(t *testing.T) { t.Parallel() raw := &fasthttp.RequestCtx{} raw.SetUserValue(contextHelperTestKey{}, "value") value, ok := ValueFromContext[string](raw, contextHelperTestKey{}) require.True(t, ok) require.Equal(t, "value", value) }) t.Run("context.Context", func(t *testing.T) { t.Parallel() ctx := context.WithValue(context.Background(), contextHelperTestKey{}, "value") value, ok := ValueFromContext[string](ctx, contextHelperTestKey{}) require.True(t, ok) require.Equal(t, "value", value) }) } // go test -run Test_Ctx_Context_Multiple_Requests func Test_Ctx_Context_Multiple_Requests(t *testing.T) { t.Parallel() var testKey testContextKey testValue := "foobar-value" app := New() app.Get("/", func(c Ctx) error { ctx := c.Context() if ctx.Value(testKey) != nil { return c.SendStatus(StatusInternalServerError) } input := utils.CopyString(Query(c, "input", "NO_VALUE")) ctx = context.WithValue(ctx, testKey, fmt.Sprintf("%s_%s", testValue, input)) c.SetContext(ctx) return c.Status(StatusOK).SendString(fmt.Sprintf("resp_%s_returned", input)) }) // Consecutive Requests for i := 1; i <= 10; i++ { t.Run(fmt.Sprintf("request_%d", i), func(t *testing.T) { t.Parallel() resp, err := app.Test(httptest.NewRequest(MethodGet, fmt.Sprintf("/?input=%d", i), http.NoBody)) require.NoError(t, err, "Unexpected error from response") require.Equal(t, StatusOK, resp.StatusCode, "context.Context returned from c.Context() is reused") b, err := io.ReadAll(resp.Body) require.NoError(t, err, "Unexpected error from reading response body") require.Equal(t, fmt.Sprintf("resp_%d_returned", i), string(b), "response text incorrect") }) } } // go test -run Test_Ctx_Locals_Generic func Test_Ctx_Locals_Generic(t *testing.T) { t.Parallel() app := New() app.Use(func(c Ctx) error { Locals(c, "john", "doe") Locals(c, "age", 18) Locals(c, "isHuman", true) return c.Next() }) app.Get("/test", func(c Ctx) error { require.Equal(t, "doe", Locals[string](c, "john")) require.Equal(t, 18, Locals[int](c, "age")) require.True(t, Locals[bool](c, "isHuman")) require.Equal(t, 0, Locals[int](c, "isHuman")) return nil }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/test", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") } // go test -run Test_Ctx_Locals_GenericCustomStruct func Test_Ctx_Locals_GenericCustomStruct(t *testing.T) { t.Parallel() type User struct { name string age int } app := New() app.Use(func(c Ctx) error { Locals(c, "user", User{name: "john", age: 18}) return c.Next() }) app.Use("/test", func(c Ctx) error { require.Equal(t, User{name: "john", age: 18}, Locals[User](c, "user")) return nil }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/test", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") } // go test -run Test_Ctx_Method func Test_Ctx_Method(t *testing.T) { t.Parallel() fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(MethodGet) app := New() c := app.AcquireCtx(fctx) require.Equal(t, MethodGet, c.Method()) c.Method(MethodPost) require.Equal(t, MethodPost, c.Method()) c.Method("MethodInvalid") require.Equal(t, MethodPost, c.Method()) } // go test -run Test_Ctx_ClientHelloInfo func Test_Ctx_ClientHelloInfo(t *testing.T) { t.Parallel() app := New() app.Get("/ServerName", func(c Ctx) error { result := c.ClientHelloInfo() if result != nil { return c.SendString(result.ServerName) } return c.SendString("ClientHelloInfo is nil") }) app.Get("/SignatureSchemes", func(c Ctx) error { result := c.ClientHelloInfo() if result != nil { return c.JSON(result.SignatureSchemes) } return c.SendString("ClientHelloInfo is nil") }) app.Get("/SupportedVersions", func(c Ctx) error { result := c.ClientHelloInfo() if result != nil { return c.JSON(result.SupportedVersions) } return c.SendString("ClientHelloInfo is nil") }) // Test without TLS handler resp, err := app.Test(httptest.NewRequest(MethodGet, "/ServerName", http.NoBody)) require.NoError(t, err) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, []byte("ClientHelloInfo is nil"), body) // Test with TLS Handler const ( pssWithSHA256 = 0x0804 versionTLS13 = 0x0304 ) app.tlsHandler = &TLSHandler{clientHelloInfo: &tls.ClientHelloInfo{ ServerName: "example.golang", SignatureSchemes: []tls.SignatureScheme{pssWithSHA256}, SupportedVersions: []uint16{versionTLS13}, }} // Test ServerName resp, err = app.Test(httptest.NewRequest(MethodGet, "/ServerName", http.NoBody)) require.NoError(t, err) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, []byte("example.golang"), body) // Test SignatureSchemes resp, err = app.Test(httptest.NewRequest(MethodGet, "/SignatureSchemes", http.NoBody)) require.NoError(t, err) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "["+strconv.Itoa(pssWithSHA256)+"]", string(body)) // Test SupportedVersions resp, err = app.Test(httptest.NewRequest(MethodGet, "/SupportedVersions", http.NoBody)) require.NoError(t, err) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "["+strconv.Itoa(versionTLS13)+"]", string(body)) } // go test -run Test_Ctx_InvalidMethod func Test_Ctx_InvalidMethod(t *testing.T) { t.Parallel() app := New() app.Get("/", func(_ Ctx) error { return nil }) fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod("InvalidMethod") fctx.Request.SetRequestURI("/") app.Handler()(fctx) require.Equal(t, 501, fctx.Response.StatusCode()) require.Equal(t, []byte("Not Implemented"), fctx.Response.Body()) } // go test -run Test_Ctx_MultipartForm func Test_Ctx_MultipartForm(t *testing.T) { t.Parallel() app := New() app.Post("/test", func(c Ctx) error { result, err := c.MultipartForm() require.NoError(t, err) require.Equal(t, "john", result.Value["name"][0]) return nil }) body := &bytes.Buffer{} writer := multipart.NewWriter(body) require.NoError(t, writer.WriteField("name", "john")) require.NoError(t, writer.Close()) req := httptest.NewRequest(MethodPost, "/test", body) req.Header.Set(HeaderContentType, "multipart/form-data; boundary="+writer.Boundary()) req.Header.Set(HeaderContentLength, strconv.Itoa(len(body.Bytes()))) resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") } // go test -v -run=^$ -bench=Benchmark_Ctx_MultipartForm -benchmem -count=4 func Benchmark_Ctx_MultipartForm(b *testing.B) { app := New() app.Post("/", func(c Ctx) error { _, err := c.MultipartForm() return err }) c := &fasthttp.RequestCtx{} body := []byte("--b\r\nContent-Disposition: form-data; name=\"name\"\r\n\r\njohn\r\n--b--") c.Request.SetBody(body) c.Request.Header.SetContentType(MIMEMultipartForm + `;boundary="b"`) c.Request.Header.SetContentLength(len(body)) h := app.Handler() b.ReportAllocs() for b.Loop() { h(c) } } // go test -run Test_Ctx_OriginalURL func Test_Ctx_OriginalURL(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.SetRequestURI("http://google.com/test?search=demo") require.Equal(t, "http://google.com/test?search=demo", c.OriginalURL()) } // go test -race -run Test_Ctx_Params func Test_Ctx_Params(t *testing.T) { t.Parallel() app := New() app.Get("/test/:user", func(c Ctx) error { require.Equal(t, "john", c.Params("user")) return nil }) app.Get("/test2/*", func(c Ctx) error { require.Equal(t, "im/a/cookie", c.Params("*")) return nil }) app.Get("/test3/*/blafasel/*", func(c Ctx) error { require.Equal(t, "1111", c.Params("*1")) require.Equal(t, 1111, Params(c, "*1", 0)) require.Equal(t, "2222", c.Params("*2")) require.Equal(t, 2222, Params(c, "*2", 0)) require.Equal(t, "1111", c.Params("*")) require.Equal(t, 1111, Params(c, "*", 0)) return nil }) app.Get("/test4/:optional?", func(c Ctx) error { require.Empty(t, c.Params("optional")) require.Equal(t, "default", Params(c, "optional", "default")) return nil }) app.Get("/test5/:id/:Id", func(c Ctx) error { require.Equal(t, "first", c.Params("id")) require.Equal(t, "first", c.Params("Id")) return nil }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/test/john", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") resp, err = app.Test(httptest.NewRequest(MethodGet, "/test2/im/a/cookie", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") resp, err = app.Test(httptest.NewRequest(MethodGet, "/test3/1111/blafasel/2222", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") resp, err = app.Test(httptest.NewRequest(MethodGet, "/test4", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") resp, err = app.Test(httptest.NewRequest(MethodGet, "/test5/first/second", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") } func Test_Ctx_Params_ErrorHandler_Panic_Issue_2832(t *testing.T) { t.Parallel() app := New(Config{ ErrorHandler: func(c Ctx, _ error) error { return c.SendString(c.Params("user")) }, BodyLimit: 1 * 1024, }) app.Get("/test/:user", func(_ Ctx) error { return NewError(StatusInternalServerError, "error") }) largeBody := make([]byte, 2*1024) _, err := app.Test(httptest.NewRequest(MethodGet, "/test/john", bytes.NewReader(largeBody))) require.ErrorIs(t, err, fasthttp.ErrBodyTooLarge, "app.Test(req)") } func Test_Ctx_Params_Case_Sensitive(t *testing.T) { t.Parallel() app := New(Config{CaseSensitive: true}) app.Get("/test/:User", func(c Ctx) error { require.Equal(t, "john", c.Params("User")) require.Empty(t, c.Params("user")) return nil }) app.Get("/test2/:id/:Id", func(c Ctx) error { require.Equal(t, "first", c.Params("id")) require.Equal(t, "second", c.Params("Id")) return nil }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/test/john", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") resp, err = app.Test(httptest.NewRequest(MethodGet, "/test2/first/second", http.NoBody)) require.NoError(t, err) require.Equal(t, StatusOK, resp.StatusCode, "Status code") } func Test_Ctx_Params_Immutable(t *testing.T) { t.Parallel() app := New(Config{Immutable: true}) c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed c.route = &Route{Params: []string{"user"}} c.path = []byte("/test/john") c.values[0] = c.app.toString(c.path[6:]) param := c.Params("user") c.path[6] = 'p' c.path[7] = 'a' c.path[8] = 'u' c.path[9] = 'l' require.Equal(t, "john", param) } // go test -v -run=^$ -bench=Benchmark_Ctx_Params -benchmem -count=4 func Benchmark_Ctx_Params(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed c.route = &Route{ Params: []string{ "param1", "param2", "param3", "param4", }, } c.values = [maxParams]string{ "john", "doe", "is", "awesome", } var res string b.ReportAllocs() for b.Loop() { _ = c.Params("param1") _ = c.Params("param2") _ = c.Params("param3") res = c.Params("param4") } require.Equal(b, "awesome", res) } // go test -run Test_Ctx_Path func Test_Ctx_Path(t *testing.T) { t.Parallel() app := New(Config{UnescapePath: true}) app.Get("/test/:user", func(c Ctx) error { require.Equal(t, "/Test/John", c.Path()) require.Equal(t, "/Test/John", string(c.Request().URI().Path())) // not strict && case-insensitive require.Equal(t, "/ABC/", c.Path("/ABC/")) require.Equal(t, "/ABC/", string(c.Request().URI().Path())) require.Equal(t, "/test/john/", c.Path("/test/john/")) require.Equal(t, "/test/john/", string(c.Request().URI().Path())) return nil }) // test with special chars app.Get("/specialChars/:name", func(c Ctx) error { require.Equal(t, "/specialChars/créer", c.Path()) // unescape is also working if you set the path afterwards require.Equal(t, "/اختبار/", c.Path("/%D8%A7%D8%AE%D8%AA%D8%A8%D8%A7%D8%B1/")) require.Equal(t, "/اختبار/", string(c.Request().URI().Path())) return nil }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/specialChars/cr%C3%A9er", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") } // go test -run Test_Ctx_Protocol func Test_Ctx_Protocol(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) require.Equal(t, "HTTP/1.1", c.Protocol()) c.Request().Header.SetProtocol("HTTP/2") require.Equal(t, "HTTP/2", c.Protocol()) } // go test -v -run=^$ -bench=Benchmark_Ctx_Protocol -benchmem -count=4 func Benchmark_Ctx_Protocol(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) var res string b.ReportAllocs() for b.Loop() { res = c.Protocol() } require.Equal(b, "HTTP/1.1", res) } // go test -run Test_Ctx_Scheme func Test_Ctx_Scheme(t *testing.T) { t.Parallel() app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{ Proxies: []string{"0.0.0.0"}, }, }) freq := &fasthttp.RequestCtx{} freq.SetRemoteAddr(net.Addr(&net.TCPAddr{IP: net.ParseIP("0.0.0.0")})) freq.Request.Header.Set("X-Forwarded", "invalid") c := app.AcquireCtx(freq) c.Request().Header.Set(HeaderXForwardedProto, schemeHTTPS) require.Equal(t, schemeHTTPS, c.Scheme()) c.Request().Header.Reset() c.Request().Header.Set(HeaderXForwardedProtocol, schemeHTTPS) require.Equal(t, schemeHTTPS, c.Scheme()) c.Request().Header.Reset() c.Request().Header.Set(HeaderXForwardedProto, "https, http") require.Equal(t, schemeHTTPS, c.Scheme()) c.Request().Header.Reset() c.Request().Header.Set(HeaderXForwardedProtocol, "https, http") require.Equal(t, schemeHTTPS, c.Scheme()) c.Request().Header.Reset() c.Request().Header.Set(HeaderXForwardedSsl, "on") require.Equal(t, schemeHTTPS, c.Scheme()) c.Request().Header.Reset() c.Request().Header.Set(HeaderXUrlScheme, schemeHTTPS) require.Equal(t, schemeHTTPS, c.Scheme()) c.Request().Header.Reset() require.Equal(t, schemeHTTP, c.Scheme()) } // go test -run Test_Ctx_Scheme_HeaderNormalization func Test_Ctx_Scheme_HeaderNormalization(t *testing.T) { t.Parallel() app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{ Proxies: []string{"0.0.0.0"}, }, }) freq := &fasthttp.RequestCtx{} freq.SetRemoteAddr(net.Addr(&net.TCPAddr{IP: net.ParseIP("0.0.0.0")})) c := app.AcquireCtx(freq) c.Request().Header.Set("x-forwarded-proto", " HTTPS , http") require.Equal(t, schemeHTTPS, c.Scheme()) c.Request().Header.Reset() c.Request().Header.Set("X-FORWARDED-PROTOCOL", " HTTPS") require.Equal(t, schemeHTTPS, c.Scheme()) c.Request().Header.Reset() c.Request().Header.Set("x-url-scheme", " HTTPS ") require.Equal(t, schemeHTTPS, c.Scheme()) c.Request().Header.Reset() c.Request().Header.Set("x-Forwarded-ProToCol", " HTTPS ") require.Equal(t, schemeHTTPS, c.Scheme()) c.Request().Header.Reset() } // go test -v -run=^$ -bench=Benchmark_Ctx_Scheme -benchmem -count=4 func Benchmark_Ctx_Scheme(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) var res string b.ReportAllocs() for b.Loop() { res = c.Scheme() } require.Equal(b, "http", res) } // go test -run Test_Ctx_Scheme_TrustedProxy func Test_Ctx_Scheme_TrustedProxy(t *testing.T) { t.Parallel() app := New(Config{TrustProxy: true, TrustProxyConfig: TrustProxyConfig{Proxies: []string{"0.0.0.0"}}}) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderXForwardedProto, schemeHTTPS) require.Equal(t, schemeHTTPS, c.Scheme()) c.Request().Header.Reset() c.Request().Header.Set(HeaderXForwardedProtocol, schemeHTTPS) require.Equal(t, schemeHTTPS, c.Scheme()) c.Request().Header.Reset() c.Request().Header.Set(HeaderXForwardedSsl, "on") require.Equal(t, schemeHTTPS, c.Scheme()) c.Request().Header.Reset() c.Request().Header.Set(HeaderXUrlScheme, schemeHTTPS) require.Equal(t, schemeHTTPS, c.Scheme()) c.Request().Header.Reset() require.Equal(t, schemeHTTP, c.Scheme()) } // go test -run Test_Ctx_Scheme_TrustedProxyRange func Test_Ctx_Scheme_TrustedProxyRange(t *testing.T) { t.Parallel() app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{Proxies: []string{"0.0.0.0/30"}}, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderXForwardedProto, schemeHTTPS) require.Equal(t, schemeHTTPS, c.Scheme()) c.Request().Header.Reset() c.Request().Header.Set(HeaderXForwardedProtocol, schemeHTTPS) require.Equal(t, schemeHTTPS, c.Scheme()) c.Request().Header.Reset() c.Request().Header.Set(HeaderXForwardedSsl, "on") require.Equal(t, schemeHTTPS, c.Scheme()) c.Request().Header.Reset() c.Request().Header.Set(HeaderXUrlScheme, schemeHTTPS) require.Equal(t, schemeHTTPS, c.Scheme()) c.Request().Header.Reset() require.Equal(t, schemeHTTP, c.Scheme()) } // go test -run Test_Ctx_Scheme_UntrustedProxyRange func Test_Ctx_Scheme_UntrustedProxyRange(t *testing.T) { t.Parallel() app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{Proxies: []string{"1.1.1.1/30"}}, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderXForwardedProto, schemeHTTPS) require.Equal(t, schemeHTTP, c.Scheme()) c.Request().Header.Reset() c.Request().Header.Set(HeaderXForwardedProtocol, schemeHTTPS) require.Equal(t, schemeHTTP, c.Scheme()) c.Request().Header.Reset() c.Request().Header.Set(HeaderXForwardedSsl, "on") require.Equal(t, schemeHTTP, c.Scheme()) c.Request().Header.Reset() c.Request().Header.Set(HeaderXUrlScheme, schemeHTTPS) require.Equal(t, schemeHTTP, c.Scheme()) c.Request().Header.Reset() require.Equal(t, schemeHTTP, c.Scheme()) } // go test -run Test_Ctx_Scheme_UnTrustedProxy func Test_Ctx_Scheme_UnTrustedProxy(t *testing.T) { t.Parallel() app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{Proxies: []string{"0.8.0.1"}}, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderXForwardedProto, schemeHTTPS) require.Equal(t, schemeHTTP, c.Scheme()) c.Request().Header.Reset() c.Request().Header.Set(HeaderXForwardedProtocol, schemeHTTPS) require.Equal(t, schemeHTTP, c.Scheme()) c.Request().Header.Reset() c.Request().Header.Set(HeaderXForwardedSsl, "on") require.Equal(t, schemeHTTP, c.Scheme()) c.Request().Header.Reset() c.Request().Header.Set(HeaderXUrlScheme, schemeHTTPS) require.Equal(t, schemeHTTP, c.Scheme()) c.Request().Header.Reset() require.Equal(t, schemeHTTP, c.Scheme()) } // go test -run Test_Ctx_Query func Test_Ctx_Query(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().URI().SetQueryString("search=john&age=20") require.Equal(t, "john", c.Query("search")) require.Equal(t, "20", c.Query("age")) require.Equal(t, "default", c.Query("unknown", "default")) // test with generic require.Equal(t, "john", Query[string](c, "search")) require.Equal(t, "20", Query[string](c, "age")) require.Equal(t, "default", Query(c, "unknown", "default")) } // go test -v -run=^$ -bench=Benchmark_Ctx_Query -benchmem -count=4 func Benchmark_Ctx_Query(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().URI().SetQueryString("search=john&age=8") var res string b.ReportAllocs() for b.Loop() { res = Query[string](c, "search") } require.Equal(b, "john", res) } // go test -run Test_Ctx_Range func Test_Ctx_Range(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) testRange := func(header string, ranges ...RangeSet) { c.Request().Header.Set(HeaderRange, header) result, err := c.Range(1000) if len(ranges) == 0 { require.Error(t, err) } else { require.Equal(t, "bytes", result.Type) require.NoError(t, err) } require.Len(t, ranges, len(result.Ranges)) for i := range ranges { require.Equal(t, ranges[i], result.Ranges[i]) } } testRange("bytes=500") testRange("bytes=") testRange("bytes=500=") testRange("bytes=500-300") testRange("bytes=a-700", RangeSet{Start: 300, End: 999}) testRange("bytes=500-b", RangeSet{Start: 500, End: 999}) testRange("bytes=500-1000", RangeSet{Start: 500, End: 999}) testRange("bytes=500-700", RangeSet{Start: 500, End: 700}) testRange("bytes=0-0,2-1000", RangeSet{Start: 0, End: 0}, RangeSet{Start: 2, End: 999}) testRange("bytes=0-99,450-549,-100", RangeSet{Start: 0, End: 99}, RangeSet{Start: 450, End: 549}, RangeSet{Start: 900, End: 999}) testRange("bytes=500-700,601-999", RangeSet{Start: 500, End: 700}, RangeSet{Start: 601, End: 999}) testRange("bytes= 0-1", RangeSet{Start: 0, End: 1}) testRange("seconds=0-1") } func Test_Ctx_Range_LargeFile(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(c) size := int64(math.MaxInt32) + 1024 start := int64(math.MaxInt32) + 10 end := start + 50 c.Request().Header.Set(HeaderRange, fmt.Sprintf("bytes=%d-%d", start, end)) result, err := c.Range(size) require.NoError(t, err) require.Equal(t, "bytes", result.Type) require.Len(t, result.Ranges, 1) require.Equal(t, start, result.Ranges[0].Start) require.Equal(t, end, result.Ranges[0].End) c.Request().Header.Set(HeaderRange, "bytes=-200") result, err = c.Range(size) require.NoError(t, err) require.Equal(t, size-200, result.Ranges[0].Start) require.Equal(t, size-1, result.Ranges[0].End) } func Test_Ctx_Range_Overflow(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(c) tooBig := uint64((math.MaxUint64 >> 1) + 1) c.Request().Header.Set(HeaderRange, fmt.Sprintf("bytes=%d-100", tooBig)) _, err := c.Range(math.MaxInt64) require.ErrorIs(t, err, ErrRangeMalformed) c.Request().Header.Set(HeaderRange, fmt.Sprintf("bytes=0-%d", tooBig)) _, err = c.Range(math.MaxInt64) require.ErrorIs(t, err, ErrRangeMalformed) } func Test_Ctx_Range_Unsatisfiable(t *testing.T) { t.Parallel() app := New() app.Get("/", func(c Ctx) error { _, err := c.Range(10) if err != nil { return err } return c.SendString("ok") }) req := httptest.NewRequest(MethodGet, "http://example.com/", http.NoBody) req.Header.Set(HeaderRange, "bytes=20-30") resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, StatusRequestedRangeNotSatisfiable, resp.StatusCode) require.Equal(t, "bytes */10", resp.Header.Get(HeaderContentRange)) } func Test_Ctx_Range_TooManyRanges(t *testing.T) { t.Parallel() app := New(Config{MaxRanges: 2}) app.Get("/", func(c Ctx) error { _, err := c.Range(10) if err != nil { return err } return c.SendString("ok") }) req := httptest.NewRequest(MethodGet, "http://example.com/", http.NoBody) req.Header.Set(HeaderRange, "bytes=0-1,2-3,4-5") resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, StatusRequestedRangeNotSatisfiable, resp.StatusCode) require.Equal(t, "bytes */10", resp.Header.Get(HeaderContentRange)) } func Test_Ctx_Range_SuffixNormalization(t *testing.T) { t.Parallel() body := bytes.Repeat([]byte("x"), 123) newApp := func() *App { app := New() app.Get("/", func(c Ctx) error { rangesHeader := c.Get(HeaderRange) if rangesHeader == "" { return c.Send(body) } rangeData, err := c.Range(int64(len(body))) if err != nil { return err } if len(rangeData.Ranges) != 1 { c.Status(StatusRequestedRangeNotSatisfiable) c.Set(HeaderContentRange, fmt.Sprintf("bytes */%d", len(body))) return ErrRequestedRangeNotSatisfiable } currentRange := rangeData.Ranges[0] contentRange := fmt.Sprintf("bytes %d-%d/%d", currentRange.Start, currentRange.End, len(body)) c.Set(HeaderContentRange, contentRange) statusCode := StatusPartialContent if currentRange.Start == 0 && currentRange.End == int64(len(body))-1 { statusCode = StatusOK } c.Status(statusCode) return c.Send(body[currentRange.Start : currentRange.End+1]) }) return app } testCases := []struct { name string rangeHeader string contentRange string statusCode int expectedBodySize int }{ { name: "suffix less than size", rangeHeader: "bytes=-20", contentRange: "bytes 103-122/123", statusCode: StatusPartialContent, expectedBodySize: 20, }, { name: "suffix equal to size", rangeHeader: "bytes=-123", contentRange: "bytes 0-122/123", statusCode: StatusOK, expectedBodySize: 123, }, { name: "suffix larger than size", rangeHeader: "bytes=-9999", contentRange: "bytes 0-122/123", statusCode: StatusOK, expectedBodySize: 123, }, { name: "unsatisfiable mixed ranges", rangeHeader: "bytes=200-400,700-1200", contentRange: "bytes */123", statusCode: StatusRequestedRangeNotSatisfiable, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() app := newApp() req := httptest.NewRequest(MethodGet, "http://example.com/", http.NoBody) if tc.rangeHeader != "" { req.Header.Set(HeaderRange, tc.rangeHeader) } resp, err := app.Test(req) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, resp.Body.Close()) }) require.Equal(t, tc.statusCode, resp.StatusCode) require.Equal(t, tc.contentRange, resp.Header.Get(HeaderContentRange)) if tc.expectedBodySize > 0 { bodyBytes, bodyErr := io.ReadAll(resp.Body) require.NoError(t, bodyErr) require.Len(t, bodyBytes, tc.expectedBodySize) } }) } } // go test -v -run=^$ -bench=Benchmark_Ctx_Range -benchmem -count=4 func Benchmark_Ctx_Range(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(c) testCases := []struct { str string start int64 end int64 }{ {str: "bytes=-700", start: 300, end: 999}, {str: "bytes=500-", start: 500, end: 999}, {str: "bytes=500-1000", start: 500, end: 999}, {str: "bytes=0-700,800-1000", start: 0, end: 700}, } for _, tc := range testCases { b.Run(tc.str, func(b *testing.B) { c.Request().Header.Set(HeaderRange, tc.str) var ( result Range err error ) for b.Loop() { result, err = c.Range(1000) } require.NoError(b, err) require.Equal(b, "bytes", result.Type) require.Equal(b, tc.start, result.Ranges[0].Start) require.Equal(b, tc.end, result.Ranges[0].End) }) } } // go test -run Test_Ctx_Route func Test_Ctx_Route(t *testing.T) { t.Parallel() app := New() app.Get("/test", func(c Ctx) error { require.Equal(t, "/test", c.Route().Path) return nil }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/test", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") c := app.AcquireCtx(&fasthttp.RequestCtx{}) require.Equal(t, "/", c.Route().Path) require.Equal(t, MethodGet, c.Route().Method) require.Empty(t, c.Route().Handlers) } // go test -run Test_Ctx_FullPath func Test_Ctx_FullPath(t *testing.T) { t.Parallel() app := New() app.Get("/test", func(c Ctx) error { require.Equal(t, "/test", c.FullPath()) return c.SendStatus(StatusOK) }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/test", http.NoBody)) require.NoError(t, err, "app.Test(req)") defer func() { require.NoError(t, resp.Body.Close()) }() require.Equal(t, StatusOK, resp.StatusCode) } // go test -run Test_Ctx_FullPath_Group func Test_Ctx_FullPath_Group(t *testing.T) { t.Parallel() app := New() group := app.Group("/v1") group.Get("/test", func(c Ctx) error { require.Equal(t, "/v1/test", c.FullPath()) return c.SendStatus(StatusOK) }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/v1/test", http.NoBody)) require.NoError(t, err, "app.Test(req)") defer func() { require.NoError(t, resp.Body.Close()) }() require.Equal(t, StatusOK, resp.StatusCode) } // go test -run Test_Ctx_FullPath_Middleware func Test_Ctx_FullPath_Middleware(t *testing.T) { t.Parallel() app := New() var recorded []string app.Use(func(c Ctx) error { recorded = append(recorded, c.FullPath()) if err := c.Next(); err != nil { return err } recorded = append(recorded, c.FullPath()) return nil }) app.Get("/test", func(c Ctx) error { require.Equal(t, "/test", c.FullPath()) return c.SendStatus(StatusOK) }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/test", http.NoBody)) require.NoError(t, err, "app.Test(req)") defer func() { require.NoError(t, resp.Body.Close()) }() require.Equal(t, StatusOK, resp.StatusCode) require.Equal(t, []string{"/", "/test"}, recorded) } // go test -run Test_Ctx_RouteNormalized func Test_Ctx_RouteNormalized(t *testing.T) { t.Parallel() app := New() app.Get("/test", func(c Ctx) error { require.Equal(t, "/test", c.Route().Path) return nil }) resp, err := app.Test(httptest.NewRequest(MethodGet, "//test", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusNotFound, resp.StatusCode, "Status code") } // go test -run Test_Ctx_SaveFile func Test_Ctx_SaveFile(t *testing.T) { // TODO We should clean this up t.Parallel() app := New() app.Post("/test", func(c Ctx) error { fh, err := c.Req().FormFile("file") require.NoError(t, err) tempFile, err := os.CreateTemp(os.TempDir(), "test-") require.NoError(t, err) defer func(file *os.File) { closeErr := file.Close() require.NoError(t, closeErr) closeErr = os.Remove(file.Name()) require.NoError(t, closeErr) }(tempFile) err = c.SaveFile(fh, tempFile.Name()) require.NoError(t, err) bs, err := os.ReadFile(tempFile.Name()) require.NoError(t, err) require.Equal(t, "hello world", string(bs)) return nil }) body := &bytes.Buffer{} writer := multipart.NewWriter(body) ioWriter, err := writer.CreateFormFile("file", "test") require.NoError(t, err) _, err = ioWriter.Write([]byte("hello world")) require.NoError(t, err) require.NoError(t, writer.Close()) req := httptest.NewRequest(MethodPost, "/test", body) req.Header.Set("Content-Type", writer.FormDataContentType()) req.Header.Set("Content-Length", strconv.Itoa(len(body.Bytes()))) resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") } func createMultipartFileHeader(t *testing.T, filename string, data []byte) *multipart.FileHeader { t.Helper() body := &bytes.Buffer{} writer := multipart.NewWriter(body) ioWriter, err := writer.CreateFormFile("file", filename) require.NoError(t, err) _, err = ioWriter.Write(data) require.NoError(t, err) require.NoError(t, writer.Close()) multipartReader := multipart.NewReader(bytes.NewReader(body.Bytes()), writer.Boundary()) form, err := multipartReader.ReadForm(int64(len(body.Bytes()))) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, form.RemoveAll()) }) files := form.File["file"] require.Len(t, files, 1) return files[0] } // go test -run Test_Ctx_SaveFileToStorage func Test_Ctx_SaveFileToStorage(t *testing.T) { t.Parallel() app := New() storage := memory.New() app.Post("/test", func(c Ctx) error { fh, err := c.FormFile("file") require.NoError(t, err) err = c.SaveFileToStorage(fh, "test", storage) require.NoError(t, err) file, err := storage.Get("test") require.Equal(t, []byte("hello world"), file) require.NoError(t, err) err = storage.Delete("test") require.NoError(t, err) return nil }) body := &bytes.Buffer{} writer := multipart.NewWriter(body) ioWriter, err := writer.CreateFormFile("file", "test") require.NoError(t, err) _, err = ioWriter.Write([]byte("hello world")) require.NoError(t, err) require.NoError(t, writer.Close()) req := httptest.NewRequest(MethodPost, "/test", body) req.Header.Set("Content-Type", writer.FormDataContentType()) req.Header.Set("Content-Length", strconv.Itoa(len(body.Bytes()))) resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") } func Test_Ctx_SaveFileToStorage_LargeUpload(t *testing.T) { t.Parallel() const ( bodyLimit = 8 * 1024 * 1024 fileSize = 5 * 1024 * 1024 ) app := New(Config{BodyLimit: bodyLimit}) storage := memory.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) fileHeader := createMultipartFileHeader(t, "large.bin", bytes.Repeat([]byte{'a'}, fileSize)) err := ctx.SaveFileToStorage(fileHeader, "test", storage) require.NoError(t, err) stored, err := storage.Get("test") require.NoError(t, err) require.Len(t, stored, fileSize) } func Test_Ctx_SaveFileToStorage_LimitExceeded(t *testing.T) { t.Parallel() const ( allowedSize = 1024 fileSize = allowedSize + 512 ) app := New(Config{BodyLimit: allowedSize}) storage := memory.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) fileHeader := createMultipartFileHeader(t, "too-large.bin", bytes.Repeat([]byte{'a'}, fileSize)) err := ctx.SaveFileToStorage(fileHeader, "test", storage) require.ErrorIs(t, err, fasthttp.ErrBodyTooLarge) } func Test_Ctx_SaveFileToStorage_LimitExceededUnknownSize(t *testing.T) { t.Parallel() const ( allowedSize = 1024 fileSize = allowedSize + 256 ) app := New(Config{BodyLimit: allowedSize}) storage := memory.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) fileHeader := createMultipartFileHeader(t, "unknown-size.bin", bytes.Repeat([]byte{'a'}, fileSize)) fileHeader.Size = -1 err := ctx.SaveFileToStorage(fileHeader, "test", storage) require.ErrorIs(t, err, fasthttp.ErrBodyTooLarge) } type captureStorage struct { t *testing.T data map[string][]byte } func (s *captureStorage) helperFailure(msg string, args ...any) { s.t.Helper() s.t.Fatalf(msg, args...) } func (s *captureStorage) ensureStore(key string, val []byte) { s.t.Helper() if key == "" || len(val) == 0 { return } if s.data == nil { s.data = make(map[string][]byte) } s.data[key] = val } func (s *captureStorage) GetWithContext(context.Context, string) ([]byte, error) { s.helperFailure("unexpected call to GetWithContext") return nil, nil } func (s *captureStorage) Get(string) ([]byte, error) { s.helperFailure("unexpected call to Get") return nil, nil } func (s *captureStorage) SetWithContext(_ context.Context, key string, val []byte, _ time.Duration) error { s.ensureStore(key, val) return nil } func (s *captureStorage) Set(key string, _ []byte, _ time.Duration) error { s.helperFailure("unexpected call to Set for key %q", key) return nil } func (s *captureStorage) DeleteWithContext(context.Context, string) error { s.helperFailure("unexpected call to DeleteWithContext") return nil } func (s *captureStorage) Delete(string) error { s.helperFailure("unexpected call to Delete") return nil } func (s *captureStorage) ResetWithContext(context.Context) error { s.data = nil return nil } func (s *captureStorage) Reset() error { s.data = nil return nil } func (s *captureStorage) Close() error { if s == nil { return nil } s.data = nil return nil } func Test_Ctx_SaveFileToStorage_BufferNotReused(t *testing.T) { t.Parallel() app := New() storage := &captureStorage{t: t} ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) const payloadSize = 1024 firstPayload := bytes.Repeat([]byte{'a'}, payloadSize) secondPayload := bytes.Repeat([]byte{'b'}, payloadSize) firstHeader := createMultipartFileHeader(t, "first.bin", firstPayload) require.NoError(t, ctx.SaveFileToStorage(firstHeader, "first", storage)) firstStored := storage.data["first"] require.Equal(t, firstPayload, firstStored) secondHeader := createMultipartFileHeader(t, "second.bin", secondPayload) require.NoError(t, ctx.SaveFileToStorage(secondHeader, "second", storage)) require.Equal(t, secondPayload, storage.data["second"]) require.Equal(t, firstPayload, firstStored, "stored data must not rely on pooled buffers") } type mockContextAwareStorage struct { t *testing.T key any expectedValue any validateCtx func(context.Context) cancel context.CancelFunc ctxMatched atomic.Bool cancelObserved atomic.Bool } func (s *mockContextAwareStorage) helperFailure(msg string, args ...any) { s.t.Helper() s.t.Fatalf(msg, args...) } func (s *mockContextAwareStorage) GetWithContext(context.Context, string) ([]byte, error) { s.helperFailure("unexpected call to GetWithContext") return nil, nil } func (s *mockContextAwareStorage) Get(string) ([]byte, error) { s.helperFailure("unexpected call to Get") return nil, nil } func (s *mockContextAwareStorage) SetWithContext(ctx context.Context, _ string, _ []byte, _ time.Duration) error { s.t.Helper() if s.validateCtx == nil { s.helperFailure("validateCtx must be configured before SetWithContext") } s.validateCtx(ctx) if val := ctx.Value(s.key); val != s.expectedValue { s.helperFailure("storage observed unexpected context value: %v", val) } s.ctxMatched.Store(true) if s.cancel != nil { s.cancel() } select { case <-ctx.Done(): s.cancelObserved.Store(true) case <-time.After(100 * time.Millisecond): s.helperFailure("storage did not observe context cancellation") } return nil } func (s *mockContextAwareStorage) Set(string, []byte, time.Duration) error { s.helperFailure("unexpected call to Set") return nil } func (s *mockContextAwareStorage) DeleteWithContext(context.Context, string) error { s.helperFailure("unexpected call to DeleteWithContext") return nil } func (s *mockContextAwareStorage) Delete(string) error { s.helperFailure("unexpected call to Delete") return nil } func (s *mockContextAwareStorage) ResetWithContext(context.Context) error { s.helperFailure("unexpected call to ResetWithContext") return nil } func (s *mockContextAwareStorage) Reset() error { s.helperFailure("unexpected call to Reset") return nil } func (s *mockContextAwareStorage) Close() error { if s == nil { return nil } return nil } // go test -run Test_Ctx_SaveFileToStorage_ContextPropagation func Test_Ctx_SaveFileToStorage_ContextPropagation(t *testing.T) { t.Parallel() type ctxKeyType string const ctxKey ctxKeyType = "storage-context-key" storage := &mockContextAwareStorage{t: t, key: ctxKey, expectedValue: "expected-context-value"} app := New() app.Post("/test", func(c Ctx) error { fh, err := c.FormFile("file") require.NoError(t, err) ctxWithValue := context.WithValue(context.Background(), ctxKey, storage.expectedValue) ctx, cancel := context.WithCancel(ctxWithValue) storage.validateCtx = func(received context.Context) { if received != ctx { storage.helperFailure("storage received unexpected context instance") } } storage.cancel = cancel c.SetContext(ctx) err = c.SaveFileToStorage(fh, "test", storage) require.NoError(t, err) require.True(t, storage.ctxMatched.Load(), "storage should receive the context installed on Ctx") require.True(t, storage.cancelObserved.Load(), "storage should observe context cancellation") return nil }) body := &bytes.Buffer{} writer := multipart.NewWriter(body) ioWriter, err := writer.CreateFormFile("file", "test") require.NoError(t, err) _, err = ioWriter.Write([]byte("hello world")) require.NoError(t, err) require.NoError(t, writer.Close()) req := httptest.NewRequest(MethodPost, "/test", body) req.Header.Set("Content-Type", writer.FormDataContentType()) req.Header.Set("Content-Length", strconv.Itoa(len(body.Bytes()))) resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") } // go test -run Test_Ctx_Secure func Test_Ctx_Secure(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) // TODO Add TLS conn require.False(t, c.Secure()) } // go test -run Test_Ctx_Stale func Test_Ctx_Stale(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) require.True(t, c.Stale()) } // go test -run Test_Ctx_Subdomains func Test_Ctx_Subdomains(t *testing.T) { app := New() type tc struct { name string host string offset []int // nil ⇒ call without argument want []string } cases := []tc{ { name: "default offset (2) drops registrable domain + TLD", host: "john.doe.is.awesome.google.com", offset: nil, // Subdomains() want: []string{"john", "doe", "is", "awesome"}, }, { name: "custom offset trims N right-hand labels", host: "john.doe.is.awesome.google.com", offset: []int{4}, want: []string{"john", "doe"}, }, { name: "offset too high returns empty", host: "john.doe.is.awesome.google.com", offset: []int{10}, want: []string{}, }, { name: "zero offset returns all labels", host: "john.doe.google.com", offset: []int{0}, want: []string{"john", "doe", "google", "com"}, }, { name: "offset 1 keeps registrable domain", host: "john.doe.google.com", offset: []int{1}, want: []string{"john", "doe", "google"}, }, { name: "negative offset returns empty", host: "john.doe.google.com", offset: []int{-1}, want: []string{}, }, { name: "offset equal len returns empty", host: "john.doe.com", offset: []int{3}, want: []string{}, }, { name: "offset equal len returns empty", host: "john.doe.com", offset: []int{3}, want: []string{}, }, { name: "zero offset returns all labels with port present", host: "localhost:3000", offset: []int{0}, want: []string{"localhost"}, }, { name: "host with port — custom offset trims 2 labels", host: "foo.bar.example.com:8080", offset: []int{2}, want: []string{"foo", "bar"}, }, { name: "fully qualified domain trims trailing dot", host: "john.doe.example.com.", offset: nil, want: []string{"john", "doe"}, }, { name: "punycode domain is decoded", host: "xn--bcher-kva.example.com", offset: nil, want: []string{"bücher"}, }, { name: "punycode fqdn is decoded", host: "xn--bcher-kva.example.com.", offset: nil, want: []string{"bücher"}, }, { name: "punycode decode failure uses fallback", host: "xn--bcher--.example.com", offset: nil, want: []string{"xn--bcher--"}, }, { name: "invalid host keeps original lowercased", host: "Foo Bar", offset: []int{0}, want: []string{"foo bar"}, }, { name: "IPv4 host returns empty", host: "192.168.0.1", offset: nil, want: []string{}, }, { name: "IPv6 host returns empty", host: "[2001:db8::1]", offset: nil, want: []string{}, }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { c := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(c) c.Request().URI().SetHost(tc.host) got := c.Subdomains(tc.offset...) require.Equal(t, tc.want, got) }) } } // go test -v -run=^$ -bench=Benchmark_Ctx_Subdomains -benchmem -count=4 func Benchmark_Ctx_Subdomains(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://john.doe.google.com") var res []string b.ReportAllocs() for b.Loop() { res = c.Subdomains() } require.Equal(b, []string{"john", "doe"}, res) } // go test -run Test_Ctx_ClearCookie func Test_Ctx_ClearCookie(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderCookie, "john=doe") c.Res().ClearCookie("john") require.True(t, strings.HasPrefix(string(c.Response().Header.Peek(HeaderSetCookie)), "john=; expires=")) c.Request().Header.Set(HeaderCookie, "test1=dummy") c.Request().Header.Set(HeaderCookie, "test2=dummy") c.ClearCookie() require.Contains(t, string(c.Response().Header.Peek(HeaderSetCookie)), "test1=; expires=") require.Contains(t, string(c.Response().Header.Peek(HeaderSetCookie)), "test2=; expires=") } // go test -race -run Test_Ctx_Download func Test_Ctx_Download(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) require.NoError(t, c.Download("ctx.go", "Awesome File!")) f, err := os.Open("./ctx.go") require.NoError(t, err) defer func() { require.NoError(t, f.Close()) }() expect, err := io.ReadAll(f) require.NoError(t, err) require.Equal(t, expect, c.Response().Body()) require.Equal(t, `attachment; filename="Awesome+File%21"`, string(c.Response().Header.Peek(HeaderContentDisposition))) require.NoError(t, c.Res().Download("ctx.go")) require.Equal(t, `attachment; filename="ctx.go"`, string(c.Response().Header.Peek(HeaderContentDisposition))) require.NoError(t, c.Download("ctx.go", "файл.txt")) header := string(c.Response().Header.Peek(HeaderContentDisposition)) require.Contains(t, header, `filename="файл.txt"`) require.Contains(t, header, `filename*=UTF-8''%D1%84%D0%B0%D0%B9%D0%BB.txt`) } // go test -race -run Test_Ctx_Download_SanitizesFilenameControls func Test_Ctx_Download_SanitizesFilenameControls(t *testing.T) { t.Parallel() app := New() testCases := []struct { name string filename string expected string }{ { name: "base name only", filename: "../docs/archive.tar.gz", expected: `attachment; filename="archive.tar.gz"`, }, { name: "controls stripped", filename: "down\r\nload\t\x00.txt", expected: `attachment; filename="download.txt"`, }, { name: "empty after sanitize", filename: "\r\n\t\x00", expected: `attachment; filename="download"`, }, { name: "dot fallback", filename: ".", expected: `attachment; filename="download"`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() c := app.AcquireCtx(&fasthttp.RequestCtx{}) require.NoError(t, c.Download("ctx.go", tc.filename)) header := string(c.Response().Header.Peek(HeaderContentDisposition)) require.Equal(t, tc.expected, header) require.NotContains(t, header, "\r") require.NotContains(t, header, "\n") require.NotContains(t, header, "\t") require.NotContains(t, header, "\x00") }) } } // go test -race -run Test_Ctx_SendEarlyHints func Test_Ctx_SendEarlyHints(t *testing.T) { t.Parallel() app := New() hints := []string{"; rel=preload; as=script"} app.Get("/earlyhints", func(c Ctx) error { err := c.SendEarlyHints(hints) require.NoError(t, err, "SendEarlyHints") c.Status(StatusBadRequest) return c.SendString("fail") }) req := httptest.NewRequest(MethodGet, "/earlyhints", http.NoBody) resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusBadRequest, resp.StatusCode, "Status code") require.Equal(t, hints, resp.Header["Link"], "Link header") body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "fail", string(body)) } // go test -race -run Test_Ctx_SendFile func Test_Ctx_SendFile(t *testing.T) { t.Parallel() app := New() // fetch file content f, err := os.Open("./ctx.go") require.NoError(t, err) defer func() { require.NoError(t, f.Close()) }() expectFileContent, err := io.ReadAll(f) require.NoError(t, err) // fetch file info for the not modified test case fI, err := os.Stat("./ctx.go") require.NoError(t, err) // simple test case c := app.AcquireCtx(&fasthttp.RequestCtx{}) err = c.SendFile("ctx.go") // check expectation require.NoError(t, err) require.Equal(t, expectFileContent, c.Response().Body()) require.Equal(t, StatusOK, c.Response().StatusCode()) app.ReleaseCtx(c) // test with custom error code c = app.AcquireCtx(&fasthttp.RequestCtx{}) err = c.Res().Status(StatusInternalServerError).SendFile("ctx.go") // check expectation require.NoError(t, err) require.Equal(t, expectFileContent, c.Response().Body()) require.Equal(t, StatusInternalServerError, c.Response().StatusCode()) app.ReleaseCtx(c) // test not modified c = app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderIfModifiedSince, fI.ModTime().Format(time.RFC1123)) err = c.SendFile("ctx.go") // check expectation require.NoError(t, err) require.Equal(t, StatusNotModified, c.Response().StatusCode()) require.Equal(t, []byte(nil), c.Response().Body()) app.ReleaseCtx(c) } // go test -race -run Test_Ctx_SendFile_ContentType func Test_Ctx_SendFile_ContentType(t *testing.T) { t.Parallel() app := New() // 1) simple case c := app.AcquireCtx(&fasthttp.RequestCtx{}) err := c.Res().SendFile("./.github/testdata/fs/img/fiber.png") // check expectation require.NoError(t, err) require.Equal(t, StatusOK, c.Response().StatusCode()) require.Equal(t, "image/png", string(c.Response().Header.Peek(HeaderContentType))) app.ReleaseCtx(c) // 2) set by valid file extension, not file header // see: https://github.com/valyala/fasthttp/blob/d795f13985f16622a949ea9fc3459cf54dc78b3e/fs.go#L1638 c = app.AcquireCtx(&fasthttp.RequestCtx{}) err = c.SendFile("./.github/testdata/fs/img/fiberpng.jpeg") // check expectation require.NoError(t, err) require.Equal(t, StatusOK, c.Response().StatusCode()) require.Equal(t, "image/jpeg", string(c.Response().Header.Peek(HeaderContentType))) app.ReleaseCtx(c) // 3) set by file header if extension is invalid c = app.AcquireCtx(&fasthttp.RequestCtx{}) err = c.SendFile("./.github/testdata/fs/img/fiberpng.notvalidext") // check expectation require.NoError(t, err) require.Equal(t, StatusOK, c.Response().StatusCode()) require.Equal(t, "image/png", string(c.Response().Header.Peek(HeaderContentType))) app.ReleaseCtx(c) // 4) set by file header if extension is missing c = app.AcquireCtx(&fasthttp.RequestCtx{}) err = c.SendFile("./.github/testdata/fs/img/fiberpng") // check expectation require.NoError(t, err) require.Equal(t, StatusOK, c.Response().StatusCode()) require.Equal(t, "image/png", string(c.Response().Header.Peek(HeaderContentType))) app.ReleaseCtx(c) } func Test_Ctx_SendFile_Download(t *testing.T) { t.Parallel() app := New() // fetch file content f, err := os.Open("./ctx.go") require.NoError(t, err) defer func() { require.NoError(t, f.Close()) }() expectFileContent, err := io.ReadAll(f) require.NoError(t, err) // fetch file info for the not modified test case _, err = os.Stat("./ctx.go") require.NoError(t, err) // simple test case c := app.AcquireCtx(&fasthttp.RequestCtx{}) err = c.SendFile("ctx.go", SendFile{ Download: true, }) // check expectation require.NoError(t, err) require.Equal(t, expectFileContent, c.Response().Body()) require.Equal(t, "attachment", string(c.Response().Header.Peek(HeaderContentDisposition))) require.Equal(t, StatusOK, c.Response().StatusCode()) app.ReleaseCtx(c) } func Test_Ctx_SendFile_MaxAge(t *testing.T) { t.Parallel() app := New() // fetch file content f, err := os.Open("./ctx.go") require.NoError(t, err) defer func() { require.NoError(t, f.Close()) }() expectFileContent, err := io.ReadAll(f) require.NoError(t, err) // fetch file info for the not modified test case _, err = os.Stat("./ctx.go") require.NoError(t, err) // simple test case c := app.AcquireCtx(&fasthttp.RequestCtx{}) err = c.SendFile("ctx.go", SendFile{ MaxAge: 100, }) // check expectation require.NoError(t, err) require.Equal(t, expectFileContent, c.Response().Body()) require.Equal(t, "public, max-age=100", string(c.RequestCtx().Response.Header.Peek(HeaderCacheControl)), "CacheControl Control") require.Equal(t, StatusOK, c.Response().StatusCode()) app.ReleaseCtx(c) } func Test_Static_Compress(t *testing.T) { t.Parallel() app := New() app.Get("/file", func(c Ctx) error { return c.SendFile("ctx.go", SendFile{ Compress: true, }) }) // Note: deflate is not supported by fasthttp.FS algorithms := []string{"zstd", "gzip", "br"} for _, algo := range algorithms { t.Run(algo+"_compression", func(t *testing.T) { t.Parallel() req := httptest.NewRequest(MethodGet, "/file", http.NoBody) req.Header.Set("Accept-Encoding", algo) resp, err := app.Test(req, TestConfig{ Timeout: 10 * time.Second, FailOnTimeout: true, }) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.NotEqual(t, "58726", resp.Header.Get(HeaderContentLength)) }) } } func Test_Ctx_SendFile_Compress_CheckCompressed(t *testing.T) { t.Parallel() app := New() // fetch file content f, err := os.Open("./ctx.go") require.NoError(t, err) t.Cleanup(func() { require.NoError(t, f.Close()) }) expectedFileContent, err := io.ReadAll(f) require.NoError(t, err) sendFileBodyReader := func(compression string) ([]byte, error) { t.Helper() c := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(c) c.Request().Header.Add(HeaderAcceptEncoding, compression) err := c.SendFile("./ctx.go", SendFile{ Compress: true, }) return c.Response().Body(), err } t.Run("gzip", func(t *testing.T) { t.Parallel() b, err := sendFileBodyReader("gzip") require.NoError(t, err) body, err := fasthttp.AppendGunzipBytes(nil, b) require.NoError(t, err) require.Equal(t, expectedFileContent, body) }) t.Run("zstd", func(t *testing.T) { t.Parallel() b, err := sendFileBodyReader("zstd") require.NoError(t, err) body, err := fasthttp.AppendUnzstdBytes(nil, b) require.NoError(t, err) require.Equal(t, expectedFileContent, body) }) t.Run("br", func(t *testing.T) { t.Parallel() b, err := sendFileBodyReader("br") require.NoError(t, err) body, err := fasthttp.AppendUnbrotliBytes(nil, b) require.NoError(t, err) require.Equal(t, expectedFileContent, body) }) } //go:embed ctx.go var embedFile embed.FS func Test_Ctx_SendFile_EmbedFS(t *testing.T) { t.Parallel() app := New() f, err := os.Open("./ctx.go") require.NoError(t, err) defer func() { require.NoError(t, f.Close()) }() expectFileContent, err := io.ReadAll(f) require.NoError(t, err) app.Get("/test", func(c Ctx) error { return c.SendFile("ctx.go", SendFile{ FS: embedFile, }) }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, StatusOK, resp.StatusCode) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, expectFileContent, body) } // go test -race -run Test_Ctx_SendFile_404 func Test_Ctx_SendFile_404(t *testing.T) { t.Parallel() app := New() app.Get("/", func(c Ctx) error { return c.SendFile("ctx12.go") }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, StatusNotFound, resp.StatusCode) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "sendfile: file ctx12.go not found", string(body)) } func Test_Ctx_SendFile_Multiple(t *testing.T) { t.Parallel() app := New() app.Get("/test", func(c Ctx) error { switch c.Query("file") { case "1": return c.SendFile("ctx.go") case "2": return c.SendFile("app.go") case "3": return c.SendFile("ctx.go", SendFile{ Download: true, }) case "4": return c.SendFile("app_test.go", SendFile{ FS: os.DirFS("."), }) default: return c.SendStatus(StatusNotFound) } }) app.Get("/test2", func(c Ctx) error { return c.SendFile("ctx.go", SendFile{ Download: true, }) }) testCases := []struct { url string body string contentDisposition string }{ {url: "/test?file=1", body: "type DefaultCtx struct", contentDisposition: ""}, {url: "/test?file=2", body: "type App struct", contentDisposition: ""}, {url: "/test?file=3", body: "type DefaultCtx struct", contentDisposition: "attachment"}, {url: "/test?file=4", body: "Test_App_MethodNotAllowed", contentDisposition: ""}, {url: "/test2", body: "type DefaultCtx struct", contentDisposition: "attachment"}, {url: "/test2", body: "type DefaultCtx struct", contentDisposition: "attachment"}, } for _, tc := range testCases { resp, err := app.Test(httptest.NewRequest(MethodGet, tc.url, http.NoBody)) require.NoError(t, err) require.Equal(t, StatusOK, resp.StatusCode) require.Equal(t, tc.contentDisposition, resp.Header.Get(HeaderContentDisposition)) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Contains(t, string(body), tc.body) } app.sendfilesMutex.RLock() defer app.sendfilesMutex.RUnlock() require.Len(t, app.sendfiles, 3) } // go test -race -run Test_Ctx_SendFile_Immutable func Test_Ctx_SendFile_Immutable(t *testing.T) { t.Parallel() app := New() var endpointsForTest []string addEndpoint := func(file, endpoint string) { endpointsForTest = append(endpointsForTest, endpoint) app.Get(endpoint, func(c Ctx) error { if err := c.SendFile(file); err != nil { require.NoError(t, err) return err } return c.SendStatus(200) }) } // relative paths addEndpoint("./.github/index.html", "/relativeWithDot") addEndpoint(filepath.FromSlash("./.github/index.html"), "/relativeOSWithDot") addEndpoint(".github/index.html", "/relative") addEndpoint(filepath.FromSlash(".github/index.html"), "/relativeOS") // absolute paths if path, err := filepath.Abs(".github/index.html"); err != nil { require.NoError(t, err) } else { addEndpoint(path, "/absolute") addEndpoint(filepath.FromSlash(path), "/absoluteOS") // os related } for _, endpoint := range endpointsForTest { t.Run(endpoint, func(t *testing.T) { t.Parallel() // 1st try resp, err := app.Test(httptest.NewRequest(MethodGet, endpoint, http.NoBody)) require.NoError(t, err) require.Equal(t, StatusOK, resp.StatusCode) // 2nd try resp, err = app.Test(httptest.NewRequest(MethodGet, endpoint, http.NoBody)) require.NoError(t, err) require.Equal(t, StatusOK, resp.StatusCode) }) } } // go test -race -run Test_Ctx_SendFile_RestoreOriginalURL func Test_Ctx_SendFile_RestoreOriginalURL(t *testing.T) { t.Parallel() app := New() app.Get("/", func(c Ctx) error { originalURL := utils.CopyString(c.OriginalURL()) err := c.SendFile("ctx.go") require.Equal(t, originalURL, c.OriginalURL()) return err }) _, err1 := app.Test(httptest.NewRequest(MethodGet, "/?test=true", http.NoBody)) // second request required to confirm with zero allocation _, err2 := app.Test(httptest.NewRequest(MethodGet, "/?test=true", http.NoBody)) require.NoError(t, err1) require.NoError(t, err2) } func Test_SendFile_withRoutes(t *testing.T) { t.Parallel() app := New() app.Get("/file", func(c Ctx) error { return c.SendFile("ctx.go") }) app.Get("/file/download", func(c Ctx) error { return c.SendFile("ctx.go", SendFile{ Download: true, }) }) app.Get("/file/fs", func(c Ctx) error { return c.SendFile("ctx.go", SendFile{ FS: os.DirFS("."), }) }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/file", http.NoBody)) require.NoError(t, err) require.Equal(t, StatusOK, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(MethodGet, "/file/download", http.NoBody)) require.NoError(t, err) require.Equal(t, StatusOK, resp.StatusCode) require.Equal(t, "attachment", resp.Header.Get(HeaderContentDisposition)) resp, err = app.Test(httptest.NewRequest(MethodGet, "/file/fs", http.NoBody)) require.NoError(t, err) require.Equal(t, StatusOK, resp.StatusCode) } func Test_SendFile_ByteRange(t *testing.T) { if runtime.GOOS == windowsOS { t.Skip("SendFile byte-range tests are flaky on Windows") } content := []byte("0123456789") tmpDir := t.TempDir() fixture := filepath.Join(tmpDir, "fixture.txt") require.NoError(t, os.WriteFile(fixture, content, 0o600)) app := New() app.Get("/range", func(c Ctx) error { return c.SendFile(fixture, SendFile{ByteRange: true}) }) app.Get("/norange", func(c Ctx) error { return c.SendFile(fixture) }) t.Run("satisfiable single range", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/range", http.NoBody) req.Header.Set(HeaderRange, "bytes=0-4") resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") defer func() { require.NoError(t, resp.Body.Close()) }() require.Equal(t, StatusPartialContent, resp.StatusCode) require.Equal(t, "bytes", resp.Header.Get(HeaderAcceptRanges)) require.Equal(t, "bytes 0-4/10", resp.Header.Get(HeaderContentRange)) require.EqualValues(t, len(content[:5]), resp.ContentLength) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, content[:5], body) }) t.Run("single byte range", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/range", http.NoBody) req.Header.Set(HeaderRange, "bytes=4-4") resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") defer func() { require.NoError(t, resp.Body.Close()) }() require.Equal(t, StatusPartialContent, resp.StatusCode) require.Equal(t, "bytes", resp.Header.Get(HeaderAcceptRanges)) require.Equal(t, "bytes 4-4/10", resp.Header.Get(HeaderContentRange)) require.EqualValues(t, 1, resp.ContentLength) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, content[4:5], body) }) t.Run("open ended range", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/range", http.NoBody) req.Header.Set(HeaderRange, "bytes=4-") resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") defer func() { require.NoError(t, resp.Body.Close()) }() require.Equal(t, StatusPartialContent, resp.StatusCode) require.Equal(t, "bytes", resp.Header.Get(HeaderAcceptRanges)) require.Equal(t, "bytes 4-9/10", resp.Header.Get(HeaderContentRange)) require.EqualValues(t, len(content[4:]), resp.ContentLength) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, content[4:], body) }) t.Run("range exceeding end", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/range", http.NoBody) req.Header.Set(HeaderRange, "bytes=5-20") resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") defer func() { require.NoError(t, resp.Body.Close()) }() require.Equal(t, StatusPartialContent, resp.StatusCode) require.Equal(t, "bytes", resp.Header.Get(HeaderAcceptRanges)) require.Equal(t, "bytes 5-9/10", resp.Header.Get(HeaderContentRange)) require.EqualValues(t, len(content[5:]), resp.ContentLength) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, content[5:], body) }) t.Run("suffix range", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/range", http.NoBody) req.Header.Set(HeaderRange, "bytes=-3") resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") defer func() { require.NoError(t, resp.Body.Close()) }() require.Equal(t, StatusPartialContent, resp.StatusCode) require.Equal(t, "bytes", resp.Header.Get(HeaderAcceptRanges)) require.Equal(t, "bytes 7-9/10", resp.Header.Get(HeaderContentRange)) require.EqualValues(t, len(content[len(content)-3:]), resp.ContentLength) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, content[len(content)-3:], body) }) t.Run("suffix range exceeding size", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/range", http.NoBody) req.Header.Set(HeaderRange, "bytes=-20") resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") defer func() { require.NoError(t, resp.Body.Close()) }() require.Equal(t, StatusPartialContent, resp.StatusCode) require.Equal(t, "bytes", resp.Header.Get(HeaderAcceptRanges)) require.Equal(t, "bytes 0-9/10", resp.Header.Get(HeaderContentRange)) require.EqualValues(t, len(content), resp.ContentLength) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, content, body) }) t.Run("unsatisfiable range", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/range", http.NoBody) req.Header.Set(HeaderRange, "bytes=1000-2000") resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") defer func() { require.NoError(t, resp.Body.Close()) }() require.Equal(t, StatusRequestedRangeNotSatisfiable, resp.StatusCode) require.Equal(t, "bytes */10", resp.Header.Get(HeaderContentRange)) }) t.Run("unsatisfiable reversed range", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/range", http.NoBody) req.Header.Set(HeaderRange, "bytes=6-5") resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") defer func() { require.NoError(t, resp.Body.Close()) }() require.Equal(t, StatusRequestedRangeNotSatisfiable, resp.StatusCode) require.Equal(t, "bytes */10", resp.Header.Get(HeaderContentRange)) }) t.Run("unsatisfiable start past end", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/range", http.NoBody) req.Header.Set(HeaderRange, "bytes=10-") resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") defer func() { require.NoError(t, resp.Body.Close()) }() require.Equal(t, StatusRequestedRangeNotSatisfiable, resp.StatusCode) require.Equal(t, "bytes */10", resp.Header.Get(HeaderContentRange)) }) t.Run("range ignored when byte range disabled", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/norange", http.NoBody) req.Header.Set(HeaderRange, "bytes=0-4") resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") defer func() { require.NoError(t, resp.Body.Close()) }() require.Equal(t, StatusOK, resp.StatusCode) require.Empty(t, resp.Header.Get(HeaderAcceptRanges)) require.Empty(t, resp.Header.Get(HeaderContentRange)) require.EqualValues(t, len(content), resp.ContentLength) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, content, body) }) } func Benchmark_Ctx_SendFile(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) b.ReportAllocs() var err error for b.Loop() { err = c.SendFile("ctx.go") } require.NoError(b, err) require.Contains(b, string(c.Response().Body()), "type DefaultCtx struct") } // go test -run Test_Ctx_JSON func Test_Ctx_JSON(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) require.Error(t, c.JSON(complex(1, 1))) // Test without ctype err := c.JSON(Map{ // map has no order "Name": "Grame", "Age": 20, }) require.NoError(t, err) require.JSONEq(t, `{"Age":20,"Name":"Grame"}`, string(c.Response().Body())) require.Equal(t, "application/json; charset=utf-8", string(c.Response().Header.Peek("content-type"))) // Test with ctype err = c.JSON(Map{ // map has no order "Name": "Grame", "Age": 20, }, "application/problem+json") require.NoError(t, err) require.JSONEq(t, `{"Age":20,"Name":"Grame"}`, string(c.Response().Body())) require.Equal(t, "application/problem+json", string(c.Response().Header.Peek("content-type"))) testEmpty := func(v any, r string) { err := c.JSON(v) require.NoError(t, err) require.Equal(t, r, string(c.Response().Body())) } testEmpty(nil, "null") testEmpty("", `""`) testEmpty(0, "0") testEmpty([]int{}, "[]") t.Run("custom json encoder", func(t *testing.T) { t.Parallel() app := New(Config{ JSONEncoder: func(_ any) ([]byte, error) { return []byte(`["custom","json"]`), nil }, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) err := c.JSON(Map{ // map has no order "Name": "Grame", "Age": 20, }) require.NoError(t, err) require.Equal(t, `["custom","json"]`, string(c.Response().Body())) require.Equal(t, "application/json; charset=utf-8", string(c.Response().Header.Peek("content-type"))) }) } // go test -run=^$ -bench=Benchmark_Ctx_JSON -benchmem -count=4 func Benchmark_Ctx_JSON(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) type SomeStruct struct { Name string Age uint8 } data := SomeStruct{ Name: "Grame", Age: 20, } var err error b.ReportAllocs() for b.Loop() { err = c.JSON(data) } require.NoError(b, err) require.JSONEq(b, `{"Name":"Grame","Age":20}`, string(c.Response().Body())) } // go test -run Test_Ctx_MsgPack func Test_Ctx_MsgPack(t *testing.T) { t.Parallel() app := New(Config{ MsgPackEncoder: msgpack.Marshal, MsgPackDecoder: msgpack.Unmarshal, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) err := c.MsgPack(complex(1, 1)) require.NoError(t, err) require.Equal(t, "\u0600?\xf0\x00\x00\x00\x00\x00\x00?\xf0\x00\x00\x00\x00\x00\x00", string(c.Response().Body())) // Test without ctype err = c.MsgPack(Map{ // map has no order "Name": "Grame", }) require.NoError(t, err) require.Equal(t, "\x81\xa4Name\xa5Grame", string(c.Response().Body())) require.Equal(t, MIMEApplicationMsgPack, string(c.Response().Header.Peek("content-type"))) // Test with ctype err = c.MsgPack(Map{ // map has no order "Name": "Grame", }, "application/problem+msgpack") require.NoError(t, err) require.Equal(t, "\x81\xa4Name\xa5Grame", string(c.Response().Body())) require.Equal(t, "application/problem+msgpack", string(c.Response().Header.Peek("content-type"))) testEmpty := func(v any, r string) { err := c.MsgPack(v) require.NoError(t, err) require.Equal(t, r, string(c.Response().Body())) } testEmpty(nil, "\xc0") testEmpty("", "\xa0") testEmpty(0, "\x00") testEmpty([]int{}, "\x90") t.Run("custom msgpack encoder", func(t *testing.T) { t.Parallel() app := New(Config{ MsgPackEncoder: func(_ any) ([]byte, error) { return []byte(`["custom","msgpack"]`), nil }, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) err := c.MsgPack(Map{ // map has no order "Name": "Grame", "Age": 20, }) require.NoError(t, err) require.Equal(t, `["custom","msgpack"]`, string(c.Response().Body())) require.Equal(t, MIMEApplicationMsgPack, string(c.Response().Header.Peek("content-type"))) }) t.Run("error msgpack", func(t *testing.T) { t.Parallel() app := New(Config{ MsgPackEncoder: func(_ any) ([]byte, error) { return []byte("error"), errors.New("msgpack error") }, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) err := c.MsgPack(Map{ // map has no order "Name": "Grame", "Age": 20, }) require.Error(t, err) }) } // go test -run=^$ -bench=Benchmark_Ctx_MsgPack -benchmem -count=4 func Benchmark_Ctx_MsgPack(b *testing.B) { app := New( Config{ MsgPackEncoder: msgpack.Marshal, MsgPackDecoder: msgpack.Unmarshal, }, ) c := app.AcquireCtx(&fasthttp.RequestCtx{}) type SomeStruct struct { Name string Age uint8 } data := SomeStruct{ Name: "Grame", Age: 20, } var err error b.ReportAllocs() for b.Loop() { err = c.MsgPack(data) } require.NoError(b, err) require.Equal(b, "\x82\xa4Name\xa5Grame\xa3Age\x14", string(c.Response().Body())) } // go test -run Test_Ctx_CBOR func Test_Ctx_CBOR(t *testing.T) { t.Parallel() app := New(Config{ CBOREncoder: cbor.Marshal, CBORDecoder: cbor.Unmarshal, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) require.Error(t, c.CBOR(complex(1, 1))) type dummyStruct struct { Name string Age int } // Test without ctype err := c.CBOR(dummyStruct{ // map has no order Name: "Grame", Age: 20, }) require.NoError(t, err) require.Equal(t, `a2644e616d65654772616d656341676514`, hex.EncodeToString(c.Response().Body())) require.Equal(t, "application/cbor", string(c.Response().Header.Peek("content-type"))) // Test with ctype err = c.CBOR(dummyStruct{ // map has no order Name: "Grame", Age: 20, }, "application/problem+cbor") require.NoError(t, err) require.Equal(t, `a2644e616d65654772616d656341676514`, hex.EncodeToString(c.Response().Body())) require.Equal(t, "application/problem+cbor", string(c.Response().Header.Peek("content-type"))) testEmpty := func(v any, r string) { cbErr := c.CBOR(v) require.NoError(t, cbErr) require.Equal(t, r, hex.EncodeToString(c.Response().Body())) } testEmpty(nil, "f6") testEmpty("", `60`) testEmpty(0, "00") testEmpty([]int{}, "80") // Test invalid types err = c.CBOR(make(chan int)) require.Error(t, err) err = c.CBOR(func() {}) require.Error(t, err) t.Run("custom cbor encoder", func(t *testing.T) { t.Parallel() app := New(Config{ CBOREncoder: func(_ any) ([]byte, error) { return []byte(hex.EncodeToString([]byte("random"))), nil }, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) err := c.CBOR(Map{ // map has no order "Name": "Grame", "Age": 20, }) require.NoError(t, err) require.Equal(t, `72616e646f6d`, string(c.Response().Body())) require.Equal(t, "application/cbor", string(c.Response().Header.Peek("content-type"))) }) } // go test -run=^$ -bench=Benchmark_Ctx_CBOR -benchmem -count=4 func Benchmark_Ctx_CBOR(b *testing.B) { app := New(Config{ CBOREncoder: cbor.Marshal, CBORDecoder: cbor.Unmarshal, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) type SomeStruct struct { Name string Age uint8 } data := SomeStruct{ Name: "Grame", Age: 20, } var err error b.ReportAllocs() for b.Loop() { err = c.CBOR(data) } require.NoError(b, err) require.Equal(b, `a2644e616d65654772616d656341676514`, hex.EncodeToString(c.Response().Body())) } // go test -run=^$ -bench=Benchmark_Ctx_JSON_Ctype -benchmem -count=4 func Benchmark_Ctx_JSON_Ctype(b *testing.B) { app := New() // TODO: Check extra allocs because of the interface stuff c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed type SomeStruct struct { Name string Age uint8 } data := SomeStruct{ Name: "Grame", Age: 20, } var err error b.ReportAllocs() for b.Loop() { err = c.JSON(data, "application/problem+json") } require.NoError(b, err) require.JSONEq(b, `{"Name":"Grame","Age":20}`, string(c.Response().Body())) require.Equal(b, "application/problem+json", string(c.Response().Header.Peek("content-type"))) } // go test -run Test_Ctx_JSONP func Test_Ctx_JSONP(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) require.Error(t, c.JSONP(complex(1, 1))) err := c.JSONP(Map{ "Name": "Grame", "Age": 20, }) require.NoError(t, err) require.Equal(t, `callback({"Age":20,"Name":"Grame"});`, string(c.Response().Body())) require.Equal(t, "text/javascript; charset=utf-8", string(c.Response().Header.Peek("content-type"))) err = c.Res().JSONP(Map{ "Name": "Grame", "Age": 20, }, "john") require.NoError(t, err) require.Equal(t, `john({"Age":20,"Name":"Grame"});`, string(c.Response().Body())) require.Equal(t, "text/javascript; charset=utf-8", string(c.Response().Header.Peek("content-type"))) t.Run("custom json encoder", func(t *testing.T) { t.Parallel() app := New(Config{ JSONEncoder: func(_ any) ([]byte, error) { return []byte(`["custom","json"]`), nil }, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) err := c.JSONP(Map{ // map has no order "Name": "Grame", "Age": 20, }) require.NoError(t, err) require.Equal(t, `callback(["custom","json"]);`, string(c.Response().Body())) require.Equal(t, "text/javascript; charset=utf-8", string(c.Response().Header.Peek("content-type"))) }) } // go test -v -run=^$ -bench=Benchmark_Ctx_JSONP -benchmem -count=4 func Benchmark_Ctx_JSONP(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed type SomeStruct struct { Name string Age uint8 } data := SomeStruct{ Name: "Grame", Age: 20, } callback := "emit" var err error b.ReportAllocs() for b.Loop() { err = c.JSONP(data, callback) } require.NoError(b, err) require.Equal(b, `emit({"Name":"Grame","Age":20});`, string(c.Response().Body())) } // go test -run Test_Ctx_XML func Test_Ctx_XML(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed require.Error(t, c.JSON(complex(1, 1))) type xmlResult struct { XMLName xml.Name `xml:"Users"` Names []string `xml:"Names"` Ages []int `xml:"Ages"` } err := c.XML(xmlResult{ Names: []string{"Grame", "John"}, Ages: []int{1, 12, 20}, }) require.NoError(t, err) require.Equal(t, `GrameJohn11220`, string(c.Response().Body())) require.Equal(t, "application/xml; charset=utf-8", string(c.Response().Header.Peek("content-type"))) testEmpty := func(v any, r string) { err := c.XML(v) require.NoError(t, err) require.Equal(t, r, string(c.Response().Body())) } testEmpty(nil, "") testEmpty("", ``) testEmpty(0, "0") testEmpty([]int{}, "") t.Run("custom xml encoder", func(t *testing.T) { t.Parallel() app := New(Config{ XMLEncoder: func(_ any) ([]byte, error) { return []byte(`xml`), nil }, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) type xmlResult struct { XMLName xml.Name `xml:"Users"` Names []string `xml:"Names"` Ages []int `xml:"Ages"` } err := c.XML(xmlResult{ Names: []string{"Grame", "John"}, Ages: []int{1, 12, 20}, }) require.NoError(t, err) require.Equal(t, `xml`, string(c.Response().Body())) require.Equal(t, "application/xml; charset=utf-8", string(c.Response().Header.Peek("content-type"))) }) } // go test -run=^$ -bench=Benchmark_Ctx_XML -benchmem -count=4 func Benchmark_Ctx_XML(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed type SomeStruct struct { Name string `xml:"Name"` Age uint8 `xml:"Age"` } data := SomeStruct{ Name: "Grame", Age: 20, } var err error b.ReportAllocs() for b.Loop() { err = c.XML(data) } require.NoError(b, err) require.Equal(b, `Grame20`, string(c.Response().Body())) } // go test -run Test_Ctx_Links func Test_Ctx_Links(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Links() require.Empty(t, string(c.Response().Header.Peek(HeaderLink))) c.Links( "http://api.example.com/users?page=2", "next", "http://api.example.com/users?page=5", "last", ) require.Equal(t, `; rel="next",; rel="last"`, string(c.Response().Header.Peek(HeaderLink))) } // go test -v -run=^$ -bench=Benchmark_Ctx_Links -benchmem -count=4 func Benchmark_Ctx_Links(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed b.ReportAllocs() for b.Loop() { c.Links( "http://api.example.com/users?page=2", "next", "http://api.example.com/users?page=5", "last", ) } } // go test -run Test_Ctx_Location func Test_Ctx_Location(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Location("http://example.com") require.Equal(t, "http://example.com", string(c.Response().Header.Peek(HeaderLocation))) } // go test -run Test_Ctx_Next func Test_Ctx_Next(t *testing.T) { t.Parallel() app := New() app.Use("/", func(c Ctx) error { return c.Next() }) app.Get("/test", func(c Ctx) error { c.Set("X-Next-Result", "Works") return nil }) resp, err := app.Test(httptest.NewRequest(MethodGet, "http://example.com/test", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") require.Equal(t, "Works", resp.Header.Get("X-Next-Result")) } // go test -run Test_Ctx_Next_Error func Test_Ctx_Next_Error(t *testing.T) { t.Parallel() app := New() app.Use("/", func(c Ctx) error { c.Set("X-Next-Result", "Works") return ErrNotFound }) resp, err := app.Test(httptest.NewRequest(MethodGet, "http://example.com/test", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusNotFound, resp.StatusCode, "Status code") require.Equal(t, "Works", resp.Header.Get("X-Next-Result")) } // go test -run Test_Ctx_Render func Test_Ctx_Render(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) err := c.Render("./.github/testdata/index.tmpl", Map{ "Title": "Hello, World!", }) require.NoError(t, err) require.Equal(t, "

Hello, World!

", string(c.Response().Body())) err = c.Render("./.github/testdata/template-non-exists.html", nil) require.Error(t, err) err = c.Res().Render("./.github/testdata/template-invalid.html", nil) require.Error(t, err) } func Test_Ctx_RenderWithoutLocals(t *testing.T) { t.Parallel() app := New(Config{ PassLocalsToViews: false, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Locals("Title", "Hello, World!") err := c.Render("./.github/testdata/index.tmpl", Map{}) require.NoError(t, err) require.Equal(t, "

", string(c.Response().Body())) } func Test_Ctx_RenderWithLocals(t *testing.T) { t.Parallel() app := New(Config{ PassLocalsToViews: true, }) t.Run("EmptyBind", func(t *testing.T) { t.Parallel() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Locals("Title", "Hello, World!") err := c.Render("./.github/testdata/index.tmpl", Map{}) require.NoError(t, err) require.Equal(t, "

Hello, World!

", string(c.Response().Body())) }) t.Run("NilBind", func(t *testing.T) { t.Parallel() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Locals("Title", "Hello, World!") err := c.Render("./.github/testdata/index.tmpl", nil) require.NoError(t, err) require.Equal(t, "

Hello, World!

", string(c.Response().Body())) }) } func Test_Ctx_Matched_AfterNext(t *testing.T) { t.Parallel() app := New() app.Use(func(c Ctx) error { require.False(t, c.Matched()) err := c.Next() if c.Path() == "/one" { require.True(t, c.Matched()) } else { require.False(t, c.Matched()) } return err }) app.Get("/one", func(c Ctx) error { return c.SendStatus(StatusOK) }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/one", http.NoBody)) require.NoError(t, err) require.Equal(t, StatusOK, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(MethodGet, "/missing", http.NoBody)) require.NoError(t, err) require.Equal(t, StatusNotFound, resp.StatusCode) } func Test_Ctx_Matched_RouteError(t *testing.T) { t.Parallel() app := New(Config{ ErrorHandler: func(c Ctx, err error) error { require.True(t, c.Matched()) return c.Status(StatusNotFound).SendString(err.Error()) }, }) app.Get("/", func(_ Ctx) error { return ErrNotFound }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, StatusNotFound, resp.StatusCode) } func Test_Ctx_IsMiddleware(t *testing.T) { t.Parallel() app := New() app.Use(func(c Ctx) error { require.True(t, c.IsMiddleware()) return c.Next() }) app.Get("/", func(c Ctx) error { require.False(t, c.IsMiddleware()) return c.SendStatus(StatusOK) }) app.Get("/route", func(c Ctx) error { require.True(t, c.IsMiddleware()) return c.Next() }, func(c Ctx) error { require.False(t, c.IsMiddleware()) return c.SendStatus(StatusOK) }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, StatusOK, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(MethodGet, "/route", http.NoBody)) require.NoError(t, err) require.Equal(t, StatusOK, resp.StatusCode) } func Test_Ctx_HasBody(t *testing.T) { t.Parallel() app := New() acquire := func(t *testing.T) CustomCtx { t.Helper() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) require.NotNil(t, ctx) t.Cleanup(func() { app.ReleaseCtx(ctx) }) return ctx } setTransferEncoding := func(t *testing.T, ctx Ctx, value string) { t.Helper() hdr := &ctx.Request().Header hdr.DisableSpecialHeader() hdr.Set(HeaderTransferEncoding, value) hdr.Set(HeaderContentLength, "0") hdr.EnableSpecialHeader() require.Zero(t, hdr.ContentLength()) require.Empty(t, ctx.Request().Body()) } t.Run("body bytes", func(t *testing.T) { t.Parallel() ctx := acquire(t) ctx.Request().SetBody([]byte("test")) require.True(t, ctx.HasBody()) }) t.Run("content length header", func(t *testing.T) { t.Parallel() ctx := acquire(t) ctx.Request().Header.SetContentLength(4) require.True(t, ctx.HasBody()) }) t.Run("chunked sentinel", func(t *testing.T) { t.Parallel() ctx := acquire(t) ctx.Request().Header.SetContentLength(-1) require.Equal(t, -1, ctx.Request().Header.ContentLength()) require.Empty(t, ctx.Request().Body()) require.True(t, ctx.HasBody()) }) t.Run("transfer encoding chunked", func(t *testing.T) { t.Parallel() ctx := acquire(t) setTransferEncoding(t, ctx, "chunked") require.True(t, ctx.HasBody()) }) t.Run("transfer encoding whitespace", func(t *testing.T) { t.Parallel() ctx := acquire(t) setTransferEncoding(t, ctx, " ChUnKeD ") require.True(t, ctx.HasBody()) }) t.Run("transfer encoding parameters", func(t *testing.T) { t.Parallel() ctx := acquire(t) setTransferEncoding(t, ctx, "chunked; q=1") require.True(t, ctx.HasBody()) }) t.Run("transfer encoding multiple values", func(t *testing.T) { t.Parallel() ctx := acquire(t) setTransferEncoding(t, ctx, "gzip, chunked") require.True(t, ctx.HasBody()) }) t.Run("transfer encoding identity", func(t *testing.T) { t.Parallel() ctx := acquire(t) setTransferEncoding(t, ctx, "identity") require.False(t, ctx.HasBody()) }) t.Run("transfer encoding identity then chunked", func(t *testing.T) { t.Parallel() ctx := acquire(t) setTransferEncoding(t, ctx, "identity, chunked") require.True(t, ctx.HasBody()) }) t.Run("no body", func(t *testing.T) { t.Parallel() ctx := acquire(t) require.False(t, ctx.HasBody()) }) } func Test_Ctx_IsWebSocket(t *testing.T) { t.Parallel() app := New() ws := app.AcquireCtx(&fasthttp.RequestCtx{}) require.NotNil(t, ws) t.Cleanup(func() { app.ReleaseCtx(ws) }) ws.Request().Header.Set(HeaderConnection, "keep-alive, Upgrade") ws.Request().Header.Set(HeaderUpgrade, "websocket") require.True(t, ws.IsWebSocket()) non := app.AcquireCtx(&fasthttp.RequestCtx{}) require.NotNil(t, non) t.Cleanup(func() { app.ReleaseCtx(non) }) non.Request().Header.Set(HeaderConnection, "not-an-upgrade") non.Request().Header.Set(HeaderUpgrade, "websocket") require.False(t, non.IsWebSocket()) } func Test_Ctx_IsPreflight(t *testing.T) { t.Parallel() app := New() preCtx := &fasthttp.RequestCtx{} preCtx.Request.Header.SetMethod(MethodOptions) preCtx.Request.Header.Set(HeaderAccessControlRequestMethod, MethodGet) preCtx.Request.Header.Set(HeaderOrigin, "https://example.com") pre := app.AcquireCtx(preCtx) require.NotNil(t, pre) t.Cleanup(func() { app.ReleaseCtx(pre) }) require.True(t, pre.IsPreflight()) noOriginCtx := &fasthttp.RequestCtx{} noOriginCtx.Request.Header.SetMethod(MethodOptions) noOriginCtx.Request.Header.Set(HeaderAccessControlRequestMethod, MethodGet) noOrigin := app.AcquireCtx(noOriginCtx) require.NotNil(t, noOrigin) t.Cleanup(func() { app.ReleaseCtx(noOrigin) }) require.False(t, noOrigin.IsPreflight()) optCtx := &fasthttp.RequestCtx{} optCtx.Request.Header.SetMethod(MethodOptions) optCtx.Request.Header.Set(HeaderOrigin, "https://example.com") opt := app.AcquireCtx(optCtx) require.NotNil(t, opt) t.Cleanup(func() { app.ReleaseCtx(opt) }) require.False(t, opt.IsPreflight()) } func Test_Ctx_RenderWithViewBind(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) err := c.ViewBind(Map{ "Title": "Hello, World!", }) require.NoError(t, err) err = c.Render("./.github/testdata/index.tmpl", Map{}) require.NoError(t, err) buf := bytebufferpool.Get() buf.WriteString("overwrite") defer bytebufferpool.Put(buf) require.NoError(t, err) require.Equal(t, "

Hello, World!

", string(c.Response().Body())) } func Test_Ctx_RenderWithOverwrittenViewBind(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) err := c.ViewBind(Map{ "Title": "Hello, World!", }) require.NoError(t, err) err = c.Render("./.github/testdata/index.tmpl", Map{ "Title": "Hello from Fiber!", }) require.NoError(t, err) buf := bytebufferpool.Get() buf.WriteString("overwrite") defer bytebufferpool.Put(buf) require.Equal(t, "

Hello from Fiber!

", string(c.Response().Body())) } func Test_Ctx_RenderWithViewBindLocals(t *testing.T) { t.Parallel() app := New(Config{ PassLocalsToViews: true, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) err := c.ViewBind(Map{ "Title": "Hello, World!", }) require.NoError(t, err) c.Locals("Summary", "Test") err = c.Render("./.github/testdata/template.tmpl", Map{}) require.NoError(t, err) require.Equal(t, "

Hello, World! Test

", string(c.Response().Body())) require.Equal(t, "

Hello, World! Test

", string(c.Response().Body())) } func Test_Ctx_RenderWithLocalsAndBinding(t *testing.T) { t.Parallel() engine := &testTemplateEngine{} err := engine.Load() require.NoError(t, err) app := New(Config{ PassLocalsToViews: true, Views: engine, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Locals("Title", "This is a test.") err = c.Render("index.tmpl", Map{ "Title": "Hello, World!", }) require.NoError(t, err) require.Equal(t, "

Hello, World!

", string(c.Response().Body())) } func Benchmark_Ctx_RenderWithLocalsAndViewBind(b *testing.B) { engine := &testTemplateEngine{} err := engine.Load() require.NoError(b, err) app := New(Config{ PassLocalsToViews: true, Views: engine, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) err = c.ViewBind(Map{ "Title": "Hello, World!", }) require.NoError(b, err) c.Locals("Summary", "Test") b.ReportAllocs() for b.Loop() { err = c.Render("template.tmpl", Map{}) } require.NoError(b, err) require.Equal(b, "

Hello, World! Test

", string(c.Response().Body())) } func Benchmark_Ctx_RenderLocals(b *testing.B) { engine := &testTemplateEngine{} err := engine.Load() require.NoError(b, err) app := New(Config{ PassLocalsToViews: true, }) app.config.Views = engine c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Locals("Title", "Hello, World!") b.ReportAllocs() for b.Loop() { err = c.Render("index.tmpl", Map{}) } require.NoError(b, err) require.Equal(b, "

Hello, World!

", string(c.Response().Body())) } func Benchmark_Ctx_RenderViewBind(b *testing.B) { engine := &testTemplateEngine{} err := engine.Load() require.NoError(b, err) app := New() app.config.Views = engine c := app.AcquireCtx(&fasthttp.RequestCtx{}) err = c.ViewBind(Map{ "Title": "Hello, World!", }) require.NoError(b, err) b.ReportAllocs() for b.Loop() { err = c.Render("index.tmpl", Map{}) } require.NoError(b, err) require.Equal(b, "

Hello, World!

", string(c.Response().Body())) } // go test -run Test_Ctx_RestartRouting func Test_Ctx_RestartRouting(t *testing.T) { t.Parallel() app := New() calls := 0 app.Get("/", func(c Ctx) error { calls++ if calls < 3 { return c.RestartRouting() } return nil }) resp, err := app.Test(httptest.NewRequest(MethodGet, "http://example.com/", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") require.Equal(t, 3, calls, "Number of calls") } // go test -run Test_Ctx_RestartRoutingWithChangedPath func Test_Ctx_RestartRoutingWithChangedPath(t *testing.T) { t.Parallel() app := New() var executedOldHandler, executedNewHandler bool app.Get("/old", func(c Ctx) error { c.Path("/new") return c.RestartRouting() }) app.Get("/old", func(_ Ctx) error { executedOldHandler = true return nil }) app.Get("/new", func(_ Ctx) error { executedNewHandler = true return nil }) resp, err := app.Test(httptest.NewRequest(MethodGet, "http://example.com/old", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") require.False(t, executedOldHandler, "Executed old handler") require.True(t, executedNewHandler, "Executed new handler") } // go test -run Test_Ctx_RestartRoutingWithChangedPathAnd404 func Test_Ctx_RestartRoutingWithChangedPathAndCatchAll(t *testing.T) { t.Parallel() app := New() app.Get("/new", func(_ Ctx) error { return nil }) app.Use(func(c Ctx) error { c.Path("/new") // c.Next() would fail this test as a 404 is returned from the next handler return c.RestartRouting() }) app.Use(func(_ Ctx) error { return ErrNotFound }) resp, err := app.Test(httptest.NewRequest(MethodGet, "http://example.com/old", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") } type testTemplateEngine struct { templates *template.Template path string } func (t *testTemplateEngine) Render(w io.Writer, name string, bind any, layout ...string) error { if len(layout) == 0 { if err := t.templates.ExecuteTemplate(w, name, bind); err != nil { return fmt.Errorf("failed to execute template without layout: %w", err) } return nil } if err := t.templates.ExecuteTemplate(w, name, bind); err != nil { return fmt.Errorf("failed to execute template: %w", err) } if err := t.templates.ExecuteTemplate(w, layout[0], bind); err != nil { return fmt.Errorf("failed to execute template with layout: %w", err) } return nil } func (t *testTemplateEngine) Load() error { if t.path == "" { t.path = "testdata" } t.templates = template.Must(template.ParseGlob("./.github/" + t.path + "/*.tmpl")) return nil } // go test -run Test_Ctx_Render_Engine func Test_Ctx_Render_Engine(t *testing.T) { t.Parallel() engine := &testTemplateEngine{} require.NoError(t, engine.Load()) app := New() app.config.Views = engine c := app.AcquireCtx(&fasthttp.RequestCtx{}) err := c.Render("index.tmpl", Map{ "Title": "Hello, World!", }) require.NoError(t, err) require.Equal(t, "

Hello, World!

", string(c.Response().Body())) } // go test -run Test_Ctx_Render_Engine_With_View_Layout func Test_Ctx_Render_Engine_With_View_Layout(t *testing.T) { t.Parallel() engine := &testTemplateEngine{} require.NoError(t, engine.Load()) app := New(Config{ViewsLayout: "main.tmpl"}) app.config.Views = engine c := app.AcquireCtx(&fasthttp.RequestCtx{}) err := c.Render("index.tmpl", Map{ "Title": "Hello, World!", }) require.NoError(t, err) require.Equal(t, "

Hello, World!

I'm main

", string(c.Response().Body())) } // go test -v -run=^$ -bench=Benchmark_Ctx_Render_Engine -benchmem -count=4 func Benchmark_Ctx_Render_Engine(b *testing.B) { engine := &testTemplateEngine{} err := engine.Load() require.NoError(b, err) app := New() app.config.Views = engine c := app.AcquireCtx(&fasthttp.RequestCtx{}) b.ReportAllocs() for b.Loop() { err = c.Render("index.tmpl", Map{ "Title": "Hello, World!", }) } require.NoError(b, err) require.Equal(b, "

Hello, World!

", string(c.Response().Body())) } // go test -v -run=^$ -bench=Benchmark_Ctx_Get_Location_From_Route -benchmem -count=4 func Benchmark_Ctx_Get_Location_From_Route(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed app.Get("/user/:name", func(c Ctx) error { return c.SendString(c.Params("name")) }).Name("User") var err error var location string for b.Loop() { route := app.GetRoute("User") location, err = c.getLocationFromRoute(&route, Map{"name": "fiber"}) } require.Equal(b, "/user/fiber", location) require.NoError(b, err) } // go test -run Test_Ctx_Get_Location_From_Route_name func Test_Ctx_Get_Location_From_Route_name(t *testing.T) { t.Parallel() t.Run("case-insensitive", func(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) app.Get("/user/:name", func(c Ctx) error { return c.SendString(c.Params("name")) }).Name("User") location, err := c.GetRouteURL("User", Map{"name": "fiber"}) require.NoError(t, err) require.Equal(t, "/user/fiber", location) location, err = c.GetRouteURL("User", Map{"Name": "fiber"}) require.NoError(t, err) require.Equal(t, "/user/fiber", location) }) t.Run("case-sensitive", func(t *testing.T) { t.Parallel() app := New(Config{CaseSensitive: true}) c := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(c) app.Get("/user/:name", func(c Ctx) error { return c.SendString(c.Params("name")) }).Name("User") location, err := c.GetRouteURL("User", Map{"name": "fiber"}) require.NoError(t, err) require.Equal(t, "/user/fiber", location) location, err = c.GetRouteURL("User", Map{"Name": "fiber"}) require.NoError(t, err) require.Equal(t, "/user/", location) }) } // go test -run Test_Ctx_Get_Location_From_Route_name_Optional_greedy func Test_Ctx_Get_Location_From_Route_name_Optional_greedy(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) app.Get("/:phone/*/send/*", func(c Ctx) error { return c.SendString("Phone: " + c.Params("phone") + "\nFirst Param: " + c.Params("*1") + "\nSecond Param: " + c.Params("*2")) }).Name("SendSms") location, err := c.GetRouteURL("SendSms", Map{ "phone": "23456789", "*1": "sms", "*2": "test-msg", }) require.NoError(t, err) require.Equal(t, "/23456789/sms/send/test-msg", location) } // go test -run Test_Ctx_Get_Location_From_Route_name_Optional_greedy_one_param func Test_Ctx_Get_Location_From_Route_name_Optional_greedy_one_param(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) app.Get("/:phone/*/send", func(c Ctx) error { return c.SendString("Phone: " + c.Params("phone") + "\nFirst Param: " + c.Params("*1")) }).Name("SendSms") location, err := c.GetRouteURL("SendSms", Map{ "phone": "23456789", "*": "sms", }) require.NoError(t, err) require.Equal(t, "/23456789/sms/send", location) } type errorTemplateEngine struct{} func (errorTemplateEngine) Render(_ io.Writer, _ string, _ any, _ ...string) error { return errors.New("errorTemplateEngine") } func (errorTemplateEngine) Load() error { return nil } // go test -run Test_Ctx_Render_Engine_Error func Test_Ctx_Render_Engine_Error(t *testing.T) { t.Parallel() app := New() app.config.Views = errorTemplateEngine{} c := app.AcquireCtx(&fasthttp.RequestCtx{}) err := c.Render("index.tmpl", nil) require.Error(t, err) } // go test -run Test_Ctx_Render_Go_Template func Test_Ctx_Render_Go_Template(t *testing.T) { t.Parallel() file, err := os.CreateTemp(os.TempDir(), "fiber") require.NoError(t, err) defer func() { removeErr := os.Remove(file.Name()) require.NoError(t, removeErr) }() _, err = file.WriteString("template") require.NoError(t, err) err = file.Close() require.NoError(t, err) app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) err = c.Render(file.Name(), nil) require.NoError(t, err) require.Equal(t, "template", string(c.Response().Body())) } // go test -run Test_Ctx_Send func Test_Ctx_Send(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) require.NoError(t, c.Send([]byte("Hello, World"))) require.NoError(t, c.Send([]byte("Don't crash please"))) require.NoError(t, c.Send([]byte("1337"))) require.Equal(t, "1337", string(c.Response().Body())) } // go test -v -run=^$ -bench=Benchmark_Ctx_Send -benchmem -count=4 func Benchmark_Ctx_Send(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) byt := []byte("Hello, World!") b.ReportAllocs() var err error for b.Loop() { err = c.Send(byt) } require.NoError(b, err) require.Equal(b, "Hello, World!", string(c.Response().Body())) } // go test -run Test_Ctx_SendStatus func Test_Ctx_SendStatus(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) err := c.SendStatus(415) require.NoError(t, err) require.Equal(t, 415, c.Response().StatusCode()) require.Equal(t, "Unsupported Media Type", string(c.Response().Body())) } func Test_Ctx_SendStatusNoBodyResponses(t *testing.T) { t.Parallel() testCases := []struct { name string status int }{ { name: "Informational", status: StatusContinue, }, { name: "Processing", status: StatusProcessing, }, { name: "SwitchingProtocols", status: StatusSwitchingProtocols, }, { name: "EarlyHints", status: StatusEarlyHints, }, { name: "NoContent", status: StatusNoContent, }, { name: "ResetContent", status: StatusResetContent, }, { name: "NotModified", status: StatusNotModified, }, } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Response().SetBodyString("preset body") err := c.SendStatus(testCase.status) require.NoError(t, err) require.Empty(t, c.Response().Body()) require.Equal(t, 0, c.Response().Header.ContentLength()) }) } } // go test -run Test_Ctx_SendString func Test_Ctx_SendString(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) err := c.SendString("Don't crash please") require.NoError(t, err) require.Equal(t, "Don't crash please", string(c.Response().Body())) } // go test -run Test_Ctx_SendStream func Test_Ctx_SendStream(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) err := c.SendStream(bytes.NewReader([]byte("Don't crash please"))) require.NoError(t, err) require.Equal(t, "Don't crash please", string(c.Response().Body())) err = c.SendStream(bytes.NewReader([]byte("Don't crash please")), len([]byte("Don't crash please"))) require.NoError(t, err) require.Equal(t, "Don't crash please", string(c.Response().Body())) err = c.SendStream(bufio.NewReader(bytes.NewReader([]byte("Hello bufio")))) require.NoError(t, err) require.Equal(t, "Hello bufio", string(c.Response().Body())) } // go test -run Test_Ctx_SendStreamWriter func Test_Ctx_SendStreamWriter(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) err := c.SendStreamWriter(func(w *bufio.Writer) { w.WriteString("Don't crash please") //nolint:errcheck // It is fine to ignore the error }) require.NoError(t, err) require.Equal(t, "Don't crash please", string(c.Response().Body())) err = c.SendStreamWriter(func(w *bufio.Writer) { for lineNum := 1; lineNum <= 5; lineNum++ { fmt.Fprintf(w, "Line %d\n", lineNum) if flushErr := w.Flush(); flushErr != nil { t.Errorf("unexpected error: %s", flushErr) return } } }) require.NoError(t, err) require.Equal(t, "Line 1\nLine 2\nLine 3\nLine 4\nLine 5\n", string(c.Response().Body())) err = c.SendStreamWriter(func(_ *bufio.Writer) {}) require.NoError(t, err) require.Empty(t, c.Response().Body()) } // go test -run Test_Ctx_SendStreamWriter_Interrupted func Test_Ctx_SendStreamWriter_Interrupted(t *testing.T) { t.Parallel() app := New() var flushed atomic.Int32 var flushErrLine atomic.Int32 app.Get("/", func(c Ctx) error { return c.SendStreamWriter(func(w *bufio.Writer) { for lineNum := 1; lineNum <= 5; lineNum++ { fmt.Fprintf(w, "Line %d\n", lineNum) if err := w.Flush(); err != nil { flushErrLine.Store(int32(lineNum)) //nolint:gosec // G115 - lineNum is 1-5, fits int32 return } if lineNum <= 3 { flushed.Add(1) } if lineNum == 3 { time.Sleep(500 * time.Millisecond) } } }) }) req := httptest.NewRequest(MethodGet, "/", http.NoBody) testConfig := TestConfig{ // allow enough time for three lines to flush before // the test connection is closed but stop before the // fourth line is sent Timeout: 200 * time.Millisecond, FailOnTimeout: true, // Changed to true to test interrupted behavior } resp, err := app.Test(req, testConfig) // With FailOnTimeout: true, we should get a timeout error require.ErrorIs(t, err, os.ErrDeadlineExceeded) require.Nil(t, resp) } // go test -run Test_Ctx_Set func Test_Ctx_Set(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Set("X-1", "1") c.Set("X-2", "2") c.Set("X-3", "3") c.Set("X-3", "1337") require.Equal(t, "1", string(c.Response().Header.Peek("x-1"))) require.Equal(t, "2", string(c.Response().Header.Peek("x-2"))) require.Equal(t, "1337", string(c.Response().Header.Peek("x-3"))) } // go test -run Test_Ctx_Set_Splitter func Test_Ctx_Set_Splitter(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Set("Location", "foo\r\nSet-Cookie:%20SESSIONID=MaliciousValue\r\n") h := string(c.Response().Header.Peek("Location")) require.NotContains(t, h, "\r\n") c.Set("Location", "foo\nSet-Cookie:%20SESSIONID=MaliciousValue\n") h = string(c.Response().Header.Peek("Location")) require.NotContains(t, h, "\n") } // go test -v -run=^$ -bench=Benchmark_Ctx_Set -benchmem -count=4 func Benchmark_Ctx_Set(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) val := "1431-15132-3423" b.ReportAllocs() for b.Loop() { c.Set(HeaderXRequestID, val) } } // go test -run Test_Ctx_Status func Test_Ctx_Status(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Status(400) require.Equal(t, 400, c.Response().StatusCode()) err := c.Status(415).Send([]byte("Hello, World")) require.NoError(t, err) require.Equal(t, 415, c.Response().StatusCode()) require.Equal(t, "Hello, World", string(c.Response().Body())) } // go test -run Test_Ctx_Type func Test_Ctx_Type(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Type(".json") require.Equal(t, "application/json; charset=utf-8", string(c.Response().Header.Peek("Content-Type"))) c.Type("json", "utf-8") require.Equal(t, "application/json; charset=utf-8", string(c.Response().Header.Peek("Content-Type"))) c.Type(".html") require.Equal(t, "text/html; charset=utf-8", string(c.Response().Header.Peek("Content-Type"))) c.Type("html", "utf-8") require.Equal(t, "text/html; charset=utf-8", string(c.Response().Header.Peek("Content-Type"))) // Test other text types get UTF-8 by default c.Type("txt") require.Equal(t, "text/plain; charset=utf-8", string(c.Response().Header.Peek("Content-Type"))) c.Type("css") require.Equal(t, "text/css; charset=utf-8", string(c.Response().Header.Peek("Content-Type"))) c.Type("js") require.Equal(t, "text/javascript; charset=utf-8", string(c.Response().Header.Peek("Content-Type"))) c.Type("xml") require.Equal(t, "application/xml; charset=utf-8", string(c.Response().Header.Peek("Content-Type"))) // Test binary types don't get charset c.Type("png") require.Equal(t, "image/png", string(c.Response().Header.Peek("Content-Type"))) c.Type("pdf") require.Equal(t, "application/pdf", string(c.Response().Header.Peek("Content-Type"))) // Test custom charset override c.Type("html", "iso-8859-1") require.Equal(t, "text/html; charset=iso-8859-1", string(c.Response().Header.Peek("Content-Type"))) } // go test -run Test_shouldIncludeCharset func Test_shouldIncludeCharset(t *testing.T) { t.Parallel() // Test text/* types - should include charset require.True(t, shouldIncludeCharset("text/html")) require.True(t, shouldIncludeCharset("text/plain")) require.True(t, shouldIncludeCharset("text/css")) require.True(t, shouldIncludeCharset("text/javascript")) require.True(t, shouldIncludeCharset("text/xml")) // Test explicit application types - should include charset require.True(t, shouldIncludeCharset("application/json")) require.True(t, shouldIncludeCharset("application/javascript")) require.True(t, shouldIncludeCharset("application/xml")) // Test +json suffixes - should include charset require.True(t, shouldIncludeCharset("application/problem+json")) require.True(t, shouldIncludeCharset("application/vnd.api+json")) require.True(t, shouldIncludeCharset("application/hal+json")) require.True(t, shouldIncludeCharset("application/merge-patch+json")) // Test +xml suffixes - should include charset require.True(t, shouldIncludeCharset("application/soap+xml")) require.True(t, shouldIncludeCharset("application/xhtml+xml")) require.True(t, shouldIncludeCharset("application/atom+xml")) require.True(t, shouldIncludeCharset("application/rss+xml")) // Test binary types - should NOT include charset require.False(t, shouldIncludeCharset("image/png")) require.False(t, shouldIncludeCharset("image/jpeg")) require.False(t, shouldIncludeCharset("application/pdf")) require.False(t, shouldIncludeCharset("application/octet-stream")) require.False(t, shouldIncludeCharset("video/mp4")) require.False(t, shouldIncludeCharset("audio/mpeg")) // Test other application types - should NOT include charset require.False(t, shouldIncludeCharset("application/cbor")) require.False(t, shouldIncludeCharset("application/x-www-form-urlencoded")) require.False(t, shouldIncludeCharset("application/vnd.msgpack")) } // go test -v -run=^$ -bench=Benchmark_Ctx_Type -benchmem -count=4 func Benchmark_Ctx_Type(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) b.ReportAllocs() for b.Loop() { c.Type(".json") c.Type("json") } } // go test -v -run=^$ -bench=Benchmark_Ctx_Type_Charset -benchmem -count=4 func Benchmark_Ctx_Type_Charset(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed b.ReportAllocs() for b.Loop() { c.Type(".json", "utf-8") c.Type("json", "utf-8") } } // go test -run Test_Ctx_Vary func Test_Ctx_Vary(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Vary("Origin") c.Vary("User-Agent") c.Vary("Accept-Encoding", "Accept") require.Equal(t, "Origin, User-Agent, Accept-Encoding, Accept", string(c.Response().Header.Peek("Vary"))) } // go test -v -run=^$ -bench=Benchmark_Ctx_Vary -benchmem -count=4 func Benchmark_Ctx_Vary(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed b.ReportAllocs() for b.Loop() { c.Vary("Origin", "User-Agent") } } // go test -run Test_Ctx_Write func Test_Ctx_Write(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) _, err := c.WriteString("Hello, ") require.NoError(t, err) _, err = c.WriteString("World!") require.NoError(t, err) require.Equal(t, "Hello, World!", string(c.Response().Body())) } // go test -v -run=^$ -bench=Benchmark_Ctx_Write -benchmem -count=4 func Benchmark_Ctx_Write(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) byt := []byte("Hello, World!") b.ReportAllocs() var err error for b.Loop() { _, err = c.Write(byt) } require.NoError(b, err) } // go test -run Test_Ctx_Writef func Test_Ctx_Writef(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) world := "World!" _, err := c.Writef("Hello, %s", world) require.NoError(t, err) require.Equal(t, "Hello, World!", string(c.Response().Body())) } // go test -v -run=^$ -bench=Benchmark_Ctx_Writef -benchmem -count=4 func Benchmark_Ctx_Writef(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed world := "World!" b.ReportAllocs() var err error for b.Loop() { _, err = c.Writef("Hello, %s", world) } require.NoError(b, err) } // go test -run Test_Ctx_WriteString func Test_Ctx_WriteString(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) _, err := c.WriteString("Hello, ") require.NoError(t, err) _, err = c.WriteString("World!") require.NoError(t, err) require.Equal(t, "Hello, World!", string(c.Response().Body())) } // go test -run Test_Ctx_XHR func Test_Ctx_XHR(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderXRequestedWith, "XMLHttpRequest") require.True(t, c.XHR()) } // go test -run=^$ -bench=Benchmark_Ctx_XHR -benchmem -count=4 func Benchmark_Ctx_XHR(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderXRequestedWith, "XMLHttpRequest") var equal bool b.ReportAllocs() for b.Loop() { equal = c.XHR() } require.True(b, equal) } // go test -v -run=^$ -bench=Benchmark_Ctx_SendString_B -benchmem -count=4 func Benchmark_Ctx_SendString_B(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) body := "Hello, world!" b.ReportAllocs() var err error for b.Loop() { err = c.SendString(body) } require.NoError(b, err) require.Equal(b, []byte("Hello, world!"), c.Response().Body()) } // go test -run Test_Ctx_Queries -v func Test_Ctx_Queries(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetBody([]byte(``)) c.Request().Header.SetContentType("") c.Request().URI().SetQueryString("id=1&name=tom&hobby=basketball,football&favouriteDrinks=milo,coke,pepsi&alloc=&no=1&field1=value1&field1=value2&field2=value3&list_a=1&list_a=2&list_a=3&list_b[]=1&list_b[]=2&list_b[]=3&list_c=1,2,3") queries := c.Queries() require.Equal(t, "1", queries["id"]) require.Equal(t, "tom", queries["name"]) require.Equal(t, "basketball,football", queries["hobby"]) require.Equal(t, "milo,coke,pepsi", queries["favouriteDrinks"]) require.Empty(t, queries["alloc"]) require.Equal(t, "1", queries["no"]) require.Equal(t, "value2", queries["field1"]) require.Equal(t, "value3", queries["field2"]) require.Equal(t, "3", queries["list_a"]) require.Equal(t, "3", queries["list_b[]"]) require.Equal(t, "1,2,3", queries["list_c"]) c.Request().URI().SetQueryString("filters.author.name=John&filters.category.name=Technology&filters[customer][name]=Alice&filters[status]=pending") queries = c.Queries() require.Equal(t, "John", queries["filters.author.name"]) require.Equal(t, "Technology", queries["filters.category.name"]) require.Equal(t, "Alice", queries["filters[customer][name]"]) require.Equal(t, "pending", queries["filters[status]"]) c.Request().URI().SetQueryString("tags=apple,orange,banana&filters[tags]=apple,orange,banana&filters[category][name]=fruits&filters.tags=apple,orange,banana&filters.category.name=fruits") queries = c.Req().Queries() require.Equal(t, "apple,orange,banana", queries["tags"]) require.Equal(t, "apple,orange,banana", queries["filters[tags]"]) require.Equal(t, "fruits", queries["filters[category][name]"]) require.Equal(t, "apple,orange,banana", queries["filters.tags"]) require.Equal(t, "fruits", queries["filters.category.name"]) c.Request().URI().SetQueryString("filters[tags][0]=apple&filters[tags][1]=orange&filters[tags][2]=banana&filters[category][name]=fruits") queries = c.Queries() require.Equal(t, "apple", queries["filters[tags][0]"]) require.Equal(t, "orange", queries["filters[tags][1]"]) require.Equal(t, "banana", queries["filters[tags][2]"]) require.Equal(t, "fruits", queries["filters[category][name]"]) } // go test -v -run=^$ -bench=Benchmark_Ctx_Queries -benchmem -count=4 func Benchmark_Ctx_Queries(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) b.ReportAllocs() c.Request().URI().SetQueryString("id=1&name=tom&hobby=basketball,football&favouriteDrinks=milo,coke,pepsi&alloc=&no=1") var queries map[string]string for b.Loop() { queries = c.Queries() } require.Equal(b, "1", queries["id"]) require.Equal(b, "tom", queries["name"]) require.Equal(b, "basketball,football", queries["hobby"]) require.Equal(b, "milo,coke,pepsi", queries["favouriteDrinks"]) require.Empty(b, queries["alloc"]) require.Equal(b, "1", queries["no"]) } // go test -run Test_Ctx_BodyStreamWriter func Test_Ctx_BodyStreamWriter(t *testing.T) { t.Parallel() ctx := &fasthttp.RequestCtx{} ctx.SetBodyStreamWriter(func(w *bufio.Writer) { fmt.Fprintf(w, "body writer line 1\n") if err := w.Flush(); err != nil { t.Errorf("unexpected error: %s", err) } fmt.Fprintf(w, "body writer line 2\n") }) require.True(t, ctx.IsBodyStream()) s := ctx.Response.String() br := bufio.NewReader(bytes.NewBufferString(s)) var resp fasthttp.Response require.NoError(t, resp.Read(br)) body := string(resp.Body()) expectedBody := "body writer line 1\nbody writer line 2\n" require.Equal(t, expectedBody, body) } // go test -v -run=^$ -bench=Benchmark_Ctx_BodyStreamWriter -benchmem -count=4 func Benchmark_Ctx_BodyStreamWriter(b *testing.B) { ctx := &fasthttp.RequestCtx{} user := []byte(`{"name":"john"}`) b.ReportAllocs() var err error for b.Loop() { ctx.ResetBody() ctx.SetBodyStreamWriter(func(w *bufio.Writer) { for range 10 { _, err = w.Write(user) if flushErr := w.Flush(); flushErr != nil { return } } }) } require.NoError(b, err) } func Test_Ctx_String(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) require.Equal(t, "#0000000000000000 - 0.0.0.0:0 <-> 0.0.0.0:0 - GET http:///", c.String()) } // go test -v -run=^$ -bench=Benchmark_Ctx_String -benchmem -count=4 func Benchmark_Ctx_String(b *testing.B) { var str string app := New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) b.ReportAllocs() for b.Loop() { str = ctx.String() } require.Equal(b, "#0000000000000000 - 0.0.0.0:0 <-> 0.0.0.0:0 - GET http:///", str) } // go test -run Test_Ctx_IsFromLocal_X_Forwarded func Test_Ctx_IsFromLocal_X_Forwarded(t *testing.T) { t.Parallel() // Test unset X-Forwarded-For header. { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) // fasthttp returns "0.0.0.0" as IP as there is no remote address. require.Equal(t, "0.0.0.0", c.IP()) require.False(t, c.IsFromLocal()) } // Test when setting X-Forwarded-For header to localhost "127.0.0.1" { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderXForwardedFor, "127.0.0.1") defer app.ReleaseCtx(c) require.False(t, c.IsFromLocal()) } // Test when setting X-Forwarded-For header to localhost "::1" { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderXForwardedFor, "::1") defer app.ReleaseCtx(c) require.False(t, c.IsFromLocal()) } // Test when setting X-Forwarded-For to full localhost IPv6 address "0:0:0:0:0:0:0:1" { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderXForwardedFor, "0:0:0:0:0:0:0:1") defer app.ReleaseCtx(c) require.False(t, c.IsFromLocal()) } // Test for a random IP address. { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderXForwardedFor, "93.46.8.90") require.False(t, c.Req().IsFromLocal()) } } // go test -run Test_Ctx_IsFromLocal_RemoteAddr func Test_Ctx_IsFromLocal_RemoteAddr(t *testing.T) { t.Parallel() localIPv4 := net.Addr(&net.TCPAddr{IP: net.ParseIP("127.0.0.1")}) localIPv6 := net.Addr(&net.TCPAddr{IP: net.ParseIP("::1")}) localIPv6long := net.Addr(&net.TCPAddr{IP: net.ParseIP("0:0:0:0:0:0:0:1")}) zeroIPv4 := net.Addr(&net.TCPAddr{IP: net.IPv4zero}) someIPv4 := net.Addr(&net.TCPAddr{IP: net.ParseIP("93.46.8.90")}) someIPv6 := net.Addr(&net.TCPAddr{IP: net.ParseIP("2001:0db8:85a3:0000:0000:8a2e:0370:7334")}) // Test for the case fasthttp remoteAddr is set to "127.0.0.1". { app := New() fastCtx := &fasthttp.RequestCtx{} fastCtx.SetRemoteAddr(localIPv4) c := app.AcquireCtx(fastCtx) defer app.ReleaseCtx(c) require.Equal(t, "127.0.0.1", c.IP()) require.True(t, c.IsFromLocal()) } // Test for the case fasthttp remoteAddr is set to "::1". { app := New() fastCtx := &fasthttp.RequestCtx{} fastCtx.SetRemoteAddr(localIPv6) c := app.AcquireCtx(fastCtx) defer app.ReleaseCtx(c) require.Equal(t, "::1", c.Req().IP()) require.True(t, c.Req().IsFromLocal()) } // Test for the case fasthttp remoteAddr is set to "0:0:0:0:0:0:0:1". { app := New() fastCtx := &fasthttp.RequestCtx{} fastCtx.SetRemoteAddr(localIPv6long) c := app.AcquireCtx(fastCtx) defer app.ReleaseCtx(c) // fasthttp should return "::1" for "0:0:0:0:0:0:0:1". // otherwise IsFromLocal() will break. require.Equal(t, "::1", c.IP()) require.True(t, c.IsFromLocal()) } // Test for the case fasthttp remoteAddr is set to "0.0.0.0". { app := New() fastCtx := &fasthttp.RequestCtx{} fastCtx.SetRemoteAddr(zeroIPv4) c := app.AcquireCtx(fastCtx) defer app.ReleaseCtx(c) require.Equal(t, "0.0.0.0", c.IP()) require.False(t, c.IsFromLocal()) } // Test for the case fasthttp remoteAddr is set to "93.46.8.90". { app := New() fastCtx := &fasthttp.RequestCtx{} fastCtx.SetRemoteAddr(someIPv4) c := app.AcquireCtx(fastCtx) defer app.ReleaseCtx(c) require.Equal(t, "93.46.8.90", c.IP()) require.False(t, c.IsFromLocal()) } // Test for the case fasthttp remoteAddr is set to "2001:0db8:85a3:0000:0000:8a2e:0370:7334". { app := New() fastCtx := &fasthttp.RequestCtx{} fastCtx.SetRemoteAddr(someIPv6) c := app.AcquireCtx(fastCtx) defer app.ReleaseCtx(c) require.Equal(t, "2001:db8:85a3::8a2e:370:7334", c.IP()) require.False(t, c.IsFromLocal()) } // Test for the case fasthttp remoteAddr is set to a Unix socket. // Unix sockets are inherently local - only processes on the same host can connect. { app := New() fastCtx := &fasthttp.RequestCtx{} unixAddr := &net.UnixAddr{Name: "/tmp/fiber.sock", Net: "unix"} fastCtx.SetRemoteAddr(unixAddr) c := app.AcquireCtx(fastCtx) defer app.ReleaseCtx(c) require.True(t, c.IsFromLocal()) } } // go test -run Test_Ctx_extractIPsFromHeader -v func Test_Ctx_extractIPsFromHeader(t *testing.T) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set("x-forwarded-for", "1.1.1.1,8.8.8.8 , /n, \n,1.1, a.c, 6.,6., , a,,42.118.81.169,10.0.137.108") ips := c.IPs() res := ips[len(ips)-2] require.Equal(t, "42.118.81.169", res) } // go test -run Test_Ctx_extractIPsFromHeader -v func Test_Ctx_extractIPsFromHeader_EnableValidateIp(t *testing.T) { app := New() app.config.EnableIPValidation = true c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set("x-forwarded-for", "1.1.1.1,8.8.8.8 , /n, \n,1.1, a.c, 6.,6., , a,,42.118.81.169,10.0.137.108") ips := c.IPs() res := ips[len(ips)-2] require.Equal(t, "42.118.81.169", res) } // go test -run Test_Ctx_GetRespHeaders func Test_Ctx_GetRespHeaders(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Set("test", "Hello, World 👋!") c.Set("foo", "bar") c.Response().Header.Set("multi", "one") c.Response().Header.Add("multi", "two") c.Response().Header.Set(HeaderContentType, "application/json") require.Equal(t, map[string][]string{ "Content-Type": {"application/json"}, "Foo": {"bar"}, "Multi": {"one", "two"}, "Test": {"Hello, World 👋!"}, }, c.GetRespHeaders()) require.Equal(t, map[string][]string{ "Content-Type": {"application/json"}, "Foo": {"bar"}, "Multi": {"one", "two"}, "Test": {"Hello, World 👋!"}, }, c.Res().GetHeaders()) } func Benchmark_Ctx_GetRespHeaders(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Response().Header.Set("test", "Hello, World 👋!") c.Response().Header.Set("foo", "bar") c.Response().Header.Set(HeaderContentType, "application/json") b.ReportAllocs() var headers map[string][]string for b.Loop() { headers = c.GetRespHeaders() } require.Equal(b, map[string][]string{ "Content-Type": {"application/json"}, "Foo": {"bar"}, "Test": {"Hello, World 👋!"}, }, headers) } // go test -run Test_Ctx_GetReqHeaders func Test_Ctx_GetReqHeaders(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set("test", "Hello, World 👋!") c.Request().Header.Set("foo", "bar") c.Request().Header.Set("multi", "one") c.Request().Header.Add("multi", "two") c.Request().Header.Set(HeaderContentType, "application/json") require.Equal(t, map[string][]string{ "Content-Type": {"application/json"}, "Foo": {"bar"}, "Test": {"Hello, World 👋!"}, "Multi": {"one", "two"}, }, c.GetReqHeaders()) require.Equal(t, map[string][]string{ "Content-Type": {"application/json"}, "Foo": {"bar"}, "Test": {"Hello, World 👋!"}, "Multi": {"one", "two"}, }, c.GetHeaders()) } func Test_Ctx_Set_SanitizeHeaderValue(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Set("X-Test", "foo\r\nbar: bad") headerVal := string(c.Response().Header.Peek("X-Test")) require.Equal(t, "foo bar: bad", headerVal) } func Benchmark_Ctx_GetReqHeaders(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set("test", "Hello, World 👋!") c.Request().Header.Set("foo", "bar") c.Request().Header.Set(HeaderContentType, "application/json") b.ReportAllocs() var headers map[string][]string for b.Loop() { headers = c.GetReqHeaders() } require.Equal(b, map[string][]string{ "Content-Type": {"application/json"}, "Foo": {"bar"}, "Test": {"Hello, World 👋!"}, }, headers) } // go test -run Test_Ctx_Drop -v func Test_Ctx_Drop(t *testing.T) { t.Parallel() app := New() // Handler that calls Drop app.Get("/block-me", func(c Ctx) error { return c.Drop() }) // Additional handler that just calls return app.Get("/no-response", func(_ Ctx) error { return nil }) // Test the Drop method resp, err := app.Test(httptest.NewRequest(MethodGet, "/block-me", http.NoBody)) require.ErrorIs(t, err, ErrTestGotEmptyResponse) require.Nil(t, resp) // Test the no-response handler resp, err = app.Test(httptest.NewRequest(MethodGet, "/no-response", http.NoBody)) require.NoError(t, err) require.NotNil(t, resp) require.Equal(t, StatusOK, resp.StatusCode) require.Equal(t, "0", resp.Header.Get("Content-Length")) } // go test -run Test_Ctx_DropWithMiddleware -v func Test_Ctx_DropWithMiddleware(t *testing.T) { t.Parallel() app := New() // Middleware that calls Drop app.Use(func(c Ctx) error { err := c.Next() c.Set("X-Test", "test") return err }) // Handler that calls Drop app.Get("/block-me", func(c Ctx) error { return c.Drop() }) // Test the Drop method resp, err := app.Test(httptest.NewRequest(MethodGet, "/block-me", http.NoBody)) require.ErrorIs(t, err, ErrTestGotEmptyResponse) require.Nil(t, resp) } // go test -run Test_Ctx_End func Test_Ctx_End(t *testing.T) { app := New() app.Get("/", func(c Ctx) error { c.SendString("Hello, World!") //nolint:errcheck // unnecessary to check error return c.End() }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/", http.NoBody)) require.NoError(t, err) require.NotNil(t, resp) require.Equal(t, StatusOK, resp.StatusCode) body, err := io.ReadAll(resp.Body) require.NoError(t, err, "io.ReadAll(resp.Body)") require.Equal(t, "Hello, World!", string(body)) } // go test -run Test_Ctx_End_after_timeout func Test_Ctx_End_after_timeout(t *testing.T) { app := New() // Early flushing handler app.Get("/", func(c Ctx) error { time.Sleep(2 * time.Second) return c.End() }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/", http.NoBody)) require.ErrorIs(t, err, os.ErrDeadlineExceeded) require.Nil(t, resp) } // go test -run Test_Ctx_End_with_drop_middleware func Test_Ctx_End_with_drop_middleware(t *testing.T) { app := New() // Middleware that will drop connections // that persist after c.Next() app.Use(func(c Ctx) error { c.Next() //nolint:errcheck // unnecessary to check error return c.Drop() }) // Early flushing handler app.Get("/", func(c Ctx) error { c.SendStatus(StatusOK) //nolint:errcheck // unnecessary to check error return c.End() }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/", http.NoBody)) require.NoError(t, err) require.NotNil(t, resp) require.Equal(t, StatusOK, resp.StatusCode) } // go test -run Test_Ctx_End_after_drop func Test_Ctx_End_after_drop(t *testing.T) { app := New() // Middleware that ends the request // after c.Next() app.Use(func(c Ctx) error { c.Next() //nolint:errcheck // unnecessary to check error return c.End() }) // Early flushing handler app.Get("/", func(c Ctx) error { return c.Drop() }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/", http.NoBody)) require.ErrorIs(t, err, ErrTestGotEmptyResponse) require.Nil(t, resp) } // go test -run Test_Ctx_OverrideParam func Test_Ctx_OverrideParam(t *testing.T) { t.Parallel() t.Run("route_params", func(t *testing.T) { // a basic request to check if OverrideParam functions correctly on different scenarios // - Does it change an existing param (it should) // - Does it ignore a non-existing param (it should) t.Parallel() app := New() app.Get("/user/:name/:id", func(c Ctx) error { c.OverrideParam("name", "overridden") c.OverrideParam("nonexistent", "ignored") require.Equal(t, "overridden", c.Params("name")) require.Equal(t, "123", c.Params("id")) require.Empty(t, c.Params("nonexistent")) require.Equal(t, []string{"name", "id"}, c.Route().Params) return c.JSON(map[string]any{ "name": c.Params("name"), "id": c.Params("id"), "all": c.Route().Params, }) }) req, err := http.NewRequest(http.MethodGet, "/user/original/123", http.NoBody) require.NoError(t, err) resp, err := app.Test(req) require.NoError(t, err) defer func() { require.NoError(t, resp.Body.Close()) }() require.Equal(t, StatusOK, resp.StatusCode) }) t.Run("plus_wildcard_params", func(t *testing.T) { t.Parallel() app := New() app.Get("/files+/+", func(c Ctx) error { c.OverrideParam("+", "changed") c.OverrideParam("+2", "changed2") require.Equal(t, "changed", c.Params("+")) require.Equal(t, "changed2", c.Params("+2")) return nil }, ) req, err := http.NewRequest(http.MethodGet, "/filesoriginal/original2", http.NoBody) require.NoError(t, err) resp, err := app.Test(req) require.NoError(t, err) defer func() { require.NoError(t, resp.Body.Close()) }() require.Equal(t, StatusOK, resp.StatusCode) }) t.Run("wildcard_params", func(t *testing.T) { t.Parallel() app := New() app.Get("/files/*", func(c Ctx) error { c.OverrideParam("*", "changed") require.Equal(t, "changed", c.Params("*")) return nil }) req, err := http.NewRequest(http.MethodGet, "/files/testing", http.NoBody) require.NoError(t, err) resp, err := app.Test(req) require.NoError(t, err) defer func() { require.NoError(t, resp.Body.Close()) }() require.Equal(t, StatusOK, resp.StatusCode) }) t.Run("multi_wildcard_params", func(t *testing.T) { t.Parallel() app := New() app.Get("/files/*/*", func(c Ctx) error { c.OverrideParam("*", "changed") c.OverrideParam("*2", "changed2") require.Equal(t, "changed", c.Params("*")) require.Equal(t, "changed2", c.Params("*2")) return nil }) req, err := http.NewRequest(http.MethodGet, "/files/testing/testing", http.NoBody) require.NoError(t, err) resp, err := app.Test(req) require.NoError(t, err) defer func() { require.NoError(t, resp.Body.Close()) }() require.Equal(t, StatusOK, resp.StatusCode) }) t.Run("case_sensitive", func(t *testing.T) { t.Parallel() // Ensure OverrideParam respects the CaseSensitive configuration app := New(Config{ CaseSensitive: true, }) app.Get("/user/:name", func(c Ctx) error { c.OverrideParam("name", "overridden") require.Equal(t, "overridden", c.Params("name")) require.Empty(t, c.Params("NAME")) return c.SendStatus(StatusOK) }) req, err := http.NewRequest(http.MethodGet, "/user/original", http.NoBody) require.NoError(t, err) resp, err := app.Test(req) require.NoError(t, err) defer func() { require.NoError(t, resp.Body.Close()) }() require.Equal(t, StatusOK, resp.StatusCode) }) t.Run("case_insensitive", func(t *testing.T) { t.Parallel() // CaseInsensitive mode (default) app := New(Config{ CaseSensitive: false, }) app.Get("/user/:name", func(c Ctx) error { c.OverrideParam("NAME", "overridden") require.Equal(t, "overridden", c.Params("name")) require.Equal(t, "overridden", c.Params("NAME")) return c.SendStatus(StatusOK) }) req, err := http.NewRequest(http.MethodGet, "/user/original", http.NoBody) require.NoError(t, err) resp, err := app.Test(req) require.NoError(t, err) defer func() { require.NoError(t, resp.Body.Close()) }() require.Equal(t, StatusOK, resp.StatusCode) }) t.Run("nil_router", func(t *testing.T) { t.Parallel() // Ensure OverrideParam handles nil route context gracefully app := New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) c, ok := ctx.(*DefaultCtx) require.True(t, ok) defer app.ReleaseCtx(c) c.route = nil c.OverrideParam("test", "value") // Should not change require.Empty(t, c.Params("test")) }) } func Test_Ctx_AbandonSkipsReleaseCtx(t *testing.T) { t.Parallel() app := New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // controlled test setup ctx.route = &Route{} t.Cleanup(func() { ctx.ForceRelease() }) require.False(t, ctx.IsAbandoned()) ctx.Abandon() require.True(t, ctx.IsAbandoned()) app.ReleaseCtx(ctx) require.True(t, ctx.IsAbandoned(), "ReleaseCtx must not pool abandoned contexts") require.NotNil(t, ctx.fasthttp, "ReleaseCtx should not reset fasthttp on abandoned ctx") require.NotNil(t, ctx.route, "ReleaseCtx should not reset route on abandoned ctx") } func Test_Ctx_ForceReleaseClearsAbandon(t *testing.T) { t.Parallel() app := New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // controlled test setup ctx.route = &Route{} ctx.Abandon() ctx.ForceRelease() require.False(t, ctx.IsAbandoned(), "ForceRelease should clear abandon flag") require.Nil(t, ctx.fasthttp, "ForceRelease should release fasthttp reference") require.Nil(t, ctx.route, "ForceRelease should reset route before pooling") } // go test -v -run=^$ -bench=Benchmark_Ctx_IsProxyTrusted -benchmem -count=4 func Benchmark_Ctx_IsProxyTrusted(b *testing.B) { // Scenario without trusted proxy check b.Run("NoProxyCheck", func(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com:8080/test") b.ReportAllocs() for b.Loop() { c.IsProxyTrusted() } app.ReleaseCtx(c) }) // Scenario without trusted proxy check in parallel b.Run("NoProxyCheckParallel", func(b *testing.B) { app := New() b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com:8080/test") for pb.Next() { c.IsProxyTrusted() } app.ReleaseCtx(c) }) }) // Scenario with trusted proxy check simple b.Run("WithProxyCheckSimple", func(b *testing.B) { app := New(Config{ TrustProxy: true, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com/test") c.Request().Header.Set(HeaderXForwardedHost, "google1.com") b.ReportAllocs() for b.Loop() { c.IsProxyTrusted() } app.ReleaseCtx(c) }) // Scenario with trusted proxy check simple in parallel b.Run("WithProxyCheckSimpleParallel", func(b *testing.B) { app := New(Config{ TrustProxy: true, }) b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com/") c.Request().Header.Set(HeaderXForwardedHost, "google1.com") for pb.Next() { c.IsProxyTrusted() } app.ReleaseCtx(c) }) }) // Scenario with trusted proxy check b.Run("WithProxyCheck", func(b *testing.B) { app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{ Proxies: []string{"0.0.0.0"}, }, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com/test") c.Request().Header.Set(HeaderXForwardedHost, "google1.com") b.ReportAllocs() for b.Loop() { c.IsProxyTrusted() } app.ReleaseCtx(c) }) // Scenario with trusted proxy check in parallel b.Run("WithProxyCheckParallel", func(b *testing.B) { app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{ Proxies: []string{"0.0.0.0"}, }, }) b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com/") c.Request().Header.Set(HeaderXForwardedHost, "google1.com") for pb.Next() { c.IsProxyTrusted() } app.ReleaseCtx(c) }) }) // Scenario with trusted proxy check allow private b.Run("WithProxyCheckAllowPrivate", func(b *testing.B) { app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{ Private: true, }, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com/test") c.Request().Header.Set(HeaderXForwardedHost, "google1.com") b.ReportAllocs() for b.Loop() { c.IsProxyTrusted() } app.ReleaseCtx(c) }) // Scenario with trusted proxy check allow private in parallel b.Run("WithProxyCheckAllowPrivateParallel", func(b *testing.B) { app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{ Private: true, }, }) b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com/") c.Request().Header.Set(HeaderXForwardedHost, "google1.com") for pb.Next() { c.IsProxyTrusted() } app.ReleaseCtx(c) }) }) // Scenario with trusted proxy check allow private as subnets b.Run("WithProxyCheckAllowPrivateAsSubnets", func(b *testing.B) { app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{ Proxies: []string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "fc00::/7"}, }, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com/test") c.Request().Header.Set(HeaderXForwardedHost, "google1.com") b.ReportAllocs() for b.Loop() { c.IsProxyTrusted() } app.ReleaseCtx(c) }) // Scenario with trusted proxy check allow private as subnets in parallel b.Run("WithProxyCheckAllowPrivateAsSubnetsParallel", func(b *testing.B) { app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{ Proxies: []string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "fc00::/7"}, }, }) b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com/") c.Request().Header.Set(HeaderXForwardedHost, "google1.com") for pb.Next() { c.IsProxyTrusted() } app.ReleaseCtx(c) }) }) // Scenario with trusted proxy check allow private, loopback, and link-local b.Run("WithProxyCheckAllowAll", func(b *testing.B) { app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{ Private: true, Loopback: true, LinkLocal: true, }, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com/test") c.Request().Header.Set(HeaderXForwardedHost, "google1.com") b.ReportAllocs() for b.Loop() { c.IsProxyTrusted() } app.ReleaseCtx(c) }) // Scenario with trusted proxy check allow private, loopback, and link-local in parallel b.Run("WithProxyCheckAllowAllParallel", func(b *testing.B) { app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{ Private: true, Loopback: true, LinkLocal: true, }, }) b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com/") c.Request().Header.Set(HeaderXForwardedHost, "google1.com") for pb.Next() { c.IsProxyTrusted() } app.ReleaseCtx(c) }) }) // Scenario with trusted proxy check allow private, loopback, and link-local as subnets b.Run("WithProxyCheckAllowAllowAllAsSubnets", func(b *testing.B) { app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{ Proxies: []string{ // Link-local "169.254.0.0/16", "fe80::/10", // Loopback "127.0.0.0/8", "::1/128", // Private "10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "fc00::/7", }, }, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com/test") c.Request().Header.Set(HeaderXForwardedHost, "google1.com") b.ReportAllocs() b.ResetTimer() for b.Loop() { c.IsProxyTrusted() } app.ReleaseCtx(c) }) // Scenario with trusted proxy check allow private, loopback, and link-local as subnets in parallel b.Run("WithProxyCheckAllowAllowAllAsSubnetsParallel", func(b *testing.B) { app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{ Proxies: []string{ // Link-local "169.254.0.0/16", "fe80::/10", // Loopback "127.0.0.0/8", "::1/128", // Private "10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "fc00::/7", }, }, }) b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com/") c.Request().Header.Set(HeaderXForwardedHost, "google1.com") for pb.Next() { c.IsProxyTrusted() } app.ReleaseCtx(c) }) }) // Scenario with trusted proxy check with subnet b.Run("WithProxyCheckSubnet", func(b *testing.B) { app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{ Proxies: []string{"0.0.0.0/8"}, }, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com/test") c.Request().Header.Set(HeaderXForwardedHost, "google1.com") b.ReportAllocs() for b.Loop() { c.IsProxyTrusted() } app.ReleaseCtx(c) }) // Scenario with trusted proxy check with subnet in parallel b.Run("WithProxyCheckParallelSubnet", func(b *testing.B) { app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{ Proxies: []string{"0.0.0.0/8"}, }, }) b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com/") c.Request().Header.Set(HeaderXForwardedHost, "google1.com") for pb.Next() { c.IsProxyTrusted() } app.ReleaseCtx(c) }) }) // Scenario with trusted proxy check with multiple subnet b.Run("WithProxyCheckMultipleSubnet", func(b *testing.B) { app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{ Proxies: []string{"192.168.0.0/24", "10.0.0.0/16", "0.0.0.0/8"}, }, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com/test") c.Request().Header.Set(HeaderXForwardedHost, "google1.com") b.ReportAllocs() for b.Loop() { c.IsProxyTrusted() } app.ReleaseCtx(c) }) // Scenario with trusted proxy check with multiple subnet in parallel b.Run("WithProxyCheckParallelMultipleSubnet", func(b *testing.B) { app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{ Proxies: []string{"192.168.0.0/24", "10.0.0.0/16", "0.0.0.0/8"}, }, }) b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com/") c.Request().Header.Set(HeaderXForwardedHost, "google1.com") for pb.Next() { c.IsProxyTrusted() } app.ReleaseCtx(c) }) }) // Scenario with trusted proxy check with all subnets b.Run("WithProxyCheckAllSubnets", func(b *testing.B) { app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{ Proxies: []string{ "127.0.0.0/8", // Loopback addresses "169.254.0.0/16", // Link-Local addresses "fe80::/10", // Link-Local addresses "192.168.0.0/16", // Private Network addresses "172.16.0.0/12", // Private Network addresses "10.0.0.0/8", // Private Network addresses "fc00::/7", // Unique Local addresses "173.245.48.0/20", // My custom range "0.0.0.0/8", // All IPv4 addresses }, }, }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com/test") c.Request().Header.Set(HeaderXForwardedHost, "google1.com") b.ReportAllocs() for b.Loop() { c.IsProxyTrusted() } app.ReleaseCtx(c) }) // Scenario with trusted proxy check with all subnets in parallel b.Run("WithProxyCheckParallelAllSubnets", func(b *testing.B) { app := New(Config{ TrustProxy: true, TrustProxyConfig: TrustProxyConfig{ Proxies: []string{ "127.0.0.0/8", // Loopback addresses "169.254.0.0/16", // Link-Local addresses "fe80::/10", // Link-Local addresses "192.168.0.0/16", // Private Network addresses "172.16.0.0/12", // Private Network addresses "10.0.0.0/8", // Private Network addresses "fc00::/7", // Unique Local addresses "173.245.48.0/20", // My custom range "0.0.0.0/8", // All IPv4 addresses }, }, }) b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com/") c.Request().Header.Set(HeaderXForwardedHost, "google1.com") for pb.Next() { c.IsProxyTrusted() } app.ReleaseCtx(c) }) }) } func Benchmark_Ctx_IsFromLocalhost(b *testing.B) { // Scenario without localhost check b.Run("Non_Localhost", func(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://google.com:8080/test") b.ReportAllocs() for b.Loop() { c.IsFromLocal() } app.ReleaseCtx(c) }) // Scenario with localhost check b.Run("Localhost", func(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().SetRequestURI("http://localhost:8080/test") b.ReportAllocs() for b.Loop() { c.IsFromLocal() } app.ReleaseCtx(c) }) } // go test -v -run=^$ -bench=Benchmark_Ctx_OverrideParam -benchmem -count=4 func Benchmark_Ctx_OverrideParam(b *testing.B) { app := New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) c, ok := ctx.(*DefaultCtx) if !ok { b.Fatal("AcquireCtx did not return *DefaultCtx") } defer app.ReleaseCtx(c) c.values = [maxParams]string{"original", "12345"} c.route = &Route{Params: []string{"name", "id"}} c.setMatched(true) b.ReportAllocs() b.ResetTimer() for b.Loop() { c.OverrideParam("name", "changed") } } ================================================ FILE: docs/addon/_category_.json ================================================ { "label": "\uD83D\uDD0C Addon", "position": 5, "collapsed": true, "link": { "type": "generated-index", "description": "Addon is an additional useful package that can be used in Fiber." } } ================================================ FILE: docs/addon/retry.md ================================================ --- id: retry --- # Retry Addon The Retry addon for [Fiber](https://github.com/gofiber/fiber) retries failed network operations using exponential backoff with jitter. It repeatedly invokes a function until it succeeds or the maximum number of attempts is exhausted. Jitter at each step breaks client synchronization and helps avoid collisions. If all attempts fail, the addon returns an error. ## Table of Contents - [Signatures](#signatures) - [Examples](#examples) - [Default Config](#default-config) - [Custom Config](#custom-config) - [Config](#config) - [Default Config Example](#default-config-example) ## Signatures ```go func NewExponentialBackoff(config ...retry.Config) *retry.ExponentialBackoff ``` ## Examples ```go package main import ( "fmt" "github.com/gofiber/fiber/v3/addon/retry" "github.com/gofiber/fiber/v3/client" ) func main() { expBackoff := retry.NewExponentialBackoff(retry.Config{}) // Local variables used inside Retry var resp *client.Response var err error // Retry a network request and return an error to signal another attempt err = expBackoff.Retry(func() error { client := client.New() resp, err = client.Get("https://gofiber.io") if err != nil { return fmt.Errorf("GET gofiber.io failed: %w", err) } if resp.StatusCode() != 200 { return fmt.Errorf("GET gofiber.io did not return 200 OK") } return nil }) // If all retries failed, panic if err != nil { panic(err) } fmt.Printf("GET gofiber.io succeeded with status code %d\n", resp.StatusCode()) } ``` ## Default Config ```go retry.NewExponentialBackoff() ``` ## Custom Config ```go retry.NewExponentialBackoff(retry.Config{ InitialInterval: 2 * time.Second, MaxBackoffTime: 64 * time.Second, Multiplier: 2.0, MaxRetryCount: 15, }) ``` ## Config ```go // Config defines the config for addon. type Config struct { // InitialInterval defines the initial time interval for backoff algorithm. // // Optional. Default: 1 * time.Second InitialInterval time.Duration // MaxBackoffTime defines maximum time duration for backoff algorithm. When // the algorithm is reached this time, rest of the retries will be maximum // 32 seconds. // // Optional. Default: 32 * time.Second MaxBackoffTime time.Duration // Multiplier defines multiplier number of the backoff algorithm. // // Optional. Default: 2.0 Multiplier float64 // MaxRetryCount defines maximum retry count for the backoff algorithm. // // Optional. Default: 10 MaxRetryCount int } ``` ## Default Config Example ```go // DefaultConfig is the default config for retry. var DefaultConfig = Config{ InitialInterval: 1 * time.Second, MaxBackoffTime: 32 * time.Second, Multiplier: 2.0, MaxRetryCount: 10, currentInterval: 1 * time.Second, } ``` ================================================ FILE: docs/api/_category_.json ================================================ { "label": "\uD83D\uDEE0\uFE0F API", "position": 3, "link": { "type": "generated-index", "description": "API documentation for Fiber." } } ================================================ FILE: docs/api/app.md ================================================ --- id: app title: 🚀 App description: The `App` type represents your Fiber application. sidebar_position: 2 --- import Reference from '@site/src/components/reference'; ## Helpers ### GetString Returns `s` unchanged when [`Immutable`](./fiber.md#immutable) is disabled or `s` resides in read-only memory. Otherwise, it returns a detached copy using `strings.Clone`. ```go title="Signature" func (app *App) GetString(s string) string ``` ### GetBytes Returns `b` unchanged when [`Immutable`](./fiber.md#immutable) is disabled or `b` resides in read-only memory. Otherwise, it returns a detached copy. ```go title="Signature" func (app *App) GetBytes(b []byte) []byte ``` ### ReloadViews Reloads the configured view engine on demand by calling its `Load` method. Use this helper in development workflows (e.g., file watchers or debug-only routes) to pick up template changes without restarting the server. Returns an error if no view engine is configured or reloading fails. ```go title="Signature" func (app *App) ReloadViews() error ``` ```go title="Example" app := fiber.New(fiber.Config{Views: engine}) app.Get("/dev/reload", func(c fiber.Ctx) error { if err := app.ReloadViews(); err != nil { return err } return c.SendString("Templates reloaded") }) ``` ## Routing import RoutingHandler from './../partials/routing/handler.md'; ### Route Handlers ### Mounting Mount another Fiber instance with [`app.Use`](./app.md#use), similar to Express's [`router.use`](https://expressjs.com/en/api.html#router.use). ```go title="Example" package main import ( "log" "github.com/gofiber/fiber/v3" ) func main() { app := fiber.New() micro := fiber.New() // Mount the micro app on the "/john" route app.Use("/john", micro) // GET /john/doe -> 200 OK micro.Get("/doe", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) log.Fatal(app.Listen(":3000")) } ``` ### MountPath The `MountPath` property contains one or more path patterns on which a sub-app was mounted. ```go title="Signature" func (app *App) MountPath() string ``` ```go title="Example" package main import ( "fmt" "github.com/gofiber/fiber/v3" ) func main() { app := fiber.New() one := fiber.New() two := fiber.New() three := fiber.New() two.Use("/three", three) one.Use("/two", two) app.Use("/one", one) fmt.Println("Mount paths:") fmt.Println("one.MountPath():", one.MountPath()) // "/one" fmt.Println("two.MountPath():", two.MountPath()) // "/one/two" fmt.Println("three.MountPath():", three.MountPath()) // "/one/two/three" fmt.Println("app.MountPath():", app.MountPath()) // "" } ``` :::caution Mounting order is important for `MountPath`. To get mount paths properly, you should start mounting from the deepest app. ::: ### Group You can group routes by creating a `*Group` struct. ```go title="Signature" func (app *App) Group(prefix string, handlers ...any) Router ``` ```go title="Example" package main import ( "log" "github.com/gofiber/fiber/v3" ) func main() { app := fiber.New() api := app.Group("/api", handler) // /api v1 := api.Group("/v1", handler) // /api/v1 v1.Get("/list", handler) // /api/v1/list v1.Get("/user", handler) // /api/v1/user v2 := api.Group("/v2", handler) // /api/v2 v2.Get("/list", handler) // /api/v2/list v2.Get("/user", handler) // /api/v2/user log.Fatal(app.Listen(":3000")) } func handler(c fiber.Ctx) error { return c.SendString("Handler response") } ``` ### RouteChain Returns an instance of a single route, which you can then use to handle HTTP verbs with optional middleware. Similar to [`Express`](https://expressjs.com/en/api.html#app.route). ```go title="Signature" func (app *App) RouteChain(path string) Register ```
Click here to see the `Register` interface ```go type Register interface { All(handler any, handlers ...any) Register Get(handler any, handlers ...any) Register Head(handler any, handlers ...any) Register Post(handler any, handlers ...any) Register Put(handler any, handlers ...any) Register Delete(handler any, handlers ...any) Register Connect(handler any, handlers ...any) Register Options(handler any, handlers ...any) Register Trace(handler any, handlers ...any) Register Patch(handler any, handlers ...any) Register Add(methods []string, handler any, handlers ...any) Register RouteChain(path string) Register } ```
```go title="Example" package main import ( "log" "github.com/gofiber/fiber/v3" ) func main() { app := fiber.New() // Use `RouteChain` as a chainable route declaration method app.RouteChain("/test").Get(func(c fiber.Ctx) error { return c.SendString("GET /test") }) app.RouteChain("/events").All(func(c fiber.Ctx) error { // Runs for all HTTP verbs first // Think of it as route-specific middleware! }). Get(func(c fiber.Ctx) error { return c.SendString("GET /events") }). Post(func(c fiber.Ctx) error { // Maybe add a new event... return c.SendString("POST /events") }) // Combine multiple routes app.RouteChain("/reports").RouteChain("/daily").Get(func(c fiber.Ctx) error { return c.SendString("GET /reports/daily") }) // Use multiple methods app.RouteChain("/api").Get(func(c fiber.Ctx) error { return c.SendString("GET /api") }).Post(func(c fiber.Ctx) error { return c.SendString("POST /api") }) log.Fatal(app.Listen(":3000")) } ``` ### Route Defines routes with a common prefix inside the supplied function. Internally it uses [`Group`](#group) to create a sub-router and accepts an optional name prefix. ```go title="Signature" func (app *App) Route(prefix string, fn func(router Router), name ...string) Router ``` ```go title="Example" app.Route("/test", func(api fiber.Router) { api.Get("/foo", handler).Name("foo") // /test/foo (name: test.foo) api.Get("/bar", handler).Name("bar") // /test/bar (name: test.bar) }, "test.") ``` ### HandlersCount Returns the number of registered handlers. ```go title="Signature" func (app *App) HandlersCount() uint32 ``` ### Stack Returns the underlying router stack. ```go title="Signature" func (app *App) Stack() [][]*Route ``` ```go title="Example" package main import ( "encoding/json" "log" "github.com/gofiber/fiber/v3" ) var handler = func(c fiber.Ctx) error { return nil } func main() { app := fiber.New() app.Get("/john/:age", handler) app.Post("/register", handler) data, _ := json.MarshalIndent(app.Stack(), "", " ") fmt.Println(string(data)) log.Fatal(app.Listen(":3000")) } ```
Click here to see the result ```json [ [ { "method": "GET", "path": "/john/:age", "params": [ "age" ] } ], [ { "method": "HEAD", "path": "/john/:age", "params": [ "age" ] } ], [ { "method": "POST", "path": "/register", "params": null } ] ] ```
### Name This method assigns the name to the latest created route. ```go title="Signature" func (app *App) Name(name string) Router ``` ```go title="Example" package main import ( "encoding/json" "fmt" "log" "github.com/gofiber/fiber/v3" ) func main() { var handler = func(c fiber.Ctx) error { return nil } app := fiber.New() app.Get("/", handler) app.Name("index") app.Get("/doe", handler).Name("home") app.Trace("/tracer", handler).Name("tracert") app.Delete("/delete", handler).Name("delete") a := app.Group("/a") a.Name("fd.") a.Get("/test", handler).Name("test") data, _ := json.MarshalIndent(app.Stack(), "", " ") fmt.Println(string(data)) log.Fatal(app.Listen(":3000")) } ```
Click here to see the result ```json [ [ { "method": "GET", "name": "index", "path": "/", "params": null }, { "method": "GET", "name": "home", "path": "/doe", "params": null }, { "method": "GET", "name": "fd.test", "path": "/a/test", "params": null } ], [ { "method": "HEAD", "name": "", "path": "/", "params": null }, { "method": "HEAD", "name": "", "path": "/doe", "params": null }, { "method": "HEAD", "name": "", "path": "/a/test", "params": null } ], null, null, [ { "method": "DELETE", "name": "delete", "path": "/delete", "params": null } ], null, null, [ { "method": "TRACE", "name": "tracert", "path": "/tracer", "params": null } ], null ] ```
### GetRoute This method retrieves a route by its name. ```go title="Signature" func (app *App) GetRoute(name string) Route ``` ```go title="Example" package main import ( "encoding/json" "fmt" "log" "github.com/gofiber/fiber/v3" ) func main() { app := fiber.New() app.Get("/", handler).Name("index") route := app.GetRoute("index") data, _ := json.MarshalIndent(route, "", " ") fmt.Println(string(data)) log.Fatal(app.Listen(":3000")) } ```
Click here to see the result ```json { "method": "GET", "name": "index", "path": "/", "params": null } ```
### GetRoutes This method retrieves all routes. ```go title="Signature" func (app *App) GetRoutes(filterUseOption ...bool) []Route ``` When `filterUseOption` is set to `true`, it filters out routes registered by middleware. ```go title="Example" package main import ( "encoding/json" "fmt" "log" "github.com/gofiber/fiber/v3" ) func main() { app := fiber.New() app.Post("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }).Name("index") routes := app.GetRoutes(true) data, _ := json.MarshalIndent(routes, "", " ") fmt.Println(string(data)) log.Fatal(app.Listen(":3000")) } ```
Click here to see the result ```json [ { "method": "POST", "name": "index", "path": "/", "params": null } ] ```
## Config `Config` returns the [app config](./fiber.md#config) as a value (read-only). ```go title="Signature" func (app *App) Config() Config ``` ## Handler `Handler` returns the server handler that can be used to serve custom [`\*fasthttp.RequestCtx`](https://pkg.go.dev/github.com/valyala/fasthttp#RequestCtx) requests. ```go title="Signature" func (app *App) Handler() fasthttp.RequestHandler ``` ## ErrorHandler `ErrorHandler` executes the process defined for the application in case of errors. This is used in some cases in middlewares. ```go title="Signature" func (app *App) ErrorHandler(ctx Ctx, err error) error ``` ## NewWithCustomCtx `NewWithCustomCtx` creates a new `*App` and sets the custom context factory function at construction time. ```go title="Signature" func NewWithCustomCtx(fn func(app *App) CustomCtx, config ...Config) *App ``` ```go title="Example" package main import ( "log" "github.com/gofiber/fiber/v3" ) type CustomCtx struct { fiber.DefaultCtx } func (c *CustomCtx) Params(key string, defaultValue ...string) string { return "prefix_" + c.DefaultCtx.Params(key) } func main() { app := fiber.NewWithCustomCtx(func(app *fiber.App) fiber.CustomCtx { return &CustomCtx{ DefaultCtx: *fiber.NewDefaultCtx(app), } }) app.Get("/:id", func(c fiber.Ctx) error { return c.SendString(c.Params("id")) }) log.Fatal(app.Listen(":3000")) } ``` ## RegisterCustomBinder You can register custom binders to use with [`Bind().Custom("name")`](bind.md#custom). They should be compatible with the `CustomBinder` interface. ```go title="Signature" func (app *App) RegisterCustomBinder(binder CustomBinder) ``` ```go title="Example" package main import ( "log" "github.com/gofiber/fiber/v3" "gopkg.in/yaml.v2" ) type User struct { Name string `yaml:"name"` } type customBinder struct{} func (*customBinder) Name() string { return "custom" } func (*customBinder) MIMETypes() []string { return []string{"application/yaml"} } func (*customBinder) Parse(c fiber.Ctx, out any) error { // Parse YAML body return yaml.Unmarshal(c.Body(), out) } func main() { app := fiber.New() // Register custom binder app.RegisterCustomBinder(&customBinder{}) app.Post("/custom", func(c fiber.Ctx) error { var user User // Use Custom binder by name if err := c.Bind().Custom("custom", &user); err != nil { return err } return c.JSON(user) }) app.Post("/normal", func(c fiber.Ctx) error { var user User // Custom binder is used by the MIME type if err := c.Bind().Body(&user); err != nil { return err } return c.JSON(user) }) log.Fatal(app.Listen(":3000")) } ``` ## RegisterCustomConstraint `RegisterCustomConstraint` allows you to register custom constraints. ```go title="Signature" func (app *App) RegisterCustomConstraint(constraint CustomConstraint) ``` See the [Custom Constraint](../guide/routing.md#custom-constraint) section for more information. ## SetTLSHandler Use `SetTLSHandler` to set [`ClientHelloInfo`](https://datatracker.ietf.org/doc/html/rfc8446#section-4.1.2) when using TLS with a `Listener`. ```go title="Signature" func (app *App) SetTLSHandler(tlsHandler *TLSHandler) ``` ## Test Testing your application is done with the `Test` method. Use this method for creating `_test.go` files or when you need to debug your routing logic. The default timeout is `1s`; to disable a timeout altogether, pass a `TestConfig` struct with `Timeout: 0`. ```go title="Signature" func (app *App) Test(req *http.Request, config ...TestConfig) (*http.Response, error) ``` ```go title="Example" package main import ( "fmt" "io" "log" "net/http" "net/http/httptest" "github.com/gofiber/fiber/v3" ) func main() { app := fiber.New() // Create route with GET method for test: app.Get("/", func(c fiber.Ctx) error { fmt.Println(c.BaseURL()) // => http://google.com fmt.Println(c.Get("X-Custom-Header")) // => hi return c.SendString("hello, World!") }) // Create http.Request req := httptest.NewRequest("GET", "http://google.com", nil) req.Header.Set("X-Custom-Header", "hi") // Perform the test resp, _ := app.Test(req) // Do something with the results: if resp.StatusCode == fiber.StatusOK { body, _ := io.ReadAll(resp.Body) fmt.Println(string(body)) // => hello, World! } } ``` If not provided, TestConfig is set to the following defaults: ```go title="Default TestConfig" config := fiber.TestConfig{ Timeout: time.Second, FailOnTimeout: true, } ``` :::caution This is **not** the same as supplying an empty `TestConfig{}` to `app.Test(), but rather be the equivalent of supplying: ```go title="Empty TestConfig" cfg := fiber.TestConfig{ Timeout: 0, FailOnTimeout: false, } ``` This would make a Test that has no timeout. ::: ## Hooks `Hooks` is a method to return the [hooks](./hooks.md) property. ```go title="Signature" func (app *App) Hooks() *Hooks ``` ## RebuildTree The `RebuildTree` method is designed to rebuild the route tree and enable dynamic route registration. It returns a pointer to the `App` instance. ```go title="Signature" func (app *App) RebuildTree() *App ``` **Note:** Use this method with caution. It is **not** thread-safe and calling it can be very performance-intensive, so it should be used sparingly and only in development mode. Avoid using it concurrently. ### Example Usage Here’s an example of how to define and register routes dynamically: ```go title="Example" package main import ( "log" "github.com/gofiber/fiber/v3" ) func main() { app := fiber.New() app.Get("/define", func(c fiber.Ctx) error { // Define a new route dynamically app.Get("/dynamically-defined", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) // Rebuild the route tree to register the new route app.RebuildTree() return c.SendStatus(fiber.StatusOK) }) log.Fatal(app.Listen(":3000")) } ``` In this example, a new route is defined and then `RebuildTree()` is called to ensure the new route is registered and available. ## RemoveRoute This method removes a route by path. You must call the `RebuildTree()` method after the removal to finalize the update and rebuild the routing tree. If no methods are specified, the route will be removed for all HTTP methods defined in the app. To limit removal to specific methods, provide them as additional arguments. ```go title="Signature" func (app *App) RemoveRoute(path string, methods ...string) ``` ```go title="Example" package main import ( "log" "github.com/gofiber/fiber/v3" ) func main() { app := fiber.New() app.Get("/api/feature-a", func(c fiber.Ctx) error { app.RemoveRoute("/api/feature", fiber.MethodGet) app.RebuildTree() // Redefine route app.Get("/api/feature", func(c fiber.Ctx) error { return c.SendString("Testing feature-a") }) app.RebuildTree() return c.SendStatus(fiber.StatusOK) }) app.Get("/api/feature-b", func(c fiber.Ctx) error { app.RemoveRoute("/api/feature", fiber.MethodGet) app.RebuildTree() // Redefine route app.Get("/api/feature", func(c fiber.Ctx) error { return c.SendString("Testing feature-b") }) app.RebuildTree() return c.SendStatus(fiber.StatusOK) }) log.Fatal(app.Listen(":3000")) } ``` ## RemoveRouteByName This method removes a route by name. If no methods are specified, the route will be removed for all HTTP methods defined in the app. To limit removal to specific methods, provide them as additional arguments. ```go title="Signature" func (app *App) RemoveRouteByName(name string, methods ...string) ``` ## RemoveRouteFunc This method removes a route by function having `*Route` parameter. If no methods are specified, the route will be removed for all HTTP methods defined in the app. To limit removal to specific methods, provide them as additional arguments. ```go title="Signature" func (app *App) RemoveRouteFunc(matchFunc func(r *Route) bool, methods ...string) ``` ================================================ FILE: docs/api/bind.md ================================================ --- id: bind title: 📎 Bind description: Binds the request and response items to a struct. sidebar_position: 4 toc_max_heading_level: 4 --- Bindings parse request and response bodies, query parameters, cookies, and more into structs. :::info Binder-returned values are valid only within the handler. To keep them, copy the data or enable the [**`Immutable`**](./ctx.md) setting. [Read more...](../#zero-allocation) ::: ## Binders - [All](#all) - [Body](#body) - [CBOR](#cbor) - [Form](#form) - [JSON](#json) - [MsgPack](#msgpack) - [XML](#xml) - [Cookie](#cookie) - [Header](#header) - [Query](#query) - [RespHeader](#respheader) - [URI](#uri) ### All The `All` function binds data from URL parameters, the request body, query parameters, headers, and cookies into `out`. Sources are applied in the following order using struct field tags. #### Precedence Order The binding sources have the following precedence: 1. **URL Parameters (URI)** 2. **Request Body (e.g., JSON or form data)** 3. **Query Parameters** 4. **Request Headers** 5. **Cookies** :::info The request body is only included as a binding source when the request has both a non-empty body **and** a non-empty `Content-Type` header. ::: ```go title="Signature" func (b *Bind) All(out any) error ``` ```go title="Example" type User struct { Name string `query:"name" json:"name" form:"name"` Email string `json:"email" form:"email"` Role string `header:"X-User-Role"` SessionID string `json:"session_id" cookie:"session_id"` ID int `uri:"id" query:"id" json:"id" form:"id"` } app.Post("/users", func(c fiber.Ctx) error { user := new(User) if err := c.Bind().All(user); err != nil { return err } // All available data is now bound to the user struct return c.JSON(user) }) ``` ### Body Binds the request body to a struct. Use tags that match the content type. For example, to parse a JSON body with a `Pass` field, declare `json:"pass"`. | Content-Type | Struct Tag | | ----------------------------------- | ---------- | | `application/x-www-form-urlencoded` | `form` | | `multipart/form-data` | `form` | | `application/json` | `json` | | `application/xml` | `xml` | | `text/xml` | `xml` | | `application/vnd.msgpack` | `msgpack` | ```go title="Signature" func (b *Bind) Body(out any) error ``` ```go title="Example" type Person struct { Name string `json:"name" xml:"name" form:"name" msgpack:"name"` Pass string `json:"pass" xml:"pass" form:"pass" msgpack:"pass"` } app.Post("/", func(c fiber.Ctx) error { p := new(Person) if err := c.Bind().Body(p); err != nil { return err } log.Println(p.Name) // john log.Println(p.Pass) // doe // ... }) ``` Test the handler with these `curl` commands: ```bash # JSON curl -X POST -H "Content-Type: application/json" --data "{\"name\":\"john\",\"pass\":\"doe\"}" localhost:3000 # MsgPack curl -X POST -H "Content-Type: application/vnd.msgpack" --data-binary $'\x82\xa4name\xa4john\xa4pass\xa3doe' localhost:3000 # XML curl -X POST -H "Content-Type: application/xml" --data "johndoe" localhost:3000 # Form URL-Encoded curl -X POST -H "Content-Type: application/x-www-form-urlencoded" --data "name=john&pass=doe" localhost:3000 # Multipart Form curl -X POST -F name=john -F pass=doe http://localhost:3000 ``` ### CBOR > **Note:** Before using any CBOR-related features, make sure to follow the [CBOR setup instructions](../guide/advance-format.md#cbor). Binds the request CBOR body to a struct. It is important to specify the correct struct tag based on the content type to be parsed. For example, if you want to parse a CBOR body with a field called `Pass`, you would use a struct field with `cbor:"pass"`. ```go title="Signature" func (b *Bind) CBOR(out any) error ``` ```go title="Example" // Field names should start with an uppercase letter type Person struct { Name string `cbor:"name"` Pass string `cbor:"pass"` } app.Post("/", func(c fiber.Ctx) error { p := new(Person) if err := c.Bind().CBOR(p); err != nil { return err } log.Println(p.Name) // john log.Println(p.Pass) // doe // ... }) ``` Test the defaults with this `curl` command: ```bash curl -X POST -H "Content-Type: application/cbor" --data "\xa2dnamedjohndpasscdoe" localhost:3000 ``` ### Form Binds the request or multipart form body data to a struct. It is important to specify the correct struct tag based on the content type to be parsed. For example, if you want to parse a form body with a field called `Pass`, you would use a struct field with `form:"pass"`. ```go title="Signature" func (b *Bind) Form(out any) error ``` ```go title="Example" type Person struct { Name string `form:"name"` Pass string `form:"pass"` } app.Post("/", func(c fiber.Ctx) error { p := new(Person) if err := c.Bind().Form(p); err != nil { return err } log.Println(p.Name) // john log.Println(p.Pass) // doe // ... }) ``` Run tests with the following `curl` commands for both `application/x-www-form-urlencoded` and `multipart/form-data`: ```bash curl -X POST -H "Content-Type: application/x-www-form-urlencoded" --data "name=john&pass=doe" localhost:3000 ``` ```bash curl -X POST -H "Content-Type: multipart/form-data" -F "name=john" -F "pass=doe" localhost:3000 ``` :::info If you need to bind multipart file, you can use `*multipart.FileHeader`, `*[]*multipart.FileHeader` or `[]*multipart.FileHeader` as a field type. ::: ```go title="Example" type Person struct { Name string `form:"name"` Pass string `form:"pass"` Avatar *multipart.FileHeader `form:"avatar"` } app.Post("/", func(c fiber.Ctx) error { p := new(Person) if err := c.Bind().Form(p); err != nil { return err } log.Println(p.Name) // john log.Println(p.Pass) // doe log.Println(p.Avatar.Filename) // file.txt // ... }) ``` Run tests with the following `curl` command: ```bash curl -X POST -H "Content-Type: multipart/form-data" -F "name=john" -F "pass=doe" -F 'avatar=@filename' localhost:3000 ``` ### JSON Binds the request JSON body to a struct. It is important to specify the correct struct tag based on the content type to be parsed. For example, if you want to parse a JSON body with a field called `Pass`, you would use a struct field with `json:"pass"`. ```go title="Signature" func (b *Bind) JSON(out any) error ``` ```go title="Example" type Person struct { Name string `json:"name"` Pass string `json:"pass"` } app.Post("/", func(c fiber.Ctx) error { p := new(Person) if err := c.Bind().JSON(p); err != nil { return err } log.Println(p.Name) // john log.Println(p.Pass) // doe // ... }) ``` Run tests with the following `curl` command: ```bash curl -X POST -H "Content-Type: application/json" --data "{\"name\":\"john\",\"pass\":\"doe\"}" localhost:3000 ``` ### MsgPack > **Note:** Before using any MsgPack-related features, make sure to follow the [MsgPack setup instructions](../guide/advance-format.md#msgpack). Binds the request MsgPack body to a struct. It is important to specify the correct struct tag based on the content type to be parsed. For example, if you want to parse a Msgpack body with a field called `Pass`, you would use a struct field with `msgpack:"pass"`. > Our library uses [shamaton-msgpack](https://github.com/shamaton/msgpack) which uses `msgpack` struct tags by default. If you want to use other libraries, you may need to update the struct tags accordingly. ```go title="Signature" func (b *Bind) MsgPack(out any) error ``` ```go title="Example" type Person struct { Name string `msgpack:"name"` Pass string `msgpack:"pass"` } app.Post("/", func(c fiber.Ctx) error { p := new(Person) if err := c.Bind().MsgPack(p); err != nil { return err } log.Println(p.Name) // john log.Println(p.Pass) // doe // ... }) ``` Run tests with the following `curl` command: ```bash curl -X POST -H "Content-Type: application/vnd.msgpack" --data-binary $'\x82\xa4name\xa4john\xa4pass\xa3doe' localhost:3000 ``` ### XML Binds the request XML body to a struct. It is important to specify the correct struct tag based on the content type to be parsed. For example, if you want to parse an XML body with a field called `Pass`, you would use a struct field with `xml:"pass"`. ```go title="Signature" func (b *Bind) XML(out any) error ``` ```go title="Example" // Field names should start with an uppercase letter type Person struct { Name string `xml:"name"` Pass string `xml:"pass"` } app.Post("/", func(c fiber.Ctx) error { p := new(Person) if err := c.Bind().XML(p); err != nil { return err } log.Println(p.Name) // john log.Println(p.Pass) // doe // ... }) ``` Run tests with the following `curl` command: ```bash curl -X POST -H "Content-Type: application/xml" --data "johndoe" localhost:3000 ``` ### Cookie This method is similar to [Body Binding](#body), but for cookie parameters. It is important to use the struct tag `cookie`. For example, if you want to parse a cookie with a field called `Age`, you would use a struct field with `cookie:"age"`. ```go title="Signature" func (b *Bind) Cookie(out any) error ``` ```go title="Example" type Person struct { Name string `cookie:"name"` Age int `cookie:"age"` Job bool `cookie:"job"` } app.Get("/", func(c fiber.Ctx) error { p := new(Person) if err := c.Bind().Cookie(p); err != nil { return err } log.Println(p.Name) // Joseph log.Println(p.Age) // 23 log.Println(p.Job) // true }) ``` Run tests with the following `curl` command: ```bash curl --cookie "name=Joseph; age=23; job=true" http://localhost:8000/ ``` ### Header This method is similar to [Body Binding](#body), but for request headers. It is important to use the struct tag `header`. For example, if you want to parse a request header with a field called `Pass`, you would use a struct field with `header:"pass"`. ```go title="Signature" func (b *Bind) Header(out any) error ``` ```go title="Example" type Person struct { Name string `header:"name"` Pass string `header:"pass"` Products []string `header:"products"` } app.Get("/", func(c fiber.Ctx) error { p := new(Person) if err := c.Bind().Header(p); err != nil { return err } log.Println(p.Name) // john log.Println(p.Pass) // doe log.Println(p.Products) // [shoe hat] // ... }) ``` Run tests with the following `curl` command: ```bash curl "http://localhost:3000/" -H "name: john" -H "pass: doe" -H "products: shoe,hat" ``` ### Query This method is similar to [Body Binding](#body), but for query parameters. It is important to use the struct tag `query`. For example, if you want to parse a query parameter with a field called `Pass`, you would use a struct field with `query:"pass"`. ```go title="Signature" func (b *Bind) Query(out any) error ``` ```go title="Example" type Person struct { Name string `query:"name"` Pass string `query:"pass"` Products []string `query:"products"` } app.Get("/", func(c fiber.Ctx) error { p := new(Person) if err := c.Bind().Query(p); err != nil { return err } log.Println(p.Name) // john log.Println(p.Pass) // doe // Depending on fiber.Config{EnableSplittingOnParsers: false} - default log.Println(p.Products) // ["shoe,hat"] // With fiber.Config{EnableSplittingOnParsers: true} // log.Println(p.Products) // ["shoe", "hat"] // ... }) ``` Run tests with the following `curl` command: ```bash curl "http://localhost:3000/?name=john&pass=doe&products=shoe,hat" ``` :::info For more parser settings, please refer to [Config](fiber.md#enablesplittingonparsers) ::: #### Array Query Parameters Fiber supports several formats for passing array values via query parameters. The following table gives an overview: | Format | Example | Requires `EnableSplittingOnParsers` | | ------------------------ | ---------------------------------------------- | ----------------------------------- | | Repeated key | `?colors=red&colors=blue` | No | | Bracket notation | `?colors[]=red&colors[]=blue` | No | | Comma-separated | `?colors=red,blue` | **Yes** | | Indexed bracket notation | `?posts[0][title]=Hello&posts[1][title]=World` | No | | Nested bracket notation | `?preferences[tags]=golang,api` | No (comma splitting: **Yes**) | ##### Repeated Key The most common approach. Repeat the same query key for each value: ```text GET /api?colors=red&colors=blue&colors=green ``` ```go title="Struct" type Filter struct { Colors []string `query:"colors"` } // Result: Colors = ["red", "blue", "green"] ``` ```bash title="curl" curl "http://localhost:3000/api?colors=red&colors=blue&colors=green" ``` ##### Bracket Notation Append `[]` to the key name. This is common in PHP-style and JavaScript frameworks: ```text GET /api?colors[]=red&colors[]=blue&colors[]=green ``` ```go title="Struct" type Filter struct { Colors []string `query:"colors"` } // Result: Colors = ["red", "blue", "green"] ``` ```bash title="curl" curl "http://localhost:3000/api?colors[]=red&colors[]=blue&colors[]=green" ``` :::note The struct field tag stays `query:"colors"` (without brackets). Fiber strips the `[]` automatically. ::: ##### Comma-Separated Values Pass multiple values in a single parameter, separated by commas. This format requires [`EnableSplittingOnParsers`](fiber.md#enablesplittingonparsers) to be set to `true`. ```text GET /api?colors=red,blue,green ``` ```go title="Struct" type Filter struct { Colors []string `query:"colors"` } ``` ```go title="App Setup" // EnableSplittingOnParsers is required for comma splitting app := fiber.New(fiber.Config{ EnableSplittingOnParsers: true, }) // Result: Colors = ["red", "blue", "green"] ``` Without `EnableSplittingOnParsers`, the entire string `"red,blue,green"` is treated as a **single** element. ```go title="Default behavior (EnableSplittingOnParsers: false)" // GET /api?colors=red,blue,green // Result: Colors = ["red,blue,green"] ← single element ``` ```bash title="curl" curl "http://localhost:3000/api?colors=red,blue,green" ``` You can also mix comma-separated values with repeated keys when splitting is enabled: ```text GET /api?hobby=soccer&hobby=basketball,football ``` ```go type Query struct { Hobby []string `query:"hobby"` } // With EnableSplittingOnParsers: true // Result: Hobby = ["soccer", "basketball", "football"] ← 3 elements ``` ##### Indexed Bracket Notation (Nested Structs) Use indexed brackets to bind arrays of nested structs: ```text GET /api?posts[0][title]=Hello&posts[0][author]=Alice&posts[1][title]=World&posts[1][author]=Bob ``` ```go title="Struct" type Post struct { Title string `query:"title"` Author string `query:"author"` } type Request struct { Posts []Post `query:"posts"` } // Result: Posts = [{Title: "Hello", Author: "Alice"}, {Title: "World", Author: "Bob"}] ``` ```bash title="curl" curl "http://localhost:3000/api?posts[0][title]=Hello&posts[0][author]=Alice&posts[1][title]=World&posts[1][author]=Bob" ``` ##### Nested Bracket Notation (Without Index) Use bracket notation to access fields of a nested struct: ```text GET /api?preferences[tags]=golang,api ``` ```go title="Struct" type Preferences struct { Tags *[]string `query:"tags"` } type Profile struct { Prefs *Preferences `query:"preferences"` } // With EnableSplittingOnParsers: true // Result: *Prefs.Tags = ["golang", "api"] ``` ```bash title="curl" curl "http://localhost:3000/api?preferences[tags]=golang,api" ``` :::note Pointer fields (`*[]string`, `*Preferences`) let you distinguish between a missing parameter (`nil`) and an empty one. When the parameter is present, Fiber allocates the pointer automatically. ::: ### RespHeader This method is similar to [Body Binding](#body), but for response headers. It is important to use the struct tag `respHeader`. For example, if you want to parse a response header with a field called `Pass`, you would use a struct field with `respHeader:"pass"`. ```go title="Signature" func (b *Bind) RespHeader(out any) error ``` ```go title="Example" type Person struct { Name string `respHeader:"name"` Pass string `respHeader:"pass"` Products []string `respHeader:"products"` } app.Get("/", func(c fiber.Ctx) error { p := new(Person) if err := c.Bind().RespHeader(p); err != nil { return err } log.Println(p.Name) // john log.Println(p.Pass) // doe log.Println(p.Products) // [shoe hat] // ... }) ``` Run tests with the following `curl` command: ```bash curl "http://localhost:3000/" -H "name: john" -H "pass: doe" -H "products: shoe,hat" ``` ### URI This method is similar to [Body Binding](#body), but for path parameters. It is important to use the struct tag `uri`. For example, if you want to parse a path parameter with a field called `Pass`, you would use a struct field with `uri:"pass"`. ```go title="Signature" func (b *Bind) URI(out any) error ``` ```go title="Example" // GET http://example.com/user/111 app.Get("/user/:id", func(c fiber.Ctx) error { param := struct { ID uint `uri:"id"` }{} if err := c.Bind().URI(¶m); err != nil { return err } // ... return c.SendString(fmt.Sprintf("User ID: %d", param.ID)) }) ``` ## BindError When a bind method fails to parse (e.g. invalid JSON, bad type conversion), the behavior depends on the error-handling mode. In **manual handling** (the default), the binder returns a `*BindError` wrapping the underlying error — use `errors.As` to extract it and branch on the binding source or field. In **automatic handling** (enabled via `WithAutoHandling`), parse failures are instead converted to a `*fiber.Error` with HTTP status 400; `*BindError` is never surfaced to the caller in that mode. If you are using `WithAutoHandling`, check for `*fiber.Error` or an HTTP 400 response rather than using `errors.As` for `*BindError`. ```go type BindError struct { Source string // "uri", "query", "body", "header", "cookie", or "respHeader" Field string // struct field or tag key that failed (best-effort, may be empty) Err error // underlying error; use errors.As to inspect } ``` Source constants: `BindSourceURI`, `BindSourceQuery`, `BindSourceHeader`, `BindSourceCookie`, `BindSourceBody`, `BindSourceRespHeader`. ### Branching on source Use `errors.As` to extract `*BindError` and branch on `Source` for RFC-correct status codes (e.g. 404 for URI failures vs 400 for body/query): ```go title="Example" // With manual handling mode (default behavior) // Will not work with WithAutoHandling() var req struct { ID int `uri:"id"` Name string `json:"name"` } if err := c.Bind().All(&req); err != nil { var be *fiber.BindError if errors.As(err, &be) && be.Source == fiber.BindSourceURI { return c.Status(fiber.StatusNotFound).JSON(fiber.Map{"error": "not found"}) } return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid request"}) } ``` ### Validation vs binding errors Validation errors (from `StructValidator`) are **not** wrapped in `BindError`. Use `errors.As(err, &be)` to distinguish: it succeeds only for parsing/binding failures, not for validation failures. ## Custom To use custom binders, you have to use this method. You can register them using the [RegisterCustomBinder](./app.md#registercustombinder) method of the Fiber instance. ```go title="Signature" func (b *Bind) Custom(name string, dest any) error ``` ```go title="Example" app := fiber.New() // My custom binder type customBinder struct{} func (cb *customBinder) Name() string { return "custom" } func (cb *customBinder) MIMETypes() []string { return []string{"application/yaml"} } func (cb *customBinder) Parse(c fiber.Ctx, out any) error { // parse YAML body return yaml.Unmarshal(c.Body(), out) } // Register custom binder app.RegisterCustomBinder(&customBinder{}) type User struct { Name string `yaml:"name"` } // curl -X POST http://localhost:3000/custom -H "Content-Type: application/yaml" -d "name: John" app.Post("/custom", func(c fiber.Ctx) error { var user User // Use Custom binder by name if err := c.Bind().Custom("custom", &user); err != nil { return err } return c.JSON(user) }) ``` Internally, custom binders are also used in the [Body](#body) method. The `MIMETypes` method is used to check if the custom binder should be used for the given content type. ## Options For more control over error handling, you can use the following methods. ### WithAutoHandling If you want to handle binder errors automatically, you can use `WithAutoHandling`. If there's an error, it will return the error and set HTTP status to `400 Bad Request`. This function does NOT panic therefore you must still return on error explicitly ```go title="Signature" func (b *Bind) WithAutoHandling() *Bind ``` ### WithoutAutoHandling To handle binder errors manually, you can use the `WithoutAutoHandling` method. It's the default behavior of the binder. ```go title="Signature" func (b *Bind) WithoutAutoHandling() *Bind ``` ### SkipValidation To enable or disable validation for the current bind chain, use `SkipValidation`. By default, validation is enabled (`skip = false`). ```go title="Signature" func (b *Bind) SkipValidation(skip bool) *Bind ``` ## SetParserDecoder Allows you to configure the BodyParser/QueryParser decoder based on schema options, providing the possibility to add custom types for parsing. ```go title="Signature" func SetParserDecoder(parserConfig binder.ParserConfig) ``` `binder.ParserConfig` has the following fields: ```go type ParserConfig struct { IgnoreUnknownKeys bool ParserType []ParserType ZeroEmpty bool SetAliasTag string } type ParserType struct { CustomType any Converter func(string) reflect.Value } ``` ```go title="Example" type CustomTime time.Time // String returns the time in string format func (ct *CustomTime) String() string { t := time.Time(*ct).String() return t } // Converter for CustomTime type with format "2006-01-02" var timeConverter = func(value string) reflect.Value { fmt.Println("timeConverter:", value) if v, err := time.Parse("2006-01-02", value); err == nil { return reflect.ValueOf(CustomTime(v)) } return reflect.Value{} } customTime := binder.ParserType{ CustomType: CustomTime{}, Converter: timeConverter, } // Add custom type to the Decoder settings binder.SetParserDecoder(binder.ParserConfig{ IgnoreUnknownKeys: true, ParserType: []binder.ParserType{customTime}, ZeroEmpty: true, }) // Example using CustomTime with non-RFC3339 format type Demo struct { Date CustomTime `form:"date" query:"date"` Title string `form:"title" query:"title"` Body string `form:"body" query:"body"` } app.Post("/body", func(c fiber.Ctx) error { var d Demo if err := c.Bind().Body(&d); err != nil { return err } fmt.Println("d.Date:", d.Date.String()) return c.JSON(d) }) app.Get("/query", func(c fiber.Ctx) error { var d Demo if err := c.Bind().Query(&d); err != nil { return err } fmt.Println("d.Date:", d.Date.String()) return c.JSON(d) }) // Run tests with the following curl commands: # Body Binding curl -X POST -F title=title -F body=body -F date=2021-10-20 http://localhost:3000/body # Query Binding curl -X GET "http://localhost:3000/query?title=title&body=body&date=2021-10-20" ``` ## Validation Validation is also possible with the binding methods. You can specify your validation rules using the `validate` struct tag. Specify your struct validator in the [config](./fiber.md#structvalidator). The validator must implement the `StructValidator` interface, which requires a `Validate` method that takes an `any` type and returns an error. ```go title="Interface" type StructValidator interface { Validate(out any) error } ``` ### Setup Your Validator in the Config ```go title="Example" import "github.com/go-playground/validator/v10" type structValidator struct { validate *validator.Validate } // Validate method implementation func (v *structValidator) Validate(out any) error { return v.validate.Struct(out) } // Setup your validator in the Fiber config app := fiber.New(fiber.Config{ StructValidator: &structValidator{validate: validator.New()}, }) ``` Fiber only runs `StructValidator` for struct destinations (or pointers to structs). Binding into maps and other non-struct types skips the validator step. ### Usage of Validation in Binding Methods ```go title="Example" type Person struct { Name string `json:"name" validate:"required"` Age int `json:"age" validate:"gte=18,lte=60"` } app.Post("/", func(c fiber.Ctx) error { p := new(Person) if err := c.Bind().JSON(p); err != nil { // Receives validation errors return err } }) ``` ## Default Fields You can set default values for fields in the struct by using the `default` struct tag. Supported types: - `bool` - Float variants (`float32`, `float64`) - Int variants (`int`, `int8`, `int16`, `int32`, `int64`) - Uint variants (`uint`, `uint8`, `uint16`, `uint32`, `uint64`) - `string` - A slice of the above types. Use `|` to separate slice items. - A pointer to one of the above types (**pointers to slices and slices of pointers are not supported**). ```go title="Example" type Person struct { Name string `query:"name,default:john"` Pass string `query:"pass"` Products []string `query:"products,default:shoe|hat"` } app.Get("/", func(c fiber.Ctx) error { p := new(Person) if err := c.Bind().Query(p); err != nil { return err } log.Println(p.Name) // john log.Println(p.Pass) // doe log.Println(p.Products) // ["shoe", "hat"] // ... }) ``` Run tests with the following `curl` command: ```bash curl "http://localhost:3000/?pass=doe" ``` ================================================ FILE: docs/api/constants.md ================================================ --- id: constants title: 📋 Constants description: Core HTTP constants used throughout Fiber. sidebar_position: 10 --- ### HTTP methods (mirrors `net/http`) ```go const ( MethodGet = "GET" // RFC 7231, 4.3.1 MethodHead = "HEAD" // RFC 7231, 4.3.2 MethodPost = "POST" // RFC 7231, 4.3.3 MethodPut = "PUT" // RFC 7231, 4.3.4 MethodPatch = "PATCH" // RFC 5789 MethodDelete = "DELETE" // RFC 7231, 4.3.5 MethodConnect = "CONNECT" // RFC 7231, 4.3.6 MethodOptions = "OPTIONS" // RFC 7231, 4.3.7 MethodTrace = "TRACE" // RFC 7231, 4.3.8 methodUse = "USE" ) ``` ### Common MIME types ```go const ( MIMETextXML = "text/xml" MIMETextHTML = "text/html" MIMETextPlain = "text/plain" MIMETextJavaScript = "text/javascript" MIMETextCSS = "text/css" MIMEApplicationXML = "application/xml" MIMEApplicationJSON = "application/json" MIMEApplicationCBOR = "application/cbor" MIMEApplicationForm = "application/x-www-form-urlencoded" MIMEOctetStream = "application/octet-stream" MIMEMultipartForm = "multipart/form-data" MIMETextXMLCharsetUTF8 = "text/xml; charset=utf-8" MIMETextHTMLCharsetUTF8 = "text/html; charset=utf-8" MIMETextPlainCharsetUTF8 = "text/plain; charset=utf-8" MIMETextJavaScriptCharsetUTF8 = "text/javascript; charset=utf-8" MIMETextCSSCharsetUTF8 = "text/css; charset=utf-8" MIMEApplicationXMLCharsetUTF8 = "application/xml; charset=utf-8" MIMEApplicationJSONCharsetUTF8 = "application/json; charset=utf-8" ) ``` ### HTTP status codes (mirrors `net/http`) ```go const ( StatusContinue = 100 // RFC 7231, 6.2.1 StatusSwitchingProtocols = 101 // RFC 7231, 6.2.2 StatusProcessing = 102 // RFC 2518, 10.1 StatusEarlyHints = 103 // RFC 8297 StatusOK = 200 // RFC 7231, 6.3.1 StatusCreated = 201 // RFC 7231, 6.3.2 StatusAccepted = 202 // RFC 7231, 6.3.3 StatusNonAuthoritativeInformation = 203 // RFC 7231, 6.3.4 StatusNoContent = 204 // RFC 7231, 6.3.5 StatusResetContent = 205 // RFC 7231, 6.3.6 StatusPartialContent = 206 // RFC 7233, 4.1 StatusMultiStatus = 207 // RFC 4918, 11.1 StatusAlreadyReported = 208 // RFC 5842, 7.1 StatusIMUsed = 226 // RFC 3229, 10.4.1 StatusMultipleChoices = 300 // RFC 7231, 6.4.1 StatusMovedPermanently = 301 // RFC 7231, 6.4.2 StatusFound = 302 // RFC 7231, 6.4.3 StatusSeeOther = 303 // RFC 7231, 6.4.4 StatusNotModified = 304 // RFC 7232, 4.1 StatusUseProxy = 305 // RFC 7231, 6.4.5 StatusSwitchProxy = 306 // RFC 9110, 15.4.7 (Unused) StatusTemporaryRedirect = 307 // RFC 7231, 6.4.7 StatusPermanentRedirect = 308 // RFC 7538, 3 StatusBadRequest = 400 // RFC 7231, 6.5.1 StatusUnauthorized = 401 // RFC 7235, 3.1 StatusPaymentRequired = 402 // RFC 7231, 6.5.2 StatusForbidden = 403 // RFC 7231, 6.5.3 StatusNotFound = 404 // RFC 7231, 6.5.4 StatusMethodNotAllowed = 405 // RFC 7231, 6.5.5 StatusNotAcceptable = 406 // RFC 7231, 6.5.6 StatusProxyAuthRequired = 407 // RFC 7235, 3.2 StatusRequestTimeout = 408 // RFC 7231, 6.5.7 StatusConflict = 409 // RFC 7231, 6.5.8 StatusGone = 410 // RFC 7231, 6.5.9 StatusLengthRequired = 411 // RFC 7231, 6.5.10 StatusPreconditionFailed = 412 // RFC 7232, 4.2 StatusRequestEntityTooLarge = 413 // RFC 7231, 6.5.11 StatusRequestURITooLong = 414 // RFC 7231, 6.5.12 StatusUnsupportedMediaType = 415 // RFC 7231, 6.5.13 StatusRequestedRangeNotSatisfiable = 416 // RFC 7233, 4.4 StatusExpectationFailed = 417 // RFC 7231, 6.5.14 StatusTeapot = 418 // RFC 7168, 2.3.3 StatusMisdirectedRequest = 421 // RFC 7540, 9.1.2 StatusUnprocessableEntity = 422 // RFC 4918, 11.2 StatusLocked = 423 // RFC 4918, 11.3 StatusFailedDependency = 424 // RFC 4918, 11.4 StatusTooEarly = 425 // RFC 8470, 5.2. StatusUpgradeRequired = 426 // RFC 7231, 6.5.15 StatusPreconditionRequired = 428 // RFC 6585, 3 StatusTooManyRequests = 429 // RFC 6585, 4 StatusRequestHeaderFieldsTooLarge = 431 // RFC 6585, 5 StatusUnavailableForLegalReasons = 451 // RFC 7725, 3 StatusInternalServerError = 500 // RFC 7231, 6.6.1 StatusNotImplemented = 501 // RFC 7231, 6.6.2 StatusBadGateway = 502 // RFC 7231, 6.6.3 StatusServiceUnavailable = 503 // RFC 7231, 6.6.4 StatusGatewayTimeout = 504 // RFC 7231, 6.6.5 StatusHTTPVersionNotSupported = 505 // RFC 7231, 6.6.6 StatusVariantAlsoNegotiates = 506 // RFC 2295, 8.1 StatusInsufficientStorage = 507 // RFC 4918, 11.5 StatusLoopDetected = 508 // RFC 5842, 7.2 StatusNotExtended = 510 // RFC 2774, 7 StatusNetworkAuthenticationRequired = 511 // RFC 6585, 6 ) ``` ### Errors ```go var ( ErrBadRequest = NewError(StatusBadRequest) // RFC 7231, 6.5.1 ErrUnauthorized = NewError(StatusUnauthorized) // RFC 7235, 3.1 ErrPaymentRequired = NewError(StatusPaymentRequired) // RFC 7231, 6.5.2 ErrForbidden = NewError(StatusForbidden) // RFC 7231, 6.5.3 ErrNotFound = NewError(StatusNotFound) // RFC 7231, 6.5.4 ErrMethodNotAllowed = NewError(StatusMethodNotAllowed) // RFC 7231, 6.5.5 ErrNotAcceptable = NewError(StatusNotAcceptable) // RFC 7231, 6.5.6 ErrProxyAuthRequired = NewError(StatusProxyAuthRequired) // RFC 7235, 3.2 ErrRequestTimeout = NewError(StatusRequestTimeout) // RFC 7231, 6.5.7 ErrConflict = NewError(StatusConflict) // RFC 7231, 6.5.8 ErrGone = NewError(StatusGone) // RFC 7231, 6.5.9 ErrLengthRequired = NewError(StatusLengthRequired) // RFC 7231, 6.5.10 ErrPreconditionFailed = NewError(StatusPreconditionFailed) // RFC 7232, 4.2 ErrRequestEntityTooLarge = NewError(StatusRequestEntityTooLarge) // RFC 7231, 6.5.11 ErrRequestURITooLong = NewError(StatusRequestURITooLong) // RFC 7231, 6.5.12 ErrUnsupportedMediaType = NewError(StatusUnsupportedMediaType) // RFC 7231, 6.5.13 ErrRequestedRangeNotSatisfiable = NewError(StatusRequestedRangeNotSatisfiable) // RFC 7233, 4.4 ErrExpectationFailed = NewError(StatusExpectationFailed) // RFC 7231, 6.5.14 ErrTeapot = NewError(StatusTeapot) // RFC 7168, 2.3.3 ErrMisdirectedRequest = NewError(StatusMisdirectedRequest) // RFC 7540, 9.1.2 ErrUnprocessableEntity = NewError(StatusUnprocessableEntity) // RFC 4918, 11.2 ErrLocked = NewError(StatusLocked) // RFC 4918, 11.3 ErrFailedDependency = NewError(StatusFailedDependency) // RFC 4918, 11.4 ErrTooEarly = NewError(StatusTooEarly) // RFC 8470, 5.2. ErrUpgradeRequired = NewError(StatusUpgradeRequired) // RFC 7231, 6.5.15 ErrPreconditionRequired = NewError(StatusPreconditionRequired) // RFC 6585, 3 ErrTooManyRequests = NewError(StatusTooManyRequests) // RFC 6585, 4 ErrRequestHeaderFieldsTooLarge = NewError(StatusRequestHeaderFieldsTooLarge) // RFC 6585, 5 ErrUnavailableForLegalReasons = NewError(StatusUnavailableForLegalReasons) // RFC 7725, 3 ErrInternalServerError = NewError(StatusInternalServerError) // RFC 7231, 6.6.1 ErrNotImplemented = NewError(StatusNotImplemented) // RFC 7231, 6.6.2 ErrBadGateway = NewError(StatusBadGateway) // RFC 7231, 6.6.3 ErrServiceUnavailable = NewError(StatusServiceUnavailable) // RFC 7231, 6.6.4 ErrGatewayTimeout = NewError(StatusGatewayTimeout) // RFC 7231, 6.6.5 ErrHTTPVersionNotSupported = NewError(StatusHTTPVersionNotSupported) // RFC 7231, 6.6.6 ErrVariantAlsoNegotiates = NewError(StatusVariantAlsoNegotiates) // RFC 2295, 8.1 ErrInsufficientStorage = NewError(StatusInsufficientStorage) // RFC 4918, 11.5 ErrLoopDetected = NewError(StatusLoopDetected) // RFC 5842, 7.2 ErrNotExtended = NewError(StatusNotExtended) // RFC 2774, 7 ErrNetworkAuthenticationRequired = NewError(StatusNetworkAuthenticationRequired) // RFC 6585, 6 ) ``` HTTP Headers were copied from net/http. ```go const ( HeaderAuthorization = "Authorization" HeaderProxyAuthenticate = "Proxy-Authenticate" HeaderProxyAuthorization = "Proxy-Authorization" HeaderWWWAuthenticate = "WWW-Authenticate" HeaderAge = "Age" HeaderCacheControl = "Cache-Control" HeaderClearSiteData = "Clear-Site-Data" HeaderExpires = "Expires" HeaderPragma = "Pragma" HeaderWarning = "Warning" HeaderAcceptCH = "Accept-CH" HeaderAcceptCHLifetime = "Accept-CH-Lifetime" HeaderContentDPR = "Content-DPR" HeaderDPR = "DPR" HeaderEarlyData = "Early-Data" HeaderSaveData = "Save-Data" HeaderViewportWidth = "Viewport-Width" HeaderWidth = "Width" HeaderETag = "ETag" HeaderIfMatch = "If-Match" HeaderIfModifiedSince = "If-Modified-Since" HeaderIfNoneMatch = "If-None-Match" HeaderIfUnmodifiedSince = "If-Unmodified-Since" HeaderLastModified = "Last-Modified" HeaderVary = "Vary" HeaderConnection = "Connection" HeaderKeepAlive = "Keep-Alive" HeaderAccept = "Accept" HeaderAcceptCharset = "Accept-Charset" HeaderAcceptEncoding = "Accept-Encoding" HeaderAcceptLanguage = "Accept-Language" HeaderCookie = "Cookie" HeaderExpect = "Expect" HeaderMaxForwards = "Max-Forwards" HeaderSetCookie = "Set-Cookie" HeaderAccessControlAllowCredentials = "Access-Control-Allow-Credentials" HeaderAccessControlAllowHeaders = "Access-Control-Allow-Headers" HeaderAccessControlAllowMethods = "Access-Control-Allow-Methods" HeaderAccessControlAllowOrigin = "Access-Control-Allow-Origin" HeaderAccessControlExposeHeaders = "Access-Control-Expose-Headers" HeaderAccessControlMaxAge = "Access-Control-Max-Age" HeaderAccessControlRequestHeaders = "Access-Control-Request-Headers" HeaderAccessControlRequestMethod = "Access-Control-Request-Method" HeaderOrigin = "Origin" HeaderTimingAllowOrigin = "Timing-Allow-Origin" HeaderXPermittedCrossDomainPolicies = "X-Permitted-Cross-Domain-Policies" HeaderDNT = "DNT" HeaderTk = "Tk" HeaderContentDisposition = "Content-Disposition" HeaderContentEncoding = "Content-Encoding" HeaderContentLanguage = "Content-Language" HeaderContentLength = "Content-Length" HeaderContentLocation = "Content-Location" HeaderContentType = "Content-Type" HeaderForwarded = "Forwarded" HeaderVia = "Via" HeaderXForwardedFor = "X-Forwarded-For" HeaderXForwardedHost = "X-Forwarded-Host" HeaderXForwardedProto = "X-Forwarded-Proto" HeaderXForwardedProtocol = "X-Forwarded-Protocol" HeaderXForwardedSsl = "X-Forwarded-Ssl" HeaderXUrlScheme = "X-Url-Scheme" HeaderLocation = "Location" HeaderFrom = "From" HeaderHost = "Host" HeaderReferer = "Referer" HeaderReferrerPolicy = "Referrer-Policy" HeaderUserAgent = "User-Agent" HeaderAllow = "Allow" HeaderServer = "Server" HeaderAcceptRanges = "Accept-Ranges" HeaderContentRange = "Content-Range" HeaderIfRange = "If-Range" HeaderRange = "Range" HeaderContentSecurityPolicy = "Content-Security-Policy" HeaderContentSecurityPolicyReportOnly = "Content-Security-Policy-Report-Only" HeaderCrossOriginResourcePolicy = "Cross-Origin-Resource-Policy" HeaderExpectCT = "Expect-CT" HeaderFeaturePolicy = "Feature-Policy" HeaderPublicKeyPins = "Public-Key-Pins" HeaderPublicKeyPinsReportOnly = "Public-Key-Pins-Report-Only" HeaderStrictTransportSecurity = "Strict-Transport-Security" HeaderUpgradeInsecureRequests = "Upgrade-Insecure-Requests" HeaderXContentTypeOptions = "X-Content-Type-Options" HeaderXDownloadOptions = "X-Download-Options" HeaderXFrameOptions = "X-Frame-Options" HeaderXPoweredBy = "X-Powered-By" HeaderXXSSProtection = "X-XSS-Protection" HeaderLastEventID = "Last-Event-ID" HeaderNEL = "NEL" HeaderPingFrom = "Ping-From" HeaderPingTo = "Ping-To" HeaderReportTo = "Report-To" HeaderTE = "TE" HeaderTrailer = "Trailer" HeaderTransferEncoding = "Transfer-Encoding" HeaderSecWebSocketAccept = "Sec-WebSocket-Accept" HeaderSecWebSocketExtensions = "Sec-WebSocket-Extensions" HeaderSecWebSocketKey = "Sec-WebSocket-Key" HeaderSecWebSocketProtocol = "Sec-WebSocket-Protocol" HeaderSecWebSocketVersion = "Sec-WebSocket-Version" HeaderAcceptPatch = "Accept-Patch" HeaderAcceptPushPolicy = "Accept-Push-Policy" HeaderAcceptSignature = "Accept-Signature" HeaderAltSvc = "Alt-Svc" HeaderDate = "Date" HeaderIndex = "Index" HeaderLargeAllocation = "Large-Allocation" HeaderLink = "Link" HeaderPushPolicy = "Push-Policy" HeaderRetryAfter = "Retry-After" HeaderServerTiming = "Server-Timing" HeaderSignature = "Signature" HeaderSignedHeaders = "Signed-Headers" HeaderSourceMap = "SourceMap" HeaderUpgrade = "Upgrade" HeaderXDNSPrefetchControl = "X-DNS-Prefetch-Control" HeaderXPingback = "X-Pingback" HeaderXRequestID = "X-Request-ID" HeaderXRequestedWith = "X-Requested-With" HeaderXRobotsTag = "X-Robots-Tag" HeaderXUACompatible = "X-UA-Compatible" HeaderAccessControlAllowPrivateNetwork = "Access-Control-Allow-Private-Network" HeaderAccessControlRequestPrivateNetwork = "Access-Control-Request-Private-Network" ) ``` ================================================ FILE: docs/api/ctx.md ================================================ --- id: ctx title: 🧠 Ctx description: >- The Ctx interface represents the Context which holds the HTTP request and response. It has methods for the request query string, parameters, body, HTTP headers, and so on. sidebar_position: 3 --- ### Abandon Marks the context as abandoned. An abandoned context will not be returned to the pool when `ReleaseCtx` is called. This is used internally by the [timeout middleware](../middleware/timeout.md) to return immediately while the handler goroutine continues safely. ```go title="Signature" func (c fiber.Ctx) Abandon() func (c fiber.Ctx) IsAbandoned() bool func (c fiber.Ctx) ForceRelease() ``` | Method | Description | |:---------------|:----------------------------------------------------------------------------| | `Abandon()` | Marks the context as abandoned. ReleaseCtx becomes a no-op for this context. | | `IsAbandoned()`| Returns `true` if `Abandon()` was called on this context. | | `ForceRelease()`| Releases an abandoned context back to the pool. Must only be called after the handler has completely finished. | :::caution These methods are primarily for internal use and advanced middleware development. Most applications should not need to call them directly. ::: ### App Returns the [\*App](app.md) reference so you can easily access all application settings. ```go title="Signature" func (c fiber.Ctx) App() *App ``` ```go title="Example" app.Get("/stack", func(c fiber.Ctx) error { return c.JSON(c.App().Stack()) }) ``` ### Bind Bind returns a helper for decoding the request body, query string, headers, cookies, and more. For full details, see the [Bind](./bind.md) documentation. ```go title="Signature" func (c fiber.Ctx) Bind() *Bind ``` ```go title="Example" app.Post("/", func(c fiber.Ctx) error { user := new(User) // Bind the request body to a struct: return c.Bind().Body(user) }) ``` ### Context Returns a `context.Context` that was previously set with [`SetContext`](#setcontext). If no context was set, it returns `context.Background()`. Unlike `fiber.Ctx` itself, the returned context is safe to use after the handler completes. ```go title="Signature" func (c fiber.Ctx) Context() context.Context ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { ctx := c.Context() go doWork(ctx) return nil }) ``` ### context.Context `Ctx` implements `context.Context`. However due to [current limitations in how fasthttp](https://github.com/valyala/fasthttp/issues/965#issuecomment-777268945) works, `Deadline()`, `Done()` and `Err()` are no-ops. The `fiber.Ctx` instance is reused after the handler returns and must not be used for asynchronous operations once the handler has completed. Call [`Context`](#context) within the handler to obtain a `context.Context` that can be used outside the handler. ```go title="Signature" func (c fiber.Ctx) Deadline() (deadline time.Time, ok bool) func (c fiber.Ctx) Done() <-chan struct{} func (c fiber.Ctx) Err() error func (c fiber.Ctx) Value(key any) any ``` ```go title="Example" func doSomething(ctx context.Context) { // ... } app.Get("/", func(c fiber.Ctx) error { doSomething(c) }) ``` #### Value Value can be used to retrieve [**`Locals`**](#locals). ```go title="Example" app.Get("/", func(c fiber.Ctx) error { c.Locals(userKey, "admin") user := c.Value(userKey) // returns "admin" }) ``` ### Drop Terminates the client connection silently without sending any HTTP headers or response body. This can be used for scenarios where you want to block certain requests without notifying the client, such as mitigating DDoS attacks or protecting sensitive endpoints from unauthorized access. ```go title="Signature" func (c fiber.Ctx) Drop() error ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { if c.IP() == "192.168.1.1" { return c.Drop() } return c.SendString("Hello World!") }) ``` ### FullPath Returns the full path of the matched route. This includes any prefixes that were added by [groups](../guide/routing.md#grouping) or mounts. ```go title="Signature" func (c fiber.Ctx) FullPath() string ``` ```go title="Example" api := app.Group("/api") api.Get("/users/:id", func(c fiber.Ctx) error { return c.JSON(fiber.Map{ "route": c.FullPath(), // "/api/users/:id" }) }) app.Use(func(c fiber.Ctx) error { beforeNext := c.FullPath() // "/" if err := c.Next(); err != nil { return err } afterNext := c.FullPath() // "/api/users/:id" // ... react to the downstream handler's route path return nil }) ``` ### GetReqHeaders Returns the HTTP request headers as a map. Because a header can appear multiple times in a request, each key maps to a slice with all values for that header. ```go title="Signature" func (c fiber.Ctx) GetReqHeaders() map[string][]string ``` :::info The returned value is valid only within the handler. Do not store references. Make copies or use the [**`Immutable`**](./fiber.md#immutable) setting instead. [Read more...](../#zero-allocation) ::: ### GetRespHeader Returns the HTTP response header specified by the field. :::tip The match is **case-insensitive**. ::: ```go title="Signature" func (c fiber.Ctx) GetRespHeader(key string, defaultValue ...string) string ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { c.GetRespHeader("X-Request-Id") // "8d7ad5e3-aaf3-450b-a241-2beb887efd54" c.GetRespHeader("Content-Type") // "text/plain" c.GetRespHeader("something", "john") // "john" // .. }) ``` :::info The returned value is valid only within the handler. Do not store references. Make copies or use the [**`Immutable`**](./fiber.md#immutable) setting instead. [Read more...](../#zero-allocation) ::: ### GetRespHeaders Returns the HTTP response headers as a map. Since a header can be set multiple times in a single request, the values of the map are slices of strings containing all the different values of the header. ```go title="Signature" func (c fiber.Ctx) GetRespHeaders() map[string][]string ``` :::info The returned value is valid only within the handler. Do not store references. Make copies or use the [**`Immutable`**](./fiber.md#immutable) setting instead. [Read more...](../#zero-allocation) ::: ### GetRouteURL Generates URLs to named routes, with parameters. URLs are relative, for example: "/user/1831" ```go title="Signature" func (c fiber.Ctx) GetRouteURL(routeName string, params Map) (string, error) ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { return c.SendString("Home page") }).Name("home") app.Get("/user/:id", func(c fiber.Ctx) error { return c.SendString(c.Params("id")) }).Name("user.show") app.Get("/test", func(c fiber.Ctx) error { location, _ := c.GetRouteURL("user.show", fiber.Map{"id": 1}) return c.SendString(location) }) // /test returns "/user/1" ``` ### HasBody Returns `true` if the incoming request contains a body or a `Content-Length` header greater than zero. ```go title="Signature" func (c fiber.Ctx) HasBody() bool ``` ```go title="Example" app.Post("/", func(c fiber.Ctx) error { if !c.HasBody() { return c.SendStatus(fiber.StatusBadRequest) } return c.SendString("OK") }) ``` ### IsMiddleware Returns `true` if the current request handler was registered as middleware. ```go title="Signature" func (c fiber.Ctx) IsMiddleware() bool ``` ```go title="Example" app.Get("/route", func(c fiber.Ctx) error { fmt.Println(c.IsMiddleware()) // true return c.Next() }, func(c fiber.Ctx) error { fmt.Println(c.IsMiddleware()) // false return c.SendStatus(fiber.StatusOK) }) ``` ### IsPreflight Returns `true` if the request is a CORS preflight (`OPTIONS` + `Access-Control-Request-Method` + `Origin`). ```go title="Signature" func (c fiber.Ctx) IsPreflight() bool ``` ```go title="Example" app.Use(func(c fiber.Ctx) error { if c.IsPreflight() { return c.SendStatus(fiber.StatusNoContent) } return c.Next() }) ``` ### IsWebSocket Returns `true` if the request includes a WebSocket upgrade handshake. ```go title="Signature" func (c fiber.Ctx) IsWebSocket() bool ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { if c.IsWebSocket() { // handle websocket } return c.Next() }) ``` ### Locals Stores variables scoped to the request, making them available only to matching routes. The variables are removed after the request completes. If a stored value implements `io.Closer`, Fiber calls its `Close` method before removal. :::tip This is useful if you want to pass some **specific** data to the next middleware. Remember to perform type assertions when retrieving the data to ensure it is of the expected type. You can also use a non-exported type as a key to avoid collisions. ::: ```go title="Signature" func (c fiber.Ctx) Locals(key any, value ...any) any ``` ```go title="Example" // keyType is an unexported type for keys defined in this package. // This prevents collisions with keys defined in other packages. type keyType int // userKey is the key for user.User values in Contexts. It is // unexported; clients use user.NewContext and user.FromContext // instead of using this key directly. var userKey keyType app.Use(func(c fiber.Ctx) error { c.Locals(userKey, "admin") // Stores the string "admin" under a non-exported type key return c.Next() }) app.Get("/admin", func(c fiber.Ctx) error { user, ok := c.Locals(userKey).(string) // Retrieves the data stored under the key and performs a type assertion if ok && user == "admin" { return c.Status(fiber.StatusOK).SendString("Welcome, admin!") } return c.SendStatus(fiber.StatusForbidden) }) ``` An alternative version of the `Locals` method that takes advantage of Go's generics feature is also available. This version allows for the manipulation and retrieval of local values within a request's context with a more specific data type. ```go title="Signature" func Locals[V any](c fiber.Ctx, key any, value ...V) V ``` ```go title="Example" app.Use(func(c fiber.Ctx) error { fiber.Locals[string](c, "john", "doe") fiber.Locals[int](c, "age", 18) fiber.Locals[bool](c, "isHuman", true) return c.Next() }) app.Get("/test", func(c fiber.Ctx) error { fiber.Locals[string](c, "john") // "doe" fiber.Locals[int](c, "age") // 18 fiber.Locals[bool](c, "isHuman") // true return nil }) ``` Make sure to understand and correctly implement the `Locals` method in both its standard and generic form for better control over route-specific data within your application. ### Matched Returns `true` if the current request path was matched by the router. ```go title="Signature" func (c fiber.Ctx) Matched() bool ``` ```go title="Example" app.Use(func(c fiber.Ctx) error { if c.Matched() { return c.Next() } return c.Status(fiber.StatusNotFound).SendString("Not Found") }) ``` ### Next When **Next** is called, it executes the next method in the stack that matches the current route. You can pass an error struct within the method that will end the chaining and call the [error handler](../guide/error-handling). ```go title="Signature" func (c fiber.Ctx) Next() error ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { fmt.Println("1st route!") return c.Next() }) app.Get("*", func(c fiber.Ctx) error { fmt.Println("2nd route!") return c.Next() }) app.Get("/", func(c fiber.Ctx) error { fmt.Println("3rd route!") return c.SendString("Hello, World!") }) ``` ### OverrideParam Overwrites the value of an existing route parameter. :::note If the parameter does not exist, this method does nothing. ::: ```go title="Signature" func (c fiber.Ctx) OverrideParam(name, value string) ``` ```go title="Example" // GET http://example.com/user app.Get("/user/:name", func(c fiber.Ctx) error { // mutate parameter c.OverrideParam("name", "new value") return c.SendString(c.Params("name")) // sends "new value" }) // GET http://example.com/shop/tech/1 app.Get("/shop/*", func(c fiber.Ctx) error { // mutate parameter c.OverrideParam("*", "new tech") // replaces "tech/1" with "new tech" return c.SendString(c.Params("*")) // sends "new tech" }) ``` Unnamed route parameters can be accessed by their character (`*` or `+`) followed by their position index (e.g., `*1` for the first wildcard, `*2` for the second). ```go title="Example" // GET /v1/brand/4/shop/blue/xs app.Get("/v1/*/shop/*", func(c fiber.Ctx) error { // mutate parameter c.OverrideParam("*1", "updated brand") c.OverrideParam("*2", "updated data") param1 := c.Params("*1") // "updated brand" param2 := c.Params("*2") // "updated data" // ... }) ``` ### Redirect Returns the Redirect reference. For detailed information, check the [Redirect](./redirect.md) documentation. ```go title="Signature" func (c fiber.Ctx) Redirect() *Redirect ``` ```go title="Example" app.Get("/coffee", func(c fiber.Ctx) error { return c.Redirect().To("/teapot") }) app.Get("/teapot", func(c fiber.Ctx) error { return c.Status(fiber.StatusTeapot).Send("🍵 short and stout 🍵") }) ``` ### Request Returns the [*fasthttp.Request](https://pkg.go.dev/github.com/valyala/fasthttp#Request) pointer. ```go title="Signature" func (c fiber.Ctx) Request() *fasthttp.Request ``` :::info Returns `nil` if the context has been released (e.g., after the handler completes and the context is returned to the pool). ::: ```go title="Example" app.Get("/", func(c fiber.Ctx) error { c.Request().Header.Method() // => []byte("GET") }) ``` ### RequestCtx Returns [\*fasthttp.RequestCtx](https://pkg.go.dev/github.com/valyala/fasthttp#RequestCtx) that is compatible with the `context.Context` interface that requires a deadline, a cancellation signal, and other values across API boundaries. ```go title="Signature" func (c fiber.Ctx) RequestCtx() *fasthttp.RequestCtx ``` :::info Please read the [Fasthttp Documentation](https://pkg.go.dev/github.com/valyala/fasthttp?tab=doc) for more information. ::: ### Reset Resets the context fields by the given request when using server handlers. ```go title="Signature" func (c fiber.Ctx) Reset(fctx *fasthttp.RequestCtx) ``` It is used outside of the Fiber Handlers to reset the context for the next request. ### Response Returns the [\*fasthttp.Response](https://pkg.go.dev/github.com/valyala/fasthttp#Response) pointer. ```go title="Signature" func (c fiber.Ctx) Response() *fasthttp.Response ``` :::info Returns `nil` if the context has been released (e.g., after the handler completes and the context is returned to the pool). ::: ```go title="Example" app.Get("/", func(c fiber.Ctx) error { c.Response().BodyWriter().Write([]byte("Hello, World!")) // => "Hello, World!" return nil }) ``` ### RestartRouting Instead of executing the next method when calling [Next](ctx.md#next), **RestartRouting** restarts execution from the first method that matches the current route. This may be helpful after overriding the path, i.e., an internal redirect. Note that handlers might be executed again, which could result in an infinite loop. ```go title="Signature" func (c fiber.Ctx) RestartRouting() error ``` ```go title="Example" app.Get("/new", func(c fiber.Ctx) error { return c.SendString("From /new") }) app.Get("/old", func(c fiber.Ctx) error { c.Path("/new") return c.RestartRouting() }) ``` ### Route Returns the matched [Route](https://pkg.go.dev/github.com/gofiber/fiber?tab=doc#Route) struct. ```go title="Signature" func (c fiber.Ctx) Route() *Route ``` ```go title="Example" // http://localhost:8080/hello app.Get("/hello/:name", func(c fiber.Ctx) error { r := c.Route() fmt.Println(r.Method, r.Path, r.Params, r.Handlers) // GET /hello/:name handler [name] // ... }) ``` :::caution Do not rely on `c.Route()` in middlewares **before** calling `c.Next()` - `c.Route()` returns the **last executed route**. ::: ```go title="Example" func MyMiddleware() fiber.Handler { return func(c fiber.Ctx) error { beforeNext := c.Route().Path // Will be '/' err := c.Next() afterNext := c.Route().Path // Will be '/hello/:name' return err } } ``` ### SetContext Sets the base `context.Context` used by [`Context`](#context). Use this to propagate deadlines, cancellation signals, or values to asynchronous operations. ```go title="Signature" func (c fiber.Ctx) SetContext(ctx context.Context) ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { c.SetContext(context.WithValue(context.Background(), "user", "alice")) ctx := c.Context() go doWork(ctx) return nil }) ``` ### String Returns a unique string representation of the context. ```go title="Signature" func (c fiber.Ctx) String() string ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { c.String() // => "#0000000100000001 - 127.0.0.1:3000 <-> 127.0.0.1:61516 - GET http://localhost:3000/" // ... }) ``` ### ViewBind Adds variables to the default view variable map binding to the template engine. Variables are read by the `Render` method and may be overwritten. ```go title="Signature" func (c fiber.Ctx) ViewBind(vars Map) error ``` ```go title="Example" app.Use(func(c fiber.Ctx) error { c.ViewBind(fiber.Map{ "Title": "Hello, World!", }) return c.Next() }) app.Get("/", func(c fiber.Ctx) error { return c.Render("xxx.tmpl", fiber.Map{}) // Render will use the Title variable }) ``` ## Request Methods which operate on the incoming request. :::tip Use `c.Req()` to limit gopls suggestions to only these methods! ::: ### AcceptEncoding Returns the `Accept-Encoding` request header. ```go title="Signature" func (c fiber.Ctx) AcceptEncoding() string ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { c.AcceptEncoding() // "gzip, br" return nil }) ``` ### AcceptLanguage Returns the `Accept-Language` request header. ```go title="Signature" func (c fiber.Ctx) AcceptLanguage() string ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { c.AcceptLanguage() // "en-US,en;q=0.9" return nil }) ``` ### Accepts Checks if the specified **extensions** or **content** **types** are acceptable. :::info Based on the request’s [Accept](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept) HTTP header. ::: ```go title="Signature" func (c fiber.Ctx) Accepts(offers ...string) string func (c fiber.Ctx) AcceptsCharsets(offers ...string) string func (c fiber.Ctx) AcceptsEncodings(offers ...string) string func (c fiber.Ctx) AcceptsLanguages(offers ...string) string func (c fiber.Ctx) AcceptsLanguagesExtended(offers ...string) string ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { c.Accepts("html") // "html" c.Accepts("text/html") // "text/html" c.Accepts("json", "text") // "json" c.Accepts("application/json") // "application/json" c.Accepts("text/plain", "application/json") // "application/json", due to quality c.Accepts("image/png") // "" c.Accepts("png") // "" // ... }) ``` ```go title="Example 2" // Accept: text/html, text/*, application/json, */*; q=0 app.Get("/", func(c fiber.Ctx) error { c.Accepts("text/plain", "application/json") // "application/json", due to specificity c.Accepts("application/json", "text/html") // "text/html", due to first match c.Accepts("image/png") // "", due to */* with q=0 is Not Acceptable // ... }) ``` Media-Type parameters are supported. ```go title="Example 3" // Accept: text/plain, application/json; version=1; foo=bar app.Get("/", func(c fiber.Ctx) error { // Extra parameters in the accept are ignored c.Accepts("text/plain;format=flowed") // "text/plain;format=flowed" // An offer must contain all parameters present in the Accept type c.Accepts("application/json") // "" // Parameter order and capitalization do not matter. Quotes on values are stripped. c.Accepts(`application/json;foo="bar";VERSION=1`) // "application/json;foo="bar";VERSION=1" }) ``` ```go title="Example 4" // Accept: text/plain;format=flowed;q=0.9, text/plain // i.e., "I prefer text/plain;format=flowed less than other forms of text/plain" app.Get("/", func(c fiber.Ctx) error { // Beware: the order in which offers are listed matters. // Although the client specified they prefer not to receive format=flowed, // the text/plain Accept matches with "text/plain;format=flowed" first, so it is returned. c.Accepts("text/plain;format=flowed", "text/plain") // "text/plain;format=flowed" // Here, things behave as expected: c.Accepts("text/plain", "text/plain;format=flowed") // "text/plain" }) ``` Fiber provides similar functions for the other accept headers. For `Accept-Language`, Fiber uses the [Basic Filtering](https://www.rfc-editor.org/rfc/rfc4647#section-3.3.1) algorithm. A language range matches an offer only if it exactly equals the tag or is a prefix followed by a hyphen. For example, the range `en` matches `en-US`, but `en-US` does not match `en`. `AcceptsLanguagesExtended` applies [Extended Filtering](https://www.rfc-editor.org/rfc/rfc4647#section-3.3.2) where `*` may match zero or more subtags and wildcard matches can slide across subtags unless blocked by a singleton like `x`. ```go // Accept-Charset: utf-8, iso-8859-1;q=0.2 // Accept-Encoding: gzip, compress;q=0.2 // Accept-Language: en;q=0.8, nl, ru app.Get("/", func(c fiber.Ctx) error { c.AcceptsCharsets("utf-16", "iso-8859-1") // "iso-8859-1" c.AcceptsEncodings("compress", "br") // "compress" c.AcceptsLanguages("pt", "nl", "ru") // "nl" c.AcceptsLanguagesExtended("en-US", "fr-CA") // depends on extended ranges in the request header // ... }) ``` ### AcceptsEventStream Returns `true` when the `Accept` header allows `text/event-stream`. ```go title="Signature" func (c fiber.Ctx) AcceptsEventStream() bool ``` ```go title="Example" // Accept: text/html, application/json;q=0.9 app.Get("/", func(c fiber.Ctx) error { c.AcceptsEventStream() // false return nil }) ``` ### AcceptsHTML Returns `true` when the `Accept` header allows HTML. ```go title="Signature" func (c fiber.Ctx) AcceptsHTML() bool ``` ```go title="Example" // Accept: text/html, application/json;q=0.9 app.Get("/", func(c fiber.Ctx) error { c.AcceptsHTML() // true return nil }) ``` ### AcceptsJSON Returns `true` when the `Accept` header allows JSON. ```go title="Signature" func (c fiber.Ctx) AcceptsJSON() bool ``` ```go title="Example" // Accept: text/html, application/json;q=0.9 app.Get("/", func(c fiber.Ctx) error { c.AcceptsJSON() // true return nil }) ``` ### AcceptsXML Returns `true` when the `Accept` header allows XML. ```go title="Signature" func (c fiber.Ctx) AcceptsXML() bool ``` ```go title="Example" // Accept: text/html, application/json;q=0.9 app.Get("/", func(c fiber.Ctx) error { c.AcceptsXML() // false return nil }) ``` ### BaseURL Returns the base URL (**protocol** + **host**) as a `string`. ```go title="Signature" func (c fiber.Ctx) BaseURL() string ``` ```go title="Example" // GET https://example.com/page#chapter-1 app.Get("/", func(c fiber.Ctx) error { c.BaseURL() // "https://example.com" // ... }) ``` ### Body As per the header `Content-Encoding`, this method will try to perform a file decompression from the **body** bytes. In case no `Content-Encoding` header is sent (or when it is set to `identity`), it will perform as [BodyRaw](#bodyraw). If an unknown or unsupported encoding is encountered, the response status will be `415 Unsupported Media Type` or `501 Not Implemented`. ```go title="Signature" func (c fiber.Ctx) Body() []byte ``` ```go title="Example" // echo 'user=john' | gzip | curl -v -i --data-binary @- -H "Content-Encoding: gzip" http://localhost:8080 app.Post("/", func(c fiber.Ctx) error { // Decompress body from POST request based on the Content-Encoding and return the raw content: return c.Send(c.Body()) // []byte("user=john") }) ``` :::info The returned value is valid only within the handler. Do not store references. Make copies or use the [**`Immutable`**](./fiber.md#immutable) setting instead. [Read more...](../#zero-allocation) ::: ### BodyRaw Returns the raw request **body**. ```go title="Signature" func (c fiber.Ctx) BodyRaw() []byte ``` ```go title="Example" // curl -X POST http://localhost:8080 -d user=john app.Post("/", func(c fiber.Ctx) error { // Get raw body from POST request: return c.Send(c.BodyRaw()) // []byte("user=john") }) ``` :::info The returned value is valid only within the handler. Do not store references. Make copies or use the [**`Immutable`**](./fiber.md#immutable) setting instead. [Read more...](../#zero-allocation) ::: ### Charset Returns the `charset` parameter from the `Content-Type` header. ```go title="Signature" func (c fiber.Ctx) Charset() string ``` ```go title="Example" // Content-Type: application/json; charset=utf-8 app.Post("/", func(c fiber.Ctx) error { c.Charset() // "utf-8" return nil }) ``` ### ClientHelloInfo `ClientHelloInfo` contains information from a ClientHello message to guide application logic in the `GetCertificate` and `GetConfigForClient` callbacks. Refer to the [ClientHelloInfo](https://golang.org/pkg/crypto/tls/#ClientHelloInfo) struct documentation for details on the returned struct. ```go title="Signature" func (c fiber.Ctx) ClientHelloInfo() *tls.ClientHelloInfo ``` ```go title="Example" // GET http://example.com/hello app.Get("/hello", func(c fiber.Ctx) error { chi := c.ClientHelloInfo() // ... }) ``` ### Cookies Gets a cookie value by key. You can pass an optional default value that will be returned if the cookie key does not exist. ```go title="Signature" func (c fiber.Ctx) Cookies(key string, defaultValue ...string) string ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { // Get cookie by key: c.Cookies("name") // "john" c.Cookies("empty", "doe") // "doe" // ... }) ``` :::info The returned value is valid only within the handler. Do not store references. Use [`App.GetString`](./app.md#getstring) or [`App.GetBytes`](./app.md#getbytes) when immutability is enabled, or manually copy values (for example with [`utils.CopyString`](https://github.com/gofiber/utils) / `utils.CopyBytes`) when it's disabled. [Read more...](../#zero-allocation) ::: ### FormFile MultipartForm files can be retrieved by name, the **first** file from the given key is returned. ```go title="Signature" func (c fiber.Ctx) FormFile(key string) (*multipart.FileHeader, error) ``` ```go title="Example" app.Post("/", func(c fiber.Ctx) error { // Get first file from form field "document": file, err := c.FormFile("document") // Save file to root directory: return c.SaveFile(file, fmt.Sprintf("./%s", file.Filename)) }) ``` ### FormValue Form values can be retrieved by name, the **first** value for the given key is returned. ```go title="Signature" func (c fiber.Ctx) FormValue(key string, defaultValue ...string) string ``` ```go title="Example" app.Post("/", func(c fiber.Ctx) error { // Get first value from form field "name": c.FormValue("name") // => "john" or "" if not exist // .. }) ``` :::info The returned value is valid only within the handler. Do not store references. Make copies or use the [**`Immutable`**](./fiber.md#immutable) setting instead. [Read more...](../#zero-allocation) ::: ### Fresh When the response is still **fresh** in the client's cache **true** is returned; otherwise, **false** is returned to indicate that the client cache is now stale and the full response should be sent. When a client sends the Cache-Control: no-cache request header to indicate an end-to-end reload request, `Fresh` will return false to make handling these requests transparent. Read more on [https://expressjs.com/en/4x/api.html\#req.fresh](https://expressjs.com/en/4x/api.html#req.fresh) ```go title="Signature" func (c fiber.Ctx) Fresh() bool ``` ### FullURL Returns the full request URL (protocol + host + original URL). ```go title="Signature" func (c fiber.Ctx) FullURL() string ``` ```go title="Example" // GET http://example.com/search?q=fiber app.Get("/", func(c fiber.Ctx) error { c.FullURL() // "http://example.com/search?q=fiber" return nil }) ``` ### Get Returns the HTTP request header specified by the field. :::tip The match is **case-insensitive**. ::: ```go title="Signature" func (c fiber.Ctx) Get(key string, defaultValue ...string) string ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { c.Get("Content-Type") // "text/plain" c.Get("CoNtEnT-TypE") // "text/plain" c.Get("something", "john") // "john" // .. }) ``` :::info The returned value is valid only within the handler. Do not store references. Make copies or use the [**`Immutable`**](./fiber.md#immutable) setting instead. [Read more...](../#zero-allocation) ::: ### HasHeader Reports whether the request includes a header with the given key. ```go title="Signature" func (c fiber.Ctx) HasHeader(key string) bool ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { c.HasHeader("X-Trace-Id") return nil }) ``` ### Host Returns the host derived from the [Host](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Host) HTTP header. In a network context, [`Host`](#host) refers to the combination of a hostname and potentially a port number used for connecting, while [`Hostname`](#hostname) refers specifically to the name assigned to a device on a network, excluding any port information. ```go title="Signature" func (c fiber.Ctx) Host() string ``` ```go title="Example" // GET http://google.com:8080/search app.Get("/", func(c fiber.Ctx) error { c.Host() // "google.com:8080" c.Hostname() // "google.com" // ... }) ``` :::info The returned value is valid only within the handler. Do not store references. Make copies or use the [**`Immutable`**](./fiber.md#immutable) setting instead. [Read more...](../#zero-allocation) ::: ### Hostname Returns the hostname derived from the [Host](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Host) HTTP header. ```go title="Signature" func (c fiber.Ctx) Hostname() string ``` ```go title="Example" // GET http://google.com/search app.Get("/", func(c fiber.Ctx) error { c.Hostname() // "google.com" // ... }) ``` :::info The returned value is valid only within the handler. Do not store references. Make copies or use the [**`Immutable`**](./fiber.md#immutable) setting instead. [Read more...](../#zero-allocation) ::: ### IP Returns the remote IP address of the request. ```go title="Signature" func (c fiber.Ctx) IP() string ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { c.IP() // "127.0.0.1" // ... }) ``` :::info By default, `c.IP()` returns the remote IP address from the TCP connection. When your Fiber app is behind a reverse proxy (like Nginx, Traefik, or a load balancer), you need to configure **both** [`TrustProxy`](fiber.md#trustproxy) and [`ProxyHeader`](fiber.md#proxyheader) to read the client IP from proxy headers like `X-Forwarded-For`. **Important:** You must enable `TrustProxy` and configure trusted proxy IPs to prevent header spoofing. Simply setting `ProxyHeader` alone will not work. **Note:** When using a proxy header such as `X-Forwarded-For`, `c.IP()` returns the raw header value unless [`EnableIPValidation`](fiber.md#enableipvalidation) is enabled. For `X-Forwarded-For`, this raw value may be a comma-separated list of IPs; enable `EnableIPValidation` if you need `c.IP()` to return a single, validated client IP. ::: #### Configuration for apps behind a reverse proxy ```go title="Example - Basic Configuration" app := fiber.New(fiber.Config{ // Enable proxy support TrustProxy: true, // Specify which header contains the real client IP ProxyHeader: fiber.HeaderXForwardedFor, // Configure which proxy IPs to trust TrustProxyConfig: fiber.TrustProxyConfig{ // Trust private IP ranges (for internal load balancers) Private: true, // Or specify exact proxy IPs/ranges // Proxies: []string{"10.10.0.58", "192.168.0.0/24"}, }, }) ``` ```go title="Example - Specific Proxy IPs" app := fiber.New(fiber.Config{ TrustProxy: true, ProxyHeader: fiber.HeaderXForwardedFor, TrustProxyConfig: fiber.TrustProxyConfig{ // Trust only specific proxy IP addresses Proxies: []string{"10.10.0.58", "192.168.1.0/24"}, }, }) ``` See [`TrustProxy`](fiber.md#trustproxy) and [`TrustProxyConfig`](fiber.md#trustproxyconfig) for more details on security considerations and configuration options. ### IPs Returns an array of IP addresses specified in the [X-Forwarded-For](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For) request header. ```go title="Signature" func (c fiber.Ctx) IPs() []string ``` ```go title="Example" // X-Forwarded-For: proxy1, 127.0.0.1, proxy3 app.Get("/", func(c fiber.Ctx) error { c.IPs() // ["proxy1", "127.0.0.1", "proxy3"] // ... }) ``` :::caution Improper use of the X-Forwarded-For header can be a security risk. For details, see the [Security and privacy concerns](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For#security_and_privacy_concerns) section. ::: ### Is Returns the matching **content type**, if the incoming request’s [Content-Type](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Type) HTTP header field matches the [MIME type](https://developer.mozilla.org/en-US/docs/Web/HTTP/Basics_of_HTTP/MIME_types) specified by the type parameter. :::info If the request has **no** body, it returns **false**. ::: ```go title="Signature" func (c fiber.Ctx) Is(extension string) bool ``` ```go title="Example" // Content-Type: text/html; charset=utf-8 app.Get("/", func(c fiber.Ctx) error { c.Is("html") // true c.Is(".html") // true c.Is("json") // false // ... }) ``` ### IsForm Reports whether the `Content-Type` header is form-encoded. ```go title="Signature" func (c fiber.Ctx) IsForm() bool ``` ```go title="Example" // Content-Type: application/x-www-form-urlencoded app.Post("/", func(c fiber.Ctx) error { c.IsForm() // true return nil }) ``` ### IsFromLocal Returns `true` if the request came from localhost. ```go title="Signature" func (c fiber.Ctx) IsFromLocal() bool ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { // If request came from localhost, return true; else return false c.IsFromLocal() // ... }) ``` ### IsJSON Reports whether the `Content-Type` header is JSON. ```go title="Signature" func (c fiber.Ctx) IsJSON() bool ``` ```go title="Example" // Content-Type: application/json; charset=utf-8 app.Post("/", func(c fiber.Ctx) error { c.IsJSON() // true return nil }) ``` ### IsMultipart Reports whether the `Content-Type` header is multipart form data. ```go title="Signature" func (c fiber.Ctx) IsMultipart() bool ``` ```go title="Example" // Content-Type: multipart/form-data; boundary=abc123 app.Post("/", func(c fiber.Ctx) error { c.IsMultipart() // true return nil }) ``` ### IsProxyTrusted Checks the trustworthiness of the remote IP. If [`TrustProxy`](fiber.md#trustproxy) is `false`, it returns `true`. `IsProxyTrusted` can check the remote IP by proxy ranges and IP map. ```go title="Signature" func (c fiber.Ctx) IsProxyTrusted() bool ``` ```go title="Example" app := fiber.New(fiber.Config{ // TrustProxy enables the trusted proxy check TrustProxy: true, // TrustProxyConfig allows for configuring trusted proxies. // Proxies is a list of trusted proxy IP ranges/addresses TrustProxyConfig: fiber.TrustProxyConfig{ Proxies: []string{"0.8.0.0", "1.1.1.1/30"}, // IP address or IP address range Loopback: true, // Trust loopback addresses (127.0.0.0/8, ::1/128) UnixSocket: true, // Trust Unix domain socket connections }, }) app.Get("/", func(c fiber.Ctx) error { // If request came from trusted proxy, return true; else return false c.IsProxyTrusted() // ... }) ``` ### MediaType Returns the MIME type from the `Content-Type` header without parameters. ```go title="Signature" func (c fiber.Ctx) MediaType() string ``` ```go title="Example" // Content-Type: application/json; charset=utf-8 app.Post("/", func(c fiber.Ctx) error { c.MediaType() // "application/json" return nil }) ``` ### Method Returns a string corresponding to the HTTP method of the request: `GET`, `POST`, `PUT`, and so on. Optionally, you can override the method by passing a string. ```go title="Signature" func (c fiber.Ctx) Method(override ...string) string ``` ```go title="Example" app.Post("/override", func(c fiber.Ctx) error { c.Method() // "POST" c.Method("GET") c.Method() // "GET" // ... }) ``` ### MultipartForm To access multipart form entries, you can parse the binary with `MultipartForm()`. This returns a `*multipart.Form`, allowing you to access form values and files. ```go title="Signature" func (c fiber.Ctx) MultipartForm() (*multipart.Form, error) ``` ```go title="Example" app.Post("/", func(c fiber.Ctx) error { // Parse the multipart form: if form, err := c.MultipartForm(); err == nil { // => *multipart.Form if token := form.Value["token"]; len(token) > 0 { // Get key value: fmt.Println(token[0]) } // Get all files from "documents" key: files := form.File["documents"] // => []*multipart.FileHeader // Loop through files: for _, file := range files { fmt.Println(file.Filename, file.Size, file.Header["Content-Type"][0]) // => "tutorial.pdf" 360641 "application/pdf" // Save the files to disk: if err := c.SaveFile(file, fmt.Sprintf("./%s", file.Filename)); err != nil { return err } } } return nil }) ``` ### OriginalURL Returns the original request URL. ```go title="Signature" func (c fiber.Ctx) OriginalURL() string ``` ```go title="Example" // GET http://example.com/search?q=something app.Get("/", func(c fiber.Ctx) error { c.OriginalURL() // "/search?q=something" // ... }) ``` :::info The returned value is valid only within the handler. Do not store references. Make copies or use the [**`Immutable`**](./fiber.md#immutable) setting instead. [Read more...](../#zero-allocation) ::: ### Params This method can be used to get the route parameters. You can pass an optional default value that will be returned if the param key does not exist. :::info Defaults to an empty string \(`""`\) if the param **doesn't** exist. ::: ```go title="Signature" func (c fiber.Ctx) Params(key string, defaultValue ...string) string ``` ```go title="Example" // GET http://example.com/user/fenny app.Get("/user/:name", func(c fiber.Ctx) error { c.Params("name") // "fenny" // ... }) // GET http://example.com/user/fenny/123 app.Get("/user/*", func(c fiber.Ctx) error { c.Params("*") // "fenny/123" c.Params("*1") // "fenny/123" // ... }) ``` Unnamed route parameters \(\*, +\) can be fetched by the **character** and the **counter** in the route. ```go title="Example" // ROUTE: /v1/*/shop/* // GET: /v1/brand/4/shop/blue/xs c.Params("*1") // "brand/4" c.Params("*2") // "blue/xs" ``` For reasons of **downward compatibility**, the first parameter segment for the parameter character can also be accessed without the counter. ```go title="Example" app.Get("/v1/*/shop/*", func(c fiber.Ctx) error { c.Params("*") // outputs the value of the first wildcard segment }) ``` :::info The returned value is valid only within the handler. Do not store references. Make copies or use the [**`Immutable`**](./fiber.md#immutable) setting instead. [Read more...](../#zero-allocation) ::: In certain scenarios, it can be useful to have an alternative approach to handle different types of parameters, not just strings. This can be achieved using a generic `Params` function known as `Params[V GenericType](c fiber.Ctx, key string, defaultValue ...V) V`. This function is capable of parsing a route parameter and returning a value of a type that is assumed and specified by `V GenericType`. ```go title="Signature" func Params[V GenericType](c fiber.Ctx, key string, defaultValue ...V) V ``` ```go title="Example" // GET http://example.com/user/114 app.Get("/user/:id", func(c fiber.Ctx) error{ fiber.Params[string](c, "id") // returns "114" as string. fiber.Params[int](c, "id") // returns 114 as integer fiber.Params[string](c, "number") // returns "" (default string type) fiber.Params[int](c, "number") // returns 0 (default integer value type) }) ``` The generic `Params` function supports returning the following data types based on `V GenericType`: - Integer: `int`, `int8`, `int16`, `int32`, `int64` - Unsigned integer: `uint`, `uint8`, `uint16`, `uint32`, `uint64` - Floating-point numbers: `float32`, `float64` - Boolean: `bool` - String: `string` - Byte array: `[]byte` ### Path Contains the path part of the request URL. Optionally, you can override the path by passing a string. For internal redirects, you might want to call [RestartRouting](ctx.md#restartrouting) instead of [Next](ctx.md#next). ```go title="Signature" func (c fiber.Ctx) Path(override ...string) string ``` ```go title="Example" // GET http://example.com/users?sort=desc app.Get("/users", func(c fiber.Ctx) error { c.Path() // "/users" c.Path("/john") c.Path() // "/john" // ... }) ``` ### Port Returns the remote port of the request. ```go title="Signature" func (c fiber.Ctx) Port() string ``` ```go title="Example" // GET http://example.com:8080 app.Get("/", func(c fiber.Ctx) error { c.Port() // "8080" // ... }) ``` ### Protocol Contains the request protocol string: `http` or `https` for **TLS** requests. ```go title="Signature" func (c fiber.Ctx) Protocol() string ``` ```go title="Example" // GET http://example.com app.Get("/", func(c fiber.Ctx) error { c.Protocol() // "http" // ... }) ``` ### Queries `Queries` is a function that returns an object containing a property for each query string parameter in the route. ```go title="Signature" func (c fiber.Ctx) Queries() map[string]string ``` ```go title="Example" // GET http://example.com/?name=alex&want_pizza=false&id= app.Get("/", func(c fiber.Ctx) error { m := c.Queries() m["name"] // "alex" m["want_pizza"] // "false" m["id"] // "" // ... }) ``` ```go title="Example" // GET http://example.com/?field1=value1&field1=value2&field2=value3 app.Get("/", func (c fiber.Ctx) error { m := c.Queries() m["field1"] // "value2" m["field2"] // "value3" }) ``` ```go title="Example" // GET http://example.com/?list_a=1&list_a=2&list_a=3&list_b[]=1&list_b[]=2&list_b[]=3&list_c=1,2,3 app.Get("/", func(c fiber.Ctx) error { m := c.Queries() m["list_a"] // "3" m["list_b[]"] // "3" m["list_c"] // "1,2,3" }) ``` ```go title="Example" // GET /api/posts?filters.author.name=John&filters.category.name=Technology app.Get("/", func(c fiber.Ctx) error { m := c.Queries() m["filters.author.name"] // John m["filters.category.name"] // Technology }) ``` ```go title="Example" // GET /api/posts?tags=apple,orange,banana&filters[tags]=apple,orange,banana&filters[category][name]=fruits&filters.tags=apple,orange,banana&filters.category.name=fruits app.Get("/", func(c fiber.Ctx) error { m := c.Queries() m["tags"] // apple,orange,banana m["filters[tags]"] // apple,orange,banana m["filters[category][name]"] // fruits m["filters.tags"] // apple,orange,banana m["filters.category.name"] // fruits }) ``` ### Query This method returns a string corresponding to a query string parameter by name. You can pass an optional default value that will be returned if the query key does not exist. :::info If there is **no** query string, it returns an **empty string**. ::: ```go title="Signature" func (c fiber.Ctx) Query(key string, defaultValue ...string) string ``` ```go title="Example" // GET http://example.com/?order=desc&brand=nike app.Get("/", func(c fiber.Ctx) error { c.Query("order") // "desc" c.Query("brand") // "nike" c.Query("empty", "nike") // "nike" // ... }) ``` :::info The returned value is valid only within the handler. Do not store references. Make copies or use the [**`Immutable`**](./fiber.md#immutable) setting instead. [Read more...](../#zero-allocation) ::: In certain scenarios, it can be useful to have an alternative approach to handle different types of query parameters, not just strings. This can be achieved using a generic `Query` function known as `Query[V GenericType](c fiber.Ctx, key string, defaultValue ...V) V`. This function is capable of parsing a query string and returning a value of a type that is assumed and specified by `V GenericType`. Here is the signature for the generic `Query` function: ```go title="Signature" func Query[V GenericType](c fiber.Ctx, key string, defaultValue ...V) V ``` ```go title="Example" // GET http://example.com/?page=1&brand=nike&new=true app.Get("/", func(c fiber.Ctx) error { fiber.Query[int](c, "page") // 1 fiber.Query[string](c, "brand") // "nike" fiber.Query[bool](c, "new") // true // ... }) ``` In this case, `Query[V GenericType](c Ctx, key string, defaultValue ...V) V` can retrieve `page` as an integer, `brand` as a string, and `new` as a boolean. The function uses the appropriate parsing function for each specified type to ensure the correct type is returned. This simplifies the retrieval process of different types of query parameters, making your controller actions cleaner. The generic `Query` function supports returning the following data types based on `V GenericType`: - Integer: `int`, `int8`, `int16`, `int32`, `int64` - Unsigned integer: `uint`, `uint8`, `uint16`, `uint32`, `uint64` - Floating-point numbers: `float32`, `float64` - Boolean: `bool` - String: `string` - Byte array: `[]byte` ### Range Returns a struct containing the type and a slice of ranges. Only the canonical `bytes` unit is recognized and any optional whitespace around range specifiers will be ignored, as specified in RFC 9110. If none of the requested ranges are satisfiable, the method automatically sets the HTTP status code to **416 Range Not Satisfiable** and populates the `Content-Range` header with the current representation size. ```go title="Signature" func (c fiber.Ctx) Range(size int64) (Range, error) ``` ```go title="Example" // Range: bytes=500-700, 700-900 app.Get("/", func(c fiber.Ctx) error { r := c.Range(1000) if r.Type == "bytes" { for _, rng := range r.Ranges { fmt.Println(rng) // [500, 700] } } }) ``` ### Referer Returns the `Referer` request header. ```go title="Signature" func (c fiber.Ctx) Referer() string ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { c.Referer() // "https://example.com" return nil }) ``` ### RequestID ```go title="Signature" func (c fiber.Ctx) RequestID() string ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { c.RequestID() // "8d7ad5e3-aaf3-450b-a241-2beb887efd54" return nil }) ``` ### SaveFile Method is used to save **any** multipart file to disk. ```go title="Signature" func (c fiber.Ctx) SaveFile(fh *multipart.FileHeader, path string) error ``` ```go title="Example" app.Post("/", func(c fiber.Ctx) error { // Parse the multipart form: if form, err := c.MultipartForm(); err == nil { // => *multipart.Form // Get all files from "documents" key: files := form.File["documents"] // => []*multipart.FileHeader // Loop through files: for _, file := range files { fmt.Println(file.Filename, file.Size, file.Header["Content-Type"][0]) // => "tutorial.pdf" 360641 "application/pdf" // Save the files to disk: if err := c.SaveFile(file, fmt.Sprintf("./%s", file.Filename)); err != nil { return err } } return err } }) ``` ### SaveFileToStorage Method is used to save **any** multipart file to an external storage system. ```go title="Signature" func (c fiber.Ctx) SaveFileToStorage(fileheader *multipart.FileHeader, path string, storage Storage) error ``` ```go title="Example" storage := memory.New() app.Post("/", func(c fiber.Ctx) error { // Parse the multipart form: if form, err := c.MultipartForm(); err == nil { // => *multipart.Form // Get all files from "documents" key: files := form.File["documents"] // => []*multipart.FileHeader // Loop through files: for _, file := range files { fmt.Println(file.Filename, file.Size, file.Header["Content-Type"][0]) // => "tutorial.pdf" 360641 "application/pdf" // Save the files to storage: if err := c.SaveFileToStorage(file, fmt.Sprintf("./%s", file.Filename), storage); err != nil { return err } } return err } }) ``` ### Schema Contains the request protocol string: `http` or `https` for TLS requests. :::info Please use [`Config.TrustProxy`](fiber.md#trustproxy) to prevent header spoofing if your app is behind a proxy. ::: ```go title="Signature" func (c fiber.Ctx) Schema() string ``` ```go title="Example" // GET http://example.com app.Get("/", func(c fiber.Ctx) error { c.Schema() // "http" // ... }) ``` ### Secure A boolean property that is `true` if a **TLS** connection is established. ```go title="Signature" func (c fiber.Ctx) Secure() bool ``` ```go title="Example" // Secure() method is equivalent to: c.Protocol() == "https" ``` ### Stale When the client's cached response is **stale**, this method returns **true**. It is the logical complement of [`Fresh`](#fresh), which checks whether the cached representation is still valid. [https://expressjs.com/en/4x/api.html#req.stale](https://expressjs.com/en/4x/api.html#req.stale) ```go title="Signature" func (c fiber.Ctx) Stale() bool ``` ### Subdomains Returns a slice with the host’s sub-domain labels. The dot-separated parts that precede the registrable domain (`example`) and the top-level domain (ex: `com`). The `subdomain offset` (default `2`) tells Fiber how many labels, counting from the right-hand side, are always discarded. Passing an `offset` argument lets you override that value for a single call. ```go func (c fiber.Ctx) Subdomains(offset ...int) []string ``` | `offset` | Result | Meaning | | ---------------------- | --------------------------------------- | --------------------------------------------- | | *omitted* → **2** | trim 2 right-most labels | drop the registrable domain **and** the TLD | | `1` to `len(labels)-1` | trim exactly `offset` right-most labels | custom trimming of available labels | | `>= len(labels)` | **return `[]`** | offset exceeds available labels → empty slice | | `0` | **return every label** | keep the entire host unchanged | | `< 0` | **return `[]`** | negative offsets are invalid → empty slice | #### Example ```go // Host: "tobi.ferrets.example.com" app.Get("/", func(c fiber.Ctx) error { c.Subdomains() // ["tobi", "ferrets"] c.Subdomains(1) // ["tobi", "ferrets", "example"] c.Subdomains(0) // ["tobi", "ferrets", "example", "com"] c.Subdomains(-1) // [] // ... }) ``` ### UserAgent Returns the `User-Agent` request header. ```go title="Signature" func (c fiber.Ctx) UserAgent() string ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { c.UserAgent() // "Mozilla/5.0 ..." return nil }) ``` ### XHR A boolean property that is `true` if the request’s [X-Requested-With](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers) header field is [XMLHttpRequest](https://developer.mozilla.org/en-US/docs/Web/API/XMLHttpRequest), indicating that the request was issued by a client library (such as [jQuery](https://api.jquery.com/jQuery.ajax/)). ```go title="Signature" func (c fiber.Ctx) XHR() bool ``` ```go title="Example" // X-Requested-With: XMLHttpRequest app.Get("/", func(c fiber.Ctx) error { c.XHR() // true // ... }) ``` ## Response Methods which modify the response object. :::tip Use `c.Res()` to limit gopls suggestions to only these methods! ::: ### Append Appends the specified **value** to the HTTP response header field. :::caution If the header is **not** already set, it creates the header with the specified value. ::: ```go title="Signature" func (c fiber.Ctx) Append(field string, values ...string) ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { c.Append("Link", "http://google.com", "http://localhost") // => Link: http://google.com, http://localhost c.Append("Link", "Test") // => Link: http://google.com, http://localhost, Test // ... }) ``` ### Attachment Sets the HTTP response [Content-Disposition](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition) header field to `attachment`. ```go title="Signature" func (c fiber.Ctx) Attachment(filename ...string) ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { c.Attachment() // => Content-Disposition: attachment c.Attachment("./upload/images/logo.png") // => Content-Disposition: attachment; filename="logo.png" // => Content-Type: image/png // ... }) ``` Non-ASCII filenames are encoded using the `filename*` parameter as defined in [RFC 6266](https://www.rfc-editor.org/rfc/rfc6266) and [RFC 8187](https://www.rfc-editor.org/rfc/rfc8187): ```go title="Example" app.Get("/non-ascii", func(c fiber.Ctx) error { c.Attachment("./files/文件.txt") // => Content-Disposition: attachment; filename="文件.txt"; filename*=UTF-8''%E6%96%87%E4%BB%B6.txt return nil }) ``` ### AutoFormat Performs content-negotiation on the [Accept](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept) HTTP header. It uses [Accepts](ctx.md#accepts) to select a proper format. The supported content types are `text/html`, `text/plain`, `application/json`, `application/vnd.msgpack`, `application/xml`, and `application/cbor`. For more flexible content negotiation, use [Format](ctx.md#format). :::info If the header is **not** specified or there is **no** proper format, **text/plain** is used. ::: ```go title="Signature" func (c fiber.Ctx) AutoFormat(body any) error ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { // Accept: text/plain c.AutoFormat("Hello, World!") // => Hello, World! // Accept: text/html c.AutoFormat("Hello, World!") // =>

Hello, World!

type User struct { Name string } user := User{"John Doe"} // Accept: application/json c.AutoFormat(user) // => {"Name":"John Doe"} // Accept: application/vnd.msgpack c.AutoFormat(user) // => 82 a4 6e 61 6d 65 a4 6a 6f 68 6e a4 70 61 73 73 a3 64 6f 65 // Accept: application/cbor c.AutoFormat(user) // => a1 64 4e 61 6d 65 68 4a 6f 68 6e 20 44 6f 65 // Accept: application/xml c.AutoFormat(user) // => John Doe // .. }) ``` ### CBOR CBOR converts any interface or string to CBOR encoded bytes. > **Note:** Before using any CBOR-related features, make sure to follow the [CBOR setup instructions](../guide/advance-format.md#cbor). :::info CBOR also sets the content header to the `ctype` parameter. If no `ctype` is passed in, the header is set to `application/cbor`. ::: ```go title="Signature" func (c fiber.Ctx) CBOR(data any, ctype ...string) error ``` ```go title="Example" type SomeStruct struct { Name string `cbor:"name"` Age uint8 `cbor:"age"` } app.Get("/cbor", func(c fiber.Ctx) error { // Create data struct: data := SomeStruct{ Name: "Grame", Age: 20, } return c.CBOR(data) // => Content-Type: application/cbor // => \xa2dnameeGramecage\x14 return c.CBOR(fiber.Map{ "name": "Grame", "age": 20, }) // => Content-Type: application/cbor // => \xa2dnameeGramecage\x14 return c.CBOR(fiber.Map{ "type": "https://example.com/probs/out-of-credit", "title": "You do not have enough credit.", "status": 403, "detail": "Your current balance is 30, but that costs 50.", "instance": "/account/12345/msgs/abc", }) // => Content-Type: application/cbor // => \xa5dtypex'https://example.com/probs/out-of-creditetitlex\x1eYou do not have enough credit.fstatus\x19\x01\x93fdetailx.Your current balance is 30, but that costs 50.hinstancew/account/12345/msgs/abc }) ``` ### ClearCookie Expires a client cookie (or all cookies if left empty). ```go title="Signature" func (c fiber.Ctx) ClearCookie(key ...string) ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { // Clears all cookies: c.ClearCookie() // Expire specific cookie by name: c.ClearCookie("user") // Expire multiple cookies by names: c.ClearCookie("token", "session", "track_id", "version") // ... }) ``` :::caution Web browsers and other compliant clients will only clear the cookie if the given options are identical to those when creating the cookie, excluding `Expires` and `MaxAge`. `ClearCookie` will not set these values for you - a technique similar to the one shown below should be used to ensure your cookie is deleted. ::: ```go title="Example" app.Get("/set", func(c fiber.Ctx) error { c.Cookie(&fiber.Cookie{ Name: "token", Value: "randomvalue", Expires: time.Now().Add(24 * time.Hour), HTTPOnly: true, SameSite: "Lax", }) // ... }) app.Get("/delete", func(c fiber.Ctx) error { c.Cookie(&fiber.Cookie{ Name: "token", Expires: fasthttp.CookieExpireDelete, // Use fasthttp's built-in constant HTTPOnly: true, SameSite: "Lax", }) // ... }) ``` You can also use `c.Cookie()` to expire cookies with specific `Path` or `Domain` attributes: ```go title="Example" app.Get("/logout", func(c fiber.Ctx) error { // Expire a cookie with path and domain c.Cookie(&fiber.Cookie{ Name: "token", Path: "/api", Domain: "example.com", Expires: fasthttp.CookieExpireDelete, }) return c.SendStatus(fiber.StatusOK) }) ``` ### Cookie Sets a cookie. ```go title="Signature" func (c fiber.Ctx) Cookie(cookie *Cookie) ``` ```go type Cookie struct { Name string `json:"name"` // The name of the cookie Value string `json:"value"` // The value of the cookie Path string `json:"path"` // Specifies a URL path which is allowed to receive the cookie Domain string `json:"domain"` // Specifies the domain which is allowed to receive the cookie MaxAge int `json:"max_age"` // The maximum age (in seconds) of the cookie Expires time.Time `json:"expires"` // The expiration date of the cookie Secure bool `json:"secure"` // Indicates that the cookie should only be transmitted over a secure HTTPS connection HTTPOnly bool `json:"http_only"` // Indicates that the cookie is accessible only through the HTTP protocol SameSite string `json:"same_site"` // Controls whether or not a cookie is sent with cross-site requests Partitioned bool `json:"partitioned"` // Indicates if the cookie is stored in a partitioned cookie jar SessionOnly bool `json:"session_only"` // Indicates if the cookie is a session-only cookie } ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { // Create cookie cookie := new(fiber.Cookie) cookie.Name = "john" cookie.Value = "doe" cookie.Expires = time.Now().Add(24 * time.Hour) // Set cookie c.Cookie(cookie) // ... }) ``` :::info When setting a cookie with `SameSite=None`, Fiber automatically sets `Secure=true` as required by RFC 6265bis and modern browsers. This ensures compliance with the "None" SameSite policy which mandates that cookies must be sent over secure connections. For more information, see: - [Mozilla Documentation](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie#none) - [Chrome Documentation](https://developers.google.com/search/blog/2020/01/get-ready-for-new-samesitenone-secure) ::: :::info Partitioned cookies allow partitioning the cookie jar by top-level site, enhancing user privacy by preventing cookies from being shared across different sites. This feature is particularly useful in scenarios where a user interacts with embedded third-party services that should not have access to the main site's cookies. You can check out [CHIPS](https://developers.google.com/privacy-sandbox/3pcd/chips) for more information. ::: ```go title="Example" app.Get("/", func(c fiber.Ctx) error { // Create a new partitioned cookie cookie := new(fiber.Cookie) cookie.Name = "user_session" cookie.Value = "abc123" cookie.Partitioned = true // This cookie will be stored in a separate jar when it's embedded into another website // Set the cookie in the response c.Cookie(cookie) return c.SendString("Partitioned cookie set") }) ``` ### Download Transfers the file from the given path as an `attachment`. Typically, browsers will prompt the user to download. By default, the [Content-Disposition](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition) header `filename=` parameter is the file path (this typically appears in the browser dialog). Override this default with the `filename` parameter. ```go title="Signature" func (c fiber.Ctx) Download(file string, filename ...string) error ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { return c.Download("./files/report-12345.pdf") // => Download report-12345.pdf return c.Download("./files/report-12345.pdf", "report.pdf") // => Download report.pdf }) ``` For filenames containing non-ASCII characters, a `filename*` parameter is added according to [RFC 6266](https://www.rfc-editor.org/rfc/rfc6266) and [RFC 8187](https://www.rfc-editor.org/rfc/rfc8187): ```go title="Example" app.Get("/non-ascii", func(c fiber.Ctx) error { return c.Download("./files/文件.txt") // => Content-Disposition: attachment; filename="文件.txt"; filename*=UTF-8''%E6%96%87%E4%BB%B6.txt }) ``` ### End End immediately flushes the current response and closes the underlying connection. ```go title="Signature" func (c fiber.Ctx) End() error ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { c.SendString("Hello World!") return c.End() }) ``` :::caution Calling `c.End()` will disallow further writes to the underlying connection. ::: :::warning `c.End()` does **not** work in streaming mode (e.g. when using `fasthttp`'s `HijackConn` or `SendStream`). In streaming mode the connection is managed asynchronously and `ctx.Conn()` may return `nil`, so `c.End()` will return `nil` without flushing or closing the connection. ::: End can be used to stop a middleware from modifying a response of a handler/other middleware down the method chain when they regain control after calling `c.Next()`. ```go title="Example" // Error Logging/Responding middleware app.Use(func(c fiber.Ctx) error { err := c.Next() // Log errors & write the error to the response if err != nil { log.Printf("Got error in middleware: %v", err) return c.Writef("(got error %v)", err) } // No errors occurred return nil }) // Handler with simulated error app.Get("/", func(c fiber.Ctx) error { // Closes the connection instantly after writing from this handler // and disallow further modification of its response defer c.End() c.SendString("Hello, ... I forgot what comes next!") return errors.New("some error") }) ``` ### Format Performs content-negotiation on the [Accept](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept) HTTP header. It uses [Accepts](ctx.md#accepts) to select a proper format from the supplied offers. A default handler can be provided by setting the `MediaType` to `"default"`. If no offers match and no default is provided, a 406 (Not Acceptable) response is sent. The Content-Type is automatically set when a handler is selected. :::info If the Accept header is **not** specified, the first handler will be used. ::: ```go title="Signature" func (c fiber.Ctx) Format(handlers ...ResFmt) error ``` ```go title="Example" // Accept: application/json => {"command":"eat","subject":"fruit"} // Accept: text/plain => Eat Fruit! // Accept: application/xml => Not Acceptable app.Get("/no-default", func(c fiber.Ctx) error { return c.Format( fiber.ResFmt{"application/json", func(c fiber.Ctx) error { return c.JSON(fiber.Map{ "command": "eat", "subject": "fruit", }) }}, fiber.ResFmt{"text/plain", func(c fiber.Ctx) error { return c.SendString("Eat Fruit!") }}, ) }) // Accept: application/json => {"command":"eat","subject":"fruit"} // Accept: text/plain => Eat Fruit! // Accept: application/xml => Eat Fruit! app.Get("/default", func(c fiber.Ctx) error { textHandler := func(c fiber.Ctx) error { return c.SendString("Eat Fruit!") } handlers := []fiber.ResFmt{ {"application/json", func(c fiber.Ctx) error { return c.JSON(fiber.Map{ "command": "eat", "subject": "fruit", }) }}, {"text/plain", textHandler}, {"default", textHandler}, } return c.Format(handlers...) }) ``` ### JSON Converts any **interface** or **string** to JSON using the [encoding/json](https://pkg.go.dev/encoding/json) package. :::info JSON also sets the content header to the `ctype` parameter. If no `ctype` is passed in, the header is set to `application/json; charset=utf-8` by default. ::: ```go title="Signature" func (c fiber.Ctx) JSON(data any, ctype ...string) error ``` ```go title="Example" type SomeStruct struct { Name string Age uint8 } app.Get("/json", func(c fiber.Ctx) error { // Create data struct: data := SomeStruct{ Name: "Grame", Age: 20, } return c.JSON(data) // => Content-Type: application/json; charset=utf-8 // => {"Name": "Grame", "Age": 20} return c.JSON(fiber.Map{ "name": "Grame", "age": 20, }) // => Content-Type: application/json; charset=utf-8 // => {"name": "Grame", "age": 20} return c.JSON(fiber.Map{ "type": "https://example.com/probs/out-of-credit", "title": "You do not have enough credit.", "status": 403, "detail": "Your current balance is 30, but that costs 50.", "instance": "/account/12345/msgs/abc", }, "application/problem+json") // => Content-Type: application/problem+json // => "{ // => "type": "https://example.com/probs/out-of-credit", // => "title": "You do not have enough credit.", // => "status": 403, // => "detail": "Your current balance is 30, but that costs 50.", // => "instance": "/account/12345/msgs/abc", // => }" }) ``` ### JSONP Sends a JSON response with JSONP support. This method is identical to [JSON](ctx.md#json), except that it opts-in to JSONP callback support. By default, the callback name is simply `callback`. Override this by passing a **named string** in the method. ```go title="Signature" func (c fiber.Ctx) JSONP(data any, callback ...string) error ``` ```go title="Example" type SomeStruct struct { Name string Age uint8 } app.Get("/", func(c fiber.Ctx) error { // Create data struct: data := SomeStruct{ Name: "Grame", Age: 20, } return c.JSONP(data) // => callback({"Name": "Grame", "Age": 20}) return c.JSONP(data, "customFunc") // => customFunc({"Name": "Grame", "Age": 20}) }) ``` ### Links Joins the links followed by the property to populate the response’s [Link HTTP header](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Link) field. ```go title="Signature" func (c fiber.Ctx) Links(link ...string) ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { c.Links( "http://api.example.com/users?page=2", "next", "http://api.example.com/users?page=5", "last", ) // Link: ; rel="next", // ; rel="last" // ... }) ``` ### Location Sets the response [Location](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Location) HTTP header to the specified path parameter. ```go title="Signature" func (c fiber.Ctx) Location(path string) ``` ```go title="Example" app.Post("/", func(c fiber.Ctx) error { c.Location("http://example.com") c.Location("/foo/bar") return nil }) ``` ### MsgPack > **Note:** Before using any MsgPack-related features, make sure to follow the [MsgPack setup instructions](../guide/advance-format.md#msgpack). A compact binary alternative to [JSON](#json) for efficient data transfer between micro-services or from server to client. MessagePack serializes faster and yields smaller payloads than plain JSON. Converts any **interface** or **string** to MsgPack using the [shamaton/msgpack](https://pkg.go.dev/github.com/shamaton/msgpack/v3) package. :::info MsgPack also sets the content header to the `ctype` parameter. If no `ctype` is passed in, the header is set to `application/vnd.msgpack`. ::: ```go title="Signature" func (c fiber.Ctx) MsgPack(data any, ctype ...string) error ``` ```go title="Example" type SomeStruct struct { Name string Age uint8 } app.Get("/msgpack", func(c fiber.Ctx) error { // Create data struct: data := SomeStruct{ Name: "Grame", Age: 20, } return c.MsgPack(data) // => Content-Type: application/vnd.msgpack // => 82 A4 4E 61 6D 65 A5 47 72 61 6D 65 A3 41 67 65 14 return c.MsgPack(fiber.Map{ "name": "Grame", "age": 20, }) // => Content-Type: application/vnd.msgpack // => 82 A4 6E 61 6D 65 A5 47 72 61 6D 65 A3 61 67 65 14 return c.MsgPack(fiber.Map{ "type": "https://example.com/probs/out-of-credit", "title": "You do not have enough credit.", "status": 403, "detail": "Your current balance is 30, but that costs 50.", "instance": "/account/12345/msgs/abc", }, "application/problem+msgpack") }) // => Content-Type: application/problem+msgpack // 85 A4 74 79 70 65 D9 27 68 74 74 70 73 3A 2F 2F 65 78 61 6D 70 6C 65 2E 63 6F 6D 2F 70 72 6F 62 73 2F 6F 75 74 2D 6F 66 2D 63 72 65 64 69 74 A5 74 69 74 6C 65 BE 59 6F 75 20 64 6F 20 6E 6F 74 20 68 61 76 65 20 65 6E 6F 75 67 68 20 63 72 65 64 69 74 2E A6 73 74 61 74 75 73 CD 01 93 A6 64 65 74 61 69 6C D9 2E 59 6F 75 72 20 63 75 72 72 65 6E 74 20 62 61 6C 61 6E 63 65 20 69 73 20 33 30 2C 20 62 75 74 20 74 68 61 74 20 63 6F 73 74 73 20 35 30 2E A8 69 6E 73 74 61 6E 63 65 B7 2F 61 63 63 6F 75 6E 74 2F 31 32 33 34 35 2F 6D 73 67 73 2F 61 62 63 ``` ### Render Renders a view with data and sends a `text/html` response. By default, `Render` uses the default [**Go Template engine**](https://pkg.go.dev/html/template/). If you want to use another view engine, please take a look at our [**Template middleware**](https://docs.gofiber.io/template). ```go title="Signature" func (c fiber.Ctx) Render(name string, bind any, layouts ...string) error ``` ### Send Sets the HTTP response body. ```go title="Signature" func (c fiber.Ctx) Send(body []byte) error ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { return c.Send([]byte("Hello, World!")) // => "Hello, World!" }) ``` Fiber also provides `SendString` and `SendStream` methods for raw inputs. :::tip Use this if you **don't need** type assertion, recommended for **faster** performance. ::: ```go title="Signature" func (c fiber.Ctx) SendString(body string) error func (c fiber.Ctx) SendStream(stream io.Reader, size ...int) error ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") // => "Hello, World!" return c.SendStream(bytes.NewReader([]byte("Hello, World!"))) // => "Hello, World!" }) ``` ### SendEarlyHints Sends an informational `103 Early Hints` response with one or more [`Link` headers](https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Link) before the final response. This allows the browser to start preloading resources while the server prepares the full response. :::caution This feature requires HTTP/2 or newer. Some legacy HTTP/1.1 clients may not support sendEarlyHints. Early Hints (`103` responses) are supported in HTTP/2 and newer. Older HTTP/1.1 clients may ignore these interim responses or misbehave when receiving them. See [Enabling HTTP/2](../guide/reverse-proxy#enabling-http2) for instructions on how to use a reverse proxy (e.g. Nginx or Traefik) to enable HTTP/2 support. ::: ```go title="Signature" func (c fiber.Ctx) SendEarlyHints(hints []string) error ``` ```go title="Example" hints := []string{"; rel=preload; as=script"} app.Get("/early", func(c fiber.Ctx) error { if err := c.SendEarlyHints(hints); err != nil { return err } return c.SendString("done") }) ``` ### SendFile Transfers the file from the given path. Sets the [Content-Type](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Type) response HTTP header field based on the **file** extension or format. ```go title="Config" title="Config" // SendFile defines configuration options when to transfer file with SendFile. type SendFile struct { // FS is the file system to serve the static files from. // You can use interfaces compatible with fs.FS like embed.FS, os.DirFS etc. // // Optional. Default: nil FS fs.FS // When set to true, the server tries minimizing CPU usage by caching compressed files. // This works differently than the github.com/gofiber/compression middleware. // You have to set Content-Encoding header to compress the file. // Available compression methods are gzip, br, and zstd. // // Optional. Default: false Compress bool `json:"compress"` // When set to true, enables byte range requests. // // Optional. Default: false ByteRange bool `json:"byte_range"` // When set to true, enables direct download. // // Optional. Default: false Download bool `json:"download"` // Expiration duration for inactive file handlers. // Use a negative time.Duration to disable it. // // Optional. Default: 10 * time.Second CacheDuration time.Duration `json:"cache_duration"` // The value for the Cache-Control HTTP-header // that is set on the file response. MaxAge is defined in seconds. // // Optional. Default: 0 MaxAge int `json:"max_age"` } ``` ```go title="Signature" title="Signature" func (c fiber.Ctx) SendFile(file string, config ...SendFile) error ``` ```go title="Example" app.Get("/not-found", func(c fiber.Ctx) error { return c.SendFile("./public/404.html") // Disable compression return c.SendFile("./static/index.html", fiber.SendFile{ Compress: false, }) }) ``` :::info If the file contains a URL-specific character, you have to escape it before passing the file path into the `SendFile` function. ::: ```go title="Example" app.Get("/file-with-url-chars", func(c fiber.Ctx) error { return c.SendFile(url.PathEscape("hash_sign_#.txt")) }) ``` :::info You can set the `CacheDuration` config property to `-1` to disable caching. ::: ```go title="Example" app.Get("/file", func(c fiber.Ctx) error { return c.SendFile("style.css", fiber.SendFile{ CacheDuration: -1, }) }) ``` :::info You can use multiple `SendFile` calls with different configurations in a single route. Fiber creates different filesystem handlers per config. ::: ```go title="Example" app.Get("/file", func(c fiber.Ctx) error { switch c.Query("config") { case "filesystem": return c.SendFile("style.css", fiber.SendFile{ FS: os.DirFS(".") }) case "filesystem-compress": return c.SendFile("style.css", fiber.SendFile{ FS: os.DirFS("."), Compress: true, }) case "compress": return c.SendFile("style.css", fiber.SendFile{ Compress: true, }) default: return c.SendFile("style.css") } return nil }) ``` :::info For sending multiple files from an embedded file system, [this functionality](../middleware/static.md#serving-files-using-embedfs) can be used. ::: ### SendStatus Sets the status code and the correct status message in the body if the response body is **empty**. :::tip You can find all used status codes and messages [in the Fiber source code](https://github.com/gofiber/fiber/blob/dffab20bcdf4f3597d2c74633a7705a517d2c8c2/utils.go#L183-L244). ::: ```go title="Signature" func (c fiber.Ctx) SendStatus(status int) error ``` ```go title="Example" app.Get("/not-found", func(c fiber.Ctx) error { return c.SendStatus(415) // => 415 "Unsupported Media Type" c.SendString("Hello, World!") return c.SendStatus(415) // => 415 "Hello, World!" }) ``` ### SendStream Sets the response body to a stream of data and adds an optional body size. ```go title="Signature" func (c fiber.Ctx) SendStream(stream io.Reader, size ...int) error ``` :::info `SendStream` operates asynchronously. The handler returns immediately after setting up the stream, but the actual reading and sending of data happens **after** the handler completes. This is handled by the underlying `fasthttp` library. If the provided stream implements `io.Closer`, it will be automatically closed by `fasthttp` after the response is fully sent or if an error occurs. ::: :::caution When passing `fiber.Ctx` as a `context.Context` to libraries that spawn goroutines (e.g., for streaming operations), those goroutines may attempt to access the context after the handler returns. Since `fiber.Ctx` is recycled and released after the handler completes, this can cause issues. **Recommended approach**: Use `c.Context()` or `c.RequestCtx()` instead of passing `c` directly to such libraries. See the [Context Guide](../guide/context.md) for more details. ::: ```go title="Example" app.Get("/", func(c fiber.Ctx) error { return c.SendStream(bytes.NewReader([]byte("Hello, World!"))) // => "Hello, World!" }) ``` ```go title="Example with file streaming" app.Get("/download", func(c fiber.Ctx) error { file, err := os.Open("large-file.zip") if err != nil { return err } // File will be automatically closed by fasthttp after streaming completes stat, err := file.Stat() if err != nil { file.Close() return err } return c.SendStream(file, int(stat.Size())) }) ``` ### SendStreamWriter Sets the response body stream writer. :::note The argument `streamWriter` represents a function that populates the response body using a buffered stream writer. ::: ```go title="Signature" func (c Ctx) SendStreamWriter(streamWriter func(*bufio.Writer)) error ``` ```go title="Example" app.Get("/", func (c fiber.Ctx) error { return c.SendStreamWriter(func(w *bufio.Writer) { fmt.Fprintf(w, "Hello, World!\n") }) // => "Hello, World!" }) ``` :::info To send data before `streamWriter` returns, you can call `w.Flush()` on the provided writer. Otherwise, the buffered stream flushes after `streamWriter` returns. ::: :::note `w.Flush()` will return an error if the client disconnects before `streamWriter` finishes writing a response. ::: ```go title="Example" app.Get("/wait", func(c fiber.Ctx) error { return c.SendStreamWriter(func(w *bufio.Writer) { // Begin Work fmt.Fprintf(w, "Please wait for 10 seconds\n") if err := w.Flush(); err != nil { log.Print("Client disconnected!") return } // Send progress over time time.Sleep(time.Second) for i := 0; i < 9; i++ { fmt.Fprintf(w, "Still waiting...\n") if err := w.Flush(); err != nil { // If client disconnected, cancel work and finish log.Print("Client disconnected!") return } time.Sleep(time.Second) } // Finish fmt.Fprintf(w, "Done!\n") }) }) ``` ### SendString Sets the response body to a string. ```go title="Signature" func (c fiber.Ctx) SendString(body string) error ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") // => "Hello, World!" }) ``` ### Set Sets the response’s HTTP header field to the specified `key`, `value`. ```go title="Signature" func (c fiber.Ctx) Set(key string, val string) ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { c.Set("Content-Type", "text/plain") // => "Content-Type: text/plain" // ... }) ``` ### Status Sets the HTTP status for the response. :::info This method is **chainable**. ::: ```go title="Signature" func (c fiber.Ctx) Status(status int) fiber.Ctx ``` ```go title="Example" app.Get("/fiber", func(c fiber.Ctx) error { c.Status(fiber.StatusOK) return nil }) app.Get("/hello", func(c fiber.Ctx) error { return c.Status(fiber.StatusBadRequest).SendString("Bad Request") }) app.Get("/world", func(c fiber.Ctx) error { return c.Status(fiber.StatusNotFound).SendFile("./public/gopher.png") }) ``` ### Type Sets the [Content-Type](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Type) HTTP header to the MIME type listed [in the Nginx MIME types configuration](https://github.com/nginx/nginx/blob/master/conf/mime.types) specified by the file **extension**. :::info This method is **chainable**. ::: ```go title="Signature" func (c fiber.Ctx) Type(ext string, charset ...string) fiber.Ctx ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { c.Type(".html") // => "text/html" c.Type("html") // => "text/html" c.Type("png") // => "image/png" c.Type("json", "utf-8") // => "application/json; charset=utf-8" // ... }) ``` ### Vary Adds the given header field to the [Vary](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Vary) response header. This will append the header if not already listed; otherwise, it leaves it listed in the current location. :::info Multiple fields are **allowed**. ::: ```go title="Signature" func (c fiber.Ctx) Vary(fields ...string) ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { c.Vary("Origin") // => Vary: Origin c.Vary("User-Agent") // => Vary: Origin, User-Agent // No duplicates c.Vary("Origin") // => Vary: Origin, User-Agent c.Vary("Accept-Encoding", "Accept") // => Vary: Origin, User-Agent, Accept-Encoding, Accept // ... }) ``` ### Write Adopts the `Writer` interface. ```go title="Signature" func (c fiber.Ctx) Write(p []byte) (n int, err error) ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { c.Write([]byte("Hello, World!")) // => "Hello, World!" fmt.Fprintf(c, "%s\n", "Hello, World!") // => "Hello, World!" }) ``` ### Writef Writes a formatted string using a format specifier. ```go title="Signature" func (c fiber.Ctx) Writef(format string, a ...any) (n int, err error) ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { world := "World!" c.Writef("Hello, %s", world) // => "Hello, World!" fmt.Fprintf(c, "%s\n", "Hello, World!") // => "Hello, World!" }) ``` ### WriteString Writes a string to the response body. ```go title="Signature" func (c fiber.Ctx) WriteString(s string) (n int, err error) ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { return c.WriteString("Hello, World!") // => "Hello, World!" }) ``` ### XML Converts any **interface** or **string** to XML using the standard `encoding/xml` package. :::info XML also sets the content header to `application/xml; charset=utf-8`. ::: ```go title="Signature" func (c fiber.Ctx) XML(data any) error ``` ```go title="Example" type SomeStruct struct { XMLName xml.Name `xml:"Fiber"` Name string `xml:"Name"` Age uint8 `xml:"Age"` } app.Get("/", func(c fiber.Ctx) error { // Create data struct: data := SomeStruct{ Name: "Grame", Age: 20, } return c.XML(data) // // Grame // 20 // }) ``` ================================================ FILE: docs/api/fiber.md ================================================ --- id: fiber title: 📦 Fiber description: Fiber represents the fiber package where you start to create an instance. sidebar_position: 1 --- import Reference from '@site/src/components/reference'; ## Server start ### New This method creates a new **App** named instance. You can pass optional [config](#config) when creating a new instance. ```go title="Signature" func New(config ...Config) *App ``` ```go title="Example" // Default config app := fiber.New() // ... ``` ### Config You can pass an optional Config when creating a new Fiber instance. ```go title="Example" // Custom config app := fiber.New(fiber.Config{ CaseSensitive: true, StrictRouting: true, ServerHeader: "Fiber", AppName: "Test App v1.0.1", }) // ... ``` #### Config fields | Property | Type | Description | Default | |---------------------------------------------------------------------------------------|-----------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------| | AppName | `string` | Sets the application name used in logs and the Server header | `""` | | BodyLimit | `int` | Sets the maximum allowed size for a request body. Zero or negative values fall back to the default limit. If the size exceeds the configured limit, it sends `413 - Request Entity Too Large` response. This limit also applies when running Fiber through the adaptor middleware from `net/http`. | `4 * 1024 * 1024` | | CaseSensitive | `bool` | When enabled, `/Foo` and `/foo` are different routes. When disabled, `/Foo` and `/foo` are treated the same. | `false` | | CBORDecoder | `utils.CBORUnmarshal` | Allowing for flexibility in using another cbor library for decoding. | `binder.UnimplementedCborUnmarshal` | | CBOREncoder | `utils.CBORMarshal` | Allowing for flexibility in using another cbor library for encoding. | `binder.UnimplementedCborMarshal` | | ColorScheme | [`Colors`](https://github.com/gofiber/fiber/blob/main/color.go) | You can define custom color scheme. They'll be used for startup message, route list and some middlewares. | [`DefaultColors`](https://github.com/gofiber/fiber/blob/main/color.go) | | CompressedFileSuffixes | `map[string]string` | Adds a suffix to the original file name and tries saving the resulting compressed file under the new file name. | `{"gzip": ".fiber.gz", "br": ".fiber.br", "zstd": ".fiber.zst"}` | | Concurrency | `int` | Maximum number of concurrent connections. | `256 * 1024` | | DisableDefaultContentType | `bool` | When true, omits the default Content-Type header from the response. | `false` | | DisableDefaultDate | `bool` | When true, omits the Date header from the response. | `false` | | DisableHeadAutoRegister | `bool` | Prevents Fiber from automatically registering `HEAD` routes for each `GET` route so you can supply custom `HEAD` handlers; manual `HEAD` routes still override the generated ones. | `false` | | DisableHeaderNormalizing | `bool` | By default all header names are normalized: conteNT-tYPE -> Content-Type | `false` | | DisableKeepalive | `bool` | Disables keep-alive connections so the server closes each connection after the first response. | `false` | | DisablePreParseMultipartForm | `bool` | Will not pre parse Multipart Form data if set to true. This option is useful for servers that desire to treat multipart form data as a binary blob, or choose when to parse the data. | `false` | | EnableIPValidation | `bool` | If set to true, `c.IP()` and `c.IPs()` will validate IP addresses before returning them. Also, `c.IP()` will return only the first valid IP rather than just the raw header value that may be a comma separated string.

**WARNING:** There is a small performance cost to doing this validation. Keep disabled if speed is your only concern and your application is behind a trusted proxy that already validates this header. | `false` | | EnableSplittingOnParsers | `bool` | Splits query, body, and header parameters on commas when enabled.

For example, `/api?foo=bar,baz` becomes `foo[]=bar&foo[]=baz`. | `false` | | ErrorHandler | `ErrorHandler` | ErrorHandler is executed when an error is returned from fiber.Handler. Mounted fiber error handlers are retained by the top-level app and applied on prefix associated requests. | `DefaultErrorHandler` | | GETOnly | `bool` | Rejects all non-GET requests if set to true. This option is useful as anti-DoS protection for servers accepting only GET requests. The request size is limited by ReadBufferSize if GETOnly is set. | `false` | | IdleTimeout | `time.Duration` | The maximum amount of time to wait for the next request when keep-alive is enabled. If IdleTimeout is zero, the value of ReadTimeout is used. | `0` | | Immutable | `bool` | When enabled, all values returned by context methods are immutable. By default, they are valid until you return from the handler; see issue [\#185](https://github.com/gofiber/fiber/issues/185). | `false` | | JSONDecoder | `utils.JSONUnmarshal` | Allowing for flexibility in using another json library for decoding. | `json.Unmarshal` | | JSONEncoder | `utils.JSONMarshal` | Allowing for flexibility in using another json library for encoding. | `json.Marshal` | | MaxRanges | `int` | Sets the maximum number of ranges parsed from a `Range` header. Zero or negative values fall back to the default limit. If the limit is exceeded, the request is rejected with `416 - Requested Range Not Satisfiable` and `Content-Range: bytes */`. | `16` | | MsgPackDecoder | `utils.MsgPackUnmarshal` | Allowing for flexibility in using another msgpack library for decoding. | `binder.UnimplementedMsgpackUnmarshal` | | MsgPackEncoder | `utils.MsgPackMarshal` | Allowing for flexibility in using another msgpack library for encoding. | `binder.UnimplementedMsgpackMarshal` | | PassLocalsToContext | `bool` | Controls whether `StoreInContext` also propagates values into the request `context.Context` for Fiber-backed contexts. `StoreInContext` always writes to `c.Locals()`. `ValueFromContext` for Fiber-backed contexts always reads from `c.Locals()`. | `false` | | PassLocalsToViews | `bool` | PassLocalsToViews Enables passing of the locals set on a fiber.Ctx to the template engine. See our **Template Middleware** for supported engines. | `false` | | ProxyHeader | `string` | Specifies the header name to read the client's real IP address from when behind a reverse proxy. Common values: `fiber.HeaderXForwardedFor`, `"X-Real-IP"`, `"CF-Connecting-IP"` (Cloudflare).

**Important:** This setting **requires** `TrustProxy` to be enabled; `TrustProxyConfig` controls which proxy IPs are trusted for reading this header. Without `TrustProxy`, this setting has no effect and `c.IP()` will always return the remote IP from the TCP connection.

**Behavior note:** `X-Forwarded-For` often contains a comma-separated chain of IP addresses. With the default `EnableIPValidation = false`, `c.IP()` will return the raw header value (the whole chain) rather than a single parsed client IP. With `EnableIPValidation = true`, `c.IP()` parses the header and returns the **first syntactically valid IP address** it finds; it does **not** walk the chain to find the first non-proxy hop. For a reliable client IP, configure your reverse proxy to overwrite or sanitize this header and/or to provide a single-IP header such as `"X-Real-IP"` or a provider-specific header like `"CF-Connecting-IP"`.

**Security Warning:** Headers can be easily spoofed. Always configure `TrustProxyConfig` to validate the proxy IP address, otherwise malicious clients can forge headers to bypass IP-based access controls. | `""` | | ReadBufferSize | `int` | per-connection buffer size for requests' reading. This also limits the maximum header size. Increase this buffer if your clients send multi-KB RequestURIs and/or multi-KB headers \(for example, BIG cookies\). | `4096` | | ReadTimeout | `time.Duration` | The amount of time allowed to read the full request, including the body. The default timeout is unlimited. | `0` | | ReduceMemoryUsage | `bool` | Aggressively reduces memory usage at the cost of higher CPU usage if set to true. | `false` | | RequestMethods | `[]string` | RequestMethods provides customizability for HTTP methods. You can add/remove methods as you wish. | `DefaultMethods` | | ServerHeader | `string` | Enables the `Server` HTTP header with the given value. | `""` | | StreamRequestBody | `bool` | StreamRequestBody enables request body streaming, and calls the handler sooner when given body is larger than the current limit. | `false` | | StrictRouting | `bool` | When enabled, the router treats `/foo` and `/foo/` as different. Otherwise, the router treats `/foo` and `/foo/` as the same. | `false` | | StructValidator | `StructValidator` | If you want to validate header/form/query... automatically when to bind, you can define struct validator. Fiber doesn't have default validator, so it'll skip validator step if you don't use any validator. | `nil` | | TrustProxy | `bool` | Enables trust of reverse proxy headers. When enabled, Fiber will check if the request is coming from a trusted proxy (configured in `TrustProxyConfig`) before reading values from proxy headers.

**Required for**: Using `ProxyHeader` to read client IP from headers like `X-Forwarded-For`.

**Behavior when enabled:** If the remote IP is trusted (matches `TrustProxyConfig`), then `c.IP()` reads from `ProxyHeader` (when configured; otherwise it uses `RemoteIP()`), `c.Scheme()` first checks standard proxy scheme headers (`X-Forwarded-Proto`, `X-Forwarded-Protocol`, `X-Forwarded-Ssl`, `X-Url-Scheme`) and falls back to the actual connection scheme if none are set, and `c.Hostname()` prefers `X-Forwarded-Host` but falls back to the request Host header when the proxy header is not present. If the remote IP is NOT trusted, these methods ignore proxy headers and use the actual connection values instead.

**Security:** This prevents header spoofing by validating the proxy's IP address. Always configure `TrustProxyConfig` when enabling this option and set `ProxyHeader` if you want `c.IP()` to use a specific header. | `false` | | TrustProxyConfig | `TrustProxyConfig` | Configures which proxy IP addresses or ranges to trust. Only effective when `TrustProxy` is enabled.

**Fields:**
• `Proxies` - List of trusted proxy IPs or CIDR ranges (e.g., `[]string{"10.10.0.58", "192.168.0.0/24"}`)
• `Loopback` - Trust loopback addresses (127.0.0.0/8, ::1/128)
• `Private` - Trust all private IP ranges (10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, fc00::/7)
• `LinkLocal` - Trust link-local addresses (169.254.0.0/16, fe80::/10)
• `UnixSocket` - Trust Unix domain socket connections

**Example:** For an app behind Nginx at 10.10.0.58, use `TrustProxyConfig{Proxies: []string{"10.10.0.58"}}` or `TrustProxyConfig{Private: true}` if using private network IPs. | `{}` | | UnescapePath | `bool` | Converts all encoded characters in the route back before setting the path for the context, so that the routing can also work with URL encoded special characters | `false` | | Views | `Views` | Views is the interface that wraps the Render function. See our **Template Middleware** for supported engines. | `nil` | | ViewsLayout | `string` | Views Layout is the global layout for all template render until override on Render function. See our **Template Middleware** for supported engines. | `""` | | WriteBufferSize | `int` | Per-connection buffer size for responses' writing. | `4096` | | WriteTimeout | `time.Duration` | The maximum duration before timing out writes of the response. The default timeout is unlimited. | `0` | | XMLDecoder | `utils.XMLUnmarshal` | Allowing for flexibility in using another XML library for decoding. | `xml.Unmarshal` | | XMLEncoder | `utils.XMLMarshal` | Allowing for flexibility in using another XML library for encoding. | `xml.Marshal` | ## Server listening ### Config You can pass an optional ListenConfig when calling the [`Listen`](#listen) or [`Listener`](#listener) method. ```go title="Example" // Custom config app.Listen(":8080", fiber.ListenConfig{ EnablePrefork: true, DisableStartupMessage: true, }) ``` #### Config fields | Property | Type | Description | Default | |-------------------------------------------------------------------------|-------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------| | BeforeServeFunc | `func(app *App) error` | Allows customizing and accessing fiber app before serving the app. | `nil` | | CertClientFile | `string` | Path of the client certificate. If you want to use mTLS, you must enter this field. | `""` | | CertFile | `string` | Path of the certificate file. If you want to use TLS, you must enter this field. | `""` | | CertKeyFile | `string` | Path of the certificate's private key. If you want to use TLS, you must enter this field. | `""` | | DisableStartupMessage | `bool` | When set to true, it will not print out the «Fiber» ASCII art and listening address. | `false` | | EnablePrefork | `bool` | When set to true, this will spawn multiple Go processes listening on the same port. | `false` | | EnablePrintRoutes | `bool` | If set to true, will print all routes with their method, path, and handler. | `false` | | GracefulContext | `context.Context` | Field to shutdown Fiber by given context gracefully. | `nil` | | ShutdownTimeout | `time.Duration` | Specifies the maximum duration to wait for the server to gracefully shutdown. When the timeout is reached, the graceful shutdown process is interrupted and forcibly terminated, and the `context.DeadlineExceeded` error is passed to the `OnPostShutdown` callback. Set to 0 to disable the timeout and wait indefinitely. | `10 * time.Second` | | ListenerAddrFunc | `func(addr net.Addr)` | Allows accessing and customizing `net.Listener`. | `nil` | | ListenerNetwork | `string` | Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only), "unix" (Unix Domain Sockets). WARNING: When prefork is set to true, only "tcp4" and "tcp6" can be chosen. | `tcp4` | | UnixSocketFileMode | `os.FileMode` | FileMode to set for Unix Domain Socket (ListenerNetwork must be "unix") | `0770` | | TLSConfigFunc | `func(tlsConfig *tls.Config)` | Allows customizing `tls.Config` as you want. Ignored when `TLSConfig` is set. | `nil` | | TLSConfig | `*tls.Config` | Recommended base TLS configuration (cloned). Use for external certificate providers via `GetCertificate`. When set, other TLS fields are ignored. | `nil` | | AutoCertManager | `*autocert.Manager` | Manages TLS certificates automatically using the ACME protocol. Enables integration with Let's Encrypt or other ACME-compatible providers. | `nil` | | TLSMinVersion | `uint16` | Allows customizing the TLS minimum version. | `tls.VersionTLS12` | ### Listen Listen serves HTTP requests from the given address. ```go title="Signature" func (app *App) Listen(addr string, config ...ListenConfig) error ``` ```go title="Basic Listen usage" // Listen on port :8080 app.Listen(":8080") // Listen on port :8080 with Prefork app.Listen(":8080", fiber.ListenConfig{EnablePrefork: true}) // Custom host app.Listen("127.0.0.1:8080") ``` #### Prefork Prefork is a feature that allows you to spawn multiple Go processes listening on the same port. This can be useful for scaling across multiple CPU cores. ```go title="Prefork listener" app.Listen(":8080", fiber.ListenConfig{EnablePrefork: true}) ``` This distributes the incoming connections between the spawned processes and allows more requests to be handled simultaneously. #### TLS Prefer `TLSConfig` for TLS configuration so you can fully control certificates and settings. When `TLSConfig` is set, Fiber ignores `CertFile`, `CertKeyFile`, `CertClientFile`, `TLSMinVersion`, `AutoCertManager`, and `TLSConfigFunc`. TLS serves HTTPs requests from the given address using certFile and keyFile paths as TLS certificate and key file. ```go title="TLS with cert and key files" app.Listen(":443", fiber.ListenConfig{CertFile: "./cert.pem", CertKeyFile: "./cert.key"}) ``` #### TLS with client CA certificate `CertClientFile` only configures the client CA for mTLS when using `CertFile`/`CertKeyFile`. If `TLSConfig` is set, `CertClientFile` is ignored, so configure client CAs in the provided `tls.Config` instead. ```go title="TLS with client CA certificate" app.Listen(":443", fiber.ListenConfig{ CertFile: "./cert.pem", CertKeyFile: "./cert.key", CertClientFile: "./ca-chain-cert.pem", }) ``` #### TLS AutoCert support (ACME / Let's Encrypt) Provides automatic access to certificates management from Let's Encrypt and any other ACME-based providers. ```go title="AutoCert (ACME) configuration" // Certificate manager certManager := &autocert.Manager{ Prompt: autocert.AcceptTOS, // Replace with your domain name HostPolicy: autocert.HostWhitelist("example.com"), // Folder to store the certificates Cache: autocert.DirCache("./certs"), } app.Listen(":444", fiber.ListenConfig{ AutoCertManager: certManager, }) ``` #### Precedence and conflicts - `TLSConfig` is preferred and ignores `CertFile`/`CertKeyFile`, `CertClientFile`, `AutoCertManager`, `TLSMinVersion`, and `TLSConfigFunc`. - `AutoCertManager` cannot be combined with `CertFile`/`CertKeyFile`. #### TLS with external certificate provider Use `TLSConfig` to supply a base `tls.Config` that can fetch certificates at runtime. `TLSConfig` is cloned and used as-is. ```go title="TLSConfig with dynamic certificate provider" app.Listen(":443", fiber.ListenConfig{ TLSConfig: &tls.Config{ GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { return myProvider.Certificate(info.ServerName) }, }, }) ``` #### Mutual TLS with TLSConfig Use `TLSConfig` to configure mutual TLS by setting `ClientAuth` and `ClientCAs`. This replaces `CertClientFile` when you manage TLS configuration directly. ```go title="TLSConfig with client CA pool" certPEM := []byte(certPEMString) keyPEM := []byte(keyPEMString) caPEM := []byte(caPEMString) cert, err := tls.X509KeyPair(certPEM, keyPEM) if err != nil { log.Fatal(err) } clientCAs := x509.NewCertPool() if ok := clientCAs.AppendCertsFromPEM(caPEM); !ok { log.Fatal("failed to append client CA") } app.Listen(":443", fiber.ListenConfig{ TLSConfig: &tls.Config{ Certificates: []tls.Certificate{cert}, ClientAuth: tls.RequireAndVerifyClientCert, ClientCAs: clientCAs, }, }) ``` Load certificates from memory or environment variables and provide them via `TLSConfig`. ```go title="TLSConfig with in-memory certificate" certPEM := []byte(certPEMString) keyPEM := []byte(keyPEMString) cert, err := tls.X509KeyPair(certPEM, keyPEM) if err != nil { log.Fatal(err) } app.Listen(":443", fiber.ListenConfig{ TLSConfig: &tls.Config{ Certificates: []tls.Certificate{cert}, }, }) ``` ```go title="TLSConfig with certificate from environment" certPEM := []byte(os.Getenv("TLS_CERT_PEM")) keyPEM := []byte(os.Getenv("TLS_KEY_PEM")) cert, err := tls.X509KeyPair(certPEM, keyPEM) if err != nil { log.Fatal(err) } app.Listen(":443", fiber.ListenConfig{ TLSConfig: &tls.Config{ Certificates: []tls.Certificate{cert}, }, }) ``` ### Listener You can pass your own [`net.Listener`](https://pkg.go.dev/net/#Listener) using the `Listener` method. This method can be used to enable **TLS/HTTPS** with a custom tls.Config. ```go title="Signature" func (app *App) Listener(ln net.Listener, config ...ListenConfig) error ``` ```go title="Examples" ln, _ := net.Listen("tcp", ":3000") cer, _:= tls.LoadX509KeyPair("server.crt", "server.key") ln = tls.NewListener(ln, &tls.Config{Certificates: []tls.Certificate{cer}}) app.Listener(ln) ``` ## Server Server returns the underlying [fasthttp server](https://godoc.org/github.com/valyala/fasthttp#Server) ```go title="Signature" func (app *App) Server() *fasthttp.Server ``` ```go title="Examples" func main() { app := fiber.New() app.Server().MaxConnsPerIP = 1 // ... } ``` ## Server Shutdown Shutdown gracefully shuts down the server without interrupting any active connections. Shutdown works by first closing all open listeners and then waits indefinitely for all connections to return to idle before shutting down. ShutdownWithTimeout will forcefully close any active connections after the timeout expires. ShutdownWithContext shuts down the server including by force if the context's deadline is exceeded. Shutdown hooks will still be executed, even if an error occurs during the shutdown process, as they are deferred to ensure cleanup happens regardless of errors. ```go func (app *App) Shutdown() error func (app *App) ShutdownWithTimeout(timeout time.Duration) error func (app *App) ShutdownWithContext(ctx context.Context) error ``` ## Helper functions ### NewError NewError creates a new HTTPError instance with an optional message. ```go title="Signature" func NewError(code int, message ...string) *Error ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { return fiber.NewError(782, "Custom error message") }) ``` ### NewErrorf NewErrorf creates a new HTTPError instance with an optional formatted message. ```go title="Signature" func NewErrorf(code int, message ...any) *Error ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { return fiber.NewErrorf(782, "Custom error %s", "message") }) ``` ### IsChild IsChild determines if the current process is a result of Prefork. ```go title="Signature" func IsChild() bool ``` ```go title="Example" // Config app app := fiber.New() app.Get("/", func(c fiber.Ctx) error { if !fiber.IsChild() { fmt.Println("I'm the parent process") } else { fmt.Println("I'm a child process") } return c.SendString("Hello, World!") }) // ... // With prefork enabled, the parent process will spawn child processes app.Listen(":8080", fiber.ListenConfig{EnablePrefork: true}) ``` ================================================ FILE: docs/api/hooks.md ================================================ --- id: hooks title: 🎣 Hooks sidebar_position: 7 --- import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; Fiber lets you run custom callbacks at specific points in the routing lifecycle. Available hooks include: - [OnRoute](#onroute) - [OnName](#onname) - [OnGroup](#ongroup) - [OnGroupName](#ongroupname) - [OnListen](#onlisten) - [OnPreStartupMessage/OnPostStartupMessage](#onprestartupmessageonpoststartupmessage) - [ListenData](#listendata) - [OnFork](#onfork) - [OnPreShutdown](#onpreshutdown) - [OnPostShutdown](#onpostshutdown) - [OnMount](#onmount) ## Constants ```go // Handlers define functions to create hooks for Fiber. type OnRouteHandler = func(Route) error type OnNameHandler = OnRouteHandler type OnGroupHandler = func(Group) error type OnGroupNameHandler = OnGroupHandler type OnListenHandler = func(ListenData) error type OnForkHandler = func(int) error type OnPreStartupMessageHandler = func(*PreStartupMessageData) error type OnPostStartupMessageHandler = func(*PostStartupMessageData) error type OnPreShutdownHandler = func() error type OnPostShutdownHandler = func(error) error type OnMountHandler = func(*App) error ``` ## OnRoute Runs after each route is registered. The callback receives the route so you can inspect its properties. ```go title="Signature" func (h *Hooks) OnRoute(handler ...OnRouteHandler) ``` ## OnName Runs when a route is named. The callback receives the route. :::caution `OnName` only works with named routes, not groups. ::: ```go title="Signature" func (h *Hooks) OnName(handler ...OnNameHandler) ``` ```go package main import ( "fmt" "github.com/gofiber/fiber/v3" ) func main() { app := fiber.New() app.Get("/", func(c fiber.Ctx) error { return c.SendString(c.Route().Name) }).Name("index") app.Hooks().OnName(func(r fiber.Route) error { fmt.Print("Name: " + r.Name + ", ") return nil }) app.Hooks().OnName(func(r fiber.Route) error { fmt.Print("Method: " + r.Method + "\n") return nil }) app.Get("/add/user", func(c fiber.Ctx) error { return c.SendString(c.Route().Name) }).Name("addUser") app.Delete("/destroy/user", func(c fiber.Ctx) error { return c.SendString(c.Route().Name) }).Name("destroyUser") app.Listen(":5000") } // Results: // Name: addUser, Method: GET // Name: destroyUser, Method: DELETE ``` ## OnGroup Runs after each group is registered. The callback receives the group. ```go title="Signature" func (h *Hooks) OnGroup(handler ...OnGroupHandler) ``` ## OnGroupName Runs when a group is named. The callback receives the group. :::caution `OnGroupName` only works with named groups, not routes. ::: ```go title="Signature" func (h *Hooks) OnGroupName(handler ...OnGroupNameHandler) ``` ## OnListen Runs when the app starts listening via `Listen` or `Listener`. ```go title="Signature" func (h *Hooks) OnListen(handler ...OnListenHandler) ``` ```go package main import ( "log" "os" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/log" ) func main() { app := fiber.New(fiber.Config{ DisableStartupMessage: true, }) app.Hooks().OnListen(func(listenData fiber.ListenData) error { if fiber.IsChild() { return nil } scheme := "http" if listenData.TLS { scheme = "https" } log.Println(scheme + "://" + listenData.Host + ":" + listenData.Port) return nil }) app.Listen(":5000") } ``` ## OnPreStartupMessage/OnPostStartupMessage Use `OnPreStartupMessage` to tweak the banner before Fiber prints it, and `OnPostStartupMessage` to run logic after the banner is printed (or skipped). You can use some helper functions to customize the banner inside the `OnPreStartupMessage` hook. ```go title="Signatures" // AddInfo adds an informational entry to the startup message with "INFO" label. func (sm *PreStartupMessageData) AddInfo(key, title, value string, priority ...int) // AddWarning adds a warning entry to the startup message with "WARNING" label. func (sm *PreStartupMessageData) AddWarning(key, title, value string, priority ...int) // AddError adds an error entry to the startup message with "ERROR" label. func (sm *PreStartupMessageData) AddError(key, title, value string, priority ...int) // EntryKeys returns all entry keys currently present in the startup message. func (sm *PreStartupMessageData) EntryKeys() []string // ResetEntries removes all existing entries from the startup message. func (sm *PreStartupMessageData) ResetEntries() // DeleteEntry removes a specific entry from the startup message by its key. func (sm *PreStartupMessageData) DeleteEntry(key string) ``` - Assign `sm.BannerHeader` to override the ASCII art banner. Leave it empty to use the default banner provided by Fiber. - Set `sm.PreventDefault = true` to suppress the built-in banner without affecting other hooks. - `PostStartupMessageData` reports whether the banner was skipped via the `Disabled`, `IsChild`, and `Prevented` flags. ### Startup Message Customization ```go title="Customize the startup message" package main import ( "fmt" "os" "github.com/gofiber/fiber/v3" ) func main() { app := fiber.New() app.Hooks().OnPreStartupMessage(func(sm *fiber.PreStartupMessageData) error { sm.BannerHeader = "FOOBER " + sm.Version + "\n-------" // Optional: you can also remove old entries // sm.ResetEntries() sm.AddInfo("git-hash", "Git hash", os.Getenv("GIT_HASH")) sm.AddInfo("prefork", "Prefork", fmt.Sprintf("%v", sm.Prefork), 15) return nil }) app.Hooks().OnPostStartupMessage(func(sm fiber.PostStartupMessageData) error { if !sm.Disabled && !sm.IsChild && !sm.Prevented { fmt.Println("startup completed") } return nil }) app.Listen(":5000") } ``` ### ListenData `ListenData` exposes runtime metadata about the listener: | Field | Type | Description | | --- | --- | --- | | `Host` | `string` | Resolved hostname or IP address. | | `Port` | `string` | The bound port. | | `TLS` | `bool` | Indicates whether TLS is enabled. | | `Version` | `string` | Fiber version reported in the startup banner. | | `AppName` | `string` | Application name from the configuration. | | `HandlerCount` | `int` | Total registered handler count. | | `ProcessCount` | `int` | Number of processes Fiber will use. | | `PID` | `int` | Current process identifier. | | `Prefork` | `bool` | Whether prefork is enabled. | | `ChildPIDs` | `[]int` | Child process identifiers when preforking. | | `ColorScheme` | [`Colors`](https://github.com/gofiber/fiber/blob/main/color.go) | Active color scheme for the startup message. | ## OnFork Runs in the child process after a fork. ```go title="Signature" func (h *Hooks) OnFork(handler ...OnForkHandler) ``` ## OnPreShutdown Runs before the server shuts down. ```go title="Signature" func (h *Hooks) OnPreShutdown(handler ...OnPreShutdownHandler) ``` ## OnPostShutdown Runs after the server shuts down. ```go title="Signature" func (h *Hooks) OnPostShutdown(handler ...OnPostShutdownHandler) ``` ## OnMount Fires after a sub-app is mounted on a parent. The parent app is passed to the callback and it works for both app and group mounts. ```go title="Signature" func (h *Hooks) OnMount(handler ...OnMountHandler) ``` ```go package main import ( "fmt" "github.com/gofiber/fiber/v3" ) func main() { app := fiber.New() app.Get("/", testSimpleHandler).Name("x") subApp := fiber.New() subApp.Get("/test", testSimpleHandler) subApp.Hooks().OnMount(func(parent *fiber.App) error { fmt.Print("Mount path of parent app: " + parent.MountPath()) // Additional custom logic... return nil }) app.Use("/sub", subApp) } func testSimpleHandler(c fiber.Ctx) error { return c.SendString("Hello, Fiber!") } // Result: // Mount path of parent app: /sub ``` :::caution OnName, OnRoute, OnGroup, and OnGroupName are mount-sensitive. When you mount a sub-app that registers these hooks, route and group paths include the mount prefix. ::: ================================================ FILE: docs/api/log.md ================================================ --- id: log title: 📃 Log description: Fiber's built-in log package sidebar_position: 6 --- Logs help you observe program behavior, diagnose issues, and trigger alerts. Structured logs improve searchability and speed up troubleshooting. Fiber logs to standard output by default and exposes global helpers such as `log.Info`, `log.Errorf`, and `log.Warnw`. ## Log Levels ```go const ( LevelTrace Level = iota LevelDebug LevelInfo LevelWarn LevelError LevelFatal LevelPanic ) ``` ## Custom Log Fiber provides the generic `AllLogger[T]` interface for adapting various log libraries. ```go type CommonLogger interface { Logger FormatLogger WithLogger } type ConfigurableLogger[T any] interface { // SetLevel sets logging level. SetLevel(level Level) // SetOutput sets the logger output. SetOutput(w io.Writer) // Logger returns the logger instance. Logger() T } type AllLogger[T any] interface { CommonLogger ConfigurableLogger[T] WithLogger } ``` ## Print Log **Note:** The Fatal level method will terminate the program after printing the log message. Please use it with caution. ### Basic Logging Call level-specific methods directly; entries use the `messageKey` (default `msg`). ```go log.Info("Hello, World!") log.Debug("Are you OK?") log.Info("42 is the answer to life, the universe, and everything") log.Warn("We are under attack!") log.Error("Houston, we have a problem.") log.Fatal("So Long, and Thanks for All the Fish.") log.Panic("The system is down.") ``` ### Formatted Logging Append `f` to format the message. ```go log.Debugf("Hello %s", "boy") log.Infof("%d is the answer to life, the universe, and everything", 42) log.Warnf("We are under attack, %s!", "boss") log.Errorf("%s, we have a problem.", "John Smith") log.Fatalf("So Long, and Thanks for All the %s.", "fish") ``` ### Key-Value Logging Key-value helpers log structured fields; mismatched pairs emit `KEYVALS UNPAIRED`. ```go log.Debugw("", "greeting", "Hello", "target", "boy") log.Infow("", "number", 42) log.Warnw("", "job", "boss") log.Errorw("", "name", "John Smith") log.Fatalw("", "fruit", "fish") ``` ## Global Log Fiber also exposes a global logger for quick messages. ```go import "github.com/gofiber/fiber/v3/log" log.Info("info") log.Warn("warn") ``` The example uses `log.DefaultLogger`, which writes to stdout. The [contrib](https://github.com/gofiber/contrib) repo offers adapters like `fiberzap` and `fiberzerolog`, or you can register your own with `log.SetLogger`. Here's an example using a custom logger: ```go import ( "log" fiberlog "github.com/gofiber/fiber/v3/log" ) var _ fiberlog.AllLogger[*log.Logger] = (*customLogger)(nil) type customLogger struct { stdlog *log.Logger } // Implement required methods for the AllLogger interface... // Inject your custom logger fiberlog.SetLogger[*log.Logger](&customLogger{ stdlog: log.New(os.Stdout, "CUSTOM ", log.LstdFlags), }) // Retrieve the underlying *log.Logger for direct use std := fiberlog.DefaultLogger[*log.Logger]().Logger() std.Println("custom logging") ``` ## Set Level `log.SetLevel` sets the minimum level that will be output. The default is `LevelTrace`. **Note:** This method is not concurrent safe. ```go import "github.com/gofiber/fiber/v3/log" log.SetLevel(log.LevelInfo) ``` Setting the log level allows you to control the verbosity of the logs, filtering out messages below the specified level. ## Set Output `log.SetOutput` sets where logs are written. By default, they go to the console. ### Writing Logs to Stderr ```go var logger fiberlog.AllLogger[*log.Logger] = &defaultLogger{ stdlog: log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile|log.Lmicroseconds), depth: 4, } ``` This lets you route logs to a file, service, or any destination. ### Writing Logs to a File To write to a file such as `test.log`: ```go // Output to ./test.log file f, err := os.OpenFile("test.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) if err != nil { log.Fatal("Failed to open log file:", err) } log.SetOutput(f) ``` ### Writing Logs to Both Console and File Write to both `test.log` and `stdout`: ```go // Output to ./test.log file file, err := os.OpenFile("test.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) if err != nil { log.Fatal("Failed to open log file:", err) } iw := io.MultiWriter(os.Stdout, file) log.SetOutput(iw) ``` ## Bind Context Bind a logger to a context with `log.WithContext`, which returns a `CommonLogger` tied to that context. ```go commonLogger := log.WithContext(ctx) commonLogger.Info("info") ``` Context binding adds request-specific data for easier tracing. ## Logger Use `Logger` to access the underlying logger and call its native methods: ```go logger := fiberlog.DefaultLogger[*log.Logger]() // Get the default logger instance stdlogger := logger.Logger() // stdlogger is *log.Logger stdlogger.SetFlags(0) // Hide timestamp by setting flags to 0 ``` ================================================ FILE: docs/api/redirect.md ================================================ --- id: redirect title: 🔄 Redirect description: Fiber's built-in redirect package sidebar_position: 5 toc_max_heading_level: 5 --- Redirect helpers send the client to another URL or route. ## Redirect Methods ### To Redirects to a URL built from the given path. Optionally set an HTTP [status](#status). :::info If unspecified, status defaults to **303 See Other**. ::: ```go title="Signature" func (r *Redirect) To(location string) error ``` ```go title="Example" app.Get("/coffee", func(c fiber.Ctx) error { // => HTTP - GET 301 /teapot return c.Redirect().Status(fiber.StatusMovedPermanently).To("/teapot") }) app.Get("/teapot", func(c fiber.Ctx) error { return c.Status(fiber.StatusTeapot).Send("🍵 short and stout 🍵") }) ``` ```go title="More examples" app.Get("/", func(c fiber.Ctx) error { // => HTTP - GET 303 /foo/bar return c.Redirect().To("/foo/bar") // => HTTP - GET 303 ../login return c.Redirect().To("../login") // => HTTP - GET 303 http://example.com return c.Redirect().To("http://example.com") // => HTTP - GET 301 https://example.com return c.Redirect().Status(301).To("http://example.com") }) ``` ### Route Redirects to a named route with parameters and queries. :::info To send params and queries to a route, use the [`RedirectConfig`](#redirectconfig) struct. ::: ```go title="Signature" func (r *Redirect) Route(name string, config ...RedirectConfig) error ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { // /user/fiber return c.Redirect().Route("user", fiber.RedirectConfig{ Params: fiber.Map{ "name": "fiber", }, }) }) app.Get("/with-queries", func(c fiber.Ctx) error { // /user/fiber?data[0][name]=john&data[0][age]=10&test=doe return c.Redirect().Route("user", fiber.RedirectConfig{ Params: fiber.Map{ "name": "fiber", }, Queries: map[string]string{ "data[0][name]": "john", "data[0][age]": "10", "test": "doe", }, }) }) app.Get("/user/:name", func(c fiber.Ctx) error { return c.SendString(c.Params("name")) }).Name("user") ``` ### Back Redirects to the referer. If it's missing, fall back to the provided URL. You can also set the status code. :::info If unspecified, status defaults to **303 See Other**. ::: ```go title="Signature" func (r *Redirect) Back(fallback string) error ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { return c.SendString("Home page") }) app.Get("/test", func(c fiber.Ctx) error { c.Set("Content-Type", "text/html") return c.SendString(`Back`) }) app.Get("/back", func(c fiber.Ctx) error { return c.Redirect().Back("/") }) ``` ## Controls :::info Methods are **chainable**. ::: ### Status Sets the HTTP status code for the redirect. :::info It is used in conjunction with [**To**](#to), [**Route**](#route), and [**Back**](#back) methods. ::: ```go title="Signature" func (r *Redirect) Status(status int) *Redirect ``` ```go title="Example" app.Get("/coffee", func(c fiber.Ctx) error { // => HTTP - GET 301 /teapot return c.Redirect().Status(fiber.StatusMovedPermanently).To("/teapot") }) ``` ### RedirectConfig Sets the configuration for the redirect. :::info It is used in conjunction with the [**Route**](#route) method. ::: ```go title="Definition" // RedirectConfig is a config to use with Redirect().Route() type RedirectConfig struct { Params fiber.Map // Route parameters Queries map[string]string // Query map } ``` ### Flash Message Similar to [Laravel](https://laravel.com/docs/11.x/redirects#redirecting-with-flashed-session-data), we can flash a message and retrieve it in the next request. #### Messages Retrieve all flash messages. See [With](#with) for details. ```go title="Signature" func (r *Redirect) Messages() map[string]string ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { messages := c.Redirect().Messages() return c.JSON(messages) }) ``` #### Message Get a flash message by key; see [With](#with). ```go title="Signature" func (r *Redirect) Message(key string) *Redirect ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { message := c.Redirect().Message("status") return c.SendString(message) }) ``` #### OldInputs Retrieve stored input data. See [WithInput](#withinput). ```go title="Signature" func (r *Redirect) OldInputs() map[string]string ``` ```go title="Example" app.Get("/", func(c fiber.Ctx) error { oldInputs := c.Redirect().OldInputs() return c.JSON(oldInputs) }) ``` #### OldInput Get stored input data by key; see [WithInput](#withinput). ```go title="Signature" func (r *Redirect) OldInput(key string) string ``` ```go title="Example" app.Get("/name", func(c fiber.Ctx) error { oldInput := c.Redirect().OldInput("name") return c.SendString(oldInput) }) ``` #### With Send flash messages with `With`. ```go title="Signature" func (r *Redirect) With(key, value string) *Redirect ``` ```go title="Example" app.Get("/login", func(c fiber.Ctx) error { return c.Redirect().With("status", "Logged in successfully").To("/") }) app.Get("/", func(c fiber.Ctx) error { // => Logged in successfully return c.SendString(c.Redirect().Message("status")) }) ``` #### WithInput Send input data with `WithInput`, which stores them in a cookie. It captures form, multipart, or query data depending on the request content type. ```go title="Signature" func (r *Redirect) WithInput() *Redirect ``` ```go title="Example" // curl -X POST http://localhost:3000/login -d "name=John" app.Post("/login", func(c fiber.Ctx) error { return c.Redirect().WithInput().Route("name") }) app.Get("/name", func(c fiber.Ctx) error { // => John return c.SendString(c.Redirect().OldInput("name")) }).Name("name") ``` ================================================ FILE: docs/api/services.md ================================================ --- id: services title: 🧩 Services sidebar_position: 9 --- Services wrap external dependencies. Register them in the application's state, and Fiber starts and stops them automatically—useful during development and testing. After adding a service to the app configuration, Fiber starts it on launch and stops it during shutdown. Retrieve a service from state with `GetService` or `MustGetService` (see [State Management](./state)). ## Service Interface The `Service` interface defines methods a service must implement. ### Definition ```go type Service interface { // Start starts the service, returning an error if it fails. Start(ctx context.Context) error // String returns a string representation of the service. // It is used to print a human-readable name of the service in the startup message. String() string // State returns the current state of the service. State(ctx context.Context) (string, error) // Terminate terminates the service, returning an error if it fails. Terminate(ctx context.Context) error } ``` ## Service Methods ### Start Starts the service. Fiber calls this when the application starts. ```go func (s *SomeService) Start(ctx context.Context) error ``` ### String Returns a string representation of the service, used to print the service in the startup message. ```go func (s *SomeService) String() string ``` ### State Reports the current state of the service for the startup message. ```go func (s *SomeService) State(ctx context.Context) (string, error) ``` ### Terminate Stops the service after the application shuts down using a post-shutdown hook. ```go func (s *SomeService) Terminate(ctx context.Context) error ``` ## Comprehensive Examples ### Example: Adding a Service This example demonstrates how to add a Redis store as a service to the application, backed by the Testcontainers Redis Go module. ```go package main import ( "context" "fmt" "log" "time" "github.com/gofiber/fiber/v3" "github.com/redis/go-redis/v9" tcredis "github.com/testcontainers/testcontainers-go/modules/redis" ) const redisServiceName = "redis-store" type redisService struct { ctr *tcredis.RedisContainer } // Start initializes and starts the service. It implements the [fiber.Service] interface. func (s *redisService) Start(ctx context.Context) error { // start the service c, err := tcredis.Run(ctx, "redis:latest") if err != nil { return err } s.ctr = c return nil } // String returns a string representation of the service. // It is used to print a human-readable name of the service in the startup message. // It implements the [fiber.Service] interface. func (s *redisService) String() string { return redisServiceName } // State returns the current state of the service. // It implements the [fiber.Service] interface. func (s *redisService) State(ctx context.Context) (string, error) { state, err := s.ctr.State(ctx) if err != nil { return "", fmt.Errorf("container state: %w", err) } return state.Status, nil } // Terminate stops and removes the service. It implements the [fiber.Service] interface. func (s *redisService) Terminate(ctx context.Context) error { // stop the service return s.ctr.Terminate(ctx) } func main() { cfg := &fiber.Config{} // Initialize service. cfg.Services = append(cfg.Services, &redisService{}) // Define a context provider for the services startup. // This is useful to cancel the startup of the services if the context is canceled. // Default is context.Background(). startupCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() cfg.ServicesStartupContextProvider = func() context.Context { return startupCtx } // Define a context provider for the services shutdown. // This is useful to cancel the shutdown of the services if the context is canceled. // Default is context.Background(). shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() cfg.ServicesShutdownContextProvider = func() context.Context { return shutdownCtx } app := fiber.New(*cfg) ctx := context.Background() // Obtain the Redis service from the application's State. redisSrv, ok := fiber.GetService[*redisService](app.State(), redisServiceName) if !ok || redisSrv == nil { log.Printf("Redis service not found") return } // Obtain the connection string from the service. connString, err := redisSrv.ctr.ConnectionString(ctx) if err != nil { log.Printf("Could not get connection string: %v", err) return } // Parse the connection string to create a Redis client. options, err := redis.ParseURL(connString) if err != nil { log.Printf("failed to parse connection string: %s", err) return } // Initialize the Redis client. rdb := redis.NewClient(options) // Check the Redis connection. if err := rdb.Ping(ctx).Err(); err != nil { log.Fatalf("Could not connect to Redis: %v", err) } app.Listen(":3000") } ``` ### Example: Add a service with the Store middleware This example shows how to use services with the Store middleware for dependency injection. It uses a Redis store backed by the Testcontainers Redis module. ```go package main import ( "context" "encoding/json" "fmt" "log" "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/logger" redisStore "github.com/gofiber/storage/redis/v3" "github.com/redis/go-redis/v9" tcredis "github.com/testcontainers/testcontainers-go/modules/redis" ) const ( redisServiceName = "redis-store" ) type User struct { ID int `json:"id"` Name string `json:"name"` Email string `json:"email"` } type redisService struct { ctr *tcredis.RedisContainer } // Start initializes and starts the service. It implements the [fiber.Service] interface. func (s *redisService) Start(ctx context.Context) error { // start the service c, err := tcredis.Run(ctx, "redis:latest") if err != nil { return err } s.ctr = c return nil } // String returns a string representation of the service. // It is used to print a human-readable name of the service in the startup message. // It implements the [fiber.Service] interface. func (s *redisService) String() string { return redisServiceName } // State returns the current state of the service. // It implements the [fiber.Service] interface. func (s *redisService) State(ctx context.Context) (string, error) { state, err := s.ctr.State(ctx) if err != nil { return "", fmt.Errorf("container state: %w", err) } return state.Status, nil } // Terminate stops and removes the service. It implements the [fiber.Service] interface. func (s *redisService) Terminate(ctx context.Context) error { // stop the service return s.ctr.Terminate(ctx) } func main() { cfg := &fiber.Config{} // Initialize service. cfg.Services = append(cfg.Services, &redisService{}) // Define a context provider for the services startup. // This is useful to cancel the startup of the services if the context is canceled. // Default is context.Background(). startupCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() cfg.ServicesStartupContextProvider = func() context.Context { return startupCtx } // Define a context provider for the services shutdown. // This is useful to cancel the shutdown of the services if the context is canceled. // Default is context.Background(). shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() cfg.ServicesShutdownContextProvider = func() context.Context { return shutdownCtx } app := fiber.New(*cfg) // Initialize default config app.Use(logger.New()) ctx := context.Background() // Obtain the Redis service from the application's State. redisSrv, ok := fiber.GetService[*redisService](app.State(), redisServiceName) if !ok || redisSrv == nil { log.Printf("Redis service not found") return } // Obtain the connection string from the service. connString, err := redisSrv.ctr.ConnectionString(ctx) if err != nil { log.Printf("Could not get connection string: %v", err) return } // define a GoFiber session store, backed by the Redis service store := redisStore.New(redisStore.Config{ URL: connString, }) app.Post("/user/create", func(c fiber.Ctx) error { var user User if err := c.Bind().JSON(&user); err != nil { return c.Status(fiber.StatusBadRequest).SendString(err.Error()) } json, err := json.Marshal(user) if err != nil { return c.Status(fiber.StatusInternalServerError).SendString(err.Error()) } // Save the user to the database. err = store.Set(user.Email, json, time.Hour*24) if err != nil { return c.Status(fiber.StatusInternalServerError).SendString(err.Error()) } return c.JSON(user) }) app.Get("/user/:id", func(c fiber.Ctx) error { id := c.Params("id") user, err := store.Get(id) if err == redis.Nil { return c.Status(fiber.StatusNotFound).SendString("User not found") } else if err != nil { return c.Status(fiber.StatusInternalServerError).SendString(err.Error()) } return c.JSON(string(user)) }) app.Listen(":3000") } ``` ================================================ FILE: docs/api/state.md ================================================ --- id: state title: 🗂️ State Management sidebar_position: 8 --- State management provides a global key–value store for application dependencies and runtime data. The store is shared across the entire application and persists between requests. It's commonly used to store [Services](../api/services), which you can retrieve with the `GetService` or `MustGetService` functions. :::warning When prefork is enabled, each worker process has an independent state store, meaning state is not shared between them. ::: ## State Type `State` is a key–value store built on top of `sync.Map` to ensure safe concurrent access. It allows storage and retrieval of dependencies and configurations in a Fiber application as well as thread–safe access to runtime data. ### Definition ```go // State is a key–value store for Fiber's app, used as a global storage for the app's dependencies. // It is a thread–safe implementation of a map[string]any, using sync.Map. type State struct { dependencies sync.Map } ``` ## Methods on State ### Set Set adds or updates a key–value pair in the State. ```go // Set adds or updates a key–value pair in the State. func (s *State) Set(key string, value any) ``` **Usage Example:** ```go app.State().Set("appName", "My Fiber App") ``` ### Get Get retrieves a value from the State. ```go title="Signature" func (s *State) Get(key string) (any, bool) ``` **Usage Example:** ```go value, ok := app.State().Get("appName") if ok { fmt.Println("App Name:", value) } ``` ### MustGet MustGet retrieves a value from the State and panics if the key is not found. ```go title="Signature" func (s *State) MustGet(key string) any ``` **Usage Example:** ```go appName := app.State().MustGet("appName") fmt.Println("App Name:", appName) ``` ### Has Has checks if a key exists in the State. ```go title="Signature" func (s *State) Has(key string) bool ``` **Usage Example:** ```go if app.State().Has("appName") { fmt.Println("App Name is set.") } ``` ### Delete Delete removes a key–value pair from the State. ```go title="Signature" func (s *State) Delete(key string) ``` **Usage Example:** ```go app.State().Delete("obsoleteKey") ``` ### Reset Reset removes all keys from the State, including those related to Services. ```go title="Signature" func (s *State) Reset() ``` **Usage Example:** ```go app.State().Reset() ``` ### Keys Keys returns a slice containing all keys present in the State. ```go title="Signature" func (s *State) Keys() []string ``` **Usage Example:** ```go keys := app.State().Keys() fmt.Println("State Keys:", keys) ``` ### Len Len returns the number of keys in the State. ```go // Len returns the number of keys in the State. func (s *State) Len() int ``` **Usage Example:** ```go fmt.Printf("Total State Entries: %d\n", app.State().Len()) ``` ### GetString GetString retrieves a string value from the State. It returns the string and a boolean indicating a successful type assertion. ```go title="Signature" func (s *State) GetString(key string) (string, bool) ``` **Usage Example:** ```go if appName, ok := app.State().GetString("appName"); ok { fmt.Println("App Name:", appName) } ``` ### GetInt GetInt retrieves an integer value from the State. It returns the int and a boolean indicating a successful type assertion. ```go title="Signature" func (s *State) GetInt(key string) (int, bool) ``` **Usage Example:** ```go if count, ok := app.State().GetInt("userCount"); ok { fmt.Printf("User Count: %d\n", count) } ``` ### GetBool GetBool retrieves a boolean value from the State. It returns the bool and a boolean indicating a successful type assertion. ```go title="Signature" func (s *State) GetBool(key string) (value, bool) ``` **Usage Example:** ```go if debug, ok := app.State().GetBool("debugMode"); ok { fmt.Printf("Debug Mode: %v\n", debug) } ``` ### GetFloat64 GetFloat64 retrieves a float64 value from the State. It returns the float64 and a boolean indicating a successful type assertion. ```go title="Signature" func (s *State) GetFloat64(key string) (float64, bool) ``` **Usage Example:** ```go title="Signature" if ratio, ok := app.State().GetFloat64("scalingFactor"); ok { fmt.Printf("Scaling Factor: %f\n", ratio) } ``` ### GetUint GetUint retrieves a `uint` value from the State. ```go title="Signature" func (s *State) GetUint(key string) (uint, bool) ``` **Usage Example:** ```go if val, ok := app.State().GetUint("maxConnections"); ok { fmt.Printf("Max Connections: %d\n", val) } ``` ### GetInt8 GetInt8 retrieves an `int8` value from the State. ```go title="Signature" func (s *State) GetInt8(key string) (int8, bool) ``` **Usage Example:** ```go if val, ok := app.State().GetInt8("threshold"); ok { fmt.Printf("Threshold: %d\n", val) } ``` ### GetInt16 GetInt16 retrieves an `int16` value from the State. ```go title="Signature" func (s *State) GetInt16(key string) (int16, bool) ``` **Usage Example:** ```go if val, ok := app.State().GetInt16("minValue"); ok { fmt.Printf("Minimum Value: %d\n", val) } ``` ### GetInt32 GetInt32 retrieves an `int32` value from the State. ```go title="Signature" func (s *State) GetInt32(key string) (int32, bool) ``` **Usage Example:** ```go if val, ok := app.State().GetInt32("portNumber"); ok { fmt.Printf("Port Number: %d\n", val) } ``` ### GetInt64 GetInt64 retrieves an `int64` value from the State. ```go title="Signature" func (s *State) GetInt64(key string) (int64, bool) ``` **Usage Example:** ```go if val, ok := app.State().GetInt64("fileSize"); ok { fmt.Printf("File Size: %d\n", val) } ``` ### GetUint8 GetUint8 retrieves a `uint8` value from the State. ```go title="Signature" func (s *State) GetUint8(key string) (uint8, bool) ``` **Usage Example:** ```go if val, ok := app.State().GetUint8("byteValue"); ok { fmt.Printf("Byte Value: %d\n", val) } ``` ### GetUint16 GetUint16 retrieves a `uint16` value from the State. ```go title="Signature" func (s *State) GetUint16(key string) (uint16, bool) ``` **Usage Example:** ```go if val, ok := app.State().GetUint16("limit"); ok { fmt.Printf("Limit: %d\n", val) } ``` ### GetUint32 GetUint32 retrieves a `uint32` value from the State. ```go title="Signature" func (s *State) GetUint32(key string) (uint32, bool) ``` **Usage Example:** ```go if val, ok := app.State().GetUint32("timeout"); ok { fmt.Printf("Timeout: %d\n", val) } ``` ### GetUint64 GetUint64 retrieves a `uint64` value from the State. ```go title="Signature" func (s *State) GetUint64(key string) (uint64, bool) ``` **Usage Example:** ```go if val, ok := app.State().GetUint64("maxSize"); ok { fmt.Printf("Max Size: %d\n", val) } ``` ### GetUintptr GetUintptr retrieves a `uintptr` value from the State. ```go title="Signature" func (s *State) GetUintptr(key string) (uintptr, bool) ``` **Usage Example:** ```go if val, ok := app.State().GetUintptr("pointerValue"); ok { fmt.Printf("Pointer Value: %d\n", val) } ``` ### GetFloat32 GetFloat32 retrieves a `float32` value from the State. ```go title="Signature" func (s *State) GetFloat32(key string) (float32, bool) ``` **Usage Example:** ```go if val, ok := app.State().GetFloat32("scalingFactor32"); ok { fmt.Printf("Scaling Factor (float32): %f\n", val) } ``` ### GetComplex64 GetComplex64 retrieves a `complex64` value from the State. ```go title="Signature" func (s *State) GetComplex64(key string) (complex64, bool) ``` **Usage Example:** ```go if val, ok := app.State().GetComplex64("complexVal"); ok { fmt.Printf("Complex Value (complex64): %v\n", val) } ``` ### GetComplex128 GetComplex128 retrieves a `complex128` value from the State. ```go title="Signature" func (s *State) GetComplex128(key string) (complex128, bool) ``` **Usage Example:** ```go if val, ok := app.State().GetComplex128("complexVal128"); ok { fmt.Printf("Complex Value (complex128): %v\n", val) } ``` ## Generic Functions Fiber provides generic functions to retrieve state values with type safety and fallback options. ### GetState GetState retrieves a value from the State and casts it to the desired type. It returns the cast value and a boolean indicating if the cast was successful. ```go title="Signature" func GetState[T any](s *State, key string) (T, bool) ``` **Usage Example:** ```go // Retrieve an integer value safely. userCount, ok := GetState[int](app.State(), "userCount") if ok { fmt.Printf("User Count: %d\n", userCount) } ``` ### MustGetState MustGetState retrieves a value from the State and casts it to the desired type. It panics if the key is not found or if the type assertion fails. ```go title="Signature" func MustGetState[T any](s *State, key string) T ``` **Usage Example:** ```go // Retrieve the value or panic if it is not present. config := MustGetState[string](app.State(), "configFile") fmt.Println("Config File:", config) ``` ### GetStateWithDefault GetStateWithDefault retrieves a value from the State, casting it to the desired type. If the key is not present, it returns the provided default value. ```go title="Signature" func GetStateWithDefault[T any](s *State, key string, defaultVal T) T ``` **Usage Example:** ```go // Retrieve a value with a fallback. requestCount := GetStateWithDefault[int](app.State(), "requestCount", 0) fmt.Printf("Request Count: %d\n", requestCount) ``` ### GetService GetService retrieves a Service from the State and casts it to the desired type. It returns the cast value and a boolean indicating if the cast was successful. ```go title="Signature" func GetService[T Service](s *State, key string) (T, bool) { ``` **Usage Example:** ```go if srv, ok := fiber.GetService[*redisService](app.State(), "someService") fmt.Printf("Some Service: %s\n", srv.String()) } ``` ### MustGetService MustGetService retrieves a Service from the State and casts it to the desired type. It panics if the key is not found or if the type assertion fails. ```go title="Signature" func MustGetService[T Service](s *State, key string) T ``` **Usage Example:** ```go srv := fiber.MustGetService[*SomeService](app.State(), "someService") ``` ## Comprehensive Examples ### Example: Request Counter This example demonstrates how to track the number of requests using the State. ```go package main import ( "fmt" "github.com/gofiber/fiber/v3" ) func main() { app := fiber.New() // Initialize state with a counter. app.State().Set("requestCount", 0) // Middleware: Increase counter for every request. app.Use(func(c fiber.Ctx) error { count, _ := c.App().State().GetInt("requestCount") app.State().Set("requestCount", count+1) return c.Next() }) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello World!") }) app.Get("/stats", func(c fiber.Ctx) error { count, _ := c.App().State().Get("requestCount") return c.SendString(fmt.Sprintf("Total requests: %d", count)) }) app.Listen(":3000") } ``` ### Example: Environment–Specific Configuration This example shows how to configure different settings based on the environment. ```go package main import ( "os" "github.com/gofiber/fiber/v3" ) func main() { app := fiber.New() // Determine environment. environment := os.Getenv("ENV") if environment == "" { environment = "development" } app.State().Set("environment", environment) // Set environment-specific configurations. if environment == "development" { app.State().Set("apiUrl", "http://localhost:8080/api") app.State().Set("debug", true) } else { app.State().Set("apiUrl", "https://api.production.com") app.State().Set("debug", false) } app.Get("/config", func(c fiber.Ctx) error { config := map[string]any{ "environment": environment, "apiUrl": fiber.GetStateWithDefault(c.App().State(), "apiUrl", ""), "debug": fiber.GetStateWithDefault(c.App().State(), "debug", false), } return c.JSON(config) }) app.Listen(":3000") } ``` ### Example: Dependency Injection with State Management This example demonstrates how to use the State for dependency injection in a Fiber application. ```go package main import ( "context" "fmt" "log" "github.com/gofiber/fiber/v3" "github.com/redis/go-redis/v9" ) type User struct { ID int `query:"id"` Name string `query:"name"` Email string `query:"email"` } func main() { app := fiber.New() ctx := context.Background() // Initialize Redis client. rdb := redis.NewClient(&redis.Options{ Addr: "localhost:6379", Password: "", DB: 0, }) // Check the Redis connection. if err := rdb.Ping(ctx).Err(); err != nil { log.Fatalf("Could not connect to Redis: %v", err) } // Inject the Redis client into Fiber's State for dependency injection. app.State().Set("redis", rdb) app.Get("/user/create", func(c fiber.Ctx) error { var user User if err := c.Bind().Query(&user); err != nil { return c.Status(fiber.StatusBadRequest).SendString(err.Error()) } // Retrieve the Redis client from the global state. rdb, ok := fiber.GetState[*redis.Client](c.App().State(), "redis") if !ok { return c.Status(fiber.StatusInternalServerError).SendString("Redis client not found") } // Save the user to the database. key := fmt.Sprintf("user:%d", user.ID) err := rdb.HSet(ctx, key, "name", user.Name, "email", user.Email).Err() if err != nil { return c.Status(fiber.StatusInternalServerError).SendString(err.Error()) } return c.JSON(user) }) app.Get("/user/:id", func(c fiber.Ctx) error { id := c.Params("id") rdb, ok := fiber.GetState[*redis.Client](c.App().State(), "redis") if !ok { return c.Status(fiber.StatusInternalServerError).SendString("Redis client not found") } key := fmt.Sprintf("user:%s", id) user, err := rdb.HGetAll(ctx, key).Result() if err == redis.Nil { return c.Status(fiber.StatusNotFound).SendString("User not found") } else if err != nil { return c.Status(fiber.StatusInternalServerError).SendString(err.Error()) } return c.JSON(user) }) app.Listen(":3000") } ``` ================================================ FILE: docs/client/_category_.json ================================================ { "label": "\uD83C\uDF0E Client", "position": 6, "link": { "type": "generated-index", "description": "HTTP client for Fiber." } } ================================================ FILE: docs/client/examples.md ================================================ --- id: examples title: 🍳 Examples description: >- Client usage examples. sidebar_position: 5 --- import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; ## Basic Auth Clients send credentials via the `Authorization` header, while the server stores hashed passwords as shown in the middleware example. ```go package main import ( "encoding/base64" "fmt" "github.com/gofiber/fiber/v3/client" ) func main() { cc := client.New() out := base64.StdEncoding.EncodeToString([]byte("john:doe")) resp, err := cc.Get("http://localhost:3000", client.Config{ Header: map[string]string{ "Authorization": "Basic " + out, }, }) if err != nil { panic(err) } fmt.Print(string(resp.Body())) } ``` ```go package main import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/basicauth" ) func main() { app := fiber.New() app.Use( basicauth.New(basicauth.Config{ Users: map[string]string{ // "doe" hashed using SHA-256 "john": "{SHA256}eZ75KhGvkY4/t0HfQpNPO1aO0tk6wd908bjUGieTKm8=", }, }), ) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) app.Listen(":3000") } ``` ## TLS ```go package main import ( "crypto/tls" "crypto/x509" "fmt" "os" "github.com/gofiber/fiber/v3/client" ) func main() { cc := client.New() certPool, err := x509.SystemCertPool() if err != nil { panic(err) } cert, err := os.ReadFile("ssl.cert") if err != nil { panic(err) } certPool.AppendCertsFromPEM(cert) cc.SetTLSConfig(&tls.Config{ RootCAs: certPool, }) resp, err := cc.Get("https://localhost:3000") if err != nil { panic(err) } fmt.Print(string(resp.Body())) } ``` ```go package main import ( "github.com/gofiber/fiber/v3" ) func main() { app := fiber.New() app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) err := app.Listen(":3000", fiber.ListenConfig{ CertFile: "ssl.cert", CertKeyFile: "ssl.key", }) if err != nil { panic(err) } } ``` ## Reusing fasthttp transports The Fiber client can wrap existing `fasthttp` clients so that you can reuse connection pools, custom dialers, or load-balancing logic that is already tuned for your infrastructure. ### HostClient ```go package main import ( "log" "time" "github.com/gofiber/fiber/v3/client" "github.com/valyala/fasthttp" ) func main() { hc := &fasthttp.HostClient{ Addr: "api.internal:443", IsTLS: true, MaxConnDuration: 30 * time.Second, MaxIdleConnDuration: 10 * time.Second, } cc := client.NewWithHostClient(hc) resp, err := cc.Get("https://api.internal:443/status") if err != nil { log.Fatal(err) } log.Printf("status=%d body=%s", resp.StatusCode(), resp.Body()) } ``` ### LBClient ```go package main import ( "log" "time" "github.com/gofiber/fiber/v3/client" "github.com/valyala/fasthttp" ) func main() { lb := &fasthttp.LBClient{ Timeout: 2 * time.Second, Clients: []fasthttp.BalancingClient{ &fasthttp.HostClient{Addr: "edge-1.internal:8080"}, &fasthttp.HostClient{Addr: "edge-2.internal:8080"}, }, } cc := client.NewWithLBClient(lb) // Per-request overrides such as redirects, retries, TLS, and proxy dialers // are shared across every host client managed by the load balancer. resp, err := cc.Get("http://service.internal/api") if err != nil { log.Fatal(err) } log.Printf("status=%d body=%s", resp.StatusCode(), resp.Body()) } ``` ## Cookie jar The client can store and reuse cookies between requests by attaching a cookie jar. ### Request ```go func main() { jar := client.AcquireCookieJar() defer client.ReleaseCookieJar(jar) cc := client.New() cc.SetCookieJar(jar) jar.SetKeyValueBytes("httpbin.org", []byte("john"), []byte("doe")) resp, err := cc.Get("https://httpbin.org/cookies") if err != nil { panic(err) } fmt.Println(string(resp.Body())) } ```
Click here to see the result ```json { "cookies": { "john": "doe" } } ```
### Response Read cookies set by the server directly from the jar. ```go func main() { jar := client.AcquireCookieJar() defer client.ReleaseCookieJar(jar) cc := client.New() cc.SetCookieJar(jar) _, err := cc.Get("https://httpbin.org/cookies/set/john/doe") if err != nil { panic(err) } uri := fasthttp.AcquireURI() defer fasthttp.ReleaseURI(uri) uri.SetHost("httpbin.org") uri.SetPath("/cookies") fmt.Println(jar.Get(uri)) } ```
Click here to see the result ```plaintext [john=doe; path=/] ```
### Response (follow-up request) ```go func main() { jar := client.AcquireCookieJar() defer client.ReleaseCookieJar(jar) cc := client.New() cc.SetCookieJar(jar) _, err := cc.Get("https://httpbin.org/cookies/set/john/doe") if err != nil { panic(err) } resp, err := cc.Get("https://httpbin.org/cookies") if err != nil { panic(err) } fmt.Println(resp.String()) } ```
Click here to see the result ```json { "cookies": { "john": "doe" } } ```
================================================ FILE: docs/client/hooks.md ================================================ --- id: hooks title: 🎣 Hooks description: >- Hooks are used to manipulate the request/response process of the Fiber client. sidebar_position: 4 --- Hooks let you intercept and modify the request or response flow of the Fiber client. They are useful for: - Changing request parameters (e.g., URL, headers) before sending the request. - Logging request and response details. - Integrating complex tracing or monitoring tools. - Handling authentication, retries, or other custom logic. There are two kinds of hooks: ## Request Hooks **Request hooks** are functions executed before the HTTP request is sent. They follow the signature: ```go type RequestHook func(*Client, *Request) error ``` A request hook receives both the `Client` and the `Request` objects, allowing you to modify the request before it leaves your application. For example, you could: - Change the host URL. - Log request details (method, URL, headers). - Add or modify headers or query parameters. - Intercept and apply custom authentication logic. **Example:** ```go type Repository struct { Name string `json:"name"` FullName string `json:"full_name"` Description string `json:"description"` Homepage string `json:"homepage"` Owner struct { Login string `json:"login"` } `json:"owner"` } func main() { cc := client.New() // Add a request hook that modifies the request URL before sending. cc.AddRequestHook(func(c *client.Client, r *client.Request) error { r.SetURL("https://api.github.com/" + r.URL()) return nil }) resp, err := cc.Get("repos/gofiber/fiber") if err != nil { panic(err) } var repo Repository if err := resp.JSON(&repo); err != nil { panic(err) } fmt.Printf("Status code: %d\n", resp.StatusCode()) fmt.Printf("Repository: %s\n", repo.FullName) fmt.Printf("Description: %s\n", repo.Description) fmt.Printf("Homepage: %s\n", repo.Homepage) fmt.Printf("Owner: %s\n", repo.Owner.Login) fmt.Printf("Name: %s\n", repo.Name) fmt.Printf("Full Name: %s\n", repo.FullName) } ```
Click here to see the result ```plaintext Status code: 200 Repository: gofiber/fiber Description: ⚡️ Express inspired web framework written in Go Homepage: https://gofiber.io Owner: gofiber Name: fiber Full Name: gofiber/fiber ```
### Built-in Request Hooks Fiber includes built-in request hooks: - **parserRequestURL**: Normalizes and customizes the URL based on path and query parameters. Required for `PathParam` and `QueryParam` methods. - **parserRequestHeader**: Sets request headers, cookies, content type, referer, and user agent based on client and request properties. - **parserRequestBody**: Automatically serializes the request body (JSON, XML, form, file uploads, etc.). :::info If a request hook returns an error, Fiber stops the request and returns the error immediately. ::: **Example with Multiple Hooks:** ```go func main() { cc := client.New() cc.AddRequestHook(func(c *client.Client, r *client.Request) error { fmt.Println("Hook 1") return errors.New("error") }) cc.AddRequestHook(func(c *client.Client, r *client.Request) error { fmt.Println("Hook 2") return nil }) _, err := cc.Get("https://example.com/") if err != nil { panic(err) } } ```
Click here to see the result ```shell Hook 1. panic: error goroutine 1 [running]: main.main() main.go:25 +0xaa exit status 2 ```
## Response Hooks **Response hooks** are functions executed after the HTTP response is received. They follow the signature: ```go type ResponseHook func(*Client, *Response, *Request) error ``` A response hook receives the `Client`, `Response`, and `Request` objects, allowing you to inspect and modify the response or perform additional actions such as logging, tracing, or processing response data. **Example:** ```go func main() { cc := client.New() cc.AddResponseHook(func(c *client.Client, resp *client.Response, req *client.Request) error { fmt.Printf("Response Status Code: %d\n", resp.StatusCode()) fmt.Printf("HTTP protocol: %s\n\n", resp.Protocol()) fmt.Println("Response Headers:") for key, value := range resp.RawResponse.Header.All() { fmt.Printf("%s: %s\n", key, value) } return nil }) _, err := cc.Get("https://example.com/") if err != nil { panic(err) } } ```
Click here to see the result ```plaintext Response Status Code: 200 HTTP protocol: HTTP/1.1 Response Headers: Content-Length: 1256 Content-Type: text/html; charset=UTF-8 Server: ECAcc (dcd/7D5A) Age: 216114 Cache-Control: max-age=604800 Date: Fri, 10 May 2024 10:49:10 GMT Etag: "3147526947+gzip+ident" Expires: Fri, 17 May 2024 10:49:10 GMT Last-Modified: Thu, 17 Oct 2019 07:18:26 GMT Vary: Accept-Encoding X-Cache: HIT ```
### Built-in Response Hooks Fiber includes built-in response hooks: - **parserResponseCookie**: Parses cookies from the response and stores them in the response object and cookie jar if available. - **logger**: Logs information about the raw request and response. It uses the `log.CommonLogger` interface. :::info If a response hook returns an error, Fiber skips the remaining hooks and returns that error. ::: **Example with Multiple Response Hooks:** ```go func main() { cc := client.New() cc.AddResponseHook(func(c *client.Client, r1 *client.Response, r2 *client.Request) error { fmt.Println("Hook 1") return nil }) cc.AddResponseHook(func(c *client.Client, r1 *client.Response, r2 *client.Request) error { fmt.Println("Hook 2") return errors.New("error") }) cc.AddResponseHook(func(c *client.Client, r1 *client.Response, r2 *client.Request) error { fmt.Println("Hook 3") return nil }) _, err := cc.Get("https://example.com/") if err != nil { panic(err) } } ```
Click here to see the result ```shell Hook 1 Hook 2 panic: error goroutine 1 [running]: main.main() main.go:30 +0xd6 exit status 2 ```
## Hook Execution Order Hooks run in FIFO order (first in, first out), so they're executed in the order you add them. Keep this in mind when adding multiple hooks, as the order can affect the outcome. **Example:** ```go func main() { cc := client.New() cc.AddRequestHook(func(c *client.Client, r *client.Request) error { fmt.Println("Hook 1") return nil }) cc.AddRequestHook(func(c *client.Client, r *client.Request) error { fmt.Println("Hook 2") return nil }) _, err := cc.Get("https://example.com/") if err != nil { panic(err) } } ```
Click here to see the result ```plaintext Hook 1 Hook 2 ```
================================================ FILE: docs/client/request.md ================================================ --- id: request title: 📤 Request description: >- Request methods of Gofiber HTTP client. sidebar_position: 2 --- The `Request` struct in Fiber's HTTP client represents an HTTP request. It encapsulates the data required to send a request, including: - **URL**: The endpoint to which the request is sent. - **Method**: The HTTP method (GET, POST, PUT, DELETE, etc.). - **Headers**: Key-value pairs that provide additional information about the request or guide how the response should be processed. - **Body**: The data sent with the request, commonly used with methods like POST and PUT. - **Query Parameters**: Parameters appended to the URL to pass additional data or modify the request's behavior. This structure is designed to be both flexible and efficient, allowing you to easily build and modify HTTP requests as needed. ```go type Request struct { url string method string userAgent string boundary string referer string ctx context.Context header *Header params *QueryParam cookies *Cookie path *PathParam timeout time.Duration maxRedirects int client *Client body any formData *FormData files []*File bodyType bodyType RawRequest *fasthttp.Request } ``` ## REST Methods ### Get **Get** sends a GET request to the specified URL. It sets the URL and HTTP method, then dispatches the request to the server. ```go title="Signature" func (r *Request) Get(url string) (*Response, error) ``` ### Post **Post** sends a POST request. It sets the URL and method to POST, then sends the request. ```go title="Signature" func (r *Request) Post(url string) (*Response, error) ``` ### Put **Put** sends a PUT request. It sets the URL and method to PUT, then sends the request. ```go title="Signature" func (r *Request) Put(url string) (*Response, error) ``` ### Patch **Patch** sends a PATCH request. It sets the URL and method to PATCH, then sends the request. ```go title="Signature" func (r *Request) Patch(url string) (*Response, error) ``` ### Delete **Delete** sends a DELETE request. It sets the URL and method to DELETE, then sends the request. ```go title="Signature" func (r *Request) Delete(url string) (*Response, error) ``` ### Head **Head** sends a HEAD request. It sets the URL and method to HEAD, then sends the request. ```go title="Signature" func (r *Request) Head(url string) (*Response, error) ``` ### Options **Options** sends an OPTIONS request. It sets the URL and method to OPTIONS, then sends the request. ```go title="Signature" func (r *Request) Options(url string) (*Response, error) ``` ### Custom **Custom** sends a request using a custom HTTP method. For example, you can use this to send a TRACE or CONNECT request. ```go title="Signature" func (r *Request) Custom(url, method string) (*Response, error) ``` ## AcquireRequest **AcquireRequest** returns a new pooled `Request`. Call `ReleaseRequest` when you're finished to return it to the pool and limit allocations. ```go title="Signature" func AcquireRequest() *Request ``` ## ReleaseRequest **ReleaseRequest** returns the `Request` to the pool. Do not use it after releasing; doing so may cause data races. ```go title="Signature" func ReleaseRequest(req *Request) ``` ## Method **Method** returns the current HTTP method set for the request. ```go title="Signature" func (r *Request) Method() string ``` ## SetMethod **SetMethod** sets the HTTP method for the `Request` object. Typically, you should use the specialized request methods (`Get`, `Post`, etc.) instead of calling `SetMethod` directly. ```go title="Signature" func (r *Request) SetMethod(method string) *Request ``` ## URL **URL** returns the current URL set in the `Request`. ```go title="Signature" func (r *Request) URL() string ``` ## SetURL **SetURL** sets the URL for the `Request` object. ```go title="Signature" func (r *Request) SetURL(url string) *Request ``` ## Client **Client** retrieves the `Client` instance associated with the `Request`. ```go title="Signature" func (r *Request) Client() *Client ``` ## SetClient **SetClient** assigns a `Client` to the `Request`. If the provided client is `nil`, it will panic. ```go title="Signature" func (r *Request) SetClient(c *Client) *Request ``` ## Context **Context** returns the `context.Context` of the request, or `context.Background()` if none is set. ```go title="Signature" func (r *Request) Context() context.Context ``` ## SetContext **SetContext** sets the `context.Context` for the request, allowing you to cancel or time out the request. See the [Go blog](https://blog.golang.org/context) and [context](https://pkg.go.dev/context) docs for more details. ```go title="Signature" func (r *Request) SetContext(ctx context.Context) *Request ``` ## Header **Header** returns all values for the specified header key. It searches all header fields stored in the request. ```go title="Signature" func (r *Request) Header(key string) []string ``` ### Headers **Headers** returns an iterator over all headers in the request. Use `maps.Collect()` to transform them into a map if needed. The returned values are valid only until the request is released. Make copies as required. ```go title="Signature" func (r *Request) Headers() iter.Seq2[string, []string] ```
Example ```go title="Example" req := client.AcquireRequest() req.AddHeader("Golang", "Fiber") req.AddHeader("Test", "123456") req.AddHeader("Test", "654321") for k, v := range req.Headers() { fmt.Printf("Header Key: %s, Header Value: %v\n", k, v) } ``` ```sh Header Key: Golang, Header Value: [Fiber] Header Key: Test, Header Value: [123456 654321] ```
Example with maps.Collect() ```go title="Example with maps.Collect()" req := client.AcquireRequest() req.AddHeader("Golang", "Fiber") req.AddHeader("Test", "123456") req.AddHeader("Test", "654321") headers := maps.Collect(req.Headers()) // Collect all headers into a map for k, v := range headers { fmt.Printf("Header Key: %s, Header Value: %v\n", k, v) } ``` ```sh Header Key: Golang, Header Value: [Fiber] Header Key: Test, Header Value: [123456 654321] ```
### AddHeader **AddHeader** adds a single header field and its value to the request. ```go title="Signature" func (r *Request) AddHeader(key, val string) *Request ```
Example ```go title="Example" req := client.AcquireRequest() defer client.ReleaseRequest(req) req.AddHeader("Golang", "Fiber") req.AddHeader("Test", "123456") req.AddHeader("Test", "654321") resp, err := req.Get("https://httpbin.org/headers") if err != nil { panic(err) } fmt.Println(resp.String()) ``` ```json { "headers": { "Golang": "Fiber", "Host": "httpbin.org", "Referer": "", "Test": "123456,654321", "User-Agent": "fiber", "X-Amzn-Trace-Id": "Root=1-664105d2-033cf7173457adb56d9e7193" } } ```
### SetHeader **SetHeader** sets a single header field and its value, overriding any previously set header with the same key. ```go title="Signature" func (r *Request) SetHeader(key, val string) *Request ```
Example ```go title="Example" req := client.AcquireRequest() defer client.ReleaseRequest(req) req.SetHeader("Test", "123456") req.SetHeader("Test", "654321") resp, err := req.Get("https://httpbin.org/headers") if err != nil { panic(err) } fmt.Println(resp.String()) ``` ```json { "headers": { "Golang": "Fiber", "Host": "httpbin.org", "Referer": "", "Test": "654321", "User-Agent": "fiber", "X-Amzn-Trace-Id": "Root=1-664105e5-5d676ba348450cdb62847f04" } } ```
### AddHeaders **AddHeaders** adds multiple headers at once from a map of string slices. ```go title="Signature" func (r *Request) AddHeaders(h map[string][]string) *Request ``` ### SetHeaders **SetHeaders** sets multiple headers at once from a map of strings, overriding any previously set headers. ```go title="Signature" func (r *Request) SetHeaders(h map[string]string) *Request ``` ## Param **Param** returns all values associated with a given query parameter key. ```go title="Signature" func (r *Request) Param(key string) []string ``` ### Params **Params** returns an iterator over all query parameters. Use `maps.Collect()` if you need them in a map. The returned values are valid only until the request is released. ```go title="Signature" func (r *Request) Params() iter.Seq2[string, []string] ``` ### AddParam **AddParam** adds a single query parameter key-value pair. ```go title="Signature" func (r *Request) AddParam(key, val string) *Request ```
Example ```go title="Example" req := client.AcquireRequest() defer client.ReleaseRequest(req) req.AddParam("name", "john") req.AddParam("hobbies", "football") req.AddParam("hobbies", "basketball") resp, err := req.Get("https://httpbin.org/response-headers") if err != nil { panic(err) } fmt.Println(string(resp.Body())) ``` ```json { "Content-Length": "145", "Content-Type": "application/json", "hobbies": [ "football", "basketball" ], "name": "joe" } ```
### SetParam **SetParam** sets a single query parameter key-value pair, overriding any previously set values for that key. ```go title="Signature" func (r *Request) SetParam(key, val string) *Request ``` ### AddParams **AddParams** adds multiple query parameters from a map of string slices. ```go title="Signature" func (r *Request) AddParams(m map[string][]string) *Request ``` ### SetParams **SetParams** sets multiple query parameters from a map of strings, overriding previously set values. ```go title="Signature" func (r *Request) SetParams(m map[string]string) *Request ``` ### SetParamsWithStruct **SetParamsWithStruct** sets multiple query parameters from a struct. Nested structs are not supported. ```go title="Signature" func (r *Request) SetParamsWithStruct(v any) *Request ```
Example ```go title="Example" req := client.AcquireRequest() defer client.ReleaseRequest(req) req.SetParamsWithStruct(struct { Name string `json:"name"` Hobbies []string `json:"hobbies"` }{ Name: "John Doe", Hobbies: []string{ "Football", "Basketball", }, }) resp, err := req.Get("https://httpbin.org/response-headers") if err != nil { panic(err) } fmt.Println(string(resp.Body())) ``` ```json { "Content-Length": "147", "Content-Type": "application/json", "Hobbies": [ "Football", "Basketball" ], "Name": "John Doe" } ```
### DelParams **DelParams** removes one or more query parameters by their keys. ```go title="Signature" func (r *Request) DelParams(key ...string) *Request ``` ## UserAgent **UserAgent** returns the user agent currently set in the request. ```go title="Signature" func (r *Request) UserAgent() string ``` ## SetUserAgent **SetUserAgent** sets the user agent header for the request, overriding the one set at the client level if any. ```go title="Signature" func (r *Request) SetUserAgent(ua string) *Request ``` ## Boundary **Boundary** returns the multipart boundary used by the request. ```go title="Signature" func (r *Request) Boundary() string ``` ## SetBoundary **SetBoundary** sets the multipart boundary for file uploads. ```go title="Signature" func (r *Request) SetBoundary(b string) *Request ``` ## Referer **Referer** returns the Referer header value currently set in the request. ```go title="Signature" func (r *Request) Referer() string ``` ## SetReferer **SetReferer** sets the Referer header for the request, overriding the one set at the client level if any. ```go title="Signature" func (r *Request) SetReferer(referer string) *Request ``` ## Cookie **Cookie** returns the value of the specified cookie. If the cookie does not exist, it returns an empty string. ```go title="Signature" func (r *Request) Cookie(key string) string ``` ### Cookies **Cookies** returns an iterator over all cookies set in the request. Use `maps.Collect()` to gather them into a map. ```go title="Signature" func (r *Request) Cookies() iter.Seq2[string, string] ``` ### SetCookie **SetCookie** sets a single cookie key-value pair, overriding any previously set cookie with the same key. ```go title="Signature" func (r *Request) SetCookie(key, val string) *Request ``` ### SetCookies **SetCookies** sets multiple cookies from a map, overriding previously set values. ```go title="Signature" func (r *Request) SetCookies(m map[string]string) *Request ```
Example ```go title="Example" req := client.AcquireRequest() defer client.ReleaseRequest(req) req.SetCookies(map[string]string{ "cookie1": "value1", "cookie2": "value2", }) resp, err := req.Get("https://httpbin.org/cookies") if err != nil { panic(err) } fmt.Println(string(resp.Body())) ``` ```json { "cookies": { "test": "123" } } ```
### SetCookiesWithStruct **SetCookiesWithStruct** sets multiple cookies from a struct. ```go title="Signature" func (r *Request) SetCookiesWithStruct(v any) *Request ``` ### DelCookies **DelCookies** removes one or more cookies by their keys. ```go title="Signature" func (r *Request) DelCookies(key ...string) *Request ``` ## PathParam **PathParam** returns the value of a named path parameter. If not found, returns an empty string. ```go title="Signature" func (r *Request) PathParam(key string) string ``` ### PathParams **PathParams** returns an iterator over all path parameters in the request. Use `maps.Collect()` to convert them into a map. ```go title="Signature" func (r *Request) PathParams() iter.Seq2[string, string] ``` ### SetPathParam **SetPathParam** sets a single path parameter key-value pair, overriding previously set values. ```go title="Signature" func (r *Request) SetPathParam(key, val string) *Request ```
Example ```go title="Example" req := client.AcquireRequest() defer client.ReleaseRequest(req) req.SetPathParam("base64", "R29maWJlcg==") resp, err := req.Get("https://httpbin.org/base64/:base64") if err != nil { panic(err) } fmt.Println(string(resp.Body())) ``` ```plaintext Gofiber ```
### SetPathParams **SetPathParams** sets multiple path parameters at once, overriding previously set values. ```go title="Signature" func (r *Request) SetPathParams(m map[string]string) *Request ``` ### SetPathParamsWithStruct **SetPathParamsWithStruct** sets multiple path parameters from a struct. ```go title="Signature" func (r *Request) SetPathParamsWithStruct(v any) *Request ``` ### DelPathParams **DelPathParams** deletes one or more path parameters by their keys. ```go title="Signature" func (r *Request) DelPathParams(key ...string) *Request ``` ### ResetPathParams **ResetPathParams** deletes all path parameters. ```go title="Signature" func (r *Request) ResetPathParams() *Request ``` ## SetJSON **SetJSON** sets the request body to a JSON-encoded payload. ```go title="Signature" func (r *Request) SetJSON(v any) *Request ``` ## SetXML **SetXML** sets the request body to an XML-encoded payload. ```go title="Signature" func (r *Request) SetXML(v any) *Request ``` ## SetCBOR **SetCBOR** sets the request body to a CBOR-encoded payload. It automatically sets the `Content-Type` to `application/cbor`. ```go title="Signature" func (r *Request) SetCBOR(v any) *Request ``` ## SetRawBody **SetRawBody** sets the request body to raw bytes. ```go title="Signature" func (r *Request) SetRawBody(v []byte) *Request ``` ## FormData **FormData** returns all values associated with the given form data field. ```go title="Signature" func (r *Request) FormData(key string) []string ``` ### AllFormData **AllFormData** returns an iterator over all form data fields. Use `maps.Collect()` if needed. ```go title="Signature" func (r *Request) AllFormData() iter.Seq2[string, []string] ``` ### AddFormData **AddFormData** adds a single form data key-value pair. ```go title="Signature" func (r *Request) AddFormData(key, val string) *Request ```
Example ```go title="Example" req := client.AcquireRequest() defer client.ReleaseRequest(req) req.AddFormData("points", "80") req.AddFormData("points", "90") req.AddFormData("points", "100") resp, err := req.Post("https://httpbin.org/post") if err != nil { panic(err) } fmt.Println(string(resp.Body())) ``` ```json { "args": {}, "data": "", "files": {}, "form": { "points": [ "80", "90", "100" ] }, // ... } ```
### SetFormData **SetFormData** sets a single form data field, overriding any previously set values. ```go title="Signature" func (r *Request) SetFormData(key, val string) *Request ```
Example ```go title="Example" req := client.AcquireRequest() defer client.ReleaseRequest(req) req.SetFormData("name", "john") req.SetFormData("email", "john@doe.com") resp, err := req.Post("https://httpbin.org/post") if err != nil { panic(err) } fmt.Println(string(resp.Body())) ``` ```json { "args": {}, "data": "", "files": {}, "form": { "email": "john@doe.com", "name": "john" }, // ... } ```
### AddFormDataWithMap **AddFormDataWithMap** adds multiple form data fields and values from a map of string slices. ```go title="Signature" func (r *Request) AddFormDataWithMap(m map[string][]string) *Request ``` ### SetFormDataWithMap **SetFormDataWithMap** sets multiple form data fields from a map of strings. ```go title="Signature" func (r *Request) SetFormDataWithMap(m map[string]string) *Request ``` ### SetFormDataWithStruct **SetFormDataWithStruct** sets multiple form data fields from a struct. ```go title="Signature" func (r *Request) SetFormDataWithStruct(v any) *Request ``` ### DelFormData **DelFormData** deletes one or more form data fields by their keys. ```go title="Signature" func (r *Request) DelFormData(key ...string) *Request ``` ## File **File** returns a file from the request by its name. If no name was provided, it attempts to match by path. ```go title="Signature" func (r *Request) File(name string) *File ``` ### Files **Files** returns all files in the request as a slice. The returned slice is valid only until the request is released. ```go title="Signature" func (r *Request) Files() []*File ``` ### FileByPath **FileByPath** returns a file from the request by its file path. ```go title="Signature" func (r *Request) FileByPath(path string) *File ``` ### AddFile **AddFile** adds a single file to the request from a file path. ```go title="Signature" func (r *Request) AddFile(path string) *Request ```
Example ```go title="Example" req := client.AcquireRequest() defer client.ReleaseRequest(req) req.AddFile("test.txt") resp, err := req.Post("https://httpbin.org/post") if err != nil { panic(err) } fmt.Println(string(resp.Body())) ``` ```json { "args": {}, "data": "", "files": { "file1": "This is an empty file!\n" }, "form": {}, // ... } ```
### AddFileWithReader **AddFileWithReader** adds a single file to the request from an `io.ReadCloser`. ```go title="Signature" func (r *Request) AddFileWithReader(name string, reader io.ReadCloser) *Request ```
Example ```go title="Example" req := client.AcquireRequest() defer client.ReleaseRequest(req) buf := bytes.NewBuffer([]byte("Hello, World!")) req.AddFileWithReader("test.txt", io.NopCloser(buf)) resp, err := req.Post("https://httpbin.org/post") if err != nil { panic(err) } fmt.Println(string(resp.Body())) ``` ```json { "args": {}, "data": "", "files": { "file1": "Hello, World!" }, "form": {}, // ... } ```
### AddFiles **AddFiles** adds multiple files to the request at once. ```go title="Signature" func (r *Request) AddFiles(files ...*File) *Request ``` ## Timeout **Timeout** returns the timeout duration set in the request. ```go title="Signature" func (r *Request) Timeout() time.Duration ``` ## SetTimeout **SetTimeout** sets a timeout for the request, overriding any timeout set at the client level. ```go title="Signature" func (r *Request) SetTimeout(t time.Duration) *Request ```
Example 1 ```go title="Example 1" req := client.AcquireRequest() defer client.ReleaseRequest(req) req.SetTimeout(5 * time.Second) resp, err := req.Get("https://httpbin.org/delay/4") if err != nil { panic(err) } fmt.Println(string(resp.Body())) ``` ```json { "args": {}, "data": "", "files": {}, "form": {}, // ... } ```
Example 2 ```go title="Example 2" req := client.AcquireRequest() defer client.ReleaseRequest(req) req.SetTimeout(5 * time.Second) resp, err := req.Get("https://httpbin.org/delay/6") if err != nil { panic(err) } fmt.Println(string(resp.Body())) ``` ```shell panic: timeout or cancel goroutine 1 [running]: main.main() main.go:18 +0xeb exit status 2 ```
## MaxRedirects **MaxRedirects** returns the maximum number of redirects allowed for the request. ```go title="Signature" func (r *Request) MaxRedirects() int ``` ## SetMaxRedirects **SetMaxRedirects** sets the maximum number of redirects for the request, overriding the client's setting. ```go title="Signature" func (r *Request) SetMaxRedirects(count int) *Request ``` ## Send **Send** executes the HTTP request and returns a `Response`. ```go title="Signature" func (r *Request) Send() (*Response, error) ``` ## Reset **Reset** clears the `Request` object, making it ready for reuse. This is used by `ReleaseRequest`. ```go title="Signature" func (r *Request) Reset() ``` ## Header **Header** is a wrapper around `fasthttp.RequestHeader`, storing headers for both the client and request. ```go type Header struct { *fasthttp.RequestHeader } ``` ### PeekMultiple **PeekMultiple** returns multiple values associated with the same header key. ```go title="Signature" func (h *Header) PeekMultiple(key string) []string ``` ### AddHeaders **AddHeaders** adds multiple headers from a map of string slices. ```go title="Signature" func (h *Header) AddHeaders(r map[string][]string) ``` ### SetHeaders **SetHeaders** sets multiple headers from a map of strings, overriding previously set headers. ```go title="Signature" func (h *Header) SetHeaders(r map[string]string) ``` ## QueryParam **QueryParam** is a wrapper around `fasthttp.Args`, storing query parameters. ```go type QueryParam struct { *fasthttp.Args } ``` ### Keys **Keys** returns all keys in the query parameters. ```go title="Signature" func (p *QueryParam) Keys() []string ``` ### AddParams **AddParams** adds multiple query parameters from a map of string slices. ```go title="Signature" func (p *QueryParam) AddParams(r map[string][]string) ``` ### SetParams **SetParams** sets multiple query parameters from a map of strings, overriding previously set values. ```go title="Signature" func (p *QueryParam) SetParams(r map[string]string) ``` ### SetParamsWithStruct **SetParamsWithStruct** sets multiple query parameters from a struct. Nested structs are not supported. ```go title="Signature" func (p *QueryParam) SetParamsWithStruct(v any) ``` ## Cookie **Cookie** is a map that stores cookies. ```go type Cookie map[string]string ``` ### Add **Add** adds a cookie key-value pair. ```go title="Signature" func (c Cookie) Add(key, val string) ``` ### Del **Del** removes a cookie by its key. ```go title="Signature" func (c Cookie) Del(key string) ``` ### SetCookie **SetCookie** sets a single cookie key-value pair, overriding previously set values. ```go title="Signature" func (c Cookie) SetCookie(key, val string) ``` ### SetCookies **SetCookies** sets multiple cookies from a map of strings. ```go title="Signature" func (c Cookie) SetCookies(m map[string]string) ``` ### SetCookiesWithStruct **SetCookiesWithStruct** sets multiple cookies from a struct. ```go title="Signature" func (c Cookie) SetCookiesWithStruct(v any) ``` ### DelCookies **DelCookies** deletes one or more cookies by their keys. ```go title="Signature" func (c Cookie) DelCookies(key ...string) ``` ### All **All** returns an iterator over all cookies. The key and value returned should not be retained after the loop ends. ```go title="Signature" func (c Cookie) All() iter.Seq2[string, string] ``` ### Reset **Reset** clears all cookies. ```go title="Signature" func (c Cookie) Reset() ``` ## PathParam **PathParam** is a map that stores path parameters. ```go type PathParam map[string]string ``` ### Add **Add** adds a path parameter key-value pair. ```go title="Signature" func (p PathParam) Add(key, val string) ``` ### Del **Del** removes a path parameter by its key. ```go title="Signature" func (p PathParam) Del(key string) ``` ### SetParam **SetParam** sets a single path parameter key-value pair, overriding previously set values. ```go title="Signature" func (p PathParam) SetParam(key, val string) ``` ### SetParams **SetParams** sets multiple path parameters from a map of strings. ```go title="Signature" func (p PathParam) SetParams(m map[string]string) ``` ### SetParamsWithStruct **SetParamsWithStruct** sets multiple path parameters from a struct. ```go title="Signature" func (p PathParam) SetParamsWithStruct(v any) ``` ### DelParams **DelParams** deletes one or more path parameters by their keys. ```go title="Signature" func (p PathParam) DelParams(key ...string) ``` ### All **All** returns an iterator over all path parameters. The key and value returned should not be retained after the loop ends. ```go title="Signature" func (p PathParam) All() iter.Seq2[string, string] ``` ### Reset **Reset** clears all path parameters. ```go title="Signature" func (p PathParam) Reset() ``` ## FormData **FormData** is a wrapper around `fasthttp.Args`, used to handle URL-encoded and form-data (multipart) request bodies. ```go type FormData struct { *fasthttp.Args } ``` ### Keys **Keys** returns all form data keys. ```go title="Signature" func (f *FormData) Keys() []string ``` ### Add **Add** adds a single form field key-value pair. ```go title="Signature" func (f *FormData) Add(key, val string) ``` ### Set **Set** sets a single form field key-value pair, overriding any previously set values. ```go title="Signature" func (f *FormData) Set(key, val string) ``` ### AddWithMap **AddWithMap** adds multiple form fields from a map of string slices. ```go title="Signature" func (f *FormData) AddWithMap(m map[string][]string) ``` ### SetWithMap **SetWithMap** sets multiple form fields from a map of strings. ```go title="Signature" func (f *FormData) SetWithMap(m map[string]string) ``` ### SetWithStruct **SetWithStruct** sets multiple form fields from a struct. ```go title="Signature" func (f *FormData) SetWithStruct(v any) ``` ### DelData **DelData** deletes one or more form fields by their keys. ```go title="Signature" func (f *FormData) DelData(key ...string) ``` ### Reset **Reset** clears all form data fields. ```go title="Signature" func (f *FormData) Reset() ``` ## File **File** represents a file to be uploaded. It can be specified by name, path, or an `io.ReadCloser`. ```go type File struct { name string fieldName string path string reader io.ReadCloser } ``` ### AcquireFile **AcquireFile** returns a `File` from the pool and applies any provided `SetFileFunc` functions to it. Release it with `ReleaseFile` when done. ```go title="Signature" func AcquireFile(setter ...SetFileFunc) *File ``` ### ReleaseFile **ReleaseFile** returns the `File` to the pool. Do not use the file afterward. ```go title="Signature" func ReleaseFile(f *File) ``` ### SetName **SetName** sets the file's name. ```go title="Signature" func (f *File) SetName(n string) ``` ### SetFieldName **SetFieldName** sets the field name of the file in the multipart form. ```go title="Signature" func (f *File) SetFieldName(n string) ``` ### SetPath **SetPath** sets the file's path. ```go title="Signature" func (f *File) SetPath(p string) ``` ### SetReader **SetReader** sets the file's `io.ReadCloser`. The reader is closed automatically when the request body is parsed. ```go title="Signature" func (f *File) SetReader(r io.ReadCloser) ``` ### Reset **Reset** clears the file's fields. ```go title="Signature" func (f *File) Reset() ``` ================================================ FILE: docs/client/response.md ================================================ --- id: response title: 📥 Response description: >- Response methods of Gofiber HTTP client. sidebar_position: 3 --- The `Response` struct in Fiber's HTTP client represents the server's reply and exposes: - **Status Code**: The HTTP status code returned by the server (e.g., `200 OK`, `404 Not Found`). - **Headers**: All HTTP headers returned by the server, providing additional response-related information. - **Body**: The response body content, which can be JSON, XML, plain text, or other formats. - **Cookies**: Any cookies the server sent along with the response. It makes it easy to inspect and handle data returned by the server. ```go type Response struct { client *Client request *Request cookie []*fasthttp.Cookie RawResponse *fasthttp.Response } ``` ## AcquireResponse **AcquireResponse** returns a new pooled `Response`. Call `ReleaseResponse` when you're done to return it to the pool and limit allocations. ```go title="Signature" func AcquireResponse() *Response ``` ## ReleaseResponse **ReleaseResponse** puts the `Response` back into the pool. Do not use it after releasing; doing so can trigger data races. ```go title="Signature" func ReleaseResponse(resp *Response) ``` ## Status **Status** returns the HTTP status message (e.g., `OK`, `Not Found`) associated with the response. ```go title="Signature" func (r *Response) Status() string ``` ## StatusCode **StatusCode** returns the numeric HTTP status code of the response. ```go title="Signature" func (r *Response) StatusCode() int ``` ## Protocol **Protocol** returns the HTTP protocol used (e.g., `HTTP/1.1`, `HTTP/2`) for the response. ```go title="Signature" func (r *Response) Protocol() string ```
Example ```go title="Example" resp, err := client.Get("https://httpbin.org/get") if err != nil { panic(err) } fmt.Println(resp.Protocol()) ``` **Output:** ```text HTTP/1.1 ```
## Header **Header** retrieves the value of a specific response header by key. If multiple values exist for the same header, this returns the first one. ```go title="Signature" func (r *Response) Header(key string) string ``` ## Headers **Headers** returns an iterator over all response headers. Use `maps.Collect()` to convert them into a map if desired. The returned values are only valid until the response is released, so make copies if needed. ```go title="Signature" func (r *Response) Headers() iter.Seq2[string, []string] ```
Example ```go title="Example" resp, err := client.Get("https://httpbin.org/get") if err != nil { panic(err) } for key, values := range resp.Headers() { fmt.Printf("%s => %s\n", key, strings.Join(values, ", ")) } ``` **Output:** ```text Date => Wed, 04 Dec 2024 15:28:29 GMT Connection => keep-alive Access-Control-Allow-Origin => * Access-Control-Allow-Credentials => true ```
Example with maps.Collect() ```go title="Example with maps.Collect()" resp, err := client.Get("https://httpbin.org/get") if err != nil { panic(err) } headers := maps.Collect(resp.Headers()) for key, values := range headers { fmt.Printf("%s => %s\n", key, strings.Join(values, ", ")) } ``` **Output:** ```text Date => Wed, 04 Dec 2024 15:28:29 GMT Connection => keep-alive Access-Control-Allow-Origin => * Access-Control-Allow-Credentials => true ```
## Cookies **Cookies** returns a slice of all cookies set by the server in this response. The slice is only valid until the response is released. ```go title="Signature" func (r *Response) Cookies() []*fasthttp.Cookie ```
Example ```go title="Example" resp, err := client.Get("https://httpbin.org/cookies/set/go/fiber") if err != nil { panic(err) } cookies := resp.Cookies() for _, cookie := range cookies { fmt.Printf("%s => %s\n", string(cookie.Key()), string(cookie.Value())) } ``` **Output:** ```text go => fiber ```
## Body **Body** returns the raw response body as a byte slice. ```go title="Signature" func (r *Response) Body() []byte ``` ## BodyStream **BodyStream** returns the response body as an `io.Reader`, allowing incremental reading without loading the entire body into memory. This is particularly useful when `Client.SetStreamResponseBody(true)` is enabled. When streaming is enabled, the underlying stream from fasthttp is returned directly. When streaming is not enabled, a `bytes.Reader` wrapping the body is returned as a fallback. :::note When using `BodyStream()`, the response body is consumed as you read. Calling `Body()` afterward may return an empty slice if the stream has been fully read. ::: ```go title="Signature" func (r *Response) BodyStream() io.Reader ```
Example ```go title="Example" cc := client.New() cc.SetStreamResponseBody(true) resp, err := cc.Get("https://httpbin.org/bytes/1024") if err != nil { panic(err) } defer resp.Close() buf := make([]byte, 256) total, err := io.CopyBuffer(io.Discard, resp.BodyStream(), buf) if err != nil { panic(err) } fmt.Printf("Read %d bytes\n", total) ``` **Output:** ```text Read 1024 bytes ```
## IsStreaming **IsStreaming** returns `true` if the response body is being streamed (i.e., when `Client.SetStreamResponseBody(true)` was set and the underlying transport provided a stream). ```go title="Signature" func (r *Response) IsStreaming() bool ```
Example ```go title="Example" cc := client.New() cc.SetStreamResponseBody(true) resp, err := cc.Get("https://httpbin.org/get") if err != nil { panic(err) } defer resp.Close() if resp.IsStreaming() { fmt.Println("Response is streaming") // Use resp.BodyStream() to read incrementally } else { fmt.Println("Response is buffered") // Use resp.Body() for direct access } ```
## String **String** returns the response body as a trimmed string. ```go title="Signature" func (r *Response) String() string ``` ## JSON **JSON** unmarshal the response body into the provided variable `v` using JSON. `v` should be a pointer to a struct or a type compatible with JSON unmarshal. ```go title="Signature" func (r *Response) JSON(v any) error ```
Example ```go title="Example" type Body struct { Slideshow struct { Author string `json:"author"` Date string `json:"date"` Title string `json:"title"` } `json:"slideshow"` } var out Body resp, err := client.Get("https://httpbin.org/json") if err != nil { panic(err) } if err = resp.JSON(&out); err != nil { panic(err) } fmt.Printf("%+v\n", out) ``` **Output:** ```text {Slideshow:{Author:Yours Truly Date:date of publication Title:Sample Slide Show}} ```
## XML **XML** unmarshal the response body into the provided variable `v` using XML decoding. ```go title="Signature" func (r *Response) XML(v any) error ``` ## CBOR **CBOR** unmarshal the response body into `v` using CBOR decoding. ```go title="Signature" func (r *Response) CBOR(v any) error ``` ## Save **Save** writes the response body to a file or an `io.Writer`. If `v` is a string, it interprets it as a file path, creates the file (and directories if needed), and writes the response to it. If `v` is an `io.Writer`, it writes directly to it. ```go title="Signature" func (r *Response) Save(v any) error ``` ## Reset **Reset** clears the `Response` object, making it ready for reuse by `ReleaseResponse`. ```go title="Signature" func (r *Response) Reset() ``` ## Close **Close** releases both the associated `Request` and `Response` objects back to their pools. :::warning After calling `Close`, any attempt to use the request or response may result in data races or undefined behavior. Ensure all processing is complete before closing. ::: ```go title="Signature" func (r *Response) Close() ``` ================================================ FILE: docs/client/rest.md ================================================ --- id: rest title: 🖥️ REST description: >- HTTP client for Fiber. sidebar_position: 1 toc_max_heading_level: 5 --- The Fiber Client is a high-performance HTTP client built on FastHTTP. It handles both internal service calls and external requests with minimal overhead. ## Features - **Lightweight and fast**: built on FastHTTP for minimal overhead. - **Flexible configuration**: set global defaults like timeouts or headers and override them per request. - **Connection pooling**: reuses persistent connections instead of opening new ones. - **Timeouts and retries**: supports per-request deadlines and retry policies for transient errors. ## Usage Create a client with any required configuration, then send requests: ```go package main import ( "fmt" "time" "github.com/gofiber/fiber/v3/client" ) func main() { cc := client.New() cc.SetTimeout(10 * time.Second) // Send a GET request resp, err := cc.Get("https://httpbin.org/get") if err != nil { panic(err) } fmt.Printf("Status: %d\n", resp.StatusCode()) fmt.Printf("Body: %s\n", string(resp.Body())) } ``` See [examples](examples.md) for more detailed usage. ```go type Client struct { mu sync.RWMutex fasthttp *fasthttp.Client baseURL string userAgent string referer string header *Header params *QueryParam cookies *Cookie path *PathParam debug bool timeout time.Duration // user-defined request hooks userRequestHooks []RequestHook // client package-defined request hooks builtinRequestHooks []RequestHook // user-defined response hooks userResponseHooks []ResponseHook // client package-defined response hooks builtinResponseHooks []ResponseHook jsonMarshal utils.JSONMarshal jsonUnmarshal utils.JSONUnmarshal xmlMarshal utils.XMLMarshal xmlUnmarshal utils.XMLUnmarshal cborMarshal utils.CBORMarshal cborUnmarshal utils.CBORUnmarshal cookieJar *CookieJar // proxy proxyURL string // retry retryConfig *RetryConfig // logger logger log.CommonLogger } ``` ### New **New** creates and returns a new Client object. ```go title="Signature" func New() *Client ``` ### NewWithClient **NewWithClient** creates and returns a new Client object from an existing `fasthttp.Client`. ```go title="Signature" func NewWithClient(c *fasthttp.Client) *Client ``` ## REST Methods These helpers mirror axios-style method names and send HTTP requests using the configured client: ### Get Sends a GET request. ```go title="Signature" func (c *Client) Get(url string, cfg ...Config) (*Response, error) ``` ### Post Sends a POST request. ```go title="Signature" func (c *Client) Post(url string, cfg ...Config) (*Response, error) ``` ### Put Sends a PUT request. ```go title="Signature" func (c *Client) Put(url string, cfg ...Config) (*Response, error) ``` ### Patch Sends a PATCH request. ```go title="Signature" func (c *Client) Patch(url string, cfg ...Config) (*Response, error) ``` ### Delete Sends a DELETE request. ```go title="Signature" func (c *Client) Delete(url string, cfg ...Config) (*Response, error) ``` ### Head Sends a HEAD request. ```go title="Signature" func (c *Client) Head(url string, cfg ...Config) (*Response, error) ``` ### Options Sends an OPTIONS request. ```go title="Signature" func (c *Client) Options(url string, cfg ...Config) (*Response, error) ``` ### Custom Sends a request with any HTTP method. ```go title="Signature" func (c *Client) Custom(url, method string, cfg ...Config) (*Response, error) ``` ## Request Configuration The `Config` type holds per-request parameters. JSON is used to serialize the body by default. If multiple body sources are set, precedence is: 1. Body 2. FormData 3. File ```go type Config struct { Ctx context.Context UserAgent string Referer string Header map[string]string Param map[string]string Cookie map[string]string PathParam map[string]string Timeout time.Duration MaxRedirects int Body any FormData map[string]string File []*File } ``` ## R **R** gets a `Request` object from the pool. Call `ReleaseRequest` when finished. ```go title="Signature" func (c *Client) R() *Request ``` ## Hooks Hooks allow you to add custom logic before a request is sent or after a response is received. ### RequestHook **RequestHook** returns user-defined request hooks. ```go title="Signature" func (c *Client) RequestHook() []RequestHook ``` ### ResponseHook **ResponseHook** returns user-defined response hooks. ```go title="Signature" func (c *Client) ResponseHook() []ResponseHook ``` ### AddRequestHook Adds one or more user-defined request hooks. ```go title="Signature" func (c *Client) AddRequestHook(h ...RequestHook) *Client ``` ### AddResponseHook Adds one or more user-defined response hooks. ```go title="Signature" func (c *Client) AddResponseHook(h ...ResponseHook) *Client ``` ## JSON ### JSONMarshal Returns the JSON marshaler function used by the client. ```go title="Signature" func (c *Client) JSONMarshal() utils.JSONMarshal ``` ### JSONUnmarshal Returns the JSON unmarshaler function used by the client. ```go title="Signature" func (c *Client) JSONUnmarshal() utils.JSONUnmarshal ``` ### SetJSONMarshal Sets a custom JSON marshaler. ```go title="Signature" func (c *Client) SetJSONMarshal(f utils.JSONMarshal) *Client ``` ### SetJSONUnmarshal Sets a custom JSON unmarshaler. ```go title="Signature" func (c *Client) SetJSONUnmarshal(f utils.JSONUnmarshal) *Client ``` ## XML ### XMLMarshal Returns the XML marshaler function used by the client. ```go title="Signature" func (c *Client) XMLMarshal() utils.XMLMarshal ``` ### XMLUnmarshal Returns the XML unmarshaler function used by the client. ```go title="Signature" func (c *Client) XMLUnmarshal() utils.XMLUnmarshal ``` ### SetXMLMarshal Sets a custom XML marshaler. ```go title="Signature" func (c *Client) SetXMLMarshal(f utils.XMLMarshal) *Client ``` ### SetXMLUnmarshal Sets a custom XML unmarshaler. ```go title="Signature" func (c *Client) SetXMLUnmarshal(f utils.XMLUnmarshal) *Client ``` ## CBOR ### CBORMarshal Returns the CBOR marshaler function used by the client. ```go title="Signature" func (c *Client) CBORMarshal() utils.CBORMarshal ``` ### CBORUnmarshal Returns the CBOR unmarshaler function used by the client. ```go title="Signature" func (c *Client) CBORUnmarshal() utils.CBORUnmarshal ``` ### SetCBORMarshal Sets a custom CBOR marshaler. ```go title="Signature" func (c *Client) SetCBORMarshal(f utils.CBORMarshal) *Client ``` ### SetCBORUnmarshal Sets a custom CBOR unmarshaler. ```go title="Signature" func (c *Client) SetCBORUnmarshal(f utils.CBORUnmarshal) *Client ``` ## TLS ### TLSConfig Returns the client's TLS configuration. If none is set, it initializes a new configuration with `MinVersion` defaulting to TLS 1.2. ```go title="Signature" func (c *Client) TLSConfig() *tls.Config ``` ### SetTLSConfig Sets the TLS configuration for the client. ```go title="Signature" func (c *Client) SetTLSConfig(config *tls.Config) *Client ``` ### SetCertificates Adds client certificates to the TLS configuration. ```go title="Signature" func (c *Client) SetCertificates(certs ...tls.Certificate) *Client ``` ### SetRootCertificate Adds one or more root certificates to the client's trust store. ```go title="Signature" func (c *Client) SetRootCertificate(path string) *Client ``` ### SetRootCertificateFromString Adds one or more root certificates from a string. ```go title="Signature" func (c *Client) SetRootCertificateFromString(pem string) *Client ``` ## SetProxyURL Sets a proxy URL for the client. All subsequent requests will use this proxy. ```go title="Signature" func (c *Client) SetProxyURL(proxyURL string) error ``` ## Response Streaming ### StreamResponseBody Returns whether response body streaming is enabled. When enabled, the response body is not fully loaded into memory and can be read as a stream using `Response.BodyStream()`. This is useful for handling large responses or server-sent events (SSE). ```go title="Signature" func (c *Client) StreamResponseBody() bool ``` ### SetStreamResponseBody Enables or disables response body streaming. When enabled, responses can be consumed incrementally without loading the entire body into memory. ```go title="Signature" func (c *Client) SetStreamResponseBody(enable bool) *Client ``` **Example:** ```go title="Example" cc := client.New() cc.SetStreamResponseBody(true) resp, err := cc.Get("https://example.com/large-file") if err != nil { panic(err) } defer resp.Close() // Check if response is streaming if resp.IsStreaming() { // Read body incrementally reader := resp.BodyStream() buf := make([]byte, 4096) for { n, err := reader.Read(buf) if n > 0 { // Process chunk... } if err == io.EOF { break } if err != nil { panic(err) } } } else { // Regular body access body := resp.Body() fmt.Println(string(body)) } ``` **Server-Sent Events Example:** ```go title="SSE Example" cc := client.New() cc.SetStreamResponseBody(true) resp, err := cc.Get("https://example.com/events") if err != nil { panic(err) } defer resp.Close() reader := bufio.NewReader(resp.BodyStream()) for { line, err := reader.ReadString('\n') if err == io.EOF { break } if err != nil { panic(err) } fmt.Print(line) // Process SSE event } ``` ## RetryConfig Returns the retry configuration of the client. ```go title="Signature" func (c *Client) RetryConfig() *RetryConfig ``` ## SetRetryConfig Sets the retry configuration for the client. ```go title="Signature" func (c *Client) SetRetryConfig(config *RetryConfig) *Client ``` ## BaseURL ### BaseURL **BaseURL** returns the base URL currently set in the client. ```go title="Signature" func (c *Client) BaseURL() string ``` ### SetBaseURL Sets a base URL prefix for all requests made by the client. ```go title="Signature" func (c *Client) SetBaseURL(url string) *Client ``` **Example:** ```go title="Example" cc := client.New() cc.SetBaseURL("https://httpbin.org/") resp, err := cc.Get("/get") if err != nil { panic(err) } fmt.Println(string(resp.Body())) ```
Click here to see the result ```json { "args": {}, ... } ```
## Headers ### Header Retrieves all values of a header key at the client level. The returned values apply to all requests. ```go title="Signature" func (c *Client) Header(key string) []string ``` ### AddHeader Adds a single header to all requests initiated by this client. ```go title="Signature" func (c *Client) AddHeader(key, val string) *Client ``` ### SetHeader Sets a single header, overriding any existing headers with the same key. ```go title="Signature" func (c *Client) SetHeader(key, val string) *Client ``` ### AddHeaders Adds multiple headers at once, all applying to all future requests from this client. ```go title="Signature" func (c *Client) AddHeaders(h map[string][]string) *Client ``` ### SetHeaders Sets multiple headers at once, overriding previously set headers. ```go title="Signature" func (c *Client) SetHeaders(h map[string]string) *Client ``` ## Query Parameters ### Param Returns the values for a given query parameter key. ```go title="Signature" func (c *Client) Param(key string) []string ``` ### AddParam Adds a single query parameter for all requests. ```go title="Signature" func (c *Client) AddParam(key, val string) *Client ``` ### SetParam Sets a single query parameter, overriding previously set values. ```go title="Signature" func (c *Client) SetParam(key, val string) *Client ``` ### AddParams Adds multiple query parameters from a map of string slices. ```go title="Signature" func (c *Client) AddParams(m map[string][]string) *Client ``` ### SetParams Sets multiple query parameters from a map, overriding previously set values. ```go title="Signature" func (c *Client) SetParams(m map[string]string) *Client ``` ### SetParamsWithStruct Sets multiple query parameters from a struct. Nested structs are not currently supported. ```go title="Signature" func (c *Client) SetParamsWithStruct(v any) *Client ``` ### DelParams Deletes one or more query parameters. ```go title="Signature" func (c *Client) DelParams(key ...string) *Client ``` ## UserAgent & Referer ### SetUserAgent Sets the user agent header for all requests. ```go title="Signature" func (c *Client) SetUserAgent(ua string) *Client ``` ### SetReferer Sets the referer header for all requests. ```go title="Signature" func (c *Client) SetReferer(r string) *Client ``` ## Path Parameters ### PathParam Returns the value of a named path parameter, if set. ```go title="Signature" func (c *Client) PathParam(key string) string ``` ### SetPathParam Sets a single path parameter. ```go title="Signature" func (c *Client) SetPathParam(key, val string) *Client ``` ### SetPathParams Sets multiple path parameters at once. ```go title="Signature" func (c *Client) SetPathParams(m map[string]string) *Client ``` ### SetPathParamsWithStruct Sets multiple path parameters from a struct. ```go title="Signature" func (c *Client) SetPathParamsWithStruct(v any) *Client ``` ### DelPathParams Deletes one or more path parameters. ```go title="Signature" func (c *Client) DelPathParams(key ...string) *Client ``` ## Cookies ### Cookie Returns the value of a named cookie if set at the client level. ```go title="Signature" func (c *Client) Cookie(key string) string ``` ### SetCookie Sets a single cookie for all requests. ```go title="Signature" func (c *Client) SetCookie(key, val string) *Client ``` **Example:** ```go title="Example" cc := client.New() cc.SetCookie("john", "doe") resp, err := cc.Get("https://httpbin.org/cookies") if err != nil { panic(err) } fmt.Println(string(resp.Body())) ```
Click here to see the result ```json { "cookies": { "john": "doe" } } ```
### SetCookies Sets multiple cookies at once. ```go title="Signature" func (c *Client) SetCookies(m map[string]string) *Client ``` ### SetCookiesWithStruct Sets multiple cookies from a struct. ```go title="Signature" func (c *Client) SetCookiesWithStruct(v any) *Client ``` ### DelCookies Deletes one or more cookies. ```go title="Signature" func (c *Client) DelCookies(key ...string) *Client ``` ## Timeout ### SetTimeout Sets a default timeout for all requests, which can be overridden per request. ```go title="Signature" func (c *Client) SetTimeout(t time.Duration) *Client ``` ## Debugging ### Debug Enables debug-level logging output. ```go title="Signature" func (c *Client) Debug() *Client ``` ### DisableDebug Disables debug-level logging output. ```go title="Signature" func (c *Client) DisableDebug() *Client ``` ## Cookie Jar ### SetCookieJar Assigns a cookie jar to the client to store and manage cookies across requests. ```go title="Signature" func (c *Client) SetCookieJar(cookieJar *CookieJar) *Client ``` ## Dial & Logger ### SetDial Sets a custom dial function. ```go title="Signature" func (c *Client) SetDial(dial fasthttp.DialFunc) *Client ``` ### SetLogger Sets the logger instance used by the client. ```go title="Signature" func (c *Client) SetLogger(logger log.CommonLogger) *Client ``` ### Logger Returns the current logger instance. ```go title="Signature" func (c *Client) Logger() log.CommonLogger ``` ## Reset ### Reset Clears and resets the client to its default state and reinstates the default `fasthttp.Client` transport. ```go title="Signature" func (c *Client) Reset() ``` ## Default Client Fiber provides a default client (created with `New()`). You can configure it or replace it as needed. ### C **C** returns the default client. ```go title="Signature" func C() *Client ``` ### Get Get is a convenience method that sends a GET request using the `defaultClient`. ```go title="Signature" func Get(url string, cfg ...Config) (*Response, error) ``` ### Post Post is a convenience method that sends a POST request using the `defaultClient`. ```go title="Signature" func Post(url string, cfg ...Config) (*Response, error) ``` ### Put Put is a convenience method that sends a PUT request using the `defaultClient`. ```go title="Signature" func Put(url string, cfg ...Config) (*Response, error) ``` ### Patch Patch is a convenience method that sends a PATCH request using the `defaultClient`. ```go title="Signature" func Patch(url string, cfg ...Config) (*Response, error) ``` ### Delete Delete is a convenience method that sends a DELETE request using the `defaultClient`. ```go title="Signature" func Delete(url string, cfg ...Config) (*Response, error) ``` ### Head Head sends a HEAD request using the `defaultClient`, a convenience method. ```go title="Signature" func Head(url string, cfg ...Config) (*Response, error) ``` ### Options Options is a convenience method that sends an OPTIONS request using the `defaultClient`. ```go title="Signature" func Options(url string, cfg ...Config) (*Response, error) ``` ### Replace **Replace** replaces the default client with a new one. It returns a function that can restore the old client. :::caution Do not modify the default client concurrently. ::: ```go title="Signature" func Replace(c *Client) func() ``` ================================================ FILE: docs/extra/_category_.json ================================================ { "label": "\uD83E\uDDE9 Extra", "position": 8, "link": { "type": "generated-index", "description": "Extra contents for Fiber." } } ================================================ FILE: docs/extra/benchmarks.md ================================================ --- id: benchmarks title: 📊 Benchmarks description: >- These benchmarks aim to compare the performance of Fiber and other web frameworks. sidebar_position: 2 --- ## TechEmpower [TechEmpower](https://www.techempower.com/benchmarks/#section=test&runid=1d5bfc8a-5c4a-4fb2-a792-ad967f1eb138) provides a performance comparison of many web application frameworks that execute fundamental tasks such as JSON serialization, database access, and server-side template rendering. Each framework runs under a realistic production configuration. Results are recorded on both cloud instances and physical hardware. The test implementations are community contributed and live in the [FrameworkBenchmarks repository](https://github.com/TechEmpower/FrameworkBenchmarks). * Fiber `v3.0.0` * 56 Cores Intel(R) Xeon(R) Gold 6330 CPU @ 2.00GHz (Three homogeneous ProLiant DL360 Gen10 Plus) * 64GB RAM * Enterprise SSD * Ubuntu * Mellanox Technologies MT28908 Family ConnectX-6 40Gbps Ethernet ### Plaintext The Plaintext test measures basic request routing and demonstrates the capacity of high-performance platforms. Requests are pipelined, and the tiny response body demands high throughput to saturate the benchmark's gigabit Ethernet. See [Plaintext requirements](https://github.com/TechEmpower/FrameworkBenchmarks/wiki/Project-Information-Framework-Tests-Overview#plaintext) **Fiber** - **11,987,976** responses per second with an average latency of **1.0** ms. **Express** - **1,204,969** responses per second with an average latency of **8.8** ms. ![](/img/v3/plaintext.png) ![Fiber vs Express](/img/v3/plaintext_express.png) ### Data Updates **Fiber** handled **29,984** responses per second with an average latency of **16.9** ms. **Express** handled **54,887** responses per second with an average latency of **9.2** ms. ![](/img/v3/data_updates.png) ![Fiber vs Express](/img/v3/data_updates_express.png) ### Multiple Queries **Fiber** handled **54,002** responses per second with an average latency of **9.4** ms. **Express** handled **85,011** responses per second with an average latency of **6.0** ms. ![](/img/v3/multiple_queries.png) ![Fiber vs Express](/img/v3/multiple_queries_express.png) ### Single Query **Fiber** handled **953,016** responses per second with an average latency of **0.6** ms. **Express** handled **441,543** responses per second with an average latency of **1.3** ms. ![](/img/v3/single_query.png) ![Fiber vs Express](/img/v3/single_query_express.png) ### JSON Serialization **Fiber** handled **2,363,294** responses per second with an average latency of **0.2** ms. **Express** handled **949,717** responses per second with an average latency of **0.5** ms. ![](/img/v3/json.png) ![Fiber vs Express](/img/v3/json_express.png) ================================================ FILE: docs/extra/faq.md ================================================ --- id: faq title: 🤔 FAQ description: >- Frequently asked questions. Open an issue if you have another question to add. sidebar_position: 1 --- ## How should I structure my application? There's no single answer; the ideal structure depends on your application's scale and team. Fiber makes no assumptions about project layout. Routes and other application logic can live in any files or directories. For inspiration, see: * [gofiber/boilerplate](https://github.com/gofiber/boilerplate) * [thomasvvugt/fiber-boilerplate](https://github.com/thomasvvugt/fiber-boilerplate) * [Youtube - Building a REST API using Gorm and Fiber](https://www.youtube.com/watch?v=Iq2qT0fRhAA) * [embedmode/fiberseed](https://github.com/embedmode/fiberseed) ## How do I handle custom 404 responses? If you're using v2.32.0 or later, implement a custom error handler as shown below or read more at [Error Handling](../guide/error-handling.md#custom-error-handler). If you're using v2.31.0 or earlier, the error handler will not capture 404 errors. Instead, add a middleware function at the very bottom of the stack \(below all other functions\) to handle a 404 response: ```go title="Example" app.Use(func(c fiber.Ctx) error { return c.Status(fiber.StatusNotFound).SendString("Sorry can't find that!") }) ``` ## How can I use live reload? [Air](https://github.com/air-verse/air) automatically restarts your Go application when source files change, speeding development. To use Air in a Fiber project, follow these steps: * Install Air by downloading the appropriate binary for your operating system from the GitHub release page or by building the tool from source. * Create a configuration file for Air in your project directory, such as `.air.toml` or `air.conf`. Here's a sample configuration file that works with Fiber: ```toml # .air.toml root = "." tmp_dir = "tmp" [build] cmd = "go build -o ./tmp/main ." bin = "./tmp/main" delay = 1000 # ms exclude_dir = ["assets", "tmp", "vendor"] include_ext = ["go", "tpl", "tmpl", "html"] exclude_regex = ["_test\\.go"] ``` * Start your Fiber application with Air by running the following command: ```sh air ``` As you edit source files, Air detects the changes and restarts the application. A complete example is available in the [Fiber Recipes repository](https://github.com/gofiber/recipes/tree/master/air) and shows how to configure Air for a Fiber project. ## How do I set up an error handler? To override the default error handler, provide a custom one in the [Config](../api/fiber.md#errorhandler) when creating a new [Fiber instance](../api/fiber.md#new). ```go title="Example" app := fiber.New(fiber.Config{ ErrorHandler: func(c fiber.Ctx, err error) error { return c.Status(fiber.StatusInternalServerError).SendString(err.Error()) }, }) ``` We have a dedicated page explaining how error handling works in Fiber, see [Error Handling](../guide/error-handling.md). ## Which template engines does Fiber support? Fiber currently supports 9 template engines in our [gofiber/template](https://docs.gofiber.io/template/) middleware: * [ace](https://docs.gofiber.io/template/ace/) * [amber](https://docs.gofiber.io/template/amber/) * [django](https://docs.gofiber.io/template/django/) * [handlebars](https://docs.gofiber.io/template/handlebars/) * [html](https://docs.gofiber.io/template/html/) * [jet](https://docs.gofiber.io/template/jet/) * [mustache](https://docs.gofiber.io/template/mustache/) * [pug](https://docs.gofiber.io/template/pug/) * [slim](https://docs.gofiber.io/template/slim/) To learn more about using Templates in Fiber, see [Templates](../guide/templates.md). ## Does Fiber have a community chat? Yes, we have a [Discord](https://gofiber.io/discord) server with rooms for every topic. If you have questions or just want to chat, join us via this [invite link](https://gofiber.io/discord). ![](/img/support-discord.png) ## Does Fiber support subdomain routing? Yes, we do. Here are some examples:
Example ```go package main import ( "log" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/logger" ) type Host struct { Fiber *fiber.App } func main() { // Hosts hosts := map[string]*Host{} //----- // API //----- api := fiber.New() api.Use(logger.New(logger.Config{ Format: "[${ip}]:${port} ${status} - ${method} ${path}\n", })) hosts["api.localhost:3000"] = &Host{api} api.Get("/", func(c fiber.Ctx) error { return c.SendString("API") }) //------ // Blog //------ blog := fiber.New() blog.Use(logger.New(logger.Config{ Format: "[${ip}]:${port} ${status} - ${method} ${path}\n", })) hosts["blog.localhost:3000"] = &Host{blog} blog.Get("/", func(c fiber.Ctx) error { return c.SendString("Blog") }) //--------- // Website //--------- site := fiber.New() site.Use(logger.New(logger.Config{ Format: "[${ip}]:${port} ${status} - ${method} ${path}\n", })) hosts["localhost:3000"] = &Host{site} site.Get("/", func(c fiber.Ctx) error { return c.SendString("Website") }) // Server app := fiber.New() app.Use(func(c fiber.Ctx) error { host := hosts[c.Hostname()] if host == nil { return c.SendStatus(fiber.StatusNotFound) } else { host.Fiber.Handler()(c.Context()) return nil } }) log.Fatal(app.Listen(":3000")) } ```
For more information, see issue [#750](https://github.com/gofiber/fiber/issues/750). ## How can I handle conversions between Fiber and net/http? Fiber can register common `net/http` handlers directly—just pass an `http.Handler`, `http.HandlerFunc`, compatible function, or even a native `fasthttp.RequestHandler` to your routing method. For other interoperability scenarios, the `adaptor` middleware provides utilities for converting between Fiber and `net/http`. It allows seamless integration of `net/http` handlers, middleware, and requests into Fiber applications, and vice versa. :::caution Performance trade-offs Converted `net/http` handlers run through a compatibility layer. They won't expose `fiber.Ctx` or Fiber-specific helpers, and the extra adaptation work makes them slower than native Fiber handlers. Use them when interoperability matters, but prefer Fiber handlers for maximum performance. ::: For details on how to: * Convert `net/http` handlers to Fiber handlers * Convert Fiber handlers to `net/http` handlers * Convert `fiber.Ctx` to `http.Request` See the dedicated documentation: [Adaptor Documentation](../middleware/adaptor.md). ================================================ FILE: docs/extra/internal.md ================================================ --- title: 🏗️ Internal Architecture description: >- Learn about the internal architecture of Fiber, including the overall structure, request handling flow, routing, and path parsing. sidebar_position: 3 --- ## Overall Architecture At the heart of Fiber is the **App** struct. It is responsible for configuring the server, managing a pool of Contexts (either our default implementation, **DefaultCtx**, or a user‑supplied **CustomCtx**), and holding the router stack with all registered routes and groups. In addition, the App contains mount fields to support sub‑applications and hooks that allow developers to run custom code at key stages (e.g. when registering routes or starting the server). ```mermaid flowchart TD A[App] B["Configuration (Config)"] C[Context Pool] D["DefaultCtx \/ CustomCtx"] E[Router Stack] F["Groups & Routes"] G["MountFields (Sub‑Apps)"] H[Hooks] A --> B A --> C C --> D A --> E E --> F A --> G A --> H ``` ### Explanation - App: The central object that bootstraps and runs the Fiber server. - Configuration (Config): Contains settings for body limits, timeouts, TLS options, routing behavior (e.g. case‑sensitivity, strict routing), and more. - Context Pool: A synchronized pool from which Contexts are acquired per request. This design minimizes allocations by recycling DefaultCtx (or CustomCtx) instances. - Router Stack: Organizes all registered routes. It is later processed into a tree structure for fast route‑matching. - MountFields: Support for mounting sub‑applications so that large APIs can be segmented into independent routers. - Hooks: Allow for custom behavior at critical points (e.g., on route registration, route naming, on listen, on shutdown, etc.). ## Request Processing Flow Fiber’s request processing is designed for performance and minimal overhead. When an HTTP request is received by the underlying fasthttp server, the flow is as follows: 1. Request Arrival: The fasthttp server receives the HTTP request. 2. Context Acquisition: The App calls AcquireCtx() to fetch a Context from the pool. 3. Context Reset: The acquired Context is reset (via DefaultCtx.Reset()) with the new request’s data. 4. Request Handling: The request handler (default or custom) is invoked. 5. Route Matching: The framework uses the next() (or nextCustom()) function to traverse the pre‑built route tree and find a matching route based on the URL and HTTP method. 6. Middleware Chain Execution: The matched route’s handler chain is executed in sequence. 7. Error Handling (if required): Any errors encountered trigger the registered error handler. 8. Response Generation: The response is sent back to the client. 9. Context Release: Finally, the Context is cleaned up and returned to the pool. ```mermaid flowchart LR R["HTTP Request (fasthttp)"] A["App.RequestHandler
(default or custom)"] C["Acquire Context
(from Pool)"] X["Reset Context
(DefaultCtx.Reset())"] N["Route Matching
(next() \/ nextCustom())"] M["Handler Chain Execution"] EH["Error Handling
(if needed)"] S["HTTP Response"] RC["Release Context
(to Pool)"] R --> A A --> C C --> X X --> N N --> M M --> EH EH --> S S --> RC ``` ### Additional Note Fiber minimizes memory allocations by reusing Context objects and uses an optimized route‑matching algorithm to rapidly determine the correct handler chain. ## Routing & Path Parsing Fiber allows you to register routes using helper methods (e.g. Get(), Post()) or by creating groups and sub‑routers. Internally, the route pattern is parsed by the parseRoute() function. This function decomposes the route string into segments: - Constant Segments: Fixed parts of the path (e.g. /api). - Parameter Segments: Dynamic parts that begin with a colon. For example, a route may be defined as: /api/\:userId<int> Here, the segment \:userId<int> is a parameter segment with a type constraint (an integer). - Constraints: Constraints (such as int, bool, datetime, or even regular expressions) are extracted from the parameter part and stored in the route’s metadata for validation at runtime. ```mermaid flowchart TD P["Route Pattern String
(e.g., '/api/\\:userId\\<int>')"] PA["parseRoute()"] RP[routeParser] RS["routeSegment(s)"] C["Constraints
(e.g., int, datetime, regex)"] PARAM[Extracted Parameter Names] P --> PA PA --> RP RP --> RS RS --> C RP --> PARAM ``` ### Explanation - parseRoute(): Takes a route string and returns a routeParser struct that includes a list of routeSegment objects. - routeSegment: Represents a portion of the route. If it is a parameter segment, it may include constraints that determine the allowed format (for example, ensuring that a parameter is an integer). - Extracted Parameter Names: These are later used to populate the request’s Context with the actual values parsed from the URL. ## Route Matching and Parameter Extraction When a request is processed, Fiber uses its pre‑computed route tree (the treeStack) to efficiently match the incoming URL against registered routes. 1. Normalization: The URL is normalized (converted to lowercase, trailing slashes trimmed) to create a “detection path.” 2. Tree Traversal: The route tree, grouped by common prefixes, is traversed based on the HTTP method. 3. Matching: Constant segments are compared exactly, while parameter segments extract dynamic values. 4. Constraint Validation: Extracted parameter values are validated against any defined constraints. ```mermaid flowchart TD A["Incoming Request URL
(e.g., '/api/john')"] B["Normalize URL
(lowercase, trim trailing slashes)"] C["Detection Path"] D["Traverse Route Tree
(treeStack based on method)"] E["Match Constant Segments"] F["Identify Parameter Segments
(e.g., ':userId')"] G["Extract Parameter Values"] H["Validate Constraints
(e.g., 'int', 'datetime', 'regex')"] I["Route Found"] A --> B B --> C C --> D D --> E E --> F F --> G G --> H H --> I ``` ### Insight This efficient matching mechanism leverages pre‑grouped routes to minimize comparisons, while dynamic segments allow for flexible URL structures and runtime validation. ## Middleware Chain Execution Once a matching route is found, Fiber executes the chain of middleware and route handlers sequentially. The process is as follows: 1. Initial Handler Execution: The first handler of the matched route is invoked. 2. Calling Next(): Each handler calls Ctx.Next() to pass control to the next handler in the chain. 3. Termination: When no further handlers remain, the chain terminates and the response is sent. ```mermaid flowchart TD A[Matched Route] B[Handler 1] C[Handler 2] D[Handler 3] E[Response Generation] A --> B B -- "Calls C via Next()" --> C C -- "Calls D via Next()" --> D D -- "No Next() available" --> E ``` ### Explanation - Each handler in the chain can perform operations (e.g. authentication, logging, transformation) before calling Next() to forward control. - This sequential processing ensures that middleware are executed in the order they were registered. - If an error occurs or a handler does not call Next(), the chain may be terminated early, and an error handler may be invoked. ### Observations Middleware are executed in the order they are registered. This sequential design allows each handler to perform tasks such as authentication, logging, or transformation before delegating to the next handler. ## Sub-Application Mounting & Grouping Fiber allows mounting sub‑applications (or sub‑routers) under specific path prefixes. This enables modular design of large APIs. The mounting process works as follows: 1. Defining a Mount Point: A parent application (or group) calls `Use` with a sub-app, which triggers the internal mount path logic. 2. Merging Mount Fields: The sub‑app’s mount fields are updated with the prefix of the parent, and its routes are integrated into the parent’s routing structure. 3. Processing Sub‑App Routes: During startup, the parent app collects routes from mounted sub‑apps and builds a unified route tree. ```mermaid flowchart TD A[Parent App] B["Sub-App (Mounted)"] C["Define Mount Point
(e.g. \'/admin\')"] D["Update MountFields
(assign mount path)"] E["Merge Sub-App Routes
(append to Router Stack)"] F[Generate Unified Route Tree] A --> C C --> B B --> D D --> E E --> F ``` ### Impact This mechanism enables large APIs to be broken down into smaller, maintainable modules while still benefiting from Fiber’s optimized routing and request handling. ## Route Tree Building Fiber builds a route tree (the treeStack) to optimize route matching. This involves grouping routes based on a prefix (usually the first few characters) to reduce the number of comparisons during a request. 1. Iterating Over the Router Stack: Each registered route is examined. 2. Computing the Tree Key: A key is computed from the route’s normalized path (e.g. the first 3 characters). 3. Grouping Routes: Routes are added to the appropriate branch of the tree. 4. Sorting: Within each group, routes are sorted based on their registration order (or position) to ensure the correct match is found. ```mermaid flowchart TD A["Router Stack
(All Registered Routes)"] B["Compute Tree Key
(e.g. first 3 characters)"] C["Group Routes by Key
(treeStack)"] D["Merge Global Routes
(key \'\' for global matches)"] E[Sort Routes within Groups] F[Optimized Route Tree] A --> B B --> C C --> D D --> E E --> F ``` ### Explanation - Building a route tree is an optimization step that reduces the matching overhead by limiting the search space to a subset of routes that share a common prefix. - The tree is rebuilt whenever new routes are registered, ensuring that the latest routing configuration is always used for matching. ## Context Lifecycle Management Fiber minimizes allocations by pooling Context objects. The lifecycle of a Context is as follows: 1. **Acquisition:** When a new HTTP request arrives, a Context is retrieved from the pool via `App.AcquireCtx()`. 2. **Reset:** The acquired Context is reset with the current `fasthttp.RequestCtx` to clear previous data and initialize new request‑specific values. 3. **Processing:** The Context is passed along the middleware and handler chain. 4. **Release:** After processing the request (or when an error occurs), the Context is released back to the pool via `App.ReleaseCtx()`, making it available for reuse. ```mermaid flowchart TD A["HTTP Request
(fasthttp)"] B["Acquire Context
(App.AcquireCtx())"] C["Reset Context
(DefaultCtx.Reset())"] D["Process Request
(Handlers & Middleware)"] E["Error Handling
(if needed)"] F["Release Context
(App.ReleaseCtx())"] A --> B B --> C C --> D D --> E E --> F ``` ### Key Benefit Reusing Context objects significantly reduces garbage collection overhead, ensuring Fiber remains fast and memory‑efficient even under heavy load. ## Preforking Mechanism To take full advantage of multi‑core systems, Fiber offers a prefork mode. In this mode, the master process spawns several child processes that listen on the same port using OS features such as SO_REUSEPORT (or fall back to SO_REUSEADDR). ```mermaid flowchart LR M["Master Process (App)"] C[Child Processes] GOMAX["Set GOMAXPROCS(1)"] REQ[Handle HTTP Requests] WM["watchMaster()"] M -->|Spawns| C C --> GOMAX C -->|Processes| REQ C --> WM ``` ### Explanation - Master Process: The main process determines the number of available CPU cores and spawns that many child processes. - Child Processes: Each child sets GOMAXPROCS(1) to run on a single CPU core and listens on the shared port. - watchMaster(): Each child process runs a watchdog routine to monitor the master process; if the master exits (or its parent process ID becomes 1 on Unix‑like systems), the child terminates gracefully. ### Detailed Preforking Workflow Fiber’s prefork mode uses OS‑level mechanisms to allow multiple processes to listen on the same port. Here’s a more detailed look: 1. Master Process Spawning: The master process detects the number of CPU cores and spawns that many child processes. 2. Child Process Initialization: Each child process sets GOMAXPROCS(1) so that it runs on a single core. 3. Binding to Port: Child processes use packages like reuseport to bind to the same address and port. 4. Parent Monitoring: Each child runs a watchdog function (watchMaster()) to monitor the master process; if the master terminates, children exit. 5. Request Handling: Each child independently handles incoming HTTP requests. ```mermaid flowchart TD A[Master Process] B[Determine CPU Cores] C[Spawn Child Processes] D["Child Process Initialization
(GOMAXPROCS(1))"] E["Bind to Port
(reuseport)"] F["Run watchMaster()
(Monitor Parent)"] G[Handle HTTP Requests] A --> B B --> C C --> D D --> E E --> F F --> G ``` #### Explanation - Preforking improves performance by allowing multiple processes to handle requests concurrently. - Using reuseport (or a fallback) ensures that all child processes can listen on the same port without conflicts. - The watchdog routine in each child ensures that they exit if the master process is no longer running, maintaining process integrity. ## Redirection & Flash Messages Fiber’s redirection mechanism is implemented via the Redirect struct. This structure allows not only setting a new location for redirection but also passing along flash messages and old input data via a special cookie. ```mermaid flowchart LR R[Redirect Struct] RP[redirectPool] FM["Flash Messages \/ Old Inputs"] M["Methods:
To(), Route(), Back()"] LH[Set Location Header] CK["Flash Cookie
(fiber\_flash)"] R -->|Acquired from| RP R --> FM R --> M M --> LH FM -->|Serialized| CK ``` ### Explanation - Redirect Struct: Retrieved from a pool (to minimize allocations), it stores redirection settings such as the HTTP status code (defaulting to 303 See Other) and any flash messages. - Flash Messages & Old Inputs: These are collected via methods like With() or WithInput() and then serialized and stored in a cookie named fiber_flash. - Redirection Methods: The To(), Route(), and Back() methods determine the target URL and set the Location header accordingly. ### Flash Message Handling in Redirection When performing redirections, Fiber can send flash messages or preserve old input data. This process involves: 1. Collecting Flash Data: When a redirect is initiated, developers can add flash messages via Redirect.With() or old input data via Redirect.WithInput(). 2. Serialization: The flash messages and input data are serialized (using a fast marshalling method) into a byte sequence. 3. Setting a Cookie: The serialized data is stored in a special cookie (named fiber_flash) that will be sent to the client. 4. Retrieval & Clearing: On the subsequent request, the flash data is read from the cookie, deserialized, and then cleared. ```mermaid flowchart TD A[Initiate Redirect] B["Add Flash Messages
(With(), WithInput())"] C[Serialize Flash Data] D["Set Flash Cookie
(\'fiber\_flash\')"] E[Client Receives Redirect] F[Next Request Reads Flash Cookie] G["Deserialize & Clear Flash Data"] A --> B B --> C C --> D D --> E E --> F F --> G ``` #### Explanation - Flash messages provide a way to pass transient data (such as notifications or error messages) to the next request after a redirect. - The data is stored temporarily in a cookie, which is then read and cleared upon processing the next request. - This mechanism is essential for implementing post‑redirect‑get patterns and ensuring a smooth user experience. ## Hooks, Error Handling & Context Lifecycle ### Hooks Fiber provides a comprehensive hook system that allows you to run custom functions at key moments: - OnRoute: Called when a route is registered. - OnName: Invoked when a route is assigned a name. - OnGroup: Triggered when a group is created. - OnListen: Runs when the server starts listening. - OnShutdown: Called during graceful shutdown. - OnFork: Invoked when a child process is forked. - OnMount: Used when a sub‑application is mounted. ```mermaid flowchart TD H[Hooks] OR[OnRoute] ON[OnName] OG[OnGroup] OL[OnListen] OS[OnShutdown] OF[OnFork] OM[OnMount] H --> OR H --> ON H --> OG H --> OL H --> OS H --> OF H --> OM ``` #### Explanation - Hooks provide extension points for developers and maintainers to inject custom logic without modifying the core Fiber code. - They are executed at various stages (for example, every time a new route is registered, the OnRoute hooks are executed to allow for logging, validation, or transformation of the route). ### Error Handling & Context Lifecycle Fiber’s DefaultCtx (or CustomCtx) represents the per‑request context. The lifecycle is as follows: - Acquire: A Context is obtained from the pool at the beginning of a request. - Processing: The context is passed along to the route handlers and middleware. - Error Handling: If an error occurs (e.g., route not found, method not allowed, or a panic in the handler), Fiber calls the registered error handler. Errors such as ErrMethodNotAllowed or StatusNotFound are generated as needed. - Release: Once the request is processed, the Context is released back into the pool for reuse. ```mermaid flowchart LR AC["Acquire Context
(from Pool)"] HP["Handle Request
(Handlers & Middleware)"] EH["Error Handling
(if needed)"] RC["Release Context
(to Pool)"] AC --> HP HP --> EH EH --> RC ``` #### Explanation - This lifecycle ensures that Fiber minimizes allocations by reusing Context objects. - Errors are propagated and handled consistently, and the context is properly reset after every request. ================================================ FILE: docs/extra/learning-resources.md ================================================ --- id: learning-resources title: 📚 Learning Resources description: >- Interactive learning platforms and community resources to help you learn Fiber concepts through hands-on practice. sidebar_position: 3 --- ## Interactive Learning Platforms Looking to practice Fiber concepts through hands-on exercises? Here are some community-driven learning resources: ### Go Interview Practice - Fiber Challenges A comprehensive platform offering progressive Fiber challenges that complement the official documentation. ![Learning Path Overview](/img/learning-resources/fiber-learning-path.png) **What You'll Learn:** - **High-Performance APIs** - Build ultra-fast RESTful APIs with zero-allocation routing - **Middleware & Security** - Implement custom middleware, rate limiting, CORS, and authentication - **Request Validation** - Input validation, error handling, and data transformation - **Authentication & JWT** - Secure authentication systems with JWT tokens and API key validation ![Challenge Interface](/img/learning-resources/fiber-challenge-interface.png) **Challenge Roadmap:** 1. **Basic Routing** - Setup Fiber, routes, and handlers (Beginner) 2. **Middleware & CORS** - Custom middleware and rate limiting (Intermediate) 3. **Validation & Errors** - Input validation and error handling (Intermediate) 4. **Authentication** - JWT tokens and API key validation (Advanced) ![Fiber Framework Overview](/img/learning-resources/fiber-framework-overview.png) ![Interactive Learning Experience](/img/learning-resources/fiber-learning-experience.png) [Explore Fiber Challenges →](https://rezasi.github.io/go-interview-practice/fiber) | [GitHub Repository →](https://github.com/RezaSi/go-interview-practice) ================================================ FILE: docs/guide/_category_.json ================================================ { "label": "\uD83D\uDCD6 Guide", "position": 7, "link": { "type": "generated-index", "description": "Guides for Fiber." } } ================================================ FILE: docs/guide/advance-format.md ================================================ --- id: advance-format title: 🐛 Advanced Format description: >- Learn how to use MessagePack (MsgPack) and CBOR for efficient binary serialization in Fiber applications. sidebar_position: 9 --- ## MsgPack Fiber lets you use MessagePack for efficient binary serialization. Use one of the popular Go libraries below to encode and decode data in handlers. - Fiber can bind requests with the `application/vnd.msgpack` content type out of the box. See the [Binding documentation](../api/bind.md#msgpack) for details. - Use `Bind().MsgPack()` to bind data to structs, similar to JSON. `Ctx.AutoFormat()` responds with MsgPack when the `Accept` header is `application/vnd.msgpack`. See the [AutoFormat documentation](../api/ctx.md#autoformat) for more. ### Recommended Libraries - [github.com/vmihailenco/msgpack](https://pkg.go.dev/github.com/vmihailenco/msgpack) — A widely used, feature-rich MsgPack library. - [github.com/shamaton/msgpack/v3](https://pkg.go.dev/github.com/shamaton/msgpack/v3) — High-performance MsgPack library. ### Installation Install either library using: ```bash go get github.com/vmihailenco/msgpack # or go get github.com/shamaton/msgpack/v3 ``` > **Note:** Fiber doesn't bundle a MsgPack implementation because it's outside the Go standard library. Pick one of the popular libraries in the ecosystem; the two below are widely used and well maintained. ### Example: Using `shamaton/msgpack/v3` ```go import ( "github.com/gofiber/fiber/v3" "github.com/shamaton/msgpack/v3" ) type User struct { Name string `msgpack:"name"` // tag may vary depending on your MsgPack library Age int `msgpack:"age"` } func main() { app := fiber.New(fiber.Config{ // Optional: Set custom MsgPack encoder/decoder MsgPackEncoder: msgpack.Marshal, MsgPackDecoder: msgpack.Unmarshal, }) app.Post("/msgpack", func(c fiber.Ctx) error { var user User if err := c.Bind().MsgPack(&user); err != nil { return err } // Content type will be set automatically to application/vnd.msgpack return c.MsgPack(user) }) app.Listen(":3000") } ``` ## CBOR Fiber doesn't ship with a CBOR implementation. Use a library such as [fxamacker/cbor](https://github.com/fxamacker/cbor) to add encoding and decoding. - Use `Bind().CBOR()` to bind CBOR to structs. `Ctx.AutoFormat()` replies with CBOR when the `Accept` header is `application/cbor`. See the [AutoFormat documentation](../api/ctx.md#autoformat) for details. ```bash go get github.com/fxamacker/cbor/v2 ``` Configure Fiber with the chosen library: ```go import ( "github.com/gofiber/fiber/v3" "github.com/fxamacker/cbor/v2" ) func main() { app := fiber.New(fiber.Config{ CBOREncoder: cbor.Marshal, CBORDecoder: cbor.Unmarshal, }) type User struct { Name string `cbor:"name"` Age int `cbor:"age"` } app.Post("/cbor", func(c fiber.Ctx) error { var user User if err := c.Bind().CBOR(&user); err != nil { return err } // Content type will be set automatically to application/cbor return c.CBOR(user) }) app.Listen(":3000") } ``` ================================================ FILE: docs/guide/context.md ================================================ --- id: go-context title: "\U0001F9E0 Go Context" description: >- Learn how Fiber's Ctx integrates with Go's context.Context, how to interact with the underlying fasthttp RequestCtx, and how to use the available context helpers. sidebar_position: 6 toc_max_heading_level: 4 --- import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; ## Fiber Context as `context.Context` Fiber's [`Ctx`](../api/ctx.md) implements Go's [`context.Context`](https://pkg.go.dev/context#Context) interface. You can pass `c` directly to functions that expect a `context.Context` without adapters. However, `fasthttp` doesn't support cancellation yet, so `Deadline`, `Done`, and `Err` are no-ops. :::caution The `fiber.Ctx` instance is only valid within the lifetime of the handler. It is reused for subsequent requests, so avoid storing `c` or using it in goroutines that outlive the handler. For asynchronous work, call `c.Context()` inside the handler to obtain a `context.Context` that can safely be used after the handler returns. By default, this returns `context.Background()` unless a custom context was provided with `c.SetContext`. ::: ```go title="Example" func doSomething(ctx context.Context) { // ... your logic here } app.Get("/", func(c fiber.Ctx) error { doSomething(c) // c satisfies context.Context return nil }) ``` ### Using context outside the handler `fiber.Ctx` is recycled after each request. If you need a context that lives longer—for example, for work performed in a new goroutine—obtain it with `c.Context()` before returning from the handler. ```go title="Async work" app.Get("/job", func(c fiber.Ctx) error { ctx := c.Context() go performAsync(ctx) return c.SendStatus(fiber.StatusAccepted) }) ``` You can customize the base context by calling `c.SetContext` before requesting it: ```go app.Get("/job", func(c fiber.Ctx) error { c.SetContext(context.WithValue(context.Background(), "requestID", "123")) ctx := c.Context() go performAsync(ctx) return nil }) ``` ### Retrieving Values `Ctx.Value` is backed by [Locals](../api/ctx.md#locals). Values stored with `c.Locals` are accessible through `Value` or standard `context.WithValue` helpers. ```go title="Locals and Value" app.Get("/", func(c fiber.Ctx) error { c.Locals("role", "admin") role := c.Value("role") // returns "admin" return c.SendString(role.(string)) }) ``` ## Working with `RequestCtx` and `fasthttpctx` The underlying [`fasthttp.RequestCtx`](https://pkg.go.dev/github.com/valyala/fasthttp#RequestCtx) can be accessed via `c.RequestCtx()`. This exposes low-level APIs and the extra context support provided by `fasthttpctx`. ```go title="Accessing RequestCtx" app.Get("/raw", func(c fiber.Ctx) error { fctx := c.RequestCtx() // use fasthttp APIs directly fctx.Response.Header.Set("X-Engine", "fasthttp") return nil }) ``` `fasthttpctx` enables `fasthttp` to satisfy the `context.Context` interface. `Deadline` always reports no deadline, `Done` is closed when the client connection ends, and once it fires `Err` reports `context.Canceled`. This means handlers can detect client disconnects while still passing `c.RequestCtx()` into APIs that expect a `context.Context`. ## Context Helpers Fiber and its middleware expose a number of helper functions that retrieve request-scoped values from the context. ### Request ID The RequestID middleware stores the generated identifier in the context. Use `requestid.FromContext` to read it later. ```go app.Use(requestid.New()) app.Get("/", func(c fiber.Ctx) error { id := requestid.FromContext(c) return c.SendString(id) }) ``` ### CSRF The CSRF middleware provides helpers to fetch the token or the handler attached to the current context. ```go app.Use(csrf.New()) app.Get("/form", func(c fiber.Ctx) error { token := csrf.TokenFromContext(c) return c.SendString(token) }) ``` ```go title="Deleting a token" app.Post("/logout", func(c fiber.Ctx) error { handler := csrf.HandlerFromContext(c) if handler != nil { // Invalidate the token on logout _ = handler.DeleteToken(c) } // ... other logout logic return c.SendString("Logged out") }) ``` ### Session Sessions are stored on the context and can be retrieved via `session.FromContext`. ```go app.Use(session.New()) app.Get("/", func(c fiber.Ctx) error { sess := session.FromContext(c) count := sess.Get("visits") return c.JSON(fiber.Map{"visits": count}) }) ``` ### Basic Authentication After successful authentication, the username is available with `basicauth.UsernameFromContext`. Passwords in `Users` must be pre-hashed. ```go app.Use(basicauth.New(basicauth.Config{ Users: map[string]string{ // "secret" hashed using SHA-256 "admin": "{SHA256}K7gNU3sdo+OL0wNhqoVWhr3g6s1xYv72ol/pe/Unols=", }, })) app.Get("/", func(c fiber.Ctx) error { user := basicauth.UsernameFromContext(c) return c.SendString(user) }) ``` ### Key Authentication For API key authentication, the extracted token is stored in the context and accessible via `keyauth.TokenFromContext`. ```go app.Use(keyauth.New()) app.Get("/", func(c fiber.Ctx) error { token := keyauth.TokenFromContext(c) return c.SendString(token) }) ``` ## Using `context.WithValue` and Friends Since `fiber.Ctx` conforms to `context.Context`, standard helpers such as `context.WithValue`, `context.WithTimeout`, or `context.WithCancel` can wrap the request context when needed. ```go app.Get("/job", func(c fiber.Ctx) error { ctx, cancel := context.WithTimeout(c, 5*time.Second) defer cancel() // pass ctx to async operations that honor cancellation if err := doWork(ctx); err != nil { return err } return c.SendStatus(fiber.StatusOK) }) ``` ### Context Cancellation with Goroutines in Fiber When starting asynchronous work inside a handler, Fiber does not cancel the base `fiber.Ctx` automatically. By wrapping the request context with `context.WithTimeout`, you can create a derived context that honors deadlines and cancellation signals. The goroutine checks `ctx.Done()` before sending a result. If the request times out or the client disconnects the goroutine exits early and avoids leaking resources. The handler then waits for either: - a result from the goroutine, or - the `context timeout` (which returns a 504 Gateway Timeout) This pattern ensures that long-running operations (database queries, external API calls, background tasks) do not continue running after the request has ended. ```go func Handler(c fiber.Ctx) error { ctx, cancel := context.WithTimeout(c.Context(), 2*time.Second) defer cancel() resultChan := make(chan string, 1) go func() { select { case <-time.After(3 * time.Second): select { case <-ctx.Done(): return case resultChan <- "done": } case <-ctx.Done(): return } }() select { case res := <-resultChan: return c.SendString(res) case <-ctx.Done(): return c.Status(fiber.StatusGatewayTimeout).SendString("timeout") } } ``` This approach provides safe cancellation semantics for goroutine-based work while allowing you to integrate Fiber handlers with context-aware APIs. ## Summary - `fiber.Ctx` satisfies `context.Context` but its `Deadline`, `Done`, and `Err` methods are currently no-ops. - `RequestCtx` exposes the raw `fasthttp` context, whose `Done` channel closes when the client connection ends. - Use `fiber.StoreInContext(c, key, value)` to store request-scoped values in both `c.Locals()` and `c.Context()` when values must be available through either API. - Middleware helpers like `requestid.FromContext` or `session.FromContext` make it easy to retrieve request-scoped data. - Standard helpers such as `context.WithTimeout` can wrap `fiber.Ctx` to create fully featured derived contexts inside handlers. - `fiber.Config.PassLocalsToContext` controls whether Fiber context helpers also propagate values into the request `context.Context` for Fiber-backed contexts when using `StoreInContext`. It defaults to `false` for backward compatibility, while `ValueFromContext` keeps reading from `c.Locals()`. - Use `c.Context()` to obtain a `context.Context` that can outlive the handler, and `c.SetContext()` to customize it with additional values or deadlines. With these tools, you can seamlessly integrate Fiber applications with Go's context-based APIs and manage request-scoped data effectively. ================================================ FILE: docs/guide/error-handling.md ================================================ --- id: error-handling title: 🐛 Error Handling description: >- Fiber supports centralized error handling: handlers return errors so you can log them or send a custom HTTP response to the client. sidebar_position: 4 --- import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; ## Catching Errors Return errors from route handlers and middleware so Fiber can handle them centrally. ```go app.Get("/", func(c fiber.Ctx) error { // Pass error to Fiber return c.SendFile("file-does-not-exist") }) ``` Fiber does not recover from [panics](https://go.dev/blog/defer-panic-and-recover) by default. Add the `Recover` middleware to catch panics in any handler: ```go title="Example" package main import ( "log" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/recover" ) func main() { app := fiber.New() app.Use(recover.New()) app.Get("/", func(c fiber.Ctx) error { panic("This panic is caught by fiber") }) log.Fatal(app.Listen(":3000")) } ``` Use `fiber.NewError()` to create an error with a status code. If you omit the message, Fiber uses the standard status text (for example, `404` becomes `Not Found`). ```go title="Example" app.Get("/", func(c fiber.Ctx) error { // 503 Service Unavailable return fiber.ErrServiceUnavailable // 503 On vacation! return fiber.NewError(fiber.StatusServiceUnavailable, "On vacation!") }) ``` ## Default Error Handler Fiber ships with a default error handler that sends **500 Internal Server Error** for generic errors. If the error is a [fiber.Error](https://godoc.org/github.com/gofiber/fiber#Error), the response uses the embedded status code and message. ```go title="Example" // Default error handler var DefaultErrorHandler = func(c fiber.Ctx, err error) error { // Status code defaults to 500 code := fiber.StatusInternalServerError // Retrieve the custom status code if it's a *fiber.Error var e *fiber.Error if errors.As(err, &e) { code = e.Code } // Set Content-Type: text/plain; charset=utf-8 c.Set(fiber.HeaderContentType, fiber.MIMETextPlainCharsetUTF8) // Return status code with error message return c.Status(code).SendString(err.Error()) } ``` ## Custom Error Handler Set a custom error handler in [`fiber.Config`](../api/fiber.md#errorhandler) when creating a new app. The default handler covers most cases, but a custom handler lets you react to specific error types—for example, by logging to a service or sending a tailored JSON or HTML response. The following example shows how to display error pages for different types of errors. ```go title="Example" // Create a new fiber instance with custom config app := fiber.New(fiber.Config{ // Override default error handler ErrorHandler: func(ctx fiber.Ctx, err error) error { // Status code defaults to 500 code := fiber.StatusInternalServerError // Retrieve the custom status code if it's a *fiber.Error var e *fiber.Error if errors.As(err, &e) { code = e.Code } // Send custom error page err = ctx.Status(code).SendFile(fmt.Sprintf("./%d.html", code)) if err != nil { // In case the SendFile fails return ctx.Status(fiber.StatusInternalServerError).SendString("Internal Server Error") } // Return from handler return nil }, }) // ... ``` > Special thanks to the [Echo](https://echo.labstack.com/) and [Express](https://expressjs.com/) frameworks for inspiring parts of this error-handling approach. ================================================ FILE: docs/guide/extractors.md ================================================ --- id: extractors title: 🔬 Extractors description: Learn how to use extractors in Fiber middleware sidebar_position: 8.5 toc_max_heading_level: 4 --- The extractors package provides shared value extraction utilities for Fiber middleware packages. It helps reduce code duplication across middleware packages while ensuring consistent behavior and security practices. ## Overview The `github.com/gofiber/fiber/v3/extractors` module provides standardized value extraction utilities integrated into Fiber's middleware ecosystem. This approach: - **Reduces Code Duplication**: Eliminates redundant extractor implementations across middleware packages - **Ensures Consistency**: Maintains identical behavior and security practices across all extractors - **Simplifies Maintenance**: Changes to extraction logic only need to be made in one place - **Enables Direct Usage**: Middleware can import and use extractors directly - **Improves Performance**: Shared, optimized extraction functions reduce overhead ## What Are Extractors? Extractors are utilities that middleware uses to get values from different parts of HTTP requests: ### Available Extractors - `FromAuthHeader(authScheme string)`: Extract from Authorization header with optional scheme - `FromCookie(key string)`: Extract from HTTP cookies - `FromParam(param string)`: Extract from URL path parameters - `FromForm(param string)`: Extract from form data - `FromHeader(header string)`: Extract from custom HTTP headers - `FromQuery(param string)`: Extract from URL query parameters - `FromCustom(key string, fn func(fiber.Ctx) (string, error))`: Define custom extraction logic with metadata - `Chain(extractors ...Extractor)`: Chain multiple extractors with fallback logic ### Extractor Structure Each `Extractor` contains: ```go type Extractor struct { Extract func(fiber.Ctx) (string, error) // Extraction function Key string // Parameter/header name Source Source // Source type for inspection AuthScheme string // Auth scheme (FromAuthHeader) Chain []Extractor // Chained extractors } ``` - **Headers**: `Authorization`, `X-API-Key`, custom headers - **Cookies**: Session cookies, authentication tokens - **Query Parameters**: URL parameters like `?token=abc123` - **Form Data**: POST body form fields - **URL Parameters**: Route parameters like `/users/:id` ### Chain Behavior The `Chain` function creates extractors that try multiple sources in order: - Returns the first successful extraction (non-empty value with no error) - If all extractors fail, returns the last error encountered or `ErrNotFound` - **Robust error handling**: Skips extractors with `nil` Extract functions - Preserves the source and key from the first extractor for metadata - Stores a defensive copy of all chained extractors for introspection via the `Chain` field ## Why Middleware Uses Extractors Middleware needs to extract values from requests for authentication, authorization, and other purposes. Extractors provide: - **Security Awareness**: Different sources have different security implications - **Fallback Support**: Try multiple sources if the first one doesn't have the value - **Consistency**: Same extraction logic across all middleware packages - **Source Tracking**: Know where values came from for security decisions ## Usage Examples ### Basic Usage ```go // KeyAuth middleware extracts key from header app.Use(keyauth.New(keyauth.Config{ Extractor: extractors.FromHeader("Middleware-Key"), })) ``` ### Fallback Chains ```go // Try multiple sources in order tokenExtractor := extractors.Chain( extractors.FromHeader("Middleware-Key"), // Try header first extractors.FromCookie("middleware_key"), // Then cookie extractors.FromQuery("middleware_key"), // Finally query param ) app.Use(keyauth.New(keyauth.Config{ Extractor: tokenExtractor, })) ``` ## Configuring Middleware That Uses Extractors ### Authentication Middleware ```go // KeyAuth middleware (default: FromAuthHeader) app.Use(keyauth.New(keyauth.Config{ // Default extracts from Authorization header // Extractor: extractors.FromAuthHeader("Bearer"), })) // Custom header extraction app.Use(keyauth.New(keyauth.Config{ Extractor: extractors.FromHeader("X-API-Key"), })) // Multiple sources with secure fallback app.Use(keyauth.New(keyauth.Config{ Extractor: extractors.Chain( extractors.FromAuthHeader("Bearer"), // Secure first extractors.FromHeader("X-API-Key"), // Then custom header extractors.FromQuery("api_key"), // Least secure last ), })) ``` ### Session Middleware ```go // Session middleware (default: FromCookie) app.Use(session.New(session.Config{ // Default extracts from session_id cookie // Extractor: extractors.FromCookie("session_id"), })) // Custom cookie name app.Use(session.New(session.Config{ Extractor: extractors.FromCookie("my_session"), })) ``` ### CSRF Middleware ```go // CSRF middleware (default: FromHeader) app.Use(csrf.New(csrf.Config{ // Default extracts from X-CSRF-Token header // Extractor: extractors.FromHeader("X-CSRF-Token"), })) // Form-based CSRF (less secure, use only if needed) app.Use(csrf.New(csrf.Config{ Extractor: extractors.Chain( extractors.FromHeader("X-CSRF-Token"), // Secure first extractors.FromForm("_csrf"), // Form fallback ), })) ``` ## Security Considerations ### Source Characteristics Different extraction sources have different security properties and use cases: #### Headers (Generally Preferred) - **Authorization Header**: Standard for authentication tokens, widely supported - **Custom Headers**: Application-specific, less likely to be logged by default - **Considerations**: Can be intercepted without HTTPS, may be stripped by proxies #### Cookies (Good for Sessions) - **Session Cookies**: Designed for secure client-side storage - **Considerations**: Require proper `Secure`, `HttpOnly`, and `SameSite` flags - **Best for**: Session management, remember-me tokens #### Query Parameters (Use Sparingly) - **Query parameters**: Convenient for simple APIs and debugging - **Considerations**: Always visible in URLs, logged by servers/proxies, stored in browser history - **Best for**: Non-sensitive parameters, public identifiers #### Form Data (Context Dependent) - **POST Bodies**: Suitable for form submissions and API requests - **Considerations**: Avoid putting sensitive data in query strings; ensure request bodies aren’t logged and use the correct content type - **Best for**: User-generated content, file uploads ### Security Best Practices 1. **Use HTTPS**: Encrypt all traffic to protect extracted values in transit 2. **Validate Input**: Always validate and sanitize extracted values 3. **Log Carefully**: Avoid logging sensitive values from any source 4. **Choose Appropriate Sources**: Match the source to your security requirements 5. **Test Thoroughly**: Verify extraction works in your environment 6. **Monitor Security**: Watch for extraction failures or unusual patterns ### Chain Ordering Strategy When using multiple sources, order them by your security preferences: ```go // Example: Prefer headers, fall back to cookies, then query extractors.Chain( extractors.FromAuthHeader("Bearer"), // Standard auth extractors.FromCookie("auth_token"), // Secure storage extractors.FromQuery("token"), // Public fallback ) ``` The "best" source depends on your specific use case, security requirements, and application architecture. ### Common Security Issues #### Leaky URLs ```go // ❌ DON'T: API keys in URLs (visible in logs, history, bookmarks) app.Use(keyauth.New(keyauth.Config{ Extractor: extractors.FromQuery("api_key"), // PROBLEMATIC })) // ✅ DO: API keys in headers (not visible in URLs) app.Use(keyauth.New(keyauth.Config{ Extractor: extractors.FromHeader("X-API-Key"), // BETTER })) ``` #### Session Tokens in Query Parameters ```go // ❌ DON'T: Session tokens in URLs (can be bookmarked, leaked) app.Use(session.New(session.Config{ Extractor: extractors.FromQuery("session"), // PROBLEMATIC })) // ✅ DO: Session tokens in cookies (designed for this purpose) app.Use(session.New(session.Config{ Extractor: extractors.FromCookie("session_id"), // BETTER })) ``` #### Form-Only CSRF Tokens While the default extractor uses headers, some implementations use form fields, which is fine if you don't have AJAX or API clients: ```go // ❌ DON'T: CSRF tokens only in forms (breaks AJAX, API calls) app.Use(csrf.New(csrf.Config{ Extractor: extractors.FromForm("_csrf"), // LIMITED })) // ✅ DO: Header-first with form fallback (works everywhere) app.Use(csrf.New(csrf.Config{ Extractor: extractors.Chain( extractors.FromHeader("X-CSRF-Token"), // PREFERRED extractors.FromForm("_csrf"), // FALLBACK ), })) ``` ### Understanding Trade-offs **No extractor is universally "secure" - security depends on:** - Whether you're using HTTPS - How you configure cookies (Secure, HttpOnly, SameSite flags) - Your logging and monitoring setup - The sensitivity of the data being extracted - Your threat model and security requirements Choose extractors based on your specific use case and security needs, not blanket "secure" vs "insecure" labels. ## Standards Compliance ### Authorization Header (RFC 9110 & RFC 7235) The `FromAuthHeader` extractor provides comprehensive RFC compliance with strict security validation: #### RFC 9110 Compliance (Authorization Header Format) - **Section 11.6.2 Format**: Enforces `credentials = auth-scheme 1*SP token68` structure - **1*SP Requirement**: Validates exactly one or more spaces between auth-scheme and token - **Case-insensitive scheme matching**: `Bearer`, `bearer`, `BEARER` all work correctly - **Proper whitespace handling**: Rejects tabs between scheme and token (only spaces allowed) #### RFC 7235 Token68 Validation The extractor implements strict token68 character validation per RFC 7235: - **Allowed characters**: `A-Z`, `a-z`, `0-9`, `-`, `.`, `_`, `~`, `+`, `/`, `=` - **Padding rules**: `=` characters only allowed at the end of tokens - **Security validation**: Prevents tokens starting with `=` or having non-padding characters after `=` - **Whitespace rejection**: Rejects tokens containing spaces, tabs, or any other whitespace #### Security Features - **Header injection prevention**: Strict parsing prevents malformed authorization headers from bypassing authentication - **Token validation**: Ensures extracted tokens conform to standards, preventing authentication bypass - **Consistent error handling**: Returns `ErrNotFound` for all invalid cases #### Examples ```go // Standard usage - strict validation extractor := extractors.FromAuthHeader("Bearer") // ✅ Valid cases: // "Bearer abc123" -> "abc123" // "bearer ABC123" -> "ABC123" (case-insensitive scheme) // "Bearer token123=" -> "token123=" (valid padding) // "Bearer token==" -> "token==" (valid multiple padding) // ❌ Invalid cases (all return ErrNotFound): // "Bearer abc def" -> rejected (space in token) // "Bearer abc\tdef" -> rejected (tab in token) // "Bearer =abc" -> rejected (padding at start) // "Bearer ab=cd" -> rejected (padding in middle) // "Bearer token" -> rejected (multiple spaces after scheme) // "Bearer\ttoken" -> rejected (tab after scheme) // "Bearertoken" -> rejected (no space after scheme) // Raw header extraction (no validation) rawExtractor := extractors.FromAuthHeader("") // "CustomAuth anything goes here" -> "CustomAuth anything goes here" ``` #### Benefits - **Standards Compliance**: Full adherence to HTTP authentication RFCs - **Security Hardening**: Prevents common authentication bypass vulnerabilities - **Consistent Behavior**: Reliable parsing across different client implementations - **Developer Confidence**: Clear validation rules reduce authentication bugs ## Troubleshooting ### Extraction Fails **Problem**: Middleware returns "value not found" or authentication fails **Solutions**: 1. Check if the expected header/cookie/query parameter is present 2. Verify the key name matches exactly (headers are case-insensitive; params/cookies/query keys are case-sensitive) 3. Ensure the request uses the correct HTTP method (GET vs POST) 4. Check if middleware is configured with the right extractor **Debug Example**: ```go // Add simple debug logging (avoid logging secrets in production) app.Use(func(c fiber.Ctx) error { hdr := c.Get("X-API-Key") cookie := c.Cookies("session_id") if hdr != "" || cookie != "" { log.Printf("debug: X-API-Key present=%t, session_id present=%t", hdr != "", cookie != "") } return c.Next() }) ``` ### Wrong Source Used **Problem**: Values extracted from unexpected sources **Solutions**: 1. Check middleware configuration order 2. Verify chain order (first successful extraction wins) 3. Use more specific extractors when needed ### Security Warnings **Problem**: Getting security warnings in logs **Solutions**: 1. Switch to more secure sources (headers/cookies) 2. Use HTTPS to encrypt traffic 3. Review if sensitive data should be in that source ## Advanced Usage ### Custom Extraction Logic Extractors support custom extractors for complex scenarios: ```go // Extract from custom logic (rarely needed) customExtractor := extractors.FromCustom("my-source", func(c fiber.Ctx) (string, error) { // Complex extraction logic if value := c.Locals("computed_token"); value != nil { return value.(string), nil } return "", extractors.ErrNotFound }) ``` :::warning **Custom extractors break source awareness.** When you use `FromCustom`, middleware cannot determine where the value came from, which means: - **No automatic security warnings** for potentially insecure sources - **No source-based logging** or monitoring capabilities - **Developer responsibility** for ensuring the extraction is secure and appropriate **Only use `FromCustom` when:** - Standard extractors don't meet your needs - You've carefully evaluated the security implications - You're confident in the security of your custom extraction logic - You understand that middleware cannot provide source-aware security guidance **Note:** If you pass `nil` as the function parameter, `FromCustom` will return an extractor that always fails with `ErrNotFound`. ::: ### Multiple Middleware Coordination When using multiple middleware that extract values, ensure they don't conflict: ```go // Good: Different sources for different purposes app.Use(keyauth.New(keyauth.Config{ Extractor: extractors.FromHeader("X-API-Key"), })) app.Use(session.New(session.Config{ Extractor: extractors.FromCookie("session_id"), })) // Avoid: Same source for different middleware app.Use(keyauth.New(keyauth.Config{ Extractor: extractors.FromCookie("token"), // API auth })) app.Use(session.New(session.Config{ Extractor: extractors.FromCookie("token"), // Session - CONFLICT! })) ``` ================================================ FILE: docs/guide/faster-fiber.md ================================================ --- id: faster-fiber title: ⚡ Make Fiber Faster sidebar_position: 7 --- ## Custom JSON Encoder/Decoder Fiber defaults to the standard `encoding/json` for stability and reliability. If you need more speed, consider these libraries: - [goccy/go-json](https://github.com/goccy/go-json) - [bytedance/sonic](https://github.com/bytedance/sonic) - [segmentio/encoding](https://github.com/segmentio/encoding) - [minio/simdjson-go](https://github.com/minio/simdjson-go) ```go title="Example" package main import "github.com/gofiber/fiber/v3" import "github.com/goccy/go-json" func main() { app := fiber.New(fiber.Config{ JSONEncoder: json.Marshal, JSONDecoder: json.Unmarshal, }) // ... } ``` ### References - [Set custom JSON encoder for client](../client/rest.md#setjsonmarshal) - [Set custom JSON decoder for client](../client/rest.md#setjsonunmarshal) - [Set custom JSON encoder for application](../api/fiber.md#jsonencoder) - [Set custom JSON decoder for application](../api/fiber.md#jsondecoder) ================================================ FILE: docs/guide/grouping.md ================================================ --- id: grouping title: 🎭 Grouping sidebar_position: 2 --- :::info Grouping works like Express.js. Groups are virtual; routes are flattened with the group's prefix and executed in declaration order, mirroring Express.js. ::: ## Paths Groups can use path prefixes to organize related routes. ```go func main() { app := fiber.New() api := app.Group("/api", middleware) // /api v1 := api.Group("/v1", middleware) // /api/v1 v1.Get("/list", handler) // /api/v1/list v1.Get("/user", handler) // /api/v1/user v2 := api.Group("/v2", middleware) // /api/v2 v2.Get("/list", handler) // /api/v2/list v2.Get("/user", handler) // /api/v2/user log.Fatal(app.Listen(":3000")) } ``` :::note Group prefixes follow the same slash-boundary rule as `app.Use`. A prefix must either match the full path or stop at a `/`, so `/api` applies to `/api` and `/api/v1` but not `/apiv1`. Parameter markers (for example `:id`, `:id?`, `*`, and `+`) are processed before checking the boundary. ::: Groups can also include an optional handler. ```go func main() { app := fiber.New() api := app.Group("/api") // /api v1 := api.Group("/v1") // /api/v1 v1.Get("/list", handler) // /api/v1/list v1.Get("/user", handler) // /api/v1/user v2 := api.Group("/v2") // /api/v2 v2.Get("/list", handler) // /api/v2/list v2.Get("/user", handler) // /api/v2/user log.Fatal(app.Listen(":3000")) } ``` :::caution Accessing `/api`, `/v1`, or `/v2` directly returns a **404**, so add error handlers as needed. ::: ## Group Handlers Group handlers can act as routing paths but must call `Next` to continue the flow. ```go func main() { app := fiber.New() handler := func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) } api := app.Group("/api") // /api v1 := api.Group("/v1", func(c fiber.Ctx) error { // middleware for /api/v1 c.Set("Version", "v1") return c.Next() }) v1.Get("/list", handler) // /api/v1/list v1.Get("/user", handler) // /api/v1/user log.Fatal(app.Listen(":3000")) } ``` ================================================ FILE: docs/guide/reverse-proxy.md ================================================ --- id: reverse-proxy title: 🔄 Reverse Proxy Configuration description: >- Learn how to set up reverse proxies like Nginx or Traefik to enable modern HTTP capabilities in your Fiber application, including HTTP/2 and HTTP/3 (QUIC) support. This guide also covers basic reverse proxy configuration and links to external documentation. sidebar_position: 4 --- ## Reverse Proxies Running Fiber behind a reverse proxy is a common production setup. Reverse proxies can handle: - **HTTPS/TLS termination** (offloading SSL certificates) - **Protocol upgrades** (HTTP/2, HTTP/3 support) - **Request routing & load balancing** - **Caching & compression** - **Security features** (rate limiting, WAF, DDoS mitigation) Some Fiber features (like [`SendEarlyHints`](../api/ctx.md#sendearlyhints)) require **HTTP/2 or newer**, which is easiest to enable using a reverse proxy. ### Popular Reverse Proxies - [Nginx](https://nginx.org/) - [Traefik](https://traefik.io/) - [HA PROXY](https://www.haproxy.com/) - [Caddy](https://caddyserver.com/) ## Getting the Real Client IP Address When your Fiber application is behind a reverse proxy, the TCP connection comes from the proxy server, not the actual client. To get the real client IP address, you need to configure Fiber to read it from proxy headers like `X-Forwarded-For`. :::warning Security Warning Proxy headers can be easily spoofed by malicious clients. **Always** configure `TrustProxyConfig` to validate the proxy IP address, otherwise attackers can forge headers to bypass IP-based access controls, rate limiting, or geolocation features. In addition, your reverse proxy should be configured to **set or overwrite** the forwarding header you choose (for example, `X-Forwarded-For`) based on the real client connection, or to use its real IP / PROXY protocol features. Do not simply pass through client-supplied forwarding headers, or `c.IP()` may still be controlled by an attacker even when `TrustProxyConfig` is correct. ::: ### Configuration To enable reading the client IP from proxy headers, you must configure **three settings**: 1. **`TrustProxy`** - Enable proxy header trust (must be `true`) 2. **`ProxyHeader`** - Specify which header contains the client IP 3. **`TrustProxyConfig`** - Define which proxy IPs to trust ```go title="Example - App Behind Nginx" app := fiber.New(fiber.Config{ // Enable proxy support TrustProxy: true, // Read client IP from X-Forwarded-For header ProxyHeader: fiber.HeaderXForwardedFor, // Trust requests from your Nginx proxy TrustProxyConfig: fiber.TrustProxyConfig{ // Option 1: Trust specific proxy IPs Proxies: []string{"10.10.0.58", "192.168.1.0/24"}, // Option 2: Or trust all private IPs (useful for internal load balancers) // Private: true, }, }) ``` ### Common Proxy Headers Different proxies use different headers: | Proxy/Service | Recommended Header | Config Value | |---------------|-------------------|--------------| | Nginx, HAProxy, Apache | X-Forwarded-For | `fiber.HeaderXForwardedFor` | | Cloudflare | CF-Connecting-IP | `"CF-Connecting-IP"` | | Fastly | Fastly-Client-IP | `"Fastly-Client-IP"` | | Generic | X-Real-IP | `"X-Real-IP"` | ### TrustProxyConfig Options The `TrustProxyConfig` struct provides multiple ways to specify trusted proxies: ```go TrustProxyConfig: fiber.TrustProxyConfig{ // Specific IPs or CIDR ranges Proxies: []string{ "10.10.0.58", // Single IP "192.168.0.0/24", // CIDR range "2001:db8::/32", // IPv6 range }, // Or use convenience flags: Loopback: true, // Trust 127.0.0.0/8, ::1/128 LinkLocal: true, // Trust 169.254.0.0/16, fe80::/10 Private: true, // Trust 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16, fc00::/7 UnixSocket: true, // Trust Unix domain socket connections }, ``` ### Complete Example with Nginx ```nginx title="nginx.conf" server { listen 443 ssl; http2 on; server_name example.com; ssl_certificate /etc/ssl/certs/example.crt; ssl_certificate_key /etc/ssl/private/example.key; location / { proxy_pass http://127.0.0.1:3000; proxy_http_version 1.1; proxy_set_header Connection ""; proxy_set_header Host $host; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; } } ``` ```go title="main.go" package main import ( "log" "github.com/gofiber/fiber/v3" ) func main() { app := fiber.New(fiber.Config{ TrustProxy: true, ProxyHeader: fiber.HeaderXForwardedFor, EnableIPValidation: true, TrustProxyConfig: fiber.TrustProxyConfig{ // Trust localhost since Nginx is on the same machine Loopback: true, }, }) app.Get("/", func(c fiber.Ctx) error { // This will now return the real client IP from X-Forwarded-For // instead of 127.0.0.1 return c.SendString("Your IP: " + c.IP()) }) log.Fatal(app.Listen(":3000")) } ``` ### Testing Your Configuration You can verify your configuration is working: ```go app.Get("/debug", func(c fiber.Ctx) error { return c.JSON(fiber.Map{ "c.IP()": c.IP(), // Should show real client IP "X-Forwarded-For": c.Get("X-Forwarded-For"), // Raw header value "IsProxyTrusted": c.IsProxyTrusted(), // Should be true "RemoteIP": c.RequestCtx().RemoteIP().String(), // Proxy IP }) }) ``` ## Enabling HTTP/2 Popular choices include Nginx and Traefik.
Nginx Example See the [Complete Example with Nginx](#complete-example-with-nginx) above for a full configuration with HTTP/2 enabled.
Traefik Example ```yaml title="traefik.yaml" entryPoints: websecure: address: ":443" http: routers: app: rule: "Host(`example.com`)" entryPoints: - websecure service: app tls: {} services: app: loadBalancer: servers: - url: "http://127.0.0.1:3000" ``` With this configuration, Traefik terminates TLS and serves your app over HTTP/2.
## HTTP/3 (QUIC) Support Early Hints (103 responses) are defined for HTTP and can be delivered over HTTP/1.1 and HTTP/2/3. In practice, browsers process 103 most reliably over HTTP/2/3. Many reverse proxies also support HTTP/3 (QUIC): - **Nginx** - **Traefik** Enabling HTTP/3 is optional but can provide lower latency and improved performance for clients that support it. If you enable HTTP/3, your Early Hints responses will still work as expected. For more details, see the official documentation: - [Nginx QUIC / HTTP/3](https://nginx.org/en/docs/quic.html) - [Traefik HTTP/3](https://doc.traefik.io/traefik/reference/install-configuration/entrypoints/#http3) ================================================ FILE: docs/guide/routing.md ================================================ --- id: routing title: 🔌 Routing description: >- Routing refers to how an application's endpoints (URIs) respond to client requests. sidebar_position: 1 toc_max_heading_level: 4 --- import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; import RoutingHandler from './../partials/routing/handler.md'; ## Handlers ## Automatic HEAD routes Fiber automatically registers a `HEAD` route for every `GET` route you add. The generated handler chain mirrors the `GET` chain, so `HEAD` requests reuse middleware, status codes, and headers while the response body is suppressed. ```go title="GET handlers automatically expose HEAD" app := fiber.New() app.Get("/users/:id", func(c fiber.Ctx) error { c.Set("X-User", c.Params("id")) return c.SendStatus(fiber.StatusOK) }) // HEAD /users/:id now returns the same headers and status without a body. ``` You can still register dedicated `HEAD` handlers—even with auto-registration enabled—and Fiber replaces the generated route so your implementation wins: ```go title="Override the generated HEAD handler" app.Head("/users/:id", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusNoContent) }) ``` To opt out globally, start the app with `DisableHeadAutoRegister`: ```go title="Disable automatic HEAD registration" handler := func(c fiber.Ctx) error { c.Set("X-User", c.Params("id")) return c.SendStatus(fiber.StatusOK) } app := fiber.New(fiber.Config{DisableHeadAutoRegister: true}) app.Get("/users/:id", handler) // HEAD /users/:id now returns 405 unless you add it manually. ``` Auto-generated `HEAD` routes participate in every router scope, including `Group` hierarchies, mounted sub-apps, parameterized and wildcard paths, and static file helpers. They also appear in route listings such as `app.Stack()` so tooling sees both the `GET` and `HEAD` entries. ## Paths A route path paired with an HTTP method defines an endpoint. It can be a plain **string** or a **pattern**. ### Examples of route paths based on strings ```go // This route path will match requests to the root route, "/": app.Get("/", func(c fiber.Ctx) error { return c.SendString("root") }) // This route path will match requests to "/about": app.Get("/about", func(c fiber.Ctx) error { return c.SendString("about") }) // This route path will match requests to "/random.txt": app.Get("/random.txt", func(c fiber.Ctx) error { return c.SendString("random.txt") }) ``` As with the Express.js framework, the order in which routes are declared matters. Routes are evaluated sequentially, so more specific paths should appear before those with variables. :::info Place routes with variable parameters after fixed paths to avoid unintended matches. ::: ## Parameters Route parameters are dynamic segments in a path, either named or unnamed, used to capture values from the URL. Retrieve them with the [Params](../api/ctx.md#params) function using the parameter name or, for unnamed parameters, the wildcard (`*`) or plus (`+`) symbol with an index. The characters `:`, `+`, and `*` introduce parameters. Use `*` or `+` to capture segments greedily. You can define optional parameters by appending `?` to a named segment. The `+` sign is greedy and required, while `*` acts as an optional greedy wildcard. ### Example of defining routes with route parameters ```go // Parameters app.Get("/user/:name/books/:title", func(c fiber.Ctx) error { fmt.Fprintf(c, "%s\n", c.Params("name")) fmt.Fprintf(c, "%s\n", c.Params("title")) return nil }) // Plus - greedy - not optional app.Get("/user/+", func(c fiber.Ctx) error { return c.SendString(c.Params("+")) }) // Optional parameter app.Get("/user/:name?", func(c fiber.Ctx) error { return c.SendString(c.Params("name")) }) // Wildcard - greedy - optional app.Get("/user/*", func(c fiber.Ctx) error { return c.SendString(c.Params("*")) }) // This route path will match requests to "/v1/some/resource/name:customVerb", since the parameter character is escaped app.Get(`/v1/some/resource/name\:customVerb`, func(c fiber.Ctx) error { return c.SendString("Hello, Community") }) ``` :::info The hyphen \(`-`\) and dot \(`.`\) are treated literally, so you can combine them with route parameters. ::: :::info Escape special parameter characters with `\\` to treat them literally. This technique is useful for custom methods like those in the [Google API Design Guide](https://cloud.google.com/apis/design/custom_methods). Wrap routes in backticks to keep escape sequences clear. ::: ```go // http://localhost:3000/plantae/prunus.persica app.Get("/plantae/:genus.:species", func(c fiber.Ctx) error { fmt.Fprintf(c, "%s.%s\n", c.Params("genus"), c.Params("species")) return nil // prunus.persica }) ``` ```go // http://localhost:3000/flights/LAX-SFO app.Get("/flights/:from-:to", func(c fiber.Ctx) error { fmt.Fprintf(c, "%s-%s\n", c.Params("from"), c.Params("to")) return nil // LAX-SFO }) ``` Fiber's router detects when these characters belong to the literal path and handles them accordingly. ```go // http://localhost:3000/shop/product/color:blue/size:xs app.Get("/shop/product/color::color/size::size", func(c fiber.Ctx) error { fmt.Fprintf(c, "%s:%s\n", c.Params("color"), c.Params("size")) return nil // blue:xs }) ``` You can chain multiple named or unnamed parameters—including wildcard and plus segments—giving the router greater flexibility. ```go // GET /@v1 // Params: "sign" -> "@", "param" -> "v1" app.Get("/:sign:param", handler) // GET /api-v1 // Params: "name" -> "v1" app.Get("/api-:name", handler) // GET /customer/v1/cart/proxy // Params: "*1" -> "customer/", "*2" -> "/cart" app.Get("/*v1*/proxy", handler) // GET /v1/brand/4/shop/blue/xs // Params: "*1" -> "brand/4", "*2" -> "blue/xs" app.Get("/v1/*/shop/*", handler) ``` Fiber's routing is inspired by Express but intentionally omits regular expression routes due to their performance cost. You can try similar patterns using the Express route tester (v0.1.7). ### Constraints Route constraints execute when a match has occurred to the incoming URL and the URL path is tokenized into route values by parameters. The feature was introduced in `v2.37.0` and inspired by [.NET Core](https://docs.microsoft.com/en-us/aspnet/core/fundamentals/routing?view=aspnetcore-6.0#route-constraints). :::caution Constraints aren't validation for parameters. If constraints aren't valid for a parameter value, Fiber returns **404 handler**. ::: | Constraint | Example | Example matches | | ----------------- | -------------------------------- | ------------------------------------------------------------------------------------------- | | int | `:id` | 123456789, -123456789 | | bool | `:active` | true,false | | guid | `:id` | CD2C1638-1638-72D5-1638-DEADBEEF1638 | | float | `:weight` | 1.234, -1,001.01e8 | | minLen(value) | `:username` | Test (must be at least 4 characters) | | maxLen(value) | `:filename` | MyFile (must be no more than 8 characters | | len(length) | `:filename` | somefile.txt (exactly 12 characters) | | min(value) | `:age` | 19 (Integer value must be at least 18) | | max(value) | `:age` | 91 (Integer value must be no more than 120) | | range(min,max) | `:age` | 91 (Integer value must be at least 18 but no more than 120) | | alpha | `:name` | Rick (String must consist of one or more alphabetical characters, a-z and case-insensitive) | | datetime | `:dob` | 2005-11-01 | | regex(expression) | `:date` | 2022-08-27 (Must match regular expression) | #### Examples ```go app.Get("/:test", func(c fiber.Ctx) error { return c.SendString(c.Params("test")) }) // curl -X GET http://localhost:3000/12 // 12 // curl -X GET http://localhost:3000/1 // Not Found ``` You can use `;` for multiple constraints. ```go app.Get("/:test", func(c fiber.Ctx) error { return c.SendString(c.Params("test")) }) // curl -X GET http://localhost:3000/120000 // Not Found // curl -X GET http://localhost:3000/1 // Not Found // curl -X GET http://localhost:3000/250 // 250 ``` Fiber precompiles the regex when registering routes, so regex constraints add no runtime overhead. ```go app.Get(`/:date`, func(c fiber.Ctx) error { return c.SendString(c.Params("date")) }) // curl -X GET http://localhost:3000/125 // Not Found // curl -X GET http://localhost:3000/test // Not Found // curl -X GET http://localhost:3000/2022-08-27 // 2022-08-27 ``` :::caution Prefix routing characters with `\\` when using the datetime constraint (`*`, `+`, `?`, `:`, `/`, `<`, `>`, `;`, `(`, `)`), to avoid misparsing. ::: #### Optional Parameter Example You can impose constraints on optional parameters as well. ```go app.Get("/:test?", func(c fiber.Ctx) error { return c.SendString(c.Params("test")) }) // curl -X GET http://localhost:3000/42 // 42 // curl -X GET http://localhost:3000/ // // curl -X GET http://localhost:3000/7.0 // Not Found ``` #### Custom Constraint Custom constraints can be added to Fiber using the `app.RegisterCustomConstraint` method. Your constraints have to be compatible with the `CustomConstraint` interface. :::caution Attention, custom constraints can now override built-in constraints. If a custom constraint has the same name as a built-in constraint, the custom constraint will be used instead. This allows for more flexibility in defining route parameter constraints. ::: Add external constraints when you need stricter rules, such as verifying that a parameter is a valid ULID. ```go // CustomConstraint is an interface for custom constraints type CustomConstraint interface { // Name returns the name of the constraint. // This name is used in the constraint matching. Name() string // Execute executes the constraint. // It returns true if the constraint is matched and right. // param is the parameter value to check. // args are the constraint arguments. Execute(param string, args ...string) bool } ``` You can check the example below: ```go type UlidConstraint struct { fiber.CustomConstraint } func (*UlidConstraint) Name() string { return "ulid" } func (*UlidConstraint) Execute(param string, args ...string) bool { _, err := ulid.Parse(param) return err == nil } func main() { app := fiber.New() app.RegisterCustomConstraint(&UlidConstraint{}) app.Get("/login/:id", func(c fiber.Ctx) error { return c.SendString("...") }) app.Listen(":3000") // /login/01HK7H9ZE5BFMK348CPYP14S0Z -> 200 // /login/12345 -> 404 } ``` ## Middleware Functions that are designed to make changes to the request or response are called **middleware functions**. The [Next](../api/ctx.md#next) is a **Fiber** router function, when called, executes the **next** function that **matches** the current route. ### Example of a middleware function ```go app.Use(func(c fiber.Ctx) error { // Set a custom header on all responses: c.Set("X-Custom-Header", "Hello, World") // Go to next middleware: return c.Next() }) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) ``` `Use` method path is a **mount**, or **prefix** path, and limits middleware to only apply to any paths requested that begin with it. :::note Prefix matches must now end at a slash boundary (or be an exact match). For example, `/api` runs for `/api` and `/api/users` but no longer for `/apiv2`. Parameter tokens such as `:name`, `:name?`, `*`, and `+` are still expanded before this boundary check runs. ::: ### Constraints on Adding Routes Dynamically :::caution Adding routes dynamically after the application has started is not supported due to design and performance considerations. Make sure to define all your routes before the application starts. ::: ## Grouping If you have many endpoints, you can organize your routes using `Group`. ```go func main() { app := fiber.New() api := app.Group("/api", middleware) // /api v1 := api.Group("/v1", middleware) // /api/v1 v1.Get("/list", handler) // /api/v1/list v1.Get("/user", handler) // /api/v1/user v2 := api.Group("/v2", middleware) // /api/v2 v2.Get("/list", handler) // /api/v2/list v2.Get("/user", handler) // /api/v2/user log.Fatal(app.Listen(":3000")) } ``` More information about this in our [Grouping Guide](./grouping.md) ================================================ FILE: docs/guide/templates.md ================================================ --- id: templates title: 📝 Templates description: Fiber supports server-side template engines. sidebar_position: 3 --- import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; Templates render dynamic content without requiring a separate frontend framework. ## Template Engines Fiber accepts a custom template engine during app initialization. ```go app := fiber.New(fiber.Config{ // Provide a template engine Views: engine, // Default path for views, overridden when calling Render() ViewsLayout: "layouts/main", // Enables/Disables access to `ctx.Locals()` entries in rendered views // (defaults to false) PassLocalsToViews: false, }) ``` ### Supported Engines Fiber maintains a [templates](https://docs.gofiber.io/template) package that wraps several engines: * [ace](https://docs.gofiber.io/template/ace/) * [amber](https://docs.gofiber.io/template/amber/) * [django](https://docs.gofiber.io/template/django/) * [handlebars](https://docs.gofiber.io/template/handlebars) * [html](https://docs.gofiber.io/template/html) * [jet](https://docs.gofiber.io/template/jet) * [mustache](https://docs.gofiber.io/template/mustache) * [pug](https://docs.gofiber.io/template/pug) * [slim](https://docs.gofiber.io/template/slim) :::info Custom engines implement the `Views` interface to work with Fiber. ::: ```go title="Views interface" type Views interface { // Fiber executes Load() on app initialization to load/parse the templates Load() error // Outputs a template to the provided buffer using the provided template, // template name, and bound data Render(io.Writer, string, interface{}, ...string) error } ``` :::note The `Render` method powers [**ctx.Render\(\)**](../api/ctx.md#render), which accepts a template name and data to bind. ::: ## Rendering Templates After configuring an engine, handlers call [**ctx.Render\(\)**](../api/ctx.md#render) with a template name and data to send the rendered output. ```go title="Signature" func (c Ctx) Render(name string, bind Map, layouts ...string) error ``` :::info By default, [**ctx.Render\(\)**](../api/ctx.md#render) searches for the template in the `ViewsLayout` path. Pass alternate paths in the `layouts` argument to override this behavior. ::: ```go app.Get("/", func(c fiber.Ctx) error { return c.Render("index", fiber.Map{ "Title": "Hello, World!", }) }) ``` ```html

{{.Title}}

```
:::caution When `PassLocalsToViews` is enabled, all values set using `ctx.Locals(key, value)` are passed to the template. Use unique keys to avoid collisions. ::: ## Advanced Templating ### Custom Functions Fiber supports adding custom functions to templates. #### AddFunc Adds a global function to all templates. ```go title="Signature" func (e *Engine) AddFunc(name string, fn interface{}) IEngineCore ``` ```go // Add `ToUpper` to engine engine := html.New("./views", ".html") engine.AddFunc("ToUpper", func(s string) string { return strings.ToUpper(s) } // Initialize Fiber App app := fiber.New(fiber.Config{ Views: engine, }) app.Get("/", func (c fiber.Ctx) error { return c.Render("index", fiber.Map{ "Content": "hello, World!" }) }) ``` ```html

This will be in {{ToUpper "all caps"}}:

{{ToUpper .Content}}

```
#### AddFuncMap Adds a Map of functions (keyed by name) to all templates. ```go title="Signature" func (e *Engine) AddFuncMap(m map[string]interface{}) IEngineCore ``` ```go // Add `ToUpper` to engine engine := html.New("./views", ".html") engine.AddFuncMap(map[string]interface{}{ "ToUpper": func(s string) string { return strings.ToUpper(s) }, }) // Initialize Fiber App app := fiber.New(fiber.Config{ Views: engine, }) app.Get("/", func (c fiber.Ctx) error { return c.Render("index", fiber.Map{ "Content": "hello, world!" }) }) ``` ```html

This will be in {{ToUpper "all caps"}}:

{{ToUpper .Content}}

```
* For more advanced template documentation, please visit the [gofiber/template GitHub Repository](https://github.com/gofiber/template). ## Full Example ```go package main import ( "log" "github.com/gofiber/fiber/v3" "github.com/gofiber/template/html/v2" ) func main() { // Initialize standard Go html template engine engine := html.New("./views", ".html") // If you want to use another engine, // just replace with following: // Create a new engine with django // engine := django.New("./views", ".django") app := fiber.New(fiber.Config{ Views: engine, }) app.Get("/", func(c fiber.Ctx) error { // Render index template return c.Render("index", fiber.Map{ "Title": "Go Fiber Template Example", "Description": "An example template", "Greeting": "Hello, World!", }); }) log.Fatal(app.Listen(":3000")) } ``` ```html {{.Title}}

{{.Title}}

{{.Greeting}}

```
================================================ FILE: docs/guide/utils.md ================================================ --- id: utils title: 🧰 Utils sidebar_position: 8 toc_max_heading_level: 4 --- ## Generics ### Convert Converts a string to a specific type while handling errors and optional defaults. It wraps conversion and fallback logic to keep your code clean and consistent. ```go title="Signature" func Convert[T any](value string, converter func(string) (T, error), defaultValue ...T) (T, error) ``` ```go title="Example" // GET http://example.com/id/bb70ab33-d455-4a03-8d78-d3c1dacae9ff app.Get("/id/:id", func(c fiber.Ctx) error { fiber.Convert(c.Params("id"), uuid.Parse) // UUID(bb70ab33-d455-4a03-8d78-d3c1dacae9ff), nil }) // GET http://example.com/search?id=65f6f54221fb90e6a6b76db7 app.Get("/search", func(c fiber.Ctx) error { fiber.Convert(c.Query("id"), mongo.ParseObjectID) // objectid(65f6f54221fb90e6a6b76db7), nil fiber.Convert(c.Query("id"), uuid.Parse) // uuid.Nil, error(cannot parse given uuid) fiber.Convert(c.Query("id"), uuid.Parse, mongo.NewObjectID) // new object id generated and return nil as error. return nil }) // ... ``` ### GetReqHeader Retrieves an HTTP request header as a specific type using generics. ```go title="Signature" func GetReqHeader[V GenericType](c Ctx, key string, defaultValue ...V) V ``` ```go title="Example" app.Get("/search", func(c fiber.Ctx) error { // curl -X GET http://example.com/search -H "X-Request-ID: 12345" -H "X-Request-Name: John" fiber.GetReqHeader[int](c, "X-Request-ID") // => returns 12345 as integer. fiber.GetReqHeader[string](c, "X-Request-Name") // => returns "John" as string. fiber.GetReqHeader[string](c, "unknownParam", "default") // => returns "default" as string. // ... }) ``` ### Locals Reads or writes local values in the request context using generics. ```go title="Signature" // Set a value func Locals[V any](c Ctx, key any, value ...V) V // Get a value func Locals[V any](c Ctx, key any) V ``` ```go title="Example" app.Use("/user/:user/:id", func(c fiber.Ctx) error { // set local values fiber.Locals[string](c, "user", "john") fiber.Locals[int](c, "id", 25) // ... return c.Next() }) app.Get("/user/*", func(c fiber.Ctx) error { // get local values name := fiber.Locals[string](c, "user") // john age := fiber.Locals[int](c, "id") // 25 // ... }) ``` ### Params Retrieves route parameters as a specific type. ```go title="Signature" func Params[V GenericType](c Ctx, key string, defaultValue ...V) V ``` ```go title="Example" app.Get("/user/:user/:id", func(c fiber.Ctx) error { // http://example.com/user/john/25 fiber.Params[int](c, "id") // => returns 25 as integer. fiber.Params[int](c, "unknownParam", 99) // => returns the default 99 as integer. // ... return c.SendString("Hello, " + fiber.Params[string](c, "user")) }) ``` ### Query Retrieves query parameters as a specific type. ```go title="Signature" func Query[V GenericType](c Ctx, key string, defaultValue ...V) V ``` ```go title="Example" app.Get("/search", func(c fiber.Ctx) error { // http://example.com/search?name=john&age=25 fiber.Query[string](c, "name") // => returns "john" fiber.Query[int](c, "age") // => returns 25 as integer. fiber.Query[string](c, "unknownParam", "default") // => returns "default" as string. // ... }) ``` ### RoutePatternMatch Checks whether a given path matches a Fiber route pattern. Useful for testing patterns without registering them. Patterns may contain parameters, wildcards and optional segments. An optional `Config` allows control over case sensitivity and strict routing. ```go title="Signature" func RoutePatternMatch(path, pattern string, cfg ...Config) bool ``` ```go title="Example" fiber.RoutePatternMatch("/user/john", "/user/:name") // true fiber.RoutePatternMatch( "/User/john", "/user/:name", fiber.Config{CaseSensitive: true}, ) // false ``` ================================================ FILE: docs/guide/validation.md ================================================ --- id: validation title: 🔎 Validation sidebar_position: 5 --- ## Validator package Fiber's [Bind](../api/bind.md#validation) function binds request data to a struct and validates it. ```go title="Basic Example" import "github.com/go-playground/validator/v10" type structValidator struct { validate *validator.Validate } // Validator needs to implement the Validate method func (v *structValidator) Validate(out any) error { return v.validate.Struct(out) } // Set up your validator in the config app := fiber.New(fiber.Config{ StructValidator: &structValidator{validate: validator.New()}, }) // Note: // StructValidator runs only for struct destinations (or pointers to structs). // Binding into maps and other non-struct types skips validation. type User struct { Name string `json:"name" form:"name" query:"name" validate:"required"` Age int `json:"age" form:"age" query:"age" validate:"gte=0,lte=100"` } app.Post("/", func(c fiber.Ctx) error { user := new(User) // Works with all bind methods—Body, Query, Form, etc. if err := c.Bind().Body(user); err != nil { // validation errors are returned here return err } return c.JSON(user) }) ``` ```go title="Advanced Validation Example" type User struct { Name string `json:"name" validate:"required,min=3,max=32"` Email string `json:"email" validate:"required,email"` Age int `json:"age" validate:"gte=0,lte=100"` Password string `json:"password" validate:"required,min=8"` Website string `json:"website" validate:"url"` } // Custom validation error messages type UserWithCustomMessages struct { Name string `json:"name" validate:"required,min=3,max=32" message:"Name is required and must be between 3 and 32 characters"` Email string `json:"email" validate:"required,email" message:"Valid email is required"` Age int `json:"age" validate:"gte=0,lte=100" message:"Age must be between 0 and 100"` } app.Post("/user", func(c fiber.Ctx) error { user := new(User) if err := c.Bind().Body(user); err != nil { // Handle validation errors if validationErrors, ok := err.(validator.ValidationErrors); ok { for _, e := range validationErrors { // e.Field() - field name // e.Tag() - validation tag // e.Value() - invalid value // e.Param() - validation parameter return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ "field": e.Field(), "error": e.Error(), }) } } return err } return c.JSON(user) }) ``` ```go title="Custom Validator Example" // Custom validator for password strength type PasswordValidator struct { validate *validator.Validate } func (v *PasswordValidator) Validate(out any) error { if err := v.validate.Struct(out); err != nil { return err } // Custom password validation logic if user, ok := out.(*User); ok { if len(user.Password) < 8 { return errors.New("password must be at least 8 characters") } // Add more password validation rules here } return nil } // Usage app := fiber.New(fiber.Config{ StructValidator: &PasswordValidator{validate: validator.New()}, }) ``` ================================================ FILE: docs/intro.md ================================================ --- slug: / id: welcome title: 👋 Welcome sidebar_position: 1 --- Welcome to Fiber's online API documentation, complete with examples to help you start building web applications right away! **Fiber** is an [Express](https://github.com/expressjs/express)-inspired **web framework** built on top of [Fasthttp](https://github.com/valyala/fasthttp), the **fastest** HTTP engine for [Go](https://go.dev/doc/). It is designed to facilitate rapid development with **zero memory allocations** and a strong focus on **performance**. These docs cover **Fiber v3**. Looking to practice Fiber concepts hands-on? Check out our [Learning Resources](./extra/learning-resources) for interactive challenges and tutorials. ### Installation First, [download](https://go.dev/dl/) and install Go. Version `1.25` or higher is required. Install Fiber using the [`go get`](https://pkg.go.dev/cmd/go/#hdr-Add_dependencies_to_current_module_and_install_them) command: ```bash go get github.com/gofiber/fiber/v3 ``` ### Zero Allocation Fiber is optimized for **high performance**, meaning values returned from **fiber.Ctx** are **not** immutable by default and **will** be reused across requests. As a rule of thumb, you should use context values only within the handler and **must not** keep any references. Once you return from the handler, any values obtained from the context will be reused in future requests. Here is an example: ```go func handler(c fiber.Ctx) error { // Variable is only valid within this handler result := c.Params("foo") // ... } ``` If you need to persist such values outside the handler, make copies of their **underlying buffer** using the [copy](https://pkg.go.dev/builtin/#copy) builtin. Here is an example of persisting a string: ```go func handler(c fiber.Ctx) error { // Variable is only valid within this handler result := c.Params("foo") // Make a copy buffer := make([]byte, len(result)) copy(buffer, result) resultCopy := string(buffer) // Variable is now valid indefinitely // ... } ``` Fiber provides `GetString` and `GetBytes` methods on the app that detach values when `Immutable` is enabled and the data isn't already read-only. If it's disabled, use `utils.CopyString` and `utils.CopyBytes` to allocate only when necessary. ```go app.Get("/:foo", func(c fiber.Ctx) error { // Detach if necessary when Immutable is enabled result := c.App().GetString(c.Params("foo")) // ... }) ``` Alternatively, you can enable the `Immutable` setting. This makes all values returned from the context immutable, allowing you to persist them anywhere. Note that this comes at the cost of performance. ```go app := fiber.New(fiber.Config{ Immutable: true, }) ``` For more information, please refer to [#426](https://github.com/gofiber/fiber/issues/426), [#185](https://github.com/gofiber/fiber/issues/185), and [#3012](https://github.com/gofiber/fiber/issues/3012). ### Hello, World Here is the simplest **Fiber** application you can create: ```go package main import "github.com/gofiber/fiber/v3" func main() { app := fiber.New() app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) app.Listen(":3000") } ``` ```bash go run server.go ``` Browse to `http://localhost:3000` and you should see `Hello, World!` displayed on the page. ### Basic Routing Routing determines how an application responds to a client request at a particular endpoint—a combination of path and HTTP request method (`GET`, `PUT`, `POST`, etc.). Each route can have **multiple handler functions** that are executed when the route is matched. Route definitions follow the structure below: ```go // Function signature app.Method(path string, ...func(fiber.Ctx) error) ``` - `app` is an instance of **Fiber** - `Method` is an [HTTP request method](./api/app#route-handlers): `GET`, `PUT`, `POST`, etc. - `path` is a virtual path on the server - `func(fiber.Ctx) error` is a callback function containing the [Context](./api/ctx) executed when the route is matched #### Simple Route ```go // Respond with "Hello, World!" on root path "/" app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) ``` #### Parameters ```go // GET http://localhost:8080/hello%20world app.Get("/:value", func(c fiber.Ctx) error { return c.SendString("value: " + c.Params("value")) // => Response: "value: hello world" }) ``` #### Optional Parameter ```go // GET http://localhost:3000/john app.Get("/:name?", func(c fiber.Ctx) error { if c.Params("name") != "" { return c.SendString("Hello " + c.Params("name")) // => Response: "Hello john" } return c.SendString("Where is john?") // => Response: "Where is john?" }) ``` #### Wildcards ```go // GET http://localhost:3000/api/user/john app.Get("/api/*", func(c fiber.Ctx) error { return c.SendString("API path: " + c.Params("*")) // => Response: "API path: user/john" }) ``` ### Static Files To serve static files such as **images**, **CSS**, and **JavaScript** files, use the [static middleware](./middleware/static.md). Use the following code to serve files in a directory named `./public`: ```go package main import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/static" ) func main() { app := fiber.New() app.Use("/", static.New("./public")) app.Listen(":3000") } ``` Now, you can access the files in the `./public` directory via your browser: ```bash http://localhost:3000/hello.html http://localhost:3000/js/jquery.js http://localhost:3000/css/style.css ``` ================================================ FILE: docs/middleware/_category_.json ================================================ { "label": "\uD83E\uDDEC Middleware", "position": 4, "collapsed": true, "link": { "type": "generated-index", "description": "Middleware is a function chained in the HTTP request cycle with access to the Context which it uses to perform a specific action, for example, logging every request or enabling CORS." } } ================================================ FILE: docs/middleware/adaptor.md ================================================ --- id: adaptor --- # Adaptor The `adaptor` package converts between Fiber and `net/http`, letting you reuse handlers, middleware, and requests across both frameworks. :::tip Fiber can register plain `net/http` handlers directly—just pass an `http.Handler`, `http.HandlerFunc`, or `func(http.ResponseWriter, *http.Request)` to any router method and it will be adapted automatically. The adaptor helpers remain valuable when you need to convert middleware, swap handler directions, or transform requests explicitly. ::: :::caution Fiber features are unavailable Even when you register them directly, adapted `net/http` handlers still run with standard library semantics. They don't have access to `fiber.Ctx`, and the compatibility layer comes with additional overhead compared to native Fiber handlers. Use them for interop and legacy scenarios, but prefer Fiber handlers when performance or Fiber-specific APIs matter. ::: ## Features - Convert `net/http` handlers and middleware to Fiber handlers - Convert Fiber handlers to `net/http` handlers - Convert a Fiber context (`fiber.Ctx`) into an `http.Request` - Copy values stored in a `context.Context` onto a `fasthttp.RequestCtx` :::note Body size limits when running Fiber from net/http When Fiber is executed from a `net/http` server through `FiberHandler`, `FiberHandlerFunc`, or `FiberApp`, the adaptor enforces the app's configured `BodyLimit`. The app's `BodyLimit` defaults to **4 MiB** if a non-positive value is provided during configuration. Requests exceeding the active limit receive `413 Request Entity Too Large`. ::: ## API Reference | Name | Signature | Description | |-------------------------------|-------------------------------------------------------------------------------|-------------------------------------------------------------------------------| | `HTTPHandler` | `HTTPHandler(h http.Handler) fiber.Handler` | Converts `http.Handler` to `fiber.Handler` | | `HTTPHandlerWithContext` | `HTTPHandlerWithContext(h http.Handler) fiber.Handler` | Converts `http.Handler` to `fiber.Handler`, propagating Fiber's local context | | `HTTPHandlerFunc` | `HTTPHandlerFunc(h http.HandlerFunc) fiber.Handler` | Converts `http.HandlerFunc` to `fiber.Handler` | | `HTTPMiddleware` | `HTTPMiddleware(mw func(http.Handler) http.Handler) fiber.Handler` | Converts `http.Handler` middleware to `fiber.Handler` middleware | | `FiberHandler` | `FiberHandler(h fiber.Handler) http.Handler` | Converts `fiber.Handler` to `http.Handler` | | `FiberHandlerFunc` | `FiberHandlerFunc(h fiber.Handler) http.HandlerFunc` | Converts `fiber.Handler` to `http.HandlerFunc` | | `FiberApp` | `FiberApp(app *fiber.App) http.HandlerFunc` | Converts an entire Fiber app to a `http.HandlerFunc` | | `ConvertRequest` | `ConvertRequest(c fiber.Ctx, forServer bool) (*http.Request, error)` | Converts `fiber.Ctx` into a `http.Request` | | `LocalContextFromHTTPRequest` | `LocalContextFromHTTPRequest(r *http.Request) (context.Context, bool)` | Extracts the propagated `context.Context` from an adapted `http.Request` | | `CopyContextToFiberContext` | `CopyContextToFiberContext(context any, requestContext *fasthttp.RequestCtx)` | Copies `context.Context` to `fasthttp.RequestCtx` | --- ## Usage Examples ### 1. Using `net/http` handlers in Fiber (`HTTPHandler`, `HTTPHandlerFunc`) Run standard `net/http` handlers inside Fiber. Fiber can auto-adapt them, or you can explicitly convert them when you want to cache or share the converted handler. ```go package main import ( "fmt" "net/http" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/adaptor" ) func main() { app := fiber.New() // Fiber adapts net/http handlers for you during registration. app.Get("/", http.HandlerFunc(helloHandler)) // You can also convert and reuse the handler manually. cached := adaptor.HTTPHandler(http.HandlerFunc(helloHandler)) app.Get("/cached", cached) // When you already have an http.HandlerFunc, convert it directly. app.Get("/func", adaptor.HTTPHandlerFunc(helloHandler)) app.Listen(":3000") } func helloHandler(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "Hello from net/http!") } ``` ### 2. Using `net/http` middleware with Fiber (`HTTPMiddleware`) Middleware written for `net/http` can run inside Fiber: ```go package main import ( "log" "net/http" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/adaptor" ) func main() { app := fiber.New() // Apply an http middleware in Fiber app.Use(adaptor.HTTPMiddleware(loggingMiddleware)) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello Fiber!") }) app.Listen(":3000") } func loggingMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.Println("Request received") next.ServeHTTP(w, r) }) } ``` ### 3. Using Fiber handlers in `net/http` (`FiberHandler`) You can use Fiber handlers from `net/http`: ```go package main import ( "net/http" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/adaptor" ) func main() { // Convert a Fiber handler to an http.Handler http.Handle("/", adaptor.FiberHandler(helloFiber)) // Convert a Fiber handler to an http.HandlerFunc http.HandleFunc("/func", adaptor.FiberHandlerFunc(helloFiber)) http.ListenAndServe(":3000", nil) } func helloFiber(c fiber.Ctx) error { return c.SendString("Hello from Fiber!") } ``` ### 4. Converting Fiber handlers to `http.HandlerFunc` (`FiberHandlerFunc`) When you specifically need an `http.HandlerFunc`, wrap the Fiber handler directly: ```go package main import ( "net/http" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/adaptor" ) func main() { http.HandleFunc("/func-only", adaptor.FiberHandlerFunc(helloFiber)) http.ListenAndServe(":3000", nil) } func helloFiber(c fiber.Ctx) error { return c.SendString("Hello from Fiber!") } ``` ### 5. Running a full Fiber app inside `net/http` (`FiberApp`) You can wrap a full Fiber app inside `net/http`: ```go package main import ( "net/http" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/adaptor" ) func main() { app := fiber.New() app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello from Fiber!") }) // Run Fiber inside an http server http.ListenAndServe(":3000", adaptor.FiberApp(app)) } ``` ### 6. Converting `fiber.Ctx` to `*http.Request` (`ConvertRequest`) Create an `*http.Request` from a `fiber.Ctx`. The `forServer` parameter determines how server-oriented fields are populated: - Use `forServer = true` when the converted request will be passed into a `net/http` handler (sets `RequestURI`, `RemoteAddr`, and `TLS` fields for server-side handling) - Use `forServer = false` when creating a request for client-side use (e.g., making an outbound HTTP request with `http.Client`) ```go package main import ( "net/http" "net/http/httptest" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/adaptor" ) func main() { app := fiber.New() app.Get("/request", handleRequest) app.Listen(":3000") } func handleRequest(c fiber.Ctx) error { // Use forServer = true when passing to a net/http handler httpReq, err := adaptor.ConvertRequest(c, true) if err != nil { return err } // Pass the request to a net/http handler. recorder := httptest.NewRecorder() http.DefaultServeMux.ServeHTTP(recorder, httpReq) return c.SendString("Converted Request URL: " + httpReq.URL.String()) } ``` ### 7. Passing Fiber user context into `net/http` This example shows a realistic flow: a Fiber middleware sets a request-scoped `context.Context` (with a `request_id`) on the Fiber context, then an adapted `net/http` handler retrieves it via `LocalContextFromHTTPRequest`. ```go package main import ( "context" "fmt" "net/http" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/adaptor" ) type ctxKey string const requestIDKey ctxKey = "request_id" func main() { app := fiber.New() // Create a request-scoped context in Fiber (e.g., request id, auth claims, trace span). app.Use(func(c fiber.Ctx) error { reqID := c.Get("X-Request-ID") ctx := context.WithValue(context.Background(), requestIDKey, reqID) // Fiber stores request-scoped context as "user context". c.SetContext(ctx) return c.Next() }) // 2) Run a standard net/http handler that includes Fiber's user context propagated. app.Get("/hello", adaptor.HTTPHandlerWithContext(http.HandlerFunc(handleRequest))) app.Listen(":3000") } func handleRequest(w http.ResponseWriter, r *http.Request) { ctx, ok := adaptor.LocalContextFromHTTPRequest(r) if !ok || ctx == nil { http.Error(w, "missing propagated context", http.StatusInternalServerError) return } reqID, _ := ctx.Value(requestIDKey).(string) fmt.Fprintf(w, "Hello from net/http (request_id=%s)\n", reqID) } ``` ### 8. Copying context values onto `fasthttp.RequestCtx` (`CopyContextToFiberContext`) `CopyContextToFiberContext` copies values stored in a `context.Context` onto a `fasthttp.RequestCtx`. The function is marked deprecated in code because it uses reflection and unsafe operations—prefer explicit parameter passing when possible. When you do need it, call it immediately after you add values to the `net/http` context so Fiber can read them via `c.Context()`: ```go package main import ( "context" "net/http" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/adaptor" ) type contextKey string func main() { app := fiber.New() app.Use(func(c fiber.Ctx) error { // Convert the Fiber context to an http.Request so we can attach context values. httpReq, err := adaptor.ConvertRequest(c, true) if err != nil { return err } // Add context data and push it back to the Fiber context. enriched := httpReq.WithContext(context.WithValue(httpReq.Context(), contextKey("requestID"), "req-123")) adaptor.CopyContextToFiberContext(enriched.Context(), c.RequestCtx()) return c.Next() }) app.Get("/", func(c fiber.Ctx) error { if id, ok := c.Context().Value(contextKey("requestID")).(string); ok { return c.SendString("Request ID: " + id) } return c.SendStatus(fiber.StatusNotFound) }) app.Listen(":3000") } ``` --- ## Summary The `adaptor` package lets Fiber and `net/http` interoperate so you can: - Convert handlers and middleware in both directions - Run Fiber apps inside `net/http` - Convert `fiber.Ctx` to `http.Request` - Propagate Fiber's user context into adapted `net/http` handlers This makes it straightforward to integrate Fiber with existing Go projects or migrate between frameworks. ================================================ FILE: docs/middleware/basicauth.md ================================================ --- id: basicauth --- # BasicAuth Basic Authentication middleware for [Fiber](https://github.com/gofiber/fiber) that provides HTTP basic auth. It calls the next handler for valid credentials and returns [`401 Unauthorized`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/401) for missing or invalid credentials, [`400 Bad Request`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/400) for malformed `Authorization` headers, or [`431 Request Header Fields Too Large`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/431) when the header exceeds size limits. Credentials may omit Base64 padding as permitted by RFC 7235's `token68` syntax. The default unauthorized response includes the header `WWW-Authenticate: Basic realm="Restricted", charset="UTF-8"`, sets `Cache-Control: no-store`, and adds a `Vary: Authorization` header. Only the `UTF-8` charset is supported; any other value will panic. ## Signatures ```go func New(config Config) fiber.Handler func UsernameFromContext(ctx any) string ``` `UsernameFromContext` accepts a `fiber.CustomCtx`, `fiber.Ctx`, a `*fasthttp.RequestCtx`, or a `context.Context`. ## Examples Import the middleware package: ```go import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/basicauth" ) ``` Once your Fiber app is initialized, choose one of the following approaches: ```go // Provide a minimal config app.Use(basicauth.New(basicauth.Config{ Users: map[string]string{ // "doe" hashed using SHA-256 "john": "{SHA256}eZ75KhGvkY4/t0HfQpNPO1aO0tk6wd908bjUGieTKm8=", // "123456" hashed using bcrypt "admin": "$2a$10$gTYwCN66/tBRoCr3.TXa1.v1iyvwIF7GRBqxzv7G.AHLMt/owXrp.", }, })) // Or extend your config for customization app.Use(basicauth.New(basicauth.Config{ Users: map[string]string{ // "doe" hashed using SHA-256 "john": "{SHA256}eZ75KhGvkY4/t0HfQpNPO1aO0tk6wd908bjUGieTKm8=", // "123456" hashed using bcrypt "admin": "$2a$10$gTYwCN66/tBRoCr3.TXa1.v1iyvwIF7GRBqxzv7G.AHLMt/owXrp.", }, Realm: "Forbidden", Authorizer: func(user, pass string, c fiber.Ctx) bool { // custom validation logic return (user == "john" || user == "admin") }, Unauthorized: func(c fiber.Ctx) error { return c.SendFile("./unauthorized.html") }, })) ``` ### Password hashes Passwords must be supplied in pre-hashed form. The middleware detects the hashing algorithm from a prefix: - `"{SHA512}"` or `"{SHA256}"` followed by a base64-encoded digest - standard bcrypt strings beginning with `$2` If no prefix is present, the value is interpreted as a SHA-256 digest encoded in hex or base64. Plaintext passwords are rejected. #### Generating SHA-256 and SHA-512 passwords Create a digest, encode it in base64, and prefix it with `{SHA256}` or `{SHA512}` before adding it to `Users`: ```bash # SHA-256 printf 'secret' | openssl dgst -binary -sha256 | base64 # SHA-512 printf 'secret' | openssl dgst -binary -sha512 | base64 ``` Include the prefix in your config: ```go Users: map[string]string{ "john": "{SHA256}K7gNU3sdo+OL0wNhqoVWhr3g6s1xYv72ol/pe/Unols=", "admin": "{SHA512}vSsar3708Jvp9Szi2NWZZ02Bqp1qRCFpbcTZPdBhnWgs5WtNZKnvCXdhztmeD2cmW192CF5bDufKRpayrW/isg==", } ``` ## Config | Property | Type | Description | Default | |:----------------|:----------------------------|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------|:----------------------| | Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when it returns true. | `nil` | | Users | `map[string]string` | Users maps usernames to **hashed** passwords (e.g. bcrypt, `{SHA256}`). | `map[string]string{}` | | Realm | `string` | Realm is a string to define the realm attribute of BasicAuth. The realm identifies the system to authenticate against and can be used by clients to save credentials. | `"Restricted"` | | Charset | `string` | Charset sent in the `WWW-Authenticate` header. Only `"UTF-8"` is supported (case-insensitive). | `"UTF-8"` | | HeaderLimit | `int` | Maximum allowed length of the `Authorization` header. Requests exceeding this limit are rejected. | `8192` | | Authorizer | `func(string, string, fiber.Ctx) bool` | Authorizer defines a function to check the credentials. It will be called with a username, password, and the current context and is expected to return true or false to indicate approval. | `nil` | | Unauthorized | `fiber.Handler` | Unauthorized defines the response body for unauthorized responses. | `nil` | | BadRequest | `fiber.Handler` | BadRequest defines the response for malformed `Authorization` headers. | `nil` | ## Default Config ```go var ConfigDefault = Config{ Next: nil, Users: map[string]string{}, Realm: "Restricted", Charset: "UTF-8", HeaderLimit: 8192, Authorizer: nil, Unauthorized: nil, BadRequest: nil, } ``` ================================================ FILE: docs/middleware/cache.md ================================================ --- id: cache --- # Cache Cache middleware for [Fiber](https://github.com/gofiber/fiber) that intercepts responses and stores the body, `Content-Type`, and status code under a key derived from the request path and method. Special thanks to [@codemicro](https://github.com/codemicro/fiber-cache) for contributing this middleware to Fiber core. By default, cached responses expire after five minutes and the middleware stores up to 1 MB of response bodies. Request directives - `Cache-Control: no-cache` returns the latest response while still caching it, so the status is always `miss`. - `Cache-Control: no-store` skips caching and always forwards a fresh response. If the response includes a `Cache-Control: max-age` directive, its value sets the cache entry's expiration. Cacheable status codes The middleware caches these RFC 7231 status codes: - `200: OK` - `203: Non-Authoritative Information` - `204: No Content` - `206: Partial Content` - `300: Multiple Choices` - `301: Moved Permanently` - `404: Not Found` - `405: Method Not Allowed` - `410: Gone` - `414: URI Too Long` - `501: Not Implemented` Responses with other status codes result in an `unreachable` cache status. For more about cacheable status codes and RFC 7231, see: - [Cacheable - MDN Web Docs](https://developer.mozilla.org/en-US/docs/Glossary/Cacheable) - [RFC7231 - Hypertext Transfer Protocol (HTTP/1.1): Semantics and Content](https://datatracker.ietf.org/doc/html/rfc7231) ## Signatures ```go func New(config ...Config) fiber.Handler ``` ## Examples Import the middleware package: ```go import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/cache" "github.com/gofiber/utils/v2" ) ``` Once your Fiber app is initialized, register the middleware: ```go // Initialize default config app.Use(cache.New()) // Or extend the config for customization app.Use(cache.New(cache.Config{ Next: func(c fiber.Ctx) bool { return fiber.Query[bool](c, "noCache") }, Expiration: 30 * time.Minute, DisableCacheControl: true, })) ``` Customize the cache key and expiration; the HTTP method is appended automatically: ```go app.Use(cache.New(cache.Config{ ExpirationGenerator: func(c fiber.Ctx, cfg *cache.Config) time.Duration { newCacheTime, _ := strconv.Atoi(c.GetRespHeader("Cache-Time", "600")) return time.Second * time.Duration(newCacheTime) }, KeyGenerator: func(c fiber.Ctx) string { return utils.CopyString(c.Path()) }, })) app.Get("/", func(c fiber.Ctx) error { c.Response().Header.Add("Cache-Time", "6000") return c.SendString("hi") }) ``` Use `CacheInvalidator` to invalidate entries programmatically: ```go app.Use(cache.New(cache.Config{ CacheInvalidator: func(c fiber.Ctx) bool { return fiber.Query[bool](c, "invalidateCache") }, })) ``` `CacheInvalidator` defines custom invalidation rules. Return `true` to bypass the cache. In the example above, setting the `invalidateCache` query parameter to `true` invalidates the entry. Cache keys are masked in logs and error messages by default. Set `DisableValueRedaction` to `true` if you explicitly need the raw key for debugging. ## Config | Property | Type | Description | Default | | :------------------- | :--------------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :--------------------------------------------------------------- | | Next | `func(fiber.Ctx) bool` | Next defines a function that is executed before creating the cache entry and can be used to execute the request without cache creation. If an entry already exists, it will be used. If you want to completely bypass the cache functionality in certain cases, you should use the [skip middleware](skip.md). | `nil` | | Expiration | `time.Duration` | Expiration is the time that a cached response will live. | `5 * time.Minute` | | CacheHeader | `string` | CacheHeader is the header on the response header that indicates the cache status, with the possible return values "hit," "miss," or "unreachable." | `X-Cache` | | DisableCacheControl | `bool` | DisableCacheControl omits the `Cache-Control` header when set to `true`. | `false` | | CacheInvalidator | `func(fiber.Ctx) bool` | CacheInvalidator defines a function that is executed before checking the cache entry. It can be used to invalidate the existing cache manually by returning true. | `nil` | | DisableValueRedaction | `bool` | Turns off cache key redaction in logs and error messages when set to `true`. | `false` | | KeyGenerator | `func(fiber.Ctx) string` | KeyGenerator allows you to generate custom keys. The HTTP method is appended automatically. | `func(c fiber.Ctx) string { return utils.CopyString(c.Path()) }` | | ExpirationGenerator | `func(fiber.Ctx, *cache.Config) time.Duration` | ExpirationGenerator allows you to generate custom expiration keys based on the request. | `nil` | | Storage | `fiber.Storage` | Storage is used to store the state of the middleware. | In-memory store | | StoreResponseHeaders | `bool` | StoreResponseHeaders allows you to store additional headers generated by next middlewares & handler. | `false` | | MaxBytes | `uint` | MaxBytes is the maximum number of bytes of response bodies simultaneously stored in cache. | `1 * 1024 * 1024` (~1 MB) | | Methods | `[]string` | Methods specifies the HTTP methods to cache. | `[]string{fiber.MethodGet, fiber.MethodHead}` | ## Default Config ```go var ConfigDefault = Config{ Next: nil, Expiration: 5 * time.Minute, CacheHeader: "X-Cache", DisableCacheControl: false, CacheInvalidator: nil, DisableValueRedaction: false, KeyGenerator: func(c fiber.Ctx) string { return utils.CopyString(c.Path()) }, ExpirationGenerator: nil, StoreResponseHeaders: false, Storage: nil, MaxBytes: 1 * 1024 * 1024, Methods: []string{fiber.MethodGet, fiber.MethodHead}, } ``` ================================================ FILE: docs/middleware/compress.md ================================================ --- id: compress --- # Compress Compression middleware for [Fiber](https://github.com/gofiber/fiber) that automatically compresses responses with `gzip`, `deflate`, `brotli`, or `zstd` based on the client's [Accept-Encoding](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding) header. :::note Bodies smaller than 200 bytes remain uncompressed because compression would likely increase their size and waste CPU cycles. [See the fasthttp source](https://github.com/valyala/fasthttp/blob/497922a21ef4b314f393887e9c6147b8c3e3eda4/http.go#L1713-L1715). ::: ## Behavior - Skips compression for responses that already define `Content-Encoding`, for range requests, `206` responses, status codes without bodies, or when either side sends `Cache-Control: no-transform`. - `HEAD` requests negotiate compression so `Content-Encoding`, `Content-Length`, `ETag`, and `Vary` reflect the encoded representation, but the body is removed before sending. - When compression runs, strong `ETag` values are recomputed from the compressed bytes; when skipped, `Accept-Encoding` is still merged into `Vary` unless the header is `*` or already present. ## Signatures ```go func New(config ...Config) fiber.Handler ``` ## Examples Import the middleware package: ```go import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/compress" ) ``` Once your Fiber app is initialized, use the middleware like this: ```go // Initialize default config app.Use(compress.New()) // Or extend your config for customization app.Use(compress.New(compress.Config{ Level: compress.LevelBestSpeed, // 1 })) // Skip middleware for specific routes app.Use(compress.New(compress.Config{ Next: func(c fiber.Ctx) bool { return c.Path() == "/dont_compress" }, Level: compress.LevelBestSpeed, // 1 })) ``` ## Config | Property | Type | Description | Default | |:-------- |:-----------------------|:------------------------------------------------------------|:-------------------| | Next | `func(fiber.Ctx) bool` | Skips this middleware when the function returns `true`. | `nil` | | Level | `Level` | Compression level to use. | `LevelDefault (0)` | Possible values for the "Level" field are: - `LevelDisabled (-1)`: Compression is disabled. - `LevelDefault (0)`: Default compression level. - `LevelBestSpeed (1)`: Best compression speed. - `LevelBestCompression (2)`: Best compression. ## Default Config ```go var ConfigDefault = Config{ Next: nil, Level: LevelDefault, } ``` ## Constants ```go // Compression levels const ( LevelDisabled = -1 LevelDefault = 0 LevelBestSpeed = 1 LevelBestCompression = 2 ) ``` ================================================ FILE: docs/middleware/cors.md ================================================ --- id: cors --- # CORS CORS (Cross-Origin Resource Sharing) middleware for [Fiber](https://github.com/gofiber/fiber) lets servers control who can access resources and how. It isn't a security feature; it merely relaxes the browser's same-origin policy so cross-origin requests can succeed. Learn more on [MDN](https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS). It adds CORS headers to responses, listing allowed origins, methods, and headers, and handles preflight checks. Use the `AllowOrigins` option to define which origins may send cross-origin requests. It accepts single origins, lists, subdomain patterns, wildcards, and supports dynamic validation with `AllowOriginsFunc`. The middleware normalizes `AllowOrigins`, verifies HTTP/HTTPS schemes, and strips trailing slashes. Invalid origins cause a panic. Panic messages and logs redact misconfigured origins by default; set `DisableValueRedaction` to `true` if you need the raw value for troubleshooting. Avoid [common pitfalls](#common-pitfalls) such as using wildcard origins with credentials, overly permissive origin lists, or skipping validation with `AllowOriginsFunc`, as misconfiguration can create security risks. ## Signatures ```go func New(config ...Config) fiber.Handler ``` ## Examples Import the middleware package: ```go import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/cors" ) ``` Once your Fiber app is initialized, apply the middleware in one of the following ways: ### Basic usage To use the default configuration, simply use `cors.New()`. This will allow wildcard origins '*', all methods, no credentials, and no headers or exposed headers. ```go app.Use(cors.New()) ``` ### Custom configuration (specific origins, headers, etc.) ```go // Initialize default config app.Use(cors.New()) // Or extend your config for customization app.Use(cors.New(cors.Config{ AllowOrigins: []string{"https://gofiber.io", "https://gofiber.net"}, AllowHeaders: []string{"Origin", "Content-Type", "Accept"}, })) ``` ### Dynamic origin validation You can use `AllowOriginsFunc` to programmatically determine whether to allow a request based on its origin. This is useful when you need to validate origins against a database or other dynamic sources. The function should return `true` if the origin is allowed, and `false` otherwise. Be sure to review the [security considerations](#security-considerations) when using `AllowOriginsFunc`. :::caution Never allow `AllowOriginsFunc` to return `true` for all origins. This is particularly crucial when `AllowCredentials` is set to `true`. Doing so can bypass the restriction of using a wildcard origin with credentials, exposing your application to serious security threats. If you need to allow wildcard origins, use `AllowOrigins` with a wildcard `"*"` instead of `AllowOriginsFunc`. ::: ```go // dbCheckOrigin checks if the origin is in the list of allowed origins in the database. func dbCheckOrigin(db *sql.DB, origin string) bool { // Placeholder query - adjust according to your database schema and query needs query := "SELECT COUNT(*) FROM allowed_origins WHERE origin = $1" var count int err := db.QueryRow(query, origin).Scan(&count) if err != nil { // Handle error (e.g., log it); for simplicity, we return false here return false } return count > 0 } // ... app.Use(cors.New(cors.Config{ AllowOriginsFunc: func(origin string) bool { return dbCheckOrigin(db, origin) }, })) ``` ### Prohibited usage The following example is prohibited because it can expose your application to security risks. It sets `AllowOrigins` to `"*"` (a wildcard) and `AllowCredentials` to `true`. ```go app.Use(cors.New(cors.Config{ AllowOrigins: []string{"*"}, AllowCredentials: true, })) ``` This will result in the following panic: ```text panic: [CORS] Configuration error: When 'AllowCredentials' is set to true, 'AllowOrigins' cannot contain a wildcard origin '*'. Please specify allowed origins explicitly or adjust 'AllowCredentials' setting. ``` ## Config | Property | Type | Description | Default | |:---------------------|:----------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:----------------------------------------| | AllowCredentials | `bool` | AllowCredentials indicates whether or not the response to the request can be exposed when the credentials flag is true. When used as part of a response to a preflight request, this indicates whether or not the actual request can be made using credentials. Note: If true, AllowOrigins cannot be set to a wildcard (`"*"`) to prevent security vulnerabilities. | `false` | | AllowHeaders | `[]string` | AllowHeaders defines a list of request headers that can be used when making the actual request. This is in response to a preflight request. | `[]` | | AllowMethods | `[]string` | AllowMethods defines a list of methods allowed when accessing the resource. This is used in response to a preflight request. | `"GET, POST, HEAD, PUT, DELETE, PATCH"` | | AllowOrigins | `[]string` | AllowOrigins defines a list of origins that may access the resource. This supports subdomain matching, so you can use a value like "https://*.example.com" to allow any subdomain of example.com to submit requests. If the special wildcard `"*"` is present in the list, all origins will be allowed. | `["*"]` | | AllowOriginsFunc | `func(origin string) bool` | `AllowOriginsFunc` is a function that dynamically determines whether to allow a request based on its origin. If this function returns `true`, the 'Access-Control-Allow-Origin' response header will be set to the request's 'origin' header. This function is only used if the request's origin doesn't match any origin in `AllowOrigins`. | `nil` | | AllowPrivateNetwork | `bool` | Indicates whether the `Access-Control-Allow-Private-Network` response header should be set to `true`, allowing requests from private networks. This aligns with modern security practices for web applications interacting with private networks. | `false` | | DisableValueRedaction | `bool` | Disables redaction of misconfigured origins and settings in panics and logs. | `false` | | ExposeHeaders | `[]string` | ExposeHeaders defines an allowlist of headers that clients are allowed to access. | `[]` | | MaxAge | `int` | MaxAge indicates how long (in seconds) the results of a preflight request can be cached. If you pass MaxAge 0, the Access-Control-Max-Age header will not be added and the browser will use 5 seconds by default. To disable caching completely, pass MaxAge value negative. It will set the Access-Control-Max-Age header to 0. | `0` | | Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when it returns true. | `nil` | :::note If AllowOrigins is a zero value `[]string{}`, and AllowOriginsFunc is provided, the middleware will not default to allowing all origins with the wildcard value "*". Instead, it will rely on the AllowOriginsFunc to dynamically determine whether to allow a request based on its origin. This provides more flexibility and control over which origins are allowed. ::: ## Default Config ```go var ConfigDefault = Config{ Next: nil, AllowOriginsFunc: nil, AllowOrigins: []string{"*"}, DisableValueRedaction: false, AllowMethods: []string{ fiber.MethodGet, fiber.MethodPost, fiber.MethodHead, fiber.MethodPut, fiber.MethodDelete, fiber.MethodPatch, }, AllowHeaders: []string{}, AllowCredentials: false, ExposeHeaders: []string{}, MaxAge: 0, AllowPrivateNetwork: false, } ``` ## Subdomain Matching The `AllowOrigins` configuration supports matching subdomains at any level. This means you can use a value like `"https://*.example.com"` to allow any subdomain of `example.com` to submit requests, including multiple subdomain levels such as `"https://sub.sub.example.com"`. ### Example If you want to allow CORS requests from any subdomain of `example.com`, including nested subdomains, you can configure the `AllowOrigins` like so: ```go app.Use(cors.New(cors.Config{ AllowOrigins: []string{"https://*.example.com"}, })) ``` ## How It Works The CORS middleware works by adding the necessary CORS headers to responses from your Fiber application. These headers tell browsers what origins, methods, and headers are allowed for cross-origin requests. When a request arrives, the middleware first checks whether it is a preflight request—a CORS mechanism that determines if the actual request is safe to send. Preflight requests are HTTP OPTIONS requests with specific CORS headers. If the request is preflight, the middleware responds with the appropriate CORS headers and ends the request. :::note Preflight requests are typically sent by browsers before making actual cross-origin requests, especially for methods other than GET or POST, or when custom headers are used. A preflight request is an HTTP OPTIONS request that includes the `Origin`, `Access-Control-Request-Method`, and optionally `Access-Control-Request-Headers` headers. The browser sends this request to check if the server allows the actual request method and headers. ::: If the request is not preflight, the middleware adds the CORS headers to the response and passes the request to the next handler. The actual CORS headers added depend on the configuration of the middleware. The `AllowOrigins` option controls which origins can make cross-origin requests. The middleware handles different `AllowOrigins` configurations as follows: - **Single origin:** If `AllowOrigins` is set to a single origin like `"http://www.example.com"`, and that origin matches the origin of the incoming request, the middleware adds the header `Access-Control-Allow-Origin: http://www.example.com` to the response. - **Multiple origins:** If `AllowOrigins` is set to multiple origins like `"https://example.com, https://www.example.com"`, the middleware picks the origin that matches the origin of the incoming request. - **Subdomain matching:** If `AllowOrigins` includes `"https://*.example.com"`, a subdomain like `https://sub.example.com` will be matched and `"https://sub.example.com"` will be the header. This will also match `https://sub.sub.example.com` and so on, but not `https://example.com`. - **Wildcard origin:** If `AllowOrigins` is set to `"*"`, the middleware uses that and adds the header `Access-Control-Allow-Origin: *` to the response. In all cases above, except the **Wildcard origin**, the middleware will either add the `Access-Control-Allow-Origin` header to the response matching the origin of the incoming request, or it will not add the header at all if the origin is not allowed. - **Programmatic origin validation:**: The middleware also handles the `AllowOriginsFunc` option, which allows you to programmatically determine if an origin is allowed. If `AllowOriginsFunc` returns `true` for an origin, the middleware sets the `Access-Control-Allow-Origin` header to that origin. - **Null origin handling:** The middleware accepts the special literal value `"null"` as a valid origin. According to the [CORS specification](https://developer.mozilla.org/en-US/docs/Web/HTTP/Guides/CORS#origin), browsers send `"null"` as the origin for certain privacy-sensitive contexts, such as: - Requests from sandboxed iframes - Requests from `file://` URLs - Requests from `data:` URLs - Cross-origin redirects When using `AllowOriginsFunc`, if the function returns `true` for the literal string `"null"`, the middleware will set `Access-Control-Allow-Origin: null` in the response. The `"null"` origin is case-sensitive and must be lowercase. The `AllowMethods` option controls which HTTP methods are allowed. For example, if `AllowMethods` is set to `"GET, POST"`, the middleware adds the header `Access-Control-Allow-Methods: GET, POST` to the response. The `AllowHeaders` option specifies which headers are allowed in the actual request. The middleware sets the Access-Control-Allow-Headers response header to the value of `AllowHeaders`. This informs the client which headers it can use in the actual request. The `AllowCredentials` option indicates whether the response to the request can be exposed when the credentials flag is true. If `AllowCredentials` is set to `true`, the middleware adds the header `Access-Control-Allow-Credentials: true` to the response. To prevent security vulnerabilities, `AllowCredentials` cannot be set to `true` if `AllowOrigins` is set to a wildcard (`*`). The `ExposeHeaders` option defines an allowlist of headers that clients are allowed to access. If `ExposeHeaders` is set to `"X-Custom-Header"`, the middleware adds the header `Access-Control-Expose-Headers: X-Custom-Header` to the response. The `MaxAge` option indicates how long the results of a preflight request can be cached. If `MaxAge` is set to `3600`, the middleware adds the header `Access-Control-Max-Age: 3600` to the response. The `Vary` header helps caches store the correct response. For simple requests the middleware sets `Vary: Origin` unless all origins are allowed. Preflight responses add `Vary: Origin, Access-Control-Request-Method, Access-Control-Request-Headers` (and `Access-Control-Request-Private-Network` when enabled and requested). This ensures caches know when to reuse a response and when to revalidate with the server. ## Infrastructure Considerations When deploying Fiber applications behind infrastructure components like CDNs, API gateways, load balancers, or reverse proxies, you have two main options for handling CORS: ### Option 1: Use Infrastructure-Level CORS (Recommended) **For most production deployments, it is often preferable to handle CORS at the infrastructure level** rather than in your Fiber application. This approach offers several advantages: - **Better Performance**: CORS headers are added at the edge, closer to the client - **Reduced Server Load**: Preflight requests are handled without reaching your application - **Centralized Configuration**: Manage CORS policies alongside other infrastructure settings - **Built-in Caching**: Infrastructure providers optimize CORS response caching **Common infrastructure CORS solutions:** - **CDNs**: CloudFront, CloudFlare, Azure CDN - handle CORS at edge locations - **API Gateways**: AWS API Gateway, Google Cloud API Gateway - centralized CORS management - **Load Balancers**: Application Load Balancers with CORS rules - **Reverse Proxies**: Nginx, Apache with CORS modules If using infrastructure-level CORS, **disable Fiber's CORS middleware** to avoid conflicts: ```go // Don't use both - choose one approach // app.Use(cors.New()) // Remove this line when using infrastructure CORS ``` ### Option 2: Application-Level CORS (Fiber Middleware) Use Fiber's CORS middleware when you need: - **Dynamic origin validation** based on application logic - **Fine-grained control** over CORS policies per route - **Integration with application state** (database-driven origins, etc.) - **Development environments** where infrastructure CORS isn't available If choosing this approach, ensure that **all CORS headers reach your Fiber application unchanged**. ### Required Headers for CORS Preflight Requests For CORS preflight requests to work correctly, these headers **must not be stripped or modified by caching layers**: - `Origin` - Required to identify the requesting origin - `Access-Control-Request-Method` - Required to identify the HTTP method for the actual request - `Access-Control-Request-Headers` - Optional, contains custom headers the actual request will use - `Access-Control-Request-Private-Network` - Optional, for private network access requests :::warning Critical Preflight Requirement If the `Access-Control-Request-Method` header is missing from an OPTIONS request, Fiber will not recognize them as CORS preflight requests. Instead, they'll be treated as regular OPTIONS requests, which typically return `405 Method Not Allowed` since most applications don't define explicit OPTIONS handlers. ::: ### CORS Response Headers (Set by Fiber) The middleware sets these response headers based on your configuration: **For all CORS requests:** - `Access-Control-Allow-Origin` - Set to the allowed origin or "*" - `Access-Control-Allow-Credentials` - Set to "true" when `AllowCredentials: true` - `Access-Control-Expose-Headers` - Lists headers the client can access - `Vary` - Set to "Origin" (unless wildcard origins are used) **For preflight responses only:** - `Access-Control-Allow-Methods` - Lists allowed HTTP methods - `Access-Control-Allow-Headers` - Lists allowed request headers (or echoes the request) - `Access-Control-Max-Age` - Cache duration for preflight results (if MaxAge > 0) - `Access-Control-Allow-Private-Network` - Set to "true" when private network access is allowed - `Vary` - Set to "Access-Control-Request-Method, Access-Control-Request-Headers, Origin" ### Common Infrastructure Issues **CDNs (CloudFront, CloudFlare, etc.)**: - Configure cache policies to forward all CORS headers - Ensure OPTIONS requests are not cached inappropriately or cache them correctly with proper Vary headers - Don't strip or modify CORS request headers **API Gateways**: - Choose either gateway-level CORS OR application-level CORS, not both - If using gateway CORS, disable Fiber's CORS middleware - If forwarding to Fiber, ensure all headers pass through unchanged **Load Balancers/Reverse Proxies**: - Preserve all HTTP headers, especially CORS-related ones - Don't modify or strip `Origin`, `Access-Control-Request-*` headers **WAFs/Security Services**: - Whitelist CORS headers in security rules - Ensure OPTIONS requests with CORS headers aren't blocked ### Debugging CORS Issues Add this middleware **before** your CORS configuration to debug what headers Fiber receives: ```go // Debug middleware to log CORS preflight requests // Only use in development or testing environments app.Use(func(c *fiber.Ctx) error { if c.Method() == "OPTIONS" { fmt.Printf("OPTIONS %s\n", c.Path()) fmt.Printf(" Origin: %s\n", c.Get("Origin")) fmt.Printf(" Access-Control-Request-Method: %s\n", c.Get("Access-Control-Request-Method")) fmt.Printf(" Access-Control-Request-Headers: %s\n", c.Get("Access-Control-Request-Headers")) } return c.Next() }) app.Use(cors.New(cors.Config{ AllowOrigins: []string{"https://yourdomain.com"}, AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, })) ``` Test CORS preflight directly with curl: ```bash # Test preflight request curl -X OPTIONS https://your-app.com/api/test \ -H "Origin: https://yourdomain.com" \ -H "Access-Control-Request-Method: POST" \ -H "Access-Control-Request-Headers: Content-Type" \ -v # Test simple CORS request curl -X GET https://your-app.com/api/test \ -H "Origin: https://yourdomain.com" \ -v ``` ### Caching Considerations The middleware sets appropriate `Vary` headers to ensure proper caching: - **Non-wildcard origins**: `Vary: Origin` is set to cache responses per origin - **Preflight requests**: `Vary: Access-Control-Request-Method, Access-Control-Request-Headers, Origin` - **OPTIONS without preflight headers**: `Vary: Origin` to avoid cache poisoning Ensure your infrastructure respects these `Vary` headers for correct caching behavior. ### Choosing the Right Approach | Scenario | Recommended Approach | |----------|---------------------| | Production with CDN/API Gateway | Infrastructure-level CORS | | Dynamic origin validation needed | Application-level CORS | | Microservices with different CORS policies | Application-level CORS | | Simple static origins | Infrastructure-level CORS | | Development/testing | Application-level CORS | | High traffic applications | Infrastructure-level CORS | :::tip Infrastructure CORS Configuration Most cloud providers offer comprehensive CORS documentation: - [AWS CloudFront CORS](https://docs.aws.amazon.com/AmazonCloudFront/latest/DeveloperGuide/header-caching.html#header-caching-web-cors) - [Google Cloud CORS](https://cloud.google.com/storage/docs/cross-origin) - [Azure CDN CORS](https://docs.microsoft.com/en-us/azure/cdn/cdn-cors) - [CloudFlare CORS](https://developers.cloudflare.com/fundamentals/get-started/reference/http-request-headers/#cf-connecting-ip) Configure CORS at the infrastructure level when possible for optimal performance and reduced complexity. ::: ## Security Considerations When configuring CORS, misconfiguration can potentially expose your application to various security risks. Here are some secure configurations and common pitfalls to avoid: ### Secure Configurations - **Specify Allowed Origins**: Instead of using a wildcard (`"*"`), specify the exact domains allowed to make requests. For example, `AllowOrigins: "https://www.example.com, https://api.example.com"` ensures only these domains can make cross-origin requests to your application. - **Use Credentials Carefully**: If your application needs to support credentials in cross-origin requests, ensure `AllowCredentials` is set to `true` and specify exact origins in `AllowOrigins`. Do not use a wildcard origin in this case. - **Limit Exposed Headers**: Only allowlist headers that are necessary for the client-side application by setting `ExposeHeaders` appropriately. This minimizes the risk of exposing sensitive information. ### Common Pitfalls - **Wildcard Origin with Credentials**: Setting `AllowOrigins` to `"*"` (a wildcard) and `AllowCredentials` to `true` is a common misconfiguration. This combination is prohibited because it can expose your application to security risks. - **Overly Permissive Origins**: Specifying too many origins or using overly broad patterns (e.g., `https://*.example.com`) can inadvertently allow malicious sites to interact with your application. Be as specific as possible with allowed origins. - **Inadequate `AllowOriginsFunc` Validation**: When using `AllowOriginsFunc` for dynamic origin validation, ensure the function includes robust checks to prevent unauthorized origins from being accepted. Overly permissive validation can lead to security vulnerabilities. Never allow `AllowOriginsFunc` to return `true` for all origins. This is particularly crucial when `AllowCredentials` is set to `true`. Doing so can bypass the restriction of using a wildcard origin with credentials, exposing your application to serious security threats. If you need to allow wildcard origins, use `AllowOrigins` with a wildcard `"*"` instead of `AllowOriginsFunc`. Remember, the key to secure CORS configuration is specificity and caution. By carefully selecting which origins, methods, and headers are allowed, you can help protect your application from cross-origin attacks. ================================================ FILE: docs/middleware/csrf.md ================================================ --- id: csrf --- # CSRF The CSRF middleware protects against [Cross-Site Request Forgery](https://en.wikipedia.org/wiki/Cross-site_request_forgery) attacks by validating tokens on unsafe HTTP methods such as POST, PUT, and DELETE. It responds with 403 Forbidden when validation fails. ## Table of Contents - [Quick Start](#quick-start) - [Best Practices & Production Requirements](#best-practices--production-requirements) - [Configuration by Application Type](#configuration-by-application-type) - [Recipes for Common Use Cases](#recipes-for-common-use-cases) - [Using CSRF Tokens](#using-csrf-tokens) - [Security Model](#security-model) - [Token Extractors](#token-extractors) - [Advanced Configuration](#advanced-configuration) - [API Reference](#api-reference) - [Config Properties](#config-properties) - [Error Types](#error-types) - [Constants](#constants) ## Quick Start ```go import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/extractors" "github.com/gofiber/fiber/v3/middleware/csrf" ) // Default config (development only) app.Use(csrf.New()) // Production config app.Use(csrf.New(csrf.Config{ CookieName: "__Host-csrf_", CookieSecure: true, CookieHTTPOnly: true, // false for SPAs CookieSameSite: "Lax", CookieSessionOnly: true, Extractor: extractors.FromHeader("X-Csrf-Token"), Session: sessionStore, // Redaction is enabled by default. Set DisableValueRedaction when you must expose tokens or storage keys in diagnostics. // DisableValueRedaction: true, })) ``` ## Best Practices & Production Requirements :::danger Production Requirements - `CookieSecure: true` (HTTPS only) - `CookieSameSite: "Lax"` or `"Strict"` - Use `Session` store for better security ::: 1. **Always use HTTPS** in production 2. **Use sessions** for authenticated applications 3. **Set `CookieSecure: true`** and appropriate SameSite values 4. **Implement XSS protection** alongside CSRF 5. **Regenerate tokens** after auth changes 6. **Use `__Host-` cookie prefix** when possible :::warning BREACH Protection To mitigate BREACH attacks, ensure your pages are served over HTTPS, disable HTTP compression, and implement rate limiting for requests. The CSRF token is sent as a header on every request, so if you include the token in a page that is vulnerable to BREACH, an attacker may be able to extract the token. ::: ## Configuration by Application Type ### Server-Side Rendered Apps ```go app.Use(csrf.New(csrf.Config{ CookieName: "__Host-csrf_", CookieSecure: true, CookieHTTPOnly: true, // Secure - blocks JavaScript CookieSameSite: "Lax", CookieSessionOnly: true, Extractor: extractors.FromForm("_csrf"), Session: sessionStore, })) ``` ### Single Page Applications (SPAs) ```go app.Use(csrf.New(csrf.Config{ CookieName: "__Host-csrf_", CookieSecure: true, CookieHTTPOnly: false, // Required for JavaScript access to tokens CookieSameSite: "Lax", CookieSessionOnly: true, Extractor: extractors.FromHeader("X-Csrf-Token"), Session: sessionStore, })) ``` :::warning SPA Security Trade-off SPAs require `CookieHTTPOnly: false` to access tokens via JavaScript. This slightly increases XSS risk but is necessary for SPA functionality. ::: ## Recipes for Common Use Cases - **Without Sessions**: [CSRF Recipe](https://github.com/gofiber/recipes/tree/master/csrf) - Simple Double Submit Cookie pattern - **With Sessions**: [CSRF with Session Recipe](https://github.com/gofiber/recipes/tree/master/csrf-with-session) - More secure Synchronizer Token pattern ## Using CSRF Tokens ### Server-Side Forms ```go func formHandler(c fiber.Ctx) error { token := csrf.TokenFromContext(c) return c.SendString(fmt.Sprintf(`
`, token)) } ``` ### Single Page Applications ```go func apiHandler(c fiber.Ctx) error { token := csrf.TokenFromContext(c) return c.JSON(fiber.Map{ "csrf_token": token, "data": "your data", }) } ``` ```javascript // Get CSRF token from cookie function getCsrfToken() { const value = `; ${document.cookie}`; const parts = value.split(`; __Host-csrf_=`); if (parts.length === 2) return parts.pop().split(';').shift(); } // Use with fetch API async function makeRequest(url, data) { const csrfToken = getCsrfToken(); const response = await fetch(url, { method: 'POST', headers: { 'Content-Type': 'application/json', 'X-Csrf-Token': csrfToken }, body: JSON.stringify(data) }); if (!response.ok) { throw new Error(`HTTP ${response.status}: ${response.statusText}`); } return response.json(); } ``` ## Security Model The middleware employs a robust, defense-in-depth strategy to protect against CSRF attacks. The primary defense is token-based validation, which operates in one of two modes depending on your configuration. This is supplemented by a mandatory secondary check on the request's origin. ### Fetch Metadata Guardrails - **Sec-Fetch-Site**: For unsafe methods, the middleware inspects the [`Sec-Fetch-Site`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-Fetch-Site) header when present. If the header value is not one of "same-origin", "none", "same-site", or "cross-site", the request is rejected with `ErrFetchSiteInvalid`. If the header is valid or absent, the request proceeds to the standard origin and token validation checks. This provides an early check to block requests with invalid `Sec-Fetch-Site` values, while allowing legitimate same-site and cross-site requests to be validated by the existing mechanisms. ### 1. Token Validation Patterns #### Double Submit Cookie (Default Mode) This is the default pattern, used when a `Session` store is **not** configured. It is a "semi-stateless" approach; while it doesn't tie tokens to a specific user session, the server still maintains a record of all validly issued tokens. - **How it Works:** 1. On a user's first visit (or a safe request like `GET`), the middleware generates a unique token. 2. This token is sent to the client in a `Set-Cookie` header. 3. The server also stores this token (in memory by default or in the configured `Storage`). It confirms the token is server-generated and still valid, but it is not tied to a specific user. 4. For subsequent unsafe requests (e.g., `POST`, `PUT`), the client must read the token from the cookie and echo it in a different location, such as the `X-Csrf-Token` header. - **Validation:** The middleware validates three things: that the token from the header/form **exactly matches** the token from the cookie, that the token **exists** in the server-side storage, and that it **has not expired**. - **Why it is secure:** Attackers on a malicious domain cannot read the victim's cookie to forge a matching header. They also cannot invent a token because it wouldn't exist in the server's storage registry. #### Synchronizer Token (Session-Based Mode) This is a more secure, stateful pattern that is **automatically enabled** when you provide a `Session` store in the configuration. - **How it Works:** 1. A unique token is generated and stored directly within the user's session data on the server. 2. The token is also sent to the client as a cookie. 3. For unsafe requests, the client sends the token back in a header or form field. - **Validation:** The middleware performs a multi-step validation: 1. It first performs the standard **Double Submit Cookie check**: the token from the header/form must exactly match the token from the cookie. This is a fast and efficient first line of defense, and there is little benefit of skipping it. 2. It then validates that this token exists and is valid within the user's **server-side session**. This is the authoritative check that ties the token to the authenticated user. - **Why it is more secure:** Tying the token to the server-side session provides the strongest CSRF protection, as the token is then guaranteed to have been generated for the specific user. While browsers automatically send the required cookie, custom API clients must remember to include the cookie with their requests for validation to succeed. ```go // Enable the more secure Synchronizer Token pattern app.Use(csrf.New(csrf.Config{ Session: sessionStore, // Providing a session store activates this mode })) ``` ### 2. Origin & Referer Validation As a crucial second layer of defense, the middleware **always** performs `Origin` and `Referer` header checks for unsafe requests (when the connection is HTTPS). - The request's `Origin` (for cross-origin requests) or `Referer` (for same-origin requests) header **must** match the application's `Host` header or be explicitly allowed in the `TrustedOrigins` list. - This check is performed *in addition* to token validation and provides strong protection because these headers are reliably set by browsers and cannot be programmatically controlled by an attacker from a malicious site. ## Token Extractors This middleware uses the shared `extractors` package for token extraction. For full details on extractor types, chaining, security, and advanced usage, see the [Extractors Guide](../guide/extractors). **Extractor Source Constants:** Extractor source constants (such as `SourceHeader`, `SourceForm`, etc.) are defined in the shared extractors package, not in the CSRF middleware itself. Refer to the Extractors Guide for their definitions and usage. ### CSRF-Specific Extractor Notes For CSRF protection, prefer secure extraction methods: - **Headers** (`extractors.FromHeader("X-Csrf-Token")`) – Most secure, not logged in URLs - **Form data** (`extractors.FromForm("_csrf")`) – Secure for form submissions - **Avoid URL parameters** – Query/param extractors expose tokens in logs and browser history :::note What about cookies? **Cookies are generally not a secure source for CSRF tokens.** The middleware will panic if you configure an extractor that reads from cookies with the same name as your CSRF cookie. This is because reading the CSRF token from a cookie with the same name as the CSRF cookie defeats CSRF protection entirely, as the extracted token will always match the cookie value, allowing any CSRF attack to succeed. **Advanced usage:** In rare cases, you may securely extract a CSRF token from a cookie if: - You read from a different cookie (not the CSRF cookie itself) - You use multiple cookies for custom validation - You implement custom logic across different cookie sources If you do this, set the extractor’s `Source` to `SourceCookie` and allow the middleware to check that the cookie name is different from your CSRF cookie. It will panic if this is the case. **Warning:** Cookie-based extraction is strongly discouraged, as it is easy to misconfigure and creates security risks. Prefer extracting tokens from headers or form fields for robust CSRF protection. See the [Extractors Guide](../guide/extractors#security-considerations) for more details. ::: ### Route-Specific Configuration You can configure different extraction methods for different routes: ```go // API routes - header extraction for AJAX/fetch requests api := app.Group("/api") api.Use(csrf.New(csrf.Config{ Extractor: extractors.FromHeader("X-Csrf-Token"), })) // Form routes - form field extraction for traditional forms forms := app.Group("/forms") forms.Use(csrf.New(csrf.Config{ Extractor: extractors.FromForm("_csrf"), })) ``` ### Custom CSRF Extractors For specialized CSRF token extraction needs, you can create custom extractors. See the [Extractors Guide](../guide/extractors#custom-extraction-logic) for advanced patterns and security notes. :::danger Never Extract from Cookies **NEVER create custom extractors that read from cookies using the same `CookieName` as your CSRF configuration.** This completely defeats CSRF protection by making the extracted token always match the cookie value, allowing any CSRF attack to succeed. ```go // ❌ NEVER DO THIS - Completely defeats CSRF protection badExtractor := csrf.Extractor{ Extract: func(c fiber.Ctx) (string, error) { return c.Cookies("csrf_"), nil // Always passes validation! }, Source: csrf.SourceCustom, // See extractors.SourceCustom in shared package Key: "csrf_", } // ✅ DO THIS - Extract from different source than cookie app.Use(csrf.New(csrf.Config{ CookieName: "csrf_", Extractor: extractors.FromHeader("X-Csrf-Token"), // Header vs cookie comparison })) ``` The middleware uses the **Double Submit Cookie** pattern – it compares the extracted token against the cookie value. If you configure an extractor that reads from the same cookie, it will panic because they will always match and provide zero CSRF protection. ::: #### Bearer Token Embedding & Custom Extractors You can create advanced extractors for use cases like JWT embedding or JSON body parsing. See the [Extractors Guide](../guide/extractors#custom-extraction-logic) for secure implementation patterns and more examples. ### Fallback Extraction For applications that need to support both AJAX and form submissions: ```go // Try header first (AJAX), fall back to form (traditional forms) app.Use(csrf.New(csrf.Config{ Extractor: extractors.Chain( extractors.FromHeader("X-Csrf-Token"), extractors.FromForm("_csrf"), ), })) ``` :::warning Chaining extractors increases complexity. Use only when you need to support multiple client types. See the [Extractors Guide](../guide/extractors#chain-ordering-strategy) for details and security notes. ::: ## Advanced Configuration ### Trusted Origins ```go app.Use(csrf.New(csrf.Config{ TrustedOrigins: []string{ "https://trusted.example.com", "https://*.example.com", // Wildcard subdomains }, })) ``` ### Custom Error Handler ```go app.Use(csrf.New(csrf.Config{ ErrorHandler: func(c fiber.Ctx, err error) error { accepts := c.Accepts("html", "json") path := c.Path() if accepts == "json" || strings.HasPrefix(path, "/api/") { return c.Status(fiber.StatusForbidden).JSON(fiber.Map{ "error": "Forbidden", }) } return c.Status(fiber.StatusForbidden).Render("error", fiber.Map{ "Title": "Forbidden", "Status": fiber.StatusForbidden, }, "layouts/main") }, })) ``` ### Custom Storage/Database You can use any storage from our [storage](https://github.com/gofiber/storage/) package. ```go storage := sqlite3.New() // From github.com/gofiber/storage/sqlite3/v2 app.Use(csrf.New(csrf.Config{ Storage: storage, })) ``` ### Token Management ```go // Delete token (e.g., on logout) handler := csrf.HandlerFromContext(c) if handler != nil { if err := handler.DeleteToken(c); err != nil { // handle error, e.g. log it } } // With session middleware // Destroying the session will also remove the CSRF token if using session-based CSRF. session.Destroy() ``` ## API Reference ```go // Create middleware func New(config ...csrf.Config) fiber.Handler // Get token from context func TokenFromContext(ctx any) string // Get handler from context func HandlerFromContext(ctx any) *csrf.Handler // Delete token func (h *csrf.Handler) DeleteToken(c fiber.Ctx) error ``` `TokenFromContext` and `HandlerFromContext` accept a `fiber.CustomCtx`, `fiber.Ctx`, a `*fasthttp.RequestCtx`, or a `context.Context`. ## Config Properties | Property | Type | Description | Default | |:------------------|:-----------------------------------|:------------------------------------------------------------------------------------------------------------------------------|:-----------------------------| | Next | `func(fiber.Ctx) bool` | Skip middleware when returns true | `nil` | | CookieName | `string` | CSRF cookie name | `"csrf_"` | | CookieDomain | `string` | CSRF cookie domain | `""` | | CookiePath | `string` | CSRF cookie path | `""` | | CookieSecure | `bool` | HTTPS only cookie (**required for production**) | `false` | | CookieHTTPOnly | `bool` | Prevent JavaScript access (**use `false` for SPAs**) | `false` | | CookieSameSite | `string` | SameSite attribute (**use "Lax" or "Strict"**) | `"Lax"` | | CookieSessionOnly | `bool` | Session-only cookie (expires on browser close) | `false` | | IdleTimeout | `time.Duration` | Token expiration time | `30 * time.Minute` | | KeyGenerator | `func() string` | Token generation function | `utils.SecureToken` | | ErrorHandler | `fiber.ErrorHandler` | Custom error handler | `defaultErrorHandler` | | Extractor | `extractors.Extractor` | Token extraction method with metadata | `extractors.FromHeader("X-Csrf-Token")` | | DisableValueRedaction | `bool` | Disables redaction of tokens and storage keys in logs and error messages. | `false` | | Session | `*session.Store` | Session store (**recommended for production**) | `nil` | | Storage | `fiber.Storage` | Token storage (overridden by Session) | `nil` | | TrustedOrigins | `[]string` | Trusted origins for cross-origin requests | `[]` | | SingleUseToken | `bool` | Generate new token after each use | `false` | ## Error Types ```go var ( ErrTokenNotFound = errors.New("csrf: token not found") ErrTokenInvalid = errors.New("csrf: token invalid") ErrRefererNotFound = errors.New("csrf: referer header missing") ErrRefererInvalid = errors.New("csrf: referer header invalid") ErrRefererNoMatch = errors.New("csrf: referer does not match host or trusted origins") ErrOriginInvalid = errors.New("csrf: origin header invalid") ErrOriginNoMatch = errors.New("csrf: origin does not match host or trusted origins") ) ``` ## Constants ```go const ( HeaderName = "X-Csrf-Token" ) ``` ================================================ FILE: docs/middleware/earlydata.md ================================================ --- id: earlydata --- # EarlyData The Early Data middleware adds TLS 1.3 "0-RTT" support to [Fiber](https://github.com/gofiber/fiber). When the client and server share a PSK, TLS 1.3 lets the client send data with the first flight and skip the initial round trip. Enable Fiber's `TrustProxy` option before using this middleware to avoid spoofed client headers. When `TrustProxy` is disabled (the default) or the remote address is not trusted by your proxy configuration, requests carrying the `Early-Data` header are rejected with `425 Too Early` to prevent 0-RTT spoofing from direct clients. Enabling early data in a reverse proxy (for example, `ssl_early_data on;` in nginx) makes requests replayable. Review these resources before proceeding: - [datatracker](https://datatracker.ietf.org/doc/html/rfc8446#section-8) - [trailofbits](https://blog.trailofbits.com/2019/03/25/what-application-developers-need-to-know-about-tls-early-data-0rtt) By default, the middleware permits early data only for safe methods (`GET`, `HEAD`, `OPTIONS`, `TRACE`) and rejects other requests before your handler runs. Override this behavior with the `AllowEarlyData` option. ## Signatures ```go func New(config ...Config) fiber.Handler func IsEarly(c fiber.Ctx) bool ``` `IsEarly` returns `true` when a request used early data and the middleware allowed it to proceed. ## Examples Import the middleware package: ```go import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/earlydata" ) ``` Once your Fiber app is initialized, use the middleware like this: ```go // Initialize default config app.Use(earlydata.New()) // Or extend your config for customization app.Use(earlydata.New(earlydata.Config{ Error: fiber.ErrTooEarly, // ... })) ``` ## Config | Property | Type | Description | Default | |:---------------|:------------------------|:-----------|:-------------------------------------------------------| | Next | `func(fiber.Ctx) bool` | Skip this middleware when the function returns true. | `nil` | | IsEarlyData | `func(fiber.Ctx) bool` | Reports whether the request used early data. | Function checking if "Early-Data" header equals "1" | | AllowEarlyData | `func(fiber.Ctx) bool` | Decides if an early-data request should be allowed. | Function rejecting on unsafe and allowing safe methods | | Error | `error` | Returned when an early-data request is rejected. | `fiber.ErrTooEarly` | ## Default Config ```go var ConfigDefault = Config{ IsEarlyData: func(c fiber.Ctx) bool { return c.Get(DefaultHeaderName) == DefaultHeaderTrueValue }, AllowEarlyData: func(c fiber.Ctx) bool { return fiber.IsMethodSafe(c.Method()) }, Error: fiber.ErrTooEarly, } ``` ## Constants ```go const ( DefaultHeaderName = "Early-Data" DefaultHeaderTrueValue = "1" ) ``` ================================================ FILE: docs/middleware/encryptcookie.md ================================================ --- id: encryptcookie --- # Encrypt Cookie The Encrypt Cookie middleware for [Fiber](https://github.com/gofiber/fiber) encrypts cookie values for secure storage. :::note This middleware encrypts cookie values but not cookie names. ::: ## Signatures ```go // Initializes the middleware func New(config ...Config) fiber.Handler // GenerateKey returns a random string of 16, 24, or 32 bytes. // The length of the key determines the AES encryption algorithm used: // 16 bytes for AES-128, 24 bytes for AES-192, and 32 bytes for AES-256-GCM. func GenerateKey(length int) string ``` ## Examples Import the middleware package: ```go import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/encryptcookie" ) ``` Once your Fiber app is initialized, register the middleware: ```go // Provide a minimal configuration app.Use(encryptcookie.New(encryptcookie.Config{ Key: "secret-32-character-string", })) // Retrieve the encrypted cookie value app.Get("/", func(c fiber.Ctx) error { return c.SendString("value=" + c.Cookies("test")) }) // Create an encrypted cookie app.Post("/", func(c fiber.Ctx) error { c.Cookie(&fiber.Cookie{ Name: "test", Value: "SomeThing", }) return nil }) ``` :::note Use an encoded key of 16, 24, or 32 bytes to select AES‑128, AES‑192, or AES‑256‑GCM. Generate a stable key with `openssl rand -base64 32` or `encryptcookie.GenerateKey(32)` and store it securely. Generating a new key on each startup renders existing cookies unreadable. ::: ## Config | Property | Type | Description | Default | |:----------|:----------------------------------------------------|:------------------------------------------------------------------------------------------------------|:-----------------------------| | Next | `func(fiber.Ctx) bool` | A function to skip this middleware when it returns true. | `nil` | | Except | `[]string` | Array of cookie keys that should not be encrypted. | `[]` | | Key | `string` | A base64-encoded unique key to encode & decode cookies. Required. Key length should be 16, 24, or 32 bytes. | (No default, required field) | | Encryptor | `func(name, decryptedString, key string) (string, error)` | A custom function to encrypt cookies. | `EncryptCookie` | | Decryptor | `func(name, encryptedString, key string) (string, error)` | A custom function to decrypt cookies. | `DecryptCookie` | ### Encryptor and Decryptor parameters Custom encryptor and decryptor functions receive three arguments: - `name`: The cookie name. The default helpers bind this value as additional authenticated data (AAD) so encrypted values can only be decrypted for the same cookie. - `string`: The cookie payload. `EncryptCookie` accepts the decrypted value and returns ciphertext, while `DecryptCookie` receives ciphertext and must return the decrypted value. - `key`: The base64-encoded key pulled from the middleware configuration. Use it to derive or validate any encryption keys your implementation requires. ## Default Config ```go var ConfigDefault = Config{ Next: nil, Except: []string{}, Key: "", Encryptor: EncryptCookie, Decryptor: DecryptCookie, } ``` ## Use with Other Middleware That Reads or Modifies Cookies Place `encryptcookie` before middleware that reads or writes cookies. If you use the CSRF middleware, register `encryptcookie` first so it can read the token. Exclude cookies from encryption by listing them in `Except`. If a frontend framework such as Angular reads the CSRF token from a cookie, add that name to the `Except` array: ```go app.Use(encryptcookie.New(encryptcookie.Config{ Key: "secret-thirty-2-character-string", Except: []string{csrf.ConfigDefault.CookieName}, // exclude CSRF cookie })) app.Use(csrf.New(csrf.Config{ Extractor: csrf.FromHeader(csrf.HeaderName), CookieSameSite: "Lax", CookieSecure: true, CookieHTTPOnly: false, })) ``` ## Encryption Algorithms The default Encryptor and Decryptor functions use `AES-256-GCM` for encryption and decryption. If you need to use `AES-128` or `AES-192` instead, you can do so by changing the length of the key when calling `encryptcookie.GenerateKey(length)` or by providing a key of one of the following lengths: - AES-128 requires a 16-byte key. - AES-192 requires a 24-byte key. - AES-256 requires a 32-byte key. For example, to generate a key for AES-128: ```go key := encryptcookie.GenerateKey(16) ``` And for AES-192: ```go key := encryptcookie.GenerateKey(24) ``` ================================================ FILE: docs/middleware/envvar.md ================================================ --- id: envvar --- # EnvVar EnvVar middleware for [Fiber](https://github.com/gofiber/fiber) exposes environment variables with configurable options. ## Signatures ```go func New(config ...Config) fiber.Handler ``` ## Examples Import the middleware package: ```go import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/envvar" ) ``` Once your Fiber app is initialized, configure the middleware as shown: ```go // Initialize default config (exports no variables) app.Use("/expose/envvars", envvar.New()) // Or extend your config for customization app.Use("/expose/envvars", envvar.New( envvar.Config{ ExportVars: map[string]string{"testKey": "", "testDefaultKey": "testDefaultVal"}, }), ) ``` :::note Mount the middleware on a path; it cannot be used without one. ::: ## Response Sample response: ```json { "vars": { "someEnvVariable": "someValue", "anotherEnvVariable": "anotherValue" } } ``` ## Config | Property | Type | Description | Default | |:------------|:--------------------|:-----------------------------------------------------------------------------|:--------| | ExportVars | `map[string]string` | ExportVars lists the environment variables to expose. | `nil` | ## Default Config ```go Config{} // Exports no environment variables ``` ================================================ FILE: docs/middleware/etag.md ================================================ --- id: etag --- # ETag ETag middleware for [Fiber](https://github.com/gofiber/fiber) that helps caches validate responses and saves bandwidth by avoiding full retransmits when content is unchanged. ## Signatures ```go func New(config ...Config) fiber.Handler ``` ## Examples Import the middleware package: ```go import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/etag" ) ``` Once your Fiber app is initialized, use the middleware like this: ```go // Initialize default config app.Use(etag.New()) // GET / -> ETag: "13-1831710635" app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) // Or extend your config for customization app.Use(etag.New(etag.Config{ Weak: true, })) // GET / -> ETag: W/"13-1831710635" app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) ``` Entity tags in requests must be quoted per RFC 9110. For example: ```text If-None-Match: "example-etag" ``` ## Config | Property | Type | Description | Default | |:---------|:------------------------|:-------------------------------------------------------------------------------------------------------------------|:--------| | Weak | `bool` | Enables weak validators. Weak ETags are easier to generate but less reliable for comparisons. | `false` | | Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when it returns true. | `nil` | ## Default Config ```go var ConfigDefault = Config{ Next: nil, Weak: false, } ``` ================================================ FILE: docs/middleware/expvar.md ================================================ --- id: expvar --- # ExpVar The ExpVar middleware exposes runtime variables over HTTP in JSON. Using it (e.g., `app.Use(expvarmw.New())`) registers handlers on `/debug/vars`. ## Signatures ```go func New() fiber.Handler ``` ## Examples Import the middleware package: ```go import ( "github.com/gofiber/fiber/v3" expvarmw "github.com/gofiber/fiber/v3/middleware/expvar" ) ``` Once your Fiber app is initialized, use the middleware as shown: ```go var count = expvar.NewInt("count") app.Use(expvarmw.New()) app.Get("/", func(c fiber.Ctx) error { count.Add(1) return c.SendString(fmt.Sprintf("hello expvar count %d", count.Value())) }) ``` Visit `/debug/vars` to see all variables, and append `?r=key` to filter the output. ```bash curl 127.0.0.1:3000 hello expvar count 1 curl 127.0.0.1:3000/debug/vars { "cmdline": ["xxx"], "count": 1, "expvarHandlerCalls": 33, "expvarRegexpErrors": 0, "memstats": {...} } curl 127.0.0.1:3000/debug/vars?r=c { "cmdline": ["xxx"], "count": 1 } ``` ## Config | Property | Type | Description | Default | |:---------|:------------------------|:--------------------------------------------------------------------|:--------| | Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when it returns true. | `nil` | ## Default Config ```go var ConfigDefault = Config{ Next: nil, } ``` ================================================ FILE: docs/middleware/favicon.md ================================================ --- id: favicon --- # Favicon Favicon middleware for [Fiber](https://github.com/gofiber/fiber) that drops repeated `/favicon.ico` requests or serves a cached icon from memory. Mount it before your logger to suppress noisy requests and avoid disk reads. It handles only `GET`, `HEAD`, and `OPTIONS` to the configured URL; other methods return `405 Method Not Allowed`. :::note This middleware only serves the default `/favicon.ico` (or a [custom URL](#config)). For multiple icons, use the Static middleware. ::: ## Signatures ```go func New(config ...Config) fiber.Handler ``` ## Examples Import the middleware package: ```go import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/favicon" ) ``` Once your Fiber app is initialized, use the middleware like this: ```go // Initialize default config app.Use(favicon.New()) // Or extend your config for customization app.Use(favicon.New(favicon.Config{ File: "./favicon.ico", URL: "/favicon.ico", })) ``` ## Config | Property | Type | Description | Default | |:-------------|:------------------------|:---------------------------------------------------------------------------------|:---------------------------| | Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when it returns true. | `nil` | | Data | `[]byte` | Raw data of the favicon file. This can be used instead of `File`. | `nil` | | File | `string` | File holds the path to an actual favicon that will be cached. | "" | | URL | `string` | URL for favicon handler. | "/favicon.ico" | | FileSystem | `fs.FS` | FileSystem is an optional alternate filesystem from which to load the favicon file (e.g. using `os.DirFS` or an `embed.FS`). | `nil` | | CacheControl | `string` | CacheControl defines how the Cache-Control header in the response should be set. | "public, max-age=31536000" | | MaxBytes | `int64` | MaxBytes limits the maximum size of the cached favicon asset. | `1048576` | ## Default Config ```go var ConfigDefault = Config{ Next: nil, File: "", URL: fPath, CacheControl: "public, max-age=31536000", MaxBytes: 1024 * 1024, } ``` ================================================ FILE: docs/middleware/healthcheck.md ================================================ --- id: healthcheck --- # Health Check Middleware that adds liveness, readiness, and startup probes to [Fiber](https://github.com/gofiber/fiber) apps. It provides a generic handler you can mount on any route, with constants for the conventional `/livez`, `/readyz`, and `/startupz` endpoints. ## Overview Register the middleware on any endpoint you want to expose a probe on. The package exports constants for the conventional liveness, readiness, and startup endpoints: ```go app.Get(healthcheck.LivenessEndpoint, healthcheck.New()) app.Get(healthcheck.ReadinessEndpoint, healthcheck.New()) app.Get(healthcheck.StartupEndpoint, healthcheck.New()) ``` By default the probe returns `true`, so each endpoint responds with `200 OK`; returning `false` yields `503 Service Unavailable`. - **Liveness**: Checks if the server is running. - **Readiness**: Checks if the application is ready to handle requests. - **Startup**: Checks if the application has completed its startup sequence. ## Signatures ```go func New(config ...Config) fiber.Handler ``` ## Examples Import the middleware package: ```go import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/healthcheck" ) ``` After your app is initialized, register the middleware on the endpoints you want to expose: ```go // Use the default probe on the conventional endpoints app.Get(healthcheck.LivenessEndpoint, healthcheck.New()) app.Get(healthcheck.ReadinessEndpoint, healthcheck.New(healthcheck.Config{ Probe: func(c fiber.Ctx) bool { return serviceA.Ready() && serviceB.Ready() }, })) app.Get(healthcheck.StartupEndpoint, healthcheck.New()) // Register a custom endpoint app.Get("/healthz", healthcheck.New()) ``` The middleware responds only to GET. Use `app.All` to expose a probe on every method; other methods fall through to the next handler: ```go app.All("/healthz", healthcheck.New()) ``` ## Config ```go type Config struct { // Next defines a function to skip this middleware when it returns true. If this function returns true // and no other handlers are defined for the route, Fiber will return a status 404 Not Found, since // no other handlers were defined to return a different status. // // Optional. Default: nil Next func(fiber.Ctx) bool // Probe is executed to determine the current health state. It can be used for // liveness, readiness or startup checks. Returning true indicates the application // is healthy. // // Optional. Default: func(c fiber.Ctx) bool { return true } Probe func(fiber.Ctx) bool } ``` ## Default Config The default configuration used by this middleware is defined as follows: ```go func defaultProbe(_ fiber.Ctx) bool { return true } var ConfigDefault = Config{ Next: nil, Probe: defaultProbe, } ``` ================================================ FILE: docs/middleware/helmet.md ================================================ --- id: helmet --- # Helmet Helmet secures your app by adding common security headers. ## Signatures ```go func New(config ...Config) fiber.Handler ``` ## Examples Once your Fiber app is initialized, add the middleware: ```go package main import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/helmet" ) func main() { app := fiber.New() app.Use(helmet.New()) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Welcome!") }) app.Listen(":3000") } ``` ## Test ```bash curl -I http://localhost:3000 ``` ## Config | Property | Type | Description | Default | |:--------------------------|:------------------------|:--------------------------------------------|:-----------------| | Next | `func(fiber.Ctx) bool` | Skips the middleware when the function returns `true`. | `nil` | | XSSProtection | `string` | Value for the `X-XSS-Protection` header. | "0" | | ContentTypeNosniff | `string` | Value for the `X-Content-Type-Options` header. | "nosniff" | | XFrameOptions | `string` | Value for the `X-Frame-Options` header. | "SAMEORIGIN" | | HSTSMaxAge | `int` | `max-age` value for `Strict-Transport-Security`. | 0 | | HSTSExcludeSubdomains | `bool` | Disables HSTS on subdomains when `true`. | false | | ContentSecurityPolicy | `string` | Value for the `Content-Security-Policy` header. | "" | | CSPReportOnly | `bool` | Enables report-only mode for CSP. | false | | HSTSPreloadEnabled | `bool` | Adds the `preload` directive to HSTS. | false | | ReferrerPolicy | `string` | Value for the `Referrer-Policy` header. | "no-referrer" | | PermissionPolicy | `string` | Value for the `Permissions-Policy` header. | "" | | CrossOriginEmbedderPolicy | `string` | Value for the `Cross-Origin-Embedder-Policy` header. | "require-corp" | | CrossOriginOpenerPolicy | `string` | Value for the `Cross-Origin-Opener-Policy` header. | "same-origin" | | CrossOriginResourcePolicy | `string` | Value for the `Cross-Origin-Resource-Policy` header. | "same-origin" | | OriginAgentCluster | `string` | Value for the `Origin-Agent-Cluster` header. | "?1" | | XDNSPrefetchControl | `string` | Value for the `X-DNS-Prefetch-Control` header. | "off" | | XDownloadOptions | `string` | Value for the `X-Download-Options` header. | "noopen" | | XPermittedCrossDomain | `string` | Value for the `X-Permitted-Cross-Domain-Policies` header. | "none" | ## Default Config ```go var ConfigDefault = Config{ XSSProtection: "0", ContentTypeNosniff: "nosniff", XFrameOptions: "SAMEORIGIN", ReferrerPolicy: "no-referrer", CrossOriginEmbedderPolicy: "require-corp", CrossOriginOpenerPolicy: "same-origin", CrossOriginResourcePolicy: "same-origin", OriginAgentCluster: "?1", XDNSPrefetchControl: "off", XDownloadOptions: "noopen", XPermittedCrossDomain: "none", } ``` ================================================ FILE: docs/middleware/idempotency.md ================================================ --- id: idempotency --- # Idempotency The Idempotency middleware helps build fault-tolerant APIs. Duplicate requests—such as retries after network issues—won't trigger the same action twice on the server. Refer to [IETF RFC 7231 §4.2.2](https://tools.ietf.org/html/rfc7231#section-4.2.2) for definitions of safe and idempotent HTTP methods. ## HTTP Method Categories * **Safe Methods** (do not modify server state): `GET`, `HEAD`, `OPTIONS`, `TRACE` * **Idempotent Methods** (identical requests have the same effect as a single one): all safe methods **plus** `PUT` and `DELETE` > According to the RFC, safe methods never change server state, while idempotent methods may change state but remain safe to repeat. ## Signatures ```go func New(config ...Config) fiber.Handler func IsFromCache(c fiber.Ctx) bool func WasPutToCache(c fiber.Ctx) bool ``` ## Examples Import the middleware package: ```go import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/idempotency" ) ``` Once your Fiber app is initialized, configure the middleware: ### Default Config (Skip **Safe** Methods) By default, the `Next` function skips middleware for safe methods only: ```go app.Use(idempotency.New()) ``` ### Skip **Idempotent** Methods Instead Skip all idempotent methods (including `PUT` and `DELETE`) by overriding `Next`: ```go app.Use(idempotency.New(idempotency.Config{ Next: func(c fiber.Ctx) bool { // Skip middleware for idempotent methods (safe + PUT, DELETE) return fiber.IsMethodIdempotent(c.Method()) }, })) ``` ### Custom Config ```go app.Use(idempotency.New(idempotency.Config{ Lifetime: 42 * time.Minute, // ... })) ``` ## Config Idempotency keys are hidden in logs and error messages by default. Set `DisableValueRedaction` to `true` only when you need to expose them for debugging. | Property | Type | Description | Default | |:--------------------|:-----------------------|:----------------------------------------------------------------------------------------------------------------------------------------|:-------------------------------------------------------------------| | Next | `func(fiber.Ctx) bool` | Function to skip this middleware when it returns `true`; use `IsMethodSafe` or `IsMethodIdempotent`. | `func(c fiber.Ctx) bool { return fiber.IsMethodSafe(c.Method()) }` | | Lifetime | `time.Duration` | Maximum lifetime of an idempotency key. | `30 * time.Minute` | | KeyHeader | `string` | Header name containing the idempotency key. | `"X-Idempotency-Key"` | | KeyHeaderValidate | `func(string) error` | Function to validate idempotency header syntax (e.g., UUID). | UUID length check (`36` characters) | | KeepResponseHeaders | `[]string` | List of headers to preserve from original response. | `nil` (keep all headers) | | DisableValueRedaction | `bool` | Disables idempotency key redaction in logs and error messages. | `false` | | Lock | `Locker` | Locks an idempotency key to prevent race conditions. | In-memory locker | | Storage | `fiber.Storage` | Stores response data by idempotency key. | In-memory storage | ## Default Config Values ```go var ConfigDefault = Config{ Next: func(c fiber.Ctx) bool { // Skip middleware for safe methods per RFC 7231 §4.2.2 return fiber.IsMethodSafe(c.Method()) }, Lifetime: 30 * time.Minute, KeyHeader: "X-Idempotency-Key", KeyHeaderValidate: func(k string) error { if l, wl := len(k), 36; l != wl { // UUID length is 36 chars return fmt.Errorf("%w: invalid length: %d != %d", ErrInvalidIdempotencyKey, l, wl) } return nil }, KeepResponseHeaders: nil, Lock: nil, // Set in configDefault so we don't allocate data here. Storage: nil, // Set in configDefault so we don't allocate data here. DisableValueRedaction: false, } ``` ================================================ FILE: docs/middleware/keyauth.md ================================================ --- id: keyauth --- # KeyAuth The KeyAuth middleware implements API key authentication. ## Signatures ```go func New(config ...Config) fiber.Handler func TokenFromContext(ctx any) string ``` `TokenFromContext` accepts a `fiber.CustomCtx`, `fiber.Ctx`, a `*fasthttp.RequestCtx`, or a `context.Context`. ## Examples ### Basic example This example registers KeyAuth with an API key stored in a cookie. ```go package main import ( "crypto/sha256" "crypto/subtle" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/keyauth" ) var ( apiKey = "correct horse battery staple" ) func validateAPIKey(c fiber.Ctx, key string) (bool, error) { hashedAPIKey := sha256.Sum256([]byte(apiKey)) hashedKey := sha256.Sum256([]byte(key)) if subtle.ConstantTimeCompare(hashedAPIKey[:], hashedKey[:]) == 1 { return true, nil } return false, keyauth.ErrMissingOrMalformedAPIKey } func main() { app := fiber.New() // Register middleware before the routes that need it app.Use(keyauth.New(keyauth.Config{ Extractor: keyauth.FromCookie("access_token"), Validator: validateAPIKey, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Successfully authenticated!") }) app.Listen(":3000") } ``` **Test:** ```bash # No API key specified -> 401 Missing or invalid API Key curl http://localhost:3000 #> Missing or invalid API Key # Correct API key -> 200 OK curl --cookie "access_token=correct horse battery staple" http://localhost:3000 #> Successfully authenticated! # Incorrect API key -> 401 Missing or invalid API Key curl --cookie "access_token=Clearly A Wrong Key" http://localhost:3000 #> Missing or invalid API Key ``` For a more detailed example, see the [`fiber-envoy-extauthz`](https://github.com/gofiber/recipes/tree/master/fiber-envoy-extauthz) recipe in the `gofiber/recipes` repository. ### Authenticate only certain endpoints Use the `Next` function to run KeyAuth only on selected routes. ```go package main import ( "crypto/sha256" "crypto/subtle" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/keyauth" "regexp" "strings" ) var ( apiKey = "correct horse battery staple" protectedURLs = []*regexp.Regexp{ regexp.MustCompile("^/authenticated$"), regexp.MustCompile("^/auth2$"), } ) func validateAPIKey(c fiber.Ctx, key string) (bool, error) { hashedAPIKey := sha256.Sum256([]byte(apiKey)) hashedKey := sha256.Sum256([]byte(key)) if subtle.ConstantTimeCompare(hashedAPIKey[:], hashedKey[:]) == 1 { return true, nil } return false, keyauth.ErrMissingOrMalformedAPIKey } func authFilter(c fiber.Ctx) bool { originalURL := strings.ToLower(c.OriginalURL()) for _, pattern := range protectedURLs { if pattern.MatchString(originalURL) { // Run middleware for protected routes return false } } // Skip middleware for non-protected routes return true } func main() { app := fiber.New() app.Use(keyauth.New(keyauth.Config{ Next: authFilter, Extractor: keyauth.FromCookie("access_token"), Validator: validateAPIKey, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Welcome") }) app.Get("/authenticated", func(c fiber.Ctx) error { return c.SendString("Successfully authenticated!") }) app.Get("/auth2", func(c fiber.Ctx) error { return c.SendString("Successfully authenticated 2!") }) app.Listen(":3000") } ``` **Test:** ```bash # / doesn't require authentication curl http://localhost:3000 #> Welcome # /authenticated requires authentication curl --cookie "access_token=correct horse battery staple" http://localhost:3000/authenticated #> Successfully authenticated! # /auth2 requires authentication too curl --cookie "access_token=correct horse battery staple" http://localhost:3000/auth2 #> Successfully authenticated 2! ``` ### Apply middleware in the handler You can apply the middleware to specific routes or groups instead of globally. This example uses the default extractor (`FromAuthHeader`). ```go package main import ( "crypto/sha256" "crypto/subtle" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/keyauth" ) const ( apiKey = "my-super-secret-key" ) func main() { app := fiber.New() authMiddleware := keyauth.New(keyauth.Config{ Validator: func(c fiber.Ctx, key string) (bool, error) { hashedAPIKey := sha256.Sum256([]byte(apiKey)) hashedKey := sha256.Sum256([]byte(key)) if subtle.ConstantTimeCompare(hashedAPIKey[:], hashedKey[:]) == 1 { return true, nil } return false, keyauth.ErrMissingOrMalformedAPIKey }, }) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Welcome") }) app.Get("/allowed", authMiddleware, func(c fiber.Ctx) error { return c.SendString("Successfully authenticated!") }) app.Listen(":3000") } ``` **Test:** ```bash # / doesn't require authentication curl http://localhost:3000 #> Welcome # /allowed requires authentication curl --header "Authorization: Bearer my-super-secret-key" http://localhost:3000/allowed #> Successfully authenticated! ``` ## Key Extractors KeyAuth uses an `Extractor` from the shared [extractors](../guide/extractors) package to retrieve the API key from the request. You can specify one or more extractors in the configuration. For a full list of extractors, chaining, and advanced usage, see the [Extractors Guide](../guide/extractors). ### Typical Usage Specify the extractor in the config. For example, to extract from a cookie: ```go app.Use(keyauth.New(keyauth.Config{ Extractor: extractors.FromCookie("access_token"), Validator: validateAPIKey, })) ``` To use the default (Authorization header with Bearer scheme): ```go app.Use(keyauth.New(keyauth.Config{ Validator: validateAPIKey, // Extractor defaults to FromAuthHeader("Bearer") })) ``` To try multiple sources (header, then query): ```go app.Use(keyauth.New(keyauth.Config{ Extractor: extractors.Chain( extractors.FromHeader("X-API-Key"), extractors.FromQuery("api_key"), ), Validator: validateAPIKey, })) ``` For custom logic, use `extractors.FromCustom`: ```go app.Use(keyauth.New(keyauth.Config{ Extractor: extractors.FromCustom(func(c fiber.Ctx) (string, error) { return c.Get("X-My-API-Key"), nil }), Validator: validateAPIKey, })) ``` Refer to the [Extractors Guide](../guide/extractors) for details, security notes, and advanced configuration. ## Config | Property | Type | Description | Default | |:----------------|:-----------------------------------------|:-------------------------------------------------------------------------------------------------------|:------------------------------| | Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when it returns true. | `nil` | | SuccessHandler | `fiber.Handler` | SuccessHandler defines a function which is executed for a valid key. | `c.Next()` | | ErrorHandler | `fiber.ErrorHandler` | ErrorHandler defines a function which is executed for an invalid key. By default a 401 response with a `WWW-Authenticate` challenge is sent. | Default error handler | | Validator | `func(fiber.Ctx, string) (bool, error)` | **Required.** Validator is a function to validate the key. | `nil` (panic) | | Extractor | `extractors.Extractor` | Extractor defines how to retrieve the key from the request. Use helper functions from the shared extractors package, e.g. `extractors.FromAuthHeader("Bearer")` or `extractors.FromCookie("access_token")`. | `extractors.FromAuthHeader("Bearer")` | | Realm | `string` | Realm specifies the protected area name used in the `WWW-Authenticate` header. | `"Restricted"` | | Challenge | `string` | Value of the `WWW-Authenticate` header when no `Authorization` scheme is present. | `ApiKey realm="Restricted"` | | Error | `string` | Error code appended as the `error` parameter in Bearer challenges. Must be `invalid_request`, `invalid_token`, or `insufficient_scope`. | `""` | | ErrorDescription| `string` | Human-readable text for the `error_description` parameter in Bearer challenges. Requires `Error`. | `""` | | ErrorURI | `string` | URI identifying a human-readable web page with information about the `error` in Bearer challenges. Requires `Error` and must be an absolute URI. | `""` | | Scope | `string` | Space-delimited list of scopes for the `scope` parameter in Bearer challenges. Each token must conform to the RFC 6750 `scope-token` syntax and requires `Error` set to `insufficient_scope`. | `""` | ## Default Config ```go var ConfigDefault = Config{ SuccessHandler: func(c fiber.Ctx) error { return c.Next() }, ErrorHandler: func(c fiber.Ctx, _ error) error { return c.Status(fiber.StatusUnauthorized).SendString(ErrMissingOrMalformedAPIKey.Error()) }, Realm: "Restricted", Extractor: extractors.FromAuthHeader("Bearer"), } ``` ================================================ FILE: docs/middleware/limiter.md ================================================ --- id: limiter --- # Limiter The Limiter middleware for [Fiber](https://github.com/gofiber/fiber) throttles repeated requests to public APIs or endpoints such as password resets. It's also useful for API clients, web crawlers, or other tasks that need rate limiting. Limiter redacts request keys in error paths by default so storage identifiers and rate-limit keys don't leak into logs. Set `DisableValueRedaction` to `true` when you explicitly need the raw key for troubleshooting. :::note This middleware uses our [Storage](https://github.com/gofiber/storage) package to support various databases through a single interface. The default configuration for this middleware saves data to memory, see the examples below for other databases. ::: :::note This module does not share state with other processes/servers by default. ::: ## Signatures ```go func New(config ...Config) fiber.Handler type Handler interface { New(config *Config) fiber.Handler } ``` ## Examples Import the middleware package: ```go import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/limiter" ) ``` Once your Fiber app is initialized, use the middleware like this: ```go // Initialize default config app.Use(limiter.New()) // Or extend your config for customization app.Use(limiter.New(limiter.Config{ Next: func(c fiber.Ctx) bool { return c.IP() == "127.0.0.1" }, Max: 20, MaxFunc: func(c fiber.Ctx) int { return 20 }, Expiration: 30 * time.Second, ExpirationFunc: func(c fiber.Ctx) time.Duration { // Use longer expiration for sensitive endpoints if c.Path() == "/login" { return 60 * time.Second } return 30 * time.Second }, KeyGenerator: func(c fiber.Ctx) string { return c.Get("x-forwarded-for") }, LimitReached: func(c fiber.Ctx) error { return c.SendFile("./toofast.html") }, Storage: myCustomStorage{}, })) ``` ## Sliding window Instead of using the standard fixed window algorithm, you can enable the [sliding window](https://en.wikipedia.org/wiki/Sliding_window_protocol) algorithm. An example configuration is: ```go app.Use(limiter.New(limiter.Config{ Max: 20, Expiration: 30 * time.Second, LimiterMiddleware: limiter.SlidingWindow{}, })) ``` Each new window also considers the previous one (if any). The rate is calculated as: ```text weightOfPreviousWindow = previousWindowRequests * (elapsedInCurrentWindow / Expiration) rate = weightOfPreviousWindow + currentWindowRequests ``` ## Dynamic limit You can also calculate the limit dynamically using the `MaxFunc` parameter. It receives the request context and allows you to compute a different limit for each request. Example: ```go app.Use(limiter.New(limiter.Config{ MaxFunc: func(c fiber.Ctx) int { return getUserLimit(ctx.Param("id")) }, Expiration: 30 * time.Second, })) ``` ## Dynamic expiration You can also calculate the expiration dynamically using the `ExpirationFunc` parameter. It receives the request context and allows you to set a different expiration window for each request. Example: ```go app.Use(limiter.New(limiter.Config{ Max: 20, ExpirationFunc: func(c fiber.Ctx) time.Duration { return getExpirationForRoute(c.Path()) }, })) ``` ## Config | Property | Type | Description | Default | |:-----------------------|:--------------------------|:--------------------------------------------------------------------------------------------|:-----------------------------------------| | Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when it returns true. | `nil` | | Max | `int` | Maximum number of recent connections within `Expiration` seconds before sending a 429 response. | 5 | | MaxFunc | `func(fiber.Ctx) int` | Function that calculates the maximum number of recent connections within `Expiration` seconds before sending a 429 response. | A function that returns `cfg.Max` | | KeyGenerator | `func(fiber.Ctx) string` | Function to generate custom keys; uses `c.IP()` by default. | A function using `c.IP()` as the default | | Expiration | `time.Duration` | Duration to keep request records in memory. | 1 * time.Minute | | ExpirationFunc | `func(fiber.Ctx) time.Duration` | Function that calculates the expiration duration dynamically. | A function that returns `cfg.Expiration` | | LimitReached | `fiber.Handler` | Called when a request exceeds the limit. | A function sending a 429 response | | SkipFailedRequests | `bool` | When set to `true`, requests with status code ≥ 400 aren't counted. | false | | SkipSuccessfulRequests | `bool` | When set to `true`, requests with status code < 400 aren't counted. | false | | DisableHeaders | `bool` | When set to `true`, the middleware omits rate limit headers (`X-RateLimit-*` and `Retry-After`). | false | | DisableValueRedaction | `bool` | Disables redaction of limiter keys in error messages and logs. | false | | Storage | `fiber.Storage` | Persists middleware state. | An in-memory store for this process only | | LimiterMiddleware | `limiter.Handler` | Selects the algorithm implementation. Implementations now receive a pointer to the active config when their `New` method is invoked. | A new Fixed Window Rate Limiter | :::note A custom store can be used if it implements the `Storage` interface - more details and an example can be found in `store.go`. ::: ## Default Config ```go var ConfigDefault = Config{ Max: 5, MaxFunc: func(c fiber.Ctx) int { return 5 }, Expiration: 1 * time.Minute, // ExpirationFunc defaults to nil and is set dynamically to return cfg.Expiration KeyGenerator: func(c fiber.Ctx) string { return c.IP() }, LimitReached: func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusTooManyRequests) }, SkipFailedRequests: false, SkipSuccessfulRequests: false, DisableHeaders: false, DisableValueRedaction: false, LimiterMiddleware: FixedWindow{}, } ``` ### Custom Storage/Database You can use any storage from our [storage](https://github.com/gofiber/storage/) package. ```go storage := sqlite3.New() // From github.com/gofiber/storage/sqlite3/v2 app.Use(limiter.New(limiter.Config{ Storage: storage, })) ``` ================================================ FILE: docs/middleware/logger.md ================================================ --- id: logger --- # Logger Logger middleware for [Fiber](https://github.com/gofiber/fiber) that logs HTTP requests and responses. ## Signatures ```go func New(config ...Config) fiber.Handler ``` ## Examples Import the package: ```go import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/logger" ) ``` :::tip Registration order matters: only routes added after the logger are logged, so register it early. ::: Once your Fiber app is initialized, use the middleware like this: ```go // Initialize default config app.Use(logger.New()) // Or extend your config for customization // Log remote IP and port app.Use(logger.New(logger.Config{ Format: "[${ip}]:${port} ${status} - ${method} ${path}\n", })) // Logging Request ID app.Use(requestid.New()) // Ensure requestid middleware is used before the logger app.Use(logger.New(logger.Config{ CustomTags: map[string]logger.LogFunc{ "requestid": func(output logger.Buffer, c fiber.Ctx, data *logger.Data, extraParam string) (int, error) { return output.WriteString(requestid.FromContext(c)) }, }, // For more options, see the Config section // Use the custom tag ${requestid} as defined above. Format: "${pid} ${requestid} ${status} - ${method} ${path}\n", })) // Changing TimeZone & TimeFormat app.Use(logger.New(logger.Config{ Format: "${pid} ${status} - ${method} ${path}\n", TimeFormat: "02-Jan-2006", TimeZone: "America/New_York", })) // Custom File Writer accessLog, err := os.OpenFile("./access.log", os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666) if err != nil { log.Fatalf("error opening access.log file: %v", err) } defer accessLog.Close() app.Use(logger.New(logger.Config{ Stream: accessLog, })) // Add Custom Tags app.Use(logger.New(logger.Config{ CustomTags: map[string]logger.LogFunc{ "custom_tag": func(output logger.Buffer, c fiber.Ctx, data *logger.Data, extraParam string) (int, error) { return output.WriteString("it is a custom tag") }, }, })) // Callback after log is written app.Use(logger.New(logger.Config{ TimeFormat: time.RFC3339Nano, TimeZone: "Asia/Shanghai", Done: func(c fiber.Ctx, logString []byte) { if c.Response().StatusCode() != fiber.StatusOK { reporter.SendToSlack(logString) } }, })) // Disable colors when outputting to default format app.Use(logger.New(logger.Config{ DisableColors: true, })) // Force the use of colors app.Use(logger.New(logger.Config{ ForceColors: true, })) // Use predefined formats app.Use(logger.New(logger.Config{ Format: logger.CommonFormat, })) app.Use(logger.New(logger.Config{ Format: logger.CombinedFormat, })) app.Use(logger.New(logger.Config{ Format: logger.JSONFormat, })) app.Use(logger.New(logger.Config{ Format: logger.ECSFormat, })) ``` ### Use Logger Middleware with Other Loggers To combine the logger middleware with loggers like Zerolog, Zap, or Logrus, use the `LoggerToWriter` helper to adapt them to an `io.Writer`. ```go package main import ( "github.com/gofiber/contrib/fiberzap/v2" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/log" "github.com/gofiber/fiber/v3/middleware/logger" ) func main() { // Create a new Fiber instance app := fiber.New() // Create a new zap logger which is compatible with Fiber AllLogger interface zap := fiberzap.NewLogger(fiberzap.LoggerConfig{ ExtraKeys: []string{"request_id"}, }) // Use the logger middleware with the zap logger app.Use(logger.New(logger.Config{ Stream: logger.LoggerToWriter(zap, log.LevelDebug), })) // Define a route app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) // Start server on http://localhost:3000 app.Listen(":3000") } ``` :::tip Writing to `os.File` is goroutine-safe, but custom streams may require locking to serialize writes. ::: ## Config | Property | Type | Description | Default | | :------------ | :------------------------------------------------ | :-------------------------------------------------------------------------------------------------------------------------------------------- | :-------------------------------------------------------------------- | | Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when it returns true. | `nil` | | Skip | `func(fiber.Ctx) bool` | Skip is a function to determine if logging is skipped or written to Stream. | `nil` | | Done | `func(fiber.Ctx, []byte)` | Done is a function that is called after the log string for a request is written to Stream, and pass the log string as parameter. | `nil` | | CustomTags | `map[string]LogFunc` | tagFunctions defines the custom tag action. | `map[string]LogFunc` | | `Format` | `string` | Defines the logging tags. See more in [Predefined Formats](#predefined-formats), or create your own using [Tags](#constants). | `[${time}] ${ip} ${status} - ${latency} ${method} ${path} ${error}\n` (same as `DefaultFormat`) | | TimeFormat | `string` | TimeFormat defines the time format for log timestamps. | `15:04:05` | | TimeZone | `string` | TimeZone can be specified, such as "UTC" and "America/New_York" and "Asia/Chongqing", etc | `"Local"` | | TimeInterval | `time.Duration` | TimeInterval is the delay before the timestamp is updated. | `500 * time.Millisecond` | | Stream | `io.Writer` | Stream is a writer where logs are written. | `os.Stdout` | | LoggerFunc | `func(c fiber.Ctx, data *Data, cfg *Config) error` | Custom logger function for integration with logging libraries (Zerolog, Zap, Logrus, etc). Defaults to Fiber's default logger if not defined. | `see default_logger.go defaultLoggerInstance` | | DisableColors | `bool` | DisableColors defines if the logs output should be colorized. | `false` | | ForceColors | `bool` | ForceColors defines if the logs output should be colorized even when the output is not a terminal. | `false` | ## Default Config ```go var ConfigDefault = Config{ Next: nil, Skip: nil, Done: nil, Format: DefaultFormat, TimeFormat: "15:04:05", TimeZone: "Local", TimeInterval: 500 * time.Millisecond, Stream: os.Stdout, BeforeHandlerFunc: beforeHandlerFunc, LoggerFunc: defaultLoggerInstance, enableColors: true, } ``` ## Predefined Formats Logger provides predefined formats that you can use by name or directly by specifying the format string. | **Format Constant** | **Format String** | **Description** | |---------------------|--------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------| | `DefaultFormat` | `"[${time}] ${ip} ${status} - ${latency} ${method} ${path} ${error}\n"` | Fiber's default logger format. | | `CommonFormat` | `"${ip} - - [${time}] "${method} ${url} ${protocol}" ${status} ${bytesSent}\n"` | Common Log Format (CLF) used in web server logs. | | `CombinedFormat` | `"${ip} - - [${time}] "${method} ${url} ${protocol}" ${status} ${bytesSent} "${referer}" "${ua}"\n"` | CLF format plus the `referer` and `user agent` fields. | | `JSONFormat` | `"{time: ${time}, ip: ${ip}, method: ${method}, url: ${url}, status: ${status}, bytesSent: ${bytesSent}}\n"` | JSON format for structured logging. | | `ECSFormat` | `"{\"@timestamp\":\"${time}\",\"ecs\":{\"version\":\"1.6.0\"},\"client\":{\"ip\":\"${ip}\"},\"http\":{\"request\":{\"method\":\"${method}\",\"url\":\"${url}\",\"protocol\":\"${protocol}\"},\"response\":{\"status_code\":${status},\"body\":{\"bytes\":${bytesSent}}}},\"log\":{\"level\":\"INFO\",\"logger\":\"fiber\"},\"message\":\"${method} ${url} responded with ${status}\"}\n"` | Elastic Common Schema (ECS) format for structured logging. | ## Constants ```go // Logger variables const ( TagPid = "pid" TagTime = "time" TagReferer = "referer" TagProtocol = "protocol" TagPort = "port" TagIP = "ip" TagIPs = "ips" TagHost = "host" TagMethod = "method" TagPath = "path" TagURL = "url" TagUA = "ua" TagLatency = "latency" TagStatus = "status" // response status TagResBody = "resBody" // response body TagReqHeaders = "reqHeaders" TagQueryStringParams = "queryParams" // request query parameters TagBody = "body" // request body TagBytesSent = "bytesSent" TagBytesReceived = "bytesReceived" TagRoute = "route" TagError = "error" TagReqHeader = "reqHeader:" // request header TagRespHeader = "respHeader:" // response header TagQuery = "query:" // request query TagForm = "form:" // request form TagCookie = "cookie:" // request cookie TagLocals = "locals:" // colors TagBlack = "black" TagRed = "red" TagGreen = "green" TagYellow = "yellow" TagBlue = "blue" TagMagenta = "magenta" TagCyan = "cyan" TagWhite = "white" TagReset = "reset" ) ``` ================================================ FILE: docs/middleware/paginate.md ================================================ --- id: paginate --- # Paginate Pagination middleware for [Fiber](https://github.com/gofiber/fiber) that extracts pagination parameters from query strings and stores them in the request context. Supports page-based, offset-based, and cursor-based pagination with multi-field sorting. ## Signatures ```go func New(config ...Config) fiber.Handler func FromContext(ctx any) (*PageInfo, bool) ``` `FromContext` accepts `fiber.CustomCtx`, `fiber.Ctx`, `*fasthttp.RequestCtx`, or `context.Context`. ## Examples Import the middleware package: ```go import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/paginate" ) ``` Once your Fiber app is initialized, choose one of the following approaches: ### Basic Usage ```go app.Use(paginate.New()) app.Get("/users", func(c fiber.Ctx) error { pageInfo, ok := paginate.FromContext(c) if !ok { return fiber.ErrBadRequest } // Use pageInfo.Page, pageInfo.Limit, pageInfo.Start() // GET /users?page=2&limit=20 → Page: 2, Limit: 20, Start(): 20 return c.JSON(pageInfo) }) ``` ### Sorting ```go app.Use(paginate.New(paginate.Config{ SortKey: "sort", DefaultSort: "id", AllowedSorts: []string{"id", "name", "created_at"}, })) // GET /users?sort=name,-created_at // → Sort: [{Field: "name", Order: "asc"}, {Field: "created_at", Order: "desc"}] ``` ### Cursor Pagination ```go app.Use(paginate.New()) app.Get("/feed", func(c fiber.Ctx) error { pageInfo, ok := paginate.FromContext(c) if !ok { return fiber.ErrBadRequest } if pageInfo.Cursor != "" { // Decode the cursor to get keyset values values := pageInfo.CursorValues() // Use values["id"], values["created_at"], etc. for WHERE clause } // results is a slice of items from your database query // After fetching results, set the next cursor for the client if len(results) > 0 { lastItem := results[len(results)-1] if err := pageInfo.SetNextCursor(map[string]any{ "id": lastItem.ID, "created_at": lastItem.CreatedAt, }); err != nil { return err } } return c.JSON(fiber.Map{ "data": results, "has_more": pageInfo.HasMore, "next_cursor": pageInfo.NextCursor, }) }) // First request: GET /feed?limit=20 // Next request: GET /feed?cursor=&limit=20 ``` ### Custom Configuration ```go app.Use(paginate.New(paginate.Config{ PageKey: "p", LimitKey: "size", DefaultPage: 1, DefaultLimit: 25, SortKey: "order_by", DefaultSort: "created_at", AllowedSorts: []string{"created_at", "name", "email"}, CursorKey: "after", CursorParam: "starting_after", })) ``` ## Config | Property | Type | Description | Default | |:-------------|:-------------------------|:-------------------------------------------------------------------|:-----------| | Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when returned true. | `nil` | | PageKey | `string` | Query string key for page number. | `"page"` | | DefaultPage | `int` | Default page number. | `1` | | LimitKey | `string` | Query string key for limit. | `"limit"` | | DefaultLimit | `int` | Default items per page. | `10` | | SortKey | `string` | Query string key for sort. | `""` | | DefaultSort | `string` | Default sort field. | `"id"` | | AllowedSorts | `[]string` | Whitelist of allowed sort fields. If nil, all fields are allowed. | `nil` | | OffsetKey | `string` | Query string key for offset. | `"offset"` | | CursorKey | `string` | Query string key for cursor-based pagination. | `"cursor"` | | CursorParam | `string` | Optional alias for the cursor query key. | `""` | | MaxLimit | `int` | Maximum items per page. | `100` | ## Default Config ```go var ConfigDefault = Config{ Next: nil, PageKey: "page", DefaultPage: 1, LimitKey: "limit", DefaultLimit: 10, MaxLimit: 100, DefaultSort: "id", OffsetKey: "offset", CursorKey: "cursor", } ``` ## PageInfo The `PageInfo` struct is stored in the request context and provides: | Method | Description | |:------------------------------------------------|:---------------------------------------------------------------| | `Start() int` | Returns calculated start index (from page/limit or offset) | | `SortBy(field, order)` | Adds a sort field (chainable) | | `NextPageURL(baseURL)` | Generates next page URL with default keys | | `NextPageURLWithKeys(baseURL, pageKey, limitKey)` | Generates next page URL with custom query keys | | `PreviousPageURL(baseURL)` | Generates previous page URL (empty on page 1) | | `PreviousPageURLWithKeys(baseURL, pageKey, limitKey)` | Generates previous page URL with custom query keys | | `NextCursorURL(baseURL)` | Generates next cursor URL (empty if no more) | | `NextCursorURLWithKeys(baseURL, cursorKey, limitKey)` | Generates next cursor URL with custom query keys | | `CursorValues()` | Decodes cursor token into key-value map | | `SetNextCursor(values)` | Encodes values into cursor token, sets HasMore; returns error | ## Safety - Limit is capped at `MaxLimit` (default: 100, configurable) to prevent excessive memory usage - Page values below 1 reset to 1 - Negative offsets reset to 0 - Sort fields are validated against `AllowedSorts` - Cursor tokens exceeding 2048 characters are rejected with `400 Bad Request` - `SetNextCursor` returns an error if the encoded token would exceed 2048 characters, preventing the server from issuing cursors it would later reject - Invalid cursor tokens return `400 Bad Request` via Fiber's error handler - If `DefaultSort` is not included in `AllowedSorts`, it falls back to the first allowed sort field - URL helpers preserve existing query parameters when building pagination links ================================================ FILE: docs/middleware/pprof.md ================================================ --- id: pprof --- # Pprof Pprof middleware exposes runtime profiling data for analysis with the Go `pprof` tool. Importing it registers handlers under `/debug/pprof/`. ## Signatures ```go func New(config ...Config) fiber.Handler ``` ## Examples Import the middleware package: ```go import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/pprof" ) ``` Once your Fiber app is initialized, use the middleware as shown: ```go // Initialize default config app.Use(pprof.New()) // Or customize the config // For multi-ingress systems, add a URL prefix: app.Use(pprof.New(pprof.Config{Prefix: "/endpoint-prefix"})) // The resulting URL is "/endpoint-prefix/debug/pprof/" ``` ## Config | Property | Type | Description | Default | |:---------|:------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-------:| | Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when it returns true. | `nil` | | Prefix | `string` | Prefix adds a segment before `/debug/pprof`; it must start with a slash and omit the trailing slash. Example: `/federated-fiber` | `""` | ## Default Config ```go var ConfigDefault = Config{ Next: nil, } ``` ================================================ FILE: docs/middleware/proxy.md ================================================ --- id: proxy --- # Proxy The Proxy middleware forwards requests to one or more upstream servers. ## Signatures ```go // Balancer creates a load balancer among multiple upstream servers. func Balancer(config ...Config) fiber.Handler // Forward performs the given http request and fills the given http response. func Forward(addr string, clients ...*fasthttp.Client) fiber.Handler // Do performs the given http request and fills the given http response. func Do(c fiber.Ctx, addr string, clients ...*fasthttp.Client) error // DoRedirects performs the given http request and fills the given http response while following up to maxRedirectsCount redirects. func DoRedirects(c fiber.Ctx, addr string, maxRedirectsCount int, clients ...*fasthttp.Client) error // DoDeadline performs the given request and waits for response until the given deadline. func DoDeadline(c fiber.Ctx, addr string, deadline time.Time, clients ...*fasthttp.Client) error // DoTimeout performs the given request and waits for response during the given timeout duration. func DoTimeout(c fiber.Ctx, addr string, timeout time.Duration, clients ...*fasthttp.Client) error // DomainForward performs the given http request based on the provided domain and fills the given http response. func DomainForward(hostname string, addr string, clients ...*fasthttp.Client) fiber.Handler // BalancerForward performs the given http request based round robin balancer and fills the given http response. func BalancerForward(servers []string, clients ...*fasthttp.Client) fiber.Handler ``` ## Examples Import the middleware package: ```go import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/proxy" ) ``` Once your Fiber app is initialized, you can use the middleware as shown: ```go // Use proxy.WithClient to set a global custom client. proxy.WithClient(&fasthttp.Client{ NoDefaultUserAgentHeader: true, DisablePathNormalizing: true, // Allow self-signed certificates when proxying to HTTPS targets. TLSConfig: &tls.Config{ InsecureSkipVerify: true, }, }) // Forward requests for a specific domain with proxy.DomainForward. app.Get("/payments", proxy.DomainForward("docs.gofiber.io", "http://localhost:8000")) // Forward to a URL using a custom client app.Get("/gif", proxy.Forward("https://i.imgur.com/IWaBepg.gif", &fasthttp.Client{ NoDefaultUserAgentHeader: true, DisablePathNormalizing: true, })) // Make a proxied request within a handler app.Get("/:id", func(c fiber.Ctx) error { url := "https://i.imgur.com/" + c.Params("id") + ".gif" if err := proxy.Do(c, url); err != nil { return err } // Remove Server header from response c.Response().Header.Del(fiber.HeaderServer) return nil }) // Proxy requests while following redirects app.Get("/proxy", func(c fiber.Ctx) error { if err := proxy.DoRedirects(c, "http://google.com", 3); err != nil { return err } // Remove Server header from response c.Response().Header.Del(fiber.HeaderServer) return nil }) // Proxy requests and wait up to five seconds before timing out app.Get("/proxy", func(c fiber.Ctx) error { if err := proxy.DoTimeout(c, "http://localhost:3000", time.Second * 5); err != nil { return err } // Remove Server header from response c.Response().Header.Del(fiber.HeaderServer) return nil }) // Proxy requests with a deadline one minute from now app.Get("/proxy", func(c fiber.Ctx) error { if err := proxy.DoDeadline(c, "http://localhost", time.Now().Add(time.Minute)); err != nil { return err } // Remove Server header from response c.Response().Header.Del(fiber.HeaderServer) return nil }) // Minimal round-robin balancer app.Use(proxy.Balancer(proxy.Config{ Servers: []string{ "http://localhost:3001", "http://localhost:3002", "http://localhost:3003", }, })) // Keep the Connection header when proxying app.Use(proxy.Balancer(proxy.Config{ Servers: []string{ "http://localhost:3001", }, KeepConnectionHeader: true, })) // Or extend your balancer for customization app.Use(proxy.Balancer(proxy.Config{ Servers: []string{ "http://localhost:3001", "http://localhost:3002", "http://localhost:3003", }, ModifyRequest: func(c fiber.Ctx) error { c.Request().Header.Add("X-Real-IP", c.IP()) return nil }, ModifyResponse: func(c fiber.Ctx) error { c.Response().Header.Del(fiber.HeaderServer) return nil }, })) // Or this way if the balancer is using https and the destination server is only using http. app.Use(proxy.BalancerForward([]string{ "http://localhost:3001", "http://localhost:3002", "http://localhost:3003", })) // Make round robin balancer with IPv6 support. app.Use(proxy.Balancer(proxy.Config{ Servers: []string{ "http://[::1]:3001", "http://127.0.0.1:3002", "http://localhost:3003", }, // Enable TCP4 and TCP6 network stacks. DialDualStack: true, })) ``` ## Config | Property | Type | Description | Default | |:----------------|:-----------------------------------------------|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:----------------| | Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when it returns true. | `nil` | | Servers | `[]string` | Servers defines a list of `://` HTTP servers, which are used in a round-robin manner. i.e.: "[https://foobar.com](https://foobar.com), [http://www.foobar.com](http://www.foobar.com)" | (Required) | | ModifyRequest | `fiber.Handler` | ModifyRequest allows you to alter the request. | `nil` | | ModifyResponse | `fiber.Handler` | ModifyResponse allows you to alter the response. | `nil` | | Timeout | `time.Duration` | Timeout is the request timeout used when calling the proxy client. | 1 second | | ReadBufferSize | `int` | Per-connection buffer size for requests' reading. This also limits the maximum header size. Increase this buffer if your clients send multi-KB RequestURIs and/or multi-KB headers (for example, BIG cookies). | (Not specified) | | WriteBufferSize | `int` | Per-connection buffer size for responses' writing. | (Not specified) | | KeepConnectionHeader | `bool` | Keeps the `Connection` header when set to `true`. By default the header is removed to comply with RFC 7230 §6.1 and avoid proxy loops. | `false` | | TLSConfig | `*tls.Config` | TLS config for the HTTP client. | `nil` | | DialDualStack | `bool` | Client will attempt to connect to both IPv4 and IPv6 host addresses if set to true. | `false` | | Client | `*fasthttp.LBClient` | Client is a custom client when client config is complex. | `nil` | ## Default Config ```go var ConfigDefault = Config{ Next: nil, ModifyRequest: nil, ModifyResponse: nil, Timeout: fasthttp.DefaultLBClientTimeout, KeepConnectionHeader: false, } ``` ================================================ FILE: docs/middleware/recover.md ================================================ --- id: recover --- # Recover The Recover middleware for [Fiber](https://github.com/gofiber/fiber) intercepts panics and forwards them to the central [ErrorHandler](../guide/error-handling). ## Signatures ```go func New(config ...Config) fiber.Handler ``` ## Examples Import the middleware package: ```go import ( "github.com/gofiber/fiber/v3" recoverer "github.com/gofiber/fiber/v3/middleware/recover" ) ``` Once your Fiber app is initialized, use the middleware like this: ```go // Initialize default config app.Use(recoverer.New()) // Panics in subsequent handlers are caught by the middleware app.Get("/", func(c fiber.Ctx) error { panic("I'm an error") }) ``` ## Config | Property | Type | Description | Default | |:------------------|:-----------------------------|:------------------------------------------------------|:---------------------------| | Next | `func(fiber.Ctx) bool` | Skip when the function returns `true`. | `nil` | | PanicHandler | `func(fiber.Ctx, any) error` | Customize the error returned from a recovered panic. | `DefaultPanicHandler` | | EnableStackTrace | `bool` | Capture and include a stack trace in error responses. | `false` | | StackTraceHandler | `func(fiber.Ctx, any)` | Handle the captured stack trace when enabled. | `defaultStackTraceHandler` | ## Default Config ```go var ConfigDefault = recoverer.Config{ Next: nil, PanicHandler: DefaultPanicHandler, StackTraceHandler: defaultStackTraceHandler, EnableStackTrace: false, } // Set up a PanicHandler to hide internals. app.Use(recoverer.New(recoverer.Config{PanicHandler: func(c fiber.Ctx, r any) error { return fiber.ErrInternalServerError }})) // In more elaborate scenarios you can also create a custom error which can be processed differently in the fiber.ErrorHandler. // See the tests for an example of such an ErrorHandler. // You could also just wrap the default handler's error, e.g. fmt.Errorf("[RECOVERED]: %w", recoverer.DefaultPanicHandler(c, r)) app.Use(recoverer.New(recoverer.Config{PanicHandler: func(c fiber.Ctx, r any) error { return &MyCustomRecoveredFromPanicError { Inner: recoverer.DefaultPanicHandler(c, r), } }})) ``` ================================================ FILE: docs/middleware/redirect.md ================================================ --- id: redirect --- # Redirect Redirect middleware maps old URLs to new ones using simple rules. ## Signatures ```go func New(config ...Config) fiber.Handler ``` ## Examples ```go package main import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/redirect" ) func main() { app := fiber.New() app.Use(redirect.New(redirect.Config{ Rules: map[string]string{ "/old": "/new", "/old/*": "/new/$1", }, StatusCode: fiber.StatusMovedPermanently, })) app.Get("/new", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) app.Get("/new/*", func(c fiber.Ctx) error { return c.SendString("Wildcard: " + c.Params("*")) }) app.Listen(":3000") } ``` ## Test ```bash curl http://localhost:3000/old curl http://localhost:3000/old/hello ``` ## Config | Property | Type | Description | Default | |:-----------|:--------------------|:------------------------------------------|:-----------------------| | Next | `func(fiber.Ctx) bool` | Skip when function returns true. | nil | | Rules | `map[string]string` | Map paths to new ones; `$1`, `$2` insert params. | Required | | StatusCode | `int` | HTTP code for redirects. | 302 Temporary Redirect | ## Default Config ```go var ConfigDefault = Config{ StatusCode: fiber.StatusFound, } ``` ================================================ FILE: docs/middleware/requestid.md ================================================ --- id: requestid --- # RequestID The RequestID middleware generates or propagates a request identifier, adding it to the response headers and request context. ## Signatures ```go func New(config ...Config) fiber.Handler func FromContext(ctx any) string ``` `FromContext` accepts a `fiber.CustomCtx`, `fiber.Ctx`, a `*fasthttp.RequestCtx`, or a `context.Context`. ## Examples Import the middleware package: ```go import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/requestid" ) ``` Once your Fiber app is initialized, add the middleware like this: ```go // Initialize default config app.Use(requestid.New()) // Or extend your config for customization app.Use(requestid.New(requestid.Config{ Header: "X-Custom-Header", Generator: func() string { return "static-id" }, })) ``` If the request already includes the configured header, that value is reused instead of generating a new one. The middleware rejects IDs containing characters outside the visible ASCII range (for example, control characters or obs-text bytes) and will regenerate the value using up to three attempts from the configured generator (or SecureToken when no generator is set). When a custom generator fails to produce a valid ID, the middleware falls back to SecureToken to keep headers RFC-compliant across transports. Retrieve the request ID ```go func handler(c fiber.Ctx) error { id := requestid.FromContext(c) log.Printf("Request ID: %s", id) return c.SendString("Hello, World!") } ``` ## Config | Property | Type | Description | Default | |:----------|:---------------------|:-----------------------------------------|:---------------| | Next | `func(fiber.Ctx) bool` | Skip when the function returns `true`. | `nil` | | Header | `string` | Header key used to store the request ID. | "X-Request-ID" | | Generator | `func() string` | Function that generates the identifier. | utils.SecureToken | ## Default Config The default config uses a cryptographically secure token generator for better security and privacy. ```go var ConfigDefault = Config{ Next: nil, Header: fiber.HeaderXRequestID, Generator: utils.SecureToken, } ``` ================================================ FILE: docs/middleware/responsetime.md ================================================ --- id: responsetime --- # ResponseTime Response time middleware for [Fiber](https://github.com/gofiber/fiber) that measures the time spent handling a request and exposes it via a response header. ## Signatures ```go func New(config ...Config) fiber.Handler ``` ## Examples Import the package: ```go import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/responsetime" ) ``` ### Default config ```go app.Use(responsetime.New()) ``` ### Custom header ```go app.Use(responsetime.New(responsetime.Config{ Header: "X-Elapsed", })) ``` ### Skip logic ```go app.Use(responsetime.New(responsetime.Config{ Next: func(c fiber.Ctx) bool { return c.Path() == "/healthz" }, })) ``` ## Config | Property | Type | Description | Default | | :------- | :--- | :---------- | :------ | | Next | `func(c fiber.Ctx) bool` | Defines a function to skip this middleware when it returns `true`. | `nil` | | Header | `string` | Header key used to store the measured response time. If left empty, the default header is used. | `"X-Response-Time"` | ================================================ FILE: docs/middleware/rewrite.md ================================================ --- id: rewrite --- # Rewrite The Rewrite middleware remaps the request path using custom rules, helping with backward compatibility and cleaner URLs. ## Signatures ```go func New(config ...Config) fiber.Handler ``` ## Config | Property | Type | Description | Default | |:---------|:----------------------|:------------------------------------------------------|:-----------| | Next | `func(fiber.Ctx) bool` | Skip when function returns `true`. | `nil` | | Rules | `map[string]string` | Map paths to new values; use `$1`, `$2` for wildcard captures.| (Required) | :::note Rules are stored in a map, so iteration order is undefined. Avoid overlapping patterns if precedence matters. ::: ### Examples ```go package main import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/rewrite" ) func main() { app := fiber.New() app.Use(rewrite.New(rewrite.Config{ Rules: map[string]string{ "/old": "/new", "/old/*": "/new/$1", }, })) app.Get("/new", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) app.Get("/new/*", func(c fiber.Ctx) error { return c.SendString("Wildcard: " + c.Params("*")) }) app.Listen(":3000") } ``` ## Test ```bash curl http://localhost:3000/old curl http://localhost:3000/old/hello ``` ================================================ FILE: docs/middleware/session.md ================================================ --- id: session --- # Session The Session middleware adds session management to Fiber apps through the [Storage](https://github.com/gofiber/storage) package, which offers a unified interface for multiple databases. By default, sessions live in memory, but you can plug in any storage backend. ## Table of Contents - [Quick Start](#quick-start) - [Usage Patterns](#usage-patterns) - [Session Security](#session-security) - [Session ID Extractors](#session-id-extractors) - [Configuration](#configuration) - [Migration Guide](#migration-guide) - [API Reference](#api-reference) - [Examples](#examples) ## Quick Start ```go import ( "fmt" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/session" ) // Basic usage app.Use(session.New()) app.Get("/", func(c fiber.Ctx) error { sess := session.FromContext(c) // Get and update visits count var visits int if v := sess.Get("visits"); v != nil { // Use type assertion with an ok check to prevent a panic if vInt, ok := v.(int); ok { visits = vInt } } visits++ sess.Set("visits", visits) return c.SendString(fmt.Sprintf("Visits: %d", visits)) }) ``` ### Production Configuration ```go import ( "time" "github.com/gofiber/fiber/v3/extractors" "github.com/gofiber/storage/redis/v3" ) storage := redis.New(redis.Config{ Host: "localhost", Port: 6379, }) app.Use(session.New(session.Config{ Storage: storage, CookieSecure: true, // HTTPS only CookieHTTPOnly: true, // Prevent XSS CookieSameSite: "Lax", // CSRF protection IdleTimeout: 30 * time.Minute, // Session timeout AbsoluteTimeout: 24 * time.Hour, // Maximum session life Extractor: extractors.FromCookie("__Host-session_id"), })) Notes: - AbsoluteTimeout must be greater than or equal to IdleTimeout; otherwise, the middleware panics during configuration. - If CookieSameSite is set to "None", the middleware automatically forces CookieSecure=true when setting the cookie. ``` ## Usage Patterns ### Middleware Pattern (Recommended) This pattern automatically manages the session lifecycle and is recommended for most applications. ```go // Setup middleware app.Use(session.New()) // Use in handlers app.Post("/login", func(c fiber.Ctx) error { sess := session.FromContext(c) // Session is automatically saved when handler returns sess.Set("user_id", 123) sess.Set("authenticated", true) return c.Redirect("/dashboard") }) ``` **Benefits:** - Automatic session saving - Automatic resource cleanup - No manual lifecycle management - Thread-safe operations ### Store Pattern (Advanced) Use the store pattern for background tasks or when you need direct access to sessions. ```go import ( "context" "log" "time" ) store := session.NewStore() // In background tasks func backgroundTask(sessionID string) { sess, err := store.GetByID(context.Background(), sessionID) if err != nil { return } defer sess.Release() // Important: Manual cleanup required // Modify session sess.Set("last_task", time.Now()) // Manual save required if err := sess.Save(); err != nil { log.Printf("Failed to save session: %v", err) } } ``` **Requirements:** - Must call `sess.Release()` when done - Must call `sess.Save()` to persist changes - Handle errors manually ## Session Security ### Authentication Flow Understanding session lifecycle during authentication is crucial for security. #### Basic Login/Logout ```go app.Post("/login", func(c fiber.Ctx) error { sess := session.FromContext(c) email := c.FormValue("email") password := c.FormValue("password") // Simple credential validation (use proper authentication in production) if email == "admin@example.com" && password == "secret" { // Important: Regenerate the session ID to prevent fixation // This changes the session ID while preserving existing data if err := sess.Regenerate(); err != nil { return c.Status(500).SendString("Session error") } // Add authentication data to existing session sess.Set("user_id", 1) sess.Set("authenticated", true) return c.Redirect("/dashboard") } return c.Status(401).SendString("Invalid credentials") }) app.Post("/logout", func(c fiber.Ctx) error { sess := session.FromContext(c) // Complete session reset (clears all data + new session ID) if err := sess.Reset(); err != nil { return c.Status(500).SendString("Session error") } return c.Redirect("/") }) ``` #### Cart Preservation During Login ```go app.Post("/login", func(c fiber.Ctx) error { sess := session.FromContext(c) // Validate credentials (implement your own validation) email := c.FormValue("email") password := c.FormValue("password") if !isValidUser(email, password) { return c.Status(401).JSON(fiber.Map{"error": "Invalid credentials"}) } // Important: Regenerate the session ID to prevent fixation // This changes the session ID while preserving existing data if err := sess.Regenerate(); err != nil { return c.Status(500).JSON(fiber.Map{"error": "Session error"}) } // Add authentication data to existing session sess.Set("user_id", getUserID(email)) sess.Set("authenticated", true) sess.Set("login_time", time.Now()) return c.JSON(fiber.Map{"status": "logged in"}) }) ``` ### Security Methods Comparison | Method | Session ID | Session Data | Use Case | |--------|------------|--------------|----------| | `Regenerate()` | ✅ Changes | ✅ Preserved | Login, privilege escalation | | `Reset()` | ✅ Changes | ❌ Cleared | Logout, security breach | | `Destroy()` | ⚪ Unchanged | ❌ Cleared | Clear data only | ### Common Security Mistakes ❌ **Session Fixation Vulnerability:** ```go // DANGEROUS: Keeping same session ID after login app.Post("/login", func(c fiber.Ctx) error { sess := session.FromContext(c) // Validate user... sess.Set("user_id", userID) // Attacker can hijack this session! return c.Redirect("/dashboard") }) ``` ✅ **Secure Implementation:** ```go // SECURE: Always regenerate session ID after authentication app.Post("/login", func(c fiber.Ctx) error { sess := session.FromContext(c) // Validate user... if err := sess.Regenerate(); err != nil { // Prevents session fixation return err } sess.Set("user_id", userID) return c.Redirect("/dashboard") }) ``` ### Authentication Middleware This is a basic example of an authentication middleware that checks if a user is logged in before accessing protected routes. ```go // Authentication check middleware func RequireAuth(c fiber.Ctx) error { sess := session.FromContext(c) if sess == nil { return c.Redirect("/login") } // Check if user is authenticated if sess.Get("authenticated") != true { return c.Redirect("/login") } return c.Next() } // Usage app.Use("/dashboard", RequireAuth) app.Use("/admin", RequireAuth) ``` ### Automatic Session Expiration Sessions automatically expire based on your configuration: ```go app.Use(session.New(session.Config{ IdleTimeout: 30 * time.Minute, // Auto-expire after 30 min of inactivity AbsoluteTimeout: 24 * time.Hour, // Force expire after 24 hours regardless of activity })) ``` **How it works:** - `IdleTimeout`: Storage automatically removes sessions after inactivity period - Any route that uses the middleware will reset the idle timer - Calling `sess.Save()` will also reset the idle timer - `AbsoluteTimeout`: Sessions are forcibly expired after maximum duration - No manual cleanup required - the storage layer handles this ## Session ID Extractors This middleware uses the shared extractors module for session ID extraction. See the [Extractors Guide](../guide/extractors) for more details. ### Built-in Extractors ```go // Cookie-based (recommended for web apps) extractors.FromCookie("session_id") // Header-based (recommended for APIs) extractors.FromHeader("X-Session-ID") // Authorization header (read-only) extractors.FromAuthHeader("Bearer") // Form data extractors.FromForm("session_id") // URL query parameter extractors.FromQuery("session_id") // URL path parameter extractors.FromParam("id") ``` **Session Response Behavior:** - Cookie extractors: set cookie in the response - Header extractors (non-Authorization): set header in the response - Authorization header, Query, Form, Param, Custom: read-only (no response values are set) ### Multiple Sources with Fallback ```go app.Use(session.New(session.Config{ Extractor: extractors.Chain( extractors.FromCookie("session_id"), // Try cookie first extractors.FromHeader("X-Session-ID"), // Then header extractors.FromQuery("session_id"), // Finally query ), })) ``` **Response Behavior with Chained Extractors:** Only cookie and non-Authorization header extractors contribute to response setting. Others are read-only. - Cookie + Header (non-Auth) extractors: both cookie and header are set - Only Cookie extractors: only cookie is set - Only Header (non-Auth) extractors: only header is set - Any mix that includes Authorization/Query/Form/Param/Custom: those sources are read-only ```go // This will set both cookie and header in response extractors.Chain( extractors.FromCookie("session_id"), extractors.FromHeader("X-Session-ID") ) // This will set only cookie in response extractors.Chain( extractors.FromCookie("session_id"), extractors.FromQuery("session_id") // Ignored for response ) // This will set nothing in response (read-only mode) extractors.Chain( extractors.FromQuery("session_id"), extractors.FromForm("session_id") ) ``` ### Custom Extractors (Session-specific) Prefer the helper constructors from the extractors module. See the Extractors Guide for the full API; below are session-specific examples and notes. ```go // Authorization Bearer tokens (read-only for sessions) // The session middleware will NOT set Authorization back in the response. app.Use(session.New(session.Config{ Extractor: extractors.FromAuthHeader("Bearer"), })) ``` ```go // Custom read-only header via FromCustom (read-only for sessions) app.Use(session.New(session.Config{ Extractor: extractors.FromCustom("X-Custom-Session", func(c fiber.Ctx) (string, error) { v := c.Get("X-Custom-Session") if v == "" { return "", extractors.ErrNotFound } return v, nil }), })) ``` ## Configuration ### Storage Options ```go import ( "github.com/gofiber/storage/redis/v3" "github.com/gofiber/storage/postgres/v3" ) // Redis (recommended for production) redisStorage := redis.New(redis.Config{ Host: "localhost", Port: 6379, Password: "", Database: 0, }) // PostgreSQL pgStorage := postgres.New(postgres.Config{ Host: "localhost", Port: 5432, Database: "sessions", Username: "user", Password: "pass", }) app.Use(session.New(session.Config{ Storage: redisStorage, })) ``` ### Production Security Settings ```go import ( "log" "time" "github.com/gofiber/utils/v2" "github.com/gofiber/fiber/v3/extractors" ) app.Use(session.New(session.Config{ // Storage Storage: redisStorage, // Security CookieSecure: true, // HTTPS only (required in production) CookieHTTPOnly: true, // No JavaScript access (prevents XSS) CookieSameSite: "Lax", // CSRF protection // Session Management IdleTimeout: 30 * time.Minute, // Inactivity timeout AbsoluteTimeout: 24 * time.Hour, // Maximum session duration // Cookie Settings CookiePath: "/", CookieDomain: "example.com", CookieSessionOnly: false, // Persist across browser restarts // Session ID Extractor: extractors.FromCookie("__Host-session_id"), KeyGenerator: utils.SecureToken, // Error Handling ErrorHandler: func(c fiber.Ctx, err error) { log.Printf("Session error: %v", err) }, })) ``` ### Custom Types Session data supports basic Go types by default: - `string`, `int`, `int8`, `int16`, `int32`, `int64` - `uint`, `uint8`, `uint16`, `uint32`, `uint64` - `bool`, `float32`, `float64` - `[]byte`, `complex64`, `complex128` - `interface{}` For custom types (structs, maps, slices), you must register them for encoding/decoding: ```go import "fmt" type User struct { ID int `json:"id"` Name string `json:"name"` Role string `json:"role"` } // Method 1: Using NewWithStore func main() { app := fiber.New() sessionMiddleware, store := session.NewWithStore() store.RegisterType(User{}) // Register custom type app.Use(sessionMiddleware) app.Get("/", func(c fiber.Ctx) error { sess := session.FromContext(c) // Use custom type sess.Set("user", User{ID: 123, Name: "John", Role: "admin"}) user, ok := sess.Get("user").(User) if ok { return c.JSON(fiber.Map{"user": user.Name, "role": user.Role}) } return c.SendString("No user found") }) app.Listen(":3000") } ``` ```go // Method 2: Using separate store store := session.NewStore() store.RegisterType(User{}) app.Use(session.New(session.Config{ Store: store, })) // Usage in handlers sess.Set("user", User{ID: 123, Name: "John", Role: "admin"}) user, ok := sess.Get("user").(User) if ok { fmt.Printf("User: %s (Role: %s)", user.Name, user.Role) } ``` **Important Notes:** - Custom types must be registered before using them in sessions - Registration must happen during application startup - All instances of the application must register the same types - Types are encoded using Go's `gob` package ## Migration Guide ### v2 to v3 Breaking Changes 1. **Function Signature**: `session.New()` now returns middleware handler, not store 2. **Session ID Extraction**: `KeyLookup` replaced with `Extractor` functions 3. **Lifecycle Management**: Manual `Release()` required for store pattern 4. **Timeout Handling**: `Expiration` split into `IdleTimeout` and `AbsoluteTimeout` ### Migration Examples **v2 Code:** ```go store := session.New(session.Config{ KeyLookup: "cookie:session_id", }) app.Get("/", func(c fiber.Ctx) error { sess, err := store.Get(c) if err != nil { return err } // Session automatically saved and released sess.Set("key", "value") return nil }) ``` **v3 Middleware Pattern (Recommended):** ```go app.Use(session.New(session.Config{ Extractor: extractors.FromCookie("session_id"), })) app.Get("/", func(c fiber.Ctx) error { sess := session.FromContext(c) // Session automatically saved and released sess.Set("key", "value") return nil }) ``` **v3 Store Pattern (Advanced):** ```go store := session.NewStore(session.Config{ Extractor: extractors.FromCookie("session_id"), }) app.Get("/", func(c fiber.Ctx) error { sess, err := store.Get(c) if err != nil { return err } defer sess.Release() // Manual cleanup required sess.Set("key", "value") return sess.Save() // Manual save required }) ``` ### KeyLookup to Extractor Migration | v2 KeyLookup | v3 Extractor | |---------------------------------|------------------------------------------------------------------------------------| | `"cookie:session_id"` | `extractors.FromCookie("session_id")` | | `"header:X-Session-ID"` | `extractors.FromHeader("X-Session-ID")` | | `"query:session_id"` | `extractors.FromQuery("session_id")` | | `"form:session_id"` | `extractors.FromForm("session_id")` | | `"cookie:sid,header:X-Sid"` | `extractors.Chain(extractors.FromCookie("sid"), extractors.FromHeader("X-Sid"))` | ## API Reference ### Middleware Methods (Recommended) ```go sess := session.FromContext(c) // Data operations sess.Get(key any) any sess.Set(key, value any) sess.Delete(key any) sess.Keys() []any // Session management sess.ID() string sess.Fresh() bool sess.Regenerate() error // Change ID, keep data sess.Reset() error // Change ID, clear data sess.Destroy() error // Keep ID, clear data // Store access sess.Store() *session.Store ``` `FromContext` accepts a `fiber.CustomCtx`, `fiber.Ctx`, a `*fasthttp.RequestCtx`, or a `context.Context`. ### Store Methods ```go store := session.NewStore() // Store operations store.Get(c fiber.Ctx) (*session.Session, error) store.GetByID(ctx context.Context, sessionID string) (*session.Session, error) store.Reset(ctx context.Context) error store.Delete(ctx context.Context, sessionID string) error // Type registration store.RegisterType(interface{}) ``` ### Session Methods (Store Pattern) ```go sess, err := store.Get(c) defer sess.Release() // Required! // Same methods as middleware, plus: sess.Save() error // Manual save required sess.SetIdleTimeout(duration) // Per-session timeout sess.Release() // Manual cleanup required ``` ### Extractor Functions ```go // Built-in extractors (import "github.com/gofiber/fiber/v3/extractors") extractors.FromCookie(key string) extractors.Extractor extractors.FromHeader(key string) extractors.Extractor extractors.FromQuery(key string) extractors.Extractor extractors.FromForm(key string) extractors.Extractor extractors.FromParam(key string) extractors.Extractor // Chaining extractors.Chain(extractors ...extractors.Extractor) extractors.Extractor ``` ### Config Properties | Property | Type | Description | Default | |---------------------|-----------------------------|-----------------------------|--------------------------------------------| | `Store` | `*session.Store` | Pre-built session store (use when you need to share/register types) | `nil` (auto-created) | | `Storage` | `fiber.Storage` | Session storage backend (used when creating a store if `Store` is nil) | `memory.New()` | | `Extractor` | `extractors.Extractor` | Session ID extraction | `extractors.FromCookie("session_id")` | | `KeyGenerator` | `func() string` | Session ID generator | `utils.SecureToken` | | `IdleTimeout` | `time.Duration` | Inactivity timeout | `30 * time.Minute` | | `AbsoluteTimeout` | `time.Duration` | Maximum session duration | `0` (unlimited) | | `CookieSecure` | `bool` | HTTPS only | `false` | | `CookieHTTPOnly` | `bool` | No JavaScript access | `false` | | `CookieSameSite` | `string` | SameSite attribute | `"Lax"` | | `CookiePath` | `string` | Cookie path | `""` | | `CookieDomain` | `string` | Cookie domain | `""` | | `CookieSessionOnly` | `bool` | Session cookie | `false` | | `Next` | `func(fiber.Ctx) bool` | Skip middleware when returns true | `nil` | | `ErrorHandler` | `func(fiber.Ctx, error)` | Error callback | `DefaultErrorHandler` | ## Examples ### E-commerce with Cart Persistence ```go import ( "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/session" "github.com/gofiber/fiber/v3/extractors" "github.com/gofiber/storage/redis/v3" ) func main() { app := fiber.New() // Session middleware app.Use(session.New(session.Config{ Storage: redis.New(), CookieSecure: true, CookieHTTPOnly: true, CookieSameSite: "Lax", IdleTimeout: 30 * time.Minute, AbsoluteTimeout: 24 * time.Hour, Extractor: extractors.FromCookie("__Host-cart_session"), })) // Add to cart (anonymous user) app.Post("/cart/add", func(c fiber.Ctx) error { sess := session.FromContext(c) cart, _ := sess.Get("cart").([]string) cart = append(cart, c.FormValue("item_id")) sess.Set("cart", cart) return c.JSON(fiber.Map{"items": len(cart)}) }) // Login (preserve session data) app.Post("/login", func(c fiber.Ctx) error { sess := session.FromContext(c) // Simple validation (implement proper authentication) email := c.FormValue("email") password := c.FormValue("password") if email != "user@example.com" || password != "password" { return c.Status(401).JSON(fiber.Map{"error": "Invalid credentials"}) } // Regenerate session ID for security // This changes the session ID while preserving existing data if err := sess.Regenerate(); err != nil { return c.Status(500).JSON(fiber.Map{"error": "Session error"}) } sess.Set("user_id", 1) sess.Set("authenticated", true) return c.JSON(fiber.Map{"status": "logged in"}) }) // Logout (clear everything) app.Post("/logout", func(c fiber.Ctx) error { sess := session.FromContext(c) // Reset clears all data and generates new session ID if err := sess.Reset(); err != nil { return c.Status(500).JSON(fiber.Map{"error": "Session error"}) } return c.JSON(fiber.Map{"status": "logged out"}) }) app.Listen(":3000") } // Helper functions (implement these properly in production) func isValidUser(email, password string) bool { return email == "user@example.com" && password == "password" } func getUserID(email string) int { return 1 // Return actual user ID from database } ``` ### API with Header-based Sessions ```go import ( "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/session" "github.com/gofiber/fiber/v3/extractors" "github.com/gofiber/storage/redis/v3" ) func main() { app := fiber.New() // API session middleware with header extraction app.Use(session.New(session.Config{ Storage: redis.New(), Extractor: extractors.FromHeader("X-Session-Token"), IdleTimeout: time.Hour, })) // API endpoint app.Post("/api/data", func(c fiber.Ctx) error { sess := session.FromContext(c) // Track API usage count, _ := sess.Get("api_calls").(int) count++ sess.Set("api_calls", count) sess.Set("last_call", time.Now()) return c.JSON(fiber.Map{ "data": "some data", "calls": count, }) }) app.Listen(":3000") } ``` ### Multi-source Session ID Support ```go import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/session" "github.com/gofiber/fiber/v3/extractors" ) func main() { app := fiber.New() // Support multiple sources with priority app.Use(session.New(session.Config{ Extractor: extractors.Chain( extractors.FromCookie("session_id"), // 1st: Cookie (web) extractors.FromHeader("X-Session-ID"), // 2nd: Header (API) extractors.FromQuery("session_id"), // 3rd: Query (fallback) ), })) app.Get("/", func(c fiber.Ctx) error { sess := session.FromContext(c) // Works with any of the above methods return c.JSON(fiber.Map{ "session_id": sess.ID(), "source": "multi-source", }) }) app.Listen(":3000") } ``` ================================================ FILE: docs/middleware/skip.md ================================================ --- id: skip --- # Skip The Skip middleware wraps a handler and bypasses it when the predicate returns `true` for the current request. ## Signatures ```go func New(handler fiber.Handler, exclude func(c fiber.Ctx) bool) fiber.Handler ``` ## Examples Import the package: ```go import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/skip" ) ``` `skip.New` accepts the handler to wrap and a predicate function. The predicate runs for every request, and returning `true` skips the wrapped handler and executes the next middleware in the chain. After you initialize your Fiber app, use `skip.New` like this: ```go func main() { app := fiber.New() app.Use(skip.New(BasicHandler, func(ctx fiber.Ctx) bool { return ctx.Method() == fiber.MethodGet })) app.Get("/", func(ctx fiber.Ctx) error { return ctx.SendString("It was a GET request!") }) log.Fatal(app.Listen(":3000")) } func BasicHandler(ctx fiber.Ctx) error { return ctx.SendString("It was not a GET request!") } ``` :::tip `app.Use` processes requests on any route and method. In the example above, the handler is skipped only for `GET`. ::: ================================================ FILE: docs/middleware/static.md ================================================ --- id: static --- # Static The Static middleware serves assets such as **images**, **CSS**, and **JavaScript**. :::info By default, it serves `index.html` when a directory is requested. Customize this behavior in the [Config](#config) options. ::: ## Signatures ```go func New(root string, cfg ...Config) fiber.Handler ``` ## Examples Import the package: ```go import( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/static" ) ``` ### Serving files from a directory ```go app.Get("/*", static.New("./public")) ```
Test ```sh curl http://localhost:3000/hello.html curl http://localhost:3000/css/style.css ```
### Serving files from a directory with `Use` ```go app.Use("/", static.New("./public")) ```
Test ```sh curl http://localhost:3000/hello.html curl http://localhost:3000/css/style.css ```
### Serving a single file ```go app.Use("/static", static.New("./public/hello.html")) ```
Test ```sh curl http://localhost:3000/static # will show hello.html curl http://localhost:3000/static/john/doe # will show hello.html ```
### Serving files using os.DirFS ```go app.Get("/files*", static.New("", static.Config{ FS: os.DirFS("files"), Browse: true, })) ```
Test ```sh curl http://localhost:3000/files/css/style.css curl http://localhost:3000/files/index.html ```
### Serving files using embed.FS ```go //go:embed path/to/files var myfiles embed.FS app.Get("/files*", static.New("", static.Config{ FS: myfiles, Browse: true, })) ```
Test ```sh curl http://localhost:3000/files/css/style.css curl http://localhost:3000/files/index.html ```
### SPA (Single Page Application) ```go app.Use("/web", static.New("", static.Config{ FS: os.DirFS("dist"), })) app.Get("/web*", func(c fiber.Ctx) error { return c.SendFile("dist/index.html") }) ```
Test ```sh curl http://localhost:3000/web/css/style.css curl http://localhost:3000/web/index.html curl http://localhost:3000/web ```
:::caution To define static routes using `Get`, append the wildcard (`*`) operator at the end of the route. ::: ## Config | Property | Type | Description | Default | |:-----------|:------------------------|:---------------------------------------------------------------------------------------------------------------------------|:-----------------------| | Next | `func(fiber.Ctx) bool` | Next defines a function to skip this middleware when it returns true. | `nil` | | FS | `fs.FS` | FS is the file system to serve the static files from.

You can use interfaces compatible with fs.FS like embed.FS, os.DirFS etc. | `nil` | | Compress | `bool` | When set to true, the server tries minimizing CPU usage by caching compressed files. The middleware will compress the response using `gzip`, `brotli`, or `zstd` compression depending on the [Accept-Encoding](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding) header.

This works differently than the github.com/gofiber/compression middleware. | `false` | | ByteRange | `bool` | When set to true, enables byte range requests. | `false` | | Browse | `bool` | When set to true, enables directory browsing. | `false` | | Download | `bool` | When set to true, enables direct download. | `false` | | IndexNames | `[]string` | The names of the index files for serving a directory. | `[]string{"index.html"}` | | CacheDuration | `time.Duration` | Expiration duration for inactive file handlers.

Use a negative time.Duration to disable it. | `10 * time.Second` | | MaxAge | `int` | The value for the Cache-Control HTTP-header that is set on the file response. MaxAge is defined in seconds. | `0` | | ModifyResponse | `fiber.Handler` | ModifyResponse defines a function that allows you to alter the response. | `nil` | | NotFoundHandler | `fiber.Handler` | NotFoundHandler defines a function to handle when the path is not found. | `nil` | When **Download** is enabled, the response includes a `Content-Disposition` header with the requested filename. Non-ASCII names use the `filename*` parameter as defined by [RFC 6266](https://www.rfc-editor.org/rfc/rfc6266) and [RFC 8187](https://www.rfc-editor.org/rfc/rfc8187). :::info You can set `CacheDuration` config property to `-1` to disable caching. ::: ## Default Config ```go var ConfigDefault = Config{ IndexNames: []string{"index.html"}, CacheDuration: 10 * time.Second, } ``` ================================================ FILE: docs/middleware/timeout.md ================================================ --- id: timeout --- # Timeout The timeout middleware enforces a deadline on handler execution. It wraps handlers with `context.WithTimeout`, exposes the derived context through `c.Context()`, and returns `408 Request Timeout` when the deadline is exceeded. ## How It Works When a timeout occurs, the middleware **returns immediately** without waiting for the handler to finish. This is achieved through Fiber's **Abandon mechanism**: 1. The handler runs in a goroutine with a timeout context 2. On timeout, the middleware marks the context as "abandoned" and returns `408` immediately 3. The handler goroutine can continue safely (e.g., for cleanup) without blocking the response 4. A background cleanup goroutine waits for the handler to finish and performs context cleanup Handlers can detect the timeout by listening on `c.Context().Done()` and return early. This is the recommended pattern for cooperative cancellation. If a handler panics, the middleware catches it and returns `500 Internal Server Error`. ## Known limitations - Timed-out requests abandon their `fiber.Ctx` to avoid data races with the core request handler (including the `ErrorHandler`). These contexts are **not** returned to the pool, so each timed-out request leaks a context. Calling `ForceRelease` is only safe if you can guarantee that no goroutine (including Fiber internals) will touch the context anymore; the timeout middleware intentionally does not call it. :::caution `timeout.New` wraps your final handler and can't be added with `app.Use` or used in a middleware chain. Register it per route and avoid calling `c.Next()` inside the wrapped handler—doing so will panic. ::: ## Signatures ```go func New(handler fiber.Handler, config ...timeout.Config) fiber.Handler ``` ## Examples ### Basic example The following program times out any request that takes longer than two seconds. The handler simulates work with `sleepWithContext`, which stops when the context is canceled: ```go package main import ( "context" "fmt" "log" "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/timeout" ) func sleepWithContext(ctx context.Context, d time.Duration) error { select { case <-time.After(d): return nil case <-ctx.Done(): return ctx.Err() } } func main() { app := fiber.New() handler := func(c fiber.Ctx) error { delay, _ := time.ParseDuration(c.Params("delay") + "ms") if err := sleepWithContext(c.Context(), delay); err != nil { return fmt.Errorf("%w: execution error", err) } return c.SendString("finished") } app.Get("/sleep/:delay", timeout.New(handler, timeout.Config{ Timeout: 2 * time.Second, })) log.Fatal(app.Listen(":3000")) } ``` Use these requests to see the middleware in action: ```bash curl -i http://localhost:3000/sleep/1000 # finishes within the timeout curl -i http://localhost:3000/sleep/3000 # returns 408 Request Timeout ``` ## Config | Property | Type | Description | Default | |:------------|:-------------------|:---------------------------------------------------------------------|:-------| | Next | `func(fiber.Ctx) bool` | Function to skip this middleware when it returns `true`. | `nil` | | Timeout | `time.Duration` | Timeout duration for requests. `0` or a negative value disables the timeout. | `0` | | OnTimeout | `fiber.Handler` | Handler executed when a timeout occurs. Defaults to returning `fiber.ErrRequestTimeout`. | `nil` | | Errors | `[]error` | Custom errors treated as timeout errors. | `nil` | ### Use with a custom error ```go var ErrFooTimeOut = errors.New("foo context canceled") func main() { app := fiber.New() h := func(c fiber.Ctx) error { sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms") if err := sleepWithContextWithCustomError(c.Context(), sleepTime); err != nil { return fmt.Errorf("%w: execution error", err) } return nil } app.Get("/foo/:sleepTime", timeout.New(h, timeout.Config{Timeout: 2 * time.Second, Errors: []error{ErrFooTimeOut}})) log.Fatal(app.Listen(":3000")) } func sleepWithContextWithCustomError(ctx context.Context, d time.Duration) error { timer := time.NewTimer(d) select { case <-ctx.Done(): if !timer.Stop() { <-timer.C } return ErrFooTimeOut case <-timer.C: } return nil } ``` ### Sample usage with a database call ```go func main() { app := fiber.New() db, _ := gorm.Open(postgres.Open("postgres://localhost/foodb"), &gorm.Config{}) handler := func(ctx fiber.Ctx) error { tran := db.WithContext(ctx.Context()).Begin() if tran = tran.Exec("SELECT pg_sleep(50)"); tran.Error != nil { return tran.Error } if tran = tran.Commit(); tran.Error != nil { return tran.Error } return nil } app.Get("/foo", timeout.New(handler, timeout.Config{Timeout: 10 * time.Second})) log.Fatal(app.Listen(":3000")) } ``` ================================================ FILE: docs/partials/routing/handler.md ================================================ --- id: route-handlers title: Route Handlers --- import Reference from '@site/src/components/reference'; Registers a route bound to a specific [HTTP method](https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods). ```go title="Signatures" // HTTP methods func (app *App) Get(path string, handler any, handlers ...any) Router func (app *App) Head(path string, handler any, handlers ...any) Router func (app *App) Post(path string, handler any, handlers ...any) Router func (app *App) Put(path string, handler any, handlers ...any) Router func (app *App) Delete(path string, handler any, handlers ...any) Router func (app *App) Connect(path string, handler any, handlers ...any) Router func (app *App) Options(path string, handler any, handlers ...any) Router func (app *App) Trace(path string, handler any, handlers ...any) Router func (app *App) Patch(path string, handler any, handlers ...any) Router // Add allows you to specify multiple methods at once // The provided handlers are executed in order, starting with `handler` and then the variadic `handlers`. func (app *App) Add(methods []string, path string, handler any, handlers ...any) Router // All will register the route on all HTTP methods // Almost the same as app.Use but not bound to prefixes func (app *App) All(path string, handler any, handlers ...any) Router ``` Fiber's adapter converts a variety of handler shapes to native `func(fiber.Ctx) error` callbacks. It currently recognizes seventeen cases (the numbers below match the comments in `toFiberHandler` inside `adapter.go`). This lets you mix Fiber-style handlers with Express-style callbacks and even reuse `net/http` or `fasthttp` functions. ### Fiber-native handlers (cases 1–2) - **Case 1.** `fiber.Handler` — the canonical `func(fiber.Ctx) error` form. - **Case 2.** `func(fiber.Ctx)` — Fiber runs the function and treats it as if it returned `nil`. ### Express-style request handlers (cases 3–12) - **Case 3.** `func(fiber.Req, fiber.Res) error` - **Case 4.** `func(fiber.Req, fiber.Res)` - **Case 5.** `func(fiber.Req, fiber.Res, func() error) error` - **Case 6.** `func(fiber.Req, fiber.Res, func() error)` - **Case 7.** `func(fiber.Req, fiber.Res, func()) error` - **Case 8.** `func(fiber.Req, fiber.Res, func())` - **Case 9.** `func(fiber.Req, fiber.Res, func(error))` - **Case 10.** `func(fiber.Req, fiber.Res, func(error)) error` - **Case 11.** `func(fiber.Req, fiber.Res, func(error) error)` - **Case 12.** `func(fiber.Req, fiber.Res, func(error) error) error` The adapter injects a `next` callback when your signature accepts one. Fiber propagates downstream errors from `c.Next()` back through the wrapper, so returning those errors remains optional. If you never call the injected `next` function, the handler chain stops, matching Express semantics. When you accept `next` callbacks that take an `error`, calling `next(nil)` continues the chain and passing a non-nil error short-circuits with that error. If the handler itself returns an error, Fiber prioritizes that value over any recorded `next` error. ### net/http handlers (cases 13–15) - **Case 13.** `http.HandlerFunc` - **Case 14.** `http.Handler` - **Case 15.** `func(http.ResponseWriter, *http.Request)` :::caution Compatibility overhead Fiber adapts these handlers through `fasthttpadaptor`. They do not receive `fiber.Ctx`, cannot call `c.Next()`, and therefore always terminate the handler chain. The compatibility layer also adds more overhead than running a native Fiber handler, so prefer the other forms when possible. ::: ### fasthttp handlers (cases 16–17) - **Case 16.** `fasthttp.RequestHandler` - **Case 17.** `func(*fasthttp.RequestCtx) error` fasthttp handlers run with full access to the underlying `fasthttp.RequestCtx`. They are expected to manage the response directly. Fiber will propagate any error returned by the `func(*fasthttp.RequestCtx) error` variant but otherwise does not inspect the context state. ```go title="Examples" // Simple GET handler (Fiber accepts both func(fiber.Ctx) and func(fiber.Ctx) error) app.Get("/api/list", func(c fiber.Ctx) error { return c.SendString("I'm a GET request!") }) // Reuse an existing net/http handler without manual adaptation httpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) }) app.Get("/foo", httpHandler) // Align with Express-style handlers using fiber.Req and fiber.Res helpers (works // for middleware and routes alike) app.Use(func(req fiber.Req, res fiber.Res, next func() error) error { if req.IP() == "192.168.1.254" { return res.SendStatus(fiber.StatusForbidden) } return next() }) app.Get("/express", func(req fiber.Req, res fiber.Res) error { return res.SendString("Hello from Express-style handlers!") }) // Mount a fasthttp.RequestHandler directly app.Get("/bar", func(ctx *fasthttp.RequestCtx) { ctx.SetStatusCode(fiber.StatusAccepted) }) // Simple POST handler app.Post("/api/register", func(c fiber.Ctx) error { return c.SendString("I'm a POST request!") }) ``` ## Use Can be used for middleware packages and prefix catchers. Prefixes now require either an exact match or a slash boundary, so `/john` matches `/john` and `/john/doe` but not `/johnnnnn`. Parameter tokens like `:name`, `:name?`, `*`, and `+` are still expanded before the boundary check runs. ```go title="Signature" func (app *App) Use(args ...any) Router // Fiber inspects args to support these common usage patterns: // - app.Use(handler, handlers ...any) // - app.Use(path string, handler, handlers ...any) // - app.Use(paths []string, handler, handlers ...any) // - app.Use(path string, subApp *App) ``` Each handler argument can independently be a Fiber handler (with or without an `error` return), an Express-style callback, a `net/http` handler, or any other supported shape including fasthttp callbacks that return errors. ```go title="Examples" // Match any request app.Use(func(c fiber.Ctx) error { return c.Next() }) // Match request starting with /api app.Use("/api", func(c fiber.Ctx) error { return c.Next() }) // Match requests starting with /api or /home (multiple-prefix support) app.Use([]string{"/api", "/home"}, func(c fiber.Ctx) error { return c.Next() }) // Attach multiple handlers app.Use("/api", func(c fiber.Ctx) error { c.Set("X-Custom-Header", random.String(32)) return c.Next() }, func(c fiber.Ctx) error { return c.Next() }) // Mount a sub-app app.Use("/api", api) ``` ================================================ FILE: docs/whats_new.md ================================================ --- id: whats_new title: 🆕 What's New in v3 sidebar_position: 2 toc_max_heading_level: 4 --- ## 🎉 Welcome We are excited to announce the release of Fiber v3! 🚀 In this guide, we'll walk you through the most important changes in Fiber `v3` and show you how to migrate your existing Fiber `v2` applications to Fiber `v3`. ### 🛠️ Migration tool Fiber v3 introduces a CLI-powered migration helper. Install the CLI and let it update your project automatically: ```bash go install github.com/gofiber/cli/fiber@latest fiber migrate --to v3 ``` See the [migration guide](#-migration-guide) for more details and options. Here's a quick overview of the changes in Fiber `v3`: - [🚀 App](#-app) - [🎣 Hooks](#-hooks) - [🚀 Listen](#-listen) - [🗺️ Router](#-router) - [🧠 Context](#-context) - [📎 Binding](#-binding) - [🔬 Extractors Package](#-extractors-package) - [🔄️ Redirect](#-redirect) - [🌎 Client package](#-client-package) - [🧰 Generic functions](#-generic-functions) - [🛠️ Utils](#utils) - [🧩 Services](#-services) - [📃 Log](#-log) - [📦 Storage Interface](#-storage-interface) - [🧬 Middlewares](#-middlewares) - [Important Change for Accessing Middleware Data](#important-change-for-accessing-middleware-data) - [Adaptor](#adaptor) - [BasicAuth](#basicauth) - [Cache](#cache) - [CORS](#cors) - [CSRF](#csrf) - [Compression](#compression) - [EncryptCookie](#encryptcookie) - [Favicon](#favicon) - [Filesystem](#filesystem) - [Healthcheck](#healthcheck) - [KeyAuth](#keyauth) - [Logger](#logger) - [Monitor](#monitor) - [Proxy](#proxy) - [Recover](#recover) - [Session](#session) - [🔌 Addons](#-addons) - [📋 Migration guide](#-migration-guide) ## Dropping support for old Go versions Fiber `v3` requires Go `1.25` or later. Update your toolchain to `1.25+` before upgrading so the module `go` directive and standard library features align with the new minimum version. ## 🚀 App We have made several changes to the Fiber app, including: - **Listen**: The `Listen` method has been unified with the configuration, allowing for more streamlined setup. - **Static**: The `Static` method has been removed and its functionality has been moved to the [static middleware](./middleware/static.md). - **app.Config properties**: Several properties have been moved to the listen configuration: - `DisableStartupMessage` - `EnablePrefork` (previously `Prefork`) - `EnablePrintRoutes` - `ListenerNetwork` (previously `Network`) - **Trusted Proxy Configuration**: The `EnabledTrustedProxyCheck` has been moved to `app.Config.TrustProxy`, and `TrustedProxies` has been moved to `TrustProxyConfig.Proxies`. Additionally, `ProxyHeader` must be set to read client IPs from proxy headers (e.g., `X-Forwarded-For`). - **XMLDecoder Config Property**: The `XMLDecoder` property has been added to allow usage of 3rd-party XML libraries in XML binder. ### New Methods - **RegisterCustomBinder**: Allows for the registration of custom binders. - **RegisterCustomConstraint**: Allows for the registration of custom constraints. - **NewWithCustomCtx**: Initialize an app with a custom context in one step. - **State**: Provides a global state for the application, which can be used to store and retrieve data across the application. Check out the [State](./api/state) method for further details. - **NewErrorf**: Allows variadic parameters when creating formatted errors. - **GetBytes / GetString**: Helpers that detach values only when `Immutable` is enabled and the data still references request or response buffers. Access via `c.App().GetString` and `c.App().GetBytes`. - **ReloadViews**: Lets you re-run the configured view engine's `Load()` logic at runtime, including guard rails for missing or nil view engines so development hot-reload hooks can refresh templates safely. #### Custom Route Constraints Custom route constraints enable you to define your own validation rules for route parameters. Use `RegisterCustomConstraint` to add a constraint type that implements the `CustomConstraint` interface.
Example ```go type UlidConstraint struct { fiber.CustomConstraint } func (*UlidConstraint) Name() string { return "ulid" } func (*UlidConstraint) Execute(param string, args ...string) bool { _, err := ulid.Parse(param) return err == nil } app.RegisterCustomConstraint(&UlidConstraint{}) app.Get("/login/:id", func(c fiber.Ctx) error { return c.SendString("User " + c.Params("id")) }) ```
### Removed Methods - **Mount**: Use `app.Use()` instead. - **ListenTLS**: Use `app.Listen()` with `tls.Config`. - **ListenTLSWithCertificate**: Use `app.Listen()` with `tls.Config`. - **ListenMutualTLS**: Use `app.Listen()` with `tls.Config`. - **ListenMutualTLSWithCertificate**: Use `app.Listen()` with `tls.Config`. ### Method Changes - **Test**: The `Test` method has replaced the timeout parameter with a configuration parameter. `0` or lower represents no timeout. - **Listen**: Now has a configuration parameter. - **Listener**: Now has a configuration parameter. ### Custom Ctx Interface in Fiber v3 Fiber v3 introduces a customizable `Ctx` interface, allowing developers to extend and modify the context to fit their needs. This feature provides greater flexibility and control over request handling. #### Idea Behind Custom Ctx Classes The idea behind custom `Ctx` classes is to give developers the ability to extend the default context with additional methods and properties tailored to the specific requirements of their application. This allows for better request handling and easier implementation of specific logic. #### NewWithCustomCtx `NewWithCustomCtx` creates the application and sets the custom context factory at initialization time. ```go title="Signature" func NewWithCustomCtx(fn func(app *App) CustomCtx, config ...Config) *App ```
Example ```go package main import ( "log" "github.com/gofiber/fiber/v3" ) type CustomCtx struct { fiber.DefaultCtx } func (c *CustomCtx) CustomMethod() string { return "custom value" } func main() { app := fiber.NewWithCustomCtx(func(app *fiber.App) fiber.CustomCtx { return &CustomCtx{ DefaultCtx: *fiber.NewDefaultCtx(app), } }) app.Get("/", func(c fiber.Ctx) error { customCtx := c.(*CustomCtx) return c.SendString(customCtx.CustomMethod()) }) log.Fatal(app.Listen(":3000")) } ``` This example creates a `CustomCtx` with an extra `CustomMethod` and initializes the app with `NewWithCustomCtx`.
### Configurable TLS Minimum Version We have added support for configuring the TLS minimum version. This field allows you to set the TLS minimum version for TLSAutoCert and the server listener. ```go app.Listen(":444", fiber.ListenConfig{TLSMinVersion: tls.VersionTLS12}) ``` #### TLS AutoCert support (ACME / Let's Encrypt) We have added native support for automatic certificates management from Let's Encrypt and any other ACME-based providers. ```go // Certificate manager certManager := &autocert.Manager{ Prompt: autocert.AcceptTOS, // Replace with your domain name HostPolicy: autocert.HostWhitelist("example.com"), // Folder to store the certificates Cache: autocert.DirCache("./certs"), } app.Listen(":444", fiber.ListenConfig{ AutoCertManager: certManager, }) ``` ### MIME Constants `MIMEApplicationJavaScript` and `MIMEApplicationJavaScriptCharsetUTF8` are deprecated. Use `MIMETextJavaScript` and `MIMETextJavaScriptCharsetUTF8` instead. ## 🎣 Hooks We have made several changes to the Fiber hooks, including: - Added new shutdown hooks to provide better control over the shutdown process: - `OnPreShutdown` - Executes before the server starts shutting down - `OnPostShutdown` - Executes after the server has shut down, receives any shutdown error - `OnPreStartupMessage` - Executes before the startup message is printed, allowing customization of the banner and info entries - `OnPostStartupMessage` - Executes after the startup message is printed, allowing post-startup logic - Deprecated `OnShutdown` in favor of the new pre/post shutdown hooks - Improved shutdown hook execution order and reliability - Added mutex protection for hook registration and execution Important: When using shutdown hooks, ensure app.Listen() is called in a separate goroutine: ```go // Correct usage go app.Listen(":3000") // ... register shutdown hooks app.Shutdown() // Incorrect usage - hooks won't work app.Listen(":3000") // This blocks app.Shutdown() // Never reached ``` ## 🚀 Listen We have made several changes to the Fiber listen, including: - Removed `OnShutdownError` and `OnShutdownSuccess` from `ListenConfig` in favor of using the `OnPostShutdown` hook, which receives the shutdown error ```go app := fiber.New() // Before - using ListenConfig callbacks app.Listen(":3000", fiber.ListenConfig{ OnShutdownError: func(err error) { log.Printf("Shutdown error: %v", err) }, OnShutdownSuccess: func() { log.Println("Shutdown successful") }, }) // After - using OnPostShutdown hook app.Hooks().OnPostShutdown(func(err error) error { if err != nil { log.Printf("Shutdown error: %v", err) } else { log.Println("Shutdown successful") } return nil }) go app.Listen(":3000") ``` This change simplifies the shutdown handling by consolidating the shutdown callbacks into a single hook that receives the error status. - Added support for Unix domain sockets via `ListenerNetwork` and `UnixSocketFileMode` ```go // v2 - Requires manual deletion of old file and permissions change app := fiber.New(fiber.Config{ Network: "unix", }) os.Remove("app.sock") app.Hooks().OnListen(func(fiber.ListenData) error { return os.Chmod("app.sock", 0770) }) app.Listen("app.sock") // v3 - Fiber does it for you app := fiber.New() app.Listen("app.sock", fiber.ListenConfig{ ListenerNetwork: fiber.NetworkUnix, UnixSocketFileMode: 0770, }) ``` - Added `TLSConfig` to `ListenConfig` so external providers can supply certificates via `GetCertificate`. Prefer `TLSConfig` when configuring TLS; when set, it is cloned and takes precedence over other TLS fields. ```go app := fiber.New() app.Listen(":443", fiber.ListenConfig{ TLSConfig: &tls.Config{ GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { return myProvider.Certificate(info.ServerName) }, }, }) ``` - Expanded `ListenData` with versioning, handler, process, and PID metadata, plus dedicated startup message hooks for customization. Check out the [Hooks](./api/hooks#startup-message-customization) documentation for further details. ```go title="Customize the startup message" package main import ( "fmt" "os" "github.com/gofiber/fiber/v3" ) func main() { app := fiber.New() app.Hooks().OnPreStartupMessage(func(sm *fiber.PreStartupMessageData) error { sm.BannerHeader = "FOOBER " + sm.Version + "\n-------" // Optional: you can also remove old entries // sm.ResetEntries() sm.AddInfo("git-hash", "Git hash", os.Getenv("GIT_HASH")) sm.AddInfo("prefork", "Prefork", fmt.Sprintf("%v", sm.Prefork), 15) return nil }) app.Hooks().OnPostStartupMessage(func(sm *fiber.PostStartupMessageData) error { if !sm.Disabled && !sm.IsChild && !sm.Prevented { fmt.Println("startup completed") } return nil }) app.Listen(":5000") } ``` ## 🗺 Router We have slightly adapted our router interface ### Handler compatibility Fiber now ships with a routing adapter (see `adapter.go`) that understands native Fiber handlers alongside `net/http` and `fasthttp` handlers. Route registration helpers accept a required `handler` argument plus optional additional `handlers`, all typed as `any`, and the adapter transparently converts supported handler styles so you can keep using the ecosystem functions you're familiar with. To align even closer with Express, you can also register handlers that accept the new `fiber.Req` and `fiber.Res` helper interfaces. The adapter understands both two-argument (`func(fiber.Req, fiber.Res)`) and three-argument (`func(fiber.Req, fiber.Res, func() error)`) callbacks, regardless of whether they return an `error`. It also accepts Express-style `next` callbacks that take an `error` (`func(error)` or `func(error) error`). When you include the optional `next` callback, Fiber wires it to `c.Next()` for you so middleware continues to behave as expected. Calling `next(nil)` continues the chain, while passing a non-nil error short-circuits and returns that error. If your handler returns an `error`, the value returned from the injected `next()` bubbles straight back to the caller. When your handler omits an `error` return, Fiber records the result of `next()` and returns it after your function exits so downstream failures still propagate. | Case | Handler signature | Notes | | ---- | ----------------- | ----- | | 1 | `fiber.Handler` | Native Fiber handler. | | 2 | `func(fiber.Ctx)` | Fiber handler without an error return. | | 3 | `func(fiber.Req, fiber.Res) error` | Express-style request handler with error return. | | 4 | `func(fiber.Req, fiber.Res)` | Express-style request handler without error return. | | 5 | `func(fiber.Req, fiber.Res, func() error) error` | Express-style middleware with an error-returning `next` callback and handler error return. | | 6 | `func(fiber.Req, fiber.Res, func() error)` | Express-style middleware with an error-returning `next` callback. | | 7 | `func(fiber.Req, fiber.Res, func()) error` | Express-style middleware with a no-argument `next` callback and handler error return. | | 8 | `func(fiber.Req, fiber.Res, func())` | Express-style middleware with a no-argument `next` callback. | | 9 | `func(fiber.Req, fiber.Res, func(error))` | Express-style middleware with an error-accepting `next` callback. | | 10 | `func(fiber.Req, fiber.Res, func(error)) error` | Express-style middleware with an error-accepting `next` callback and handler error return. | | 11 | `func(fiber.Req, fiber.Res, func(error) error)` | Express-style middleware with an error-accepting `next` callback that returns an error. | | 12 | `func(fiber.Req, fiber.Res, func(error) error) error` | Express-style middleware with an error-accepting `next` callback that returns an error and handler error return. | | 13 | `http.HandlerFunc` | Standard-library handler function adapted through `fasthttpadaptor`. | | 14 | `http.Handler` | Standard-library handler implementation; pointer receivers must be non-nil. | | 15 | `func(http.ResponseWriter, *http.Request)` | Standard-library function handlers via `fasthttpadaptor`. | | 16 | `fasthttp.RequestHandler` | Direct fasthttp handler without error return. | | 17 | `func(*fasthttp.RequestCtx) error` | fasthttp handler that returns an error to Fiber. | ### Route chaining `RouteChain` is a new helper inspired by [`Express`](https://expressjs.com/en/api.html#app.route) that makes it easy to declare a stack of handlers on the same path, while the existing `Route` helper stays available for prefix encapsulation. ```go RouteChain(path string) Register ```
Example ```go app.RouteChain("/api").RouteChain("/user/:id?") .Get(func(c fiber.Ctx) error { // Get user return c.JSON(fiber.Map{"message": "Get user", "id": c.Params("id")}) }) .Post(func(c fiber.Ctx) error { // Create user return c.JSON(fiber.Map{"message": "User created"}) }) .Put(func(c fiber.Ctx) error { // Update user return c.JSON(fiber.Map{"message": "User updated", "id": c.Params("id")}) }) .Delete(func(c fiber.Ctx) error { // Delete user return c.JSON(fiber.Map{"message": "User deleted", "id": c.Params("id")}) }) ```
You can find more information about `app.RouteChain` and `app.Route` in the API documentation ([RouteChain](./api/app#routechain), [Route](./api/app#route)). ### Automatic HEAD routes for GET Fiber now auto-registers a `HEAD` route whenever you add a `GET` route. The generated handler chain matches the `GET` chain so status codes and headers stay in sync while the response body remains empty, ensuring `HEAD` clients observe the same metadata as a `GET` consumer. ```go title="GET now enables HEAD automatically" app := fiber.New() app.Get("/health", func(c fiber.Ctx) error { c.Set("X-Service", "api") return c.SendString("OK") }) // HEAD /health reuses the GET middleware chain and returns headers only. ``` You can still register explicit `HEAD` handlers for any `GET` route, and they continue to win when you add them: ```go title="Override the generated HEAD handler" app.Head("/health", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusNoContent) }) ``` Prefer to manage `HEAD` routes yourself? Disable the feature through `fiber.Config.DisableHeadAutoRegister`: ```go title="Disable automatic HEAD registration" handler := func(c fiber.Ctx) error { c.Set("X-Service", "api") return c.SendString("OK") } app := fiber.New(fiber.Config{DisableHeadAutoRegister: true}) app.Get("/health", handler) // HEAD /health now returns 405 unless you add it manually. ``` Auto-generated `HEAD` routes appear in tooling such as `app.Stack()` and cover the same routing scenarios as their `GET` counterparts, including groups, mounted apps, dynamic parameters, and static file handlers. ### Middleware registration We have aligned our method for middlewares closer to [`Express`](https://expressjs.com/en/api.html#app.use) and now also support the [`Use`](./api/app#use) of multiple prefixes. Prefix matching is now stricter: partial matches must end at a slash boundary (or be an exact match). This keeps `/api` middleware from running on `/apiv1` while still allowing `/api/:version` style patterns that leverage route parameters, optional segments, or wildcards. Registering a subapp is now also possible via the [`Use`](./api/app#use) method instead of the old `app.Mount` method.
Example ```go // register multiple prefixes app.Use([]string{"/v1", "/v2"}, func(c fiber.Ctx) error { // Middleware for /v1 and /v2 return c.Next() }) // define subapp api := fiber.New() api.Get("/user", func(c fiber.Ctx) error { return c.SendString("User") }) // register subapp app.Use("/api", api) ```
To enable the routing changes above we had to slightly adjust the signature of the `Add` method. ```diff - Add(method, path string, handlers ...Handler) Router + Add(methods []string, path string, handler any, handlers ...any) Router ``` ### Test Config The `app.Test()` method now allows users to customize their test configurations:
Example ```go // Create a test app with a handler to test app := fiber.New() app.Get("/", func(c fiber.Ctx) { return c.SendString("hello world") }) // Define the HTTP request and custom TestConfig to test the handler req := httptest.NewRequest(MethodGet, "/", nil) testConfig := fiber.TestConfig{ Timeout: 0, FailOnTimeout: false, } // Test the handler using the request and testConfig resp, err := app.Test(req, testConfig) ```
To provide configurable testing capabilities, we had to change the signature of the `Test` method. ```diff - Test(req *http.Request, timeout ...time.Duration) (*http.Response, error) + Test(req *http.Request, config ...fiber.TestConfig) (*http.Response, error) ``` The `TestConfig` struct provides the following configuration options: - `Timeout`: The duration to wait before timing out the test. Use 0 for no timeout. - `FailOnTimeout`: Controls the behavior when a timeout occurs: - When true, the test will return an `os.ErrDeadlineExceeded` if the test exceeds the `Timeout` duration. - When false, the test will return the partial response received before timing out. If a custom `TestConfig` isn't provided, then the following will be used: ```go testConfig := fiber.TestConfig{ Timeout: time.Second, FailOnTimeout: true, } ``` **Note:** Using this default is **NOT** the same as providing an empty `TestConfig` as an argument to `app.Test()`. An empty `TestConfig` is the equivalent of: ```go testConfig := fiber.TestConfig{ Timeout: 0, FailOnTimeout: false, } ``` ## 🧠 Context ### New Features - Cookie now allows Partitioned cookies for [CHIPS](https://developers.google.com/privacy-sandbox/3pcd/chips) support. CHIPS (Cookies Having Independent Partitioned State) is a feature that improves privacy by allowing cookies to be partitioned by top-level site, mitigating cross-site tracking. - Cookie automatic security enforcement: When setting a cookie with `SameSite=None`, Fiber automatically sets `Secure=true` as required by RFC 6265bis and modern browsers (Chrome, Firefox, Safari). This ensures compliance with the "None" SameSite policy. See [Mozilla docs](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie#none) and [Chrome docs](https://developers.google.com/search/blog/2020/01/get-ready-for-new-samesitenone-secure) for details. - `Ctx` now implements the [context.Context](https://pkg.go.dev/context#Context) interface, replacing the former `UserContext` helpers. ### New Methods - **AutoFormat**: Similar to Express.js, automatically formats the response based on the request's `Accept` header. - **Deadline**: For implementing `context.Context`. - **Done**: For implementing `context.Context`. - **Err**: For implementing `context.Context`. - **Host**: Similar to Express.js, returns the host name of the request. - **Port**: Similar to Express.js, returns the port number of the request. - **IsProxyTrusted**: Checks the trustworthiness of the remote IP. - **Reset**: Resets context fields for server handlers. - **Schema**: Similar to Express.js, returns the schema (HTTP or HTTPS) of the request. - **SendEarlyHints**: Sends `HTTP 103 Early Hints` status code with `Link` headers so browsers can preload resources while the final response is being prepared. - **SendStream**: Similar to Express.js, sends a stream as the response. - **SendStreamWriter**: Sends a stream using a writer function. - **SendString**: Similar to Express.js, sends a string as the response. - **String**: Similar to Express.js, converts a value to a string. - **Value**: For implementing `context.Context`. Returns request-scoped value from Locals. - **Context()**: Returns a `context.Context` that can be used outside the handler. - **SetContext**: Sets the base `context.Context` returned by `Context()` for propagating deadlines or values. - **ViewBind**: Binds data to a view, replacing the old `Bind` method. - **CBOR**: Introducing [CBOR](https://cbor.io/) binary encoding format for both request & response body. CBOR is a binary data serialization format which is both compact and efficient, making it ideal for use in web applications. - **MsgPack**: Introducing [MsgPack](https://msgpack.org/) binary encoding format for both request & response body. MsgPack is a binary serialization format that is more efficient than JSON, making it ideal for high-performance applications. - **Drop**: Terminates the client connection silently without sending any HTTP headers or response body. This can be used for scenarios where you want to block certain requests without notifying the client, such as mitigating DDoS attacks or protecting sensitive endpoints from unauthorized access. - **End**: Similar to Express.js, immediately flushes the current response and closes the underlying connection. - **AcceptsLanguagesExtended**: Matches language ranges using RFC 4647 Extended Filtering with wildcard subtags. - **FullURL**: Returns the full request URL (scheme + host + original URL). - **RequestID**: Returns the request identifier from the response or request headers. - **UserAgent**: Returns the `User-Agent` request header. - **Referer**: Returns the `Referer` request header. - **AcceptLanguage**: Returns the `Accept-Language` request header. - **AcceptEncoding**: Returns the `Accept-Encoding` request header. - **HasHeader**: Reports whether the request includes a header with the given key. - **MediaType**: Returns the MIME type from the `Content-Type` header without parameters. - **Charset**: Returns the `charset` parameter from the `Content-Type` header. - **IsJSON**: Reports whether the `Content-Type` header is JSON. - **IsForm**: Reports whether the `Content-Type` header is form-encoded. - **IsMultipart**: Reports whether the `Content-Type` header is multipart form data. - **AcceptsJSON**: Reports whether the `Accept` header allows JSON. - **AcceptsHTML**: Reports whether the `Accept` header allows HTML. - **AcceptsXML**: Reports whether the `Accept` header allows XML. - **AcceptsEventStream**: Reports whether the `Accept` header allows `text/event-stream`. - **Matched**: Detects when the current request path matched a registered route. - **IsMiddleware**: Indicates if the current handler was registered as middleware. - **HasBody**: Quickly checks whether the request includes a body. - **OverrideParam**: Overwrites the value of an existing route parameter, or does nothing if the parameter does not exist - **IsWebSocket**: Reports if the request attempts a WebSocket upgrade. - **IsPreflight**: Identifies CORS preflight requests before handlers run. ### Removed Methods - **AllParams**: Use `c.Bind().URI()` instead. - **ParamsInt**: Use `Params` with generic types. - **QueryBool**: Use `Query` with generic types. - **QueryFloat**: Use `Query` with generic types. - **QueryInt**: Use `Query` with generic types. - **BodyParser**: Use `c.Bind().Body()` instead. - **CookieParser**: Use `c.Bind().Cookie()` instead. - **ParamsParser**: Use `c.Bind().URI()` instead. - **RedirectToRoute**: Use `c.Redirect().Route()` instead. - **RedirectBack**: Use `c.Redirect().Back()` instead. - **ReqHeaderParser**: Use `c.Bind().Header()` instead. - **UserContext**: Removed. `Ctx` itself now satisfies `context.Context`; pass `c` directly where a `context.Context` is required. - **SetUserContext**: Removed. Use `SetContext` and `Context()` or `context.WithValue` on `c` to store additional request-scoped values. ### Changed Methods - **Bind**: Now used for binding instead of view binding. Use `c.ViewBind()` for view binding. - **Format**: Parameter changed from `body interface{}` to `handlers ...ResFmt`. - **Redirect**: Use `c.Redirect().To()` instead. - **SendFile**: Now supports different configurations using a config parameter. - **Attachment and Download**: Non-ASCII filenames now use `filename*` as specified by [RFC 6266](https://www.rfc-editor.org/rfc/rfc6266) and [RFC 8187](https://www.rfc-editor.org/rfc/rfc8187). - **Context()**: Renamed to `RequestCtx()` to access the underlying `fasthttp.RequestCtx`. ### SendEarlyHints `SendEarlyHints` sends an informational [`103 Early Hints`](https://developer.chrome.com/docs/web-platform/early-hints) response with `Link` headers based on the provided `hints` argument. This allows a browser to start preloading assets while the server is still preparing the final response. ```go hints := []string{"; rel=preload; as=script"} app.Get("/early", func(c fiber.Ctx) error { if err := c.SendEarlyHints(hints); err != nil { return err } return c.SendString("done") }) ``` Older HTTP/1.1 clients may ignore these interim responses or handle them inconsistently. ### SendStreamWriter In v3, we introduced support for buffered streaming with the addition of the `SendStreamWriter` method: ```go func (c Ctx) SendStreamWriter(streamWriter func(w *bufio.Writer)) error ``` With this new method, you can implement: - Server-Side Events (SSE) - Large file downloads - Live data streaming ```go app.Get("/sse", func(c fiber.Ctx) error { c.Set("Content-Type", "text/event-stream") c.Set("Cache-Control", "no-cache") c.Set("Connection", "keep-alive") c.Set("Transfer-Encoding", "chunked") return c.SendStreamWriter(func(w *bufio.Writer) { for { fmt.Fprintf(w, "event: my-event\n") fmt.Fprintf(w, "data: Hello SSE\n\n") if err := w.Flush(); err != nil { log.Print("Client disconnected!") return } } }) }) ``` You can find more details about this feature in [/docs/api/ctx.md](./api/ctx.md). ### Drop In v3, we introduced support to silently terminate requests through `Drop`. ```go func (c Ctx) Drop() error ``` With this method, you can: - Block certain requests without notifying the client to mitigate DDoS attacks - Protect sensitive endpoints from unauthorized access without leaking errors. :::caution While this feature adds the ability to drop connections, it is still **highly recommended** to use additional measures (such as **firewalls**, **proxies**, etc.) to further protect your server endpoints by blocking malicious connections before the server establishes a connection. ::: ```go app.Get("/", func(c fiber.Ctx) error { if c.IP() == "192.168.1.1" { return c.Drop() } return c.SendString("Hello World!") }) ``` You can find more details about this feature in [/docs/api/ctx.md](./api/ctx.md). ### End In v3, we introduced a new method to match the Express.js API's `res.end()` method. ```go func (c Ctx) End() error ``` With this method, you can: - Stop middleware from controlling the connection after a handler further up the method chain by immediately flushing the current response and closing the connection. - Use `return c.End()` as an alternative to `return nil` ```go app.Use(func (c fiber.Ctx) error { err := c.Next() if err != nil { log.Println("Got error: %v", err) return c.SendString(err.Error()) // Will be unsuccessful since the response ended below } return nil }) app.Get("/hello", func (c fiber.Ctx) error { query := c.Query("name", "") if query == "" { _ = c.SendString("You don't have a name?") _ = c.End() // Closes the underlying connection; errors intentionally ignored return errors.New("No name provided") } return c.SendString("Hello, " + query + "!") }) ``` --- ## 📎 Binding Fiber v3 introduces a new binding mechanism that simplifies the process of binding request data to structs. The new binding system supports binding from various sources such as URL parameters, query parameters, headers, and request bodies. This unified approach makes it easier to handle different types of request data in a consistent manner. ### New Features - Unified binding from URL parameters, query parameters, headers, and request bodies. - Support for custom binders and constraints. - Improved error handling and validation. - Support multipart file binding for `*multipart.FileHeader`, `*[]*multipart.FileHeader`, and `[]*multipart.FileHeader` field types. - Support for unified binding (`Bind().All()`) with defined precedence order: (URI -> Body -> Query -> Headers -> Cookies). [Learn more](./api/bind.md#all). - Support MsgPack binding for request body.
Example ```go type User struct { ID int `uri:"id"` Name string `json:"name"` Email string `json:"email"` } app.Post("/user/:id", func(c fiber.Ctx) error { var user User if err := c.Bind().Body(&user); err != nil { return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": err.Error()}) } return c.JSON(user) }) ``` In this example, the `Bind` method is used to bind the request body to the `User` struct. The `Body` method of the `Bind` class performs the actual binding.
## 🔬 Extractors Package Fiber v3 introduces a new shared `extractors` package that consolidates value extraction utilities previously duplicated across middleware packages. This package provides a unified API for extracting values from headers, cookies, query parameters, form data, and URL parameters with built-in chain/fallback logic and security considerations. ### Key Features - **Unified API**: Single package for extracting values from headers, cookies, query parameters, form data, and URL parameters - **Chain Logic**: Built-in fallback mechanism to try multiple extraction sources in order - **Source Awareness**: Source inspection capabilities for security-sensitive operations - **Type Safety**: Strongly typed extraction with proper error handling - **Performance**: Optimized extraction functions with minimal overhead ### Available Extractors - `FromAuthHeader(authScheme string)`: Extract from Authorization header with scheme support - `FromCookie(key string)`: Extract from HTTP cookies - `FromParam(param string)`: Extract from URL path parameters - `FromForm(param string)`: Extract from form data - `FromHeader(header string)`: Extract from custom HTTP headers - `FromQuery(param string)`: Extract from URL query parameters - `FromCustom(key string, extractor func(c fiber.Ctx) (string, error))`: Define custom extraction logic with metadata - `Chain(extractors ...Extractor)`: Chain multiple extractors with fallback logic ### Usage Example ```go import "github.com/gofiber/fiber/v3/extractors" // Extract API key from multiple sources with fallback apiKeyExtractor := extractors.Chain( extractors.FromHeader("X-API-Key"), extractors.FromQuery("api_key"), extractors.FromCookie("api_key"), ) app.Use(func(c fiber.Ctx) error { apiKey, err := apiKeyExtractor.Extract(c) if err != nil { return c.Status(401).SendString("API key required") } // Use apiKey for authentication return c.Next() }) ``` ### Migration from Middleware-Specific Extractors Middleware packages in Fiber v3 now use the shared extractors package instead of maintaining their own extraction logic. This provides: - **Code Deduplication**: Eliminates ~500+ lines of duplicated extraction code - **Consistency**: Standardized extraction behavior across all middleware - **Maintainability**: Single source of truth for extraction logic - **Security**: Unified security considerations and warnings ## 🔄 Redirect Fiber v3 enhances the redirect functionality by introducing new methods and improving existing ones. The new redirect methods provide more flexibility and control over the redirection process. ### New Methods - `Redirect().To()`: Redirects to a specific URL. - `Redirect().Route()`: Redirects to a named route. - `Redirect().Back()`: Redirects to the previous URL.
Example ```go app.Get("/old", func(c fiber.Ctx) error { return c.Redirect().To("/new") }) app.Get("/new", func(c fiber.Ctx) error { return c.SendString("Welcome to the new route!") }) ```
### Changed behavior :::info The default redirect status code has been updated from `302 Found` to `303 See Other` to ensure more consistent behavior across different browsers. ::: ## 🌎 Client package The Gofiber client has been completely rebuilt. It includes numerous new features such as Cookiejar, request/response hooks, and more. You can take a look to [client docs](./client/rest.md) to see what's new with the client. ### Configuration improvements The v3 client centralizes common configuration on the client instance and lets you override it per request with `client.Config`. You can define base URLs, defaults (headers, cookies, path parameters, timeouts), and toggle path normalization once, while still using axios-style helpers for each call. ```go cc := client.New(). SetBaseURL("https://api.service.local"). AddHeader("Authorization", "Bearer "). SetTimeout(5 * time.Second). SetPathParam("tenant", "acme") resp, err := cc.Get("/users/:tenant/:id", client.Config{ PathParam: map[string]string{"id": "42"}, Param: map[string]string{"include": "profile"}, DisablePathNormalizing: true, }) if err != nil { panic(err) } defer resp.Close() fmt.Println(resp.StatusCode(), resp.String()) ``` ### Fasthttp transport integration - `client.NewWithHostClient` and `client.NewWithLBClient` allow you to plug existing `fasthttp` clients directly into Fiber while keeping retries, redirects, and hook logic consistent. - Dialer, TLS, and proxy helpers now update every host client inside a load balancer, so complex pools inherit the same configuration. - The Fiber client exposes `Do`, `DoTimeout`, `DoDeadline`, and `CloseIdleConnections`, matching the surface area of the wrapped fasthttp transports. ## 🧰 Generic functions Fiber v3 introduces new generic functions that provide additional utility and flexibility for developers. These functions are designed to simplify common tasks and improve code readability. ### New Generic Functions - **StoreInContext**: Stores request-scoped values in both `c.Locals()` and the request `context.Context`, so the same value can be read through middleware `FromContext` helpers and direct locals access. - **Convert**: Converts a value with a specified converter function and default value. - **Locals**: Retrieves or sets local values within a request context. - **Params**: Retrieves route parameters and can handle various types of route parameters. - **Query**: Retrieves the value of a query parameter from the request URI and can handle various types of query parameters. - **GetReqHeader**: Returns the HTTP request header specified by the field and can handle various types of header values. `fiber.Config.PassLocalsToContext` is now available to control whether `StoreInContext` also synchronizes values with request `context.Context` for Fiber-backed contexts. The default is `false` for backward compatibility. `ValueFromContext` continues reading Fiber-backed values from `c.Locals()`. ### Example
Convert ```go package main import ( "strconv" "github.com/gofiber/fiber/v3" ) func main() { app := fiber.New() app.Get("/convert", func(c fiber.Ctx) error { value, err := fiber.Convert[int](c.Query("value"), strconv.Atoi, 0) if err != nil { return c.Status(fiber.StatusBadRequest).SendString(err.Error()) } return c.JSON(value) }) app.Listen(":3000") } ``` ```sh curl "http://localhost:3000/convert?value=123" # Output: 123 curl "http://localhost:3000/convert?value=abc" # Output: "failed to convert: strconv.Atoi: parsing \"abc\": invalid syntax" ```
Locals ```go package main import ( "github.com/gofiber/fiber/v3" ) func main() { app := fiber.New() app.Use("/user/:id", func(c fiber.Ctx) error { // ask database for user // ... // set local values from database fiber.Locals[string](c, "user", "john") fiber.Locals[int](c, "age", 25) // ... return c.Next() }) app.Get("/user/*", func(c fiber.Ctx) error { // get local values name := fiber.Locals[string](c, "user") age := fiber.Locals[int](c, "age") // ... return c.JSON(fiber.Map{"name": name, "age": age}) }) app.Listen(":3000") } ``` ```sh curl "http://localhost:3000/user/5" # Output: {"name":"john","age":25} ```
Params ```go package main import ( "github.com/gofiber/fiber/v3" ) func main() { app := fiber.New() app.Get("/params/:id", func(c fiber.Ctx) error { id := fiber.Params[int](c, "id", 0) return c.JSON(id) }) app.Listen(":3000") } ``` ```sh curl "http://localhost:3000/params/123" # Output: 123 curl "http://localhost:3000/params/abc" # Output: 0 ```
Query ```go package main import ( "github.com/gofiber/fiber/v3" ) func main() { app := fiber.New() app.Get("/query", func(c fiber.Ctx) error { age := fiber.Query[int](c, "age", 0) return c.JSON(age) }) app.Listen(":3000") } ``` ```sh curl "http://localhost:3000/query?age=25" # Output: 25 curl "http://localhost:3000/query?age=abc" # Output: 0 ```
GetReqHeader ```go package main import ( "github.com/gofiber/fiber/v3" ) func main() { app := fiber.New() app.Get("/header", func(c fiber.Ctx) error { userAgent := fiber.GetReqHeader[string](c, "User-Agent", "Unknown") return c.JSON(userAgent) }) app.Listen(":3000") } ``` ```sh curl -H "User-Agent: CustomAgent" "http://localhost:3000/header" # Output: "CustomAgent" curl "http://localhost:3000/header" # Output: "Unknown" ```
## 🛠️ Utils {#utils} Fiber v3 removes the built-in `utils` directory and now imports utility helpers from the separate [`github.com/gofiber/utils/v2`](https://github.com/gofiber/utils) module. See the [migration guide](#utils-migration) for detailed replacement steps and examples. The `github.com/gofiber/utils` module also introduces new helpers like `ParseInt`, `ParseUint`, `Walk`, `ReadFile`, and `Timestamp`. ## 🧩 Services Fiber v3 introduces a new feature called Services. This feature allows developers to quickly start services that the application depends on, removing the need to manually provision things like database servers, caches, or message brokers, to name a few. ### Example
Adding a service ```go package main import ( "context" "github.com/gofiber/fiber/v3" ) type myService struct { img string // ... } // Start initializes and starts the service. It implements the [fiber.Service] interface. func (s *myService) Start(ctx context.Context) error { // start the service return nil } // String returns a string representation of the service. // It is used to print a human-readable name of the service in the startup message. // It implements the [fiber.Service] interface. func (s *myService) String() string { return s.img } // State returns the current state of the service. // It implements the [fiber.Service] interface. func (s *myService) State(ctx context.Context) (string, error) { return "running", nil } // Terminate stops and removes the service. It implements the [fiber.Service] interface. func (s *myService) Terminate(ctx context.Context) error { // stop the service return nil } func main() { cfg := &fiber.Config{} cfg.Services = append(cfg.Services, &myService{img: "postgres:latest"}) cfg.Services = append(cfg.Services, &myService{img: "redis:latest"}) app := fiber.New(*cfg) // ... } ```
Output ```sh $ go run . -v _______ __ / ____(_) /_ ___ _____ / /_ / / __ \/ _ \/ ___/ / __/ / / /_/ / __/ / /_/ /_/_.___/\___/_/ v3.0.0 -------------------------------------------------- INFO Server started on: http://127.0.0.1:3000 (bound on host 0.0.0.0 and port 3000) INFO Services: 2 INFO 🧩 [ RUNNING ] postgres:latest INFO 🧩 [ RUNNING ] redis:latest INFO Total handlers count: 2 INFO Prefork: Disabled INFO PID: 12279 INFO Total process count: 1 ```
## 📃 Log `fiber.AllLogger[T]` interface now has a new generic type parameter `T` and a method called `Logger`. This method can be used to get the underlying logger instance from the Fiber logger middleware. This is useful when you want to configure the logger middleware with a custom logger and still want to access the underlying logger instance with the appropriate type. You can find more details about this feature in [/docs/api/log.md](./api/log.md#logger). `logger.Config` now supports a new field called `ForceColors`. This field allows you to force the logger to always use colors, even if the output is not a terminal. This is useful when you want to ensure that the logs are always colored, regardless of the output destination. ```go package main import "github.com/gofiber/fiber/v3/middleware/logger" app.Use(logger.New(logger.Config{ ForceColors: true, })) ``` ## 📦 Storage Interface The storage interface has been updated to include new subset of methods with `WithContext` suffix. These methods allow you to pass a context to the storage operations, enabling better control over timeouts and cancellation if needed. This is particularly useful when storage implementations used outside of the Fiber core, such as in background jobs or long-running tasks. **New Methods Signatures:** ```go // GetWithContext gets the value for the given key with a context. // `nil, nil` is returned when the key does not exist GetWithContext(ctx context.Context, key string) ([]byte, error) // SetWithContext stores the given value for the given key // with an expiration value, 0 means no expiration. // Empty key or value will be ignored without an error. SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error // DeleteWithContext deletes the value for the given key with a context. // It returns no error if the storage does not contain the key, DeleteWithContext(ctx context.Context, key string) error // ResetWithContext resets the storage and deletes all keys with a context. ResetWithContext(ctx context.Context) error ``` ## 🧬 Middlewares ### Important Change for Accessing Middleware Data In Fiber v3, many middlewares that previously set values in `c.Locals()` using string keys (e.g., `c.Locals("requestid")`) have been updated. To align with Go's context best practices and prevent key collisions, these middlewares now store their specific data in the request's context using unexported keys of custom types. This means that directly accessing these values via `c.Locals("some_string_key")` will no longer work for such middleware-provided data. **How to Access Middleware Data in v3:** Each affected middleware now provides dedicated exported functions to retrieve its specific data from the context. You should use these functions instead of relying on string-based lookups in `c.Locals()`. Examples include: - `requestid.FromContext(c)` - `csrf.TokenFromContext(c)` - `csrf.HandlerFromContext(c)` - `session.FromContext(c)` - `basicauth.UsernameFromContext(c)` - `keyauth.TokenFromContext(c)` When used with the Logger middleware, the recommended approach is to use the `CustomTags` feature of the logger, which allows you to call these specific `FromContext` functions. See the [Logger](#logger) section for more details. ### Adaptor The adaptor middleware has been significantly optimized for performance and efficiency. Key improvements include reduced response times, lower memory usage, and fewer memory allocations. These changes make the middleware more reliable and capable of handling higher loads effectively. Enhancements include the introduction of a `sync.Pool` for managing `fasthttp.RequestCtx` instances and better HTTP request and response handling between net/http and fasthttp contexts. Incoming body sizes now respect the Fiber app's configured `BodyLimit` (falling back to the default when unset) when running Fiber from `net/http` through the adaptor, returning `413 Request Entity Too Large` for oversized payloads. | Payload Size | Metric | V2 | V3 | Percent Change | | ------------ | -------------- | ------------ | ----------- | -------------- | | 100KB | Execution Time | 1056 ns/op | 588.6 ns/op | -44.25% | | | Memory Usage | 2644 B/op | 254 B/op | -90.39% | | | Allocations | 16 allocs/op | 5 allocs/op | -68.75% | | 500KB | Execution Time | 1061 ns/op | 562.9 ns/op | -46.94% | | | Memory Usage | 2644 B/op | 248 B/op | -90.62% | | | Allocations | 16 allocs/op | 5 allocs/op | -68.75% | | 1MB | Execution Time | 1080 ns/op | 629.7 ns/op | -41.68% | | | Memory Usage | 2646 B/op | 267 B/op | -89.91% | | | Allocations | 16 allocs/op | 5 allocs/op | -68.75% | | 5MB | Execution Time | 1093 ns/op | 540.3 ns/op | -50.58% | | | Memory Usage | 2654 B/op | 254 B/op | -90.43% | | | Allocations | 16 allocs/op | 5 allocs/op | -68.75% | | 10MB | Execution Time | 1044 ns/op | 533.1 ns/op | -48.94% | | | Memory Usage | 2665 B/op | 258 B/op | -90.32% | | | Allocations | 16 allocs/op | 5 allocs/op | -68.75% | | 25MB | Execution Time | 1069 ns/op | 540.7 ns/op | -49.42% | | | Memory Usage | 2706 B/op | 289 B/op | -89.32% | | | Allocations | 16 allocs/op | 5 allocs/op | -68.75% | | 50MB | Execution Time | 1137 ns/op | 554.6 ns/op | -51.21% | | | Memory Usage | 2734 B/op | 298 B/op | -89.10% | | | Allocations | 16 allocs/op | 5 allocs/op | -68.75% | ### BasicAuth The BasicAuth middleware now validates the `Authorization` header more rigorously and sets security-focused response headers. Passwords must be provided in **hashed** form (e.g. SHA-256 or bcrypt) rather than plaintext. The default challenge includes the `charset="UTF-8"` parameter and disables caching. Responses also set a `Vary: Authorization` header to prevent caching based on credentials. Passwords are no longer stored in the request context. A `Charset` option controls the value used in the challenge header. A new `HeaderLimit` option restricts the maximum length of the `Authorization` header (default: `8192` bytes). The `Authorizer` function now receives the current `fiber.Ctx` as a third argument, allowing credential checks to incorporate request context. ### Cache We are excited to introduce a new option in our caching middleware: Cache Invalidator. This feature provides greater control over cache management, allowing you to define custom conditions for invalidating cache entries. The middleware now emits `Cache-Control` headers by default via the new `DisableCacheControl` flag, increases the default `Expiration` from `1 minute` to `5 minutes`, and applies a new `MaxBytes` limit of `1 MB` (previously unlimited). Additionally, the caching middleware has been optimized to avoid caching non-cacheable status codes, as defined by the [HTTP standards](https://datatracker.ietf.org/doc/html/rfc7231#section-6.1). This improvement enhances cache accuracy and reduces unnecessary cache storage usage. Cached responses now include an RFC-compliant Age header, providing a standardized indication of how long a response has been stored in cache since it was originally generated. This enhancement improves HTTP compliance and facilitates better client-side caching strategies. Cache keys are now redacted in logs and error messages by default, and a `DisableValueRedaction` boolean (default `false`) lets you opt out when you need the raw value for troubleshooting. :::note The deprecated `Store` and `Key` options have been removed in v3. Use `Storage` and `KeyGenerator` instead. ::: ### ResponseTime A new response time middleware measures how long each request takes to process and adds the duration to the response headers. By default it writes the elapsed time to `X-Response-Time`, and you can change the header name. A `Next` hook lets you skip endpoints such as health checks. ### CORS We've made some changes to the CORS middleware to improve its functionality and flexibility. Here's what's new: #### New Struct Fields - `Config.AllowPrivateNetwork`: This new field is a boolean that allows you to control whether private networks are allowed. This is related to the [Private Network Access (PNA)](https://wicg.github.io/private-network-access/) specification from the [Web Incubator Community Group (WICG)](https://wicg.io/). When set to `true`, the CORS middleware will allow CORS preflight requests from private networks and respond with the `Access-Control-Allow-Private-Network: true` header. This could be useful in development environments or specific use cases, but should be done with caution due to potential security risks. #### Updated Struct Fields We've updated several fields from a single string (containing comma-separated values) to slices, allowing for more explicit declaration of multiple values. Here are the updated fields: - `Config.AllowOrigins`: Now accepts a slice of strings, each representing an allowed origin. - `Config.AllowMethods`: Now accepts a slice of strings, each representing an allowed method. - `Config.AllowHeaders`: Now accepts a slice of strings, each representing an allowed header. - `Config.ExposeHeaders`: Now accepts a slice of strings, each representing an exposed header. Additionally, panic messages and logs redact misconfigured origins by default, and a `DisableValueRedaction` flag (default `false`) lets you reveal them when necessary. ### Compression - Added support for `zstd` compression alongside `gzip`, `deflate`, and `brotli`. - Strong `ETag` values are now recomputed for compressed payloads so validators remain accurate. - Compression is bypassed for responses that already specify `Content-Encoding`, for range requests or `206` statuses, and when either side sends `Cache-Control: no-transform`. - `HEAD` requests still negotiate compression so `Content-Encoding`, `Content-Length`, `ETag`, and `Vary` match a corresponding `GET`, but the body is omitted. - `Vary: Accept-Encoding` is merged into responses even when compression is skipped, preventing caches from mixing encoded and unencoded variants. ### CSRF The `Expiration` field in the CSRF middleware configuration has been renamed to `IdleTimeout` to better describe its functionality. Additionally, the default value has been reduced from 1 hour to 30 minutes. CSRF now redacts tokens and storage keys by default and exposes a `DisableValueRedaction` toggle (default `false`) if you must surface those values in diagnostics. The CSRF middleware now validates the [`Sec-Fetch-Site`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-Fetch-Site) header for unsafe HTTP methods. When present, requests with invalid `Sec-Fetch-Site` values (not one of "same-origin", "none", "same-site", or "cross-site") are rejected with `ErrFetchSiteInvalid`. Valid or absent headers proceed to standard origin and token validation checks, providing an early gate to catch malformed requests while maintaining compatibility with legitimate cross-site traffic. ### Idempotency Idempotency middleware now redacts keys by default and offers a `DisableValueRedaction` configuration flag (default `false`) to expose them when debugging. ### EncryptCookie - Added support for specifying key length when using `encryptcookie.GenerateKey(length)`. Keys must be base64-encoded and may be 16, 24, or 32 bytes when decoded, supporting AES-128, AES-192, and AES-256 (default). - Custom encryptor and decryptor callbacks now receive the cookie name. The default AES-GCM helpers bind it as additional authenticated data (AAD) so ciphertext cannot be replayed under a different cookie. - **Breaking change:** Custom encryptor/decryptor hooks now accept the cookie name as their first argument. Update overrides like: ```go // Before Encryptor func(value, key string) (string, error) Decryptor func(value, key string) (string, error) // After Encryptor func(name, value, key string) (string, error) Decryptor func(name, value, key string) (string, error) ``` ### Favicon The favicon middleware now caps cached favicon assets with a configurable `MaxBytes` limit (default `1 MiB`) and uses a limited reader to guard against oversized files when loading from disk. ### EnvVar The `ExcludeVars` field has been removed from the EnvVar middleware configuration. When upgrading, remove any references to this field and explicitly list the variables you wish to expose using `ExportVars`. ### Filesystem The filesystem middleware was removed to reduce confusion with the static middleware. The static middleware now covers the functionality of both. Review the [static middleware](./middleware/static.md) docs or the [migration guide](#-migration-guide) for the updated usage. ### Healthcheck The healthcheck middleware has been simplified into a single generic probe handler. No endpoints are registered automatically. Register the middleware on each route you need—using helpers like `healthcheck.LivenessEndpoint`, `healthcheck.ReadinessEndpoint`, or `healthcheck.StartupEndpoint`—and optionally supply a `Probe` function to determine the service's health. This approach lets you expose any number of health check routes. Refer to the [healthcheck middleware migration guide](./middleware/healthcheck.md) or the [general migration guide](#-migration-guide) to review the changes. ### KeyAuth The keyauth middleware was updated to introduce a configurable `Realm` field for the `WWW-Authenticate` header. The old string-based `KeyLookup` configuration has been replaced with an `Extractor` field. Use helper functions like `keyauth.FromHeader`, `keyauth.FromAuthHeader`, or `keyauth.FromCookie` to define where the key should be retrieved from. Multiple sources can be combined with `keyauth.Chain`. See the migration guide below. New `Challenge`, `Error`, `ErrorDescription`, `ErrorURI`, and `Scope` fields allow customizing the `WWW-Authenticate` header, returning Bearer error details, and specifying required scopes. `ErrorURI` values are validated as absolute, a default `ApiKey` challenge is emitted when using non-Authorization extractors, Bearer `error` values are validated, credentials must conform to RFC 7235 `token68` syntax, and `scope` values are checked against RFC 6750's `scope-token` format. The header is also emitted only after the status code is finalized. ### Logger New helper function called `LoggerToWriter` has been added to the logger middleware. This function allows you to use 3rd party loggers such as `logrus` or `zap` with the Fiber logger middleware without an extra adapter. For example, you can use `zap` with Fiber logger middleware like this: Logger configuration now uses `Stream` instead of `Output` for the destination writer, so update your logger middleware configuration when migrating to v3. Custom logger integrations should update any `LoggerFunc` implementations to the new signature that receives a pointer to the middleware config: `func(c fiber.Ctx, data *logger.Data, cfg *logger.Config) error`.
Example ```go package main import ( "github.com/gofiber/contrib/fiberzap/v2" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/log" "github.com/gofiber/fiber/v3/middleware/logger" ) func main() { // Create a new Fiber instance app := fiber.New() // Create a new zap logger which is compatible with Fiber AllLogger interface zap := fiberzap.NewLogger(fiberzap.LoggerConfig{ ExtraKeys: []string{"request_id"}, }) // Use the logger middleware with zerolog logger app.Use(logger.New(logger.Config{ Stream: logger.LoggerToWriter(zap, log.LevelDebug), })) // Define a route app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) // Start server on http://localhost:3000 app.Listen(":3000") } ```
:::note The deprecated `TagHeader` constant was removed. Use `TagReqHeader` when you need to log request headers. ::: #### Logging Middleware Values (e.g., Request ID) In Fiber v3, middleware (like `requestid`) now stores values in the request context using unexported keys of custom types. This aligns with Go's context best practices to prevent key collisions between packages. As a result, directly accessing these values using string keys with `c.Locals("your_key")` or in the logger format string with `${locals:your_key}` (e.g., `${locals:requestid}`) will no longer work for values set by such middleware. **Recommended Solution: `CustomTags`** The cleanest and most maintainable way to include these middleware-specific values in your logs is by using the `CustomTags` option in the logger middleware configuration. This allows you to define a custom function to retrieve the value correctly from the context.
Example: Logging Request ID with CustomTags ```go package main import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/logger" "github.com/gofiber/fiber/v3/middleware/requestid" ) func main() { app := fiber.New() // Ensure requestid middleware is used before the logger app.Use(requestid.New()) app.Use(logger.New(logger.Config{ CustomTags: map[string]logger.LogFunc{ "requestid": func(output logger.Buffer, c fiber.Ctx, data *logger.Data, extraParam string) (int, error) { // Retrieve the request ID using the middleware's specific function return output.WriteString(requestid.FromContext(c)) }, }, // Use the custom tag in your format string Format: "[${time}] ${ip} - ${requestid} - ${status} ${method} ${path}\n", })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) app.Listen(":3000") } ```
**Alternative: Manually Copying to `Locals`** If you have existing logging patterns that rely on `c.Locals` or prefer to manage these values in `Locals` for other reasons, you can manually copy the value from the context to `c.Locals` in a preceding middleware:
Example: Manually setting requestid in Locals ```go app.Use(requestid.New()) // Request ID middleware app.Use(func(c fiber.Ctx) error { // Manually copy the request ID to Locals c.Locals("requestid", requestid.FromContext(c)) return c.Next() }) app.Use(logger.New(logger.Config{ // Now ${locals:requestid} can be used, but CustomTags is generally preferred Format: "[${time}] ${ip} - ${locals:requestid} - ${status} ${method} ${path}\n", })) ```
Both approaches ensure your logger can access these values while respecting Go's context practices. The `Skip` is a function to determine if logging is skipped or written to `Stream`.
Example Usage ```go app.Use(logger.New(logger.Config{ Skip: func(c fiber.Ctx) bool { // Skip logging HTTP 200 requests return c.Response().StatusCode() == fiber.StatusOK }, })) ``` ```go app.Use(logger.New(logger.Config{ Skip: func(c fiber.Ctx) bool { // Only log errors, similar to an error.log return c.Response().StatusCode() < 400 }, })) ```
#### Predefined Formats Logger provides predefined formats that you can use by name or directly by specifying the format string.
Example Usage ```go app.Use(logger.New(logger.Config{ Format: logger.FormatCombined, })) ``` See more in [Logger](./middleware/logger.md#predefined-formats)
### Limiter The limiter middleware uses a new Fixed Window Rate Limiter implementation. Custom limiter algorithms should now implement the updated `limiter.Handler` interface, whose `New` method receives a pointer to the active config: `New(cfg *limiter.Config) fiber.Handler`. Limiter now redacts request keys in error paths by default. A new `DisableValueRedaction` boolean (default `false`) lets you reveal the raw limiter key if diagnostics require it. :::note Deprecated fields `Duration`, `Store`, and `Key` have been removed in v3. Use `Expiration`, `Storage`, and `KeyGenerator` instead. ::: ### Monitor Monitor middleware is migrated to the [Contrib package](https://github.com/gofiber/contrib/tree/main/monitor) with [PR #1172](https://github.com/gofiber/contrib/pull/1172). ### Proxy The proxy middleware has been updated to improve consistency with Go naming conventions. The `TlsConfig` field in the configuration struct has been renamed to `TLSConfig`. Additionally, the `WithTlsConfig` method has been removed; you should now configure TLS directly via the `TLSConfig` property within the `Config` struct. The new `KeepConnectionHeader` option (default `false`) drops the `Connection` header unless explicitly enabled to retain it. `proxy.Balancer` now accepts an optional variadic configuration: call `proxy.Balancer()` to use defaults or continue passing a `proxy.Config` value as before. ### Recover The Recover middleware allows customizing the error it returns. Set a `PanicHandler` in its `Config` to change the default behavior. ### Session The Session middleware has undergone key changes in v3 to improve functionality and flexibility. While v2 methods remain available for backward compatibility, we now recommend using the new middleware handler for session management. #### Key Updates The session middleware has undergone significant improvements in v3, focusing on type safety, flexibility, and better developer experience. #### Key Changes - **Extractor Pattern**: The string-based `KeyLookup` configuration has been replaced with a more flexible and type-safe `Extractor` function pattern. - **New Middleware Handler**: The `New` function now returns a middleware handler instead of a `*Store`. To access the session store, use the `Store` method on the middleware, or opt for `NewStore` or `NewWithStore` for custom store integration. - **Manual Session Release**: Session instances are no longer automatically released after being saved. To ensure proper lifecycle management, you must manually call `sess.Release()`. - **Idle Timeout**: The `Expiration` field has been replaced with `IdleTimeout`, which handles session inactivity. If the session is idle for the specified duration, it will expire. The idle timeout is updated when the session is saved. If you are using the middleware handler, the idle timeout will be updated automatically. - **Absolute Timeout**: The `AbsoluteTimeout` field has been added. If you need to set an absolute session timeout, you can use this field to define the duration. The session will expire after the specified duration, regardless of activity. - **Default KeyGenerator**: Changed from `utils.UUIDv4` to `utils.SecureToken`, producing base64-encoded tokens instead of UUID format. For more details on these changes and migration instructions, check the [Session Middleware Migration Guide](./middleware/session.md#migration-guide). ### Timeout The timeout middleware is now configurable. A new `Config` struct allows customizing the timeout duration, defining a handler that runs when a timeout occurs, and specifying errors to treat as timeouts. The `New` function now accepts a `Config` value instead of a duration. **Behavioral changes:** - **Immediate return on timeout**: The middleware now returns immediately when a timeout occurs, without waiting for the handler to finish. This is achieved through the new **Abandon mechanism** which marks the context as abandoned so it won't be returned to the pool while the handler is still running. - **Context propagation**: The timeout context is properly propagated to the handler. Handlers can detect timeouts by listening on `c.Context().Done()` and return early. - **Panic handling**: Panics in the handler are caught and converted to `500 Internal Server Error` responses. - **Race-free design**: The implementation uses fasthttp's `TimeoutErrorWithCode` combined with Fiber's Abandon mechanism to ensure complete race-freedom between the middleware, handler goroutine, and context pooling. **New Ctx methods for the Abandon mechanism:** - `Abandon()`: Marks the context as abandoned - `IsAbandoned()`: Returns true if the context was abandoned - `ForceRelease()`: Releases an abandoned context back to the pool (for advanced use) **Migration:** Replace calls like `timeout.New(handler, 2*time.Second)` with `timeout.New(handler, timeout.Config{Timeout: 2 * time.Second})`. ## 🔌 Addons In v3, Fiber introduced Addons. Addons are additional useful packages that can be used in Fiber. ### Retry The Retry addon is a new addon that implements a retry mechanism for unsuccessful network operations. It uses an exponential backoff algorithm with jitter. It calls the function multiple times and tries to make it successful. If all calls are failed, then, it returns an error. It adds a jitter at each retry step because adding a jitter is a way to break synchronization across the client and avoid collision.
Example ```go package main import ( "fmt" "github.com/gofiber/fiber/v3/addon/retry" "github.com/gofiber/fiber/v3/client" ) func main() { expBackoff := retry.NewExponentialBackoff(retry.Config{}) // Local variables that will be used inside of Retry var resp *client.Response var err error // Retry a network request and return an error to signify to try again err = expBackoff.Retry(func() error { client := client.New() resp, err = client.Get("https://gofiber.io") if err != nil { return fmt.Errorf("GET gofiber.io failed: %w", err) } if resp.StatusCode() != 200 { return fmt.Errorf("GET gofiber.io did not return OK 200") } return nil }) // If all retries failed, panic if err != nil { panic(err) } fmt.Printf("GET gofiber.io succeeded with status code %d\n", resp.StatusCode()) } ```
## 📋 Migration guide To streamline upgrades between Fiber versions, the Fiber CLI ships with a `migrate` command: ```bash go install github.com/gofiber/cli/fiber@latest fiber migrate --to v3 ``` ### Options - `-t, --to string` migrate to a specific version, e.g. `v3.0.0` - `-f, --force` force migration even if already on that version - `-s, --skip_go_mod` skip running `go mod tidy`, `go mod download`, and `go mod vendor` ### Changes Overview - [🚀 App](#-app-1) - [🎣 Hooks](#-hooks-1) - [🚀 Listen](#-listen-1) - [🗺 Router](#-router-1) - [🧠 Context](#-context-1) - [📎 Binding (was Parser)](#-parser) - [🔄 Redirect](#-redirect-1) - [🧾 Log](#-log-1) - [🌎 Client package](#-client-package-1) - [🛠️ Utils](#utils-migration) - [🧬 Middlewares](#-middlewares-1) - [Important Change for Accessing Middleware Data](#important-change-for-accessing-middleware-data) - [BasicAuth](#basicauth-1) - [Cache](#cache-1) - [CORS](#cors-1) - [CSRF](#csrf-1) - [Filesystem](#filesystem-1) - [EnvVar](#envvar-1) - [Favicon](#favicon) - [Healthcheck](#healthcheck-1) - [Monitor](#monitor-1) - [Proxy](#proxy-1) - [Session](#session-1) ### 🚀 App #### Static Since we've removed `app.Static()`, you need to move methods to static middleware like the example below: ```go // Before app.Static("/", "./public") app.Static("/prefix", "./public") app.Static("/prefix", "./public", Static{ Index: "index.htm", }) app.Static("*", "./public/index.html") ``` ```go // After app.Get("/*", static.New("./public")) app.Get("/prefix*", static.New("./public")) app.Get("/prefix*", static.New("./public", static.Config{ IndexNames: []string{"index.htm", "index.html"}, })) app.Get("*", static.New("./public/index.html")) ``` :::caution You have to put `*` to the end of the route if you don't define static route with `app.Use`. ::: #### Trusted Proxies We've renamed `EnableTrustedProxyCheck` to `TrustProxy` and moved `TrustedProxies` to `TrustProxyConfig`. **Important:** To use proxy headers like `X-Forwarded-For` with `c.IP()`, you must configure **all** of `TrustProxy`, `ProxyHeader`, and a trusted proxy via `TrustProxyConfig`. If the proxy is not trusted (for example, if you set only `ProxyHeader` or only `TrustProxy` without configuring `TrustProxyConfig`), proxy headers are ignored and `c.IP()` will return the remote TCP IP instead. ```go // Before app := fiber.New(fiber.Config{ // EnableTrustedProxyCheck enables the trusted proxy check. EnableTrustedProxyCheck: true, // TrustedProxies is a list of trusted proxy IP ranges/addresses. TrustedProxies: []string{"0.8.0.0", "127.0.0.0/8", "::1/128"}, }) ``` ```go // After app := fiber.New(fiber.Config{ // TrustProxy enables the trusted proxy check TrustProxy: true, // ProxyHeader specifies which header to read the real client IP from ProxyHeader: fiber.HeaderXForwardedFor, // TrustProxyConfig allows for configuring trusted proxies. TrustProxyConfig: fiber.TrustProxyConfig{ // Proxies is a list of trusted proxy IP ranges/addresses. Proxies: []string{"0.8.0.0"}, // Trust all loop-back IP addresses (127.0.0.0/8, ::1/128) Loopback: true, // Trust Unix domain socket connections UnixSocket: true, }, }) ``` For detailed proxy configuration guidance, see the [reverse proxy guide](./guide/reverse-proxy.md). ### 🎣 Hooks `OnShutdown` has been replaced by two hooks: `OnPreShutdown` and `OnPostShutdown`. Use them to run cleanup code before and after the server shuts down. When handling shutdown errors, register an `OnPostShutdown` hook and call `app.Listen()` in a goroutine. ```go // Before app.OnShutdown(func() { // Code to run before shutdown }) ``` ```go // After app.Hooks().OnPreShutdown(func() error { // Code to run before shutdown return nil }) ``` ### 🚀 Listen The `Listen` helpers (`ListenTLS`, `ListenMutualTLS`, etc.) were removed. Use `app.Listen()` with `fiber.ListenConfig` and a `tls.Config` when TLS is required. Options such as `ListenerNetwork` and `UnixSocketFileMode` are now configured via this struct. Prefer `TLSConfig` when you need full control, or use `CertFile` and `CertKeyFile` for quick TLS setup. ```go // Before app.ListenTLS(":3000", "cert.pem", "key.pem") ``` ```go // After app.Listen(":3000", fiber.ListenConfig{ CertFile: "./cert.pem", CertKeyFile: "./key.pem", }) ``` ### 🗺 Router #### Direct `net/http` handlers Route registration helpers now accept native `net/http` handlers. Pass an `http.Handler`, `http.HandlerFunc`, or compatible function directly to methods such as `app.Get`, `Group`, or `RouteChain` and Fiber will adapt it at registration time. Manual wrapping through the adaptor middleware is no longer required for these common cases. :::note Compatibility considerations Adapted handlers stick to `net/http` semantics. They do not interact with `fiber.Ctx` and are slower than native Fiber handlers because of the extra conversion layer. Use them to ease migrations, but prefer Fiber handlers in performance-critical paths. ::: ```go httpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if _, err := w.Write([]byte("served by net/http")); err != nil { panic(err) } }) app.Get("/", httpHandler) ``` #### Middleware Registration The signatures for [`Add`](#middleware-registration) and [`Route`](#route-chaining) have been changed. To migrate [`Add`](#middleware-registration) you must change the `methods` in a slice. ```go // Before app.Add(fiber.MethodPost, "/api", myHandler) ``` ```go // After app.Add([]string{fiber.MethodPost}, "/api", myHandler) ``` #### Mounting In this release, the `Mount` method has been removed. Instead, you can use the `Use` method to achieve similar functionality. ```go // Before app.Mount("/api", apiApp) ``` ```go // After app.Use("/api", apiApp) ``` #### Route Chaining Refer to the [route chaining](#route-chaining) section for details on the new `RouteChain` helper. The `Route` function now matches its v2 behavior for prefix encapsulation. ```go // Before app.Route("/api", func(apiGrp Router) { apiGrp.Route("/user/:id?", func(userGrp Router) { userGrp.Get("/", func(c fiber.Ctx) error { // Get user return c.JSON(fiber.Map{"message": "Get user", "id": c.Params("id")}) }) userGrp.Post("/", func(c fiber.Ctx) error { // Create user return c.JSON(fiber.Map{"message": "User created"}) }) }) }) ``` ```go // After app.RouteChain("/api").RouteChain("/user/:id?") .Get(func(c fiber.Ctx) error { // Get user return c.JSON(fiber.Map{"message": "Get user", "id": c.Params("id")}) }) .Post(func(c fiber.Ctx) error { // Create user return c.JSON(fiber.Map{"message": "User created"}) }); ``` ### 🗺 RebuildTree We introduced a new method that enables rebuilding the route tree stack at runtime. This allows you to add routes dynamically while your application is running and update the route tree to make the new routes available for use. For more details, refer to the [app documentation](./api/app.md#rebuildtree): #### Example Usage ```go app.Get("/define", func(c fiber.Ctx) error { // Define a new route dynamically app.Get("/dynamically-defined", func(c fiber.Ctx) error { // Adding a dynamically defined route return c.SendStatus(http.StatusOK) }) app.RebuildTree() // Rebuild the route tree to register the new route return c.SendStatus(http.StatusOK) }) ``` In this example, a new route is defined, and `RebuildTree()` is called to ensure the new route is registered and available. Note: Use this method with caution. It is **not** thread-safe and can be very performance-intensive. Therefore, it should be used sparingly and primarily in development mode. It should not be invoke concurrently. #### RemoveRoute - **RemoveRoute**: Removes route by path - **RemoveRouteByName**: Removes route by name - **RemoveRouteFunc**: Removes route by a function having `*Route` parameter For more details, refer to the [app documentation](./api/app.md#removeroute): ### 🧠 Context Fiber v3 introduces several new features and changes to the Ctx interface, enhancing its functionality and flexibility. - **ParamsInt**: Use `Params` with generic types. - **QueryBool**: Use `Query` with generic types. - **QueryFloat**: Use `Query` with generic types. - **QueryInt**: Use `Query` with generic types. - **Bind**: Now used for binding instead of view binding. Use `c.ViewBind()` for view binding. In Fiber v3, the `Ctx` parameter in handlers is now an interface, which means the `*` symbol is no longer used. Here is an example demonstrating this change:
Example **Before**: ```go package main import ( "github.com/gofiber/fiber/v2" ) func main() { app := fiber.New() // Route Handler with *fiber.Ctx app.Get("/", func(c *fiber.Ctx) error { return c.SendString("Hello, World!") }) app.Listen(":3000") } ``` **After**: ```go package main import ( "github.com/gofiber/fiber/v3" ) func main() { app := fiber.New() // Route Handler without *fiber.Ctx app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) app.Listen(":3000") } ``` **Explanation**: In this example, the `Ctx` parameter in the handler is used as an interface (`fiber.Ctx`) instead of a pointer (`*fiber.Ctx`). This change allows for more flexibility and customization in Fiber v3.
#### 📎 Parser The `Parser` section in Fiber v3 has undergone significant changes to improve functionality and flexibility. ##### Migration Instructions 1. **BodyParser**: Use `c.Bind().Body()` instead of `c.BodyParser()`.
Example ```go // Before app.Post("/user", func(c *fiber.Ctx) error { var user User if err := c.BodyParser(&user); err != nil { return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": err.Error()}) } return c.JSON(user) }) ``` ```go // After app.Post("/user", func(c fiber.Ctx) error { var user User if err := c.Bind().Body(&user); err != nil { return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": err.Error()}) } return c.JSON(user) }) ```
2. **ParamsParser**: Use `c.Bind().URI()` instead of `c.ParamsParser()`. Note that the struct tag has changed from `params` to `uri`.
Example ```go // Before type Params struct { ID int `params:"id"` } app.Get("/user/:id", func(c *fiber.Ctx) error { var params Params if err := c.ParamsParser(¶ms); err != nil { return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": err.Error()}) } return c.JSON(params) }) ``` ```go // After type Params struct { ID int `uri:"id"` } app.Get("/user/:id", func(c fiber.Ctx) error { var params Params if err := c.Bind().URI(¶ms); err != nil { return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": err.Error()}) } return c.JSON(params) }) ```
3. **QueryParser**: Use `c.Bind().Query()` instead of `c.QueryParser()`.
Example ```go // Before app.Get("/search", func(c *fiber.Ctx) error { var query Query if err := c.QueryParser(&query); err != nil { return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": err.Error()}) } return c.JSON(query) }) ``` ```go // After app.Get("/search", func(c fiber.Ctx) error { var query Query if err := c.Bind().Query(&query); err != nil { return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": err.Error()}) } return c.JSON(query) }) ```
4. **CookieParser**: Use `c.Bind().Cookie()` instead of `c.CookieParser()`.
Example ```go // Before app.Get("/cookie", func(c *fiber.Ctx) error { var cookie Cookie if err := c.CookieParser(&cookie); err != nil { return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": err.Error()}) } return c.JSON(cookie) }) ``` ```go // After app.Get("/cookie", func(c fiber.Ctx) error { var cookie Cookie if err := c.Bind().Cookie(&cookie); err != nil { return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": err.Error()}) } return c.JSON(cookie) }) ```
#### 🔄 Redirect Fiber v3 enhances the redirect functionality by introducing new methods and improving existing ones. The new redirect methods provide more flexibility and control over the redirection process. ##### Migration Instructions 1. **RedirectToRoute**: Use `c.Redirect().Route()` instead of `c.RedirectToRoute()`.
Example ```go // Before app.Get("/old", func(c *fiber.Ctx) error { return c.RedirectToRoute("newRoute") }) ``` ```go // After app.Get("/old", func(c fiber.Ctx) error { return c.Redirect().Route("newRoute") }) ```
2. **RedirectBack**: Use `c.Redirect().Back()` instead of `c.RedirectBack()`.
Example ```go // Before app.Get("/back", func(c *fiber.Ctx) error { return c.RedirectBack() }) ``` ```go // After app.Get("/back", func(c fiber.Ctx) error { return c.Redirect().Back() }) ```
3. **Redirect**: Use `c.Redirect().To()` instead of `c.Redirect()`.
Example ```go // Before app.Get("/old", func(c *fiber.Ctx) error { return c.Redirect("/new") }) ``` ```go // After app.Get("/old", func(c fiber.Ctx) error { return c.Redirect().To("/new") }) ```
#### 🧾 Log The `ConfigurableLogger` and `AllLogger` interfaces now use generics. You can specify the underlying logger type when implementing these interfaces. While `any` can be used for maximum flexibility in some contexts, when retrieving the concrete logger via `log.DefaultLogger`, you must specify the exact underlying logger type, for example `log.DefaultLogger[*MyLogger]().Logger()`. ### 🌎 Client package Fiber v3 introduces a completely rebuilt client package with numerous new features such as Cookiejar, request/response hooks, and more. Here is a guide to help you migrate from Fiber v2 to Fiber v3. #### New Features - **Cookiejar**: Manage cookies automatically. - **Request/Response Hooks**: Customize request and response handling. - **Improved Error Handling**: Better error management and reporting. #### Migration Instructions **Import Path**: Update the import path to the new client package.
Before ```go import "github.com/gofiber/fiber/v2/client" ```
After ```go import "github.com/gofiber/fiber/v3/client" ```
**Common migrations**: 1. **Shared defaults instead of per-call mutation**: Move headers and timeouts into the reusable client and override with `client.Config` when needed.
Example ```go // Before status, body, errs := fiber.Get("https://api.example.com/users"). Set("Authorization", "Bearer "+token). Timeout(5 * time.Second). String() if len(errs) > 0 { return fmt.Errorf("request failed: %v", errs) } fmt.Println(status, body) ``` ```go // After cli := client.New(). AddHeader("Authorization", "Bearer "+token). SetTimeout(5 * time.Second) resp, err := cli.Get("https://api.example.com/users") if err != nil { return err } defer resp.Close() fmt.Println(resp.StatusCode(), resp.String()) ```
2. **Body handling**: Replace `Agent.JSON(...).Struct(&dst)` with request bodies through `client.Config` (or `Request.SetJSON`) and decode the response via `Response.JSON`.
Example ```go // Before var created user status, _, errs := fiber.Post("https://api.example.com/users"). JSON(payload). Struct(&created) if len(errs) > 0 { return fmt.Errorf("request failed: %v", errs) } fmt.Println(status, created) ``` ```go // After cli := client.New() resp, err := cli.Post("https://api.example.com/users", client.Config{ Body: payload, }) if err != nil { return err } defer resp.Close() var created user if err := resp.JSON(&created); err != nil { return fmt.Errorf("decode failed: %w", err) } fmt.Println(resp.StatusCode(), created) ```
3. **Path and query parameters**: Use the new path/query helpers instead of manually formatting URLs.
Example ```go // Before code, body, errs := fiber.Get(fmt.Sprintf("https://api.example.com/users/%s", id)). QueryString("active=true"). String() if len(errs) > 0 { return fmt.Errorf("request failed: %v", errs) } fmt.Println(code, body) ``` ```go // After cli := client.New().SetBaseURL("https://api.example.com") resp, err := cli.Get("/users/:id", client.Config{ PathParam: map[string]string{"id": id}, Param: map[string]string{"active": "true"}, }) if err != nil { return err } defer resp.Close() fmt.Println(resp.StatusCode(), resp.String()) ```
4. **Agent helpers**: `Agent.Bytes`, `AcquireAgent`, and `Agent.Parse` have been removed. Reuse a `client.Client` instance (or pool requests/responses directly) and access response data through the new typed helpers.
Example ```go // Before agent := fiber.AcquireAgent() status, body, errs := agent.Get("https://api.example.com/users").Bytes() fiber.ReleaseAgent(agent) if len(errs) > 0 { return fmt.Errorf("request failed: %v", errs) } var users []user if err := fiber.Parse(body, &users); err != nil { return fmt.Errorf("parse failed: %w", err) } fmt.Println(status, len(users)) ``` ```go // After cli := client.New() resp, err := cli.Get("https://api.example.com/users") if err != nil { return err } defer resp.Close() var users []user if err := resp.JSON(&users); err != nil { return fmt.Errorf("decode failed: %w", err) } fmt.Println(resp.StatusCode(), len(users)) ``` :::tip If you need pooling, use `client.AcquireRequest`, `client.AcquireResponse`, and their corresponding release functions around a long-lived `client.Client` instead of the removed agent pool. :::
5. **Fiber-level shortcuts**: The `fiber.Get`, `fiber.Post`, and similar top-level helpers are no longer exposed from the main module. Use the client package equivalents (`client.Get`, `client.Post`, etc.) which call the shared default client (or pass your own client instance for custom defaults).
Example ```go // Before status, body, errs := fiber.Get("https://api.example.com/health").String() if len(errs) > 0 { return fmt.Errorf("request failed: %v", errs) } fmt.Println(status, body) ``` ```go // After resp, err := client.Get("https://api.example.com/health") if err != nil { return err } defer resp.Close() fmt.Println(resp.StatusCode(), resp.String()) ``` :::note The `client.Get`/`client.Post` helpers use `client.C()` (the default shared client). For custom defaults, construct a client with `client.New()` and invoke its methods instead. :::
#### Complete API Migration Reference
Click to expand full v2 → v3 API mapping tables ##### Core Concepts | Description | v2 | v3 | |-------------|----|----| | Import | `github.com/gofiber/fiber/v2` | `github.com/gofiber/fiber/v3/client` | | Client Concept | `*fiber.Agent` | `*client.Client` + `*client.Request` | | Response Concept | `(code int, body []byte, errs []error)` | `(*client.Response, error)` | ##### Client/Agent Creation | Description | v2 | v3 | |-------------|----|----| | Create Agent/Client | `fiber.AcquireAgent()` | `client.New()` | | Get from pool | `fiber.AcquireAgent()` | `client.AcquireRequest()` | | Release | `fiber.ReleaseAgent(a)` | `client.ReleaseRequest(req)` | | With fasthttp.Client | - | `client.NewWithClient(c)` | | With HostClient | - | `client.NewWithHostClient(hc)` | | With LBClient | - | `client.NewWithLBClient(lb)` | | Get Request object | `a.Request()` | `c.R()` | | Default client | - | `client.C()` | | Replace default | - | `client.Replace(c)` | ##### HTTP Methods | Description | v2 | v3 (Client) | v3 (Request) | |-------------|----|----|--------------| | GET | `fiber.Get(url)` | `c.Get(url, cfg...)` | `req.Get(url)` | | POST | `fiber.Post(url)` | `c.Post(url, cfg...)` | `req.Post(url)` | | PUT | `fiber.Put(url)` | `c.Put(url, cfg...)` | `req.Put(url)` | | PATCH | `fiber.Patch(url)` | `c.Patch(url, cfg...)` | `req.Patch(url)` | | DELETE | `fiber.Delete(url)` | `c.Delete(url, cfg...)` | `req.Delete(url)` | | HEAD | `fiber.Head(url)` | `c.Head(url, cfg...)` | `req.Head(url)` | | OPTIONS | - | `c.Options(url, cfg...)` | `req.Options(url)` | | Custom | - | `c.Custom(url, method, cfg...)` | `req.Custom(url, method)` | ##### URL & Method | Description | v2 | v3 | |-------------|----|----| | Set URL | `req.SetRequestURI(url)` | `req.SetURL(url)` | | Get URL | `req.URI().String()` | `req.URL()` | | Set Method | `req.Header.SetMethod(method)` | `req.SetMethod(method)` | | Set Base URL | - | `c.SetBaseURL(url)` | ##### Request Execution & Response | Description | v2 | v3 | |-------------|----|----| | Parse Request | `a.Parse()` | Not needed | | Execute (bytes) | `a.Bytes()` → `(code, body, errs)` | `req.Send()` → `(*Response, error)` | | Execute (string) | `a.String()` | `resp.String()` | | Execute (struct) | `a.Struct(&v)` | `resp.JSON(&v)` / `resp.XML(&v)` | | Status Code | Return value `code` | `resp.StatusCode()` | | Status Text | - | `resp.Status()` | | Body (bytes) | Return value `body` | `resp.Body()` | | Response Header | `resp.Header.Peek(key)` | `resp.Header(key)` | | All Headers | `resp.Header.VisitAll(fn)` | `resp.Headers()` | | Cookies | - | `resp.Cookies()` | | Save to file | - | `resp.Save(path)` | | Close | - | `resp.Close()` | ##### Headers | Description | v2 | v3 (Client) | v3 (Request) | |-------------|----|----|--------------| | Set Header | `a.Set(k, v)` | `c.SetHeader(k, v)` | `req.SetHeader(k, v)` | | Add Header | `a.Add(k, v)` | `c.AddHeader(k, v)` | `req.AddHeader(k, v)` | | Multiple Headers | - | `c.SetHeaders(map)` | `req.SetHeaders(map)` | | Bytes variants | `a.SetBytesK/V/KV()` | - | - | ##### User-Agent, Referer, Content-Type, Host | Description | v2 | v3 (Client) | v3 (Request) | |-------------|----|----|--------------| | User-Agent | `a.UserAgent(ua)` | `c.SetUserAgent(ua)` | `req.SetUserAgent(ua)` | | Referer | `a.Referer(ref)` | `c.SetReferer(ref)` | `req.SetReferer(ref)` | | Content-Type | `a.ContentType(ct)` | - | `req.SetHeader("Content-Type", ct)` | | Host | `a.Host(host)` | - | `req.SetHeader("Host", host)` | | Connection Close | `a.ConnectionClose()` | - | `req.SetHeader("Connection", "close")` | ##### Cookies | Description | v2 | v3 (Client) | v3 (Request) | |-------------|----|----|--------------| | Set Cookie | `a.Cookie(k, v)` | `c.SetCookie(k, v)` | `req.SetCookie(k, v)` | | Multiple | `a.Cookies(k1, v1, ...)` | `c.SetCookies(map)` | `req.SetCookies(map)` | | With Struct | - | `c.SetCookiesWithStruct(v)` | `req.SetCookiesWithStruct(v)` | | Cookie Jar | - | `c.SetCookieJar(jar)` | - | ##### Query Parameters | Description | v2 | v3 (Client) | v3 (Request) | |-------------|----|----|--------------| | Query String | `a.QueryString(qs)` | - | - | | Add Param | - | `c.AddParam(k, v)` | `req.AddParam(k, v)` | | Set Param | - | `c.SetParam(k, v)` | `req.SetParam(k, v)` | | With Struct | - | `c.SetParamsWithStruct(v)` | `req.SetParamsWithStruct(v)` | ##### Path Parameters (NEW) | Description | v2 | v3 (Client) | v3 (Request) | |-------------|----|----|--------------| | Set Path Param | - | `c.SetPathParam(k, v)` | `req.SetPathParam(k, v)` | | Multiple | - | `c.SetPathParams(map)` | `req.SetPathParams(map)` | | With Struct | - | `c.SetPathParamsWithStruct(v)` | `req.SetPathParamsWithStruct(v)` | ##### Request Body | Description | v2 | v3 | |-------------|----|----| | Body (bytes) | `a.Body(body)` | `req.SetRawBody(body)` | | Body (string) | `a.BodyString(body)` | `req.SetRawBody([]byte(body))` | | Body Stream | `a.BodyStream(r, size)` | - | | JSON | `a.JSON(v)` | `req.SetJSON(v)` | | XML | `a.XML(v)` | `req.SetXML(v)` | | CBOR (NEW) | - | `req.SetCBOR(v)` | ##### Form Data | Description | v2 | v3 | |-------------|----|----| | Create Args | `fiber.AcquireArgs()` | Direct on Request | | Send Form | `a.Form(args)` | `req.SetFormData(k, v)` | | Add Form Data | `args.Set(k, v)` | `req.AddFormData(k, v)` | | With Map | - | `req.SetFormDataWithMap(map)` | | With Struct | - | `req.SetFormDataWithStruct(v)` | ##### File Upload | Description | v2 | v3 | |-------------|----|----| | Multipart Form | `a.MultipartForm(args)` | Automatic | | Boundary | `a.Boundary(b)` | `req.SetBoundary(b)` | | Send File | `a.SendFile(f, field...)` | `req.AddFile(path)` | | Multiple Files | `a.SendFiles(...)` | `req.AddFiles(files...)` | | With Reader | - | `req.AddFileWithReader(name, r)` | | FileData | `a.FileData(files...)` | `req.AddFiles(files...)` | ##### Timeout & TLS | Description | v2 | v3 (Client) | v3 (Request) | |-------------|----|----|--------------| | Timeout | `a.Timeout(d)` | `c.SetTimeout(d)` | `req.SetTimeout(d)` | | Max Redirects | `a.MaxRedirectsCount(n)` | Via Config | `req.SetMaxRedirects(n)` | | TLS Config | `a.TLSConfig(cfg)` | `c.SetTLSConfig(cfg)` | - | | Skip Verify | `a.InsecureSkipVerify()` | Via `tls.Config` | - | | Certificates | - | `c.SetCertificates(...)` | - | | Root Cert | - | `c.SetRootCertificate(path)` | - | ##### JSON/XML Encoder | Description | v2 | v3 | |-------------|----|----| | JSON Encoder | `a.JSONEncoder(fn)` | `c.SetJSONMarshal(fn)` | | JSON Decoder | `a.JSONDecoder(fn)` | `c.SetJSONUnmarshal(fn)` | | XML Encoder | - | `c.SetXMLMarshal(fn)` | | XML Decoder | - | `c.SetXMLUnmarshal(fn)` | | CBOR (NEW) | - | `c.SetCBORMarshal/Unmarshal(fn)` | ##### Authentication | Description | v2 | v3 | |-------------|----|----| | Basic Auth | `a.BasicAuth(user, pass)` | Via Header (Base64) | ##### Debug & Retry | Description | v2 | v3 | |-------------|----|----| | Debug | `a.Debug(w...)` | `c.Debug()` | | Disable Debug | - | `c.DisableDebug()` | | Logger | - | `c.SetLogger(logger)` | | Retry | `a.RetryIf(fn)` | `c.SetRetryConfig(cfg)` | ##### Reuse & Reset | Description | v2 | v3 | |-------------|----|----| | Reuse Agent | `a.Reuse()` | Use pool | | Reset Client | - | `c.Reset()` | | Dest Buffer | `a.Dest(dest)` | - | ##### NEW in v3 | Feature | v3 API | |---------|--------| | Request Hooks | `c.AddRequestHook(fn)` | | Response Hooks | `c.AddResponseHook(fn)` | | Proxy | `c.SetProxyURL(url)` | | Context | `req.SetContext(ctx)` | | Dial Function | `c.SetDial(fn)` | | Raw Request | `req.RawRequest` | | Raw Response | `resp.RawResponse` | ##### Key Differences 1. **Architecture**: v2 `Agent` → v3 separate `Client`, `Request`, `Response` 2. **Error Handling**: v2 `[]error` → v3 single `error` 3. **Response**: v2 tuple `(code, body, errs)` → v3 `*Response` object 4. **No Parse()**: v3 auto-initializes requests 5. **Hooks**: v3 adds request/response middleware 6. **Path Params**: v3 native `:param` support 7. **Cookie Jar**: v3 built-in session management 8. **CBOR**: v3 adds CBOR encoding 9. **Context**: v3 native cancellation support 10. **Iterators**: v3 uses `iter.Seq2` for collections 11. **Bytes variants removed**: v2 `*Bytes*` methods gone
### 🛠️ Utils {#utils-migration} Fiber v3 removes the in-repo `utils` package in favor of the external [`github.com/gofiber/utils/v2`](https://github.com/gofiber/utils) module. 1. Replace imports: ```go - import "github.com/gofiber/fiber/v2/utils" + import "github.com/gofiber/utils/v2" ``` 1. Review function changes: | v2 function | v3 replacement | | --- | --- | | `AssertEqual` | removed; use testing libraries like [`github.com/stretchr/testify/assert`](https://pkg.go.dev/github.com/stretchr/testify/assert) | | `ToLowerBytes` | `utils.ToLowerBytes` | | `ToUpperBytes` | `utils.ToUpperBytes` | | `TrimRightBytes` | `utils.TrimRight` | | `TrimLeftBytes` | `utils.TrimLeft` | | `TrimBytes` | `utils.Trim` | | `EqualFoldBytes` | `utils.EqualFold` | | `UUID` | `utils.UUID` | | `UUIDv4` | `utils.UUIDv4` | | `FunctionName` | `utils.FunctionName` | | `GetArgument` | `utils.GetArgument` | | `IncrementIPRange` | `utils.IncrementIPRange` | | `ConvertToBytes` | `utils.ConvertToBytes` | | `CopyString` | `utils.CopyString` | | `CopyBytes` | `utils.CopyBytes` | | `ByteSize` | `utils.ByteSize` | | `ToString` | `utils.ToString` | | `UnsafeString` | `utils.UnsafeString` | | `UnsafeBytes` | `utils.UnsafeBytes` | | `GetString` | removed; use `utils.ToString` or the standard library | | `GetBytes` | removed; use `utils.CopyBytes` or `[]byte(s)` | | `ImmutableString` | removed; strings are already immutable | | `GetMIME` | `utils.GetMIME` | | `ParseVendorSpecificContentType` | `utils.ParseVendorSpecificContentType` | | `StatusMessage` | `utils.StatusMessage` | | `IsIPv4` | `utils.IsIPv4` | | `IsIPv6` | `utils.IsIPv6` | | `ToLower` | `utils.ToLower` | | `ToUpper` | `utils.ToUpper` | | `TrimLeft` | `strings.TrimLeft` | | `Trim` | `strings.Trim` | | `TrimRight` | `strings.TrimRight` | | `EqualFold` | `strings.EqualFold` | | `StartTimeStampUpdater` | `utils.StartTimeStampUpdater` (new `utils.Timestamp` provides the current value) | 1. Update your code. For example: ```go // v2 import oldutils "github.com/gofiber/fiber/v2/utils" func demo() { b := oldutils.TrimBytes([]byte(" fiber ")) id := oldutils.UUIDv4() s := oldutils.GetString([]byte("foo")) } // v3 import ( "github.com/gofiber/utils/v2" "strings" ) func demo() { s := utils.TrimSpace(" fiber ") id := utils.UUIDv4() str := utils.ToString([]byte("foo")) t := strings.TrimRight("bar ", " ") } ``` The `github.com/gofiber/utils/v2` module also introduces new helpers like `ParseInt`, `ParseUint`, `Walk`, `ReadFile`, and `Timestamp`. ### 🧬 Middlewares #### Important Change for Accessing Middleware Data **Change:** In Fiber v2, some middlewares set data in `c.Locals()` using string keys (e.g., `c.Locals("requestid")`). In Fiber v3, to align with Go's context best practices and prevent key collisions, these middlewares now store their specific data in the request's context using unexported keys of custom types. **Impact:** Directly accessing these middleware-provided values via `c.Locals("some_string_key")` will no longer work. **Migration Action:** The `ContextKey` configuration option has been removed from all middlewares. Values are no longer stored under user-defined keys. You must update your code to use the dedicated exported functions provided by each affected middleware to retrieve its data from the context. **Examples of new helper functions to use:** - `requestid.FromContext(c)` - `csrf.TokenFromContext(c)` - `csrf.HandlerFromContext(c)` - `session.FromContext(c)` - `basicauth.UsernameFromContext(c)` - `keyauth.TokenFromContext(c)` **For logging these values:** The recommended approach is to use the `CustomTags` feature of the Logger middleware, which allows you to call these specific `FromContext` functions. Refer to the [Logger section in "What's New"](#logger) for detailed examples. :::note If you were manually setting and retrieving your own application-specific values in `c.Locals()` using string keys, that functionality remains unchanged. This change specifically pertains to how Fiber's built-in (and some contrib) middlewares expose their data. ::: #### BasicAuth The `Authorizer` callback now receives the current request context. Update custom functions from: ```go Authorizer: func(user, pass string) bool { // v2 style return user == "admin" && pass == "secret" } ``` to: ```go Authorizer: func(user, pass string, _ fiber.Ctx) bool { // v3 style with access to the Fiber context return user == "admin" && pass == "secret" } ``` Passwords configured for BasicAuth must now be pre-hashed. If no prefix is supplied the middleware expects a SHA-256 digest encoded in hex. Common prefixes like `{SHA256}` and `{SHA512}` and bcrypt strings are also supported. Plaintext passwords are no longer accepted. Unauthorized responses also include a `Vary: Authorization` header for correct caching behavior. You can also set the optional `HeaderLimit` and `Charset` options to further control authentication behavior. #### KeyAuth The keyauth middleware was updated to introduce a configurable `Realm` field for the `WWW-Authenticate` header. The old string-based `KeyLookup` configuration has been replaced with an `Extractor` field, and the `AuthScheme` field has been removed. The auth scheme is now inferred from the extractor used (e.g., `keyauth.FromAuthHeader`). Use helper functions like `keyauth.FromHeader`, `keyauth.FromAuthHeader`, or `keyauth.FromCookie` to define where the key should be retrieved from. Multiple sources can be combined with `keyauth.Chain`. New `Challenge`, `Error`, `ErrorDescription`, `ErrorURI`, and `Scope` options let you customize challenge responses, include Bearer error parameters, and specify required scopes. `ErrorURI` values are validated as absolute, credentials containing whitespace are rejected, and when multiple authorization extractors are chained, all schemes are advertised in the `WWW-Authenticate` header. The middleware defers emitting `WWW-Authenticate` until a 401 status is final, and `FromAuthHeader` now trims surrounding whitespace. ```go // Before app.Use(keyauth.New(keyauth.Config{ KeyLookup: "header:Authorization", AuthScheme: "Bearer", Validator: validateAPIKey, })) // After app.Use(keyauth.New(keyauth.Config{ Extractor: keyauth.FromAuthHeader(fiber.HeaderAuthorization, "Bearer"), Validator: validateAPIKey, })) ``` Combine multiple sources with `keyauth.Chain()` when needed. #### Cache The deprecated `Store` and `Key` fields were removed. Use `Storage` and `KeyGenerator` instead to configure caching backends and cache keys. Defaults also changed: the middleware now emits `Cache-Control` headers, the default `Expiration` increased to `5 minutes` (from `1 minute`), and a new `MaxBytes` limit of `1 MB` (previously unlimited) now caps cached payloads. To restore v2 behavior: - Set `DisableCacheControl` to `true` to suppress automatic `Cache-Control` headers. - Configure `Expiration` to `1*time.Minute`. - Set `MaxBytes` to `0` (or a higher value) when caching large responses. #### CORS The CORS middleware has been updated to use slices instead of strings for the `AllowOrigins`, `AllowMethods`, `AllowHeaders`, and `ExposeHeaders` fields. Here's how you can update your code: ```go // Before app.Use(cors.New(cors.Config{ AllowOrigins: "https://example.com,https://example2.com", AllowMethods: strings.Join([]string{fiber.MethodGet, fiber.MethodPost}, ","), AllowHeaders: "Content-Type", ExposeHeaders: "Content-Length", })) // After app.Use(cors.New(cors.Config{ AllowOrigins: []string{"https://example.com", "https://example2.com"}, AllowMethods: []string{fiber.MethodGet, fiber.MethodPost}, AllowHeaders: []string{"Content-Type"}, ExposeHeaders: []string{"Content-Length"}, })) ``` #### CSRF - **Field Renaming**: The `Expiration` field in the CSRF middleware configuration has been renamed to `IdleTimeout` to better describe its functionality. Additionally, the default value has been reduced from 1 hour to 30 minutes. Update your code as follows: ```go // Before app.Use(csrf.New(csrf.Config{ Expiration: 10 * time.Minute, })) // After app.Use(csrf.New(csrf.Config{ IdleTimeout: 10 * time.Minute, })) ``` - **Session Key Removal**: The `SessionKey` field has been removed from the CSRF middleware configuration. The session key is now an unexported constant within the middleware to avoid potential key collisions in the session store. - **KeyLookup Field Removal**: The `KeyLookup` field has been removed from the CSRF middleware configuration. This field was deprecated and is no longer needed as the middleware now uses a more secure approach for token management. - **DisableValueRedaction Toggle**: CSRF redacts tokens and storage keys by default; set `DisableValueRedaction` to `true` when diagnostics require the raw values. - **Default KeyGenerator**: Changed from `utils.UUIDv4` to `utils.SecureToken`, producing base64-encoded tokens instead of UUID format. ```go // Before app.Use(csrf.New(csrf.Config{ KeyLookup: "header:X-Csrf-Token", // other config... })) // After - use Extractor instead app.Use(csrf.New(csrf.Config{ Extractor: csrf.FromHeader("X-Csrf-Token"), // other config... })) ``` - **FromCookie Extractor Removal**: The `csrf.FromCookie` extractor has been intentionally removed for security reasons. Using cookie-based extraction defeats the purpose of CSRF protection by making the extracted token always match the cookie value. ```go // Before - This was a security vulnerability app.Use(csrf.New(csrf.Config{ Extractor: csrf.FromCookie("csrf_token"), // ❌ Insecure! })) // After - Use secure extractors instead app.Use(csrf.New(csrf.Config{ Extractor: csrf.FromHeader("X-Csrf-Token"), // ✅ Secure // or Extractor: csrf.FromForm("_csrf"), // ✅ Secure // or Extractor: csrf.FromQuery("csrf_token"), // ✅ Acceptable })) ``` **Security Note**: The removal of `FromCookie` prevents a common misconfiguration that would completely bypass CSRF protection. The middleware uses the Double Submit Cookie pattern, which requires the token to be submitted through a different channel than the cookie to provide meaningful protection. #### Idempotency - **DisableValueRedaction Toggle**: The idempotency middleware now hides keys in logs and error paths by default, with a `DisableValueRedaction` boolean (default `false`) to reveal them when needed. #### Timeout The timeout middleware now accepts a configuration struct instead of a duration. Update your code as follows: ```go // Before app.Use(timeout.New(handler, 2*time.Second)) // After app.Use(timeout.New(handler, timeout.Config{Timeout: 2 * time.Second})) ``` **Important behavioral changes:** - The middleware now returns immediately on timeout without waiting for the handler (using the new Abandon mechanism). - Handlers can detect timeouts by listening on `c.Context().Done()` and return early. - Panics in the handler are caught and converted to `500 Internal Server Error`. #### Filesystem You need to move filesystem middleware to static middleware due to it has been removed from the core. ```go // Before app.Use(filesystem.New(filesystem.Config{ Root: http.Dir("./assets"), })) app.Use(filesystem.New(filesystem.Config{ Root: http.Dir("./assets"), Browse: true, Index: "index.html", MaxAge: 3600, })) ``` ```go // After app.Use(static.New("", static.Config{ FS: os.DirFS("./assets"), })) app.Use(static.New("", static.Config{ FS: os.DirFS("./assets"), Browse: true, IndexNames: []string{"index.html"}, MaxAge: 3600, })) ``` #### EnvVar The `ExcludeVars` option has been removed. Remove any references to it and use `ExportVars` to explicitly list environment variables that should be exposed. #### Healthcheck Previously, the Healthcheck middleware was configured with a combined setup for liveness and readiness probes: ```go //before app.Use(healthcheck.New(healthcheck.Config{ LivenessProbe: func(c fiber.Ctx) bool { return true }, LivenessEndpoint: "/live", ReadinessProbe: func(c fiber.Ctx) bool { return serviceA.Ready() && serviceB.Ready() && ... }, ReadinessEndpoint: "/ready", })) ``` With the new version, each health check endpoint is configured separately, allowing for more flexibility: ```go // after // Default liveness endpoint configuration app.Get(healthcheck.LivenessEndpoint, healthcheck.New(healthcheck.Config{ Probe: func(c fiber.Ctx) bool { return true }, })) // Default readiness endpoint configuration app.Get(healthcheck.ReadinessEndpoint, healthcheck.New()) // New default startup endpoint configuration // Default endpoint is /startupz app.Get(healthcheck.StartupEndpoint, healthcheck.New(healthcheck.Config{ Probe: func(c fiber.Ctx) bool { return serviceA.Ready() && serviceB.Ready() && ... }, })) // Custom liveness endpoint configuration app.Get("/live", healthcheck.New()) ``` #### Monitor Since v3 the Monitor middleware has been moved to the [Contrib package](https://github.com/gofiber/contrib/tree/main/monitor) ```go // Before import "github.com/gofiber/fiber/v2/middleware/monitor" app.Use("/metrics", monitor.New()) ``` You only need to change the import path to the contrib package. ```go // After import "github.com/gofiber/contrib/monitor" app.Use("/metrics", monitor.New()) ``` #### Proxy In previous versions, TLS settings for the proxy middleware were set using the `WithTlsConfig` method. This method has been removed in favor of a more idiomatic configuration via the `TLSConfig` field in the `Config` struct. #### Before (v2 usage) ```go proxy.WithTlsConfig(&tls.Config{ InsecureSkipVerify: true, }) // Forward to url app.Get("/gif", proxy.Forward("https://i.imgur.com/IWaBepg.gif")) ``` #### After (v3 usage) ```go proxy.WithClient(&fasthttp.Client{ TLSConfig: &tls.Config{InsecureSkipVerify: true}, }) // Forward to url app.Get("/gif", proxy.Forward("https://i.imgur.com/IWaBepg.gif")) ``` `proxy.Balancer` also adopts the common middleware signature pattern and now accepts an optional variadic config: call `proxy.Balancer()` to use the defaults or continue passing a single `proxy.Config` value as in v2. #### Session `session.New()` now returns a middleware handler. When using the store pattern, create a store with `session.NewStore()` or call `Store()` on the middleware. Sessions obtained from a store must be released manually via `sess.Release()`. Additionally, replace the deprecated `KeyLookup` option with extractor functions such as `session.FromCookie()` or `session.FromHeader()`. Multiple extractors can be combined with `session.Chain()`. ```go // Before app.Use(session.New(session.Config{ KeyLookup: "cookie:session_id", Store: session.NewStore(), })) ``` ```go // After app.Use(session.New(session.Config{ Extractor: session.FromCookie("session_id"), Store: session.NewStore(), })) ``` See the [Session Middleware Migration Guide](./middleware/session.md#migration-guide) for complete details. ================================================ FILE: error.go ================================================ package fiber import ( "encoding/json" "errors" "github.com/gofiber/schema" ) // Wrap and return this for unreachable code if panicking is undesirable (i.e., in a handler). // Unexported because users will hopefully never need to see it. var errUnreachable = errors.New("fiber: unreachable code, please create an issue at github.com/gofiber/fiber") // General errors var ( ErrGracefulTimeout = errors.New("shutdown: graceful timeout has been reached, exiting") // ErrNotRunning indicates that a Shutdown method was called when the server was not running. ErrNotRunning = errors.New("shutdown: server is not running") // ErrHandlerExited is returned by App.Test if a handler panics or calls runtime.Goexit(). ErrHandlerExited = errors.New("runtime.Goexit() called in handler or server panic") // ErrNoViewEngineConfigured indicates that a helper requiring a view engine was invoked without one configured. ErrNoViewEngineConfigured = errors.New("fiber: no view engine configured") // ErrAutoCertWithCertFile indicates AutoCertManager cannot be used with CertFile/CertKeyFile. ErrAutoCertWithCertFile = errors.New("tls: AutoCertManager cannot be combined with CertFile/CertKeyFile") ) // Fiber redirection errors var ( ErrRedirectBackNoFallback = NewError(StatusInternalServerError, "Referer not found, you have to enter fallback URL for redirection.") ) // Range errors var ( ErrRangeMalformed = errors.New("range: malformed range header string") ErrRangeTooLarge = NewError(StatusRequestedRangeNotSatisfiable, "range: too many ranges") ErrRangeUnsatisfiable = errors.New("range: unsatisfiable range") ) // Binder errors var ErrCustomBinderNotFound = errors.New("binder: custom binder not found, please be sure to enter the right name") // Format errors var ( // ErrNoHandlers is returned when c.Format is called with no arguments. ErrNoHandlers = errors.New("format: at least one handler is required, but none were set") ) // gofiber/schema errors type ( // ConversionError Conversion error exposes the internal schema.ConversionError for public use. ConversionError = schema.ConversionError // UnknownKeyError error exposes the internal schema.UnknownKeyError for public use. UnknownKeyError = schema.UnknownKeyError // EmptyFieldError error exposes the internal schema.EmptyFieldError for public use. EmptyFieldError = schema.EmptyFieldError // MultiError error exposes the internal schema.MultiError for public use. MultiError = schema.MultiError ) // encoding/json errors type ( // InvalidUnmarshalError describes an invalid argument passed to Unmarshal. // (The argument to Unmarshal must be a non-nil pointer.) InvalidUnmarshalError = json.InvalidUnmarshalError // MarshalerError represents an error from calling a MarshalJSON or MarshalText method. MarshalerError = json.MarshalerError // SyntaxError is a description of a JSON syntax error. SyntaxError = json.SyntaxError // UnmarshalTypeError describes a JSON value that was // not appropriate for a value of a specific Go type. UnmarshalTypeError = json.UnmarshalTypeError // UnsupportedTypeError is returned by Marshal when attempting // to encode an unsupported value type. UnsupportedTypeError = json.UnsupportedTypeError // UnsupportedValueError exposes json.UnsupportedValueError to describe unsupported values encountered during encoding. UnsupportedValueError = json.UnsupportedValueError ) ================================================ FILE: error_test.go ================================================ package fiber import ( "encoding/json" "errors" "testing" "github.com/gofiber/schema" "github.com/stretchr/testify/require" ) func Test_ConversionError(t *testing.T) { t.Parallel() ok := errors.As(ConversionError{}, &schema.ConversionError{}) require.True(t, ok) } func Test_UnknownKeyError(t *testing.T) { t.Parallel() ok := errors.As(UnknownKeyError{}, &schema.UnknownKeyError{}) require.True(t, ok) } func Test_EmptyFieldError(t *testing.T) { t.Parallel() ok := errors.As(EmptyFieldError{}, &schema.EmptyFieldError{}) require.True(t, ok) } func Test_MultiError(t *testing.T) { t.Parallel() ok := errors.As(MultiError{}, &schema.MultiError{}) require.True(t, ok) } func Test_InvalidUnmarshalError(t *testing.T) { t.Parallel() var e *json.InvalidUnmarshalError ok := errors.As(&InvalidUnmarshalError{}, &e) require.True(t, ok) } func Test_MarshalerError(t *testing.T) { t.Parallel() var e *json.MarshalerError ok := errors.As(&MarshalerError{}, &e) require.True(t, ok) } func Test_SyntaxError(t *testing.T) { t.Parallel() var e *json.SyntaxError ok := errors.As(&SyntaxError{}, &e) require.True(t, ok) } func Test_UnmarshalTypeError(t *testing.T) { t.Parallel() var e *json.UnmarshalTypeError ok := errors.As(&UnmarshalTypeError{}, &e) require.True(t, ok) } func Test_UnsupportedTypeError(t *testing.T) { t.Parallel() var e *json.UnsupportedTypeError ok := errors.As(&UnsupportedTypeError{}, &e) require.True(t, ok) } func Test_UnsupportedValeError(t *testing.T) { t.Parallel() var e *json.UnsupportedValueError ok := errors.As(&UnsupportedValueError{}, &e) require.True(t, ok) } ================================================ FILE: errors_internal.go ================================================ package fiber import ( "errors" ) var ( errBindPoolTypeAssertion = errors.New("failed to type-assert to *Bind") errCustomCtxTypeAssertion = errors.New("failed to type-assert to CustomCtx") errInvalidEscapeSequence = errors.New("invalid escape sequence") errRedirectTypeAssertion = errors.New("failed to type-assert to *Redirect") ) ================================================ FILE: extractors/README.md ================================================ # Extractors Package Package providing shared value extraction utilities for Fiber middleware packages. ## Audience **This README is targeted at middleware developers and contributors.** If you are a Fiber framework user looking to use extractors in your application, please refer to the [Extractors Guide](https://docs.gofiber.io/guide/extractors) instead. ## Architecture ### Core Types - `Extractor`: Core extraction function with metadata - `Source`: Enumeration of extraction sources (Header, AuthHeader, Query, Form, Param, Cookie, Custom) - `ErrNotFound`: Standardized error for missing values ### Extractor Structure ```go type Extractor struct { Extract func(fiber.Ctx) (string, error) Key string // The parameter/header name used for extraction AuthScheme string // The auth scheme used, e.g., "Bearer" Chain []Extractor // For chained extractors, stores all extractors in the chain Source Source // The type of source being extracted from } ``` ### Available Functions - `FromAuthHeader(authScheme string)`: Extract from Authorization header with optional scheme - `FromCookie(key string)`: Extract from HTTP cookies - `FromParam(param string)`: Extract from URL path parameters - `FromForm(param string)`: Extract from form data - `FromHeader(header string)`: Extract from custom HTTP headers - `FromQuery(param string)`: Extract from URL query parameters - `FromCustom(key string, fn func(fiber.Ctx) (string, error))`: Define custom extraction logic with metadata - `Chain(extractors ...Extractor)`: Chain multiple extractors with fallback ### Source Inspection The `Source` field provides **security-aware extraction** by explicitly identifying the origin of extracted values. This enables middleware to enforce security policies based on data source: ```go switch extractor.Source { case SourceAuthHeader: // Authorization header - commonly used for authentication tokens case SourceHeader: // Custom HTTP headers - application-specific data case SourceCookie: // HTTP cookies - client-side stored data case SourceQuery: // URL query parameters - visible in URLs and logs (security consideration) case SourceForm: // Form data - POST body data case SourceParam: // URL path parameters - route-based data case SourceCustom: // Custom extraction logic } ``` ### Chain Behavior The `Chain` function implements fallback logic: - Returns first successful extraction (non-empty value, no error) - If all extractors fail, returns the last error encountered or `ErrNotFound` - **Skips extractors with `nil` Extract functions** (graceful error handling) - Preserves metadata from first extractor for introspection - Stores defensive copy for runtime inspection via the `Chain` field ## Security Considerations ### Source Awareness and Custom Extractors As described in the [Source Inspection](#source-inspection) section, the `Source` field enables middleware to enforce security policies based on data source: - **CSRF Protection**: The double-submit-cookie pattern requires tokens to be submitted in both a cookie AND a form field/header. Source awareness allows CSRF middleware to verify that tokens come from both expected sources, and not for example only from cookies - **Authentication**: Security middleware can enforce source-specific policies (e.g., auth tokens from headers, not query parameters) - **Audit Trails**: Source information enables security analysis and compliance reporting However, when using `FromCustom`, middleware cannot determine the source of the extracted value, which can limit the ability of a middleware to provide warnings about potential security risks. Documentation and examples should clearly warn about these risks when using custom extractors. ================================================ FILE: extractors/extractors.go ================================================ package extractors // Package extractors provides shared value extraction utilities for Fiber middleware. // This package helps reduce code duplication across middleware packages // while ensuring consistent behavior, security practices, and RFC compliance. // It can extract string values from various HTTP request sources including // headers, cookies, query parameters, form data, and URL parameters. // // Example usage: // // import "github.com/gofiber/fiber/v3/extractors" // // // Extract from Authorization header // authExtractor := extractors.FromAuthHeader("Bearer") // // // Chain multiple sources with fallback // tokenExtractor := extractors.Chain( // extractors.FromHeader("X-API-Key"), // extractors.FromCookie("api_key"), // extractors.FromQuery("token"), // ) // // Security considerations: // - Query parameters and form data can leak sensitive information // - Use HTTPS to protect extracted values in transit // - Consider source-specific security policies for your use case import ( "errors" "net/url" "github.com/gofiber/fiber/v3" "github.com/gofiber/utils/v2" ) // Source represents the type of source from which an API key is extracted. // This is informational metadata that helps developers understand the extractor behavior. type Source int const ( // SourceHeader indicates the value is extracted from an HTTP header. SourceHeader Source = iota // SourceAuthHeader indicates the value is extracted from the Authorization header. SourceAuthHeader // SourceForm indicates the value is extracted from form data. SourceForm // SourceQuery indicates the value is extracted from URL query parameters. SourceQuery // SourceParam indicates the value is extracted from URL path parameters. SourceParam // SourceCookie indicates the value is extracted from cookies. SourceCookie // SourceCustom indicates the value is extracted using a custom extractor function. SourceCustom ) // ErrNotFound is returned when the requested value is missing or empty. var ErrNotFound = errors.New("value not found") // Extractor defines a value extraction method with metadata. type Extractor struct { Extract func(fiber.Ctx) (string, error) Key string // The parameter/header name used for extraction AuthScheme string // The auth scheme used, e.g., "Bearer" Chain []Extractor // For chained extractors, stores all extractors in the chain Source Source // The type of source being extracted from } // FromAuthHeader extracts a value from the Authorization header with an optional prefix. // This function implements RFC 9110 compliant Authorization header parsing with strict token68 validation. // // RFC Compliance: // - Follows RFC 9110 Section 11.6.2 for Authorization header format // - Enforces 1*SP (one or more spaces) between auth-scheme and credentials // - Implements RFC 7235 token68 character validation for extracted tokens // - Case-insensitive auth scheme matching per HTTP standards // // Token68 Validation: // - Only allows characters: A-Z, a-z, 0-9, -, ., _, ~, +, /, = // - Rejects tokens containing spaces, tabs, or other whitespace // - Validates proper padding: = only at end, no characters after padding starts // - Prevents tokens starting with = (invalid padding) // // Security Features: // - Strict validation prevents header injection attacks // - Rejects malformed tokens that could bypass authentication // - Consistent error handling for missing or invalid credentials // // Parameters: // - authScheme: The auth scheme to strip from the header value (e.g., "Bearer", "Basic"). // If empty, the entire header value is returned without validation. // // Returns: // // An Extractor that attempts to retrieve and parse the Authorization header. // Returns ErrNotFound if the header is missing, malformed, or doesn't match the expected scheme. // // Examples: // // // Extract Bearer token with validation // extractor := FromAuthHeader("Bearer") // // Input: "Bearer abc123" -> Output: "abc123" // // Input: "Bearer abc def" -> Output: ErrNotFound (space in token) // // Input: "Basic dXNlcjpwYXNz" -> Output: ErrNotFound (wrong scheme) // // // Extract raw header value (no validation) // extractor := FromAuthHeader("") // // Input: "CustomAuth token123" -> Output: "CustomAuth token123" func FromAuthHeader(authScheme string) Extractor { return Extractor{ Extract: func(c fiber.Ctx) (string, error) { authHeader := c.Get(fiber.HeaderAuthorization) if authHeader == "" { return "", ErrNotFound } // Check if the header starts with the specified auth scheme if authScheme != "" { schemeLen := len(authScheme) if len(authHeader) <= schemeLen || !utils.EqualFold(authHeader[:schemeLen], authScheme) { return "", ErrNotFound } rest := authHeader[schemeLen:] if rest == "" || rest[0] != ' ' { return "", ErrNotFound } // Extract token after the required space token := rest[1:] if token == "" { return "", ErrNotFound } if !isValidToken68(token) { return "", ErrNotFound } return token, nil } return authHeader, nil }, Key: fiber.HeaderAuthorization, Source: SourceAuthHeader, AuthScheme: authScheme, } } // FromCookie creates an Extractor that retrieves a value from a specified cookie in the request. // // The function: // - Retrieves the cookie value using the specified name // - Returns ErrNotFound if the cookie is missing // // Parameters: // - key: The name of the cookie from which to extract the value. // // Returns: // // An Extractor that attempts to retrieve the value from the specified cookie. // Returns ErrNotFound if the cookie is not present. // // Security Note: // // Cookies are generally more secure than query parameters for sensitive data // as they are not logged in access logs or visible in browser history. // However, ensure cookies are properly secured with appropriate flags. // // Example: // // extractor := FromCookie("session_id") // // Cookie: "session_id=abc123" -> Output: "abc123" // // Missing cookie -> Output: ErrNotFound func FromCookie(key string) Extractor { return Extractor{ Extract: func(c fiber.Ctx) (string, error) { value := c.Cookies(key) if value == "" { return "", ErrNotFound } return value, nil }, Key: key, Source: SourceCookie, } } // FromParam creates an Extractor that retrieves a value from a specified URL parameter in the request. // URL parameters are extracted from the route path (e.g., /users/:id). // // SECURITY WARNING: Extracting values from URL parameters can leak sensitive information through: // - Server access logs and error logs // - Browser referrer headers when following links // - Proxy and intermediary server logs // - Browser history and bookmarks // - Network monitoring tools // // For sensitive data, prefer FromAuthHeader, FromCookie, or FromHeader instead. // // Parameters: // - param: The name of the URL parameter from which to extract the value. // // Returns: // // An Extractor that attempts to retrieve the value from the specified URL parameter. // Returns ErrNotFound if the parameter is not present. // // Example: // // // Route: GET /users/:userId/posts/:postId // userExtractor := FromParam("userId") // postExtractor := FromParam("postId") // // URL: /users/123/posts/456 -> userId: "123", postId: "456" func FromParam(param string) Extractor { return Extractor{ Extract: func(c fiber.Ctx) (string, error) { value := c.Params(param) if value == "" { return "", ErrNotFound } unescapedValue, err := url.PathUnescape(value) if err != nil { return "", ErrNotFound } return unescapedValue, nil }, Key: param, Source: SourceParam, } } // FromForm creates an Extractor that retrieves a value from a specified form field in the request. // Form data is typically submitted via POST requests with content-type application/x-www-form-urlencoded. // // SECURITY WARNING: Extracting values from form data can leak sensitive information through: // - Server access logs and error logs // - Browser referrer headers (especially if form is submitted via GET) // - Proxy and intermediary server logs // - Browser history (if form uses GET method) // // For sensitive data, prefer FromAuthHeader or FromCookie instead. // If using form data, ensure the form uses POST method and HTTPS. // // Parameters: // - param: The name of the form field from which to extract the value. // // Returns: // // An Extractor that attempts to retrieve the value from the specified form field. // Returns ErrNotFound if the field is not present. // // Example: // // extractor := FromForm("username") // // Form data: "username=john_doe&password=secret" -> Output: "john_doe" // // Missing field -> Output: ErrNotFound func FromForm(param string) Extractor { return Extractor{ Extract: func(c fiber.Ctx) (string, error) { value := c.FormValue(param) if value == "" { return "", ErrNotFound } return value, nil }, Key: param, Source: SourceForm, } } // FromHeader creates an Extractor that retrieves a value from a specified HTTP header in the request. // HTTP headers are commonly used for API keys, tokens, and other metadata. // // The function: // - Retrieves the header value using the specified name // - Returns ErrNotFound if the header is missing // // Parameters: // - header: The name of the HTTP header from which to extract the value. // // Returns: // // An Extractor that attempts to retrieve the value from the specified HTTP header. // Returns ErrNotFound if the header is not present. // // Security Note: // // Headers are generally secure for sensitive data as they are not logged // in access logs by default. However, be aware that some proxies may log headers. // // Example: // // extractor := FromHeader("X-API-Key") // // Header: "X-API-Key: abc123" -> Output: "abc123" // // Missing header -> Output: ErrNotFound func FromHeader(header string) Extractor { return Extractor{ Extract: func(c fiber.Ctx) (string, error) { value := c.Get(header) if value == "" { return "", ErrNotFound } return value, nil }, Key: header, Source: SourceHeader, } } // FromQuery creates an Extractor that retrieves a value from a specified query parameter in the request. // Query parameters are extracted from the URL query string (e.g., ?key=value&foo=bar). // // SECURITY WARNING: Extracting values from URL query parameters can leak sensitive information through: // - Server access logs and error logs // - Browser referrer headers when following links // - Proxy and intermediary server logs // - Browser history and bookmarks // - Network monitoring tools and packet sniffers // - Web browser developer tools // // For sensitive data, prefer FromAuthHeader, FromCookie, or FromHeader instead. // If query parameters must be used, ensure HTTPS is enforced. // // Parameters: // - param: The name of the query parameter from which to extract the value. // // Returns: // // An Extractor that attempts to retrieve the value from the specified query parameter. // Returns ErrNotFound if the parameter is not present. // // Example: // // extractor := FromQuery("token") // // URL: /api/data?token=abc123&format=json -> Output: "abc123" // // URL: /api/data?format=json -> Output: ErrNotFound func FromQuery(param string) Extractor { return Extractor{ Extract: func(c fiber.Ctx) (string, error) { value := c.Query(param) if value == "" { return "", ErrNotFound } return value, nil }, Key: param, Source: SourceQuery, } } // FromCustom creates an Extractor using a provided function. // This allows for custom extraction logic beyond the built-in extractors. // // The function: // - Accepts a custom extraction function with signature func(fiber.Ctx) (string, error) // - Handles nil functions gracefully by returning ErrNotFound // - Preserves the custom function for execution // // Parameters: // - key: A descriptive identifier for the custom extractor. // Used for debugging, logging, and Chain metadata. Should be meaningful for introspection. // Examples: "X-Custom-Header", "Database-Lookup", "Cache-Key" // - fn: The custom function to extract the value from the fiber.Ctx. // If nil, the extractor will return ErrNotFound when executed. // The function should return (value, nil) on success or ("", error) on failure. // // Returns: // // An Extractor that uses the provided function for extraction. // If fn is nil, the returned extractor will always return ErrNotFound. // // Examples: // // // Custom header with transformation // extractor := FromCustom("X-API-Key", func(c fiber.Ctx) (string, error) { // value := c.Get("X-API-Key") // if value == "" { // return "", ErrNotFound // } // return strings.ToUpper(value), nil // }) // // // Database lookup (pseudo-code) // userExtractor := FromCustom("user-from-db", func(c fiber.Ctx) (string, error) { // userID := c.Params("userId") // user, err := db.GetUser(userID) // if err != nil { // return "", err // } // return user.Name, nil // }) // // // Conditional extraction // smartExtractor := FromCustom("smart-auth", func(c fiber.Ctx) (string, error) { // if c.Get("X-Service-Auth") != "" { // return c.Get("X-Service-Auth"), nil // } // return c.Cookies("session"), nil // }) func FromCustom(key string, fn func(fiber.Ctx) (string, error)) Extractor { if fn == nil { fn = func(fiber.Ctx) (string, error) { return "", ErrNotFound } } return Extractor{ Extract: fn, Key: key, Source: SourceCustom, } } // Chain creates an Extractor that tries multiple extractors in order until one succeeds. // This implements a fallback pattern where multiple extraction sources are attempted in sequence. // // The function: // - Tries each extractor in the order provided // - Returns the first successful extraction (non-empty value with no error) // - Skips extractors with nil Extract functions // - Returns the last error encountered if all extractors fail // - Returns ErrNotFound if no extractors are provided or all return empty values // // Parameters: // - extractors: A variadic list of Extractor instances to try in sequence. // The order matters - more secure/preferred sources should be listed first. // // Returns: // // An Extractor that attempts each provided extractor in order. // The returned extractor uses the Source and Key from the first extractor for metadata. // // Behavior: // - Success: Returns the first non-empty value with no error // - Partial failure: Continues to next extractor if current returns error or empty value // - Total failure: Returns last error encountered, or ErrNotFound if no errors // - Empty chain: Always returns ErrNotFound // // Examples: // // // Try header first, then cookie, then query param // extractor := Chain( // FromHeader("Authorization"), // FromCookie("auth_token"), // FromQuery("token"), // ) // // // API key from multiple possible sources // apiKeyExtractor := Chain( // FromHeader("X-API-Key"), // FromQuery("api_key"), // FromForm("apiKey"), // ) // // Security Note: // // Order extractors by security preference. Most secure sources (headers, cookies) // should be attempted before less secure ones (query params, form data). func Chain(extractors ...Extractor) Extractor { if len(extractors) == 0 { return Extractor{ Extract: func(fiber.Ctx) (string, error) { return "", ErrNotFound }, Source: SourceCustom, Key: "", Chain: []Extractor{}, } } // Use the source and key from the first extractor as the primary primarySource := extractors[0].Source primaryKey := extractors[0].Key return Extractor{ Extract: func(c fiber.Ctx) (string, error) { var lastErr error // last error encountered (including ErrNotFound) for _, extractor := range extractors { if extractor.Extract == nil { continue } v, err := extractor.Extract(c) if err == nil && v != "" { return v, nil } if err != nil { lastErr = err } } if lastErr != nil { return "", lastErr } return "", ErrNotFound }, Source: primarySource, Key: primaryKey, Chain: append([]Extractor(nil), extractors...), // Defensive copy for introspection } } // isValidToken68 checks if a string is a valid token68 per RFC 7235/9110. func isValidToken68(token string) bool { if token == "" { return false } paddingStarted := false for i := 0; i < len(token); i++ { c := token[i] switch { case (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-' || c == '.' || c == '_' || c == '~' || c == '+' || c == '/': if paddingStarted { return false // No characters allowed after padding starts } case c == '=': if i == 0 { return false // Cannot start with padding } paddingStarted = true default: return false // Invalid character } } return true } ================================================ FILE: extractors/extractors_test.go ================================================ package extractors import ( "context" "net/http" "strings" "testing" "github.com/gofiber/fiber/v3" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) // go test -run Test_Extractors_Missing func Test_Extractors_Missing(t *testing.T) { t.Parallel() app := fiber.New() // Add a route to test the missing param app.Get("/test", func(c fiber.Ctx) error { token, err := FromParam("token").Extract(c) require.Empty(t, token) require.ErrorIs(t, err, ErrNotFound) return nil }) _, err := app.Test(newRequest(fiber.MethodGet, "/test")) require.NoError(t, err) ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) // Missing form token, err := FromForm("token").Extract(ctx) require.Empty(t, token) require.ErrorIs(t, err, ErrNotFound) // Missing query token, err = FromQuery("token").Extract(ctx) require.Empty(t, token) require.ErrorIs(t, err, ErrNotFound) // Missing header token, err = FromHeader("X-Token").Extract(ctx) require.Empty(t, token) require.ErrorIs(t, err, ErrNotFound) // Missing Auth header token, err = FromAuthHeader("Bearer").Extract(ctx) require.Empty(t, token) require.ErrorIs(t, err, ErrNotFound) // Missing cookie token, err = FromCookie("token").Extract(ctx) require.Empty(t, token) require.ErrorIs(t, err, ErrNotFound) } // newRequest creates a new *http.Request for Fiber's app.Test func newRequest(method, target string) *http.Request { req, err := http.NewRequestWithContext(context.Background(), method, target, http.NoBody) if err != nil { panic(err) } return req } // go test -run Test_Extractors func Test_Extractors(t *testing.T) { t.Parallel() t.Run("FromParam", func(t *testing.T) { t.Parallel() app := fiber.New() app.Get("/test/:token", func(c fiber.Ctx) error { token, err := FromParam("token").Extract(c) require.NoError(t, err) require.Equal(t, "token_from_param", token) return nil }) _, err := app.Test(newRequest(fiber.MethodGet, "/test/token_from_param")) require.NoError(t, err) }) t.Run("FromForm", func(t *testing.T) { t.Parallel() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) ctx.Request().Header.SetContentType(fiber.MIMEApplicationForm) ctx.Request().Header.SetMethod(fiber.MethodPost) ctx.Request().SetBodyString("token=token_from_form") token, err := FromForm("token").Extract(ctx) require.NoError(t, err) require.Equal(t, "token_from_form", token) }) t.Run("FromQuery", func(t *testing.T) { t.Parallel() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) ctx.Request().SetRequestURI("/?token=token_from_query") token, err := FromQuery("token").Extract(ctx) require.NoError(t, err) require.Equal(t, "token_from_query", token) }) t.Run("FromHeader", func(t *testing.T) { t.Parallel() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) ctx.Request().Header.Set("X-Token", "token_from_header") token, err := FromHeader("X-Token").Extract(ctx) require.NoError(t, err) require.Equal(t, "token_from_header", token) }) t.Run("FromAuthHeader", func(t *testing.T) { t.Parallel() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) ctx.Request().Header.Set(fiber.HeaderAuthorization, "Bearer token_from_auth_header") token, err := FromAuthHeader("Bearer").Extract(ctx) require.NoError(t, err) require.Equal(t, "token_from_auth_header", token) }) t.Run("FromCookie", func(t *testing.T) { t.Parallel() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) ctx.Request().Header.SetCookie("token", "token_from_cookie") token, err := FromCookie("token").Extract(ctx) require.NoError(t, err) require.Equal(t, "token_from_cookie", token) }) } // go test -run Test_Extractor_Chain func Test_Extractor_Chain(t *testing.T) { t.Parallel() t.Run("no_extractors", func(t *testing.T) { t.Parallel() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) token, err := Chain().Extract(ctx) require.Empty(t, token) require.ErrorIs(t, err, ErrNotFound) }) t.Run("first_extractor_succeeds", func(t *testing.T) { t.Parallel() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) ctx.Request().Header.Set("X-Token", "token_from_header") ctx.Request().SetRequestURI("/?token=token_from_query") token, err := Chain(FromHeader("X-Token"), FromQuery("token")).Extract(ctx) require.NoError(t, err) require.Equal(t, "token_from_header", token) }) t.Run("second_extractor_succeeds", func(t *testing.T) { t.Parallel() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) ctx.Request().SetRequestURI("/?token=token_from_query") token, err := Chain(FromHeader("X-Token"), FromQuery("token")).Extract(ctx) require.NoError(t, err) require.Equal(t, "token_from_query", token) }) t.Run("all_extractors_fail", func(t *testing.T) { t.Parallel() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) token, err := Chain(FromHeader("X-Token"), FromQuery("token")).Extract(ctx) require.Empty(t, token) require.ErrorIs(t, err, ErrNotFound) }) t.Run("empty_extractor_returns_not_found", func(t *testing.T) { t.Parallel() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) // This extractor will return "", nil dummyExtractor := Extractor{ Extract: func(_ fiber.Ctx) (string, error) { return "", nil }, Source: SourceCustom, Key: "token", } token, err := Chain(dummyExtractor).Extract(ctx) require.Empty(t, token) require.ErrorIs(t, err, ErrNotFound) }) } // go test -run Test_Extractor_FromAuthHeader_EdgeCases func Test_Extractor_FromAuthHeader_EdgeCases(t *testing.T) { t.Parallel() t.Run("wrong_scheme", func(t *testing.T) { t.Parallel() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) ctx.Request().Header.Set(fiber.HeaderAuthorization, "Basic dXNlcjpwYXNz") // Basic auth instead of Bearer token, err := FromAuthHeader("Bearer").Extract(ctx) require.Empty(t, token) require.ErrorIs(t, err, ErrNotFound) }) t.Run("missing_space_after_scheme", func(t *testing.T) { t.Parallel() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) ctx.Request().Header.Set(fiber.HeaderAuthorization, "Bearertoken") // Missing space after Bearer token, err := FromAuthHeader("Bearer").Extract(ctx) require.Empty(t, token) require.ErrorIs(t, err, ErrNotFound) }) t.Run("case_insensitive_scheme_matching", func(t *testing.T) { t.Parallel() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) ctx.Request().Header.Set(fiber.HeaderAuthorization, "bearer token") // lowercase bearer token, err := FromAuthHeader("Bearer").Extract(ctx) require.NoError(t, err) require.Equal(t, "token", token) }) } // go test -run Test_Extractor_Chain_Introspection func Test_Extractor_Chain_Introspection(t *testing.T) { t.Parallel() // Test chain introspection extractor1 := FromHeader("X-Token") extractor2 := FromQuery("token") extractor3 := FromCookie("auth") chainExtractor := Chain(extractor1, extractor2, extractor3) // Verify chain metadata require.Equal(t, SourceHeader, chainExtractor.Source) require.Equal(t, "X-Token", chainExtractor.Key) require.Len(t, chainExtractor.Chain, 3) // Verify individual extractors in chain require.Equal(t, SourceHeader, chainExtractor.Chain[0].Source) require.Equal(t, "X-Token", chainExtractor.Chain[0].Key) require.Equal(t, SourceQuery, chainExtractor.Chain[1].Source) require.Equal(t, "token", chainExtractor.Chain[1].Key) require.Equal(t, SourceCookie, chainExtractor.Chain[2].Source) require.Equal(t, "auth", chainExtractor.Chain[2].Key) } // go test -run Test_Extractor_FromCustom func Test_Extractor_FromCustom(t *testing.T) { t.Parallel() t.Run("successful_extraction", func(t *testing.T) { t.Parallel() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) ctx.Request().Header.Set("X-Custom", "custom-value") customExtractor := FromCustom("X-Custom", func(c fiber.Ctx) (string, error) { value := c.Get("X-Custom") if value == "" { return "", ErrNotFound } return strings.ToUpper(value), nil }) token, err := customExtractor.Extract(ctx) require.NoError(t, err) require.Equal(t, "CUSTOM-VALUE", token) // Verify metadata require.Equal(t, SourceCustom, customExtractor.Source) require.Equal(t, "X-Custom", customExtractor.Key) require.Empty(t, customExtractor.AuthScheme) }) t.Run("extraction_with_error", func(t *testing.T) { t.Parallel() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) errorExtractor := FromCustom("test", func(_ fiber.Ctx) (string, error) { return "", fiber.NewError(fiber.StatusBadRequest, "Custom error") }) token, err := errorExtractor.Extract(ctx) require.Empty(t, token) require.Error(t, err) require.Contains(t, err.Error(), "Custom error") }) t.Run("extraction_returning_empty_string", func(t *testing.T) { t.Parallel() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) emptyExtractor := FromCustom("empty", func(_ fiber.Ctx) (string, error) { return "", nil }) token, err := emptyExtractor.Extract(ctx) require.Empty(t, token) require.NoError(t, err) // Should return empty string with no error }) t.Run("nil_function", func(t *testing.T) { t.Parallel() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) nilExtractor := FromCustom("nil", nil) token, err := nilExtractor.Extract(ctx) require.Empty(t, token) require.ErrorIs(t, err, ErrNotFound) // Should return ErrNotFound for nil function }) } // go test -run Test_Extractor_Chain_Error_Propagation func Test_Extractor_Chain_Error_Propagation(t *testing.T) { t.Parallel() app := fiber.New() // Create extractors that return different errors errorExtractor1 := Extractor{ Extract: func(_ fiber.Ctx) (string, error) { return "", fiber.NewError(fiber.StatusBadRequest, "First error") }, Key: "error1", Source: SourceCustom, } errorExtractor2 := Extractor{ Extract: func(_ fiber.Ctx) (string, error) { return "", fiber.NewError(fiber.StatusUnauthorized, "Second error") }, Key: "error2", Source: SourceCustom, } chainExtractor := Chain(errorExtractor1, errorExtractor2) ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) token, err := chainExtractor.Extract(ctx) require.Empty(t, token) require.Error(t, err) require.Contains(t, err.Error(), "Second error") // Should return the last error var fe *fiber.Error require.ErrorAs(t, err, &fe) require.Equal(t, fiber.StatusUnauthorized, fe.Code) } // go test -run Test_Extractor_Chain_With_Success func Test_Extractor_Chain_With_Success(t *testing.T) { t.Parallel() app := fiber.New() // First extractor fails, second succeeds failingExtractor := Extractor{ Extract: func(_ fiber.Ctx) (string, error) { return "", ErrNotFound }, Key: "fail", Source: SourceCustom, } successExtractor := Extractor{ Extract: func(_ fiber.Ctx) (string, error) { return "success-token", nil }, Key: "success", Source: SourceCustom, } chainExtractor := Chain(failingExtractor, successExtractor) ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) token, err := chainExtractor.Extract(ctx) require.NoError(t, err) require.Equal(t, "success-token", token) } // go test -run Test_Extractor_FromAuthHeader_CustomScheme func Test_Extractor_FromAuthHeader_CustomScheme(t *testing.T) { t.Parallel() app := fiber.New() // Test with custom auth scheme ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) ctx.Request().Header.Set(fiber.HeaderAuthorization, "CustomScheme my-token") extractor := FromAuthHeader("CustomScheme") token, err := extractor.Extract(ctx) require.NoError(t, err) require.Equal(t, "my-token", token) // Verify metadata require.Equal(t, SourceAuthHeader, extractor.Source) require.Equal(t, fiber.HeaderAuthorization, extractor.Key) require.Equal(t, "CustomScheme", extractor.AuthScheme) } // go test -run Test_Extractor_FromAuthHeader_WhitespaceToken func Test_Extractor_FromAuthHeader_WhitespaceToken(t *testing.T) { t.Parallel() app := fiber.New() // Test with token containing whitespace (should be rejected per RFC 7235 token68 spec) ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) ctx.Request().Header.Set(fiber.HeaderAuthorization, "Bearer token with spaces and\ttabs") extractor := FromAuthHeader("Bearer") token, err := extractor.Extract(ctx) require.Error(t, err) require.ErrorIs(t, err, ErrNotFound) require.Empty(t, token) // Verify metadata require.Equal(t, SourceAuthHeader, extractor.Source) require.Equal(t, fiber.HeaderAuthorization, extractor.Key) require.Equal(t, "Bearer", extractor.AuthScheme) } // go test -run Test_Extractor_FromAuthHeader_RFC_Compliance func Test_Extractor_FromAuthHeader_RFC_Compliance(t *testing.T) { t.Parallel() testCases := []struct { name string header string expectedToken string description string shouldFail bool }{ { name: "tab_after_scheme", header: "Bearer\ttoken", shouldFail: true, description: "tab character after scheme should be rejected - RFC specifies 1*SP, not tabs", }, { name: "single_space_after_scheme", header: "Bearer token", shouldFail: false, expectedToken: "token", description: "single space after scheme should be accepted - standard format", }, { name: "multiple_spaces_after_scheme", header: "Bearer token", shouldFail: true, description: "multiple spaces after scheme rejected for simplicity - single space is standard", }, { name: "mixed_whitespace_after_scheme", header: "Bearer \t \ttoken", shouldFail: true, description: "mixed whitespace after scheme should be rejected - RFC specifies 1*SP, not tabs", }, { name: "no_whitespace_after_scheme", header: "Bearertoken", shouldFail: true, description: "no whitespace after scheme should fail", }, { name: "header_too_short", header: "Bearer", shouldFail: true, description: "header too short for scheme + space + token", }, { name: "only_whitespace_after_scheme", header: "Bearer \t ", shouldFail: true, description: "only whitespace after scheme should fail", }, { name: "case_insensitive_scheme", header: "BEARER token", shouldFail: false, expectedToken: "token", description: "case-insensitive scheme matching should work", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) ctx.Request().Header.Set(fiber.HeaderAuthorization, tc.header) token, err := FromAuthHeader("Bearer").Extract(ctx) if tc.shouldFail { require.Error(t, err, "Expected error for %s", tc.description) require.ErrorIs(t, err, ErrNotFound) require.Empty(t, token) } else { require.NoError(t, err, "Expected no error for %s", tc.description) require.Equal(t, tc.expectedToken, token) } }) } // Special case for case-insensitive scheme matching with different extractor scheme t.Run("case_insensitive_extractor_scheme", func(t *testing.T) { t.Parallel() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) ctx.Request().Header.Set(fiber.HeaderAuthorization, "BEARER token") token, err := FromAuthHeader("bearer").Extract(ctx) // lowercase extractor scheme require.NoError(t, err) require.Equal(t, "token", token) }) } // go test -run Test_Extractor_FromAuthHeader_Token68_Validation func Test_Extractor_FromAuthHeader_Token68_Validation(t *testing.T) { t.Parallel() app := fiber.New() // Test valid token68 characters (should pass) ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) ctx.Request().Header.Set(fiber.HeaderAuthorization, "Bearer ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~+/=") token, err := FromAuthHeader("Bearer").Extract(ctx) require.NoError(t, err) require.Equal(t, "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~+/=", token) // Test tokens with spaces (should fail) testCases := []struct { name string header string description string shouldFail bool }{ {name: "space_in_token", header: "Bearer abc def", shouldFail: true, description: "space in token"}, {name: "space_after_scheme", header: "Bearer abc", shouldFail: true, description: "multiple spaces after scheme"}, {name: "no_space_after_scheme", header: "Bearertoken", shouldFail: true, description: "no space after scheme"}, {name: "only_scheme", header: "Bearer", shouldFail: true, description: "only scheme, no token"}, {name: "tab_after_scheme", header: "Bearer\ttoken", shouldFail: true, description: "tab after scheme"}, {name: "tab_in_token", header: "Bearer abc\tdef", shouldFail: true, description: "tab in token"}, {name: "newline_in_token", header: "Bearer abc\ndef", shouldFail: true, description: "newline in token"}, {name: "leading_space_in_token", header: "Bearer abc", shouldFail: true, description: "leading space in token after scheme space"}, {name: "trailing_space_in_token", header: "Bearer abc ", shouldFail: true, description: "trailing space in token"}, {name: "comma_in_token", header: "Bearer abc,def", shouldFail: true, description: "comma in token"}, {name: "semicolon_in_token", header: "Bearer abc;def", shouldFail: true, description: "semicolon in token"}, {name: "quote_in_token", header: "Bearer abc\"def", shouldFail: true, description: "quote in token"}, {name: "bracket_in_token", header: "Bearer abc[def", shouldFail: true, description: "bracket in token"}, {name: "equals_at_start", header: "Bearer =abc", shouldFail: true, description: "equals at start of token"}, {name: "equals_in_middle", header: "Bearer ab=cd", shouldFail: true, description: "equals in middle of token"}, {name: "valid_equals_at_end", header: "Bearer abc=", shouldFail: false, description: "valid equals at end"}, {name: "valid_double_equals", header: "Bearer abc==", shouldFail: false, description: "valid double equals at end"}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) ctx.Request().Header.Set(fiber.HeaderAuthorization, tc.header) token, err := FromAuthHeader("Bearer").Extract(ctx) if tc.shouldFail { require.Error(t, err, "Expected error for %s", tc.description) require.ErrorIs(t, err, ErrNotFound) require.Empty(t, token) } else { require.NoError(t, err, "Expected no error for %s", tc.description) require.NotEmpty(t, token) } }) } } // go test -run Test_Extractor_FromAuthHeader_NoScheme func Test_Extractor_FromAuthHeader_NoScheme(t *testing.T) { t.Parallel() t.Run("returns_header_value", func(t *testing.T) { t.Parallel() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) ctx.Request().Header.Set(fiber.HeaderAuthorization, "some-token-value") extractor := FromAuthHeader("") // No scheme token, err := extractor.Extract(ctx) require.NoError(t, err) require.Equal(t, "some-token-value", token) // Verify metadata require.Equal(t, SourceAuthHeader, extractor.Source) require.Equal(t, fiber.HeaderAuthorization, extractor.Key) require.Empty(t, extractor.AuthScheme) }) t.Run("empty_header_returns_not_found", func(t *testing.T) { t.Parallel() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) // No Authorization header set extractor := FromAuthHeader("") // No scheme token, err := extractor.Extract(ctx) require.Empty(t, token) require.ErrorIs(t, err, ErrNotFound) }) } // go test -run Test_Extractor_Chain_NilFunctions func Test_Extractor_Chain_NilFunctions(t *testing.T) { t.Parallel() app := fiber.New() // Test chain with nil extractor functions nilExtractor := Extractor{ Extract: nil, Key: "nil", Source: SourceCustom, } validExtractor := Extractor{ Extract: func(_ fiber.Ctx) (string, error) { return "valid-token", nil }, Key: "valid", Source: SourceCustom, } chainExtractor := Chain(nilExtractor, validExtractor) ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) token, err := chainExtractor.Extract(ctx) require.NoError(t, err) require.Equal(t, "valid-token", token) } // go test -run Test_Extractor_Chain_AllErrors func Test_Extractor_Chain_AllErrors(t *testing.T) { t.Parallel() app := fiber.New() // Test chain where all extractors return errors errorExtractor1 := Extractor{ Extract: func(_ fiber.Ctx) (string, error) { return "", fiber.NewError(fiber.StatusUnauthorized, "First auth error") }, Key: "error1", Source: SourceCustom, } errorExtractor2 := Extractor{ Extract: func(_ fiber.Ctx) (string, error) { return "", fiber.NewError(fiber.StatusForbidden, "Second auth error") }, Key: "error2", Source: SourceCustom, } chainExtractor := Chain(errorExtractor1, errorExtractor2) ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) token, err := chainExtractor.Extract(ctx) require.Empty(t, token) require.Error(t, err) require.Contains(t, err.Error(), "Second auth error") // Should return last error var fe *fiber.Error require.ErrorAs(t, err, &fe) require.Equal(t, fiber.StatusForbidden, fe.Code) } // go test -run Test_Extractor_Chain_MixedScenarios func Test_Extractor_Chain_MixedScenarios(t *testing.T) { t.Parallel() // Define reusable extractors failingExtractor := Extractor{ Extract: func(_ fiber.Ctx) (string, error) { return "", ErrNotFound }, Key: "fail", Source: SourceCustom, } errorExtractor := Extractor{ Extract: func(_ fiber.Ctx) (string, error) { return "", fiber.NewError(fiber.StatusBadRequest, "Bad request") }, Key: "error", Source: SourceCustom, } successExtractor := Extractor{ Extract: func(_ fiber.Ctx) (string, error) { return "success", nil }, Key: "success", Source: SourceCustom, } t.Run("error_then_success", func(t *testing.T) { t.Parallel() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) chain := Chain(errorExtractor, successExtractor) token, err := chain.Extract(ctx) require.NoError(t, err) require.Equal(t, "success", token) }) t.Run("fail_then_error_then_success", func(t *testing.T) { t.Parallel() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) chain := Chain(failingExtractor, errorExtractor, successExtractor) token, err := chain.Extract(ctx) require.NoError(t, err) require.Equal(t, "success", token) }) t.Run("fail_then_error", func(t *testing.T) { t.Parallel() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) chain := Chain(failingExtractor, errorExtractor) token, err := chain.Extract(ctx) require.Empty(t, token) require.Error(t, err) require.Contains(t, err.Error(), "Bad request") }) } // go test -run Test_Extractor_SourceTypes func Test_Extractor_SourceTypes(t *testing.T) { t.Parallel() t.Run("individual_extractor_sources", func(t *testing.T) { t.Parallel() // Test that all source types are properly set require.Equal(t, SourceHeader, FromHeader("test").Source) require.Equal(t, SourceAuthHeader, FromAuthHeader("Bearer").Source) require.Equal(t, SourceAuthHeader, FromAuthHeader("").Source) // Empty scheme should still be SourceAuthHeader require.Equal(t, SourceForm, FromForm("test").Source) require.Equal(t, SourceQuery, FromQuery("test").Source) require.Equal(t, SourceParam, FromParam("test").Source) require.Equal(t, SourceCookie, FromCookie("test").Source) require.Equal(t, SourceCustom, FromCustom("test", func(_ fiber.Ctx) (string, error) { return "test", nil }).Source) }) t.Run("chain_source_metadata", func(t *testing.T) { t.Parallel() // Test chain source (should use first extractor's source) chain := Chain(FromHeader("X-Test"), FromQuery("test")) require.Equal(t, SourceHeader, chain.Source) require.Equal(t, "X-Test", chain.Key) }) } // go test -run Test_Extractor_URL_Encoded func Test_Extractor_URL_Encoded(t *testing.T) { t.Parallel() t.Run("FromQuery_with_spaces", func(t *testing.T) { t.Parallel() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) ctx.Request().SetRequestURI("/?token=token%20with%20spaces") token, err := FromQuery("token").Extract(ctx) require.NoError(t, err) require.Equal(t, "token with spaces", token) // Should be URL-decoded automatically by fasthttp }) t.Run("FromForm_with_plus", func(t *testing.T) { t.Parallel() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) ctx.Request().Header.SetContentType(fiber.MIMEApplicationForm) ctx.Request().Header.SetMethod(fiber.MethodPost) ctx.Request().SetBodyString("token=token%2Bwith%2Bplus") token, err := FromForm("token").Extract(ctx) require.NoError(t, err) require.Equal(t, "token+with+plus", token) // URL-decoded }) t.Run("FromQuery_base64_encoded", func(t *testing.T) { t.Parallel() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) base64Value := "cGFzc3dvcmQ%3D" // URL-encoded base64 "cGFzc3dvcmQ=" ctx.Request().SetRequestURI("/?token=" + base64Value) token, err := FromQuery("token").Extract(ctx) require.NoError(t, err) require.Equal(t, "cGFzc3dvcmQ=", token) // Should be URL-decoded }) t.Run("FromParam_with_slashes", func(t *testing.T) { t.Parallel() app := fiber.New() app.Get("/test/:token", func(c fiber.Ctx) error { token, extractErr := FromParam("token").Extract(c) require.NoError(t, extractErr) require.Equal(t, "token/with/slashes", token) return nil }) _, err := app.Test(newRequest(fiber.MethodGet, "/test/token%2Fwith%2Fslashes")) require.NoError(t, err) }) } func Test_isValidToken68(t *testing.T) { t.Parallel() cases := []struct { name string token string want bool }{ {name: "empty string", token: "", want: false}, {name: "single uppercase", token: "A", want: true}, {name: "single lowercase", token: "a", want: true}, {name: "single digit", token: "0", want: true}, {name: "all allowed symbols except =", token: "-._~+/", want: true}, {name: "letters and digits", token: "token68", want: true}, {name: "equals at end", token: "token=", want: true}, {name: "multiple equals", token: "token==", want: true}, {name: "equals at start", token: "=token", want: false}, {name: "equals in middle", token: "tok=en", want: false}, {name: "equals not at end with other chars", token: "token=extra", want: false}, {name: "space in token", token: "token space", want: false}, {name: "tab character in token", token: "token\ttab", want: false}, {name: "invalid symbol", token: "token@", want: false}, {name: "valid token68", token: "token68", want: true}, {token: "token68=", want: true, name: "valid token68 with equals at end"}, {token: "token68==", want: true, name: "multiple equals at end"}, {token: "token68=extra", want: false, name: "equals followed by extra chars"}, {token: "T0ken-._~+/=", want: true, name: "all allowed chars with equals at end"}, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { t.Parallel() got := isValidToken68(tc.token) if got != tc.want { t.Errorf("isValidToken68(%q) = %v, want %v", tc.token, got, tc.want) } }) } } ================================================ FILE: go.mod ================================================ module github.com/gofiber/fiber/v3 go 1.25.0 require ( github.com/gofiber/schema v1.7.0 github.com/gofiber/utils/v2 v2.0.2 github.com/google/uuid v1.6.0 github.com/mattn/go-colorable v0.1.14 github.com/mattn/go-isatty v0.0.20 github.com/shamaton/msgpack/v3 v3.1.0 github.com/stretchr/testify v1.11.1 github.com/tinylib/msgp v1.6.3 github.com/valyala/bytebufferpool v1.0.0 github.com/valyala/fasthttp v1.69.0 golang.org/x/crypto v0.49.0 ) require ( github.com/andybalholm/brotli v1.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/fxamacker/cbor/v2 v2.9.0 // direct github.com/klauspost/compress v1.18.4 // indirect github.com/philhofer/fwd v1.2.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/x448/float16 v0.8.4 // indirect golang.org/x/net v0.52.0 golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.35.0 gopkg.in/yaml.v3 v3.0.1 // indirect ) ================================================ FILE: go.sum ================================================ github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM= github.com/fxamacker/cbor/v2 v2.9.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ= github.com/gofiber/schema v1.7.0 h1:yNM+FNRZjyYEli9Ey0AXRBrAY9jTnb+kmGs3lJGPvKg= github.com/gofiber/schema v1.7.0/go.mod h1:A/X5Ffyru4p9eBdp99qu+nzviHzQiZ7odLT+TwxWhbk= github.com/gofiber/utils/v2 v2.0.2 h1:ShRRssz0F3AhTlAQcuEj54OEDtWF7+HJDwEi/aa6QLI= github.com/gofiber/utils/v2 v2.0.2/go.mod h1:+9Ub4NqQ+IaJoTliq5LfdmOJAA/Hzwf4pXOxOa3RrJ0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c= github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/philhofer/fwd v1.2.0 h1:e6DnBTl7vGY+Gz322/ASL4Gyp1FspeMvx1RNDoToZuM= github.com/philhofer/fwd v1.2.0/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/shamaton/msgpack/v3 v3.1.0 h1:jsk0vEAqVvvS9+fTZ5/EcQ9tz860c9pWxJ4Iwecz8gU= github.com/shamaton/msgpack/v3 v3.1.0/go.mod h1:DcQG8jrdrQCIxr3HlMYkiXdMhK+KfN2CitkyzsQV4uc= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tinylib/msgp v1.6.3 h1:bCSxiTz386UTgyT1i0MSCvdbWjVW+8sG3PjkGsZQt4s= github.com/tinylib/msgp v1.6.3/go.mod h1:RSp0LW9oSxFut3KzESt5Voq4GVWyS+PSulT77roAqEA= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.69.0 h1:fNLLESD2SooWeh2cidsuFtOcrEi4uB4m1mPrkJMZyVI= github.com/valyala/fasthttp v1.69.0/go.mod h1:4wA4PfAraPlAsJ5jMSqCE2ug5tqUPwKXxVj8oNECGcw= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= ================================================ FILE: group.go ================================================ // ⚡️ Fiber is an Express inspired web framework written in Go with ☕️ // 🤖 GitHub Repository: https://github.com/gofiber/fiber // 📌 API Documentation: https://docs.gofiber.io package fiber import ( "fmt" "reflect" ) // Group represents a collection of routes that share middleware and a common // path prefix. type Group struct { app *App parentGroup *Group name string Prefix string anyRouteDefined bool } // Name Assign name to specific route or group itself. // // If this method is used before any route added to group, it'll set group name and OnGroupNameHook will be used. // Otherwise, it'll set route name and OnName hook will be used. func (grp *Group) Name(name string) Router { if grp.anyRouteDefined { grp.app.Name(name) return grp } grp.app.mutex.Lock() if grp.parentGroup != nil { grp.name = grp.parentGroup.name + name } else { grp.name = name } if err := grp.app.hooks.executeOnGroupNameHooks(*grp); err != nil { panic(err) } grp.app.mutex.Unlock() return grp } // Use registers a middleware route that will match requests // with the provided prefix (which is optional and defaults to "/"). // Also, you can pass another app instance as a sub-router along a routing path. // It's very useful to split up a large API as many independent routers and // compose them as a single service using Use. The fiber's error handler and // any of the fiber's sub apps are added to the application's error handlers // to be invoked on errors that happen within the prefix route. // // app.Use(func(c fiber.Ctx) error { // return c.Next() // }) // app.Use("/api", func(c fiber.Ctx) error { // return c.Next() // }) // app.Use("/api", handler, func(c fiber.Ctx) error { // return c.Next() // }) // subApp := fiber.New() // app.Use("/mounted-path", subApp) // // This method will match all HTTP verbs: GET, POST, PUT, HEAD etc... func (grp *Group) Use(args ...any) Router { var subApp *App var prefix string var prefixes []string var handlers []Handler for i := range args { switch arg := args[i].(type) { case string: prefix = arg case *App: subApp = arg case []string: prefixes = arg default: handler, ok := toFiberHandler(arg) if !ok { panic(fmt.Sprintf("use: invalid handler %v\n", reflect.TypeOf(arg))) } handlers = append(handlers, handler) } } if len(prefixes) == 0 { prefixes = append(prefixes, prefix) } for _, prefix := range prefixes { if subApp != nil { return grp.mount(prefix, subApp) } grp.app.register([]string{methodUse}, getGroupPath(grp.Prefix, prefix), grp, handlers...) } if !grp.anyRouteDefined { grp.anyRouteDefined = true } return grp } // Get registers a route for GET methods that requests a representation // of the specified resource. Requests using GET should only retrieve data. func (grp *Group) Get(path string, handler any, handlers ...any) Router { return grp.Add([]string{MethodGet}, path, handler, handlers...) } // Head registers a route for HEAD methods that asks for a response identical // to that of a GET request, but without the response body. func (grp *Group) Head(path string, handler any, handlers ...any) Router { return grp.Add([]string{MethodHead}, path, handler, handlers...) } // Post registers a route for POST methods that is used to submit an entity to the // specified resource, often causing a change in state or side effects on the server. func (grp *Group) Post(path string, handler any, handlers ...any) Router { return grp.Add([]string{MethodPost}, path, handler, handlers...) } // Put registers a route for PUT methods that replaces all current representations // of the target resource with the request payload. func (grp *Group) Put(path string, handler any, handlers ...any) Router { return grp.Add([]string{MethodPut}, path, handler, handlers...) } // Delete registers a route for DELETE methods that deletes the specified resource. func (grp *Group) Delete(path string, handler any, handlers ...any) Router { return grp.Add([]string{MethodDelete}, path, handler, handlers...) } // Connect registers a route for CONNECT methods that establishes a tunnel to the // server identified by the target resource. func (grp *Group) Connect(path string, handler any, handlers ...any) Router { return grp.Add([]string{MethodConnect}, path, handler, handlers...) } // Options registers a route for OPTIONS methods that is used to describe the // communication options for the target resource. func (grp *Group) Options(path string, handler any, handlers ...any) Router { return grp.Add([]string{MethodOptions}, path, handler, handlers...) } // Trace registers a route for TRACE methods that performs a message loop-back // test along the path to the target resource. func (grp *Group) Trace(path string, handler any, handlers ...any) Router { return grp.Add([]string{MethodTrace}, path, handler, handlers...) } // Patch registers a route for PATCH methods that is used to apply partial // modifications to a resource. func (grp *Group) Patch(path string, handler any, handlers ...any) Router { return grp.Add([]string{MethodPatch}, path, handler, handlers...) } // Add allows you to specify multiple HTTP methods to register a route. // The provided handlers are executed in order, starting with `handler` and then the variadic `handlers`. func (grp *Group) Add(methods []string, path string, handler any, handlers ...any) Router { converted := collectHandlers("group", append([]any{handler}, handlers...)...) grp.app.register(methods, getGroupPath(grp.Prefix, path), grp, converted...) if !grp.anyRouteDefined { grp.anyRouteDefined = true } return grp } // All will register the handler on all HTTP methods func (grp *Group) All(path string, handler any, handlers ...any) Router { _ = grp.Add(grp.app.config.RequestMethods, path, handler, handlers...) return grp } // Group is used for Routes with common prefix to define a new sub-router with optional middleware. // // api := app.Group("/api") // api.Get("/users", handler) func (grp *Group) Group(prefix string, handlers ...any) Router { prefix = getGroupPath(grp.Prefix, prefix) if len(handlers) > 0 { converted := collectHandlers("group", handlers...) grp.app.register([]string{methodUse}, prefix, grp, converted...) } // Create new group newGrp := &Group{Prefix: prefix, app: grp.app, parentGroup: grp} if err := grp.app.hooks.executeOnGroupHooks(*newGrp); err != nil { panic(err) } return newGrp } // RouteChain creates a Registering instance scoped to the group's prefix, // allowing chained route declarations for the same path. func (grp *Group) RouteChain(path string) Register { // Create new group register := &Registering{app: grp.app, group: grp, path: getGroupPath(grp.Prefix, path)} return register } // Route is used to define routes with a common prefix inside the supplied // function. It mirrors the legacy helper and reuses the Group method to create // a sub-router. func (grp *Group) Route(prefix string, fn func(router Router), name ...string) Router { if fn == nil { panic("route handler 'fn' cannot be nil") } // Create new group group := grp.Group(prefix) if len(name) > 0 { group.Name(name[0]) } // Define routes fn(group) return group } ================================================ FILE: helpers.go ================================================ // ⚡️ Fiber is an Express inspired web framework written in Go with ☕️ // 🤖 GitHub Repository: https://github.com/gofiber/fiber // 📌 API Documentation: https://docs.gofiber.io package fiber import ( "bytes" "context" "crypto/tls" "errors" "fmt" "io" "net" "os" "path/filepath" "reflect" "slices" "strconv" "strings" "sync" "time" "unsafe" "github.com/gofiber/utils/v2" utilsbytes "github.com/gofiber/utils/v2/bytes" "github.com/gofiber/fiber/v3/log" "github.com/valyala/bytebufferpool" "github.com/valyala/fasthttp" ) // acceptedType is a struct that holds the parsed value of an Accept header // along with quality, specificity, parameters, and order. // Used for sorting accept headers. type acceptedType struct { params headerParams spec string quality float64 specificity int order int } const noCacheValue = "no-cache" // Pre-allocated byte slices for accept header parsing var ( semicolonQEquals = []byte(";q=") wildcardAll = []byte("*/*") wildcardSuffix = []byte("/*") ) type headerParams map[string][]byte // ValueFromContext retrieves a value stored under key from supported context types. // // Supported context types: // - Ctx (including CustomCtx implementations) // - *fasthttp.RequestCtx // - context.Context func ValueFromContext[T any](ctx, key any) (T, bool) { switch typed := ctx.(type) { case Ctx: val, ok := typed.Locals(key).(T) return val, ok case *fasthttp.RequestCtx: val, ok := typed.UserValue(key).(T) return val, ok case context.Context: val, ok := typed.Value(key).(T) return val, ok default: var zero T return zero, false } } // StoreInContext stores key/value in both Fiber locals and request context. // // This is useful when values need to be available via both c.Locals() and // context.Context lookups throughout middleware and handlers. func StoreInContext(c Ctx, key, value any) { c.Locals(key, value) if c.App().config.PassLocalsToContext { c.SetContext(context.WithValue(c.Context(), key, value)) } } // getTLSConfig returns a net listener's tls config func getTLSConfig(ln net.Listener) *tls.Config { if ln == nil { return nil } type tlsConfigProvider interface { TLSConfig() *tls.Config } type configProvider interface { Config() *tls.Config } if provider, ok := ln.(tlsConfigProvider); ok { return provider.TLSConfig() } if provider, ok := ln.(configProvider); ok { return provider.Config() } pointer := reflect.ValueOf(ln) if !pointer.IsValid() { return nil } // Reflection fallback for listeners that do not expose a TLS config method. val := reflect.Indirect(pointer) if !val.IsValid() { return nil } field := val.FieldByName("config") if !field.IsValid() { return nil } if field.Type() != reflect.TypeFor[*tls.Config]() { return nil } if field.CanInterface() { if cfg, ok := field.Interface().(*tls.Config); ok { return cfg } return nil } if !field.CanAddr() { return nil } value := reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem() //nolint:gosec // Access to unexported field is required for listeners that don't expose TLS config methods. if !value.IsValid() { return nil } cfg, ok := value.Interface().(*tls.Config) if !ok { return nil } return cfg } // readContent opens a named file and read content from it func readContent(rf io.ReaderFrom, name string) (int64, error) { // Read file f, err := os.Open(filepath.Clean(name)) if err != nil { return 0, fmt.Errorf("failed to open: %w", err) } defer func() { if err = f.Close(); err != nil { log.Errorf("Error closing file: %s", err) } }() n, readErr := rf.ReadFrom(f) if readErr != nil { return n, fmt.Errorf("failed to read: %w", readErr) } return n, nil } // quoteString escapes special characters using percent-encoding. // Non-ASCII bytes are encoded as well so the result is always ASCII. func (app *App) quoteString(raw string) string { bb := bytebufferpool.Get() quoted := app.toString(fasthttp.AppendQuotedArg(bb.B, app.toBytes(raw))) bytebufferpool.Put(bb) return quoted } // quoteRawString escapes only characters that need quoting according to // https://www.rfc-editor.org/rfc/rfc9110#section-5.6.4 so the result may // contain non-ASCII bytes. func (app *App) quoteRawString(raw string) string { const hex = "0123456789ABCDEF" bb := bytebufferpool.Get() defer bytebufferpool.Put(bb) for i := 0; i < len(raw); i++ { c := raw[i] switch { case c == '\\' || c == '"': // escape backslash and quote bb.B = append(bb.B, '\\', c) case c == '\n': bb.B = append(bb.B, '\\', 'n') case c == '\r': bb.B = append(bb.B, '\\', 'r') case c < 0x20 || c == 0x7f: // percent-encode control and DEL bb.B = append(bb.B, '%', hex[c>>4], hex[c&0x0f], ) default: bb.B = append(bb.B, c) } } return app.toString(bb.B) } // isASCII reports whether the provided string contains only ASCII characters. // See: https://www.rfc-editor.org/rfc/rfc0020 func (*App) isASCII(s string) bool { for i := 0; i < len(s); i++ { if s[i] > 127 { return false } } return true } // uniqueRouteStack drop all not unique routes from the slice func uniqueRouteStack(stack []*Route) []*Route { m := make(map[*Route]struct{}, len(stack)) unique := make([]*Route, 0, len(stack)) for _, v := range stack { if _, ok := m[v]; !ok { m[v] = struct{}{} unique = append(unique, v) } } return unique } // defaultString returns the value or a default value if it is set func defaultString(value string, defaultValue []string) string { if value == "" && len(defaultValue) > 0 { return defaultValue[0] } return value } func getGroupPath(prefix, path string) string { if path == "" { return prefix } if path[0] != '/' { path = "/" + path } return utils.TrimRight(prefix, '/') + path } // acceptsOffer determines if an offer matches a given specification. // It supports a trailing '*' wildcard and performs case-insensitive exact matching. // Returns true if the offer matches the specification, false otherwise. func acceptsOffer(spec, offer string, _ headerParams) bool { if len(spec) >= 1 && spec[len(spec)-1] == '*' { prefix := spec[:len(spec)-1] if len(offer) < len(prefix) { return false } return utils.EqualFold(prefix, offer[:len(prefix)]) } return utils.EqualFold(spec, offer) } // acceptsLanguageOfferBasic determines if a language tag offer matches a range // according to RFC 4647 Basic Filtering. // A match occurs if the range exactly equals the tag or is a prefix of the tag // followed by a hyphen. The comparison is case-insensitive. Only a single "*" // as the entire range is allowed. Any "*" appearing after a hyphen renders the // range invalid and will not match. func acceptsLanguageOfferBasic(spec, offer string, _ headerParams) bool { if spec == "*" { return true } if strings.IndexByte(spec, '*') >= 0 { return false } if utils.EqualFold(spec, offer) { return true } return len(offer) > len(spec) && utils.EqualFold(offer[:len(spec)], spec) && offer[len(spec)] == '-' } // acceptsLanguageOfferExtended determines if a language tag offer matches a // range according to RFC 4647 Extended Filtering (§3.3.2). // - Case-insensitive comparisons // - '*' matches zero or more subtags (can "slide") // - Unspecified subtags are treated like '*' (so trailing/extraneous tag subtags are fine) // - Matching fails if sliding encounters a singleton (incl. 'x') func acceptsLanguageOfferExtended(spec, offer string, _ headerParams) bool { if spec == "*" { return true } if spec == "" || offer == "" { return false } // Use stack-allocated arrays to avoid heap allocations for typical language tags var rsBuf, tsBuf [8]string rs := rsBuf[:0] ts := tsBuf[:0] // Parse spec subtags without allocation for typical cases for s := range strings.SplitSeq(spec, "-") { rs = append(rs, s) } // Parse offer subtags without allocation for typical cases for s := range strings.SplitSeq(offer, "-") { ts = append(ts, s) } // Step 2: first subtag must match (or be '*') if rs[0] != "*" && !utils.EqualFold(rs[0], ts[0]) { return false } i, j := 1, 1 // i = range index, j = tag index for i < len(rs) { if rs[i] == "*" { // 3.A: '*' matches zero or more subtags i++ continue } if j >= len(ts) { // 3.B: ran out of tag subtags return false } if utils.EqualFold(rs[i], ts[j]) { // 3.C: exact subtag match i++ j++ continue } // 3.D: singleton barrier (one letter or digit, incl. 'x') if len(ts[j]) == 1 { return false } // 3.E: slide forward in the tag and try again j++ } // 4: matched all range subtags return true } // acceptsOfferType This function determines if an offer type matches a given specification. // It checks if the specification is equal to */* (i.e., all types are accepted). // It gets the MIME type of the offer (either from the offer itself or by its file extension). // It checks if the offer MIME type matches the specification MIME type or if the specification is of the form /* and the offer MIME type has the same MIME type. // It checks if the offer contains every parameter present in the specification. // Returns true if the offer type matches the specification, false otherwise. func acceptsOfferType(spec, offerType string, specParams headerParams) bool { var offerMime, offerParams string if i := strings.IndexByte(offerType, ';'); i == -1 { offerMime = offerType } else { offerMime = offerType[:i] offerParams = offerType[i:] } // Accept: */* if spec == "*/*" { return paramsMatch(specParams, offerParams) } var mimetype string if strings.IndexByte(offerMime, '/') != -1 { mimetype = offerMime // MIME type } else { mimetype = utils.GetMIME(offerMime) // extension } if spec == mimetype { // Accept: / return paramsMatch(specParams, offerParams) } s := strings.IndexByte(mimetype, '/') specSlash := strings.IndexByte(spec, '/') // Accept: /* if s != -1 && specSlash != -1 { if utils.EqualFold(spec[:specSlash], mimetype[:s]) && (spec[specSlash:] == "/*" || mimetype[s:] == "/*") { return paramsMatch(specParams, offerParams) } } return false } // paramsMatch returns whether offerParams contains all parameters present in specParams. // Matching is case-insensitive, and surrounding quotes are stripped. // To align with the behavior of res.format from Express, the order of parameters is // ignored, and if a parameter is specified twice in the incoming Accept, the last // provided value is given precedence. // In the case of quoted values, RFC 9110 says that we must treat any character escaped // by a backslash as equivalent to the character itself (e.g., "a\aa" is equivalent to "aaa"). // For the sake of simplicity, we forgo this and compare the value as-is. Besides, it would // be highly unusual for a client to escape something other than a double quote or backslash. // See https://www.rfc-editor.org/rfc/rfc9110#name-parameters func paramsMatch(specParamStr headerParams, offerParams string) bool { if len(specParamStr) == 0 { return true } allSpecParamsMatch := true for specParam, specVal := range specParamStr { foundParam := false fasthttp.VisitHeaderParams(utils.UnsafeBytes(offerParams), func(key, value []byte) bool { if utils.EqualFold(specParam, utils.UnsafeString(key)) { foundParam = true unescaped, err := unescapeHeaderValue(value) if err != nil { allSpecParamsMatch = false return false } allSpecParamsMatch = utils.EqualFold(specVal, unescaped) return false } return true }) if !foundParam || !allSpecParamsMatch { return false } } return allSpecParamsMatch } // getSplicedStrList function takes a string and a string slice as an argument, divides the string into different // elements divided by ',' and stores these elements in the string slice. // It returns the populated string slice as an output. // // If the given slice hasn't enough space, it will allocate more and return. func getSplicedStrList(headerValue string, dst []string) []string { if headerValue == "" { return nil } dst = dst[:0] segmentStart := 0 for i := 0; i < len(headerValue); i++ { if headerValue[i] == ',' { dst = append(dst, utils.TrimSpace(headerValue[segmentStart:i])) segmentStart = i + 1 } } dst = append(dst, utils.TrimSpace(headerValue[segmentStart:])) return dst } func joinHeaderValues(headers [][]byte) []byte { switch len(headers) { case 0: return nil case 1: return headers[0] default: return bytes.Join(headers, []byte{','}) } } func unescapeHeaderValue(v []byte) ([]byte, error) { if bytes.IndexByte(v, '\\') == -1 { return v, nil } res := make([]byte, 0, len(v)) escaping := false for i, c := range v { if escaping { res = append(res, c) escaping = false continue } if c == '\\' { // invalid escape at end of string if i == len(v)-1 { return nil, errInvalidEscapeSequence } escaping = true continue } res = append(res, c) } if escaping { return nil, errInvalidEscapeSequence } return res, nil } // forEachMediaRange parses an Accept or Content-Type header, calling functor // on each media range. // See: https://www.rfc-editor.org/rfc/rfc9110#name-content-negotiation-fields func forEachMediaRange(header []byte, functor func([]byte)) { hasDQuote := bytes.IndexByte(header, '"') != -1 for len(header) > 0 { n := 0 header = utils.TrimLeft(header, ' ') quotes := 0 escaping := false if hasDQuote { // Complex case. We need to keep track of quotes and quoted-pairs (i.e., characters escaped with \ ) loop: for n < len(header) { switch header[n] { case ',': if quotes%2 == 0 { break loop } case '"': if !escaping { quotes++ } case '\\': if quotes%2 == 1 { escaping = !escaping } default: // all other characters are ignored } n++ } } else { // Simple case. Just look for the next comma. if n = bytes.IndexByte(header, ','); n == -1 { n = len(header) } } functor(header[:n]) if n >= len(header) { return } header = header[n+1:] } } // Pool for headerParams instances. The headerParams object *must* // be cleared before being returned to the pool. var headerParamPool = sync.Pool{ New: func() any { return make(headerParams) }, } // getOffer return valid offer for header negotiation. func getOffer(header []byte, isAccepted func(spec, offer string, specParams headerParams) bool, offers ...string) string { if len(offers) == 0 { return "" } if len(header) == 0 { return offers[0] } acceptedTypes := make([]acceptedType, 0, 8) order := 0 // Parse header and get accepted types with their quality and specificity // See: https://www.rfc-editor.org/rfc/rfc9110#name-content-negotiation-fields forEachMediaRange(header, func(accept []byte) { order++ spec, quality := accept, 1.0 var params headerParams if i := bytes.IndexByte(accept, ';'); i != -1 { spec = accept[:i] // Optimized quality parsing qIndex := i + 3 if bytes.HasPrefix(accept[i:], semicolonQEquals) && bytes.IndexByte(accept[qIndex:], ';') == -1 { if q, err := fasthttp.ParseUfloat(accept[qIndex:]); err == nil { quality = q } } else { params, _ = headerParamPool.Get().(headerParams) //nolint:errcheck // only contains headerParams for k := range params { delete(params, k) } fasthttp.VisitHeaderParams(accept[i:], func(key, value []byte) bool { if len(key) == 1 && key[0] == 'q' { if q, err := fasthttp.ParseUfloat(value); err == nil { quality = q } return false } lowerKey := utils.UnsafeString(utilsbytes.UnsafeToLower(key)) val, err := unescapeHeaderValue(value) if err != nil { return true } params[lowerKey] = val return true }) } // Skip this accept type if quality is 0.0 // See: https://www.rfc-editor.org/rfc/rfc9110#quality.values if quality == 0.0 { return } } spec = utils.TrimSpace(spec) // Determine specificity var specificity int // check for wildcard this could be a mime */* or a wildcard character * switch { case len(spec) == 1 && spec[0] == '*': specificity = 1 case bytes.Equal(spec, wildcardAll): specificity = 1 case bytes.HasSuffix(spec, wildcardSuffix): specificity = 2 case bytes.IndexByte(spec, '/') != -1: specificity = 3 default: specificity = 4 } // Add to accepted types acceptedTypes = append(acceptedTypes, acceptedType{ spec: utils.UnsafeString(spec), quality: quality, specificity: specificity, order: order, params: params, }) }) if len(acceptedTypes) > 1 { // Sort accepted types by quality and specificity, preserving order of equal elements sortAcceptedTypes(acceptedTypes) } // Find the first offer that matches the accepted types for _, acceptedType := range acceptedTypes { for _, offer := range offers { if offer == "" { continue } if isAccepted(acceptedType.spec, offer, acceptedType.params) { if acceptedType.params != nil { headerParamPool.Put(acceptedType.params) } return offer } } if acceptedType.params != nil { headerParamPool.Put(acceptedType.params) } } return "" } // sortAcceptedTypes sorts accepted types by quality and specificity, preserving order of equal elements // A type with parameters has higher priority than an equivalent one without parameters. // e.g., text/html;a=1;b=2 comes before text/html;a=1 // See: https://www.rfc-editor.org/rfc/rfc9110#name-content-negotiation-fields func sortAcceptedTypes(at []acceptedType) { for i := 1; i < len(at); i++ { lo, hi := 0, i-1 for lo <= hi { mid := (lo + hi) / 2 if at[i].quality < at[mid].quality || (at[i].quality == at[mid].quality && at[i].specificity < at[mid].specificity) || (at[i].quality == at[mid].quality && at[i].specificity == at[mid].specificity && len(at[i].params) < len(at[mid].params)) || (at[i].quality == at[mid].quality && at[i].specificity == at[mid].specificity && len(at[i].params) == len(at[mid].params) && at[i].order > at[mid].order) { lo = mid + 1 } else { hi = mid - 1 } } for j := i; j > lo; j-- { at[j-1], at[j] = at[j], at[j-1] } } } // normalizeEtag validates an entity tag and returns the // value without quotes. weak is true if the tag has the "W/" prefix. func normalizeEtag(t string) (value string, weak, ok bool) { //nolint:nonamedreturns // gocritic unnamedResult requires naming the parsed ETag components weak = strings.HasPrefix(t, "W/") if weak { t = t[2:] } if len(t) < 2 || t[0] != '"' || t[len(t)-1] != '"' { return "", weak, false } return t[1 : len(t)-1], weak, true } // matchEtag performs a weak comparison of entity tags according to // RFC 9110 §8.8.3.2. The weak indicator ("W/") is ignored, but both tags must // be properly quoted. Invalid tags result in a mismatch. func matchEtag(s, etag string) bool { n1, _, ok1 := normalizeEtag(s) n2, _, ok2 := normalizeEtag(etag) if !ok1 || !ok2 { return false } return n1 == n2 } // matchEtagStrong performs a strong entity-tag comparison following // RFC 9110 §8.8.3.1. A weak tag never matches a strong one, even if the quoted // values are identical. func matchEtagStrong(s, etag string) bool { n1, w1, ok1 := normalizeEtag(s) n2, w2, ok2 := normalizeEtag(etag) if !ok1 || !ok2 || w1 || w2 { return false } return n1 == n2 } // isEtagStale reports whether a response with the given ETag would be considered // stale when presented with the raw If-None-Match header value. Comparison is // weak as defined by RFC 9110 §8.8.3.2. func (app *App) isEtagStale(etag string, noneMatchBytes []byte) bool { var start, end int header := utils.TrimSpace(app.toString(noneMatchBytes)) // Short-circuit the wildcard case: "*" never counts as stale. if header == "*" { return false } // Adapted from: // https://github.com/jshttp/fresh/blob/master/index.js#L110 for i := range noneMatchBytes { switch noneMatchBytes[i] { case 0x20: if start == end { start = i + 1 end = i + 1 } case 0x2c: if matchEtag(app.toString(noneMatchBytes[start:end]), etag) { return false } start = i + 1 end = i + 1 default: end = i + 1 } } return !matchEtag(app.toString(noneMatchBytes[start:end]), etag) } func parseAddr(raw string) (host, port string) { //nolint:nonamedreturns // gocritic unnamedResult requires naming host and port parts for clarity if raw == "" { return "", "" } raw = utils.TrimSpace(raw) // Handle IPv6 addresses enclosed in brackets as defined by RFC 3986 if strings.HasPrefix(raw, "[") { if end := strings.IndexByte(raw, ']'); end != -1 { host = raw[:end+1] // keep the closing ] if len(raw) > end+1 && raw[end+1] == ':' { return host, raw[end+2:] } return host, "" } } // Everything else with a colon if i := strings.LastIndexByte(raw, ':'); i != -1 { host, port = raw[:i], raw[i+1:] // If “host” still contains ':', we must have hit an un-bracketed IPv6 // literal. In that form a port is impossible, so treat the whole thing // as host. if strings.IndexByte(host, ':') >= 0 { return raw, "" } return host, port } // No colon, nothing to split return raw, "" } // isNoCache checks if the cacheControl header value contains a `no-cache` directive. // Per RFC 9111 §5.2.2.4, no-cache can appear as either: // - "no-cache" (applies to entire response) // - "no-cache=field-name" (applies to specific header field) // Both forms indicate the response should not be served from cache without revalidation. func isNoCache(cacheControl string) bool { n := len(cacheControl) if n < len(noCacheValue) { return false } const noCacheLen = len(noCacheValue) const asciiCaseFold = byte(0x20) for i := 0; i <= n-noCacheLen; i++ { if (cacheControl[i] | asciiCaseFold) != 'n' { continue } if !matchNoCacheToken(cacheControl, i) { continue } if i > 0 && !isNoCacheDelimiter(cacheControl[i-1]) { continue } // Handle: "no-cache", "no-cache, ...", "no-cache=...", "no-cache ," if i+noCacheLen == n { return true } if isNoCacheDelimiter(cacheControl[i+noCacheLen]) || cacheControl[i+noCacheLen] == '=' { return true } } return false } func isNoCacheDelimiter(c byte) bool { return c == ' ' || c == '\t' || c == ',' } func matchNoCacheToken(s string, i int) bool { // ASCII-only case-insensitive compare for "no-cache". const asciiCaseFold = byte(0x20) b := s[i:] return (b[0]|asciiCaseFold) == 'n' && (b[1]|asciiCaseFold) == 'o' && b[2] == '-' && (b[3]|asciiCaseFold) == 'c' && (b[4]|asciiCaseFold) == 'a' && (b[5]|asciiCaseFold) == 'c' && (b[6]|asciiCaseFold) == 'h' && (b[7]|asciiCaseFold) == 'e' } var errTestConnClosed = errors.New("testConn is closed") type testConn struct { r bytes.Buffer w bytes.Buffer isClosed bool sync.Mutex } // Read implements net.Conn by reading from the buffered input. func (c *testConn) Read(b []byte) (int, error) { c.Lock() defer c.Unlock() return c.r.Read(b) //nolint:wrapcheck // This must not be wrapped } // Write implements net.Conn by appending to the buffered output. func (c *testConn) Write(b []byte) (int, error) { c.Lock() defer c.Unlock() if c.isClosed { return 0, errTestConnClosed } return c.w.Write(b) //nolint:wrapcheck // This must not be wrapped } // Close marks the connection as closed and prevents further writes. func (c *testConn) Close() error { c.Lock() defer c.Unlock() c.isClosed = true return nil } // LocalAddr implements net.Conn and returns a placeholder address. func (*testConn) LocalAddr() net.Addr { return &net.TCPAddr{Port: 0, Zone: "", IP: net.IPv4zero} } // RemoteAddr implements net.Conn and returns a placeholder address. func (*testConn) RemoteAddr() net.Addr { return &net.TCPAddr{Port: 0, Zone: "", IP: net.IPv4zero} } // SetDeadline implements net.Conn but is a no-op for the in-memory connection. func (*testConn) SetDeadline(_ time.Time) error { return nil } // SetReadDeadline implements net.Conn but is a no-op for the in-memory connection. func (*testConn) SetReadDeadline(_ time.Time) error { return nil } // SetWriteDeadline implements net.Conn but is a no-op for the in-memory connection. func (*testConn) SetWriteDeadline(_ time.Time) error { return nil } func toStringImmutable(b []byte) string { return string(b) } func toBytesImmutable(s string) []byte { return []byte(s) } // HTTP methods and their unique INTs func (app *App) methodInt(s string) int { // For better performance if len(app.configured.RequestMethods) == 0 { switch s { case MethodGet: return methodGet case MethodHead: return methodHead case MethodPost: return methodPost case MethodPut: return methodPut case MethodDelete: return methodDelete case MethodConnect: return methodConnect case MethodOptions: return methodOptions case MethodTrace: return methodTrace case MethodPatch: return methodPatch default: return -1 } } // For method customization return slices.Index(app.config.RequestMethods, s) } func (app *App) method(methodInt int) string { return app.config.RequestMethods[methodInt] } // IsMethodSafe reports whether the HTTP method is considered safe. // See https://datatracker.ietf.org/doc/html/rfc9110#section-9.2.1 func IsMethodSafe(m string) bool { switch m { case MethodGet, MethodHead, MethodOptions, MethodTrace: return true default: return false } } // IsMethodIdempotent reports whether the HTTP method is considered idempotent. // See https://datatracker.ietf.org/doc/html/rfc9110#section-9.2.2 func IsMethodIdempotent(m string) bool { if IsMethodSafe(m) { return true } switch m { case MethodPut, MethodDelete: return true default: return false } } // Convert a string value to a specified type, handling errors and optional default values. func Convert[T any](value string, converter func(string) (T, error), defaultValue ...T) (T, error) { converted, err := converter(value) if err != nil { if len(defaultValue) > 0 { return defaultValue[0], nil } return converted, fmt.Errorf("failed to convert: %w", err) } return converted, nil } var ( errParsedEmptyString = errors.New("parsed result is empty string") errParsedEmptyBytes = errors.New("parsed result is empty bytes") errParsedType = errors.New("unsupported generic type") ) func genericParseType[V GenericType](str string) (V, error) { var v V switch any(v).(type) { case int: result, err := utils.ParseInt(str) if err != nil { return v, fmt.Errorf("failed to parse int: %w", err) } return any(int(result)).(V), nil //nolint:errcheck,forcetypeassert // not needed case int8: result, err := utils.ParseInt8(str) if err != nil { return v, fmt.Errorf("failed to parse int8: %w", err) } return any(result).(V), nil //nolint:errcheck,forcetypeassert // not needed case int16: result, err := utils.ParseInt16(str) if err != nil { return v, fmt.Errorf("failed to parse int16: %w", err) } return any(result).(V), nil //nolint:errcheck,forcetypeassert // not needed case int32: result, err := utils.ParseInt32(str) if err != nil { return v, fmt.Errorf("failed to parse int32: %w", err) } return any(result).(V), nil //nolint:errcheck,forcetypeassert // not needed case int64: result, err := utils.ParseInt(str) if err != nil { return v, fmt.Errorf("failed to parse int64: %w", err) } return any(result).(V), nil //nolint:errcheck,forcetypeassert // not needed case uint: result, err := utils.ParseUint(str) if err != nil { return v, fmt.Errorf("failed to parse uint: %w", err) } return any(uint(result)).(V), nil //nolint:errcheck,forcetypeassert // not needed case uint8: result, err := utils.ParseUint8(str) if err != nil { return v, fmt.Errorf("failed to parse uint8: %w", err) } return any(result).(V), nil //nolint:errcheck,forcetypeassert // not needed case uint16: result, err := utils.ParseUint16(str) if err != nil { return v, fmt.Errorf("failed to parse uint16: %w", err) } return any(result).(V), nil //nolint:errcheck,forcetypeassert // not needed case uint32: result, err := utils.ParseUint32(str) if err != nil { return v, fmt.Errorf("failed to parse uint32: %w", err) } return any(result).(V), nil //nolint:errcheck,forcetypeassert // not needed case uint64: result, err := utils.ParseUint(str) if err != nil { return v, fmt.Errorf("failed to parse uint64: %w", err) } return any(result).(V), nil //nolint:errcheck,forcetypeassert // not needed case float32: result, err := utils.ParseFloat32(str) if err != nil { return v, fmt.Errorf("failed to parse float32: %w", err) } return any(result).(V), nil //nolint:errcheck,forcetypeassert // not needed case float64: result, err := utils.ParseFloat64(str) if err != nil { return v, fmt.Errorf("failed to parse float64: %w", err) } return any(result).(V), nil //nolint:errcheck,forcetypeassert // not needed case bool: result, err := strconv.ParseBool(str) if err != nil { return v, fmt.Errorf("failed to parse bool: %w", err) } return any(result).(V), nil //nolint:errcheck,forcetypeassert // not needed case string: if str == "" { return v, errParsedEmptyString } return any(str).(V), nil //nolint:errcheck,forcetypeassert // not needed case []byte: if str == "" { return v, errParsedEmptyBytes } return any([]byte(str)).(V), nil //nolint:errcheck,forcetypeassert // not needed default: return v, errParsedType } } // GenericType enumerates the values that can be parsed from strings by the // generic helper functions. type GenericType interface { GenericTypeInteger | GenericTypeFloat | bool | string | []byte } // GenericTypeInteger is the union of all supported integer types. type GenericTypeInteger interface { GenericTypeIntegerSigned | GenericTypeIntegerUnsigned } // GenericTypeIntegerSigned is the union of supported signed integer types. type GenericTypeIntegerSigned interface { int | int8 | int16 | int32 | int64 } // GenericTypeIntegerUnsigned is the union of supported unsigned integer types. type GenericTypeIntegerUnsigned interface { uint | uint8 | uint16 | uint32 | uint64 } // GenericTypeFloat is the union of supported floating-point types. type GenericTypeFloat interface { float32 | float64 } ================================================ FILE: helpers_fuzz_test.go ================================================ //go:build go1.18 package fiber import ( "testing" ) // go test -v -run=^$ -fuzz=FuzzUtilsGetOffer func FuzzUtilsGetOffer(f *testing.F) { inputs := []string{ `application/json; v=1; foo=bar; q=0.938; extra=param, text/plain;param="big fox"; q=0.43`, `text/html, application/xhtml+xml, application/xml;q=0.9, */*;q=0.8`, `*/*`, `text/plain; q=0.5, text/html, text/x-dvi; q=0.8, text/x-c`, } for _, input := range inputs { f.Add(input) } f.Fuzz(func(_ *testing.T, spec string) { getOffer([]byte(spec), acceptsOfferType, `application/json;version=1;v=1;foo=bar`, `text/plain;param="big fox"`) }) } ================================================ FILE: helpers_test.go ================================================ // ⚡️ Fiber is an Express inspired web framework written in Go with ☕️ // 📝 GitHub Repository: https://github.com/gofiber/fiber // 📌 API Documentation: https://docs.gofiber.io package fiber import ( "bytes" "context" "crypto/tls" "math" "net" "os" "strconv" "testing" "time" "unsafe" "github.com/gofiber/utils/v2" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) func Test_Utils_GetOffer(t *testing.T) { t.Parallel() require.Empty(t, getOffer([]byte("hello"), acceptsOffer)) require.Equal(t, "1", getOffer([]byte(""), acceptsOffer, "1")) require.Empty(t, getOffer([]byte("2"), acceptsOffer, "1")) require.Empty(t, getOffer([]byte(""), acceptsOfferType)) require.Empty(t, getOffer([]byte("text/html"), acceptsOfferType)) require.Empty(t, getOffer([]byte("text/html"), acceptsOfferType, "application/json")) require.Empty(t, getOffer([]byte("text/html;q=0"), acceptsOfferType, "text/html")) require.Empty(t, getOffer([]byte("application/json, */*; q=0"), acceptsOfferType, "image/png")) require.Equal(t, "application/xml", getOffer([]byte("text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"), acceptsOfferType, "application/xml", "application/json")) require.Equal(t, "text/html", getOffer([]byte("text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"), acceptsOfferType, "text/html")) require.Equal(t, "application/pdf", getOffer([]byte("text/plain;q=0,application/pdf;q=0.9,*/*;q=0.000"), acceptsOfferType, "application/pdf", "application/json")) require.Equal(t, "application/pdf", getOffer([]byte("text/plain;q=0,application/pdf;q=0.9,*/*;q=0.000"), acceptsOfferType, "application/pdf", "application/json")) require.Equal(t, "text/plain;a=1", getOffer([]byte("text/plain;a=1"), acceptsOfferType, "text/plain;a=1")) require.Empty(t, getOffer([]byte("text/plain;a=1;b=2"), acceptsOfferType, "text/plain;b=2")) // Spaces, quotes, out of order params, and case insensitivity require.Equal(t, "text/plain", getOffer([]byte("text/plain "), acceptsOfferType, "text/plain")) require.Equal(t, "text/plain", getOffer([]byte("text/plain;q=0.4 "), acceptsOfferType, "text/plain")) require.Equal(t, "text/plain", getOffer([]byte("text/plain;q=0.4 ;"), acceptsOfferType, "text/plain")) require.Equal(t, "text/plain", getOffer([]byte("text/plain;q=0.4 ; p=foo"), acceptsOfferType, "text/plain")) require.Equal(t, "text/plain;b=2;a=1", getOffer([]byte("text/plain ;a=1;b=2"), acceptsOfferType, "text/plain;b=2;a=1")) require.Equal(t, "text/plain;a=1", getOffer([]byte("text/plain; a=1 "), acceptsOfferType, "text/plain;a=1")) require.Equal(t, `text/plain;a="1;b=2\",text/plain"`, getOffer([]byte(`text/plain;a="1;b=2\",text/plain";q=0.9`), acceptsOfferType, `text/plain;a=1;b=2`, `text/plain;a="1;b=2\",text/plain"`)) require.Equal(t, "text/plain;A=CAPS", getOffer([]byte(`text/plain;a="caPs"`), acceptsOfferType, "text/plain;A=CAPS")) // Priority require.Equal(t, "text/plain", getOffer([]byte("text/plain"), acceptsOfferType, "text/plain", "text/plain;a=1")) require.Equal(t, "text/plain;a=1", getOffer([]byte("text/plain"), acceptsOfferType, "text/plain;a=1", "", "text/plain")) require.Equal(t, "text/plain;a=1", getOffer([]byte("text/plain,text/plain;a=1"), acceptsOfferType, "text/plain", "text/plain;a=1")) require.Equal(t, "text/plain", getOffer([]byte("text/plain;q=0.899,text/plain;a=1;q=0.898"), acceptsOfferType, "text/plain", "text/plain;a=1")) require.Equal(t, "text/plain;a=1;b=2", getOffer([]byte("text/plain,text/plain;a=1,text/plain;a=1;b=2"), acceptsOfferType, "text/plain", "text/plain;a=1", "text/plain;a=1;b=2")) // Takes the last value specified require.Equal(t, "text/plain;a=1;b=2", getOffer([]byte("text/plain;a=1;b=1;B=2"), acceptsOfferType, "text/plain;a=1;b=1", "text/plain;a=1;b=2")) require.Empty(t, getOffer([]byte("utf-8, iso-8859-1;q=0.5"), acceptsOffer)) require.Empty(t, getOffer([]byte("utf-8, iso-8859-1;q=0.5"), acceptsOffer, "ascii")) require.Equal(t, "utf-8", getOffer([]byte("utf-8, iso-8859-1;q=0.5"), acceptsOffer, "utf-8")) require.Equal(t, "iso-8859-1", getOffer([]byte("utf-8;q=0, iso-8859-1;q=0.5"), acceptsOffer, "utf-8", "iso-8859-1")) // Accept-Charset wildcard coverage require.Equal(t, "utf-8", getOffer([]byte("utf-*"), acceptsOffer, "utf-8")) require.Equal(t, "UTF-16", getOffer([]byte("utf-*"), acceptsOffer, "UTF-16", "iso-8859-1")) require.Empty(t, getOffer([]byte("utf-*"), acceptsOffer, "iso-8859-1")) require.Empty(t, getOffer([]byte("utf-*"), acceptsOffer, "utf")) require.Empty(t, getOffer([]byte("utf-*"), acceptsOffer, "x-utf-8")) // Complex wildcard negotiation require.Equal(t, "utf-16le", getOffer([]byte("utf-8;q=0.4, utf-*;q=0.8, iso-8859-1;q=0.6"), acceptsOffer, "iso-8859-1", "utf-16le")) require.Equal(t, "iso-8859-1", getOffer([]byte("utf-*;q=0.9, iso-8859-1;q=1"), acceptsOffer, "x-utf-16", "iso-8859-1")) require.Empty(t, getOffer([]byte("utf-*;q=0.5, iso-8859-1;q=0.4"), acceptsOffer, "ascii", "us-ascii")) require.Equal(t, "deflate", getOffer([]byte("gzip, deflate"), acceptsOffer, "deflate")) require.Empty(t, getOffer([]byte("gzip, deflate;q=0"), acceptsOffer, "deflate")) // Accept-Language Basic Filtering require.True(t, acceptsLanguageOfferBasic("en", "en-US", nil)) require.False(t, acceptsLanguageOfferBasic("en-US", "en", nil)) require.True(t, acceptsLanguageOfferBasic("EN", "en-us", nil)) require.False(t, acceptsLanguageOfferBasic("en", "en_US", nil)) require.Equal(t, "en-US", getOffer([]byte("fr-CA;q=0.8, en-US"), acceptsLanguageOfferBasic, "en-US", "fr-CA")) require.Empty(t, getOffer([]byte("xx"), acceptsLanguageOfferBasic, "en")) require.False(t, acceptsLanguageOfferBasic("en-*", "en-US", nil)) require.True(t, acceptsLanguageOfferBasic("*", "en-US", nil)) // Accept-Language Extended Filtering require.True(t, acceptsLanguageOfferExtended("en", "en-US", nil)) require.True(t, acceptsLanguageOfferExtended("en", "en-Latn-US", nil)) require.True(t, acceptsLanguageOfferExtended("en-*", "en-US", nil)) require.True(t, acceptsLanguageOfferExtended("*-US", "en-US", nil)) require.True(t, acceptsLanguageOfferExtended("en-US-*", "en-US", nil)) require.True(t, acceptsLanguageOfferExtended("en-*", "en-US-CA", nil)) require.False(t, acceptsLanguageOfferExtended("en-US", "en-GB", nil)) require.False(t, acceptsLanguageOfferExtended("fr", "en-US", nil)) require.False(t, acceptsLanguageOfferExtended("", "en-US", nil)) require.False(t, acceptsLanguageOfferExtended("en", "", nil)) require.True(t, acceptsLanguageOfferExtended("*", "en-US", nil)) require.True(t, acceptsLanguageOfferExtended("en-*", "en", nil)) require.Equal(t, "en-US", getOffer([]byte("fr-CA;q=0.8, en-*"), acceptsLanguageOfferExtended, "en-US", "fr-CA")) // Sliding and singleton barriers require.True(t, acceptsLanguageOfferExtended("de-*-DE", "de-DE", nil)) require.True(t, acceptsLanguageOfferExtended("de-*-DE", "de-DE-x-goethe", nil)) require.True(t, acceptsLanguageOfferExtended("de-*-DE", "de-Latn-DE-1996", nil)) require.False(t, acceptsLanguageOfferExtended("de-*-DE", "de", nil)) require.False(t, acceptsLanguageOfferExtended("de-*-DE", "de-x-DE", nil)) require.True(t, acceptsLanguageOfferExtended("*-CH", "de-CH", nil)) require.True(t, acceptsLanguageOfferExtended("*-CH", "de-Latn-CH", nil)) } func Test_ReadContentReturnsBytes(t *testing.T) { t.Parallel() content := []byte("fiber read content test") tempFile, err := os.CreateTemp("", "fiber-read-content-*.txt") require.NoError(t, err) t.Cleanup(func() { require.NoError(t, os.Remove(tempFile.Name())) }) _, err = tempFile.Write(content) require.NoError(t, err) require.NoError(t, tempFile.Close()) var buffer bytes.Buffer n, err := readContent(&buffer, tempFile.Name()) require.NoError(t, err) require.Equal(t, int64(len(content)), n) require.Equal(t, content, buffer.Bytes()) } type wrappedListener struct { net.Listener } type tlsConfigMethodListener struct { net.Listener cfg *tls.Config } func (ln *tlsConfigMethodListener) TLSConfig() *tls.Config { return ln.cfg } type configMethodListener struct { net.Listener cfg *tls.Config } func (ln *configMethodListener) Config() *tls.Config { return ln.cfg } func Test_GetTLSConfig(t *testing.T) { t.Parallel() t.Run("tls listener", func(t *testing.T) { t.Parallel() base := newLocalListener(t) cfg := &tls.Config{MinVersion: tls.VersionTLS12} tlsListener := tls.NewListener(base, cfg) t.Cleanup(func() { require.NoError(t, tlsListener.Close()) }) require.Same(t, cfg, getTLSConfig(tlsListener), "*tls.Listener should expose its TLS config") }) t.Run("wrapped tls listener", func(t *testing.T) { t.Parallel() base := newLocalListener(t) cfg := &tls.Config{MinVersion: tls.VersionTLS13} tlsListener := tls.NewListener(base, cfg) wrapped := &wrappedListener{Listener: tlsListener} t.Cleanup(func() { require.NoError(t, wrapped.Close()) }) require.Nil(t, getTLSConfig(wrapped), "wrapping without Config()-like methods should return nil") }) t.Run("listener with tls config method", func(t *testing.T) { t.Parallel() base := newLocalListener(t) cfg := &tls.Config{MinVersion: tls.VersionTLS13} listener := &tlsConfigMethodListener{Listener: base, cfg: cfg} t.Cleanup(func() { require.NoError(t, listener.Close()) }) require.Same(t, cfg, getTLSConfig(listener), "TLSConfig() should be preferred for TLS discovery") }) t.Run("listener with config method", func(t *testing.T) { t.Parallel() base := newLocalListener(t) cfg := &tls.Config{MinVersion: tls.VersionTLS12} listener := &configMethodListener{Listener: base, cfg: cfg} t.Cleanup(func() { require.NoError(t, listener.Close()) }) require.Same(t, cfg, getTLSConfig(listener), "Config() should be preferred for TLS discovery") }) t.Run("non tls listener", func(t *testing.T) { t.Parallel() base := newLocalListener(t) t.Cleanup(func() { require.NoError(t, base.Close()) }) require.Nil(t, getTLSConfig(base), "plain listeners should not report TLS config") }) } func newLocalListener(t *testing.T) net.Listener { t.Helper() ln, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) return ln } // go test -v -run=^$ -bench=Benchmark_Utils_GetOffer -benchmem -count=4 func Benchmark_Utils_GetOffer(b *testing.B) { testCases := []struct { description string accept string offers []string }{ { description: "simple", accept: "application/json", offers: []string{"application/json"}, }, { description: "6 offers", accept: "text/plain", offers: []string{"junk/a", "junk/b", "junk/c", "junk/d", "junk/e", "text/plain"}, }, { description: "1 parameter", accept: "application/json; version=1", offers: []string{"application/json;version=1"}, }, { description: "2 parameters", accept: "application/json; version=1; foo=bar", offers: []string{"application/json;version=1;foo=bar"}, }, { description: "3 parameters", accept: "application/json; version=1; foo=bar; charset=utf-8", offers: []string{"application/json;version=1;foo=bar;charset=utf-8"}, }, { description: "10 parameters", accept: "text/plain;a=1;b=2;c=3;d=4;e=5;f=6;g=7;h=8;i=9;j=10", offers: []string{"text/plain;a=1;b=2;c=3;d=4;e=5;f=6;g=7;h=8;i=9;j=10"}, }, { description: "6 offers w/params", accept: "text/plain; format=flowed", offers: []string{ "junk/a;a=b", "junk/b;b=c", "junk/c;c=d", "text/plain; format=justified", "text/plain; format=flat", "text/plain; format=flowed", }, }, { description: "mime extension", accept: "utf-8, iso-8859-1;q=0.5", offers: []string{"utf-8"}, }, { description: "mime extension", accept: "utf-8, iso-8859-1;q=0.5", offers: []string{"iso-8859-1"}, }, { description: "mime extension", accept: "utf-8, iso-8859-1;q=0.5", offers: []string{"iso-8859-1", "utf-8"}, }, { description: "mime extension", accept: "gzip, deflate", offers: []string{"deflate"}, }, { description: "web browser", accept: "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8", offers: []string{"text/html", "application/xml", "application/xml+xhtml"}, }, } b.ReportAllocs() for _, tc := range testCases { accept := []byte(tc.accept) b.Run(tc.description, func(b *testing.B) { for b.Loop() { getOffer(accept, acceptsOfferType, tc.offers...) } }) } } func Test_Utils_ParamsMatch(t *testing.T) { testCases := []struct { description string accept headerParams offer string match bool }{ { description: "empty accept and offer", accept: nil, offer: "", match: true, }, { description: "accept is empty, offer has params", accept: make(headerParams), offer: ";foo=bar", match: true, }, { description: "offer is empty, accept has params", accept: headerParams{"foo": []byte("bar")}, offer: "", match: false, }, { description: "accept has extra parameters", accept: headerParams{"foo": []byte("bar"), "a": []byte("1")}, offer: ";foo=bar", match: false, }, { description: "matches regardless of order", accept: headerParams{"b": []byte("2"), "a": []byte("1")}, offer: ";b=2;a=1", match: true, }, { description: "case-insensitive", accept: headerParams{"ParaM": []byte("FoO")}, offer: ";pAram=foO", match: true, }, } for _, tc := range testCases { require.Equal(t, tc.match, paramsMatch(tc.accept, tc.offer), tc.description) } } func Benchmark_Utils_ParamsMatch(b *testing.B) { var match bool specParams := headerParams{ "appLe": []byte("orange"), "param": []byte("foo"), } b.ReportAllocs() for b.Loop() { match = paramsMatch(specParams, `;param=foo; apple=orange`) } require.True(b, match) } func Test_Utils_AcceptsOfferType(t *testing.T) { t.Parallel() testCases := []struct { description string spec string specParams headerParams offerType string accepts bool }{ { description: "no params, matching", spec: "application/json", offerType: "application/json", accepts: true, }, { description: "no params, mismatch", spec: "application/json", offerType: "application/xml", accepts: false, }, { description: "mismatch with subtype prefix", spec: "application/json", offerType: "application/json+xml", accepts: false, }, { description: "params match", spec: "application/json", specParams: headerParams{"format": []byte("foo"), "version": []byte("1")}, offerType: "application/json;version=1;format=foo;q=0.1", accepts: true, }, { description: "spec has extra params", spec: "text/html", specParams: headerParams{"charset": []byte("utf-8")}, offerType: "text/html", accepts: false, }, { description: "offer has extra params", spec: "text/html", offerType: "text/html;charset=utf-8", accepts: true, }, { description: "ignores optional whitespace", spec: "application/json", specParams: headerParams{"format": []byte("foo"), "version": []byte("1")}, offerType: "application/json; version=1 ; format=foo ", accepts: true, }, { description: "ignores optional whitespace", spec: "application/json", specParams: headerParams{"format": []byte("foo bar"), "version": []byte("1")}, offerType: `application/json;version="1";format="foo bar"`, accepts: true, }, } for _, tc := range testCases { accepts := acceptsOfferType(tc.spec, tc.offerType, tc.specParams) require.Equal(t, tc.accepts, accepts, tc.description) } } func Test_Utils_GetSplicedStrList(t *testing.T) { t.Parallel() testCases := []struct { description string headerValue string expectedList []string }{ { description: "normal case", headerValue: "gzip, deflate,br", expectedList: []string{"gzip", "deflate", "br"}, }, { description: "no matter the value", headerValue: " gzip,deflate, br, zip", expectedList: []string{"gzip", "deflate", "br", "zip"}, }, { description: "comma with trailing spaces around values", headerValue: "gzip , br", expectedList: []string{"gzip", "br"}, }, { description: "comma with tabbed whitespace", headerValue: "gzip\t,br", expectedList: []string{"gzip", "br"}, }, { description: "headerValue is empty", headerValue: "", expectedList: nil, }, { description: "has a comma without element", headerValue: "gzip,", expectedList: []string{"gzip", ""}, }, { description: "has a space between words", headerValue: " foo bar, hello world", expectedList: []string{"foo bar", "hello world"}, }, { description: "single comma", headerValue: ",", expectedList: []string{"", ""}, }, { description: "multiple comma", headerValue: ",,", expectedList: []string{"", "", ""}, }, { description: "comma with space", headerValue: ", ,", expectedList: []string{"", "", ""}, }, } for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { tc := tc // create a new 'tc' variable for the goroutine t.Parallel() dst := make([]string, 10) result := getSplicedStrList(tc.headerValue, dst) require.Equal(t, tc.expectedList, result) }) } } func Benchmark_Utils_GetSplicedStrList(b *testing.B) { destination := make([]string, 5) result := destination const input = `deflate, gzip,br,brotli,zstd` b.ReportAllocs() for b.Loop() { result = getSplicedStrList(input, destination) } require.Equal(b, []string{"deflate", "gzip", "br", "brotli", "zstd"}, result) } func Test_Utils_SortAcceptedTypes(t *testing.T) { t.Parallel() acceptedTypes := []acceptedType{ {spec: "text/html", quality: 1, specificity: 3, order: 0}, {spec: "text/*", quality: 0.5, specificity: 2, order: 1}, {spec: "*/*", quality: 0.1, specificity: 1, order: 2}, {spec: "application/xml", quality: 1, specificity: 3, order: 4}, {spec: "application/pdf", quality: 1, specificity: 3, order: 5}, {spec: "image/png", quality: 1, specificity: 3, order: 6}, {spec: "image/jpeg", quality: 1, specificity: 3, order: 7}, {spec: "image/*", quality: 1, specificity: 2, order: 8}, {spec: "image/gif", quality: 1, specificity: 3, order: 9}, {spec: "text/plain", quality: 1, specificity: 3, order: 10}, {spec: "application/json", quality: 0.999, specificity: 3, params: headerParams{"a": []byte("1")}, order: 11}, {spec: "application/json", quality: 0.999, specificity: 3, order: 3}, } sortAcceptedTypes(acceptedTypes) require.Equal(t, []acceptedType{ {spec: "text/html", quality: 1, specificity: 3, order: 0}, {spec: "application/xml", quality: 1, specificity: 3, order: 4}, {spec: "application/pdf", quality: 1, specificity: 3, order: 5}, {spec: "image/png", quality: 1, specificity: 3, order: 6}, {spec: "image/jpeg", quality: 1, specificity: 3, order: 7}, {spec: "image/gif", quality: 1, specificity: 3, order: 9}, {spec: "text/plain", quality: 1, specificity: 3, order: 10}, {spec: "image/*", quality: 1, specificity: 2, order: 8}, {spec: "application/json", quality: 0.999, specificity: 3, params: headerParams{"a": []byte("1")}, order: 11}, {spec: "application/json", quality: 0.999, specificity: 3, order: 3}, {spec: "text/*", quality: 0.5, specificity: 2, order: 1}, {spec: "*/*", quality: 0.1, specificity: 1, order: 2}, }, acceptedTypes) } // go test -v -run=^$ -bench=Benchmark_Utils_SortAcceptedTypes_Sorted -benchmem -count=4 func Benchmark_Utils_SortAcceptedTypes_Sorted(b *testing.B) { acceptedTypes := make([]acceptedType, 3) b.ReportAllocs() for b.Loop() { acceptedTypes[0] = acceptedType{spec: "text/html", quality: 1, specificity: 1, order: 0} acceptedTypes[1] = acceptedType{spec: "text/*", quality: 0.5, specificity: 1, order: 1} acceptedTypes[2] = acceptedType{spec: "*/*", quality: 0.1, specificity: 1, order: 2} sortAcceptedTypes(acceptedTypes) } require.Equal(b, "text/html", acceptedTypes[0].spec) require.Equal(b, "text/*", acceptedTypes[1].spec) require.Equal(b, "*/*", acceptedTypes[2].spec) } // go test -v -run=^$ -bench=Benchmark_Utils_SortAcceptedTypes_Unsorted -benchmem -count=4 func Benchmark_Utils_SortAcceptedTypes_Unsorted(b *testing.B) { acceptedTypes := make([]acceptedType, 11) b.ReportAllocs() for b.Loop() { acceptedTypes[0] = acceptedType{spec: "text/html", quality: 1, specificity: 3, order: 0} acceptedTypes[1] = acceptedType{spec: "text/*", quality: 0.5, specificity: 2, order: 1} acceptedTypes[2] = acceptedType{spec: "*/*", quality: 0.1, specificity: 1, order: 2} acceptedTypes[3] = acceptedType{spec: "application/json", quality: 0.999, specificity: 3, order: 3} acceptedTypes[4] = acceptedType{spec: "application/xml", quality: 1, specificity: 3, order: 4} acceptedTypes[5] = acceptedType{spec: "application/pdf", quality: 1, specificity: 3, order: 5} acceptedTypes[6] = acceptedType{spec: "image/png", quality: 1, specificity: 3, order: 6} acceptedTypes[7] = acceptedType{spec: "image/jpeg", quality: 1, specificity: 3, order: 7} acceptedTypes[8] = acceptedType{spec: "image/*", quality: 1, specificity: 2, order: 8} acceptedTypes[9] = acceptedType{spec: "image/gif", quality: 1, specificity: 3, order: 9} acceptedTypes[10] = acceptedType{spec: "text/plain", quality: 1, specificity: 3, order: 10} sortAcceptedTypes(acceptedTypes) } require.Equal(b, []acceptedType{ {spec: "text/html", quality: 1, specificity: 3, order: 0}, {spec: "application/xml", quality: 1, specificity: 3, order: 4}, {spec: "application/pdf", quality: 1, specificity: 3, order: 5}, {spec: "image/png", quality: 1, specificity: 3, order: 6}, {spec: "image/jpeg", quality: 1, specificity: 3, order: 7}, {spec: "image/gif", quality: 1, specificity: 3, order: 9}, {spec: "text/plain", quality: 1, specificity: 3, order: 10}, {spec: "image/*", quality: 1, specificity: 2, order: 8}, {spec: "application/json", quality: 0.999, specificity: 3, order: 3}, {spec: "text/*", quality: 0.5, specificity: 2, order: 1}, {spec: "*/*", quality: 0.1, specificity: 1, order: 2}, }, acceptedTypes) } func Test_Utils_UniqueRouteStack(t *testing.T) { t.Parallel() route1 := &Route{} route2 := &Route{} route3 := &Route{} require.Equal( t, []*Route{ route1, route2, route3, }, uniqueRouteStack([]*Route{ route1, route1, route1, route2, route2, route2, route3, route3, route3, route1, route2, route3, })) } func Test_Utils_getGroupPath(t *testing.T) { t.Parallel() res := getGroupPath("/v1", "/") require.Equal(t, "/v1/", res) res = getGroupPath("/v1/", "/") require.Equal(t, "/v1/", res) res = getGroupPath("/", "/") require.Equal(t, "/", res) res = getGroupPath("/v1/api/", "/") require.Equal(t, "/v1/api/", res) res = getGroupPath("/v1/api", "group") require.Equal(t, "/v1/api/group", res) res = getGroupPath("/v1/api", "") require.Equal(t, "/v1/api", res) } // go test -v -run=^$ -bench=Benchmark_Utils_ -benchmem -count=3 func Benchmark_Utils_getGroupPath(b *testing.B) { var res string b.ReportAllocs() for b.Loop() { _ = getGroupPath("/v1/long/path/john/doe", "/why/this/name/is/so/awesome") _ = getGroupPath("/v1", "/") _ = getGroupPath("/v1", "/api") res = getGroupPath("/v1", "/api/register/:project") } require.Equal(b, "/v1/api/register/:project", res) } func Benchmark_Utils_Unescape(b *testing.B) { unescaped := "" dst := make([]byte, 0) b.ReportAllocs() for b.Loop() { source := "/cr%C3%A9er" pathBytes := utils.UnsafeBytes(source) pathBytes = fasthttp.AppendUnquotedArg(dst[:0], pathBytes) unescaped = utils.UnsafeString(pathBytes) } require.Equal(b, "/créer", unescaped) } func Test_Utils_Parse_Address(t *testing.T) { t.Parallel() testCases := []struct { addr, host, port string }{ {addr: "[::1]:3000", host: "[::1]", port: "3000"}, {addr: "127.0.0.1:3000", host: "127.0.0.1", port: "3000"}, {addr: "[::1]", host: "[::1]", port: ""}, {addr: "2001:db8::1", host: "2001:db8::1", port: ""}, {addr: "/path/to/unix/socket", host: "/path/to/unix/socket", port: ""}, {addr: "127.0.0.1", host: "127.0.0.1", port: ""}, {addr: "localhost:8080", host: "localhost", port: "8080"}, {addr: "example.com", host: "example.com", port: ""}, {addr: "[fe80::1%lo0]:1234", host: "[fe80::1%lo0]", port: "1234"}, {addr: "[fe80::1%lo0]", host: "[fe80::1%lo0]", port: ""}, {addr: ":9090", host: "", port: "9090"}, {addr: " 127.0.0.1:8080 ", host: "127.0.0.1", port: "8080"}, {addr: "", host: "", port: ""}, } for _, c := range testCases { host, port := parseAddr(c.addr) require.Equal(t, c.host, host, "addr host: %q", c.addr) require.Equal(t, c.port, port, "addr port: %q", c.addr) } } func Test_Utils_TestConn_Deadline(t *testing.T) { t.Parallel() conn := &testConn{} require.NoError(t, conn.SetDeadline(time.Time{})) require.NoError(t, conn.SetReadDeadline(time.Time{})) require.NoError(t, conn.SetWriteDeadline(time.Time{})) } func Test_Utils_TestConn_ReadWrite(t *testing.T) { t.Parallel() conn := &testConn{} // Verify read of request _, err := conn.r.Write([]byte("Request")) require.NoError(t, err) req := make([]byte, 7) _, err = conn.Read(req) require.NoError(t, err) require.Equal(t, []byte("Request"), req) // Verify write of response _, err = conn.Write([]byte("Response")) require.NoError(t, err) res := make([]byte, 8) _, err = conn.w.Read(res) require.NoError(t, err) require.Equal(t, []byte("Response"), res) } func Test_Utils_TestConn_Closed_Write(t *testing.T) { t.Parallel() conn := &testConn{} // Verify write of response _, err := conn.Write([]byte("Response 1\n")) require.NoError(t, err) // Close early, write should fail conn.Close() //nolint:errcheck // It is fine to ignore the error here _, err = conn.Write([]byte("Response 2\n")) require.ErrorIs(t, err, errTestConnClosed) res := make([]byte, 11) _, err = conn.w.Read(res) require.NoError(t, err) require.Equal(t, []byte("Response 1\n"), res) } func Test_Utils_IsNoCache(t *testing.T) { t.Parallel() testCases := []struct { string bool }{ {string: "public", bool: false}, {string: "no-cache", bool: true}, {string: "public, no-cache, max-age=30", bool: true}, {string: "public,no-cache", bool: true}, {string: "public,no-cacheX", bool: false}, {string: "no-cache, public", bool: true}, {string: "Xno-cache, public", bool: false}, {string: "max-age=30, no-cache,public", bool: true}, {string: "NO-CACHE", bool: true}, {string: "public, NO-CACHE", bool: true}, // RFC 9111 §5.2.2.4: no-cache with field-name argument {string: "no-cache=\"Set-Cookie\"", bool: true}, {string: "public, no-cache=\"Set-Cookie, Set-Cookie2\"", bool: true}, {string: "no-cache=Set-Cookie", bool: true}, // Edge cases with spaces {string: "no-cache ,public", bool: true}, {string: "public, no-cache =field", bool: true}, } for _, c := range testCases { ok := isNoCache(c.string) require.Equal(t, c.bool, ok, "want %t, got isNoCache(%s)=%t", c.bool, c.string, ok) } } // go test -v -run=^$ -bench=Benchmark_Utils_IsNoCache -benchmem -count=4 func Benchmark_Utils_IsNoCache(b *testing.B) { var ok bool b.ReportAllocs() for b.Loop() { _ = isNoCache("public") _ = isNoCache("no-cache") _ = isNoCache("public, no-cache, max-age=30") _ = isNoCache("public,no-cache") _ = isNoCache("no-cache, public") ok = isNoCache("max-age=30, no-cache,public") } require.True(b, ok) } // go test -run Test_HeaderContainsValue func Test_HeaderContainsValue(t *testing.T) { t.Parallel() testCases := []struct { header string value string expected bool }{ // Exact match {header: "gzip", value: "gzip", expected: true}, {header: "gzip", value: "deflate", expected: false}, // Prefix match (value at start with comma) {header: "gzip, deflate", value: "gzip", expected: true}, {header: "gzip,deflate", value: "gzip", expected: true}, // Suffix match (value at end) {header: "deflate, gzip", value: "gzip", expected: true}, {header: "deflate,gzip", value: "gzip", expected: true}, // No space - OWS is optional per RFC 9110 {header: "br, gzip", value: "gzip", expected: true}, // Middle match (value in middle) {header: "deflate, gzip, br", value: "gzip", expected: true}, {header: "deflate,gzip,br", value: "gzip", expected: true}, // No spaces - OWS is optional per RFC 9110 // No match - similar but not equal {header: "gzip2", value: "gzip", expected: false}, {header: "2gzip", value: "gzip", expected: false}, {header: "gzip2, deflate", value: "gzip", expected: false}, // Whitespace handling (OWS per RFC 9110) {header: " gzip , deflate ", value: "gzip", expected: true}, {header: "deflate, gzip ", value: "gzip", expected: true}, // Empty cases {header: "", value: "gzip", expected: false}, {header: "gzip", value: "", expected: false}, {header: "", value: "", expected: false}, // Both empty - should return false } for _, tc := range testCases { result := headerContainsValue(tc.header, tc.value) require.Equal(t, tc.expected, result, "headerContainsValue(%q, %q) = %v, want %v", tc.header, tc.value, result, tc.expected) } } // go test -v -run=^$ -bench=Benchmark_HeaderContainsValue -benchmem -count=4 func Benchmark_HeaderContainsValue(b *testing.B) { var ok bool b.ReportAllocs() for b.Loop() { _ = headerContainsValue("gzip", "gzip") _ = headerContainsValue("gzip, deflate, br", "deflate") _ = headerContainsValue("deflate, gzip", "gzip") ok = headerContainsValue("deflate, gzip, br", "gzip") } require.True(b, ok) } type testGenericParseTypeIntCase struct { value int64 bits int } // go test -run Test_GenericParseTypeInts func Test_GenericParseTypeInts(t *testing.T) { t.Parallel() ints := []testGenericParseTypeIntCase{ { value: 0, bits: 8, }, { value: 1, bits: 8, }, { value: 2, bits: 8, }, { value: 3, bits: 8, }, { value: 4, bits: 8, }, { value: -1, bits: 8, }, { value: math.MaxInt8, bits: 8, }, { value: math.MinInt8, bits: 8, }, { value: math.MaxInt16, bits: 16, }, { value: math.MinInt16, bits: 16, }, { value: math.MaxInt32, bits: 32, }, { value: math.MinInt32, bits: 32, }, { value: math.MaxInt64, bits: 64, }, { value: math.MinInt64, bits: 64, }, } testGenericTypeInt[int8](t, "test_genericParseTypeInt8s", ints) testGenericTypeInt[int16](t, "test_genericParseTypeInt16s", ints) testGenericTypeInt[int32](t, "test_genericParseTypeInt32s", ints) testGenericTypeInt[int64](t, "test_genericParseTypeInt64s", ints) testGenericTypeInt[int](t, "test_genericParseTypeInts", ints) } func testGenericTypeInt[V GenericTypeInteger](t *testing.T, name string, cases []testGenericParseTypeIntCase) { t.Helper() t.Run(name, func(t *testing.T) { t.Parallel() for _, test := range cases { v, err := genericParseType[V](strconv.FormatInt(test.value, 10)) if test.bits <= int(unsafe.Sizeof(V(0)))*8 { require.NoError(t, err) require.Equal(t, V(test.value), v) } else { require.ErrorIs(t, err, strconv.ErrRange) } } testGenericParseError[V](t) }) } type testGenericParseTypeUintCase struct { value uint64 bits int } // go test -run Test_GenericParseTypeUints func Test_GenericParseTypeUints(t *testing.T) { t.Parallel() uints := []testGenericParseTypeUintCase{ { value: 0, bits: 8, }, { value: 1, bits: 8, }, { value: 2, bits: 8, }, { value: 3, bits: 8, }, { value: 4, bits: 8, }, { value: math.MaxUint8, bits: 8, }, { value: math.MaxUint16, bits: 16, }, { value: math.MaxUint32, bits: 32, }, { value: math.MaxUint64, bits: 64, }, } testGenericTypeUint[uint8](t, "test_genericParseTypeUint8s", uints) testGenericTypeUint[uint16](t, "test_genericParseTypeUint16s", uints) testGenericTypeUint[uint32](t, "test_genericParseTypeUint32s", uints) testGenericTypeUint[uint64](t, "test_genericParseTypeUint64s", uints) testGenericTypeUint[uint](t, "test_genericParseTypeUints", uints) } func testGenericTypeUint[V GenericTypeInteger](t *testing.T, name string, cases []testGenericParseTypeUintCase) { t.Helper() t.Run(name, func(t *testing.T) { t.Parallel() for _, test := range cases { v, err := genericParseType[V](strconv.FormatUint(test.value, 10)) if test.bits <= int(unsafe.Sizeof(V(0)))*8 { require.NoError(t, err) require.Equal(t, V(test.value), v) } else { require.ErrorIs(t, err, strconv.ErrRange) } } testGenericParseError[V](t) }) } // go test -run Test_GenericParseTypeFloats func Test_GenericParseTypeFloats(t *testing.T) { t.Parallel() floats := []struct { str string value float64 }{ { value: 3.1415, str: "3.1415", }, { value: 1.234, str: "1.234", }, { value: 2, str: "2", }, { value: 3, str: "3", }, } t.Run("test_genericParseTypeFloat32s", func(t *testing.T) { t.Parallel() for _, test := range floats { v, err := genericParseType[float32](test.str) require.NoError(t, err) require.InEpsilon(t, float32(test.value), v, epsilon) } testGenericParseError[float32](t) }) t.Run("test_genericParseTypeFloat64s", func(t *testing.T) { t.Parallel() for _, test := range floats { v, err := genericParseType[float64](test.str) require.NoError(t, err) require.InEpsilon(t, test.value, v, epsilon) } testGenericParseError[float64](t) }) } // go test -run Test_GenericParseTypeBytes func Test_GenericParseTypeBytes(t *testing.T) { t.Parallel() cases := []struct { str string err error value []byte }{ { value: []byte("alex"), str: "alex", }, { value: []byte("32.23"), str: "32.23", }, { value: []byte("john"), str: "john", }, { value: []byte(nil), str: "", err: errParsedEmptyBytes, }, } t.Run("test_genericParseTypeBytes", func(t *testing.T) { t.Parallel() for _, test := range cases { v, err := genericParseType[[]byte](test.str) if test.err == nil { require.NoError(t, err) } else { require.ErrorIs(t, err, test.err) } require.Equal(t, test.value, v) } }) } // go test -run Test_GenericParseTypeString func Test_GenericParseTypeString(t *testing.T) { t.Parallel() tests := []string{"john", "doe", "hello", "fiber"} for _, test := range tests { t.Run("test_genericParseTypeString", func(t *testing.T) { t.Parallel() v, err := genericParseType[string](test) require.NoError(t, err) require.Equal(t, test, v) }) } } // go test -run Test_GenericParseTypeBoolean func Test_GenericParseTypeBoolean(t *testing.T) { t.Parallel() bools := []struct { str string value bool }{ { str: "True", value: true, }, { str: "False", value: false, }, { str: "true", value: true, }, { str: "false", value: false, }, } t.Run("test_genericParseTypeBoolean", func(t *testing.T) { t.Parallel() for _, test := range bools { v, err := genericParseType[bool](test.str) require.NoError(t, err) if test.value { require.True(t, v) } else { require.False(t, v) } } testGenericParseError[bool](t) }) } func testGenericParseError[V GenericType](t *testing.T) { t.Helper() var expected V v, err := genericParseType[V]("invalid-string") require.Error(t, err) require.Equal(t, expected, v) } // go test -v -run=^$ -bench=Benchmark_GenericParseTypeInts -benchmem -count=4 func Benchmark_GenericParseTypeInts(b *testing.B) { b.Skip("Skipped: too fast to compare reliably (results in sub-ns range are unstable)") ints := []testGenericParseTypeIntCase{ { value: 0, bits: 8, }, { value: 1, bits: 8, }, { value: 2, bits: 8, }, { value: 3, bits: 8, }, { value: 4, bits: 8, }, { value: -1, bits: 8, }, { value: math.MaxInt8, bits: 8, }, { value: math.MinInt8, bits: 8, }, { value: math.MaxInt16, bits: 16, }, { value: math.MinInt16, bits: 16, }, { value: math.MaxInt32, bits: 32, }, { value: math.MinInt32, bits: 32, }, { value: math.MaxInt64, bits: 64, }, { value: math.MinInt64, bits: 64, }, } for _, test := range ints { benchGenericParseTypeInt[int8](b, "bench_genericParseTypeInt8s", test) benchGenericParseTypeInt[int16](b, "bench_genericParseTypeInt16s", test) benchGenericParseTypeInt[int32](b, "bench_genericParseTypeInt32s", test) benchGenericParseTypeInt[int64](b, "bench_genericParseTypeInt64s", test) benchGenericParseTypeInt[int](b, "bench_genericParseTypeInts", test) } } func benchGenericParseTypeInt[V GenericTypeInteger](b *testing.B, name string, test testGenericParseTypeIntCase) { b.Helper() b.Run(name, func(t *testing.B) { var v V var err error b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { v, err = genericParseType[V](strconv.FormatInt(test.value, 10)) } }) if test.bits <= int(unsafe.Sizeof(V(0)))*8 { require.NoError(t, err) require.Equal(t, V(test.value), v) } else { require.ErrorIs(t, err, strconv.ErrRange) } }) } // go test -v -run=^$ -bench=Benchmark_GenericParseTypeUints -benchmem -count=4 func Benchmark_GenericParseTypeUints(b *testing.B) { b.Skip("Skipped: too fast to compare reliably (results in sub-ns range are unstable)") uints := []struct { value uint64 bits int }{ { value: 0, bits: 8, }, { value: 1, bits: 8, }, { value: 2, bits: 8, }, { value: 3, bits: 8, }, { value: 4, bits: 8, }, { value: math.MaxUint8, bits: 8, }, { value: math.MaxUint16, bits: 16, }, { value: math.MaxUint16, bits: 16, }, { value: math.MaxUint32, bits: 32, }, { value: math.MaxUint64, bits: 64, }, } for _, test := range uints { benchGenericParseTypeUInt[uint8](b, "benchmark_genericParseTypeUint8s", test) benchGenericParseTypeUInt[uint16](b, "benchmark_genericParseTypeUint16s", test) benchGenericParseTypeUInt[uint32](b, "benchmark_genericParseTypeUint32s", test) benchGenericParseTypeUInt[uint64](b, "benchmark_genericParseTypeUint64s", test) benchGenericParseTypeUInt[uint](b, "benchmark_genericParseTypeUints", test) } } func benchGenericParseTypeUInt[V GenericTypeInteger](b *testing.B, name string, test testGenericParseTypeUintCase) { b.Helper() b.Run(name, func(t *testing.B) { var v V var err error b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { v, err = genericParseType[V](strconv.FormatUint(test.value, 10)) } }) if test.bits <= int(unsafe.Sizeof(V(0)))*8 { require.NoError(t, err) require.Equal(t, V(test.value), v) } else { require.ErrorIs(t, err, strconv.ErrRange) } }) } // go test -v -run=^$ -bench=Benchmark_GenericParseTypeFloats -benchmem -count=4 func Benchmark_GenericParseTypeFloats(b *testing.B) { b.Skip("Skipped: too fast to compare reliably (results in sub-ns range are unstable)") floats := []struct { str string value float64 }{ { value: 3.1415, str: "3.1415", }, { value: 1.234, str: "1.234", }, { value: 2, str: "2", }, { value: 3, str: "3", }, } for _, test := range floats { b.Run("benchmark_genericParseTypeFloat32s", func(t *testing.B) { var v float32 var err error b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { v, err = genericParseType[float32](test.str) } }) require.NoError(t, err) require.InEpsilon(t, float32(test.value), v, epsilon) }) } for _, test := range floats { b.Run("benchmark_genericParseTypeFloat64s", func(t *testing.B) { var v float64 var err error b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { v, err = genericParseType[float64](test.str) } }) require.NoError(t, err) require.InEpsilon(t, test.value, v, epsilon) }) } } // go test -v -run=^$ -bench=Benchmark_GenericParseTypeBytes -benchmem -count=4 func Benchmark_GenericParseTypeBytes(b *testing.B) { b.Skip("Skipped: too fast to compare reliably (results in sub-ns range are unstable)") cases := []struct { str string err error value []byte }{ { value: []byte("alex"), str: "alex", }, { value: []byte("32.23"), str: "32.23", }, { value: []byte("john"), str: "john", }, { value: []byte(nil), str: "", err: errParsedEmptyBytes, }, } for _, test := range cases { b.Run("benchmark_genericParseTypeBytes", func(b *testing.B) { var v []byte var err error b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { v, err = genericParseType[[]byte](test.str) } }) if test.err == nil { require.NoError(b, err) } else { require.ErrorIs(b, err, test.err) } require.Equal(b, test.value, v) }) } } // go test -v -run=^$ -bench=Benchmark_GenericParseTypeString -benchmem -count=4 func Benchmark_GenericParseTypeString(b *testing.B) { b.Skip("Skipped: too fast to compare reliably (results in sub-ns range are unstable)") tests := []string{"john", "doe", "hello", "fiber"} for _, test := range tests { b.Run("benchmark_genericParseTypeString", func(b *testing.B) { var v string var err error b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { v, err = genericParseType[string](test) } }) require.NoError(b, err) require.Equal(b, test, v) }) } } // go test -v -run=^$ -bench=Benchmark_GenericParseTypeBoolean -benchmem -count=4 func Benchmark_GenericParseTypeBoolean(b *testing.B) { b.Skip("Skipped: too fast to compare reliably (results in sub-ns range are unstable)") bools := []struct { str string value bool }{ { str: "True", value: true, }, { str: "False", value: false, }, { str: "true", value: true, }, { str: "false", value: false, }, } for _, test := range bools { b.Run("benchmark_genericParseTypeBoolean", func(b *testing.B) { var v bool var err error b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { v, err = genericParseType[bool](test.str) } }) require.NoError(b, err) if test.value { require.True(b, v) } else { require.False(b, v) } }) } } func Test_UnescapeHeaderValue(t *testing.T) { t.Parallel() cases := []struct { in string out []byte ok bool }{ {in: "abc", out: []byte("abc"), ok: true}, {in: "a\\\"b", out: []byte("a\"b"), ok: true}, {in: "c\\\\d", out: []byte("c\\d"), ok: true}, {in: "bad\\", ok: false}, } for _, tc := range cases { out, err := unescapeHeaderValue([]byte(tc.in)) if tc.ok { require.NoError(t, err, tc.in) require.Equal(t, tc.out, out, tc.in) } else { require.Error(t, err, tc.in) } } } func Test_JoinHeaderValues(t *testing.T) { t.Parallel() require.Nil(t, joinHeaderValues(nil)) require.Equal(t, []byte("a"), joinHeaderValues([][]byte{[]byte("a")})) require.Equal(t, []byte("a,b"), joinHeaderValues([][]byte{[]byte("a"), []byte("b")})) } func Test_ParamsMatch_InvalidEscape(t *testing.T) { t.Parallel() match := paramsMatch(headerParams{"foo": []byte("bar")}, `;foo="bar\\`) require.False(t, match) } func Test_MatchEtag(t *testing.T) { t.Parallel() require.True(t, matchEtag(`"a"`, `"a"`)) require.True(t, matchEtag(`W/"a"`, `"a"`)) require.True(t, matchEtag(`"a"`, `W/"a"`)) require.False(t, matchEtag(`"a"`, `"b"`)) require.False(t, matchEtag(`a`, `"a"`)) require.False(t, matchEtag(`"a"`, `b`)) } func Test_MatchEtagStrong(t *testing.T) { t.Parallel() require.True(t, matchEtagStrong(`"a"`, `"a"`)) require.False(t, matchEtagStrong(`W/"a"`, `"a"`)) require.False(t, matchEtagStrong(`"a"`, `W/"a"`)) require.False(t, matchEtagStrong(`"a"`, `"b"`)) require.False(t, matchEtagStrong(`a`, `"a"`)) require.False(t, matchEtagStrong(`"a"`, `b`)) } func Test_IsEtagStale(t *testing.T) { t.Parallel() app := New() // Invalid/unquoted tags are considered a mismatch, so it's stale require.True(t, app.isEtagStale(`"a"`, []byte("b"))) require.True(t, app.isEtagStale(`"a"`, []byte("a"))) // Matching tags, not stale require.False(t, app.isEtagStale(`"a"`, []byte(`"a"`))) require.False(t, app.isEtagStale(`W/"a"`, []byte(`"a"`))) // List of tags, not stale require.False(t, app.isEtagStale(`"c"`, []byte(`"a", "b", "c"`))) require.False(t, app.isEtagStale(`W/"c"`, []byte(`"a", "b", "c"`))) require.False(t, app.isEtagStale(`"c"`, []byte(`"a", "b", W/"c"`))) require.False(t, app.isEtagStale(`"c"`, []byte(`"c", "b", "a"`))) require.False(t, app.isEtagStale(`"c"`, []byte(` "a", "c" , "b" `))) // List of tags, stale require.True(t, app.isEtagStale(`"d"`, []byte(`"a", "b", "c"`))) require.True(t, app.isEtagStale(`W/"d"`, []byte(`"a", "b", "c"`))) // Wildcard require.False(t, app.isEtagStale(`"a"`, []byte("*"))) require.False(t, app.isEtagStale(`"a"`, []byte(" * "))) require.False(t, app.isEtagStale(`W/"a"`, []byte("*"))) // Empty case require.True(t, app.isEtagStale(`"a"`, []byte(""))) require.True(t, app.isEtagStale(`"a"`, []byte(" "))) // Weak vs. weak require.False(t, app.isEtagStale(`W/"a"`, []byte(`W/"a"`))) } func Test_App_quoteRawString(t *testing.T) { t.Parallel() cases := []struct { name string in string out string }{ {"empty", "", ""}, {"simple", "simple", "simple"}, {"backslash", "A\\B", "A\\\\B"}, {"quote", `He said "Yo"`, `He said \"Yo\"`}, {"newline", "Hello\n", "Hello\\n"}, {"carriage", "Hello\r", "Hello\\r"}, {"controls", string([]byte{0, 31, 127}), "%00%1F%7F"}, {"mixed", "test \"A\n\r" + string([]byte{1}) + "\\", `test \"A\n\r%01\\`}, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { t.Parallel() app := New() require.Equal(t, tc.out, app.quoteRawString(tc.in)) }) } } func TestStoreInContext(t *testing.T) { t.Parallel() app := New(Config{PassLocalsToContext: true}) raw := &fasthttp.RequestCtx{} c := app.AcquireCtx(raw) defer app.ReleaseCtx(c) StoreInContext(c, "key", "value") localValue, ok := c.Locals("key").(string) require.True(t, ok) require.Equal(t, "value", localValue) contextValue, ok := c.Context().Value("key").(string) require.True(t, ok) require.Equal(t, "value", contextValue) } func TestValueFromContext(t *testing.T) { t.Parallel() t.Run("fiber.Ctx", func(t *testing.T) { t.Parallel() app := New() raw := &fasthttp.RequestCtx{} c := app.AcquireCtx(raw) defer app.ReleaseCtx(c) c.Locals("key", "value") value, ok := ValueFromContext[string](c, "key") require.True(t, ok) require.Equal(t, "value", value) }) t.Run("fiber.CustomCtx", func(t *testing.T) { t.Parallel() app := NewWithCustomCtx(func(app *App) CustomCtx { return &customCtx{DefaultCtx: *NewDefaultCtx(app)} }) raw := &fasthttp.RequestCtx{} c := app.AcquireCtx(raw) defer app.ReleaseCtx(c) c.Locals("key", "value") value, ok := ValueFromContext[string](c, "key") require.True(t, ok) require.Equal(t, "value", value) }) t.Run("fasthttp request ctx", func(t *testing.T) { t.Parallel() raw := &fasthttp.RequestCtx{} raw.SetUserValue("key", "value") value, ok := ValueFromContext[string](raw, "key") require.True(t, ok) require.Equal(t, "value", value) }) t.Run("context.Context", func(t *testing.T) { t.Parallel() type testContextKey struct{} ctx := context.WithValue(context.Background(), testContextKey{}, "value") value, ok := ValueFromContext[string](ctx, testContextKey{}) require.True(t, ok) require.Equal(t, "value", value) }) t.Run("unsupported ctx", func(t *testing.T) { t.Parallel() value, ok := ValueFromContext[string](42, "key") require.False(t, ok) require.Empty(t, value) }) } ================================================ FILE: hooks.go ================================================ package fiber import ( "slices" "github.com/gofiber/fiber/v3/log" ) type ( // OnRouteHandler defines the hook signature invoked whenever a route is registered. OnRouteHandler = func(Route) error // OnNameHandler shares the OnRouteHandler signature for route naming callbacks. OnNameHandler = OnRouteHandler // OnGroupHandler defines the hook signature invoked whenever a group is registered. OnGroupHandler = func(Group) error // OnGroupNameHandler shares the OnGroupHandler signature for group naming callbacks. OnGroupNameHandler = OnGroupHandler // OnListenHandler runs when the application begins listening and receives the listener details. OnListenHandler = func(ListenData) error // OnPreStartupMessageHandler runs before Fiber prints the startup banner. OnPreStartupMessageHandler = func(*PreStartupMessageData) error // OnPostStartupMessageHandler runs after Fiber prints (or skips) the startup banner. OnPostStartupMessageHandler = func(*PostStartupMessageData) error // OnPreShutdownHandler runs before the application shuts down. OnPreShutdownHandler = func() error // OnPostShutdownHandler runs after shutdown and receives the shutdown result. OnPostShutdownHandler = func(error) error // OnForkHandler runs inside a forked worker process and receives the worker ID. OnForkHandler = func(int) error // OnMountHandler runs after a sub-application mounts to a parent and receives the parent app reference. OnMountHandler = func(*App) error ) // Hooks is a struct to use it with App. type Hooks struct { // Embed app app *App // Hooks onRoute []OnRouteHandler onName []OnNameHandler onGroup []OnGroupHandler onGroupName []OnGroupNameHandler onListen []OnListenHandler onPreStartup []OnPreStartupMessageHandler onPostStartup []OnPostStartupMessageHandler onPreShutdown []OnPreShutdownHandler onPostShutdown []OnPostShutdownHandler onFork []OnForkHandler onMount []OnMountHandler } type StartupMessageLevel int const ( // StartupMessageLevelInfo represents informational startup message entries. StartupMessageLevelInfo StartupMessageLevel = iota // StartupMessageLevelWarning represents warning startup message entries. StartupMessageLevelWarning // StartupMessageLevelError represents error startup message entries. StartupMessageLevelError ) const errString = "ERROR" // startupMessageEntry represents a single line of startup message information. type startupMessageEntry struct { key string title string value string priority int level StartupMessageLevel } // ListenData contains the listener metadata provided to OnListenHandler. type ListenData struct { ColorScheme Colors Host string Port string Version string AppName string ChildPIDs []int HandlerCount int ProcessCount int PID int TLS bool Prefork bool } // PreStartupMessageData contains metadata exposed to OnPreStartupMessage hooks. type PreStartupMessageData struct { *ListenData // BannerHeader allows overriding the ASCII art banner displayed at startup. BannerHeader string entries []startupMessageEntry // PreventDefault, when set to true, suppresses the default startup message. PreventDefault bool } // AddInfo adds an informational entry to the startup message with "INFO" label. func (sm *PreStartupMessageData) AddInfo(key, title, value string, priority ...int) { pri := -1 if len(priority) > 0 { pri = priority[0] } sm.addEntry(key, title, value, pri, StartupMessageLevelInfo) } // AddWarning adds a warning entry to the startup message with "WARNING" label. func (sm *PreStartupMessageData) AddWarning(key, title, value string, priority ...int) { pri := -1 if len(priority) > 0 { pri = priority[0] } sm.addEntry(key, title, value, pri, StartupMessageLevelWarning) } // AddError adds an error entry to the startup message with "ERROR" label. func (sm *PreStartupMessageData) AddError(key, title, value string, priority ...int) { pri := -1 if len(priority) > 0 { pri = priority[0] } sm.addEntry(key, title, value, pri, StartupMessageLevelError) } // EntryKeys returns all entry keys currently present in the startup message. func (sm *PreStartupMessageData) EntryKeys() []string { keys := make([]string, 0, len(sm.entries)) for _, entry := range sm.entries { keys = append(keys, entry.key) } return keys } // ResetEntries removes all existing entries from the startup message. func (sm *PreStartupMessageData) ResetEntries() { sm.entries = sm.entries[:0] } // DeleteEntry removes a specific entry from the startup message by its key. func (sm *PreStartupMessageData) DeleteEntry(key string) { if sm.entries == nil { return } for i, entry := range sm.entries { if entry.key == key { sm.entries = append(sm.entries[:i], sm.entries[i+1:]...) return } } } func (sm *PreStartupMessageData) addEntry(key, title, value string, priority int, level StartupMessageLevel) { if sm.entries == nil { sm.entries = make([]startupMessageEntry, 0, 8) } for i, entry := range sm.entries { if entry.key != key { continue } sm.entries[i].value = value sm.entries[i].title = title sm.entries[i].level = level sm.entries[i].priority = priority return } sm.entries = append(sm.entries, startupMessageEntry{ key: key, title: title, value: value, priority: priority, level: level, }) } func newPreStartupMessageData(listenData *ListenData) *PreStartupMessageData { return &PreStartupMessageData{ListenData: listenData} } // PostStartupMessageData contains metadata exposed to OnPostStartupMessage hooks. type PostStartupMessageData struct { *ListenData // Disabled indicates whether the startup message was disabled via configuration. Disabled bool // IsChild indicates whether the current process is a child in prefork mode. IsChild bool // Prevented indicates whether the startup message was suppressed by a pre-startup hook using PreventDefault property. Prevented bool } func newPostStartupMessageData(listenData *ListenData, disabled, isChild, prevented bool) *PostStartupMessageData { clone := *listenData if len(listenData.ChildPIDs) > 0 { clone.ChildPIDs = slices.Clone(listenData.ChildPIDs) } return &PostStartupMessageData{ ListenData: &clone, Disabled: disabled, IsChild: isChild, Prevented: prevented, } } func newHooks(app *App) *Hooks { return &Hooks{ app: app, onRoute: make([]OnRouteHandler, 0), onGroup: make([]OnGroupHandler, 0), onGroupName: make([]OnGroupNameHandler, 0), onName: make([]OnNameHandler, 0), onListen: make([]OnListenHandler, 0), onPreStartup: make([]OnPreStartupMessageHandler, 0), onPostStartup: make([]OnPostStartupMessageHandler, 0), onPreShutdown: make([]OnPreShutdownHandler, 0), onPostShutdown: make([]OnPostShutdownHandler, 0), onFork: make([]OnForkHandler, 0), onMount: make([]OnMountHandler, 0), } } // OnRoute is a hook to execute user functions on each route registration. // Also you can get route properties by route parameter. func (h *Hooks) OnRoute(handler ...OnRouteHandler) { h.app.mutex.Lock() h.onRoute = append(h.onRoute, handler...) h.app.mutex.Unlock() } // OnName is a hook to execute user functions on each route naming. // Also you can get route properties by route parameter. // // WARN: OnName only works with naming routes, not groups. func (h *Hooks) OnName(handler ...OnNameHandler) { h.app.mutex.Lock() h.onName = append(h.onName, handler...) h.app.mutex.Unlock() } // OnGroup is a hook to execute user functions on each group registration. // Also you can get group properties by group parameter. func (h *Hooks) OnGroup(handler ...OnGroupHandler) { h.app.mutex.Lock() h.onGroup = append(h.onGroup, handler...) h.app.mutex.Unlock() } // OnGroupName is a hook to execute user functions on each group naming. // Also you can get group properties by group parameter. // // WARN: OnGroupName only works with naming groups, not routes. func (h *Hooks) OnGroupName(handler ...OnGroupNameHandler) { h.app.mutex.Lock() h.onGroupName = append(h.onGroupName, handler...) h.app.mutex.Unlock() } // OnListen is a hook to execute user functions on Listen or Listener. func (h *Hooks) OnListen(handler ...OnListenHandler) { h.app.mutex.Lock() h.onListen = append(h.onListen, handler...) h.app.mutex.Unlock() } // OnPreStartupMessage is a hook to execute user functions before the startup message is printed. func (h *Hooks) OnPreStartupMessage(handler ...OnPreStartupMessageHandler) { h.app.mutex.Lock() h.onPreStartup = append(h.onPreStartup, handler...) h.app.mutex.Unlock() } // OnPostStartupMessage is a hook to execute user functions after the startup message is printed (or skipped). func (h *Hooks) OnPostStartupMessage(handler ...OnPostStartupMessageHandler) { h.app.mutex.Lock() h.onPostStartup = append(h.onPostStartup, handler...) h.app.mutex.Unlock() } // OnPreShutdown is a hook to execute user functions before Shutdown. func (h *Hooks) OnPreShutdown(handler ...OnPreShutdownHandler) { h.app.mutex.Lock() h.onPreShutdown = append(h.onPreShutdown, handler...) h.app.mutex.Unlock() } // OnPostShutdown is a hook to execute user functions after Shutdown. func (h *Hooks) OnPostShutdown(handler ...OnPostShutdownHandler) { h.app.mutex.Lock() h.onPostShutdown = append(h.onPostShutdown, handler...) h.app.mutex.Unlock() } // OnFork is a hook to execute user function after fork process. func (h *Hooks) OnFork(handler ...OnForkHandler) { h.app.mutex.Lock() h.onFork = append(h.onFork, handler...) h.app.mutex.Unlock() } // OnMount is a hook to execute user function after mounting process. // The mount event is fired when sub-app is mounted on a parent app. The parent app is passed as a parameter. // It works for app and group mounting. func (h *Hooks) OnMount(handler ...OnMountHandler) { h.app.mutex.Lock() h.onMount = append(h.onMount, handler...) h.app.mutex.Unlock() } func (h *Hooks) executeOnRouteHooks(route *Route) error { if route == nil { return nil } cloned := *route // Check mounting if h.app.mountFields.mountPath != "" { cloned.path = h.app.mountFields.mountPath + cloned.path cloned.Path = cloned.path } for _, v := range h.onRoute { if err := v(cloned); err != nil { return err } } return nil } func (h *Hooks) executeOnNameHooks(route *Route) error { if route == nil { return nil } cloned := *route // Check mounting if h.app.mountFields.mountPath != "" { cloned.path = h.app.mountFields.mountPath + cloned.path cloned.Path = cloned.path } for _, v := range h.onName { if err := v(cloned); err != nil { return err } } return nil } func (h *Hooks) executeOnGroupHooks(group Group) error { // Check mounting if h.app.mountFields.mountPath != "" { group.Prefix = h.app.mountFields.mountPath + group.Prefix } for _, v := range h.onGroup { if err := v(group); err != nil { return err } } return nil } func (h *Hooks) executeOnGroupNameHooks(group Group) error { // Check mounting if h.app.mountFields.mountPath != "" { group.Prefix = h.app.mountFields.mountPath + group.Prefix } for _, v := range h.onGroupName { if err := v(group); err != nil { return err } } return nil } func (h *Hooks) executeOnListenHooks(listenData *ListenData) error { for _, v := range h.onListen { if err := v(*listenData); err != nil { return err } } return nil } func (h *Hooks) executeOnPreStartupMessageHooks(data *PreStartupMessageData) error { for _, handler := range h.onPreStartup { if err := handler(data); err != nil { return err } } return nil } func (h *Hooks) executeOnPostStartupMessageHooks(data *PostStartupMessageData) error { for _, handler := range h.onPostStartup { if err := handler(data); err != nil { return err } } return nil } func (h *Hooks) executeOnPreShutdownHooks() { for _, v := range h.onPreShutdown { if err := v(); err != nil { log.Errorf("failed to call pre shutdown hook: %v", err) } } } func (h *Hooks) executeOnPostShutdownHooks(err error) { for _, v := range h.onPostShutdown { if hookErr := v(err); hookErr != nil { log.Errorf("failed to call post shutdown hook: %v", hookErr) } } } func (h *Hooks) executeOnForkHooks(pid int) { for _, v := range h.onFork { if err := v(pid); err != nil { log.Errorf("failed to call fork hook: %v", err) } } } func (h *Hooks) executeOnMountHooks(app *App) error { for _, v := range h.onMount { if err := v(app); err != nil { return err } } return nil } ================================================ FILE: hooks_test.go ================================================ package fiber import ( "bytes" "errors" "os" "runtime" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/valyala/bytebufferpool" "github.com/gofiber/fiber/v3/log" ) const testMountPath = "/api" func testSimpleHandler(c Ctx) error { return c.SendString("simple") } func Test_Hook_OnRoute(t *testing.T) { t.Parallel() app := New() app.Hooks().OnRoute(func(r Route) error { require.Empty(t, r.Name) return nil }) app.Get("/", testSimpleHandler).Name("x") subApp := New() subApp.Get("/test", testSimpleHandler) app.Use("/sub", subApp) } func Test_Hook_OnRoute_Mount(t *testing.T) { t.Parallel() app := New() subApp := New() app.Use("/sub", subApp) subApp.Hooks().OnRoute(func(r Route) error { require.Equal(t, "/sub/test", r.Path) return nil }) app.Hooks().OnRoute(func(r Route) error { require.Equal(t, "/", r.Path) return nil }) app.Get("/", testSimpleHandler).Name("x") subApp.Get("/test", testSimpleHandler) } func Test_Hook_OnName(t *testing.T) { t.Parallel() app := New() buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) app.Hooks().OnName(func(r Route) error { _, err := buf.WriteString(r.Name) require.NoError(t, err) return nil }) app.Get("/", testSimpleHandler).Name("index") subApp := New() subApp.Get("/test", testSimpleHandler) subApp.Get("/test2", testSimpleHandler) app.Use("/sub", subApp) require.Equal(t, "index", buf.String()) } func Test_Hook_OnName_Error(t *testing.T) { t.Parallel() app := New() app.Hooks().OnName(func(_ Route) error { return errors.New("unknown error") }) require.PanicsWithError(t, "unknown error", func() { app.Get("/", testSimpleHandler).Name("index") }) } func Test_Hook_OnGroup(t *testing.T) { t.Parallel() app := New() buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) app.Hooks().OnGroup(func(g Group) error { _, err := buf.WriteString(g.Prefix) require.NoError(t, err) return nil }) grp := app.Group("/x").Name("x.") grp.Group("/a") require.Equal(t, "/x/x/a", buf.String()) } func Test_Hook_OnGroup_Mount(t *testing.T) { t.Parallel() app := New() micro := New() micro.Use("/john", app) app.Hooks().OnGroup(func(g Group) error { require.Equal(t, "/john/v1", g.Prefix) return nil }) v1 := app.Group("/v1") v1.Get("/doe", func(c Ctx) error { return c.SendStatus(StatusOK) }) } func Test_Hook_OnGroupName(t *testing.T) { t.Parallel() app := New() buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) buf2 := bytebufferpool.Get() defer bytebufferpool.Put(buf2) app.Hooks().OnGroupName(func(g Group) error { _, err := buf.WriteString(g.name) require.NoError(t, err) return nil }) app.Hooks().OnName(func(r Route) error { _, err := buf2.WriteString(r.Name) require.NoError(t, err) return nil }) grp := app.Group("/x").Name("x.") grp.Get("/test", testSimpleHandler).Name("test") grp.Get("/test2", testSimpleHandler) require.Equal(t, "x.", buf.String()) require.Equal(t, "x.test", buf2.String()) } func Test_Hook_OnGroupName_Error(t *testing.T) { t.Parallel() app := New() app.Hooks().OnGroupName(func(_ Group) error { return errors.New("unknown error") }) require.PanicsWithError(t, "unknown error", func() { _ = app.Group("/x").Name("x.") }) } func Test_Hook_OnPreShutdown(t *testing.T) { t.Parallel() app := New() buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) app.Hooks().OnPreShutdown(func() error { _, err := buf.WriteString("pre-shutdown") require.NoError(t, err) return nil }) require.NoError(t, app.Shutdown()) require.Equal(t, "pre-shutdown", buf.String()) } func Test_Hook_OnPostShutdown(t *testing.T) { t.Run("should execute post shutdown hook with error", func(t *testing.T) { app := New() expectedErr := errors.New("test shutdown error") hookCalled := make(chan error, 1) defer close(hookCalled) app.Hooks().OnPostShutdown(func(err error) error { hookCalled <- err return nil }) go func() { if err := app.Listen(":0"); err != nil { return } }() time.Sleep(100 * time.Millisecond) app.hooks.executeOnPostShutdownHooks(expectedErr) select { case err := <-hookCalled: require.Equal(t, expectedErr, err) case <-time.After(time.Second): t.Fatal("hook execution timeout") } require.NoError(t, app.Shutdown()) }) t.Run("should execute multiple hooks in order", func(t *testing.T) { app := New() execution := make([]int, 0) app.Hooks().OnPostShutdown(func(_ error) error { execution = append(execution, 1) return nil }) app.Hooks().OnPostShutdown(func(_ error) error { execution = append(execution, 2) return nil }) app.hooks.executeOnPostShutdownHooks(nil) require.Len(t, execution, 2, "expected 2 hooks to execute") require.Equal(t, []int{1, 2}, execution, "hooks executed in wrong order") }) t.Run("should handle hook error", func(_ *testing.T) { app := New() hookErr := errors.New("hook error") app.Hooks().OnPostShutdown(func(_ error) error { return hookErr }) // Should not panic app.hooks.executeOnPostShutdownHooks(nil) }) } func Test_Hook_OnListen(t *testing.T) { t.Parallel() app := New() buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) app.Hooks().OnListen(func(_ ListenData) error { _, err := buf.WriteString("ready") require.NoError(t, err) return nil }) go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listen(":0")) require.Equal(t, "ready", buf.String()) } func Test_ListenDataMetadata(t *testing.T) { t.Parallel() app := New(Config{AppName: "meta"}) app.handlersCount = 42 cfg := ListenConfig{EnablePrefork: true} childPIDs := []int{11, 22} listenData := app.prepareListenData(":3030", true, &cfg, childPIDs) app.Hooks().OnListen(func(data ListenData) error { require.Equal(t, globalIpv4Addr, data.Host) require.Equal(t, "3030", data.Port) require.True(t, data.TLS) require.Equal(t, Version, data.Version) require.Equal(t, "meta", data.AppName) require.Equal(t, 42, data.HandlerCount) require.Equal(t, runtime.GOMAXPROCS(0), data.ProcessCount) require.Equal(t, os.Getpid(), data.PID) require.True(t, data.Prefork) require.Equal(t, childPIDs, data.ChildPIDs) require.Equal(t, app.config.ColorScheme, data.ColorScheme) return nil }) app.runOnListenHooks(listenData) app.Hooks().OnPreStartupMessage(func(data *PreStartupMessageData) error { require.Equal(t, globalIpv4Addr, data.Host) require.Equal(t, "3030", data.Port) require.True(t, data.TLS) require.Equal(t, Version, data.Version) require.Equal(t, "meta", data.AppName) require.Equal(t, 42, data.HandlerCount) require.Equal(t, runtime.GOMAXPROCS(0), data.ProcessCount) require.Equal(t, os.Getpid(), data.PID) require.True(t, data.Prefork) require.Equal(t, childPIDs, data.ChildPIDs) require.Equal(t, app.config.ColorScheme, data.ColorScheme) data.ResetEntries() data.AddInfo("custom", "Custom Info", "value", 3) data.AddInfo("other", "Other Info", "value", 2) return nil }) pre := newPreStartupMessageData(listenData) require.NoError(t, app.hooks.executeOnPreStartupMessageHooks(pre)) require.Equal(t, "value", pre.entries[0].value) require.Equal(t, "Custom Info", pre.entries[0].title) require.Equal(t, 3, pre.entries[0].priority) require.Equal(t, "value", pre.entries[1].value) require.Equal(t, "Other Info", pre.entries[1].title) require.Equal(t, 2, pre.entries[1].priority) require.False(t, pre.PreventDefault) } func Test_ListenData_Hook_HelperFunctions(t *testing.T) { t.Parallel() t.Run("EntryKeys", func(t *testing.T) { t.Parallel() app := New() app.Hooks().OnPreStartupMessage(func(data *PreStartupMessageData) error { data.ResetEntries() data.AddInfo("key1", "Title 1", "Value 1", 1) data.AddInfo("key2", "Title 2", "Value 2", 2) keys := data.EntryKeys() require.Len(t, keys, 2) require.Equal(t, "key1", keys[0]) require.Equal(t, "key2", keys[1]) return nil }) pre := newPreStartupMessageData(&ListenData{}) require.NoError(t, app.hooks.executeOnPreStartupMessageHooks(pre)) }) t.Run("ResetEntries", func(t *testing.T) { t.Parallel() app := New() app.Hooks().OnPreStartupMessage(func(data *PreStartupMessageData) error { data.ResetEntries() data.AddInfo("key1", "Title 1", "Value 1", 1) data.AddInfo("key2", "Title 2", "Value 2", 2) require.Len(t, data.entries, 2) data.ResetEntries() require.Empty(t, data.entries) return nil }) pre := newPreStartupMessageData(&ListenData{}) require.NoError(t, app.hooks.executeOnPreStartupMessageHooks(pre)) }) t.Run("AddInfo", func(t *testing.T) { t.Parallel() app := New() app.Hooks().OnPreStartupMessage(func(data *PreStartupMessageData) error { data.ResetEntries() data.AddInfo("key1", "Title 1", "Value 1", 1) require.Len(t, data.entries, 1) require.Equal(t, "key1", data.entries[0].key) require.Equal(t, "Title 1", data.entries[0].title) require.Equal(t, "Value 1", data.entries[0].value) require.Equal(t, 1, data.entries[0].priority) return nil }) pre := newPreStartupMessageData(&ListenData{}) require.NoError(t, app.hooks.executeOnPreStartupMessageHooks(pre)) }) t.Run("AddWarning", func(t *testing.T) { t.Parallel() app := New() app.Hooks().OnPreStartupMessage(func(data *PreStartupMessageData) error { data.ResetEntries() data.AddWarning("key1", "Title 1", "Value 1", 1) require.Len(t, data.entries, 1) require.Equal(t, "key1", data.entries[0].key) require.Equal(t, "Title 1", data.entries[0].title) require.Equal(t, "Value 1", data.entries[0].value) require.Equal(t, 1, data.entries[0].priority) return nil }) pre := newPreStartupMessageData(&ListenData{}) require.NoError(t, app.hooks.executeOnPreStartupMessageHooks(pre)) }) t.Run("AddError", func(t *testing.T) { t.Parallel() app := New() app.Hooks().OnPreStartupMessage(func(data *PreStartupMessageData) error { data.ResetEntries() data.AddError("key1", "Title 1", "Value 1", 1) require.Len(t, data.entries, 1) require.Equal(t, "key1", data.entries[0].key) require.Equal(t, "Title 1", data.entries[0].title) require.Equal(t, "Value 1", data.entries[0].value) require.Equal(t, 1, data.entries[0].priority) return nil }) pre := newPreStartupMessageData(&ListenData{}) require.NoError(t, app.hooks.executeOnPreStartupMessageHooks(pre)) }) t.Run("AddInfo-UpdateExisting", func(t *testing.T) { t.Parallel() app := New() app.Hooks().OnPreStartupMessage(func(data *PreStartupMessageData) error { data.ResetEntries() data.AddInfo("key1", "Title 1", "Value 1", 1) data.AddInfo("key1", "Updated Title", "Updated Value", 2) require.Len(t, data.entries, 1) require.Equal(t, "key1", data.entries[0].key) require.Equal(t, "Updated Title", data.entries[0].title) require.Equal(t, "Updated Value", data.entries[0].value) require.Equal(t, 2, data.entries[0].priority) return nil }) pre := newPreStartupMessageData(&ListenData{}) require.NoError(t, app.hooks.executeOnPreStartupMessageHooks(pre)) }) t.Run("DeleteEntry", func(t *testing.T) { t.Parallel() app := New() app.Hooks().OnPreStartupMessage(func(data *PreStartupMessageData) error { data.ResetEntries() data.AddInfo("key1", "Title 1", "Value 1", 1) data.AddInfo("key2", "Title 2", "Value 2", 2) require.Len(t, data.entries, 2) data.DeleteEntry("key1") require.Len(t, data.entries, 1) require.Equal(t, "key2", data.entries[0].key) data.DeleteEntry("key-not-exist") // should not panic require.Len(t, data.entries, 1) return nil }) pre := newPreStartupMessageData(&ListenData{}) require.NoError(t, app.hooks.executeOnPreStartupMessageHooks(pre)) }) } func Test_Hook_OnListenPrefork(t *testing.T) { t.Parallel() app := New() buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) app.Hooks().OnListen(func(_ ListenData) error { _, err := buf.WriteString("ready") require.NoError(t, err) return nil }) go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listen(":0", ListenConfig{DisableStartupMessage: true, EnablePrefork: true})) require.Equal(t, "ready", buf.String()) } func Test_Hook_OnHook(t *testing.T) { app := New() // Reset test var testPreforkMaster = true testOnPrefork = true go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() app.Hooks().OnFork(func(pid int) error { require.Equal(t, 1, pid) return nil }) require.NoError(t, app.prefork(":0", nil, &ListenConfig{DisableStartupMessage: true, EnablePrefork: true})) } func Test_Hook_OnMount(t *testing.T) { t.Parallel() app := New() app.Get("/", testSimpleHandler).Name("x") subApp := New() subApp.Get("/test", testSimpleHandler) subApp.Hooks().OnMount(func(parent *App) error { require.Empty(t, parent.mountFields.mountPath) return nil }) app.Use("/sub", subApp) } func Test_executeOnRouteHooks_ErrorWithMount(t *testing.T) { t.Parallel() app := New() app.mountFields.mountPath = testMountPath var received string app.Hooks().OnRoute(func(r Route) error { received = r.Path return errors.New("hook error") }) err := app.hooks.executeOnRouteHooks(&Route{Path: "/foo", path: "/foo"}) require.Equal(t, testMountPath+"/foo", received) require.EqualError(t, err, "hook error") } func Test_executeOnNameHooks_ErrorWithMount(t *testing.T) { t.Parallel() app := New() app.mountFields.mountPath = testMountPath var received string app.Hooks().OnName(func(r Route) error { received = r.Path return errors.New("name error") }) err := app.hooks.executeOnNameHooks(&Route{Path: "/bar", path: "/bar"}) require.Equal(t, testMountPath+"/bar", received) require.EqualError(t, err, "name error") } func Test_executeOnGroupHooks_ErrorWithMount(t *testing.T) { t.Parallel() app := New() app.mountFields.mountPath = testMountPath var prefix string app.Hooks().OnGroup(func(g Group) error { prefix = g.Prefix return errors.New("group error") }) err := app.hooks.executeOnGroupHooks(Group{Prefix: "/grp"}) require.Equal(t, testMountPath+"/grp", prefix) require.EqualError(t, err, "group error") } func Test_executeOnGroupNameHooks_ErrorWithMount(t *testing.T) { t.Parallel() app := New() app.mountFields.mountPath = testMountPath var prefix string app.Hooks().OnGroupName(func(g Group) error { prefix = g.Prefix return errors.New("group name error") }) err := app.hooks.executeOnGroupNameHooks(Group{Prefix: "/grp"}) require.Equal(t, testMountPath+"/grp", prefix) require.EqualError(t, err, "group name error") } func Test_executeOnListenHooks_Error(t *testing.T) { t.Parallel() app := New() app.Hooks().OnListen(func(_ ListenData) error { return errors.New("listen error") }) err := app.hooks.executeOnListenHooks(&ListenData{Host: "127.0.0.1", Port: "0"}) require.EqualError(t, err, "listen error") } func Test_executeOnPreStartupMessageHooks_Error(t *testing.T) { t.Parallel() app := New() app.Hooks().OnPreStartupMessage(func(_ *PreStartupMessageData) error { return errors.New("pre startup message error") }) err := app.hooks.executeOnPreStartupMessageHooks(newPreStartupMessageData(&ListenData{})) require.EqualError(t, err, "pre startup message error") } func Test_executeOnPostStartupMessageHooks_Error(t *testing.T) { t.Parallel() app := New() app.Hooks().OnPostStartupMessage(func(_ *PostStartupMessageData) error { return errors.New("post startup message error") }) err := app.hooks.executeOnPostStartupMessageHooks(newPostStartupMessageData(&ListenData{}, false, false, false)) require.EqualError(t, err, "post startup message error") } func Test_executeOnPreShutdownHooks_Error(t *testing.T) { t.Parallel() app := New() app.Hooks().OnPreShutdown(func() error { return errors.New("pre error") }) var buf bytes.Buffer log.SetOutput(&buf) app.hooks.executeOnPreShutdownHooks() require.NotZero(t, buf.Len()) } func Test_executeOnForkHooks_Error(t *testing.T) { t.Parallel() app := New() app.Hooks().OnFork(func(pid int) error { require.Equal(t, 1, pid) return errors.New("fork error") }) var buf bytes.Buffer log.SetOutput(&buf) app.hooks.executeOnForkHooks(1) require.NotZero(t, buf.Len()) } func Test_executeOnMountHooks_Error(t *testing.T) { t.Parallel() app := New() parent := New() app.Hooks().OnMount(func(a *App) error { require.Equal(t, parent, a) return errors.New("mount error") }) err := app.hooks.executeOnMountHooks(parent) require.EqualError(t, err, "mount error") } ================================================ FILE: internal/memory/memory.go ================================================ // Package memory provides a high-performance in-memory storage that can store // any type without encoding overhead. Unlike the standard storage interface, // this storage works directly with Go types for maximum speed. // // # Safety Considerations // // This storage automatically performs defensive copying for: // - String keys: Copied to prevent corruption from pooled buffers // - []byte values: Copied on both Set and Get to prevent external mutation // // For other types (structs, ints, etc.), Go's value semantics provide natural // protection. However, if storing pointers or slices of non-byte types, // callers are responsible for not mutating the underlying data. // // This storage is primarily used internally by middleware for performance- // critical operations where the stored data types are known and controlled. package memory import ( "sync" "time" "github.com/gofiber/utils/v2" ) // Storage stores arbitrary values in memory for use in tests and benchmarks. type Storage struct { data map[string]item // data mu sync.RWMutex } type item struct { v any // val // max value is 4294967295 -> Sun Feb 07 2106 06:28:15 GMT+0000 e uint32 // exp } // New constructs an in-memory Storage initialized with a background GC loop. func New() *Storage { store := &Storage{ data: make(map[string]item), } utils.StartTimeStampUpdater() go store.gc(1 * time.Second) return store } // Get retrieves the value stored under key, returning nil when the entry does // not exist or has expired. // // For []byte values, this returns a defensive copy to prevent callers from // mutating the stored data. Other types are returned as-is. func (s *Storage) Get(key string) any { s.mu.RLock() v, ok := s.data[key] s.mu.RUnlock() if !ok || v.e != 0 && v.e <= utils.Timestamp() { return nil } // Defensive copy for byte slices to prevent external mutation if b, ok := v.v.([]byte); ok { return utils.CopyBytes(b) } return v.v } // Set stores val under key and applies the optional ttl before expiring the // entry. A non-positive ttl keeps the item forever. // // String keys are defensively copied to prevent corruption from pooled buffers. // []byte values are also copied to prevent external mutation of stored data. // Other types are stored as-is (structs are copied by value automatically). func (s *Storage) Set(key string, val any, ttl time.Duration) { var exp uint32 if ttl > 0 { exp = uint32(ttl.Seconds()) + utils.Timestamp() } // Defensive copies to prevent unsafe reuse from sync.Pool keyCopy := utils.CopyString(key) // Copy byte slices to prevent external mutation if b, ok := val.([]byte); ok { val = utils.CopyBytes(b) } i := item{e: exp, v: val} s.mu.Lock() s.data[keyCopy] = i s.mu.Unlock() } // Delete removes key and its associated value from the storage. func (s *Storage) Delete(key string) { s.mu.Lock() delete(s.data, key) s.mu.Unlock() } // Reset clears the storage by dropping every stored key. func (s *Storage) Reset() { nd := make(map[string]item) s.mu.Lock() s.data = nd s.mu.Unlock() } func (s *Storage) gc(sleep time.Duration) { ticker := time.NewTicker(sleep) defer ticker.Stop() var expired []string for range ticker.C { ts := utils.Timestamp() expired = expired[:0] s.mu.RLock() for key, v := range s.data { if v.e != 0 && v.e <= ts { expired = append(expired, key) } } s.mu.RUnlock() if len(expired) == 0 { // avoid locking if nothing to delete continue } s.mu.Lock() // Double-checked locking. // We might have replaced the item in the meantime. for i := range expired { v := s.data[expired[i]] if v.e != 0 && v.e <= ts { delete(s.data, expired[i]) } } s.mu.Unlock() } } ================================================ FILE: internal/memory/memory_test.go ================================================ package memory import ( "testing" "time" "github.com/gofiber/utils/v2" "github.com/stretchr/testify/require" ) // go test -run Test_Memory -v -race func Test_Memory(t *testing.T) { t.Parallel() store := New() var ( key = "john-internal" val any = []byte("doe") exp = 1 * time.Second ) // Set key with value store.Set(key, val, 0) result := store.Get(key) require.Equal(t, val, result) // Get non-existing key result = store.Get("empty") require.Nil(t, result) // Set key with value and ttl store.Set(key, val, exp) time.Sleep(1100 * time.Millisecond) result = store.Get(key) require.Nil(t, result) // Set key with value and no expiration store.Set(key, val, 0) result = store.Get(key) require.Equal(t, val, result) // Delete key store.Delete(key) result = store.Get(key) require.Nil(t, result) // Reset all keys store.Set("john-reset", val, 0) store.Set("doe-reset", val, 0) store.Reset() // Check if all keys are deleted result = store.Get("john-reset") require.Nil(t, result) result = store.Get("doe-reset") require.Nil(t, result) } // go test -v -run=^$ -bench=Benchmark_Memory -benchmem -count=4 func Benchmark_Memory(b *testing.B) { keyLength := 1000 keys := make([]string, keyLength) for i := range keyLength { keys[i] = utils.UUIDv4() } value := []byte("joe") ttl := 2 * time.Second b.Run("fiber_memory", func(b *testing.B) { d := New() b.ReportAllocs() for b.Loop() { for _, key := range keys { d.Set(key, value, ttl) } for _, key := range keys { _ = d.Get(key) } for _, key := range keys { d.Delete(key) } } }) } ================================================ FILE: internal/storage/memory/config.go ================================================ package memory import ( "time" ) // Config defines the config for storage. type Config struct { // Time before deleting expired keys // // Default is 10 * time.Second GCInterval time.Duration } // ConfigDefault is the default config var ConfigDefault = Config{ GCInterval: 10 * time.Second, } // configDefault is a helper function to set default values func configDefault(config ...Config) Config { // Return default config if nothing provided if len(config) < 1 { return ConfigDefault } // Override default config cfg := config[0] // Set default values if int(cfg.GCInterval.Seconds()) <= 0 { cfg.GCInterval = ConfigDefault.GCInterval } return cfg } ================================================ FILE: internal/storage/memory/memory.go ================================================ // Package memory Is a copy of the storage memory from the external storage packet as a purpose to test the behavior // in the unittests when using a storages from these packets package memory import ( "context" "fmt" "sync" "time" "github.com/gofiber/utils/v2" ) // Storage provides an in-memory implementation of the storage interface for // testing purposes. type Storage struct { db map[string]Entry done chan struct{} gcInterval time.Duration mux sync.RWMutex } // Entry represents a value stored in memory along with its expiration. type Entry struct { data []byte // max value is 4294967295 -> Sun Feb 07 2106 06:28:15 GMT+0000 expiry uint32 } // New creates a new memory storage. func New(config ...Config) *Storage { // Set default config cfg := configDefault(config...) // Create storage store := &Storage{ db: make(map[string]Entry), gcInterval: cfg.GCInterval, done: make(chan struct{}), } // Start garbage collector utils.StartTimeStampUpdater() go store.gc() return store } // Get returns the stored value for key, ignoring missing or expired entries by // returning nil. func (s *Storage) Get(key string) ([]byte, error) { if key == "" { return nil, nil } s.mux.RLock() v, ok := s.db[key] s.mux.RUnlock() if !ok || v.expiry != 0 && v.expiry <= utils.Timestamp() { return nil, nil } // Return a copy to prevent callers from mutating stored data return utils.CopyBytes(v.data), nil } // GetWithContext retrieves the value for the given key while honoring context // cancellation. func (s *Storage) GetWithContext(ctx context.Context, key string) ([]byte, error) { if err := wrapContextError(ctx, "get"); err != nil { return nil, err } return s.Get(key) } // Set saves val under key and schedules it to expire after exp. A zero exp keeps // the entry indefinitely. func (s *Storage) Set(key string, val []byte, exp time.Duration) error { // Ain't Nobody Got Time For That if key == "" || len(val) == 0 { return nil } var expire uint32 if exp != 0 { expire = uint32(exp.Seconds()) + utils.Timestamp() } // Copy both key and value to avoid unsafe reuse from sync.Pool keyCopy := utils.CopyString(key) valCopy := utils.CopyBytes(val) e := Entry{data: valCopy, expiry: expire} s.mux.Lock() s.db[keyCopy] = e s.mux.Unlock() return nil } // SetWithContext sets the value for the given key while honoring context // cancellation. func (s *Storage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { if err := wrapContextError(ctx, "set"); err != nil { return err } return s.Set(key, val, exp) } // Delete removes the value stored for key. func (s *Storage) Delete(key string) error { // Ain't Nobody Got Time For That if key == "" { return nil } s.mux.Lock() delete(s.db, key) s.mux.Unlock() return nil } // DeleteWithContext removes the value for the given key while honoring // context cancellation. func (s *Storage) DeleteWithContext(ctx context.Context, key string) error { if err := wrapContextError(ctx, "delete"); err != nil { return err } return s.Delete(key) } // Reset clears all keys and values from the storage map. func (s *Storage) Reset() error { ndb := make(map[string]Entry) s.mux.Lock() s.db = ndb s.mux.Unlock() return nil } // ResetWithContext clears all stored keys while honoring context // cancellation. func (s *Storage) ResetWithContext(ctx context.Context) error { if err := wrapContextError(ctx, "reset"); err != nil { return err } return s.Reset() } // Close stops the background garbage collector and releases resources // associated with the storage instance. func (s *Storage) Close() error { s.done <- struct{}{} return nil } func (s *Storage) gc() { ticker := time.NewTicker(s.gcInterval) defer ticker.Stop() var expired []string for { select { case <-s.done: return case <-ticker.C: ts := utils.Timestamp() expired = expired[:0] s.mux.RLock() for id, v := range s.db { if v.expiry != 0 && v.expiry < ts { expired = append(expired, id) } } s.mux.RUnlock() if len(expired) == 0 { // avoid locking if nothing to delete continue } s.mux.Lock() // Double-checked locking. // We might have replaced the item in the meantime. for i := range expired { v := s.db[expired[i]] if v.expiry != 0 && v.expiry <= ts { delete(s.db, expired[i]) } } s.mux.Unlock() } } } // Conn returns the underlying storage map. The map must not be modified by // callers. func (s *Storage) Conn() map[string]Entry { s.mux.RLock() defer s.mux.RUnlock() return s.db } // Keys returns all keys stored in the memory storage. func (s *Storage) Keys() ([][]byte, error) { s.mux.RLock() defer s.mux.RUnlock() if len(s.db) == 0 { return nil, nil } ts := utils.Timestamp() keys := make([][]byte, 0, len(s.db)) for key, v := range s.db { // Filter out the expired keys if v.expiry == 0 || v.expiry > ts { keys = append(keys, []byte(key)) } } // Double check if no valid keys were found if len(keys) == 0 { return nil, nil } return keys, nil } func wrapContextError(ctx context.Context, op string) error { if err := ctx.Err(); err != nil { return fmt.Errorf("memory storage %s: %w", op, err) } return nil } ================================================ FILE: internal/storage/memory/memory_test.go ================================================ package memory import ( "context" "testing" "time" "github.com/stretchr/testify/require" ) func Test_Storage_Memory_Set(t *testing.T) { t.Parallel() var ( testStore = New() key = "john" val = []byte("doe") ) err := testStore.Set(key, val, 0) require.NoError(t, err) keys, err := testStore.Keys() require.NoError(t, err) require.Len(t, keys, 1) } func Test_Storage_Memory_SetWithContext(t *testing.T) { t.Parallel() var ( testStore = New() key = "john" val = []byte("doe") ) ctx, cancel := context.WithCancel(context.Background()) cancel() err := testStore.SetWithContext(ctx, key, val, 0) require.ErrorIs(t, err, context.Canceled) keys, err := testStore.Keys() require.NoError(t, err) require.Nil(t, keys) } func Test_Storage_Memory_Set_Override(t *testing.T) { t.Parallel() var ( testStore = New() key = "john" val = []byte("doe") ) err := testStore.Set(key, val, 0) require.NoError(t, err) err = testStore.Set(key, val, 0) require.NoError(t, err) keys, err := testStore.Keys() require.NoError(t, err) require.Len(t, keys, 1) } func Test_Storage_Memory_Get(t *testing.T) { t.Parallel() var ( testStore = New() key = "john" val = []byte("doe") ) err := testStore.Set(key, val, 0) require.NoError(t, err) result, err := testStore.Get(key) require.NoError(t, err) require.Equal(t, val, result) keys, err := testStore.Keys() require.NoError(t, err) require.Len(t, keys, 1) } func Test_Storage_Memory_GetWithContext(t *testing.T) { t.Parallel() var ( testStore = New() key = "john" val = []byte("doe") ) err := testStore.Set(key, val, 0) require.NoError(t, err) ctx, cancel := context.WithCancel(context.Background()) cancel() result, err := testStore.GetWithContext(ctx, key) require.ErrorIs(t, err, context.Canceled) require.Nil(t, result) keys, err := testStore.Keys() require.NoError(t, err) require.Len(t, keys, 1) } func Test_Storage_Memory_Set_Expiration(t *testing.T) { t.Parallel() var ( testStore = New(Config{ GCInterval: 300 * time.Millisecond, }) key = "john" val = []byte("doe") exp = 1 * time.Second ) err := testStore.Set(key, val, exp) require.NoError(t, err) // interval + expire + buffer time.Sleep(1500 * time.Millisecond) result, err := testStore.Get(key) require.NoError(t, err) require.Empty(t, result) keys, err := testStore.Keys() require.NoError(t, err) require.Nil(t, keys) } func Test_Storage_Memory_Set_Long_Expiration_with_Keys(t *testing.T) { t.Parallel() var ( testStore = New() key = "john" val = []byte("doe") exp = 3 * time.Second ) keys, err := testStore.Keys() require.NoError(t, err) require.Nil(t, keys) err = testStore.Set(key, val, exp) require.NoError(t, err) time.Sleep(1100 * time.Millisecond) keys, err = testStore.Keys() require.NoError(t, err) require.Len(t, keys, 1) time.Sleep(4000 * time.Millisecond) result, err := testStore.Get(key) require.NoError(t, err) require.Empty(t, result) keys, err = testStore.Keys() require.NoError(t, err) require.Nil(t, keys) } func Test_Storage_Memory_Get_NotExist(t *testing.T) { t.Parallel() testStore := New() result, err := testStore.Get("notexist") require.NoError(t, err) require.Empty(t, result) keys, err := testStore.Keys() require.NoError(t, err) require.Nil(t, keys) } func Test_Storage_Memory_Delete(t *testing.T) { t.Parallel() var ( testStore = New() key = "john" val = []byte("doe") ) err := testStore.Set(key, val, 0) require.NoError(t, err) keys, err := testStore.Keys() require.NoError(t, err) require.Len(t, keys, 1) err = testStore.Delete(key) require.NoError(t, err) result, err := testStore.Get(key) require.NoError(t, err) require.Empty(t, result) keys, err = testStore.Keys() require.NoError(t, err) require.Nil(t, keys) } func Test_Storage_Memory_DeleteWithContext(t *testing.T) { t.Parallel() var ( testStore = New() key = "john" val = []byte("doe") ) err := testStore.Set(key, val, 0) require.NoError(t, err) ctx, cancel := context.WithCancel(context.Background()) cancel() err = testStore.DeleteWithContext(ctx, key) require.ErrorIs(t, err, context.Canceled) result, err := testStore.Get(key) require.NoError(t, err) require.Equal(t, val, result) keys, err := testStore.Keys() require.NoError(t, err) require.Len(t, keys, 1) } func Test_Storage_Memory_Reset(t *testing.T) { t.Parallel() testStore := New() val := []byte("doe") err := testStore.Set("john1", val, 0) require.NoError(t, err) err = testStore.Set("john2", val, 0) require.NoError(t, err) keys, err := testStore.Keys() require.NoError(t, err) require.Len(t, keys, 2) err = testStore.Reset() require.NoError(t, err) result, err := testStore.Get("john1") require.NoError(t, err) require.Empty(t, result) result, err = testStore.Get("john2") require.NoError(t, err) require.Empty(t, result) keys, err = testStore.Keys() require.NoError(t, err) require.Nil(t, keys) } func Test_Storage_Memory_ResetWithContext(t *testing.T) { t.Parallel() testStore := New() val := []byte("doe") err := testStore.Set("john1", val, 0) require.NoError(t, err) err = testStore.Set("john2", val, 0) require.NoError(t, err) keys, err := testStore.Keys() require.NoError(t, err) require.Len(t, keys, 2) ctx, cancel := context.WithCancel(context.Background()) cancel() err = testStore.ResetWithContext(ctx) require.ErrorIs(t, err, context.Canceled) result, err := testStore.Get("john1") require.NoError(t, err) require.Equal(t, val, result) result, err = testStore.Get("john2") require.NoError(t, err) require.Equal(t, val, result) keys, err = testStore.Keys() require.NoError(t, err) require.Len(t, keys, 2) } func Test_Storage_Memory_Close(t *testing.T) { t.Parallel() testStore := New() require.NoError(t, testStore.Close()) } func Test_Storage_Memory_Conn(t *testing.T) { t.Parallel() testStore := New() require.NotNil(t, testStore.Conn()) } // Benchmarks for Set operation func Benchmark_Memory_Set(b *testing.B) { testStore := New() b.ReportAllocs() for b.Loop() { _ = testStore.Set("john", []byte("doe"), 0) //nolint:errcheck // error not needed for benchmark } } func Benchmark_Memory_Set_Parallel(b *testing.B) { testStore := New() b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { _ = testStore.Set("john", []byte("doe"), 0) //nolint:errcheck // error not needed for benchmark } }) } func Benchmark_Memory_Set_Asserted(b *testing.B) { testStore := New() b.ReportAllocs() for b.Loop() { err := testStore.Set("john", []byte("doe"), 0) require.NoError(b, err) } } func Benchmark_Memory_Set_Asserted_Parallel(b *testing.B) { testStore := New() b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { err := testStore.Set("john", []byte("doe"), 0) require.NoError(b, err) } }) } // Benchmarks for Get operation func Benchmark_Memory_Get(b *testing.B) { testStore := New() err := testStore.Set("john", []byte("doe"), 0) require.NoError(b, err) b.ReportAllocs() for b.Loop() { _, _ = testStore.Get("john") //nolint:errcheck // error not needed for benchmark } } func Benchmark_Memory_Get_Parallel(b *testing.B) { testStore := New() err := testStore.Set("john", []byte("doe"), 0) require.NoError(b, err) b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { _, _ = testStore.Get("john") //nolint:errcheck // error not needed for benchmark } }) } func Benchmark_Memory_Get_Asserted(b *testing.B) { testStore := New() err := testStore.Set("john", []byte("doe"), 0) require.NoError(b, err) b.ReportAllocs() for b.Loop() { _, err := testStore.Get("john") require.NoError(b, err) } } func Benchmark_Memory_Get_Asserted_Parallel(b *testing.B) { testStore := New() err := testStore.Set("john", []byte("doe"), 0) require.NoError(b, err) b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { _, err := testStore.Get("john") require.NoError(b, err) } }) } // Benchmarks for SetAndDelete operation func Benchmark_Memory_SetAndDelete(b *testing.B) { testStore := New() b.ReportAllocs() for b.Loop() { _ = testStore.Set("john", []byte("doe"), 0) //nolint:errcheck // error not needed for benchmark _ = testStore.Delete("john") //nolint:errcheck // error not needed for benchmark } } func Benchmark_Memory_SetAndDelete_Parallel(b *testing.B) { testStore := New() b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { _ = testStore.Set("john", []byte("doe"), 0) //nolint:errcheck // error not needed for benchmark _ = testStore.Delete("john") //nolint:errcheck // error not needed for benchmark } }) } func Benchmark_Memory_SetAndDelete_Asserted(b *testing.B) { testStore := New() b.ReportAllocs() for b.Loop() { err := testStore.Set("john", []byte("doe"), 0) require.NoError(b, err) err = testStore.Delete("john") require.NoError(b, err) } } func Benchmark_Memory_SetAndDelete_Asserted_Parallel(b *testing.B) { testStore := New() b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { err := testStore.Set("john", []byte("doe"), 0) require.NoError(b, err) err = testStore.Delete("john") require.NoError(b, err) } }) } ================================================ FILE: internal/tlstest/tls.go ================================================ package tlstest import ( "bytes" "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/pem" "errors" "fmt" "math/big" "net" "time" ) var errAppendCACert = errors.New("failed to append CA certificate to certificate pool") // GetTLSConfigs generates TLS configurations for a test server and client that // trust each other using an in-memory certificate authority. func GetTLSConfigs() (serverTLSConf, clientTLSConf *tls.Config, err error) { //nolint:nonamedreturns // gocritic unnamedResult prefers naming server and client TLS configurations along with the error // set up our CA certificate ca := &x509.Certificate{ SerialNumber: big.NewInt(2021), Subject: pkix.Name{ Organization: []string{"Fiber"}, Country: []string{"NL"}, Province: []string{""}, Locality: []string{"Amsterdam"}, StreetAddress: []string{"Huidenstraat"}, PostalCode: []string{"1011 AA"}, }, NotBefore: time.Now(), NotAfter: time.Now().AddDate(10, 0, 0), IsCA: true, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, BasicConstraintsValid: true, } // create our private and public key caPrivateKey, err := rsa.GenerateKey(rand.Reader, 4096) if err != nil { return nil, nil, fmt.Errorf("generate CA key: %w", err) } // create the CA caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivateKey.PublicKey, caPrivateKey) if err != nil { return nil, nil, fmt.Errorf("create CA certificate: %w", err) } // pem encode var caPEM bytes.Buffer if err = pem.Encode(&caPEM, &pem.Block{ Type: "CERTIFICATE", Bytes: caBytes, }); err != nil { return nil, nil, fmt.Errorf("encode CA cert: %w", err) } var caPrivKeyPEM bytes.Buffer if err = pem.Encode(&caPrivKeyPEM, &pem.Block{ Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(caPrivateKey), }); err != nil { return nil, nil, fmt.Errorf("encode CA private key: %w", err) } // set up our server certificate cert := &x509.Certificate{ SerialNumber: big.NewInt(2021), Subject: pkix.Name{ Organization: []string{"Fiber"}, Country: []string{"NL"}, Province: []string{""}, Locality: []string{"Amsterdam"}, StreetAddress: []string{"Huidenstraat"}, PostalCode: []string{"1011 AA"}, }, IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback}, NotBefore: time.Now(), NotAfter: time.Now().AddDate(10, 0, 0), SubjectKeyId: []byte{1, 2, 3, 4, 6}, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, KeyUsage: x509.KeyUsageDigitalSignature, } certPrivateKey, err := rsa.GenerateKey(rand.Reader, 4096) if err != nil { return nil, nil, fmt.Errorf("generate server key: %w", err) } certBytes, err := x509.CreateCertificate(rand.Reader, cert, ca, &certPrivateKey.PublicKey, caPrivateKey) if err != nil { return nil, nil, fmt.Errorf("create server certificate: %w", err) } var certPEM bytes.Buffer if err = pem.Encode(&certPEM, &pem.Block{ Type: "CERTIFICATE", Bytes: certBytes, }); err != nil { return nil, nil, fmt.Errorf("encode server cert: %w", err) } var certPrivateKeyPEM bytes.Buffer if err = pem.Encode(&certPrivateKeyPEM, &pem.Block{ Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(certPrivateKey), }); err != nil { return nil, nil, fmt.Errorf("encode server private key: %w", err) } serverCert, err := tls.X509KeyPair(certPEM.Bytes(), certPrivateKeyPEM.Bytes()) if err != nil { return nil, nil, fmt.Errorf("load server key pair: %w", err) } serverTLSConf = &tls.Config{ Certificates: []tls.Certificate{serverCert}, MinVersion: tls.VersionTLS12, } certPool := x509.NewCertPool() if ok := certPool.AppendCertsFromPEM(caPEM.Bytes()); !ok { return nil, nil, errAppendCACert } clientTLSConf = &tls.Config{ RootCAs: certPool, MinVersion: tls.VersionTLS12, } return serverTLSConf, clientTLSConf, nil } ================================================ FILE: listen.go ================================================ // ⚡️ Fiber is an Express inspired web framework written in Go with ☕️ // 🤖 GitHub Repository: https://github.com/gofiber/fiber // 📌 API Documentation: https://docs.gofiber.io package fiber import ( "context" "crypto/tls" "crypto/x509" "fmt" "io" "net" "os" "path/filepath" "reflect" "runtime" "slices" "sort" "strings" "text/tabwriter" "time" "github.com/mattn/go-colorable" "github.com/mattn/go-isatty" "golang.org/x/crypto/acme/autocert" "github.com/gofiber/fiber/v3/log" ) // Figlet text to show Fiber ASCII art on startup message var figletFiberText = ` _______ __ / ____(_) /_ ___ _____ / /_ / / __ \/ _ \/ ___/ / __/ / / /_/ / __/ / /_/ /_/_.___/\___/_/ %s` const ( globalIpv4Addr = "0.0.0.0" ) // ListenConfig is a struct to customize startup of Fiber. type ListenConfig struct { // GracefulContext is a field to shutdown Fiber by given context gracefully. // // Default: nil GracefulContext context.Context `json:"graceful_context"` //nolint:containedctx // It's needed to set context inside Listen. // TLSConfigFunc allows customizing tls.Config as you want. // // Default: nil TLSConfigFunc func(tlsConfig *tls.Config) `json:"tls_config_func"` // TLSConfig allows providing a tls.Config used as the base for TLS settings. // This enables external certificate providers via GetCertificate. // // Default: nil TLSConfig *tls.Config `json:"tls_config"` // ListenerFunc allows accessing and customizing net.Listener. // // Default: nil ListenerAddrFunc func(addr net.Addr) `json:"listener_addr_func"` // BeforeServeFunc allows customizing and accessing fiber app before serving the app. // // Default: nil BeforeServeFunc func(app *App) error `json:"before_serve_func"` // AutoCertManager manages TLS certificates automatically using the ACME protocol, // Enables integration with Let's Encrypt or other ACME-compatible providers. // // Default: nil AutoCertManager *autocert.Manager `json:"auto_cert_manager"` // Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only), "unix" (Unix Domain Sockets) // WARNING: When prefork is set to true, only "tcp4" and "tcp6" can be chosen. // // Default: NetworkTCP4 ListenerNetwork string `json:"listener_network"` // CertFile is a path of certificate file. // If you want to use TLS, you have to enter this field. // // Default : "" CertFile string `json:"cert_file"` // KeyFile is a path of certificate's private key. // If you want to use TLS, you have to enter this field. // // Default : "" CertKeyFile string `json:"cert_key_file"` // CertClientFile is a path of client certificate. // If you want to use mTLS, you have to enter this field. // // Default : "" CertClientFile string `json:"cert_client_file"` // When the graceful shutdown begins, use this field to set the timeout // duration. If the timeout is reached, OnPostShutdown will be called with the error. // Set to 0 to disable the timeout and wait indefinitely. // // Default: 10 * time.Second ShutdownTimeout time.Duration `json:"shutdown_timeout"` // FileMode to set for Unix Domain Socket (ListenerNetwork must be "unix") // // Default: 0770 UnixSocketFileMode os.FileMode `json:"unix_socket_file_mode"` // TLSMinVersion allows to set TLS minimum version. // // Default: tls.VersionTLS12 // WARNING: TLS1.0 and TLS1.1 versions are not supported. TLSMinVersion uint16 `json:"tls_min_version"` // When set to true, it will not print out the «Fiber» ASCII art and listening address. // // Default: false DisableStartupMessage bool `json:"disable_startup_message"` // When set to true, this will spawn multiple Go processes listening on the same port. // // Default: false EnablePrefork bool `json:"enable_prefork"` // If set to true, will print all routes with their method, path and handler. // // Default: false EnablePrintRoutes bool `json:"enable_print_routes"` } // listenConfigDefault is a function to set default values of ListenConfig. func listenConfigDefault(config ...ListenConfig) ListenConfig { if len(config) < 1 { return ListenConfig{ TLSMinVersion: tls.VersionTLS12, ListenerNetwork: NetworkTCP4, UnixSocketFileMode: 0o770, ShutdownTimeout: 10 * time.Second, } } cfg := config[0] if cfg.ListenerNetwork == "" { cfg.ListenerNetwork = NetworkTCP4 } if cfg.UnixSocketFileMode == 0 { cfg.UnixSocketFileMode = 0o770 } if cfg.TLSMinVersion == 0 { cfg.TLSMinVersion = tls.VersionTLS12 } if cfg.TLSMinVersion != tls.VersionTLS12 && cfg.TLSMinVersion != tls.VersionTLS13 { panic("unsupported TLS version, please use tls.VersionTLS12 or tls.VersionTLS13") } return cfg } // Listen serves HTTP requests from the given addr. // You should enter custom ListenConfig to customize startup. (TLS, mTLS, prefork...) // // app.Listen(":8080") // app.Listen("127.0.0.1:8080") // app.Listen(":8080", ListenConfig{EnablePrefork: true}) func (app *App) Listen(addr string, config ...ListenConfig) error { cfg := listenConfigDefault(config...) // Configure TLS var tlsConfig *tls.Config if cfg.TLSConfig != nil { tlsConfig = cfg.TLSConfig.Clone() } else { switch { case cfg.AutoCertManager != nil && (cfg.CertFile != "" || cfg.CertKeyFile != ""): return ErrAutoCertWithCertFile case cfg.CertFile != "" && cfg.CertKeyFile != "": cert, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.CertKeyFile) if err != nil { return fmt.Errorf("tls: cannot load TLS key pair from certFile=%q and keyFile=%q: %w", cfg.CertFile, cfg.CertKeyFile, err) } tlsHandler := &TLSHandler{} tlsConfig = &tls.Config{ //nolint:gosec // This is a user input MinVersion: cfg.TLSMinVersion, Certificates: []tls.Certificate{ cert, }, GetCertificate: tlsHandler.GetClientInfo, } if cfg.CertClientFile != "" { clientCACert, err := os.ReadFile(filepath.Clean(cfg.CertClientFile)) if err != nil { return fmt.Errorf("failed to read file: %w", err) } clientCertPool := x509.NewCertPool() clientCertPool.AppendCertsFromPEM(clientCACert) tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert tlsConfig.ClientCAs = clientCertPool } // Attach the tlsHandler to the config app.SetTLSHandler(tlsHandler) case cfg.AutoCertManager != nil: tlsConfig = &tls.Config{ //nolint:gosec // This is a user input MinVersion: cfg.TLSMinVersion, GetCertificate: cfg.AutoCertManager.GetCertificate, NextProtos: []string{"http/1.1", "acme-tls/1"}, } default: } if tlsConfig != nil && cfg.TLSConfigFunc != nil { cfg.TLSConfigFunc(tlsConfig) } } // Graceful shutdown if cfg.GracefulContext != nil { ctx, cancel := context.WithCancel(cfg.GracefulContext) defer cancel() go app.gracefulShutdown(ctx, &cfg) } // Start prefork if cfg.EnablePrefork { return app.prefork(addr, tlsConfig, &cfg) } // Configure Listener ln, err := app.createListener(addr, tlsConfig, &cfg) if err != nil { return fmt.Errorf("failed to listen: %w", err) } // prepare the server for the start app.startupProcess() listenData := app.prepareListenData(ln.Addr().String(), getTLSConfig(ln) != nil, &cfg, nil) // run hooks app.runOnListenHooks(listenData) // Print startup message & routes app.printMessages(&cfg, listenData) // Serve if cfg.BeforeServeFunc != nil { if err := cfg.BeforeServeFunc(app); err != nil { return err } } return app.server.Serve(ln) } // Listener serves HTTP requests from the given listener. // You should enter custom ListenConfig to customize startup. (prefork, startup message, graceful shutdown...) func (app *App) Listener(ln net.Listener, config ...ListenConfig) error { cfg := listenConfigDefault(config...) // Graceful shutdown if cfg.GracefulContext != nil { ctx, cancel := context.WithCancel(cfg.GracefulContext) defer cancel() go app.gracefulShutdown(ctx, &cfg) } // prepare the server for the start app.startupProcess() listenData := app.prepareListenData(ln.Addr().String(), getTLSConfig(ln) != nil, &cfg, nil) // run hooks app.runOnListenHooks(listenData) // Print startup message & routes app.printMessages(&cfg, listenData) // Serve if cfg.BeforeServeFunc != nil { if err := cfg.BeforeServeFunc(app); err != nil { return err } } // Prefork is not supported for custom listeners if cfg.EnablePrefork { log.Warn("Prefork isn't supported for custom listeners.") } return app.server.Serve(ln) } // Create listener function. func (*App) createListener(addr string, tlsConfig *tls.Config, cfg *ListenConfig) (net.Listener, error) { if cfg == nil { cfg = &ListenConfig{} } var listener net.Listener var err error // Remove previously created socket, to make sure it's possible to listen if cfg.ListenerNetwork == NetworkUnix { if err = os.Remove(addr); err != nil && !os.IsNotExist(err) { return nil, fmt.Errorf("unexpected error when trying to remove unix socket file %q: %w", addr, err) } } if tlsConfig != nil { listener, err = tls.Listen(cfg.ListenerNetwork, addr, tlsConfig) } else { listener, err = net.Listen(cfg.ListenerNetwork, addr) } // Check for error before using the listener if err != nil { // Wrap the error from tls.Listen/net.Listen return nil, fmt.Errorf("failed to listen: %w", err) } if cfg.ListenerNetwork == NetworkUnix { if err = os.Chmod(addr, cfg.UnixSocketFileMode); err != nil { return nil, fmt.Errorf("cannot chmod %#o for %q: %w", cfg.UnixSocketFileMode, addr, err) } } if cfg.ListenerAddrFunc != nil { cfg.ListenerAddrFunc(listener.Addr()) } return listener, nil } func (app *App) printMessages(cfg *ListenConfig, listenData *ListenData) { app.startupMessage(listenData, cfg) if cfg.EnablePrintRoutes { app.printRoutesMessage() } } // prepareListenData creates a ListenData instance populated with the application metadata. func (app *App) prepareListenData(addr string, isTLS bool, cfg *ListenConfig, childPIDs []int) *ListenData { //revive:disable-line:flag-parameter // Accepting a bool param named isTLS is fine here host, port := parseAddr(addr) if host == "" { if cfg.ListenerNetwork == NetworkTCP6 { host = "[::1]" } else { host = globalIpv4Addr } } processCount := 1 if cfg.EnablePrefork { processCount = runtime.GOMAXPROCS(0) } var clonedPIDs []int if len(childPIDs) > 0 { clonedPIDs = slices.Clone(childPIDs) } return &ListenData{ Host: host, Port: port, Version: Version, AppName: app.config.AppName, ColorScheme: app.config.ColorScheme, ChildPIDs: clonedPIDs, HandlerCount: int(app.handlersCount), ProcessCount: processCount, PID: os.Getpid(), TLS: isTLS, Prefork: cfg.EnablePrefork, } } // startupMessage renders the startup banner using the provided listener metadata and configuration. func (app *App) startupMessage(listenData *ListenData, cfg *ListenConfig) { preData := newPreStartupMessageData(listenData) colors := listenData.ColorScheme out := colorable.NewColorableStdout() if os.Getenv("TERM") == "dumb" || os.Getenv("NO_COLOR") == "1" || (!isatty.IsTerminal(os.Stdout.Fd()) && !isatty.IsCygwinTerminal(os.Stdout.Fd())) { out = colorable.NewNonColorable(os.Stdout) } // Add default entries scheme := schemeHTTP if listenData.TLS { scheme = schemeHTTPS } if listenData.Host == globalIpv4Addr { preData.AddInfo("server_address", "Server started on", fmt.Sprintf("%s%s://127.0.0.1:%s%s (bound on host 0.0.0.0 and port %s)", colors.Blue, scheme, listenData.Port, colors.Reset, listenData.Port), 10) } else { preData.AddInfo("server_address", "Server started on", fmt.Sprintf("%s%s://%s:%s%s", colors.Blue, scheme, listenData.Host, listenData.Port, colors.Reset), 10) } if listenData.AppName != "" { preData.AddInfo("app_name", "Application name", fmt.Sprintf("\t%s%s%s", colors.Blue, listenData.AppName, colors.Reset), 9) } preData.AddInfo("total_handlers", "Total handlers", fmt.Sprintf("\t%s%d%s", colors.Blue, listenData.HandlerCount, colors.Reset), 8) if listenData.Prefork { preData.AddInfo("prefork", "Prefork", fmt.Sprintf("\t\t%sEnabled%s", colors.Blue, colors.Reset), 7) } else { preData.AddInfo("prefork", "Prefork", fmt.Sprintf("\t\t%sDisabled%s", colors.Red, colors.Reset), 6) } preData.AddInfo("pid", "PID", fmt.Sprintf("\t\t%s%d%s", colors.Blue, listenData.PID, colors.Reset), 5) preData.AddInfo("process_count", "Total process count", fmt.Sprintf("%s%d%s", colors.Blue, listenData.ProcessCount, colors.Reset), 4) if err := app.hooks.executeOnPreStartupMessageHooks(preData); err != nil { log.Errorf("failed to call pre startup message hook: %v", err) } disabled := cfg.DisableStartupMessage isChild := IsChild() prevented := preData != nil && preData.PreventDefault defer func() { postData := newPostStartupMessageData(listenData, disabled, isChild, prevented) if err := app.hooks.executeOnPostStartupMessageHooks(postData); err != nil { log.Errorf("failed to call post startup message hook: %v", err) } }() if preData == nil || disabled || isChild || prevented { return } if preData.BannerHeader != "" { header := preData.BannerHeader fmt.Fprint(out, header) if !strings.HasSuffix(header, "\n") { fmt.Fprintln(out) } } else { fmt.Fprintf(out, "%s\n", fmt.Sprintf(figletFiberText, colors.Red+"v"+listenData.Version+colors.Reset)) fmt.Fprintln(out, strings.Repeat("-", 50)) } printStartupEntries(out, &colors, preData.entries) if err := app.logServices(app.servicesStartupCtx(), out, &colors); err != nil { log.Errorf("failed to log services: %v", err) } if listenData.Prefork && len(listenData.ChildPIDs) > 0 { fmt.Fprintf(out, "%sINFO%s Child PIDs: \t\t%s", colors.Green, colors.Reset, colors.Blue) totalPIDs := len(listenData.ChildPIDs) rowTotalPidCount := 10 for i := 0; i < totalPIDs; i += rowTotalPidCount { start := i end := min(i+rowTotalPidCount, totalPIDs) for idx, pid := range listenData.ChildPIDs[start:end] { fmt.Fprintf(out, "%d", pid) if idx+1 != len(listenData.ChildPIDs[start:end]) { fmt.Fprint(out, ", ") } } fmt.Fprintf(out, "\n%s", colors.Reset) } } fmt.Fprintf(out, "\n%s", colors.Reset) } func printStartupEntries(out io.Writer, colors *Colors, entries []startupMessageEntry) { // Sort entries by priority (higher priority first) sort.Slice(entries, func(i, j int) bool { return entries[i].priority > entries[j].priority }) for _, entry := range entries { var label string var color string switch entry.level { case StartupMessageLevelWarning: label, color = "WARN", colors.Yellow case StartupMessageLevelError: label, color = errString, colors.Red default: label, color = "INFO", colors.Green } fmt.Fprintf(out, "%s%s%s %s: \t%s%s%s\n", color, label, colors.Reset, entry.title, colors.Blue, entry.value, colors.Reset) } } // printRoutesMessage print all routes with method, path, name and handlers // in a format of table, like this: // method | path | name | handlers // GET | / | routeName | github.com/gofiber/fiber/v3.emptyHandler // HEAD | / | | github.com/gofiber/fiber/v3.emptyHandler func (app *App) printRoutesMessage() { // ignore child processes if IsChild() { return } // Alias colors colors := app.config.ColorScheme var routes []RouteMessage for _, routeStack := range app.stack { for _, route := range routeStack { var newRoute RouteMessage newRoute.name = route.Name newRoute.method = route.Method newRoute.path = route.Path for _, handler := range route.Handlers { newRoute.handlers += runtime.FuncForPC(reflect.ValueOf(handler).Pointer()).Name() + " " } routes = append(routes, newRoute) } } out := colorable.NewColorableStdout() if os.Getenv("TERM") == "dumb" || os.Getenv("NO_COLOR") == "1" || (!isatty.IsTerminal(os.Stdout.Fd()) && !isatty.IsCygwinTerminal(os.Stdout.Fd())) { out = colorable.NewNonColorable(os.Stdout) } w := tabwriter.NewWriter(out, 1, 1, 1, ' ', 0) // Sort routes by path sort.Slice(routes, func(i, j int) bool { return routes[i].path < routes[j].path }) fmt.Fprintf(w, "%smethod\t%s| %spath\t%s| %sname\t%s| %shandlers\t%s\n", colors.Blue, colors.White, colors.Green, colors.White, colors.Cyan, colors.White, colors.Yellow, colors.Reset) fmt.Fprintf(w, "%s------\t%s| %s----\t%s| %s----\t%s| %s--------\t%s\n", colors.Blue, colors.White, colors.Green, colors.White, colors.Cyan, colors.White, colors.Yellow, colors.Reset) for _, route := range routes { fmt.Fprintf(w, "%s%s\t%s| %s%s\t%s| %s%s\t%s| %s%s%s\n", colors.Blue, route.method, colors.White, colors.Green, route.path, colors.White, colors.Cyan, route.name, colors.White, colors.Yellow, route.handlers, colors.Reset) } _ = w.Flush() //nolint:errcheck // It is fine to ignore the error here } // shutdown goroutine func (app *App) gracefulShutdown(ctx context.Context, cfg *ListenConfig) { <-ctx.Done() var err error if cfg != nil && cfg.ShutdownTimeout != 0 { err = app.ShutdownWithTimeout(cfg.ShutdownTimeout) //nolint:contextcheck // TODO: Implement it } else { err = app.Shutdown() //nolint:contextcheck // TODO: Implement it } if err != nil { app.hooks.executeOnPostShutdownHooks(err) return } app.hooks.executeOnPostShutdownHooks(nil) } ================================================ FILE: listen_test.go ================================================ package fiber import ( "bytes" "context" "crypto/tls" "errors" "fmt" "io" "log" //nolint:depguard // TODO: Required to capture output, use internal log package instead "net" "os" "path/filepath" "sync" "testing" "time" "github.com/gofiber/utils/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" "github.com/valyala/fasthttp/fasthttputil" "golang.org/x/crypto/acme/autocert" ) // go test -run Test_Listen func Test_Listen(t *testing.T) { app := New() require.Error(t, app.Listen(":99999")) go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listen(":0", ListenConfig{DisableStartupMessage: true})) } // go test -run Test_Listen_Graceful_Shutdown func Test_Listen_Graceful_Shutdown(t *testing.T) { t.Run("Basic Graceful Shutdown", func(t *testing.T) { testGracefulShutdown(t, 0) }) t.Run("Shutdown With Timeout", func(t *testing.T) { testGracefulShutdown(t, 500*time.Millisecond) }) t.Run("Shutdown With Timeout Error", func(t *testing.T) { testGracefulShutdown(t, 1*time.Nanosecond) }) } func testGracefulShutdown(t *testing.T, shutdownTimeout time.Duration) { t.Helper() var mu sync.Mutex var shutdown bool var receivedErr error app := New() app.Get("/", func(c Ctx) error { time.Sleep(10 * time.Millisecond) return c.SendString(c.Hostname()) }) ln := fasthttputil.NewInmemoryListener() errs := make(chan error, 1) app.hooks.OnPostShutdown(func(err error) error { mu.Lock() defer mu.Unlock() shutdown = true receivedErr = err return nil }) go func() { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() errs <- app.Listener(ln, ListenConfig{ DisableStartupMessage: true, GracefulContext: ctx, ShutdownTimeout: shutdownTimeout, }) }() require.Eventually(t, func() bool { conn, err := ln.Dial() if err == nil { if err := conn.Close(); err != nil { t.Logf("error closing connection: %v", err) } return true } return false }, time.Second, 100*time.Millisecond, "Server failed to become ready") client := fasthttp.HostClient{ Dial: func(_ string) (net.Conn, error) { return ln.Dial() }, } type testCase struct { expectedErr error expectedBody string name string waitTime time.Duration expectedStatusCode int closeConnection bool } testCases := []testCase{ { name: "Server running normally", waitTime: 500 * time.Millisecond, expectedBody: "example.com", expectedStatusCode: StatusOK, expectedErr: nil, closeConnection: true, }, { name: "Server shutdown complete", waitTime: 3 * time.Second, expectedBody: "", expectedStatusCode: StatusOK, expectedErr: fasthttputil.ErrInmemoryListenerClosed, closeConnection: true, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { time.Sleep(tc.waitTime) req := fasthttp.AcquireRequest() defer fasthttp.ReleaseRequest(req) req.SetRequestURI("http://example.com") resp := fasthttp.AcquireResponse() defer fasthttp.ReleaseResponse(resp) err := client.Do(req, resp) if tc.expectedErr == nil { require.NoError(t, err) require.Equal(t, tc.expectedStatusCode, resp.StatusCode()) require.Equal(t, tc.expectedBody, utils.UnsafeString(resp.Body())) } else { require.ErrorIs(t, err, tc.expectedErr) } }) } mu.Lock() require.True(t, shutdown) if shutdownTimeout == 1*time.Nanosecond { require.Error(t, receivedErr) require.ErrorIs(t, receivedErr, context.DeadlineExceeded) } require.NoError(t, <-errs) mu.Unlock() } // go test -run Test_Listen_Prefork func Test_Listen_Prefork(t *testing.T) { testPreforkMaster = true app := New() require.NoError(t, app.Listen(":0", ListenConfig{DisableStartupMessage: true, EnablePrefork: true})) } // go test -run Test_Listen_TLSMinVersion func Test_Listen_TLSMinVersion(t *testing.T) { testPreforkMaster = true app := New() // Invalid TLSMinVersion require.Panics(t, func() { _ = app.Listen(":0", ListenConfig{TLSMinVersion: tls.VersionTLS10}) //nolint:errcheck // ignore error }) require.Panics(t, func() { _ = app.Listen(":0", ListenConfig{TLSMinVersion: tls.VersionTLS11}) //nolint:errcheck // ignore error }) // Prefork require.Panics(t, func() { _ = app.Listen(":0", ListenConfig{DisableStartupMessage: true, EnablePrefork: true, TLSMinVersion: tls.VersionTLS10}) //nolint:errcheck // ignore error }) require.Panics(t, func() { _ = app.Listen(":0", ListenConfig{DisableStartupMessage: true, EnablePrefork: true, TLSMinVersion: tls.VersionTLS11}) //nolint:errcheck // ignore error }) // Valid TLSMinVersion go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listen(":0", ListenConfig{TLSMinVersion: tls.VersionTLS13})) // Valid TLSMinVersion with Prefork go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listen(":0", ListenConfig{DisableStartupMessage: true, EnablePrefork: true, TLSMinVersion: tls.VersionTLS13})) } // go test -run Test_Listen_TLS func Test_Listen_TLS(t *testing.T) { app := New() // invalid port require.Error(t, app.Listen(":99999", ListenConfig{ CertFile: "./.github/testdata/ssl.pem", CertKeyFile: "./.github/testdata/ssl.key", })) go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listen(":0", ListenConfig{ CertFile: "./.github/testdata/ssl.pem", CertKeyFile: "./.github/testdata/ssl.key", })) } // go test -run Test_Listen_TLS_Prefork func Test_Listen_TLS_Prefork(t *testing.T) { testPreforkMaster = true app := New() // invalid key file content require.Error(t, app.Listen(":0", ListenConfig{ DisableStartupMessage: true, EnablePrefork: true, CertFile: "./.github/testdata/ssl.pem", CertKeyFile: "./.github/testdata/template.tmpl", })) go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listen(":0", ListenConfig{ DisableStartupMessage: true, EnablePrefork: true, CertFile: "./.github/testdata/ssl.pem", CertKeyFile: "./.github/testdata/ssl.key", })) } // go test -run Test_Listen_MutualTLS func Test_Listen_MutualTLS(t *testing.T) { app := New() // invalid port require.Error(t, app.Listen(":99999", ListenConfig{ CertFile: "./.github/testdata/ssl.pem", CertKeyFile: "./.github/testdata/ssl.key", CertClientFile: "./.github/testdata/ca-chain.cert.pem", })) go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listen(":0", ListenConfig{ CertFile: "./.github/testdata/ssl.pem", CertKeyFile: "./.github/testdata/ssl.key", CertClientFile: "./.github/testdata/ca-chain.cert.pem", })) } // go test -run Test_Listen_MutualTLS_Prefork func Test_Listen_MutualTLS_Prefork(t *testing.T) { testPreforkMaster = true app := New() // invalid key file content require.Error(t, app.Listen(":0", ListenConfig{ DisableStartupMessage: true, EnablePrefork: true, CertFile: "./.github/testdata/ssl.pem", CertKeyFile: "./.github/testdata/template.html", CertClientFile: "./.github/testdata/ca-chain.cert.pem", })) go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listen(":0", ListenConfig{ DisableStartupMessage: true, EnablePrefork: true, CertFile: "./.github/testdata/ssl.pem", CertKeyFile: "./.github/testdata/ssl.key", CertClientFile: "./.github/testdata/ca-chain.cert.pem", })) } // go test -run Test_Listener func Test_Listener(t *testing.T) { app := New() go func() { time.Sleep(500 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() ln := fasthttputil.NewInmemoryListener() require.NoError(t, app.Listener(ln)) } func Test_App_Listener_TLS_Listener(t *testing.T) { // Create tls certificate cer, err := tls.LoadX509KeyPair("./.github/testdata/ssl.pem", "./.github/testdata/ssl.key") if err != nil { require.NoError(t, err) } //nolint:gosec // We're in a test so using old ciphers is fine config := &tls.Config{Certificates: []tls.Certificate{cer}} //nolint:gosec // We're in a test so listening on all interfaces is fine ln, err := tls.Listen(NetworkTCP4, ":0", config) require.NoError(t, err) app := New() go func() { time.Sleep(time.Millisecond * 500) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listener(ln)) } // go test -run Test_Listen_TLSConfigFunc func Test_Listen_TLSConfigFunc(t *testing.T) { var callTLSConfig bool app := New() go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listen(":0", ListenConfig{ DisableStartupMessage: true, TLSConfigFunc: func(_ *tls.Config) { callTLSConfig = true }, CertFile: "./.github/testdata/ssl.pem", CertKeyFile: "./.github/testdata/ssl.key", })) require.True(t, callTLSConfig) } // go test -run Test_Listen_TLSConfig func Test_Listen_TLSConfig(t *testing.T) { t.Parallel() cert, err := tls.LoadX509KeyPair("./.github/testdata/ssl.pem", "./.github/testdata/ssl.key") require.NoError(t, err) run := func(name string, cfg ListenConfig) { t.Run(name, func(t *testing.T) { t.Parallel() app := New() go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listen(":0", cfg)) }) } run("TLSConfig with certificates", ListenConfig{ DisableStartupMessage: true, TLSConfig: &tls.Config{ MinVersion: tls.VersionTLS12, Certificates: []tls.Certificate{cert}, }, }) run("TLSConfig with GetCertificate", ListenConfig{ DisableStartupMessage: true, TLSConfig: &tls.Config{ MinVersion: tls.VersionTLS12, GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { return &cert, nil }, }, }) run("TLSConfig ignores other TLS fields", ListenConfig{ DisableStartupMessage: true, TLSConfig: &tls.Config{ MinVersion: tls.VersionTLS12, Certificates: []tls.Certificate{cert}, }, CertFile: "./.github/testdata/does-not-exist.pem", CertKeyFile: "./.github/testdata/does-not-exist.key", CertClientFile: "./.github/testdata/does-not-exist-ca.pem", AutoCertManager: &autocert.Manager{ Prompt: autocert.AcceptTOS, }, }) } // go test -run Test_Listen_TLSCertFiles func Test_Listen_TLSCertFiles(t *testing.T) { t.Parallel() app := New() go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listen(":0", ListenConfig{ DisableStartupMessage: true, CertFile: "./.github/testdata/ssl.pem", CertKeyFile: "./.github/testdata/ssl.key", CertClientFile: "./.github/testdata/ssl.pem", })) } // go test -run Test_Listen_TLSConfig_WithTLSConfigFunc func Test_Listen_TLSConfig_WithTLSConfigFunc(t *testing.T) { t.Parallel() cert, err := tls.LoadX509KeyPair("./.github/testdata/ssl.pem", "./.github/testdata/ssl.key") require.NoError(t, err) var calledTLSConfigFunc bool app := New() go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listen(":0", ListenConfig{ DisableStartupMessage: true, TLSConfig: &tls.Config{ MinVersion: tls.VersionTLS12, Certificates: []tls.Certificate{cert}, }, TLSConfigFunc: func(_ *tls.Config) { calledTLSConfigFunc = true }, })) require.False(t, calledTLSConfigFunc) } // go test -run Test_Listen_AutoCert_Conflicts func Test_Listen_AutoCert_Conflicts(t *testing.T) { t.Parallel() app := New() err := app.Listen(":0", ListenConfig{ AutoCertManager: &autocert.Manager{}, CertFile: "./.github/testdata/ssl.pem", CertKeyFile: "./.github/testdata/ssl.key", }) require.ErrorIs(t, err, ErrAutoCertWithCertFile) } // go test -run Test_Listen_ListenerAddrFunc func Test_Listen_ListenerAddrFunc(t *testing.T) { var network string app := New() go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listen(":0", ListenConfig{ DisableStartupMessage: true, ListenerAddrFunc: func(addr net.Addr) { network = addr.Network() }, CertFile: "./.github/testdata/ssl.pem", CertKeyFile: "./.github/testdata/ssl.key", })) require.Equal(t, "tcp", network) } // go test -run Test_Listen_BeforeServeFunc func Test_Listen_BeforeServeFunc(t *testing.T) { var handlers uint32 app := New() go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() wantErr := errors.New("test") require.ErrorIs(t, app.Listen(":0", ListenConfig{ DisableStartupMessage: true, BeforeServeFunc: func(fiber *App) error { handlers = fiber.HandlersCount() return wantErr }, }), wantErr) require.Zero(t, handlers) } // go test -run Test_Listen_ListenerNetwork func Test_Listen_ListenerNetwork(t *testing.T) { var network string app := New() go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listen(":0", ListenConfig{ DisableStartupMessage: true, ListenerNetwork: NetworkTCP6, ListenerAddrFunc: func(addr net.Addr) { network = addr.String() }, })) require.Contains(t, network, "[::]:") go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listen(":0", ListenConfig{ DisableStartupMessage: true, ListenerNetwork: NetworkTCP4, ListenerAddrFunc: func(addr net.Addr) { network = addr.String() }, })) require.Contains(t, network, "0.0.0.0:") } // go test -run Test_Listen_ListenerNetwork_Unix func Test_Listen_ListenerNetwork_Unix(t *testing.T) { app := New() app.Get("/test", func(c Ctx) error { return c.SendString("all good") }) var ( f os.FileInfo network string reqErr error resp = &fasthttp.Response{} ) // Create temporary directory for storing socket in tmp, err := os.MkdirTemp(os.TempDir(), "fiber-test") require.NoError(t, err) sock := filepath.Join(tmp, "fiber-test.sock") // Make sure temporary directory is cleaned up defer func() { assert.NoError(t, os.RemoveAll(tmp)) }() // Send request through socket go func() { time.Sleep(1000 * time.Millisecond) client := &fasthttp.HostClient{ Addr: sock, Dial: func(addr string) (net.Conn, error) { return net.Dial("unix", addr) }, } req := &fasthttp.Request{} req.SetRequestURI("http://host/test") reqErr = client.Do(req, resp) assert.NoError(t, app.Shutdown()) }() require.NoError(t, app.Listen(sock, ListenConfig{ DisableStartupMessage: true, ListenerNetwork: NetworkUnix, UnixSocketFileMode: 0o666, ListenerAddrFunc: func(addr net.Addr) { network = addr.String() f, err = os.Stat(network) }, })) // Verify that listening and setting permissions works correctly require.Equal(t, sock, network) require.NoError(t, err) require.Equal(t, os.FileMode(0o666), f.Mode().Perm()) // Verify that request was successful require.NoError(t, reqErr) require.Equal(t, 200, resp.StatusCode()) require.Equal(t, "all good", string(resp.Body())) } // go test -run Test_Listen_Master_Process_Show_Startup_Message func Test_Listen_Master_Process_Show_Startup_Message(t *testing.T) { cfg := ListenConfig{ EnablePrefork: true, } ln, err := net.Listen(NetworkTCP4, "127.0.0.1:0") require.NoError(t, err) addr, ok := ln.Addr().(*net.TCPAddr) require.True(t, ok) port := addr.Port require.NoError(t, ln.Close()) childTemplate := []int{11111, 22222, 33333, 44444, 55555, 60000} childPIDs := make([]int, 0, len(childTemplate)*10) for range 10 { childPIDs = append(childPIDs, childTemplate...) } app := New() listenData := app.prepareListenData(fmt.Sprintf(":%d", port), true, &cfg, childPIDs) startupMessage := captureOutput(func() { app.startupMessage(listenData, &cfg) }) colors := Colors{} require.Contains(t, startupMessage, fmt.Sprintf("https://127.0.0.1:%d", port)) require.Contains(t, startupMessage, fmt.Sprintf("(bound on host 0.0.0.0 and port %d)", port)) require.Contains(t, startupMessage, "Child PIDs") require.Contains(t, startupMessage, "11111, 22222, 33333, 44444, 55555, 60000") require.Contains(t, startupMessage, fmt.Sprintf("Prefork: \t\t\t%sEnabled%s", colors.Blue, colors.Reset)) } // go test -run Test_Listen_Master_Process_Show_Startup_MessageWithAppName func Test_Listen_Master_Process_Show_Startup_MessageWithAppName(t *testing.T) { cfg := ListenConfig{ EnablePrefork: true, } app := New(Config{AppName: "Test App v3.0.0"}) ln, err := net.Listen(NetworkTCP4, "127.0.0.1:0") require.NoError(t, err) addr, ok := ln.Addr().(*net.TCPAddr) require.True(t, ok) port := addr.Port require.NoError(t, ln.Close()) childTemplate := []int{11111, 22222, 33333, 44444, 55555, 60000} childPIDs := make([]int, 0, len(childTemplate)*10) for range 10 { childPIDs = append(childPIDs, childTemplate...) } listenData := app.prepareListenData(fmt.Sprintf(":%d", port), true, &cfg, childPIDs) startupMessage := captureOutput(func() { app.startupMessage(listenData, &cfg) }) require.Equal(t, "Test App v3.0.0", app.Config().AppName) require.Contains(t, startupMessage, app.Config().AppName) } // go test -run Test_Listen_Master_Process_Show_Startup_MessageWithAppNameNonAscii func Test_Listen_Master_Process_Show_Startup_MessageWithAppNameNonAscii(t *testing.T) { cfg := ListenConfig{ EnablePrefork: true, } appName := "Serveur de vérification des données" app := New(Config{AppName: appName}) ln, err := net.Listen(NetworkTCP4, "127.0.0.1:0") require.NoError(t, err) addr, ok := ln.Addr().(*net.TCPAddr) require.True(t, ok) port := addr.Port require.NoError(t, ln.Close()) listenData := app.prepareListenData(fmt.Sprintf(":%d", port), false, &cfg, nil) startupMessage := captureOutput(func() { app.startupMessage(listenData, &cfg) }) require.Contains(t, startupMessage, "Serveur de vérification des données") } // go test -run Test_Listen_Master_Process_Show_Startup_MessageWithDisabledPreforkAndCustomEndpoint func Test_Listen_Master_Process_Show_Startup_MessageWithDisabledPreforkAndCustomEndpoint(t *testing.T) { cfg := ListenConfig{ EnablePrefork: false, } appName := "Fiber Example Application" app := New(Config{AppName: appName}) ln, err := net.Listen(NetworkTCP4, "127.0.0.1:0") require.NoError(t, err) addr, ok := ln.Addr().(*net.TCPAddr) require.True(t, ok) port := addr.Port require.NoError(t, ln.Close()) listenData := app.prepareListenData(fmt.Sprintf("server.com:%d", port), true, &cfg, nil) startupMessage := captureOutput(func() { app.startupMessage(listenData, &cfg) }) colors := Colors{} require.Contains(t, startupMessage, fmt.Sprintf("%sINFO%s", colors.Green, colors.Reset)) require.Contains(t, startupMessage, fmt.Sprintf("%s%s%s", colors.Blue, appName, colors.Reset)) expectedURL := fmt.Sprintf("https://server.com:%d", port) require.Contains(t, startupMessage, fmt.Sprintf("%s%s%s", colors.Blue, expectedURL, colors.Reset)) require.Contains(t, startupMessage, fmt.Sprintf("Prefork: \t\t\t%sDisabled%s", colors.Red, colors.Reset)) } func Test_StartupMessageCustomization(t *testing.T) { cfg := ListenConfig{} app := New() listenData := app.prepareListenData(":8080", false, &cfg, nil) app.Hooks().OnPreStartupMessage(func(data *PreStartupMessageData) error { data.BannerHeader = "FOOBER v98\n-------" data.ResetEntries() data.AddInfo("git_hash", "Git hash", "abc123", 3) data.AddInfo("version", "Version", "v98", 2) return nil }) var post PostStartupMessageData app.Hooks().OnPostStartupMessage(func(data *PostStartupMessageData) error { post = *data return nil }) startupMessage := captureOutput(func() { app.startupMessage(listenData, &cfg) }) require.Contains(t, startupMessage, "FOOBER v98") require.Contains(t, startupMessage, "Git hash: \tabc123") require.Contains(t, startupMessage, "Version: \tv98") require.NotContains(t, startupMessage, "Server started on:") require.NotContains(t, startupMessage, "Prefork:") require.False(t, post.Disabled) require.False(t, post.IsChild) require.False(t, post.Prevented) } func Test_StartupMessageDisabledPostHook(t *testing.T) { cfg := ListenConfig{DisableStartupMessage: true} app := New() listenData := app.prepareListenData(":7070", false, &cfg, nil) var post PostStartupMessageData app.Hooks().OnPostStartupMessage(func(data *PostStartupMessageData) error { post = *data return nil }) startupMessage := captureOutput(func() { app.startupMessage(listenData, &cfg) }) require.Empty(t, startupMessage) require.True(t, post.Disabled) require.False(t, post.IsChild) require.False(t, post.Prevented) } func Test_StartupMessagePreventedByHook(t *testing.T) { cfg := ListenConfig{} app := New() listenData := app.prepareListenData(":9090", false, &cfg, nil) app.Hooks().OnPreStartupMessage(func(data *PreStartupMessageData) error { data.PreventDefault = true return nil }) var post PostStartupMessageData app.Hooks().OnPostStartupMessage(func(data *PostStartupMessageData) error { post = *data return nil }) startupMessage := captureOutput(func() { app.startupMessage(listenData, &cfg) }) require.Empty(t, startupMessage) require.False(t, post.Disabled) require.False(t, post.IsChild) require.True(t, post.Prevented) } // go test -run Test_Listen_Print_Route func Test_Listen_Print_Route(t *testing.T) { app := New() app.Get("/", emptyHandler).Name("routeName") printRoutesMessage := captureOutput(func() { app.printRoutesMessage() }) require.Contains(t, printRoutesMessage, MethodGet) require.Contains(t, printRoutesMessage, "/") require.Contains(t, printRoutesMessage, "emptyHandler") require.Contains(t, printRoutesMessage, "routeName") } // go test -run Test_Listen_Print_Route_With_Group func Test_Listen_Print_Route_With_Group(t *testing.T) { app := New() app.Get("/", emptyHandler) v1 := app.Group("v1") v1.Get("/test", emptyHandler).Name("v1") v1.Post("/test/fiber", emptyHandler) v1.Put("/test/fiber/*", emptyHandler) printRoutesMessage := captureOutput(func() { app.printRoutesMessage() }) require.Contains(t, printRoutesMessage, MethodGet) require.Contains(t, printRoutesMessage, "/") require.Contains(t, printRoutesMessage, "emptyHandler") require.Contains(t, printRoutesMessage, "/v1/test") require.Contains(t, printRoutesMessage, "POST") require.Contains(t, printRoutesMessage, "/v1/test/fiber") require.Contains(t, printRoutesMessage, "PUT") require.Contains(t, printRoutesMessage, "/v1/test/fiber/*") } func captureOutput(f func()) string { reader, writer, err := os.Pipe() if err != nil { panic(err) } stdout := os.Stdout stderr := os.Stderr defer func() { os.Stdout = stdout os.Stderr = stderr log.SetOutput(os.Stderr) }() os.Stdout = writer os.Stderr = writer log.SetOutput(writer) out := make(chan string) go func() { var buf bytes.Buffer _, copyErr := io.Copy(&buf, reader) if copyErr != nil { panic(copyErr) } out <- buf.String() // this out channel helps in synchronization }() f() err = writer.Close() if err != nil { panic(err) } return <-out } func emptyHandler(_ Ctx) error { return nil } ================================================ FILE: log/default.go ================================================ package log import ( "context" "fmt" "io" "log" "os" "github.com/gofiber/utils/v2" "github.com/valyala/bytebufferpool" ) var _ AllLogger[*log.Logger] = (*defaultLogger)(nil) type defaultLogger struct { stdlog *log.Logger level Level depth int } // privateLog logs a message at a given level log the default logger. // when the level is fatal, it will exit the program. func (l *defaultLogger) privateLog(lv Level, fmtArgs []any) { if l.level > lv { return } level := lv.toString() buf := bytebufferpool.Get() buf.WriteString(level) fmt.Fprint(buf, fmtArgs...) _ = l.stdlog.Output(l.depth, buf.String()) //nolint:errcheck // It is fine to ignore the error if lv == LevelPanic { panic(buf.String()) } buf.Reset() bytebufferpool.Put(buf) if lv == LevelFatal { os.Exit(1) //nolint:revive // we want to exit the program when Fatal is called } } // privateLogf logs a formatted message at a given level log the default logger. // when the level is fatal, it will exit the program. func (l *defaultLogger) privateLogf(lv Level, format string, fmtArgs []any) { if l.level > lv { return } level := lv.toString() buf := bytebufferpool.Get() buf.WriteString(level) if len(fmtArgs) > 0 { _, _ = fmt.Fprintf(buf, format, fmtArgs...) } else { _, _ = fmt.Fprint(buf, format) } _ = l.stdlog.Output(l.depth, buf.String()) //nolint:errcheck // It is fine to ignore the error if lv == LevelPanic { panic(buf.String()) } buf.Reset() bytebufferpool.Put(buf) if lv == LevelFatal { os.Exit(1) //nolint:revive // we want to exit the program when Fatal is called } } // privateLogw logs a message at a given level log the default logger. // when the level is fatal, it will exit the program. func (l *defaultLogger) privateLogw(lv Level, format string, keysAndValues []any) { if l.level > lv { return } level := lv.toString() buf := bytebufferpool.Get() buf.WriteString(level) // Write format privateLog buffer if format != "" { buf.WriteString(format) } // Write keys and values privateLog buffer if len(keysAndValues) > 0 { if (len(keysAndValues) & 1) == 1 { keysAndValues = append(keysAndValues, "KEYVALS UNPAIRED") } for i := 0; i < len(keysAndValues); i += 2 { if i > 0 || format != "" { buf.WriteByte(' ') } switch key := keysAndValues[i].(type) { case string: buf.WriteString(key) default: _, _ = fmt.Fprint(buf, key) } buf.WriteByte('=') buf.WriteString(utils.ToString(keysAndValues[i+1])) } } _ = l.stdlog.Output(l.depth, buf.String()) //nolint:errcheck // It is fine to ignore the error if lv == LevelPanic { panic(buf.String()) } buf.Reset() bytebufferpool.Put(buf) if lv == LevelFatal { os.Exit(1) //nolint:revive // we want to exit the program when Fatal is called } } // Trace logs the given values at trace level. func (l *defaultLogger) Trace(v ...any) { l.privateLog(LevelTrace, v) } // Debug logs the given values at debug level. func (l *defaultLogger) Debug(v ...any) { l.privateLog(LevelDebug, v) } // Info logs the given values at info level. func (l *defaultLogger) Info(v ...any) { l.privateLog(LevelInfo, v) } // Warn logs the given values at warn level. func (l *defaultLogger) Warn(v ...any) { l.privateLog(LevelWarn, v) } // Error logs the given values at error level. func (l *defaultLogger) Error(v ...any) { l.privateLog(LevelError, v) } // Fatal logs the given values at fatal level and terminates the process. func (l *defaultLogger) Fatal(v ...any) { l.privateLog(LevelFatal, v) } // Panic logs the given values at panic level and panics. func (l *defaultLogger) Panic(v ...any) { l.privateLog(LevelPanic, v) } // Tracef formats according to a format specifier and logs at trace level. func (l *defaultLogger) Tracef(format string, v ...any) { l.privateLogf(LevelTrace, format, v) } // Debugf formats according to a format specifier and logs at debug level. func (l *defaultLogger) Debugf(format string, v ...any) { l.privateLogf(LevelDebug, format, v) } // Infof formats according to a format specifier and logs at info level. func (l *defaultLogger) Infof(format string, v ...any) { l.privateLogf(LevelInfo, format, v) } // Warnf formats according to a format specifier and logs at warn level. func (l *defaultLogger) Warnf(format string, v ...any) { l.privateLogf(LevelWarn, format, v) } // Errorf formats according to a format specifier and logs at error level. func (l *defaultLogger) Errorf(format string, v ...any) { l.privateLogf(LevelError, format, v) } // Fatalf formats according to a format specifier, logs at fatal level, and terminates the process. func (l *defaultLogger) Fatalf(format string, v ...any) { l.privateLogf(LevelFatal, format, v) } // Panicf formats according to a format specifier, logs at panic level, and panics. func (l *defaultLogger) Panicf(format string, v ...any) { l.privateLogf(LevelPanic, format, v) } // Tracew logs at trace level with a message and key/value pairs. func (l *defaultLogger) Tracew(msg string, keysAndValues ...any) { l.privateLogw(LevelTrace, msg, keysAndValues) } // Debugw logs at debug level with a message and key/value pairs. func (l *defaultLogger) Debugw(msg string, keysAndValues ...any) { l.privateLogw(LevelDebug, msg, keysAndValues) } // Infow logs at info level with a message and key/value pairs. func (l *defaultLogger) Infow(msg string, keysAndValues ...any) { l.privateLogw(LevelInfo, msg, keysAndValues) } // Warnw logs at warn level with a message and key/value pairs. func (l *defaultLogger) Warnw(msg string, keysAndValues ...any) { l.privateLogw(LevelWarn, msg, keysAndValues) } // Errorw logs at error level with a message and key/value pairs. func (l *defaultLogger) Errorw(msg string, keysAndValues ...any) { l.privateLogw(LevelError, msg, keysAndValues) } // Fatalw logs at fatal level with a message and key/value pairs, then terminates the process. func (l *defaultLogger) Fatalw(msg string, keysAndValues ...any) { l.privateLogw(LevelFatal, msg, keysAndValues) } // Panicw logs at panic level with a message and key/value pairs, then panics. func (l *defaultLogger) Panicw(msg string, keysAndValues ...any) { l.privateLogw(LevelPanic, msg, keysAndValues) } // WithContext returns a logger that shares the underlying output but adjusts the call depth. func (l *defaultLogger) WithContext(_ context.Context) CommonLogger { return &defaultLogger{ stdlog: l.stdlog, level: l.level, depth: l.depth - 1, } } // SetLevel updates the minimum level that will be emitted by the logger. func (l *defaultLogger) SetLevel(level Level) { l.level = level } // SetOutput replaces the underlying writer used by the logger. func (l *defaultLogger) SetOutput(writer io.Writer) { l.stdlog.SetOutput(writer) } // Logger returns the logger instance. It can be used to adjust the logger configurations in case of need. func (l *defaultLogger) Logger() *log.Logger { return l.stdlog } // DefaultLogger returns the default logger. func DefaultLogger[T any]() AllLogger[T] { if l, ok := logger.(AllLogger[T]); ok { return l } return nil } ================================================ FILE: log/default_test.go ================================================ package log import ( "bytes" "context" "log" "os" "testing" "github.com/stretchr/testify/require" ) const work = "work" func initDefaultLogger() { logger = &defaultLogger{ stdlog: log.New(os.Stderr, "", 0), depth: 4, } } type byteSliceWriter struct { b []byte } func (w *byteSliceWriter) Write(p []byte) (int, error) { w.b = append(w.b, p...) return len(p), nil } func Test_WithContextCaller(t *testing.T) { logger = &defaultLogger{ stdlog: log.New(os.Stderr, "", log.Lshortfile), depth: 4, } var w byteSliceWriter SetOutput(&w) ctx := context.TODO() WithContext(ctx).Info("") Info("") require.Equal(t, "default_test.go:41: [Info] \ndefault_test.go:42: [Info] \n", string(w.b)) } func Test_DefaultLogger(t *testing.T) { initDefaultLogger() var w byteSliceWriter SetOutput(&w) Trace("trace work") Debug("received work order") Info("starting work") Warn("work may fail") Error("work failed") require.Panics(t, func() { Panic("work panic") }) require.Equal(t, "[Trace] trace work\n"+ "[Debug] received work order\n"+ "[Info] starting work\n"+ "[Warn] work may fail\n"+ "[Error] work failed\n"+ "[Panic] work panic\n", string(w.b)) } func Test_DefaultFormatLogger(t *testing.T) { initDefaultLogger() var w byteSliceWriter SetOutput(&w) Tracef("trace %s", work) Debugf("received %s order", work) Infof("starting %s", work) Warnf("%s may fail", work) Errorf("%s failed", work) require.Panics(t, func() { Panicf("%s panic", work) }) require.Equal(t, "[Trace] trace work\n"+ "[Debug] received work order\n"+ "[Info] starting work\n"+ "[Warn] work may fail\n"+ "[Error] work failed\n"+ "[Panic] work panic\n", string(w.b)) } func Test_CtxLogger(t *testing.T) { initDefaultLogger() var w byteSliceWriter SetOutput(&w) ctx := context.Background() WithContext(ctx).Tracef("trace %s", work) WithContext(ctx).Debugf("received %s order", work) WithContext(ctx).Infof("starting %s", work) WithContext(ctx).Warnf("%s may fail", work) WithContext(ctx).Errorf("%s failed %d", work, 50) require.Panics(t, func() { WithContext(ctx).Panicf("%s panic", work) }) require.Equal(t, "[Trace] trace work\n"+ "[Debug] received work order\n"+ "[Info] starting work\n"+ "[Warn] work may fail\n"+ "[Error] work failed 50\n"+ "[Panic] work panic\n", string(w.b)) } func Test_LogfKeyAndValues(t *testing.T) { tests := []struct { name string format string wantOutput string fmtArgs []any keysAndValues []any level Level }{ { name: "test logf with debug level and key-values", level: LevelDebug, format: "", fmtArgs: nil, keysAndValues: []any{"name", "Bob", "age", 30}, wantOutput: "[Debug] name=Bob age=30\n", }, { name: "test logf with info level and key-values", level: LevelInfo, format: "", fmtArgs: nil, keysAndValues: []any{"status", "ok", "code", 200}, wantOutput: "[Info] status=ok code=200\n", }, { name: "test logf with warn level and key-values", level: LevelWarn, format: "", fmtArgs: nil, keysAndValues: []any{"error", "not found", "id", 123}, wantOutput: "[Warn] error=not found id=123\n", }, { name: "test logf with format and key-values", level: LevelWarn, format: "test", fmtArgs: nil, keysAndValues: []any{"error", "not found", "id", 123}, wantOutput: "[Warn] test error=not found id=123\n", }, { name: "test logf with one key", level: LevelWarn, format: "", fmtArgs: nil, keysAndValues: []any{"error"}, wantOutput: "[Warn] error=KEYVALS UNPAIRED\n", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var buf bytes.Buffer l := &defaultLogger{ stdlog: log.New(&buf, "", 0), level: tt.level, depth: 4, } l.privateLogw(tt.level, tt.format, tt.keysAndValues) require.Equal(t, tt.wantOutput, buf.String()) }) } } func Test_SetLevel(t *testing.T) { setLogger := &defaultLogger{ stdlog: log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile|log.Lmicroseconds), depth: 4, } setLogger.SetLevel(LevelTrace) require.Equal(t, LevelTrace, setLogger.level) require.Equal(t, LevelTrace.toString(), setLogger.level.toString()) setLogger.SetLevel(LevelDebug) require.Equal(t, LevelDebug, setLogger.level) require.Equal(t, LevelDebug.toString(), setLogger.level.toString()) setLogger.SetLevel(LevelInfo) require.Equal(t, LevelInfo, setLogger.level) require.Equal(t, LevelInfo.toString(), setLogger.level.toString()) setLogger.SetLevel(LevelWarn) require.Equal(t, LevelWarn, setLogger.level) require.Equal(t, LevelWarn.toString(), setLogger.level.toString()) setLogger.SetLevel(LevelError) require.Equal(t, LevelError, setLogger.level) require.Equal(t, LevelError.toString(), setLogger.level.toString()) setLogger.SetLevel(LevelFatal) require.Equal(t, LevelFatal, setLogger.level) require.Equal(t, LevelFatal.toString(), setLogger.level.toString()) setLogger.SetLevel(LevelPanic) require.Equal(t, LevelPanic, setLogger.level) require.Equal(t, LevelPanic.toString(), setLogger.level.toString()) setLogger.SetLevel(8) require.Equal(t, 8, int(setLogger.level)) require.Equal(t, "[?8] ", setLogger.level.toString()) } func Test_Logger(t *testing.T) { underlyingLogger := log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile|log.Lmicroseconds) setLogger := &defaultLogger{ stdlog: underlyingLogger, depth: 4, } require.Equal(t, underlyingLogger, setLogger.Logger()) logger := setLogger.Logger() logger.SetFlags(log.LstdFlags | log.Lshortfile | log.Lmicroseconds) require.Equal(t, log.LstdFlags|log.Lshortfile|log.Lmicroseconds, setLogger.stdlog.Flags()) } func Test_Debugw(t *testing.T) { initDefaultLogger() var w byteSliceWriter SetOutput(&w) msg := "debug work" keysAndValues := []any{"key1", "value1", "key2", "value2"} Debugw(msg, keysAndValues...) require.Equal(t, "[Debug] debug work key1=value1 key2=value2\n", string(w.b)) } func Test_Infow(t *testing.T) { initDefaultLogger() var w byteSliceWriter SetOutput(&w) msg := "info work" keysAndValues := []any{"key1", "value1", "key2", "value2"} Infow(msg, keysAndValues...) require.Equal(t, "[Info] info work key1=value1 key2=value2\n", string(w.b)) } func Test_Warnw(t *testing.T) { initDefaultLogger() var w byteSliceWriter SetOutput(&w) msg := "warning work" keysAndValues := []any{"key1", "value1", "key2", "value2"} Warnw(msg, keysAndValues...) require.Equal(t, "[Warn] warning work key1=value1 key2=value2\n", string(w.b)) } func Test_Errorw(t *testing.T) { initDefaultLogger() var w byteSliceWriter SetOutput(&w) msg := "error work" keysAndValues := []any{"key1", "value1", "key2", "value2"} Errorw(msg, keysAndValues...) require.Equal(t, "[Error] error work key1=value1 key2=value2\n", string(w.b)) } func Test_Panicw(t *testing.T) { initDefaultLogger() var w byteSliceWriter SetOutput(&w) msg := "panic work" keysAndValues := []any{"key1", "value1", "key2", "value2"} require.Panics(t, func() { Panicw(msg, keysAndValues...) }) require.Equal(t, "[Panic] panic work key1=value1 key2=value2\n", string(w.b)) } func Test_Tracew(t *testing.T) { initDefaultLogger() var w byteSliceWriter SetOutput(&w) msg := "trace work" keysAndValues := []any{"key1", "value1", "key2", "value2"} Tracew(msg, keysAndValues...) require.Equal(t, "[Trace] trace work key1=value1 key2=value2\n", string(w.b)) } type stringKey struct { value string } func (k stringKey) String() string { return "key:" + k.value } func Test_DefaultLoggerNonStringKeys(t *testing.T) { t.Parallel() t.Run("Tracew with non-string keys", func(t *testing.T) { t.Parallel() var buf bytes.Buffer l := &defaultLogger{ stdlog: log.New(&buf, "", 0), level: LevelTrace, depth: 4, } require.NotPanics(t, func() { l.Tracew("trace", 123, "value", stringKey{value: "alpha"}, 42) }) require.Equal(t, "[Trace] trace 123=value key:alpha=42\n", buf.String()) }) t.Run("Infow with non-string keys", func(t *testing.T) { t.Parallel() var buf bytes.Buffer l := &defaultLogger{ stdlog: log.New(&buf, "", 0), level: LevelTrace, depth: 4, } require.NotPanics(t, func() { l.Infow("info", 456, "value", stringKey{value: "beta"}, 7) }) require.Equal(t, "[Info] info 456=value key:beta=7\n", buf.String()) }) } func Benchmark_LogfKeyAndValues(b *testing.B) { tests := []struct { name string format string keysAndValues []any level Level }{ { name: "test logf with debug level and key-values", level: LevelDebug, format: "", keysAndValues: []any{"name", "Bob", "age", 30}, }, { name: "test logf with info level and key-values", level: LevelInfo, format: "", keysAndValues: []any{"status", "ok", "code", 200}, }, { name: "test logf with warn level and key-values", level: LevelWarn, format: "", keysAndValues: []any{"error", "not found", "id", 123}, }, { name: "test logf with format and key-values", level: LevelWarn, format: "test", keysAndValues: []any{"error", "not found", "id", 123}, }, { name: "test logf with one key", level: LevelWarn, format: "", keysAndValues: []any{"error"}, }, } for _, tt := range tests { b.Run(tt.name, func(bb *testing.B) { var buf bytes.Buffer l := &defaultLogger{ stdlog: log.New(&buf, "", 0), level: tt.level, depth: 4, } bb.ReportAllocs() for bb.Loop() { l.privateLogw(tt.level, tt.format, tt.keysAndValues) } }) } } func Benchmark_LogfKeyAndValues_Parallel(b *testing.B) { tests := []struct { name string format string keysAndValues []any level Level }{ { name: "debug level with key-values", level: LevelDebug, format: "", keysAndValues: []any{"name", "Bob", "age", 30}, }, { name: "info level with key-values", level: LevelInfo, format: "", keysAndValues: []any{"status", "ok", "code", 200}, }, { name: "warn level with key-values", level: LevelWarn, format: "", keysAndValues: []any{"error", "not found", "id", 123}, }, { name: "warn level with format and key-values", level: LevelWarn, format: "test", keysAndValues: []any{"error", "not found", "id", 123}, }, { name: "warn level with one key", level: LevelWarn, format: "", keysAndValues: []any{"error"}, }, } for _, tt := range tests { b.Run(tt.name, func(bb *testing.B) { bb.ReportAllocs() bb.ResetTimer() bb.RunParallel(func(pb *testing.PB) { var buf bytes.Buffer l := &defaultLogger{ stdlog: log.New(&buf, "", 0), level: tt.level, depth: 4, } for pb.Next() { l.privateLogw(tt.level, tt.format, tt.keysAndValues) } }) }) } } ================================================ FILE: log/fiberlog.go ================================================ package log import ( "context" "io" ) // Fatal calls the default logger's Fatal method and then os.Exit(1). func Fatal(v ...any) { logger.Fatal(v...) } // Error calls the default logger's Error method. func Error(v ...any) { logger.Error(v...) } // Warn calls the default logger's Warn method. func Warn(v ...any) { logger.Warn(v...) } // Info calls the default logger's Info method. func Info(v ...any) { logger.Info(v...) } // Debug calls the default logger's Debug method. func Debug(v ...any) { logger.Debug(v...) } // Trace calls the default logger's Trace method. func Trace(v ...any) { logger.Trace(v...) } // Panic calls the default logger's Panic method. func Panic(v ...any) { logger.Panic(v...) } // Fatalf calls the default logger's Fatalf method and then os.Exit(1). func Fatalf(format string, v ...any) { logger.Fatalf(format, v...) } // Errorf calls the default logger's Errorf method. func Errorf(format string, v ...any) { logger.Errorf(format, v...) } // Warnf calls the default logger's Warnf method. func Warnf(format string, v ...any) { logger.Warnf(format, v...) } // Infof calls the default logger's Infof method. func Infof(format string, v ...any) { logger.Infof(format, v...) } // Debugf calls the default logger's Debugf method. func Debugf(format string, v ...any) { logger.Debugf(format, v...) } // Tracef calls the default logger's Tracef method. func Tracef(format string, v ...any) { logger.Tracef(format, v...) } // Panicf calls the default logger's Tracef method. func Panicf(format string, v ...any) { logger.Panicf(format, v...) } // Tracew logs a message with some additional context. The variadic key-value // pairs are treated as they are privateLog With. func Tracew(msg string, keysAndValues ...any) { logger.Tracew(msg, keysAndValues...) } // Debugw logs a message with some additional context. The variadic key-value // pairs are treated as they are privateLog With. func Debugw(msg string, keysAndValues ...any) { logger.Debugw(msg, keysAndValues...) } // Infow logs a message with some additional context. The variadic key-value // pairs are treated as they are privateLog With. func Infow(msg string, keysAndValues ...any) { logger.Infow(msg, keysAndValues...) } // Warnw logs a message with some additional context. The variadic key-value // pairs are treated as they are privateLog With. func Warnw(msg string, keysAndValues ...any) { logger.Warnw(msg, keysAndValues...) } // Errorw logs a message with some additional context. The variadic key-value // pairs are treated as they are privateLog With. func Errorw(msg string, keysAndValues ...any) { logger.Errorw(msg, keysAndValues...) } // Fatalw logs a message with some additional context. The variadic key-value // pairs are treated as they are privateLog With. func Fatalw(msg string, keysAndValues ...any) { logger.Fatalw(msg, keysAndValues...) } // Panicw logs a message with some additional context. The variadic key-value // pairs are treated as they are privateLog With. func Panicw(msg string, keysAndValues ...any) { logger.Panicw(msg, keysAndValues...) } // WithContext binds the default logger to the provided context and returns the // contextualized logger. func WithContext(ctx context.Context) CommonLogger { return logger.WithContext(ctx) } // SetLogger sets the default logger and the system logger. // Note that this method is not concurrent-safe and must not be called // after the use of DefaultLogger and global functions from this package. func SetLogger[T any](v AllLogger[T]) { logger = v } // SetOutput sets the output of default logger and system logger. By default, it is stderr. func SetOutput(w io.Writer) { logger.SetOutput(w) } // SetLevel sets the level of logs below which logs will not be output. // The default logger is LevelTrace. // Note that this method is not concurrent-safe. func SetLevel(lv Level) { logger.SetLevel(lv) } ================================================ FILE: log/fiberlog_test.go ================================================ package log import ( "log" "os" "testing" "github.com/stretchr/testify/require" ) func Test_DefaultSystemLogger(t *testing.T) { t.Parallel() defaultL := DefaultLogger[*log.Logger]() require.Equal(t, logger, defaultL) } func Test_SetLogger(t *testing.T) { setLog := &defaultLogger{ stdlog: log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile|log.Lmicroseconds), depth: 6, } SetLogger(setLog) require.Equal(t, logger, setLog) } func Test_Fiberlog_SetLevel(t *testing.T) { mockLogger := &defaultLogger{} SetLogger(mockLogger) // Test cases testCases := []struct { name string level Level expected Level }{ { name: "Test case 1", level: LevelDebug, expected: LevelDebug, }, { name: "Test case 2", level: LevelInfo, expected: LevelInfo, }, { name: "Test case 3", level: LevelWarn, expected: LevelWarn, }, { name: "Test case 4", level: LevelError, expected: LevelError, }, { name: "Test case 5", level: LevelFatal, expected: LevelFatal, }, } // Run tests for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { SetLevel(tc.level) require.Equal(t, tc.expected, mockLogger.level) }) } } func Benchmark_DefaultSystemLogger(b *testing.B) { b.ReportAllocs() for b.Loop() { _ = DefaultLogger[*log.Logger]() } } func Benchmark_SetLogger(b *testing.B) { setLog := &defaultLogger{ stdlog: log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile|log.Lmicroseconds), depth: 6, } b.ReportAllocs() for b.Loop() { SetLogger(setLog) } } func Benchmark_Fiberlog_SetLevel(b *testing.B) { mockLogger := &defaultLogger{} SetLogger(mockLogger) // Test cases testCases := []struct { name string level Level expected Level }{ { name: "Test case 1", level: LevelDebug, expected: LevelDebug, }, { name: "Test case 2", level: LevelInfo, expected: LevelInfo, }, { name: "Test case 3", level: LevelWarn, expected: LevelWarn, }, { name: "Test case 4", level: LevelError, expected: LevelError, }, { name: "Test case 5", level: LevelFatal, expected: LevelFatal, }, } for _, tc := range testCases { b.ReportAllocs() b.Run(tc.name, func(b *testing.B) { for b.Loop() { SetLevel(tc.level) } }) } } func Benchmark_DefaultSystemLogger_Parallel(b *testing.B) { b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { _ = DefaultLogger[*log.Logger]() } }) } func Benchmark_SetLogger_Parallel(b *testing.B) { setLog := &defaultLogger{ stdlog: log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile|log.Lmicroseconds), depth: 6, } b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { SetLogger(setLog) } }) } func Benchmark_Fiberlog_SetLevel_Parallel(b *testing.B) { mockLogger := &defaultLogger{} SetLogger(mockLogger) // Test cases testCases := []struct { name string level Level expected Level }{ { name: "Test case 1", level: LevelDebug, expected: LevelDebug, }, { name: "Test case 2", level: LevelInfo, expected: LevelInfo, }, { name: "Test case 3", level: LevelWarn, expected: LevelWarn, }, { name: "Test case 4", level: LevelError, expected: LevelError, }, { name: "Test case 5", level: LevelFatal, expected: LevelFatal, }, } for _, tc := range testCases { b.Run(tc.name+"_Parallel", func(bb *testing.B) { bb.ReportAllocs() bb.ResetTimer() bb.RunParallel(func(pb *testing.PB) { for pb.Next() { SetLevel(tc.level) } }) }) } } ================================================ FILE: log/log.go ================================================ package log import ( "context" "fmt" "io" "log" "os" ) // baseLogger defines the minimal logger functionality required by the package. // It allows storing any logger implementation regardless of its generic type. type baseLogger interface { CommonLogger SetLevel(Level) SetOutput(io.Writer) WithContext(ctx context.Context) CommonLogger } var logger baseLogger = &defaultLogger{ stdlog: log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile|log.Lmicroseconds), depth: 4, } // Logger is a logger interface that provides logging function with levels. type Logger interface { Trace(v ...any) Debug(v ...any) Info(v ...any) Warn(v ...any) Error(v ...any) Fatal(v ...any) Panic(v ...any) } // FormatLogger is a logger interface that output logs with a format. type FormatLogger interface { Tracef(format string, v ...any) Debugf(format string, v ...any) Infof(format string, v ...any) Warnf(format string, v ...any) Errorf(format string, v ...any) Fatalf(format string, v ...any) Panicf(format string, v ...any) } // WithLogger is a logger interface that output logs with a message and key-value pairs. type WithLogger interface { Tracew(msg string, keysAndValues ...any) Debugw(msg string, keysAndValues ...any) Infow(msg string, keysAndValues ...any) Warnw(msg string, keysAndValues ...any) Errorw(msg string, keysAndValues ...any) Fatalw(msg string, keysAndValues ...any) Panicw(msg string, keysAndValues ...any) } // CommonLogger is the set of logging operations available across Fiber's // logging implementations. type CommonLogger interface { Logger FormatLogger WithLogger } // ConfigurableLogger provides methods to config a logger. type ConfigurableLogger[T any] interface { // SetLevel sets logging level. // // Available levels: Trace, Debug, Info, Warn, Error, Fatal, Panic. SetLevel(level Level) // SetOutput sets the logger output. SetOutput(w io.Writer) // Logger returns the logger instance. It can be used to adjust the logger configurations in case of need. Logger() T } // AllLogger is the combination of Logger, FormatLogger, CtxLogger and ConfigurableLogger. // Custom extensions can be made through AllLogger type AllLogger[T any] interface { CommonLogger ConfigurableLogger[T] // WithContext returns a new logger with the given context. WithContext(ctx context.Context) CommonLogger } // Level defines the priority of a log message. // When a logger is configured with a level, any log message with a lower // log level (smaller by integer comparison) will not be output. type Level int // The levels of logs. const ( LevelTrace Level = iota LevelDebug LevelInfo LevelWarn LevelError LevelFatal LevelPanic ) var strs = []string{ "[Trace] ", "[Debug] ", "[Info] ", "[Warn] ", "[Error] ", "[Fatal] ", "[Panic] ", } func (lv Level) toString() string { if lv >= LevelTrace && lv <= LevelPanic { return strs[lv] } return fmt.Sprintf("[?%d] ", lv) } ================================================ FILE: middleware/adaptor/adaptor.go ================================================ package adaptor import ( "context" "errors" "fmt" "io" "net" "net/http" "reflect" "sync" "unsafe" "github.com/gofiber/fiber/v3" "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" "github.com/valyala/fasthttp/fasthttpadaptor" ) // disableLogger implements the fasthttp Logger interface and discards log output. type disableLogger struct{} // Printf implements the fasthttp Logger interface and discards log output. func (*disableLogger) Printf(string, ...any) { } var ctxPool = sync.Pool{ New: func() any { return new(fasthttp.RequestCtx) }, } // LocalContextKey is the key used to store the user's context.Context in the fasthttp request context. // Adapted http.Handler functions can retrieve this context using r.Context().Value(adaptor.LocalContextKey) var localContextKey = &struct{}{} const bufferSize = 32 * 1024 var bufferPool = sync.Pool{ New: func() any { return new([bufferSize]byte) }, } var ( ErrRemoteAddrEmpty = errors.New("remote address cannot be empty") ErrRemoteAddrTooLong = errors.New("remote address too long") ) // HTTPHandlerFunc wraps net/http handler func to fiber handler func HTTPHandlerFunc(h http.HandlerFunc) fiber.Handler { return HTTPHandler(h) } // HTTPHandler wraps net/http handler to fiber handler func HTTPHandler(h http.Handler) fiber.Handler { handler := fasthttpadaptor.NewFastHTTPHandler(h) return func(c fiber.Ctx) error { handler(c.RequestCtx()) return nil } } // HTTPHandlerWithContext is like HTTPHandler, but additionally stores Fiber’s user context in the request context func HTTPHandlerWithContext(h http.Handler) fiber.Handler { handler := fasthttpadaptor.NewFastHTTPHandler(h) return func(c fiber.Ctx) error { // Store the Fiber user context (c.Context()) in the fasthttp request context // so adapted net/http handlers can retrieve it via adaptor.LocalContextFromHTTPRequest(r) c.RequestCtx().SetUserValue(localContextKey, c.Context()) handler(c.RequestCtx()) return nil } } // LocalContextFromHTTPRequest extracts the Fiber user context previously stored into r.Context() by the adaptor. func LocalContextFromHTTPRequest(r *http.Request) (context.Context, bool) { if r == nil { return nil, false } ctx, err := r.Context().Value(localContextKey).(context.Context) return ctx, err } // ConvertRequest converts a fiber.Ctx to a http.Request. // forServer should be set to true when the http.Request is going to be passed to a http.Handler. func ConvertRequest(c fiber.Ctx, forServer bool) (*http.Request, error) { var req http.Request if err := fasthttpadaptor.ConvertRequest(c.RequestCtx(), &req, forServer); err != nil { return nil, err //nolint:wrapcheck // This must not be wrapped } return &req, nil } // CopyContextToFiberContext copies the values of context.Context to a fasthttp.RequestCtx. // This function safely handles struct fields, using unsafe operations only when necessary for unexported fields. // // Deprecated: This function uses reflection and unsafe pointers; consider using explicit context passing. func CopyContextToFiberContext(src any, requestContext *fasthttp.RequestCtx) { if requestContext == nil { return } v := reflect.ValueOf(src) if !v.IsValid() { return } // Deref pointer chains for v.Kind() == reflect.Ptr { if v.IsNil() { return } v = v.Elem() } t := v.Type() if t.Kind() != reflect.Struct { return } // Ensure addressable for safe unsafe-access of unexported fields if !v.CanAddr() { tmp := reflect.New(t) tmp.Elem().Set(v) v = tmp.Elem() } contextValues := v contextKeys := t var lastKey any for i := 0; i < contextValues.NumField(); i++ { reflectValue := contextValues.Field(i) reflectField := contextKeys.Field(i) if reflectField.Name == "noCopy" { break } // Avoid unsafe access for unexported fields; use safe reflection where possible if !reflectValue.CanInterface() { /* #nosec */ reflectValue = reflect.NewAt(reflectValue.Type(), unsafe.Pointer(reflectValue.UnsafeAddr())).Elem() } switch reflectField.Name { case "Context": CopyContextToFiberContext(reflectValue.Interface(), requestContext) case "key": lastKey = reflectValue.Interface() case "val": if lastKey != nil { requestContext.SetUserValue(lastKey, reflectValue.Interface()) lastKey = nil } default: continue } } } // HTTPMiddleware wraps net/http middleware to fiber middleware func HTTPMiddleware(mw func(http.Handler) http.Handler) fiber.Handler { return func(c fiber.Ctx) error { var next bool nextHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { next = true c.Request().Header.SetMethod(r.Method) c.Request().SetRequestURI(r.RequestURI) c.Request().SetHost(r.Host) c.Request().Header.SetHost(r.Host) // Remove all cookies before setting, see https://github.com/valyala/fasthttp/pull/1864 c.Request().Header.DelAllCookies() for key, val := range r.Header { for _, v := range val { c.Request().Header.Set(key, v) } } CopyContextToFiberContext(r.Context(), c.RequestCtx()) }) if err := HTTPHandler(mw(nextHandler))(c); err != nil { return err } if next { return c.Next() } return nil } } // FiberHandler wraps fiber handler to net/http handler func FiberHandler(h fiber.Handler) http.Handler { return FiberHandlerFunc(h) } // FiberHandlerFunc wraps fiber handler to net/http handler func func FiberHandlerFunc(h fiber.Handler) http.HandlerFunc { return handlerFunc(fiber.New(), h) } // FiberApp wraps fiber app to net/http handler func func FiberApp(app *fiber.App) http.HandlerFunc { return handlerFunc(app) } func isUnixNetwork(network string) bool { return network == "unix" || network == "unixgram" || network == "unixpacket" } func resolveRemoteAddr(remoteAddr string, localAddr any) (net.Addr, error) { if addr, ok := localAddr.(net.Addr); ok && isUnixNetwork(addr.Network()) { return addr, nil } // Validate input to prevent malformed addresses if remoteAddr == "" { return nil, ErrRemoteAddrEmpty } resolved, err := net.ResolveTCPAddr("tcp", remoteAddr) if err == nil { return resolved, nil } var addrErr *net.AddrError if errors.As(err, &addrErr) && addrErr.Err == "missing port in address" { if len(remoteAddr) > 253 { // Max hostname length return nil, ErrRemoteAddrTooLong } remoteAddr = net.JoinHostPort(remoteAddr, "80") resolved, err2 := net.ResolveTCPAddr("tcp", remoteAddr) if err2 != nil { return nil, fmt.Errorf("failed to resolve TCP address after adding port: %w", err2) } return resolved, nil } return nil, fmt.Errorf("failed to resolve TCP address: %w", err) } func handlerFunc(app *fiber.App, h ...fiber.Handler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { req := fasthttp.AcquireRequest() defer fasthttp.ReleaseRequest(req) // Convert net/http -> fasthttp request with size limit maxBodySize := int64(app.Config().BodyLimit) if r.Body != nil { if r.ContentLength > maxBodySize { http.Error(w, utils.StatusMessage(fiber.StatusRequestEntityTooLarge), fiber.StatusRequestEntityTooLarge) return } limitedReader := io.LimitReader(r.Body, maxBodySize) n, err := io.Copy(req.BodyWriter(), limitedReader) req.Header.SetContentLength(int(n)) if err != nil { http.Error(w, utils.StatusMessage(fiber.StatusInternalServerError), fiber.StatusInternalServerError) return } } req.Header.SetMethod(r.Method) req.SetRequestURI(r.RequestURI) req.SetHost(r.Host) req.Header.SetHost(r.Host) for key, val := range r.Header { for _, v := range val { req.Header.Set(key, v) } } remoteAddr, err := resolveRemoteAddr(r.RemoteAddr, r.Context().Value(http.LocalAddrContextKey)) if err != nil { remoteAddr = nil // Fallback to nil } // New fasthttp Ctx from pool fctx := ctxPool.Get().(*fasthttp.RequestCtx) //nolint:forcetypeassert,errcheck // not needed fctx.Response.Reset() fctx.Request.Reset() defer ctxPool.Put(fctx) fctx.Init(req, remoteAddr, &disableLogger{}) if len(h) > 0 { // New fiber Ctx ctx := app.AcquireCtx(fctx) defer app.ReleaseCtx(ctx) // Execute fiber Ctx err := h[0](ctx) if err != nil { _ = app.Config().ErrorHandler(ctx, err) //nolint:errcheck // not needed } } else { // Execute fasthttp Ctx though app.Handler app.Handler()(fctx) } // Convert fasthttp Ctx -> net/http for k, v := range fctx.Response.Header.All() { w.Header().Add(string(k), string(v)) } w.WriteHeader(fctx.Response.StatusCode()) // Check if streaming is not possible or unnecessary. bodyStream := fctx.Response.BodyStream() flusher, ok := w.(http.Flusher) if !ok || bodyStream == nil { _, _ = w.Write(fctx.Response.Body()) //nolint:errcheck // not needed return } // Stream fctx.Response.BodyStream() -> w // in chunks. bufPtr, ok := bufferPool.Get().(*[bufferSize]byte) if !ok { panic(fmt.Errorf("failed to type-assert to *[%d]byte", bufferSize)) } defer bufferPool.Put(bufPtr) buf := bufPtr[:] for { n, err := bodyStream.Read(buf) if n > 0 { if _, writeErr := w.Write(buf[:n]); writeErr != nil { break } flusher.Flush() } if err != nil { break } } } } ================================================ FILE: middleware/adaptor/adaptor_test.go ================================================ //nolint:contextcheck,revive // Much easier to just ignore memory leaks in tests package adaptor import ( "bufio" "bytes" "context" "errors" "fmt" "io" "net" "net/http" "net/http/httptest" "net/url" "os" "path/filepath" "strings" "testing" "time" "github.com/gofiber/fiber/v3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) const ( expectedRequestURI = "/foo/bar?baz=123" expectedBody = "body 123 foo bar baz" expectedHost = "foobar.com" expectedRemoteAddr = "1.2.3.4:6789" ) func Test_HTTPHandler(t *testing.T) { t.Parallel() expectedMethod := fiber.MethodPost expectedProto := "HTTP/1.1" expectedProtoMajor := 1 expectedProtoMinor := 1 expectedContentLength := len(expectedBody) expectedHeader := map[string]string{ "Foo-Bar": "baz", "Abc": "defg", "XXX-Remote-Addr": "123.43.4543.345", } expectedURL, err := url.ParseRequestURI(expectedRequestURI) require.NoError(t, err) type contextKeyType string expectedContextKey := contextKeyType("contextKey") expectedContextValue := "contextValue" callsCount := 0 nethttpH := func(w http.ResponseWriter, r *http.Request) { callsCount++ assert.Equal(t, expectedMethod, r.Method, "Method") assert.Equal(t, expectedProto, r.Proto, "Proto") assert.Equal(t, expectedProtoMajor, r.ProtoMajor, "ProtoMajor") assert.Equal(t, expectedProtoMinor, r.ProtoMinor, "ProtoMinor") assert.Equal(t, expectedRequestURI, r.RequestURI, "RequestURI") assert.Equal(t, expectedContentLength, int(r.ContentLength), "ContentLength") assert.Empty(t, r.TransferEncoding, "TransferEncoding") assert.Equal(t, expectedHost, r.Host, "Host") assert.Equal(t, expectedRemoteAddr, r.RemoteAddr, "RemoteAddr") body, readErr := io.ReadAll(r.Body) assert.NoError(t, readErr) assert.Equal(t, expectedBody, string(body), "Body") assert.Equal(t, expectedURL, r.URL, "URL") assert.Equal(t, expectedContextValue, r.Context().Value(expectedContextKey), "Context") for k, expectedV := range expectedHeader { v := r.Header.Get(k) assert.Equal(t, expectedV, v, "Header") } w.Header().Set("Header1", "value1") w.Header().Set("Header2", "value2") w.WriteHeader(http.StatusBadRequest) fmt.Fprintf(w, "request body is %q", body) } fiberH := HTTPHandlerFunc(http.HandlerFunc(nethttpH)) fiberH = setFiberContextValueMiddleware(fiberH, expectedContextKey, expectedContextValue) var fctx fasthttp.RequestCtx var req fasthttp.Request req.Header.SetMethod(expectedMethod) req.SetRequestURI(expectedRequestURI) req.Header.SetHost(expectedHost) req.BodyWriter().Write([]byte(expectedBody)) //nolint:errcheck // not needed for k, v := range expectedHeader { req.Header.Set(k, v) } remoteAddr, err := net.ResolveTCPAddr("tcp", expectedRemoteAddr) require.NoError(t, err) fctx.Init(&req, remoteAddr, &disableLogger{}) app := fiber.New() ctx := app.AcquireCtx(&fctx) defer app.ReleaseCtx(ctx) err = fiberH(ctx) require.NoError(t, err) require.Equal(t, 1, callsCount, "callsCount") resp := &fctx.Response require.Equal(t, http.StatusBadRequest, resp.StatusCode(), "StatusCode") require.Equal(t, "value1", string(resp.Header.Peek("Header1")), "Header1") require.Equal(t, "value2", string(resp.Header.Peek("Header2")), "Header2") expectedResponseBody := fmt.Sprintf("request body is %q", expectedBody) require.Equal(t, expectedResponseBody, string(resp.Body()), "Body") } func Test_HTTPHandler_Flush(t *testing.T) { t.Parallel() expectedMethod := fiber.MethodPost expectedProto := "HTTP/1.1" expectedProtoMajor := 1 expectedProtoMinor := 1 expectedContentLength := len(expectedBody) expectedHeader := map[string]string{ "Foo-Bar": "baz", "Abc": "defg", "XXX-Remote-Addr": "123.43.4543.345", } expectedURL, err := url.ParseRequestURI(expectedRequestURI) require.NoError(t, err) type contextKeyType string expectedContextKey := contextKeyType("contextKey") expectedContextValue := "contextValue" callsCount := 0 nethttpH := func(w http.ResponseWriter, r *http.Request) { callsCount++ assert.Equal(t, expectedMethod, r.Method, "Method") assert.Equal(t, expectedProto, r.Proto, "Proto") assert.Equal(t, expectedProtoMajor, r.ProtoMajor, "ProtoMajor") assert.Equal(t, expectedProtoMinor, r.ProtoMinor, "ProtoMinor") assert.Equal(t, expectedRequestURI, r.RequestURI, "RequestURI") assert.Equal(t, expectedContentLength, int(r.ContentLength), "ContentLength") assert.Empty(t, r.TransferEncoding, "TransferEncoding") assert.Equal(t, expectedHost, r.Host, "Host") assert.Equal(t, expectedRemoteAddr, r.RemoteAddr, "RemoteAddr") body, readErr := io.ReadAll(r.Body) assert.NoError(t, readErr) assert.Equal(t, expectedBody, string(body), "Body") assert.Equal(t, expectedURL, r.URL, "URL") assert.Equal(t, expectedContextValue, r.Context().Value(expectedContextKey), "Context") for k, expectedV := range expectedHeader { v := r.Header.Get(k) assert.Equal(t, expectedV, v, "Header") } w.Header().Set("Header1", "value1") w.Header().Set("Header2", "value2") w.WriteHeader(http.StatusBadRequest) fmt.Fprintf(w, "request body is ") flusher, ok := w.(http.Flusher) if !ok { t.Fatal("w does not implement http.Flusher") } flusher.Flush() fmt.Fprintf(w, "%q", body) } fiberH := HTTPHandlerFunc(http.HandlerFunc(nethttpH)) fiberH = setFiberContextValueMiddleware(fiberH, expectedContextKey, expectedContextValue) var fctx fasthttp.RequestCtx var req fasthttp.Request req.Header.SetMethod(expectedMethod) req.SetRequestURI(expectedRequestURI) req.Header.SetHost(expectedHost) req.BodyWriter().Write([]byte(expectedBody)) //nolint:errcheck // not needed for k, v := range expectedHeader { req.Header.Set(k, v) } remoteAddr, err := net.ResolveTCPAddr("tcp", expectedRemoteAddr) require.NoError(t, err) fctx.Init(&req, remoteAddr, &disableLogger{}) app := fiber.New() ctx := app.AcquireCtx(&fctx) defer app.ReleaseCtx(ctx) err = fiberH(ctx) require.NoError(t, err) require.Equal(t, 1, callsCount, "callsCount") resp := &fctx.Response require.Equal(t, http.StatusBadRequest, resp.StatusCode(), "StatusCode") require.Equal(t, "value1", string(resp.Header.Peek("Header1")), "Header1") require.Equal(t, "value2", string(resp.Header.Peek("Header2")), "Header2") expectedResponseBody := fmt.Sprintf("request body is %q", expectedBody) require.Equal(t, expectedResponseBody, string(resp.Body()), "Body") } func Test_HTTPHandler_Flush_App_Test(t *testing.T) { t.Parallel() app := fiber.New() app.Get("/", HTTPHandlerFunc(func(w http.ResponseWriter, r *http.Request) { flusher, ok := w.(http.Flusher) if !ok { t.Fatal("w does not implement http.Flusher") } w.WriteHeader(fiber.StatusOK) fmt.Fprintf(w, "Hello ") flusher.Flush() time.Sleep(500 * time.Millisecond) fmt.Fprintf(w, "World!") })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck // not needed require.Equal(t, fiber.StatusOK, resp.StatusCode) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "Hello World!", string(body)) } func Test_HTTPHandler_App_Test_Interrupted(t *testing.T) { t.Parallel() app := fiber.New() app.Get("/", HTTPHandlerFunc(func(w http.ResponseWriter, r *http.Request) { flusher, ok := w.(http.Flusher) if !ok { t.Fatalf("w does not implement http.Flusher") } w.WriteHeader(fiber.StatusOK) fmt.Fprintf(w, "Hello ") flusher.Flush() time.Sleep(500 * time.Millisecond) fmt.Fprintf(w, "World!") })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody), fiber.TestConfig{ Timeout: 200 * time.Millisecond, FailOnTimeout: true, // Changed to true to test interrupted behavior }) // With FailOnTimeout: true, we should get a timeout error require.ErrorIs(t, err, os.ErrDeadlineExceeded) require.Nil(t, resp) } func Test_LocalContextFromHTTPRequest(t *testing.T) { t.Parallel() t.Run("nil request", func(t *testing.T) { t.Parallel() ctx, ok := LocalContextFromHTTPRequest(nil) require.False(t, ok) require.Nil(t, ctx) }) t.Run("request without stored context key", func(t *testing.T) { t.Parallel() req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) ctx, ok := LocalContextFromHTTPRequest(req) require.False(t, ok) require.Nil(t, ctx) }) t.Run("request with stored context key", func(t *testing.T) { t.Parallel() expectedCtx := context.WithValue(context.Background(), contextKey("k"), "v") req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody).WithContext( context.WithValue(context.Background(), localContextKey, expectedCtx), ) ctx, ok := LocalContextFromHTTPRequest(req) require.True(t, ok) require.Equal(t, expectedCtx, ctx) }) } func Test_HTTPHandlerWithContext_local_context(t *testing.T) { t.Parallel() app := fiber.New() // unique type for avoiding collisions in context type key struct{} var testKey key const testVal string = "test-value" // a middleware to add a value to the local context app.Use(func(c fiber.Ctx) error { ctx := context.WithValue(c.Context(), testKey, testVal) c.SetContext(ctx) return c.Next() }) // a handler that checks if the value has been appended to the local context app.Get("/", HTTPHandlerWithContext(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx, ok := LocalContextFromHTTPRequest(r) if !ok { http.Error(w, "local context not found", http.StatusInternalServerError) return } val, ok := ctx.Value(testKey).(string) if !ok { http.Error(w, "invalid context value", http.StatusInternalServerError) return } w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.WriteHeader(http.StatusOK) if _, err := w.Write([]byte(val)); err != nil { t.Logf("write failed: %v", err) } }))) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody), fiber.TestConfig{ Timeout: 200 * time.Millisecond, FailOnTimeout: false, }) require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck // no need body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, testVal, string(body)) require.Equal(t, fiber.StatusOK, resp.StatusCode) } type contextKey string func (c contextKey) String() string { return "test-" + string(c) } var ( TestContextKey = contextKey("TestContextKey") TestContextSecondKey = contextKey("TestContextSecondKey") ) func Test_HTTPMiddleware(t *testing.T) { t.Parallel() tests := []struct { name string url string method string statusCode int }{ { name: "Should return 200", url: "/", method: "POST", statusCode: 200, }, { name: "Should return 405", url: "/", method: "GET", statusCode: 405, }, { name: "Should return 400", url: "/unknown", method: "POST", statusCode: 404, }, } nethttpMW := func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { w.WriteHeader(http.StatusMethodNotAllowed) return } r = r.WithContext(context.WithValue(r.Context(), TestContextKey, "okay")) r = r.WithContext(context.WithValue(r.Context(), TestContextSecondKey, "not_okay")) r = r.WithContext(context.WithValue(r.Context(), TestContextSecondKey, "okay")) next.ServeHTTP(w, r) }) } app := fiber.New() app.Use(HTTPMiddleware(nethttpMW)) app.Post("/", func(c fiber.Ctx) error { value := c.RequestCtx().Value(TestContextKey) val, ok := value.(string) if !ok { t.Error("unexpected error on type-assertion") } if value != nil { c.Set("context_okay", val) } value = c.RequestCtx().Value(TestContextSecondKey) if value != nil { val, ok := value.(string) if !ok { t.Error("unexpected error on type-assertion") } c.Set("context_second_okay", val) } return c.SendStatus(fiber.StatusOK) }) for _, tt := range tests { req, err := http.NewRequestWithContext(context.Background(), tt.method, tt.url, http.NoBody) req.Host = expectedHost require.NoError(t, err) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, tt.statusCode, resp.StatusCode, "StatusCode") } req, err := http.NewRequestWithContext(context.Background(), fiber.MethodPost, "/", http.NoBody) req.Host = expectedHost require.NoError(t, err) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, "okay", resp.Header.Get("context_okay")) require.Equal(t, "okay", resp.Header.Get("context_second_okay")) } func Test_HTTPMiddlewareWithCookies(t *testing.T) { t.Parallel() const ( cookieHeader = "Cookie" setCookieHeader = "Set-Cookie" cookieOneName = "cookieOne" cookieTwoName = "cookieTwo" cookieOneValue = "valueCookieOne" cookieTwoValue = "valueCookieTwo" ) nethttpMW := func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { w.WriteHeader(http.StatusMethodNotAllowed) return } next.ServeHTTP(w, r) }) } app := fiber.New() app.Use(HTTPMiddleware(nethttpMW)) app.Post("/", func(c fiber.Ctx) error { // RETURNING CURRENT COOKIES TO RESPONSE cookies := strings.Split(c.Get(cookieHeader), "; ") for _, cookie := range cookies { c.Set(setCookieHeader, cookie) } return c.SendStatus(fiber.StatusOK) }) // Test case for POST request with cookies t.Run("POST request with cookies", func(t *testing.T) { t.Parallel() req, err := http.NewRequestWithContext(context.Background(), fiber.MethodPost, "/", http.NoBody) require.NoError(t, err) req.AddCookie(&http.Cookie{Name: cookieOneName, Value: cookieOneValue}) req.AddCookie(&http.Cookie{Name: cookieTwoName, Value: cookieTwoValue}) resp, err := app.Test(req) require.NoError(t, err) cookies := resp.Cookies() require.Len(t, cookies, 2) for _, cookie := range cookies { switch cookie.Name { case cookieOneName: require.Equal(t, cookieOneValue, cookie.Value) case cookieTwoName: require.Equal(t, cookieTwoValue, cookie.Value) default: t.Error("unexpected cookie key") } } }) // New test case for GET request t.Run("GET request", func(t *testing.T) { t.Parallel() req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/", http.NoBody) require.NoError(t, err) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) }) // New test case for request without cookies t.Run("POST request without cookies", func(t *testing.T) { t.Parallel() req, err := http.NewRequestWithContext(context.Background(), fiber.MethodPost, "/", http.NoBody) require.NoError(t, err) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) require.Empty(t, resp.Cookies()) }) } func Test_FiberHandler(t *testing.T) { t.Parallel() testFiberToHandlerFunc(t, false) } func Test_FiberHandler_BodyLimit(t *testing.T) { t.Parallel() tests := []struct { name string bodyLimit int bodySize int expectedStatus int }{ { name: "DefaultLimitExceededReturns413", bodySize: fiber.DefaultBodyLimit + 1024, expectedStatus: fiber.StatusRequestEntityTooLarge, }, { name: "CustomLimitExceededReturns413", bodyLimit: 1 * 1024 * 1024, bodySize: (1 * 1024 * 1024) + 1, expectedStatus: fiber.StatusRequestEntityTooLarge, }, { name: "CustomLimitAllowsLargerPayload", bodyLimit: 2 * fiber.DefaultBodyLimit, bodySize: fiber.DefaultBodyLimit + 512, expectedStatus: fiber.StatusOK, }, { name: "ZeroLimitConfigFallsBackToDefault", bodyLimit: 0, bodySize: fiber.DefaultBodyLimit - 256, expectedStatus: fiber.StatusOK, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() app := fiber.New(fiber.Config{ BodyLimit: tt.bodyLimit, }) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) handlerFunc := FiberApp(app) body := bytes.Repeat([]byte("a"), tt.bodySize) req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(body)) req.ContentLength = int64(len(body)) resp := httptest.NewRecorder() handlerFunc.ServeHTTP(resp, req) require.Equal(t, tt.expectedStatus, resp.Code) }) } } func Test_FiberApp(t *testing.T) { t.Parallel() testFiberToHandlerFunc(t, false, fiber.New()) } func Test_FiberHandlerDefaultPort(t *testing.T) { t.Parallel() testFiberToHandlerFunc(t, true) } func Test_FiberAppDefaultPort(t *testing.T) { t.Parallel() testFiberToHandlerFunc(t, true, fiber.New()) } func testFiberToHandlerFunc(t *testing.T, checkDefaultPort bool, app ...*fiber.App) { t.Helper() expectedMethod := fiber.MethodPost expectedContentLength := len(expectedBody) expectedRemoteAddr := "1.2.3.4:6789" if checkDefaultPort { expectedRemoteAddr = "1.2.3.4:80" } expectedHeader := map[string]string{ "Foo-Bar": "baz", "Abc": "defg", "XXX-Remote-Addr": "123.43.4543.345", } expectedURL, err := url.ParseRequestURI(expectedRequestURI) require.NoError(t, err) callsCount := 0 fiberH := func(c fiber.Ctx) error { callsCount++ require.Equal(t, expectedMethod, c.Method(), "Method") require.Equal(t, expectedRequestURI, string(c.RequestCtx().RequestURI()), "RequestURI") require.Equal(t, expectedContentLength, c.RequestCtx().Request.Header.ContentLength(), "ContentLength") require.Equal(t, expectedHost, c.Hostname(), "Host") require.Equal(t, expectedHost, string(c.Request().Header.Host()), "Host") require.Equal(t, "http://"+expectedHost, c.BaseURL(), "BaseURL") require.Equal(t, expectedRemoteAddr, c.RequestCtx().RemoteAddr().String(), "RemoteAddr") body := string(c.Body()) require.Equal(t, expectedBody, body, "Body") require.Equal(t, expectedURL.String(), c.OriginalURL(), "URL") for k, expectedV := range expectedHeader { v := c.Get(k) require.Equal(t, expectedV, v, "Header") } c.Set("Header1", "value1") c.Set("Header2", "value2") c.Status(fiber.StatusBadRequest) _, err := c.Write(fmt.Appendf(nil, "request body is %q", body)) return err } var handlerFunc http.HandlerFunc if len(app) > 0 { app[0].Post("/foo/bar", fiberH) handlerFunc = FiberApp(app[0]) } else { handlerFunc = FiberHandlerFunc(fiberH) } var r http.Request r.Method = expectedMethod r.Body = &netHTTPBody{b: []byte(expectedBody)} r.RequestURI = expectedRequestURI r.ContentLength = int64(expectedContentLength) r.Host = expectedHost r.RemoteAddr = expectedRemoteAddr if checkDefaultPort { r.RemoteAddr = "1.2.3.4" } hdr := make(http.Header) for k, v := range expectedHeader { hdr.Set(k, v) } r.Header = hdr var w netHTTPResponseWriter handlerFunc.ServeHTTP(&w, &r) require.Equal(t, http.StatusBadRequest, w.StatusCode(), "StatusCode") require.Equal(t, "value1", w.Header().Get("Header1"), "Header1") require.Equal(t, "value2", w.Header().Get("Header2"), "Header2") expectedResponseBody := fmt.Sprintf("request body is %q", expectedBody) require.Equal(t, expectedResponseBody, string(w.body), "Body") } func setFiberContextValueMiddleware(next fiber.Handler, key, value any) fiber.Handler { return func(c fiber.Ctx) error { c.Locals(key, value) return next(c) } } func Test_FiberHandler_RequestNilBody(t *testing.T) { t.Parallel() expectedMethod := fiber.MethodGet expectedRequestURI := "/foo/bar" expectedContentLength := 0 callsCount := 0 fiberH := func(c fiber.Ctx) error { callsCount++ require.Equal(t, expectedMethod, c.Method(), "Method") require.Equal(t, expectedRequestURI, string(c.RequestCtx().RequestURI()), "RequestURI") require.Equal(t, expectedContentLength, c.RequestCtx().Request.Header.ContentLength(), "ContentLength") _, err := c.WriteString("request body is nil") return err } nethttpH := FiberHandler(fiberH) var r http.Request r.Method = expectedMethod r.RequestURI = expectedRequestURI var w netHTTPResponseWriter nethttpH.ServeHTTP(&w, &r) expectedResponseBody := "request body is nil" require.Equal(t, expectedResponseBody, string(w.body), "Body") } type netHTTPBody struct { b []byte } func (r *netHTTPBody) Read(p []byte) (int, error) { if len(r.b) == 0 { return 0, io.EOF } n := copy(p, r.b) r.b = r.b[n:] return n, nil } func (r *netHTTPBody) Close() error { r.b = r.b[:0] return nil } func createTestRequest(method, uri, remoteAddr string, body io.Reader) *http.Request { r := &http.Request{ Method: method, RequestURI: uri, RemoteAddr: remoteAddr, Header: make(http.Header), Body: http.NoBody, } if body != nil { if rc, ok := body.(io.ReadCloser); ok { r.Body = rc } else { r.Body = io.NopCloser(body) } } return r } func executeHandlerTest(_ *testing.T, handler http.HandlerFunc, req *http.Request) *netHTTPResponseWriter { w := &netHTTPResponseWriter{} handler.ServeHTTP(w, req) return w } type netHTTPResponseWriter struct { h http.Header body []byte statusCode int } func (w *netHTTPResponseWriter) StatusCode() int { if w.statusCode == 0 { return http.StatusOK } return w.statusCode } func (w *netHTTPResponseWriter) Header() http.Header { if w.h == nil { w.h = make(http.Header) } return w.h } func (w *netHTTPResponseWriter) WriteHeader(statusCode int) { w.statusCode = statusCode } func (w *netHTTPResponseWriter) Write(p []byte) (int, error) { w.body = append(w.body, p...) return len(p), nil } func (w *netHTTPResponseWriter) Flush() {} func Test_ConvertRequest(t *testing.T) { t.Parallel() t.Run("successful conversion", func(t *testing.T) { t.Parallel() app := fiber.New() app.Get("/test", func(c fiber.Ctx) error { httpReq, err := ConvertRequest(c, false) if err != nil { return err } return c.SendString("Request URL: " + httpReq.URL.String()) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test?hello=world&another=test", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, http.StatusOK, resp.StatusCode, "Status code") body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "Request URL: /test?hello=world&another=test", string(body)) }) t.Run("conversion error handling", func(t *testing.T) { t.Parallel() // Test error case by creating a context with an invalid URL that will cause fasthttpadaptor.ConvertRequest to fail app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) // Create a malformed request URI that should cause conversion to fail ctx.Request().SetRequestURI("http://[::1:bad:url") // Invalid URL format ctx.Request().Header.SetMethod(fiber.MethodGet) _, err := ConvertRequest(ctx, true) // Use forServer=true which does more validation if err == nil { // If the above doesn't fail, try a different approach ctx.Request().SetRequestURI("\x00\x01\x02") // Invalid characters in URI _, err = ConvertRequest(ctx, true) } // Note: This test may pass if fasthttpadaptor is very permissive // The important thing is that our function doesn't panic if err != nil { require.Error(t, err, "Expected error from fasthttpadaptor.ConvertRequest") } }) } func Test_CopyContextToFiberContext(t *testing.T) { t.Parallel() t.Run("unsupported context type", func(t *testing.T) { t.Parallel() // Test with non-struct context (should return early) var fctx fasthttp.RequestCtx stringContext := "not a struct" // This should not panic and should handle the non-struct gracefully CopyContextToFiberContext(&stringContext, &fctx) // No assertions needed - just ensuring it doesn't panic }) t.Run("context with unknown field", func(t *testing.T) { t.Parallel() // Test the default case (continue statement coverage) type customContext struct { UnknownField string } var fctx fasthttp.RequestCtx ctx := customContext{UnknownField: "test"} // This should hit the default case and continue CopyContextToFiberContext(&ctx, &fctx) // No assertions needed - just ensuring it doesn't panic and continues }) t.Run("invalid src", func(t *testing.T) { t.Parallel() var fctx fasthttp.RequestCtx CopyContextToFiberContext(nil, &fctx) // Add assertion to ensure no panic and coverage is detected assert.NotNil(t, &fctx) }) t.Run("nil request context", func(t *testing.T) { t.Parallel() ctx := context.WithValue(context.Background(), contextKey("nil-request-context"), "value") require.NotPanics(t, func() { CopyContextToFiberContext(ctx, nil) }) }) t.Run("nil pointer", func(t *testing.T) { t.Parallel() var nilPtr *context.Context // Nil pointer to a context var fctx fasthttp.RequestCtx CopyContextToFiberContext(nilPtr, &fctx) // Add assertion to ensure no panic and coverage is detected assert.NotNil(t, &fctx) }) t.Run("copies key value pairs", func(t *testing.T) { t.Parallel() var fctx fasthttp.RequestCtx key := contextKey("copy-key") expectedValue := "copy-value" ctx := context.WithValue(context.Background(), key, expectedValue) CopyContextToFiberContext(ctx, &fctx) require.Equal(t, expectedValue, fctx.UserValue(key)) }) t.Run("nested context wrappers", func(t *testing.T) { t.Parallel() var fctx fasthttp.RequestCtx keyA := contextKey("nested-a") keyB := contextKey("nested-b") baseCtx := context.WithValue(context.Background(), keyA, "value-a") cancelCtx, cancel := context.WithCancel(baseCtx) t.Cleanup(cancel) wrappedCtx := context.WithValue(cancelCtx, keyB, "value-b") CopyContextToFiberContext(wrappedCtx, &fctx) require.Equal(t, "value-a", fctx.UserValue(keyA)) require.Equal(t, "value-b", fctx.UserValue(keyB)) }) t.Run("multi-level pointer", func(t *testing.T) { t.Parallel() var fctx fasthttp.RequestCtx ctx := context.Background() ptr := &ctx doublePtr := &ptr // Test deref pointer chains CopyContextToFiberContext(doublePtr, &fctx) // No assertions needed - just ensuring it doesn't panic }) t.Run("non-addressable struct", func(t *testing.T) { t.Parallel() var fctx fasthttp.RequestCtx type testStruct struct { Field string } // Pass struct value directly to test addressability check CopyContextToFiberContext(testStruct{Field: "test"}, &fctx) // No assertions needed - just ensuring it doesn't panic and creates temporary }) } func Test_HTTPMiddleware_ErrorHandling(t *testing.T) { t.Parallel() // Test middleware that returns an error from HTTPHandler errorMiddleware := func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // This will cause an error in the underlying handler w.WriteHeader(http.StatusInternalServerError) next.ServeHTTP(w, r) }) } fiberHandler := func(c fiber.Ctx) error { return fiber.NewError(fiber.StatusBadRequest, "test error") } app := fiber.New() app.Use(HTTPMiddleware(errorMiddleware)) app.Get("/error", fiberHandler) resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/error", http.NoBody)) require.NoError(t, err) // The error should be handled by the error handler require.Equal(t, fiber.StatusBadRequest, resp.StatusCode) } func Test_FiberHandler_IOError(t *testing.T) { t.Parallel() // Test io.Copy error by using a failing reader fiberH := func(c fiber.Ctx) error { return c.SendString("should not reach here") } handlerFunc := FiberHandlerFunc(fiberH) // Create a reader that fails failingReader := &failingReader{} r := &http.Request{ Method: http.MethodPost, RequestURI: "/test", Body: failingReader, ContentLength: 100, // Set content length so it tries to read Header: make(http.Header), } w := &netHTTPResponseWriter{} handlerFunc.ServeHTTP(w, r) // Should return 500 due to io.Copy error require.Equal(t, http.StatusInternalServerError, w.StatusCode()) } func Test_FiberHandler_WithErrorInHandler(t *testing.T) { t.Parallel() // Test error handling in fiber handler fiberH := func(c fiber.Ctx) error { return fiber.NewError(fiber.StatusTeapot, "I'm a teapot") } handlerFunc := FiberHandlerFunc(fiberH) r := &http.Request{ Method: http.MethodGet, RequestURI: "/test", Header: make(http.Header), Body: http.NoBody, } w := &netHTTPResponseWriter{} handlerFunc.ServeHTTP(w, r) // Should return the error status require.Equal(t, fiber.StatusTeapot, w.StatusCode()) } func Test_FiberHandler_WithSendStreamWriter(t *testing.T) { t.Parallel() // Test streaming functionality in FiberHandler using SendStreamWriter. fiberH := func(c fiber.Ctx) error { c.Status(fiber.StatusTeapot) return c.SendStreamWriter(func(w *bufio.Writer) { w.WriteString("Hello ") //nolint:errcheck // not needed w.Flush() //nolint:errcheck // not needed time.Sleep(200 * time.Millisecond) // Simulate a long operation w.WriteString("World!") //nolint:errcheck // not needed }) } handlerFunc := FiberHandlerFunc(fiberH) r := &http.Request{ Method: http.MethodGet, RequestURI: "/test", Header: make(http.Header), Body: http.NoBody, } w := &netHTTPResponseWriter{} handlerFunc.ServeHTTP(w, r) // Should return the error status require.Equal(t, fiber.StatusTeapot, w.StatusCode()) require.Equal(t, "Hello World!", string(w.body)) } func Test_FiberHandler_WithInterruptedSendStreamWriter(t *testing.T) { t.Parallel() // Test streaming functionality to ensure data is sent even during a timeout. fiberH := func(c fiber.Ctx) error { c.Status(fiber.StatusTeapot) return c.SendStreamWriter(func(w *bufio.Writer) { w.WriteString("Hello ") //nolint:errcheck // not needed w.Flush() //nolint:errcheck // not needed time.Sleep(500 * time.Millisecond) // Simulate a long operation w.WriteString("World!") //nolint:errcheck // not needed }) } handlerFunc := FiberHandlerFunc(fiberH) // Start a mock HTTP server using the handlerFunc server := &http.Server{ Handler: handlerFunc, ReadTimeout: 5 * time.Second, WriteTimeout: 10 * time.Second, } listener, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") require.NoError(t, err) addr := fmt.Sprintf("http://%s", listener.Addr()) go func() { server.Serve(listener) //nolint:errcheck // not needed }() defer func() { require.NoError(t, server.Close()) }() cc := &http.Client{ Timeout: 200 * time.Millisecond, } resp, err := cc.Get(addr) //nolint:noctx // ctx is not needed require.NoError(t, err) require.NotNil(t, resp) require.Equal(t, fiber.StatusTeapot, resp.StatusCode) body, readErr := io.ReadAll(resp.Body) require.ErrorIs(t, readErr, context.DeadlineExceeded) require.Equal(t, "Hello ", string(body)) } // failingReader always returns an error when Read is called type failingReader struct{} func (f *failingReader) Read(p []byte) (int, error) { return 0, errors.New("simulated read error") } func (f *failingReader) Close() error { return nil } // Benchmark for FiberHandlerFunc func Benchmark_FiberHandlerFunc(b *testing.B) { benchmarks := []struct { name string bodyContent []byte }{ { name: "No Content", bodyContent: nil, // No body content case }, { name: "100KB", bodyContent: make([]byte, 100*1024), }, { name: "500KB", bodyContent: make([]byte, 500*1024), }, { name: "1MB", bodyContent: make([]byte, 1*1024*1024), }, { name: "5MB", bodyContent: make([]byte, 5*1024*1024), }, { name: "10MB", bodyContent: make([]byte, 10*1024*1024), }, { name: "25MB", bodyContent: make([]byte, 25*1024*1024), }, { name: "50MB", bodyContent: make([]byte, 50*1024*1024), }, } fiberH := func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) } handlerFunc := FiberHandlerFunc(fiberH) for _, bm := range benchmarks { b.Run(bm.name, func(b *testing.B) { w := httptest.NewRecorder() var bodyBuffer *bytes.Buffer // Handle the "No Content" case where bodyContent is nil if bm.bodyContent != nil { bodyBuffer = bytes.NewBuffer(bm.bodyContent) } else { bodyBuffer = bytes.NewBuffer([]byte{}) // Empty buffer for no content } r := http.Request{ Method: http.MethodPost, Body: nil, } // Replace the empty Body with our buffer r.Body = io.NopCloser(bodyBuffer) defer r.Body.Close() //nolint:errcheck // not needed b.ReportAllocs() for b.Loop() { handlerFunc.ServeHTTP(w, &r) } }) } } func Benchmark_FiberHandlerFunc_Parallel(b *testing.B) { benchmarks := []struct { name string bodyContent []byte }{ { name: "No Content", bodyContent: nil, // No body content case }, { name: "100KB", bodyContent: make([]byte, 100*1024), }, { name: "500KB", bodyContent: make([]byte, 500*1024), }, { name: "1MB", bodyContent: make([]byte, 1*1024*1024), }, { name: "5MB", bodyContent: make([]byte, 5*1024*1024), }, { name: "10MB", bodyContent: make([]byte, 10*1024*1024), }, { name: "25MB", bodyContent: make([]byte, 25*1024*1024), }, { name: "50MB", bodyContent: make([]byte, 50*1024*1024), }, } fiberH := func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) } handlerFunc := FiberHandlerFunc(fiberH) for _, bm := range benchmarks { b.Run(bm.name, func(b *testing.B) { var bodyBuffer *bytes.Buffer // Handle the "No Content" case where bodyContent is nil if bm.bodyContent != nil { bodyBuffer = bytes.NewBuffer(bm.bodyContent) } else { bodyBuffer = bytes.NewBuffer([]byte{}) // Empty buffer for no content } b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { w := httptest.NewRecorder() r := http.Request{ Method: http.MethodPost, Body: nil, } // Replace the empty Body with our buffer r.Body = io.NopCloser(bodyBuffer) defer r.Body.Close() //nolint:errcheck // not needed for pb.Next() { handlerFunc(w, &r) } }) }) } } func Benchmark_HTTPHandler(b *testing.B) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("ok")) //nolint:errcheck // not needed }) var err error app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) defer func() { app.ReleaseCtx(ctx) }() b.ReportAllocs() b.ResetTimer() fiberHandler := HTTPHandler(handler) for b.Loop() { ctx.Request().Reset() ctx.Response().Reset() ctx.Request().SetRequestURI("/test") ctx.Request().Header.SetMethod("GET") err = fiberHandler(ctx) } require.NoError(b, err) } func Benchmark_HTTPHandlerWithContext(b *testing.B) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("ok")) //nolint:errcheck // not needed }) var err error app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) defer func() { app.ReleaseCtx(ctx) }() b.ReportAllocs() b.ResetTimer() type key struct{} var testKey key ctx.SetContext(context.WithValue(ctx.Context(), testKey, "gofiber")) fiberHandler := HTTPHandlerWithContext(handler) for b.Loop() { ctx.Request().Reset() ctx.Response().Reset() ctx.Request().SetRequestURI("/test") ctx.Request().Header.SetMethod("GET") err = fiberHandler(ctx) } require.NoError(b, err) } func Test_resolveRemoteAddr(t *testing.T) { t.Parallel() tests := []struct { expectedErr error localAddr any name string remoteAddr string errorContains string expectError bool }{ { name: "valid TCP address with port", remoteAddr: "192.168.1.1:8080", localAddr: nil, expectError: false, }, { name: "valid TCP address without port - should add default port 80", remoteAddr: "192.168.1.1", localAddr: nil, expectError: false, }, { name: "unix socket - should return local addr", remoteAddr: "irrelevant", localAddr: &net.UnixAddr{Name: "/tmp/test.sock", Net: "unix"}, expectError: false, }, { name: "invalid address - should fail", remoteAddr: "[invalid:address:format", localAddr: nil, expectError: true, errorContains: "failed to resolve TCP address:", }, { name: "invalid address after adding port - should fail", remoteAddr: "[invalid", localAddr: nil, expectError: true, errorContains: "failed to resolve TCP address after adding port:", }, { name: "empty address - should fail", remoteAddr: "", localAddr: nil, expectError: true, expectedErr: ErrRemoteAddrEmpty, }, { name: "too long address - should fail", remoteAddr: strings.Repeat("a", 254), localAddr: nil, expectError: true, expectedErr: ErrRemoteAddrTooLong, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() addr, err := resolveRemoteAddr(tt.remoteAddr, tt.localAddr) expectError := tt.expectedErr != nil || tt.errorContains != "" if expectError { require.Error(t, err) if tt.expectedErr != nil { require.ErrorIs(t, err, tt.expectedErr) } if tt.errorContains != "" { require.Contains(t, err.Error(), tt.errorContains) } require.Nil(t, addr) } else { require.NoError(t, err) require.NotNil(t, addr) } }) } } func Test_isUnixNetwork(t *testing.T) { t.Parallel() tests := []struct { name string network string expected bool }{ {"unix", "unix", true}, {"unixgram", "unixgram", true}, {"unixpacket", "unixpacket", true}, {"tcp", "tcp", false}, {"tcp4", "tcp4", false}, {"tcp6", "tcp6", false}, {"udp", "udp", false}, {"empty", "", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() result := isUnixNetwork(tt.network) require.Equal(t, tt.expected, result) }) } } func Test_FiberHandler_ErrorFallback(t *testing.T) { t.Parallel() // Test case where resolveRemoteAddr fails and falls back to nil fiberH := func(c fiber.Ctx) error { return c.SendString("success") } handlerFunc := FiberHandlerFunc(fiberH) // Use helper function for cleaner test setup req := createTestRequest(http.MethodGet, "/test", "[invalid:address:format", nil) w := executeHandlerTest(t, handlerFunc, req) // Should still work despite the invalid remote address require.Equal(t, http.StatusOK, w.StatusCode()) require.Equal(t, "success", string(w.body)) } func Test_FiberHandler_WithUnixSocket(t *testing.T) { t.Parallel() // Test case where request has unix socket context fiberH := func(c fiber.Ctx) error { return c.SendString("unix socket success") } handlerFunc := FiberHandlerFunc(fiberH) // Create a context with unix socket local address unixAddr := &net.UnixAddr{Name: "/tmp/test.sock", Net: "unix"} ctx := context.WithValue(context.Background(), http.LocalAddrContextKey, unixAddr) r := &http.Request{ Method: http.MethodGet, RequestURI: "/test", RemoteAddr: "someremoteaddr", // This will be ignored due to unix socket Header: make(http.Header), Body: http.NoBody, } r = r.WithContext(ctx) w := &netHTTPResponseWriter{} handlerFunc.ServeHTTP(w, r) require.Equal(t, http.StatusOK, w.StatusCode()) require.Equal(t, "unix socket success", string(w.body)) } func Test_FiberHandler_BodySizeLimit(t *testing.T) { t.Parallel() // Test body size limit enforcement fiberH := func(c fiber.Ctx) error { return c.SendString("processed") } handlerFunc := FiberHandlerFunc(fiberH) // Create a large body exceeding limit largeBody := make([]byte, 15*1024*1024) // 15MB > 10MB limit req := createTestRequest(http.MethodPost, "/test", "127.0.0.1:8080", bytes.NewReader(largeBody)) req.ContentLength = int64(len(largeBody)) w := executeHandlerTest(t, handlerFunc, req) // Should return 413 due to size limit require.Equal(t, http.StatusRequestEntityTooLarge, w.StatusCode()) } func Test_CopyContextToFiberContext_Safe(t *testing.T) { t.Parallel() t.Run("safe handling of unexported fields", func(t *testing.T) { t.Parallel() // Test that unexported fields are handled safely type testContext struct { exportedField string unexported string // unexported } var fctx fasthttp.RequestCtx ctx := testContext{exportedField: "exported", unexported: "unexported"} // Should not panic and handle safely CopyContextToFiberContext(&ctx, &fctx) // No specific assertion, just ensure no panic }) } func TestUnixSocketAdaptor(t *testing.T) { dir := t.TempDir() socketPath := filepath.Join(dir, "test.sock") defer func() { if err := os.Remove(socketPath); err != nil { t.Logf("cleanup failed: %v", err) } }() app := fiber.New() app.Get("/hello", func(c fiber.Ctx) error { return c.SendString("ok") }) handler := FiberApp(app) listener, err := net.Listen("unix", socketPath) if err != nil { // Skip on platforms where the "unix" network is unsupported if strings.Contains(err.Error(), "unknown network") || strings.Contains(err.Error(), "address family not supported") { t.Skipf("Unix domain sockets not supported on this platform: %v", err) } t.Fatal(err) } defer func() { if closeErr := listener.Close(); closeErr != nil { t.Logf("listener close failed: %v", closeErr) } }() // start server with timeouts srv := &http.Server{ Handler: handler, ReadTimeout: 5 * time.Second, WriteTimeout: 10 * time.Second, } done := make(chan struct{}) go func() { if serveErr := srv.Serve(listener); serveErr != nil && serveErr != http.ErrServerClosed { t.Errorf("http server failed: %v", serveErr) } close(done) }() conn, err := net.Dial("unix", socketPath) require.NoError(t, err) defer func() { if closeErr := conn.Close(); closeErr != nil { t.Logf("conn close failed: %v", closeErr) } }() // set deadline for both write + read (2s) require.NoError(t, conn.SetDeadline(time.Now().Add(2*time.Second))) // write request _, err = conn.Write([]byte("GET /hello HTTP/1.1\r\nHost: localhost\r\n\r\n")) require.NoError(t, err) // read response buf := make([]byte, 1024) n, err := conn.Read(buf) require.NoError(t, err) // clear deadline to avoid affecting further calls require.NoError(t, conn.SetDeadline(time.Time{})) raw := string(buf[:n]) t.Logf("Raw response:\n%s", raw) require.Contains(t, raw, "HTTP/1.1 200 OK") require.Contains(t, raw, "ok") // now shutdown the server explicitly before waiting for done ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() require.NoError(t, srv.Shutdown(ctx)) select { case <-done: case <-time.After(5 * time.Second): t.Fatal("server shutdown timed out") } } ================================================ FILE: middleware/basicauth/basicauth.go ================================================ package basicauth import ( "encoding/base64" "errors" "strings" "unicode" "unicode/utf8" "github.com/gofiber/fiber/v3" "github.com/gofiber/utils/v2" "golang.org/x/text/unicode/norm" ) // The contextKey type is unexported to prevent collisions with context keys defined in // other packages. type contextKey int // The key for the username value stored in the context const ( usernameKey contextKey = iota ) const basicScheme = "Basic" // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Set default config cfg := configDefault(config...) var cerr base64.CorruptInputError // Return new handler return func(c fiber.Ctx) error { // Don't execute middleware if Next returns true if cfg.Next != nil && cfg.Next(c) { return c.Next() } // Get authorization header and ensure it matches the Basic scheme rawAuth := c.Get(fiber.HeaderAuthorization) if rawAuth == "" { return cfg.Unauthorized(c) } if len(rawAuth) > cfg.HeaderLimit { return c.SendStatus(fiber.StatusRequestHeaderFieldsTooLarge) } if containsInvalidHeaderChars(rawAuth) { return cfg.BadRequest(c) } auth := utils.TrimSpace(rawAuth) if auth == "" { return cfg.Unauthorized(c) } if len(auth) < len(basicScheme) || !utils.EqualFold(auth[:len(basicScheme)], basicScheme) { return cfg.Unauthorized(c) } rest := auth[len(basicScheme):] if len(rest) < 2 || rest[0] != ' ' || rest[1] == ' ' { return cfg.BadRequest(c) } rest = rest[1:] if strings.IndexFunc(rest, unicode.IsSpace) != -1 { return cfg.BadRequest(c) } // Decode the header contents raw, err := base64.StdEncoding.DecodeString(rest) if err != nil { if errors.As(err, &cerr) { raw, err = base64.RawStdEncoding.DecodeString(rest) } if err != nil { return cfg.BadRequest(c) } } if !utf8.Valid(raw) { return cfg.BadRequest(c) } if !norm.NFC.IsNormal(raw) { raw = norm.NFC.Bytes(raw) } // Get the credentials var creds string if c.App().Config().Immutable { creds = string(raw) } else { creds = utils.UnsafeString(raw) } // Check if the credentials are in the correct form // which is "username:password". username, password, found := strings.Cut(creds, ":") if !found { return cfg.BadRequest(c) } if containsCTL(username) || containsCTL(password) { return cfg.BadRequest(c) } if cfg.Authorizer(username, password, c) { fiber.StoreInContext(c, usernameKey, username) return c.Next() } // Authentication failed return cfg.Unauthorized(c) } } func containsCTL(s string) bool { return strings.IndexFunc(s, unicode.IsControl) != -1 } func containsInvalidHeaderChars(s string) bool { return strings.IndexFunc(s, func(r rune) bool { return (r < 0x20 && r != '\t') || r == 0x7F || r >= 0x80 }) != -1 } // UsernameFromContext returns the username found in the context. // It accepts fiber.CustomCtx, fiber.Ctx, *fasthttp.RequestCtx, and context.Context. // It returns an empty string if the username does not exist. func UsernameFromContext(ctx any) string { if username, ok := fiber.ValueFromContext[string](ctx, usernameKey); ok { return username } return "" } ================================================ FILE: middleware/basicauth/basicauth_test.go ================================================ package basicauth import ( "crypto/sha256" "crypto/sha512" "encoding/base64" "encoding/hex" "fmt" "io" "net/http" "net/http/httptest" "strings" "testing" "github.com/gofiber/fiber/v3" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" "golang.org/x/crypto/bcrypt" ) func sha256Hash(p string) string { sum := sha256.Sum256([]byte(p)) return "{SHA256}" + base64.StdEncoding.EncodeToString(sum[:]) } func sha512Hash(p string) string { sum := sha512.Sum512([]byte(p)) return "{SHA512}" + base64.StdEncoding.EncodeToString(sum[:]) } // go test -run Test_BasicAuth_Next func Test_BasicAuth_Next(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Next: func(_ fiber.Ctx) bool { return true }, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusNotFound, resp.StatusCode) } func Test_Middleware_BasicAuth(t *testing.T) { t.Parallel() app := fiber.New() hashedJohn := sha256Hash("doe") hashedAdmin, err := bcrypt.GenerateFromPassword([]byte("123456"), bcrypt.MinCost) require.NoError(t, err) app.Use(New(Config{ Users: map[string]string{ "john": hashedJohn, "admin": string(hashedAdmin), }, })) app.Get("/testauth", func(c fiber.Ctx) error { username := UsernameFromContext(c) return c.SendString(username) }) tests := []struct { url string username string password string statusCode int }{ { url: "/testauth", statusCode: 200, username: "john", password: "doe", }, { url: "/testauth", statusCode: 200, username: "admin", password: "123456", }, { url: "/testauth", statusCode: 401, username: "ee", password: "123456", }, } for _, tt := range tests { // Base64 encode credentials for http auth header creds := base64.StdEncoding.EncodeToString(fmt.Appendf(nil, "%s:%s", tt.username, tt.password)) req := httptest.NewRequest(fiber.MethodGet, "/testauth", http.NoBody) req.Header.Add("Authorization", "Basic "+creds) resp, err := app.Test(req) require.NoError(t, err) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, tt.statusCode, resp.StatusCode) if tt.statusCode == 200 { require.Equal(t, tt.username, string(body)) } } } func Test_BasicAuth_UsernameFromContext_Types(t *testing.T) { t.Parallel() app := fiber.New(fiber.Config{PassLocalsToContext: true}) app.Use(New(Config{ Users: map[string]string{ "john": sha256Hash("doe"), }, })) app.Get("/", func(c fiber.Ctx) error { require.Equal(t, "john", UsernameFromContext(c)) customCtx, ok := c.(fiber.CustomCtx) require.True(t, ok) require.Equal(t, "john", UsernameFromContext(customCtx)) require.Equal(t, "john", UsernameFromContext(c.RequestCtx())) require.Equal(t, "john", UsernameFromContext(c.Context())) return c.SendStatus(fiber.StatusOK) }) creds := base64.StdEncoding.EncodeToString([]byte("john:doe")) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Add(fiber.HeaderAuthorization, "Basic "+creds) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) } func Test_BasicAuth_AuthorizerCtx(t *testing.T) { t.Parallel() app := fiber.New() called := false app.Use(New(Config{ Authorizer: func(user, pass string, c fiber.Ctx) bool { called = true require.Equal(t, "john", user) require.Equal(t, "doe", pass) require.Equal(t, "/ctx", c.Path()) return true }, })) app.Get("/ctx", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) creds := base64.StdEncoding.EncodeToString([]byte("john:doe")) req := httptest.NewRequest(fiber.MethodGet, "/ctx", http.NoBody) req.Header.Set(fiber.HeaderAuthorization, "Basic "+creds) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) require.True(t, called) } func Test_BasicAuth_WWWAuthenticateHeader(t *testing.T) { t.Parallel() app := fiber.New() hashedJohn := sha256Hash("doe") app.Use(New(Config{Users: map[string]string{"john": hashedJohn}})) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode) require.Equal(t, `Basic realm="Restricted", charset="UTF-8"`, resp.Header.Get(fiber.HeaderWWWAuthenticate)) } func Test_BasicAuth_WWWAuthenticateHeader_UTF8(t *testing.T) { t.Parallel() app := fiber.New() hashedJohn := sha256Hash("doe") app.Use(New(Config{Users: map[string]string{"john": hashedJohn}, Charset: "utf-8"})) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode) require.Equal(t, `Basic realm="Restricted", charset="UTF-8"`, resp.Header.Get(fiber.HeaderWWWAuthenticate)) } func Test_BasicAuth_InvalidHeader(t *testing.T) { t.Parallel() app := fiber.New() hashedJohn := sha256Hash("doe") app.Use(New(Config{Users: map[string]string{"john": hashedJohn}})) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderAuthorization, "Basic notbase64") resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusBadRequest, resp.StatusCode) } func Test_BasicAuth_MissingScheme(t *testing.T) { t.Parallel() app := fiber.New() hashedJohn := sha256Hash("doe") app.Use(New(Config{Users: map[string]string{"john": hashedJohn}})) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderAuthorization, "Bearer token") resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode) require.Equal(t, `Basic realm="Restricted", charset="UTF-8"`, resp.Header.Get(fiber.HeaderWWWAuthenticate)) } func Test_BasicAuth_MissingColon(t *testing.T) { t.Parallel() app := fiber.New() hashedJohn := sha256Hash("doe") app.Use(New(Config{Users: map[string]string{"john": hashedJohn}})) creds := base64.StdEncoding.EncodeToString([]byte("john")) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderAuthorization, "Basic "+creds) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusBadRequest, resp.StatusCode) } func Test_BasicAuth_EmptyAuthorization(t *testing.T) { t.Parallel() app := fiber.New() hashedJohn := sha256Hash("doe") app.Use(New(Config{Users: map[string]string{"john": hashedJohn}})) cases := []string{"", " "} for _, h := range cases { req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderAuthorization, h) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode) } } func Test_BasicAuth_HeaderWhitespace(t *testing.T) { t.Parallel() app := fiber.New() hashedJohn := sha256Hash("doe") app.Use(New(Config{Users: map[string]string{"john": hashedJohn}})) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }) creds := base64.StdEncoding.EncodeToString([]byte("john:doe")) cases := []struct { header string status int }{ {"Basic " + creds, fiber.StatusTeapot}, {" Basic " + creds, fiber.StatusTeapot}, {"Basic " + creds, fiber.StatusBadRequest}, {"Basic " + creds, fiber.StatusBadRequest}, {"Basic\t" + creds, fiber.StatusBadRequest}, {"Basic \t" + creds, fiber.StatusBadRequest}, {"Basic\u00A0" + creds, fiber.StatusBadRequest}, {"Basic\u3000" + creds, fiber.StatusBadRequest}, {"\tBasic " + creds + "\t", fiber.StatusTeapot}, {"Basic " + creds[:4] + " " + creds[4:], fiber.StatusBadRequest}, } for _, tt := range cases { req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderAuthorization, tt.header) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, tt.status, resp.StatusCode) } } func Test_BasicAuth_ControlChars(t *testing.T) { t.Parallel() called := false app := fiber.New() app.Use(New(Config{ Authorizer: func(_, _ string, _ fiber.Ctx) bool { called = true return true }, })) creds := []string{ base64.StdEncoding.EncodeToString([]byte("john:\x01doe")), base64.StdEncoding.EncodeToString([]byte("jo\x7Fhn:doe")), base64.StdEncoding.EncodeToString([]byte{'j', 'o', 'h', 'n', ':', 0x85, 'd', 'o', 'e'}), base64.StdEncoding.EncodeToString([]byte{'j', 'o', 'h', 'n', ':', 0x9F, 'd', 'o', 'e'}), } for _, c := range creds { called = false req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderAuthorization, "Basic "+c) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusBadRequest, resp.StatusCode) require.Empty(t, resp.Header.Get(fiber.HeaderWWWAuthenticate)) require.False(t, called) } } func Test_BasicAuth_UnpaddedBase64(t *testing.T) { t.Parallel() app := fiber.New() hashedJohn := sha256Hash("doe") app.Use(New(Config{Users: map[string]string{"john": hashedJohn}})) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }) creds := base64.StdEncoding.EncodeToString([]byte("john:doe")) creds = strings.TrimRight(creds, "=") req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderAuthorization, "Basic "+creds) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusTeapot, resp.StatusCode) } func Test_BasicAuth_NonASCIIHeader(t *testing.T) { t.Parallel() app := fiber.New() hashedJohn := sha256Hash("doe") app.Use(New(Config{Users: map[string]string{"john": hashedJohn}})) handler := app.Handler() creds := base64.StdEncoding.EncodeToString([]byte("john:doe")) fctx := &fasthttp.RequestCtx{} fctx.Request.SetRequestURI("/") fctx.Request.Header.SetMethod(fiber.MethodGet) fctx.Request.Header.SetBytesKV([]byte(fiber.HeaderAuthorization), []byte("Basic \x80"+creds)) handler(fctx) require.Equal(t, fiber.StatusBadRequest, fctx.Response.StatusCode()) } func Test_BasicAuth_InvalidUTF8(t *testing.T) { t.Parallel() called := false app := fiber.New() app.Use(New(Config{ Charset: "UTF-8", Authorizer: func(_, _ string, _ fiber.Ctx) bool { called = true return true }, })) creds := base64.StdEncoding.EncodeToString([]byte{'j', 'o', 'h', 'n', ':', 0xff, 'd', 'o', 'e'}) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderAuthorization, "Basic "+creds) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusBadRequest, resp.StatusCode) require.False(t, called) } func Test_BasicAuth_UTF8Normalization(t *testing.T) { t.Parallel() app := fiber.New() decomposed := "e\u0301" // e + combining acute accent called := false app.Use(New(Config{ Charset: "UTF-8", Authorizer: func(u, p string, _ fiber.Ctx) bool { called = true require.Equal(t, "é", u) require.Equal(t, "doe", p) return true }, })) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }) creds := base64.StdEncoding.EncodeToString([]byte(decomposed + ":doe")) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderAuthorization, "Basic "+creds) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusTeapot, resp.StatusCode) require.True(t, called) } func Test_BasicAuth_HeaderControlCharEdges(t *testing.T) { t.Parallel() app := fiber.New() hashedJohn := sha256Hash("doe") app.Use(New(Config{Users: map[string]string{"john": hashedJohn}})) handler := app.Handler() creds := base64.StdEncoding.EncodeToString([]byte("john:doe")) headers := [][]byte{ []byte("\rBasic " + creds), []byte("\nBasic " + creds), []byte("Basic " + creds + "\r"), []byte("Basic " + creds + "\n"), } for _, h := range headers { fctx := &fasthttp.RequestCtx{} fctx.Request.SetRequestURI("/") fctx.Request.Header.SetMethod(fiber.MethodGet) fctx.Request.Header.SetBytesKV([]byte(fiber.HeaderAuthorization), h) handler(fctx) require.Equal(t, fiber.StatusBadRequest, fctx.Response.StatusCode()) } } func Test_BasicAuth_Charset(t *testing.T) { t.Parallel() require.Panics(t, func() { New(Config{Charset: "ISO-8859-1"}) }) require.NotPanics(t, func() { New(Config{Charset: "utf-8"}) }) require.NotPanics(t, func() { New(Config{Charset: "UTF-8"}) }) require.NotPanics(t, func() { New(Config{}) }) } func Test_BasicAuth_HeaderLimit(t *testing.T) { t.Parallel() creds := base64.StdEncoding.EncodeToString([]byte("john:doe")) hashedJohn := sha256Hash("doe") t.Run("too large", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Users: map[string]string{"john": hashedJohn}, HeaderLimit: 10})) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderAuthorization, "Basic "+creds) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusRequestHeaderFieldsTooLarge, resp.StatusCode) }) t.Run("allowed", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Users: map[string]string{"john": hashedJohn}, HeaderLimit: 100})) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderAuthorization, "Basic "+creds) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusTeapot, resp.StatusCode) }) } // go test -v -run=^$ -bench=Benchmark_Middleware_BasicAuth -benchmem -count=4 func Benchmark_Middleware_BasicAuth(b *testing.B) { app := fiber.New() hashedJohn := sha256Hash("doe") app.Use(New(Config{ Users: map[string]string{ "john": hashedJohn, }, })) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }) h := app.Handler() fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(fiber.MethodGet) fctx.Request.SetRequestURI("/") fctx.Request.Header.Set(fiber.HeaderAuthorization, "basic am9objpkb2U=") // john:doe b.ReportAllocs() for b.Loop() { h(fctx) } require.Equal(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode()) } // go test -v -run=^$ -bench=Benchmark_Middleware_BasicAuth -benchmem -count=4 func Benchmark_Middleware_BasicAuth_Upper(b *testing.B) { app := fiber.New() hashedJohn := sha256Hash("doe") app.Use(New(Config{ Users: map[string]string{ "john": hashedJohn, }, })) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }) h := app.Handler() fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(fiber.MethodGet) fctx.Request.SetRequestURI("/") fctx.Request.Header.Set(fiber.HeaderAuthorization, "Basic am9objpkb2U=") // john:doe b.ReportAllocs() for b.Loop() { h(fctx) } require.Equal(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode()) } func Test_BasicAuth_Immutable(t *testing.T) { t.Parallel() app := fiber.New(fiber.Config{Immutable: true}) hashedJohn := sha256Hash("doe") app.Use(New(Config{Users: map[string]string{"john": hashedJohn}})) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }) creds := base64.StdEncoding.EncodeToString([]byte("john:doe")) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderAuthorization, "Basic "+creds) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusTeapot, resp.StatusCode) } func Test_parseHashedPassword(t *testing.T) { t.Parallel() pass := "secret" sha := sha256.Sum256([]byte(pass)) b64 := base64.StdEncoding.EncodeToString(sha[:]) hexDigest := hex.EncodeToString(sha[:]) bcryptHash, err := bcrypt.GenerateFromPassword([]byte(pass), bcrypt.MinCost) require.NoError(t, err) cases := []struct { name string hashed string }{ {"bcrypt", string(bcryptHash)}, {"sha512", sha512Hash(pass)}, {"sha256", sha256Hash(pass)}, {"sha256-hex", hexDigest}, {"sha256-b64", b64}, } for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { t.Parallel() verify, err := parseHashedPassword(tt.hashed) require.NoError(t, err) require.True(t, verify(pass)) require.False(t, verify("wrong")) }) } } func Test_BasicAuth_HashVariants(t *testing.T) { t.Parallel() pass := "doe" bcryptHash, err := bcrypt.GenerateFromPassword([]byte(pass), bcrypt.MinCost) require.NoError(t, err) cases := []struct { name string hashed string }{ {"bcrypt", string(bcryptHash)}, {"sha512", sha512Hash(pass)}, {"sha256", sha256Hash(pass)}, {"sha256-hex", func() string { h := sha256.Sum256([]byte(pass)); return hex.EncodeToString(h[:]) }()}, } for _, tt := range cases { app := fiber.New() app.Use(New(Config{Users: map[string]string{"john": tt.hashed}})) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }) creds := base64.StdEncoding.EncodeToString([]byte("john:" + pass)) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderAuthorization, "Basic "+creds) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusTeapot, resp.StatusCode) } } func Test_BasicAuth_HashVariants_Invalid(t *testing.T) { t.Parallel() pass := "doe" wrong := "wrong" bcryptHash, err := bcrypt.GenerateFromPassword([]byte(pass), bcrypt.MinCost) require.NoError(t, err) cases := []struct { name string hashed string }{ {"bcrypt", string(bcryptHash)}, {"sha512", sha512Hash(pass)}, {"sha256", sha256Hash(pass)}, {"sha256-hex", func() string { h := sha256.Sum256([]byte(pass)); return hex.EncodeToString(h[:]) }()}, } for _, tt := range cases { app := fiber.New() app.Use(New(Config{Users: map[string]string{"john": tt.hashed}})) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }) creds := base64.StdEncoding.EncodeToString([]byte("john:" + wrong)) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderAuthorization, "Basic "+creds) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusUnauthorized, resp.StatusCode) } } ================================================ FILE: middleware/basicauth/config.go ================================================ package basicauth import ( "crypto/sha256" "crypto/sha512" "crypto/subtle" "encoding/base64" "encoding/hex" "errors" "fmt" "strconv" "strings" "github.com/gofiber/fiber/v3" "github.com/gofiber/utils/v2" "golang.org/x/crypto/bcrypt" ) var ErrInvalidSHA256PasswordLength = errors.New("decode SHA256 password: invalid length") // Config defines the config for middleware. type Config struct { // Next defines a function to skip this middleware when returned true. // // Optional. Default: nil Next func(c fiber.Ctx) bool // Users defines the allowed credentials // // Required. Default: map[string]string{} Users map[string]string // Authorizer defines a function you can pass // to check the credentials however you want. // It will be called with a username, password and // the current fiber context and is expected to return // true or false to indicate that the credentials were // approved or not. // // Optional. Default: nil. Authorizer func(string, string, fiber.Ctx) bool // Unauthorized defines the response body for unauthorized responses. // By default it will return with a 401 Unauthorized and the correct WWW-Auth header // // Optional. Default: nil Unauthorized fiber.Handler // BadRequest defines the response body for malformed Authorization headers. // By default it will return with a 400 Bad Request without the WWW-Authenticate header. // // Optional. Default: nil BadRequest fiber.Handler // Realm is a string to define realm attribute of BasicAuth. // the realm identifies the system to authenticate against // and can be used by clients to save credentials // // Optional. Default: "Restricted". Realm string // Charset defines the value for the charset parameter in the // WWW-Authenticate header. According to RFC 7617 clients can use // this value to interpret credentials correctly. Only the value // "UTF-8" is allowed; any other value will panic. // // Optional. Default: "UTF-8". Charset string // HeaderLimit specifies the maximum allowed length of the // Authorization header. Requests exceeding this limit will // be rejected. // // Optional. Default: 8192. HeaderLimit int } // ConfigDefault is the default config var ConfigDefault = Config{ Next: nil, Users: map[string]string{}, Realm: "Restricted", Charset: "UTF-8", HeaderLimit: 8192, Authorizer: nil, Unauthorized: nil, BadRequest: nil, } // Helper function to set default values func configDefault(config ...Config) Config { // Return default config if nothing provided if len(config) < 1 { return ConfigDefault } // Override default config cfg := config[0] // Set default values if cfg.Next == nil { cfg.Next = ConfigDefault.Next } if cfg.Users == nil { cfg.Users = ConfigDefault.Users } if cfg.Realm == "" { cfg.Realm = ConfigDefault.Realm } switch { case cfg.Charset == "": cfg.Charset = ConfigDefault.Charset case utils.EqualFold(cfg.Charset, "UTF-8"): cfg.Charset = "UTF-8" default: panic("basicauth: charset must be UTF-8") } if cfg.HeaderLimit <= 0 { cfg.HeaderLimit = ConfigDefault.HeaderLimit } if cfg.Authorizer == nil { verifiers := make(map[string]func(string) bool, len(cfg.Users)) for u, hpw := range cfg.Users { v, err := parseHashedPassword(hpw) if err != nil { panic(err) } verifiers[u] = v } cfg.Authorizer = func(user, pass string, _ fiber.Ctx) bool { verify, ok := verifiers[user] return ok && verify(pass) } } if cfg.Unauthorized == nil { cfg.Unauthorized = func(c fiber.Ctx) error { header := "Basic realm=" + strconv.Quote(cfg.Realm) if cfg.Charset != "" { header += ", charset=" + strconv.Quote(cfg.Charset) } c.Set(fiber.HeaderWWWAuthenticate, header) c.Set(fiber.HeaderCacheControl, "no-store") c.Set(fiber.HeaderVary, fiber.HeaderAuthorization) return c.SendStatus(fiber.StatusUnauthorized) } } if cfg.BadRequest == nil { cfg.BadRequest = func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusBadRequest) } } return cfg } func parseHashedPassword(h string) (func(string) bool, error) { switch { case strings.HasPrefix(h, "$2"): hash := []byte(h) return func(p string) bool { return bcrypt.CompareHashAndPassword(hash, []byte(p)) == nil }, nil case strings.HasPrefix(h, "{SHA512}"): b, err := base64.StdEncoding.DecodeString(h[len("{SHA512}"):]) if err != nil { return nil, fmt.Errorf("decode SHA512 password: %w", err) } return func(p string) bool { sum := sha512.Sum512([]byte(p)) return subtle.ConstantTimeCompare(sum[:], b) == 1 }, nil case strings.HasPrefix(h, "{SHA256}"): b, err := base64.StdEncoding.DecodeString(h[len("{SHA256}"):]) if err != nil { return nil, fmt.Errorf("decode SHA256 password: %w", err) } return func(p string) bool { sum := sha256.Sum256([]byte(p)) return subtle.ConstantTimeCompare(sum[:], b) == 1 }, nil default: b, err := hex.DecodeString(h) if err != nil || len(b) != sha256.Size { if b, err = base64.StdEncoding.DecodeString(h); err != nil { return nil, fmt.Errorf("decode SHA256 password: %w", err) } if len(b) != sha256.Size { return nil, ErrInvalidSHA256PasswordLength } } return func(p string) bool { sum := sha256.Sum256([]byte(p)) return subtle.ConstantTimeCompare(sum[:], b) == 1 }, nil } } ================================================ FILE: middleware/cache/cache.go ================================================ // Special thanks to @codemicro for moving this to fiber core // Original middleware: github.com/codemicro/fiber-cache package cache import ( "context" "crypto/sha256" "encoding/hex" "errors" "fmt" "math" "slices" "sort" "strings" "sync" "sync/atomic" "time" "github.com/gofiber/utils/v2" utilsstrings "github.com/gofiber/utils/v2/strings" "github.com/valyala/fasthttp" "github.com/gofiber/fiber/v3" ) // timestampUpdatePeriod is the period which is used to check the cache expiration. // It should not be too long to provide more or less acceptable expiration error, and in the same // time it should not be too short to avoid overwhelming of the system const timestampUpdatePeriod = 300 * time.Millisecond // buffer size for hexpool const hexLen = sha256.Size * 2 // cache status // unreachable: when cache is bypass, or invalid // hit: cache is served // miss: do not have cache record const ( cacheUnreachable = "unreachable" cacheHit = "hit" cacheMiss = "miss" ) type expirationSource uint8 const ( expirationSourceConfig expirationSource = iota expirationSourceMaxAge expirationSourceSMaxAge expirationSourceExpires expirationSourceGenerator ) // directives const ( noCache = "no-cache" noStore = "no-store" privateDirective = "private" ) type requestCacheDirectives struct { maxAge uint64 maxStale uint64 minFresh uint64 maxAgeSet bool maxStaleSet bool maxStaleAny bool minFreshSet bool noStore bool noCache bool onlyIfCached bool } var ignoreHeaders = map[string]struct{}{ "Age": {}, "Cache-Control": {}, // already stored explicitly by the cache manager "Connection": {}, "Content-Encoding": {}, // already stored explicitly by the cache manager "Content-Type": {}, // already stored explicitly by the cache manager "Date": {}, "ETag": {}, // already stored explicitly by the cache manager "Expires": {}, // already stored explicitly by the cache manager "Keep-Alive": {}, "Proxy-Authenticate": {}, "Proxy-Authorization": {}, "TE": {}, "Trailers": {}, "Transfer-Encoding": {}, "Upgrade": {}, } var cacheableStatusCodes = map[int]struct{}{ fiber.StatusOK: {}, fiber.StatusNonAuthoritativeInformation: {}, fiber.StatusNoContent: {}, fiber.StatusPartialContent: {}, fiber.StatusMultipleChoices: {}, fiber.StatusMovedPermanently: {}, fiber.StatusPermanentRedirect: {}, fiber.StatusNotFound: {}, fiber.StatusMethodNotAllowed: {}, fiber.StatusGone: {}, fiber.StatusRequestURITooLong: {}, fiber.StatusNotImplemented: {}, } // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Set default config cfg := configDefault(config...) type evictionCandidate struct { key string size uint exp uint64 heapIdx int } redactKeys := !cfg.DisableValueRedaction maskKey := func(key string) string { if redactKeys { return redactedKey } return key } // Nothing to cache if int(cfg.Expiration.Seconds()) < 0 { return func(c fiber.Ctx) error { return c.Next() } } var ( // Cache settings mux = &sync.RWMutex{} timestamp = safeUnixSeconds(time.Now()) ) // Create manager to simplify storage operations ( see manager.go ) manager := newManager(cfg.Storage, redactKeys) // Create indexed heap for tracking expirations ( see heap.go ) heap := &indexedHeap{} // count stored bytes (sizes of response bodies) var storedBytes uint // Pool for hex encoding buffers hexBufPool := &sync.Pool{ New: func() any { buf := make([]byte, hexLen) return &buf }, } hashAuthorization := makeHashAuthFunc(hexBufPool) buildVaryKey := makeBuildVaryKeyFunc(hexBufPool) // Update timestamp in the configured interval go func() { ticker := time.NewTicker(timestampUpdatePeriod) defer ticker.Stop() for range ticker.C { atomic.StoreUint64(×tamp, safeUnixSeconds(time.Now())) } }() // Delete key from both manager and storage deleteKey := func(ctx context.Context, dkey string) error { if err := manager.del(ctx, dkey); err != nil { return err } // External storage saves body data with different key if cfg.Storage != nil { if err := manager.del(ctx, dkey+"_body"); err != nil { return err } } return nil } removeHeapEntry := func(entryKey string, heapIdx int) { if cfg.MaxBytes == 0 { return } if heapIdx < 0 || heapIdx >= len(heap.indices) { return } indexedIdx := heap.indices[heapIdx] if indexedIdx < 0 || indexedIdx >= len(heap.entries) { return } entry := heap.entries[indexedIdx] if entry.idx != heapIdx || entry.key != entryKey { return } _, size := heap.remove(heapIdx) storedBytes -= size } refreshHeapIndex := func(ctx context.Context, candidate evictionCandidate) error { entry, err := manager.get(ctx, candidate.key) if err != nil { if errors.Is(err, errCacheMiss) { return nil } return fmt.Errorf("cache: failed to reload key %q after eviction failure: %w", maskKey(candidate.key), err) } entry.heapidx = candidate.heapIdx remainingTTL := max(time.Until(secondsToTime(entry.exp)), 0) if err := manager.set(ctx, candidate.key, entry, remainingTTL); err != nil { return fmt.Errorf("cache: failed to restore heap index for key %q: %w", maskKey(candidate.key), err) } return nil } // Return new handler return func(c fiber.Ctx) error { hasAuthorization := len(c.Request().Header.Peek(fiber.HeaderAuthorization)) > 0 reqCacheControl := c.Request().Header.Peek(fiber.HeaderCacheControl) reqDirectives := parseRequestCacheControl(reqCacheControl) if !reqDirectives.noCache { reqPragma := utils.UnsafeString(c.Request().Header.Peek(fiber.HeaderPragma)) if hasDirective(reqPragma, noCache) { reqDirectives.noCache = true } } // Refrain from caching if reqDirectives.noStore { return c.Next() } requestMethod := c.Method() // Only cache selected methods if !slices.Contains(cfg.Methods, requestMethod) { c.Set(cfg.CacheHeader, cacheUnreachable) return c.Next() } // Get key from request baseKey := cfg.KeyGenerator(c) + "_" + requestMethod manifestKey := baseKey + "|vary" if hasAuthorization { authHash := hashAuthorization(c.Request().Header.Peek(fiber.HeaderAuthorization)) baseKey += "_auth_" + authHash manifestKey = baseKey + "|vary" } key := baseKey reqCtx := c.Context() varyNames, hasVaryManifest, err := loadVaryManifest(reqCtx, manager, manifestKey) if err != nil { return err } if len(varyNames) > 0 { key += buildVaryKey(varyNames, &c.Request().Header) } // Get entry from pool e, err := manager.get(reqCtx, key) if err != nil && !errors.Is(err, errCacheMiss) { return err } entryAge := uint64(0) revalidate := false oldHeapIdx := -1 // Track old heap index for replacement during revalidation handleMinFresh := func(now uint64) { if e == nil || !reqDirectives.minFreshSet { return } remainingFreshness := remainingFreshness(e, now) if remainingFreshness < reqDirectives.minFresh { revalidate = true oldHeapIdx = e.heapidx if cfg.Storage != nil { manager.release(e) } e = nil } } // Lock entry mux.Lock() locked := true unlock := func() { if locked { mux.Unlock() locked = false } } relock := func() { if !locked { mux.Lock() locked = true } } // Get timestamp ts := atomic.LoadUint64(×tamp) // Cache Entry found if e != nil { entryAge = cachedResponseAge(e, ts) if reqDirectives.maxAgeSet && (reqDirectives.maxAge == 0 || entryAge > reqDirectives.maxAge) { revalidate = true oldHeapIdx = e.heapidx if cfg.Storage != nil { manager.release(e) } e = nil } handleMinFresh(ts) } if e != nil && e.ttl == 0 && e.forceRevalidate { revalidate = true oldHeapIdx = e.heapidx if cfg.Storage != nil { manager.release(e) } e = nil } if e != nil && e.ttl == 0 && e.exp != 0 && ts >= e.exp { unlock() if err := deleteKey(reqCtx, key); err != nil { if cfg.Storage != nil { manager.release(e) } return fmt.Errorf("cache: failed to delete expired key %q: %w", maskKey(key), err) } relock() removeHeapEntry(key, e.heapidx) if cfg.Storage != nil { manager.release(e) } e = nil unlock() c.Set(cfg.CacheHeader, cacheUnreachable) goto continueRequest } if e != nil { entryHasPrivate := e != nil && e.private if !entryHasPrivate && cfg.StoreResponseHeaders && len(e.headers) > 0 { if cc, ok := lookupCachedHeader(e.headers, fiber.HeaderCacheControl); ok && hasDirective(utils.UnsafeString(cc), privateDirective) { entryHasPrivate = true } } requestNoCache := reqDirectives.noCache // Invalidate cache if requested if cfg.CacheInvalidator != nil && cfg.CacheInvalidator(c) { e.exp = ts - 1 } entryHasExpiration := e != nil && e.exp != 0 entryExpired := entryHasExpiration && ts >= e.exp staleness := uint64(0) if entryExpired { staleness = ts - e.exp } allowStale := entryExpired && (reqDirectives.maxStaleAny || (reqDirectives.maxStaleSet && staleness <= reqDirectives.maxStale)) if entryExpired && e.revalidate { revalidate = true oldHeapIdx = e.heapidx if cfg.Storage != nil { manager.release(e) } e = nil } handleMinFresh(ts) if revalidate { unlock() c.Set(cfg.CacheHeader, cacheUnreachable) if reqDirectives.onlyIfCached { return c.SendStatus(fiber.StatusGatewayTimeout) } goto continueRequest } servedStale := false switch { case entryExpired && !allowStale: unlock() if err := deleteKey(reqCtx, key); err != nil { if e != nil { manager.release(e) } return fmt.Errorf("cache: failed to delete expired key %q: %w", maskKey(key), err) } relock() idx := e.heapidx manager.release(e) removeHeapEntry(key, idx) e = nil case entryHasPrivate: unlock() if err := deleteKey(reqCtx, key); err != nil { if e != nil { manager.release(e) } return fmt.Errorf("cache: failed to delete private response for key %q: %w", maskKey(key), err) } relock() removeHeapEntry(key, e.heapidx) if cfg.Storage != nil && e != nil { manager.release(e) } e = nil unlock() c.Set(cfg.CacheHeader, cacheUnreachable) if reqDirectives.onlyIfCached { return c.SendStatus(fiber.StatusGatewayTimeout) } return c.Next() case entryHasExpiration && !requestNoCache: servedStale = entryExpired if hasAuthorization && !e.shareable { if cfg.Storage != nil { manager.release(e) } unlock() c.Set(cfg.CacheHeader, cacheUnreachable) return c.Next() } // Separate body value to avoid msgp serialization // We can store raw bytes with Storage 👍 if cfg.Storage != nil { unlock() rawBody, err := manager.getRaw(reqCtx, key+"_body") if err != nil { manager.release(e) return cacheBodyFetchError(maskKey, key, err) } e.body = rawBody } else { unlock() } // Set response headers from cache c.Response().SetBodyRaw(e.body) c.Response().SetStatusCode(e.status) c.Response().Header.SetContentTypeBytes(e.ctype) if len(e.cencoding) > 0 { c.Response().Header.SetBytesV(fiber.HeaderContentEncoding, e.cencoding) } if len(e.cacheControl) > 0 { c.Response().Header.SetBytesV(fiber.HeaderCacheControl, e.cacheControl) } if len(e.expires) > 0 { c.Response().Header.SetBytesV(fiber.HeaderExpires, e.expires) } if len(e.etag) > 0 { c.Response().Header.SetBytesV(fiber.HeaderETag, e.etag) } clampedDate := clampDateSeconds(e.date, ts) dateValue := fasthttp.AppendHTTPDate(nil, secondsToTime(clampedDate)) c.Response().Header.SetBytesV(fiber.HeaderDate, dateValue) for i := range e.headers { h := e.headers[i] c.Response().Header.SetBytesKV(h.key, h.value) } // Set Cache-Control header if not disabled and not already set if !cfg.DisableCacheControl && len(c.Response().Header.Peek(fiber.HeaderCacheControl)) == 0 { remaining := uint64(0) if e.exp > ts { remaining = e.exp - ts } maxAge := utils.FormatUint(remaining) c.Set(fiber.HeaderCacheControl, "public, max-age="+maxAge) } const maxDeltaSeconds = uint64(math.MaxInt32) ageSeconds := min(entryAge, maxDeltaSeconds) // RFC-compliant Age header (RFC 9111) age := utils.FormatUint(ageSeconds) c.Response().Header.Set(fiber.HeaderAge, age) appendWarningHeaders(&c.Response().Header, servedStale, isHeuristicFreshness(e, &cfg, entryAge)) c.Set(cfg.CacheHeader, cacheHit) // release item allocated from storage if cfg.Storage != nil { manager.release(e) } // Return response return nil default: // no cached response to serve } } if e == nil && revalidate { unlock() c.Set(cfg.CacheHeader, cacheUnreachable) if reqDirectives.onlyIfCached { return c.SendStatus(fiber.StatusGatewayTimeout) } goto continueRequest } if e == nil && reqDirectives.onlyIfCached { unlock() c.Set(cfg.CacheHeader, cacheUnreachable) return c.SendStatus(fiber.StatusGatewayTimeout) } // make sure we're not blocking concurrent requests - do unlock unlock() continueRequest: // Continue stack, return err to Fiber if exist if err := c.Next(); err != nil { return err } cacheControlBytes := c.Response().Header.Peek(fiber.HeaderCacheControl) respCacheControl := parseResponseCacheControl(cacheControlBytes) varyHeader := utils.UnsafeString(c.Response().Header.Peek(fiber.HeaderVary)) hasPrivate := respCacheControl.hasPrivate hasNoCache := respCacheControl.hasNoCache varyNames, varyHasStar := parseVary(varyHeader) // Respect server cache-control: no-store if respCacheControl.hasNoStore { c.Set(cfg.CacheHeader, cacheUnreachable) return nil } if hasPrivate || hasNoCache || varyHasStar { if e != nil { if err := deleteKey(reqCtx, key); err != nil { if cfg.Storage != nil { manager.release(e) } return fmt.Errorf("cache: failed to delete cached response for key %q: %w", maskKey(key), err) } mux.Lock() removeHeapEntry(key, e.heapidx) if cfg.Storage != nil { manager.release(e) } e = nil mux.Unlock() } if hasVaryManifest { if err := manager.del(reqCtx, manifestKey); err != nil { return fmt.Errorf("cache: failed to delete stale vary manifest %q: %w", maskKey(manifestKey), err) } } c.Set(cfg.CacheHeader, cacheUnreachable) return nil } shouldStoreVaryManifest := len(varyNames) > 0 if len(varyNames) > 0 { if key == baseKey { key += buildVaryKey(varyNames, &c.Request().Header) } } else if hasVaryManifest { if err := manager.del(reqCtx, manifestKey); err != nil { return fmt.Errorf("cache: failed to delete stale vary manifest %q: %w", maskKey(manifestKey), err) } } isSharedCacheAllowed := allowsSharedCacheDirectives(respCacheControl) if hasAuthorization && !isSharedCacheAllowed { c.Set(cfg.CacheHeader, cacheUnreachable) return nil } sharedCacheMode := !hasAuthorization || isSharedCacheAllowed // Don't cache response if status code is not cacheable if _, ok := cacheableStatusCodes[c.Response().StatusCode()]; !ok { c.Set(cfg.CacheHeader, cacheUnreachable) return nil } // Don't cache response if Next returns true if cfg.Next != nil && cfg.Next(c) { c.Set(cfg.CacheHeader, cacheUnreachable) return nil } // Don't try to cache if body won't fit into cache bodySize := uint(len(c.Response().Body())) if cfg.MaxBytes > 0 && bodySize > cfg.MaxBytes { c.Set(cfg.CacheHeader, cacheUnreachable) return nil } // Eviction loop: atomically reserve space for new entry and evict old entries. // Strategy: // 1. Under lock: reserve space by pre-incrementing storedBytes, then collect entries to evict // 2. Outside lock: perform I/O deletions // 3. On deletion failure: restore storedBytes and return error // 4. Track reservation with a flag; unreserve on early return via defer var spaceReserved bool defer func() { // If we reserved space but the entry was not successfully added to heap, unreserve it if cfg.MaxBytes > 0 && spaceReserved { mux.Lock() storedBytes -= bodySize mux.Unlock() } }() if cfg.MaxBytes > 0 { mux.Lock() // Reserve space for the new entry first storedBytes += bodySize spaceReserved = true // Now evict entries until we're under the limit var keysToRemove []string var sizesToRemove []uint var candidates []evictionCandidate for storedBytes > cfg.MaxBytes { if heap.Len() == 0 { // Can't evict more, unreserve space and fail storedBytes -= bodySize // Set spaceReserved to false so the deferred cleanup does not unreserve again spaceReserved = false mux.Unlock() return errors.New("cache: insufficient space and no entries to evict") } next := heap.entries[0] keyToRemove, size := heap.removeFirst() keysToRemove = append(keysToRemove, keyToRemove) sizesToRemove = append(sizesToRemove, size) candidates = append(candidates, evictionCandidate{ key: keyToRemove, size: size, exp: next.exp, }) storedBytes -= size } mux.Unlock() // Perform deletions outside the lock if len(keysToRemove) > 0 { for i, keyToRemove := range keysToRemove { delErr := deleteKey(reqCtx, keyToRemove) if delErr == nil { continue } // Deletion failed: restore storedBytes for failed deletions mux.Lock() // Restore sizes of entries we failed to delete for j := i; j < len(sizesToRemove); j++ { storedBytes += sizesToRemove[j] } // Unreserve space for the new entry storedBytes -= bodySize spaceReserved = false // Re-add entries to the heap to keep expiration tracking consistent var restored []evictionCandidate for j := i; j < len(candidates); j++ { candidate := candidates[j] candidate.heapIdx = heap.put(candidate.key, candidate.exp, candidate.size) restored = append(restored, candidate) } mux.Unlock() var restoreErr error for _, candidate := range restored { if err := refreshHeapIndex(reqCtx, candidate); err != nil { restoreErr = errors.Join(restoreErr, err) } } if restoreErr != nil { return errors.Join(fmt.Errorf("cache: failed to delete key %q while evicting: %w", maskKey(keyToRemove), delErr), restoreErr) } return fmt.Errorf("cache: failed to delete key %q while evicting: %w", maskKey(keyToRemove), delErr) } } } e = manager.acquire() // Cache response e.body = utils.CopyBytes(c.Response().Body()) e.status = c.Response().StatusCode() e.ctype = utils.CopyBytes(c.Response().Header.ContentType()) e.cencoding = utils.CopyBytes(c.Response().Header.Peek(fiber.HeaderContentEncoding)) e.private = false e.cacheControl = utils.CopyBytes(cacheControlBytes) e.expires = utils.CopyBytes(c.Response().Header.Peek(fiber.HeaderExpires)) e.etag = utils.CopyBytes(c.Response().Header.Peek(fiber.HeaderETag)) e.date = 0 ageVal := uint64(0) if b := c.Response().Header.Peek(fiber.HeaderAge); len(b) > 0 { if v, err := fasthttp.ParseUint(b); err == nil { if v >= 0 { ageVal = uint64(v) } } } else { c.Response().Header.Set(fiber.HeaderAge, "0") } e.age = ageVal e.shareable = isSharedCacheAllowed now := time.Now().UTC() nowUnix := safeUnixSeconds(now) dateHeader := c.Response().Header.Peek(fiber.HeaderDate) parsedDate, _ := parseHTTPDate(dateHeader) e.date = clampDateSeconds(parsedDate, nowUnix) dateBytes := fasthttp.AppendHTTPDate(nil, secondsToTime(e.date)) c.Response().Header.SetBytesV(fiber.HeaderDate, dateBytes) // Store all response headers // (more: https://datatracker.ietf.org/doc/html/rfc2616#section-13.5.1) if cfg.StoreResponseHeaders { allHeaders := c.Response().Header.All() e.headers = e.headers[:0] for key, value := range allHeaders { keyStr := string(key) if _, ok := ignoreHeaders[keyStr]; ok { continue } e.headers = append(e.headers, cachedHeader{ key: utils.CopyBytes(utils.UnsafeBytes(keyStr)), value: utils.CopyBytes(value), }) } } expirationSource := expirationSourceConfig expiresParseError := false mustRevalidate := respCacheControl.mustRevalidate || respCacheControl.proxyRevalidate // default cache expiration expiration := cfg.Expiration if sharedCacheMode && respCacheControl.sMaxAgeSet { expiration = secondsToDuration(respCacheControl.sMaxAge) expirationSource = expirationSourceSMaxAge } if expirationSource == expirationSourceConfig { if respCacheControl.maxAgeSet { expiration = secondsToDuration(respCacheControl.maxAge) expirationSource = expirationSourceMaxAge } else if expiresBytes := c.Response().Header.Peek(fiber.HeaderExpires); len(expiresBytes) > 0 { expiresAt, err := fasthttp.ParseHTTPDate(expiresBytes) if err != nil { expiration = time.Nanosecond expiresParseError = true } else { expiration = time.Until(expiresAt) } expirationSource = expirationSourceExpires } } // Calculate expiration by response header or other setting if cfg.ExpirationGenerator != nil { expiration = cfg.ExpirationGenerator(c, &cfg) expirationSource = expirationSourceGenerator } e.forceRevalidate = expiresParseError e.revalidate = mustRevalidate storageExpiration := expiration if expiresParseError || storageExpiration < cfg.Expiration { storageExpiration = cfg.Expiration } if expiration <= 0 && !expiresParseError { c.Set(cfg.CacheHeader, cacheUnreachable) return nil } ts = atomic.LoadUint64(×tamp) responseTS := max(ts, nowUnix) maxAgeSeconds := uint64(time.Duration(math.MaxInt64) / time.Second) var ageDuration time.Duration apparentAge := e.age if e.date > 0 && responseTS > e.date { dateAge := responseTS - e.date if dateAge > apparentAge { apparentAge = dateAge } } if expirationSource != expirationSourceExpires { if apparentAge > maxAgeSeconds { ageDuration = expiration + time.Second } else { ageDuration = time.Duration(apparentAge) * time.Second } } remainingExpiration := expiration - ageDuration if remainingExpiration <= 0 { if expirationSource != expirationSourceExpires { c.Set(cfg.CacheHeader, cacheUnreachable) return nil } remainingExpiration = 0 } if shouldStoreVaryManifest { if err := storeVaryManifest(reqCtx, manager, manifestKey, varyNames, storageExpiration); err != nil { return err } } e.exp = responseTS + uint64(remainingExpiration.Seconds()) e.ttl = uint64(expiration.Seconds()) if expiresParseError { e.exp = ts + 1 } // Store entry in heap (space already reserved in eviction phase) var heapIdx int if cfg.MaxBytes > 0 { mux.Lock() heapIdx = heap.put(key, e.exp, bodySize) e.heapidx = heapIdx // Note: storedBytes was incremented during reservation, and evictions // have already been accounted for, so no additional increment is needed spaceReserved = false // Clear flag to prevent defer from unreserving mux.Unlock() } cleanupOnStoreError := func(ctx context.Context, releaseEntry, rawStored bool) error { var cleanupErr error if cfg.MaxBytes > 0 { mux.Lock() _, size := heap.remove(heapIdx) storedBytes -= size mux.Unlock() } if releaseEntry { manager.release(e) } if rawStored { rawKey := key + "_body" if err := manager.del(ctx, rawKey); err != nil { cleanupErr = errors.Join(cleanupErr, fmt.Errorf("cache: failed to delete raw key %q after store error: %w", maskKey(rawKey), err)) } } return cleanupErr } // For external Storage we store raw body separated if cfg.Storage != nil { if err := manager.setRaw(reqCtx, key+"_body", e.body, storageExpiration); err != nil { if cleanupErr := cleanupOnStoreError(reqCtx, true, false); cleanupErr != nil { err = errors.Join(err, cleanupErr) } return err } // avoid body msgp encoding e.body = nil if err := manager.set(reqCtx, key, e, storageExpiration); err != nil { if cleanupErr := cleanupOnStoreError(reqCtx, false, true); cleanupErr != nil { err = errors.Join(err, cleanupErr) } return err } } else { // Store entry in memory if err := manager.set(reqCtx, key, e, storageExpiration); err != nil { if cleanupErr := cleanupOnStoreError(reqCtx, true, false); cleanupErr != nil { err = errors.Join(err, cleanupErr) } return err } } // If revalidating, remove old heap entry now that replacement is successfully stored if cfg.MaxBytes > 0 && revalidate && oldHeapIdx >= 0 { mux.Lock() removeHeapEntry(key, oldHeapIdx) mux.Unlock() } c.Set(cfg.CacheHeader, cacheMiss) // Finish response return nil } } // hasDirective checks if a cache-control header contains a directive (case-insensitive) func hasDirective(cc, directive string) bool { ccLen := len(cc) dirLen := len(directive) for i := 0; i <= ccLen-dirLen; i++ { if !utils.EqualFold(cc[i:i+dirLen], directive) { continue } if i > 0 { prev := cc[i-1] if prev != ' ' && prev != ',' { continue } } if i+dirLen == ccLen || cc[i+dirLen] == ',' { return true } } return false } func cacheBodyFetchError(mask func(string) string, key string, err error) error { if errors.Is(err, errCacheMiss) { return fmt.Errorf("cache: no cached body for key %q: %w", mask(key), err) } return err } func parseUintDirective(val []byte) (uint64, bool) { if len(val) == 0 { return 0, false } parsed, err := fasthttp.ParseUint(val) if err != nil || parsed < 0 { return 0, false } return uint64(parsed), true } func parseCacheControlDirectives(cc []byte, fn func(key, value []byte)) { for i := 0; i < len(cc); { // skip leading separators/spaces for i < len(cc) && (cc[i] == ' ' || cc[i] == ',') { i++ } if i >= len(cc) { break } start := i for i < len(cc) && cc[i] != ',' { i++ } partEnd := i for partEnd > start && cc[partEnd-1] == ' ' { partEnd-- } keyStart := start for keyStart < partEnd && cc[keyStart] == ' ' { keyStart++ } if keyStart >= partEnd { continue } keyEnd := keyStart for keyEnd < partEnd && cc[keyEnd] != '=' { keyEnd++ } // Trim trailing spaces from key keyEndTrimmed := keyEnd for keyEndTrimmed > keyStart && cc[keyEndTrimmed-1] == ' ' { keyEndTrimmed-- } key := cc[keyStart:keyEndTrimmed] var value []byte if keyEnd < partEnd && cc[keyEnd] == '=' { valueStart := keyEnd + 1 for valueStart < partEnd && cc[valueStart] == ' ' { valueStart++ } valueEnd := partEnd for valueEnd > valueStart && cc[valueEnd-1] == ' ' { valueEnd-- } if valueStart <= valueEnd { value = cc[valueStart:valueEnd] // Handle quoted-string values per RFC 9111 Section 5.2 if len(value) >= 2 && value[0] == '"' && value[len(value)-1] == '"' { value = unquoteCacheDirective(value) } } } fn(key, value) i++ // skip comma } } // unquoteCacheDirective removes quotes and handles escaped characters in quoted-string values. // Per RFC 9111 Section 5.2, quoted-string values follow RFC 9110 Section 5.6.4. func unquoteCacheDirective(quoted []byte) []byte { if len(quoted) < 2 { return quoted } // Remove surrounding quotes inner := quoted[1 : len(quoted)-1] // Check if there are any escaped characters (backslash followed by another character) hasEscapes := false for i := 0; i < len(inner)-1; i++ { if inner[i] == '\\' { hasEscapes = true break } } // If no escapes, return the inner content directly if !hasEscapes { return inner } // Process escaped characters result := make([]byte, 0, len(inner)) for i := 0; i < len(inner); i++ { if inner[i] == '\\' && i+1 < len(inner) { // Skip the backslash and take the next character i++ result = append(result, inner[i]) } else { result = append(result, inner[i]) } } return result } type responseCacheControl struct { maxAge uint64 sMaxAge uint64 maxAgeSet bool sMaxAgeSet bool hasNoCache bool hasNoStore bool hasPrivate bool hasPublic bool mustRevalidate bool proxyRevalidate bool } func parseResponseCacheControl(cc []byte) responseCacheControl { parsed := responseCacheControl{} parseCacheControlDirectives(cc, func(key, value []byte) { switch { case utils.EqualFold(utils.UnsafeString(key), noStore): parsed.hasNoStore = true case utils.EqualFold(utils.UnsafeString(key), noCache): parsed.hasNoCache = true case utils.EqualFold(utils.UnsafeString(key), privateDirective): parsed.hasPrivate = true case utils.EqualFold(utils.UnsafeString(key), "public"): parsed.hasPublic = true case utils.EqualFold(utils.UnsafeString(key), "max-age"): if v, ok := parseUintDirective(value); ok { parsed.maxAgeSet = true parsed.maxAge = v } case utils.EqualFold(utils.UnsafeString(key), "s-maxage"): if v, ok := parseUintDirective(value); ok { parsed.sMaxAgeSet = true parsed.sMaxAge = v } case utils.EqualFold(utils.UnsafeString(key), "must-revalidate"): parsed.mustRevalidate = true case utils.EqualFold(utils.UnsafeString(key), "proxy-revalidate"): parsed.proxyRevalidate = true default: // ignore unknown directives } }) return parsed } // parseMaxAge extracts the max-age directive from a Cache-Control header. func parseMaxAge(cc string) (time.Duration, bool) { parsed := parseResponseCacheControl(utils.UnsafeBytes(cc)) if !parsed.maxAgeSet { return 0, false } return secondsToDuration(parsed.maxAge), true } func parseRequestCacheControl(cc []byte) requestCacheDirectives { directives := requestCacheDirectives{} parseCacheControlDirectives(cc, func(key, value []byte) { switch { case utils.EqualFold(utils.UnsafeString(key), noStore): directives.noStore = true case utils.EqualFold(utils.UnsafeString(key), noCache): directives.noCache = true case utils.EqualFold(utils.UnsafeString(key), "only-if-cached"): directives.onlyIfCached = true case utils.EqualFold(utils.UnsafeString(key), "max-age"): if sec, ok := parseUintDirective(value); ok { directives.maxAgeSet = true directives.maxAge = sec } case utils.EqualFold(utils.UnsafeString(key), "max-stale"): directives.maxStaleSet = true directives.maxStaleAny = len(value) == 0 if !directives.maxStaleAny { if sec, ok := parseUintDirective(value); ok { directives.maxStale = sec } } case utils.EqualFold(utils.UnsafeString(key), "min-fresh"): if sec, ok := parseUintDirective(value); ok { directives.minFreshSet = true directives.minFresh = sec } default: // ignore unknown directives } }) return directives } func parseRequestCacheControlString(cc string) requestCacheDirectives { return parseRequestCacheControl(utils.UnsafeBytes(cc)) } func cachedResponseAge(e *item, now uint64) uint64 { clampedDate := clampDateSeconds(e.date, now) resident := uint64(0) if e.exp != 0 { if e.exp <= now { resident = e.ttl + (now - e.exp) } else { resident = e.ttl - (e.exp - now) } } dateAge := uint64(0) if clampedDate != 0 && now > clampedDate { dateAge = now - clampedDate } currentAge := max(dateAge, max(resident, e.age)) return currentAge } func appendWarningHeaders(h *fasthttp.ResponseHeader, servedStale, heuristicFreshness bool) { //nolint:revive // flags are intentional to represent Warning variants if servedStale { h.Add(fiber.HeaderWarning, `110 - "Response is stale"`) } if heuristicFreshness { h.Add(fiber.HeaderWarning, `113 - "Heuristic expiration"`) } } func remainingFreshness(e *item, now uint64) uint64 { if e == nil || e.exp == 0 || now >= e.exp { return 0 } return e.exp - now } func isHeuristicFreshness(e *item, cfg *Config, entryAge uint64) bool { const heuristicAgeThresholdSeconds = uint64(24 * time.Hour / time.Second) if entryAge <= heuristicAgeThresholdSeconds { return false } if len(e.expires) > 0 { return false } cacheControl := utils.UnsafeString(e.cacheControl) if parsedCC := parseResponseCacheControl(utils.UnsafeBytes(cacheControl)); parsedCC.maxAgeSet || parsedCC.sMaxAgeSet { return false } return cfg.Expiration > 0 } func lookupCachedHeader(headers []cachedHeader, name string) ([]byte, bool) { for i := range headers { if utils.EqualFold(utils.UnsafeString(headers[i].key), name) { return headers[i].value, true } } return nil, false } func parseHTTPDate(dateBytes []byte) (uint64, bool) { if len(dateBytes) == 0 { return 0, false } parsedDate, err := fasthttp.ParseHTTPDate(dateBytes) if err != nil { return 0, false } return safeUnixSeconds(parsedDate), true } func clampDateSeconds(dateSeconds, fallback uint64) uint64 { const maxUnixSeconds = uint64(math.MaxInt64) if dateSeconds == 0 || dateSeconds > maxUnixSeconds || dateSeconds > fallback { return fallback } return dateSeconds } func safeUnixSeconds(t time.Time) uint64 { sec := t.Unix() if sec < 0 { return 0 } return uint64(sec) } func secondsToTime(sec uint64) time.Time { var clamped int64 if sec > uint64(math.MaxInt64) { clamped = math.MaxInt64 } else { clamped = int64(sec) } return time.Unix(clamped, 0).UTC() } func secondsToDuration(sec uint64) time.Duration { const maxSeconds = uint64(math.MaxInt64) / uint64(time.Second) if sec > maxSeconds { return time.Duration(math.MaxInt64) } return time.Duration(sec) * time.Second } func parseVary(vary string) ([]string, bool) { names := make([]string, 0, 8) for part := range strings.SplitSeq(vary, ",") { name := utils.TrimSpace(utilsstrings.ToLower(part)) if name == "" { continue } if name == "*" { return nil, true } names = append(names, name) } if len(names) == 0 { return nil, false } sort.Strings(names) return names, false } func makeBuildVaryKeyFunc(hexBufPool *sync.Pool) func([]string, *fasthttp.RequestHeader) string { return func(names []string, hdr *fasthttp.RequestHeader) string { sum := sha256.New() for _, name := range names { _, _ = sum.Write(utils.UnsafeBytes(name)) //nolint:errcheck // hash.Hash.Write for std hashes never errors _, _ = sum.Write([]byte{0}) //nolint:errcheck // hash.Hash.Write for std hashes never errors _, _ = sum.Write(hdr.Peek(name)) //nolint:errcheck // hash.Hash.Write for std hashes never errors _, _ = sum.Write([]byte{0}) //nolint:errcheck // hash.Hash.Write for std hashes never errors } var hashBytes [sha256.Size]byte sum.Sum(hashBytes[:0]) v := hexBufPool.Get() bufPtr, ok := v.(*[]byte) if !ok || bufPtr == nil { b := make([]byte, hexLen) bufPtr = &b } buf := *bufPtr // Defensive in case someone changed Pool.New or Put a different sized buffer. if cap(buf) < hexLen { buf = make([]byte, hexLen) } else { buf = buf[:hexLen] } *bufPtr = buf hex.Encode(buf, hashBytes[:]) result := "|vary|" + string(buf) hexBufPool.Put(bufPtr) return result } } func storeVaryManifest(ctx context.Context, manager *manager, manifestKey string, names []string, exp time.Duration) error { if len(names) == 0 { return nil } data := strings.Join(names, ",") return manager.setRaw(ctx, manifestKey, utils.UnsafeBytes(data), exp) } //nolint:gocritic // returning explicit values keeps the signature concise while avoiding unnecessary named results func loadVaryManifest(ctx context.Context, manager *manager, manifestKey string) ([]string, bool, error) { raw, err := manager.getRaw(ctx, manifestKey) if err != nil { if errors.Is(err, errCacheMiss) { return nil, false, nil } return nil, false, err } manifest := utils.UnsafeString(raw) names, hasStar := parseVary(manifest) if hasStar { return nil, false, nil } return names, len(names) > 0, nil } func allowsSharedCacheDirectives(cc responseCacheControl) bool { if cc.hasPrivate { return false } if cc.hasPublic || cc.sMaxAgeSet || cc.mustRevalidate || cc.proxyRevalidate { return true } // RFC 9111 §4.2.2 permits Expires as an absolute expiry for cacheable responses, but for // authenticated requests §3.6 requires an explicit shared-cache directive. Therefore, // an Expires header alone MUST NOT allow sharing when Authorization is present. return false } func allowsSharedCache(cc string) bool { return allowsSharedCacheDirectives(parseResponseCacheControl(utils.UnsafeBytes(cc))) } func makeHashAuthFunc(hexBufPool *sync.Pool) func([]byte) string { return func(authHeader []byte) string { sum := sha256.Sum256(authHeader) v := hexBufPool.Get() bufPtr, ok := v.(*[]byte) if !ok || bufPtr == nil { b := make([]byte, hexLen) bufPtr = &b } buf := *bufPtr if cap(buf) < hexLen { buf = make([]byte, hexLen) } else { buf = buf[:hexLen] } *bufPtr = buf hex.Encode(buf, sum[:]) result := string(buf) hexBufPool.Put(bufPtr) return result } } ================================================ FILE: middleware/cache/cache_test.go ================================================ // Special thanks to @codemicro for moving this to fiber core // Original middleware: github.com/codemicro/fiber-cache package cache import ( "bytes" "context" "errors" "fmt" "io" "math" "net/http" "net/http/httptest" "os" "strconv" "strings" "sync" "sync/atomic" "testing" "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/internal/storage/memory" "github.com/gofiber/fiber/v3/middleware/etag" "github.com/gofiber/utils/v2" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) type failingCacheStorage struct { data map[string][]byte errs map[string]error mu sync.RWMutex } type mutatingStorage struct { data map[string][]byte mutate func(key string, value []byte) []byte } func newFailingCacheStorage() *failingCacheStorage { return &failingCacheStorage{ data: make(map[string][]byte), errs: make(map[string]error), } } func newMutatingStorage(mutate func(key string, value []byte) []byte) *mutatingStorage { return &mutatingStorage{ data: make(map[string][]byte), mutate: mutate, } } func (s *mutatingStorage) GetWithContext(_ context.Context, key string) ([]byte, error) { return s.Get(key) } func (s *mutatingStorage) Get(key string) ([]byte, error) { if value, ok := s.data[key]; ok { return value, nil } return nil, nil } func (s *mutatingStorage) SetWithContext(_ context.Context, key string, val []byte, _ time.Duration) error { return s.Set(key, val, 0) } func (s *mutatingStorage) Set(key string, val []byte, _ time.Duration) error { if key == "" || len(val) == 0 { return nil } if s.mutate != nil { val = s.mutate(key, val) } s.data[key] = val return nil } func (s *mutatingStorage) DeleteWithContext(_ context.Context, key string) error { return s.Delete(key) } func (s *mutatingStorage) Delete(key string) error { delete(s.data, key) return nil } func (s *mutatingStorage) ResetWithContext(_ context.Context) error { return s.Reset() } func (s *mutatingStorage) Reset() error { s.data = make(map[string][]byte) return nil } func (s *mutatingStorage) Close() error { s.data = nil return nil } func (s *failingCacheStorage) GetWithContext(_ context.Context, key string) ([]byte, error) { s.mu.RLock() defer s.mu.RUnlock() if err, ok := s.errs["get|"+key]; ok && err != nil { return nil, err } if val, ok := s.data[key]; ok { return append([]byte(nil), val...), nil } return nil, nil } func (s *failingCacheStorage) Get(key string) ([]byte, error) { return s.GetWithContext(context.Background(), key) } func (s *failingCacheStorage) SetWithContext(_ context.Context, key string, val []byte, _ time.Duration) error { s.mu.Lock() defer s.mu.Unlock() if err, ok := s.errs["set|"+key]; ok && err != nil { return err } s.data[key] = append([]byte(nil), val...) return nil } func (s *failingCacheStorage) Set(key string, val []byte, exp time.Duration) error { return s.SetWithContext(context.Background(), key, val, exp) } func (s *failingCacheStorage) DeleteWithContext(_ context.Context, key string) error { s.mu.Lock() defer s.mu.Unlock() if err, ok := s.errs["del|"+key]; ok && err != nil { return err } delete(s.data, key) return nil } func (s *failingCacheStorage) Delete(key string) error { return s.DeleteWithContext(context.Background(), key) } func (s *failingCacheStorage) ResetWithContext(context.Context) error { s.mu.Lock() defer s.mu.Unlock() s.data = make(map[string][]byte) s.errs = make(map[string]error) return nil } func (s *failingCacheStorage) Reset() error { return s.ResetWithContext(context.Background()) } func (*failingCacheStorage) Close() error { return nil } type contextRecord struct { key string value string canceled bool } type contextRecorderStorage struct { *failingCacheStorage deletes []contextRecord gets []contextRecord sets []contextRecord } func newContextRecorderStorage() *contextRecorderStorage { return &contextRecorderStorage{failingCacheStorage: newFailingCacheStorage()} } func contextRecordFrom(ctx context.Context, key string) contextRecord { record := contextRecord{ key: key, canceled: errors.Is(ctx.Err(), context.Canceled), } if value, ok := ctx.Value(markerKey).(string); ok { record.value = value } return record } func (s *contextRecorderStorage) GetWithContext(ctx context.Context, key string) ([]byte, error) { s.gets = append(s.gets, contextRecordFrom(ctx, key)) return s.failingCacheStorage.GetWithContext(ctx, key) } func (s *contextRecorderStorage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { s.sets = append(s.sets, contextRecordFrom(ctx, key)) return s.failingCacheStorage.SetWithContext(ctx, key, val, exp) } func (s *contextRecorderStorage) DeleteWithContext(ctx context.Context, key string) error { s.deletes = append(s.deletes, contextRecordFrom(ctx, key)) return s.failingCacheStorage.DeleteWithContext(ctx, key) } func (s *contextRecorderStorage) recordedGets() []contextRecord { out := make([]contextRecord, len(s.gets)) copy(out, s.gets) return out } func (s *contextRecorderStorage) recordedSets() []contextRecord { out := make([]contextRecord, len(s.sets)) copy(out, s.sets) return out } func (s *contextRecorderStorage) recordedDeletes() []contextRecord { out := make([]contextRecord, len(s.deletes)) copy(out, s.deletes) return out } func TestCacheStorageGetError(t *testing.T) { t.Parallel() storage := newFailingCacheStorage() storage.errs["get|/_GET"] = errors.New("boom") var captured error app := fiber.New(fiber.Config{ ErrorHandler: func(c fiber.Ctx, err error) error { captured = err return c.Status(fiber.StatusInternalServerError).SendString("storage failure") }, }) app.Use(New(Config{Storage: storage, Expiration: time.Second})) app.Get("/", func(c fiber.Ctx) error { return c.SendString("ok") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) require.Error(t, captured) require.ErrorContains(t, captured, "cache: failed to get key") } func TestCacheStorageSetError(t *testing.T) { t.Parallel() storage := newFailingCacheStorage() storage.errs["set|/_GET_body"] = errors.New("boom") var captured error app := fiber.New(fiber.Config{ ErrorHandler: func(c fiber.Ctx, err error) error { captured = err return c.Status(fiber.StatusInternalServerError).SendString("storage failure") }, }) app.Use(New(Config{Storage: storage, Expiration: time.Second})) app.Get("/", func(c fiber.Ctx) error { return c.SendString("ok") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) require.Error(t, captured) require.ErrorContains(t, captured, "cache: failed to store raw key") } func TestCacheStorageDeleteError(t *testing.T) { t.Parallel() storage := newFailingCacheStorage() storage.errs["del|/_GET"] = errors.New("boom") // Use an obviously expired timestamp without relying on time-based conversions expired := &item{exp: 1} raw, err := expired.MarshalMsg(nil) require.NoError(t, err) storage.data["/_GET"] = raw var captured error app := fiber.New(fiber.Config{ ErrorHandler: func(c fiber.Ctx, err error) error { captured = err return c.Status(fiber.StatusInternalServerError).SendString("storage failure") }, }) app.Use(New(Config{Storage: storage, Expiration: time.Second})) app.Get("/", func(c fiber.Ctx) error { return c.SendString("ok") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) require.Error(t, captured) require.ErrorContains(t, captured, "cache: failed to delete expired key") } type contextKey string const markerKey contextKey = "marker" func contextWithMarker(label string) context.Context { return context.WithValue(context.Background(), markerKey, label) } func canceledContextWithMarker(label string) context.Context { ctx, cancel := context.WithCancel(contextWithMarker(label)) cancel() return ctx } func TestCacheEvictionPropagatesRequestContextToDelete(t *testing.T) { t.Parallel() storage := newContextRecorderStorage() app := fiber.New() app.Use(func(c fiber.Ctx) error { path := c.Path() if path == "/first" { c.SetContext(contextWithMarker("first")) } if path == "/second" { c.SetContext(canceledContextWithMarker("evict")) } return c.Next() }) app.Use(New(Config{Storage: storage, Expiration: time.Minute, MaxBytes: 5})) app.Get("/first", func(c fiber.Ctx) error { return c.SendString("aaa") }) app.Get("/second", func(c fiber.Ctx) error { return c.SendString("bbbb") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/first", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/second", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) records := storage.recordedDeletes() require.Len(t, records, 2) var keys []string for _, rec := range records { keys = append(keys, rec.key) require.Equal(t, "evict", rec.value) require.True(t, rec.canceled) } require.ElementsMatch(t, []string{"/first_GET", "/first_GET_body"}, keys) } func TestCacheCleanupPropagatesRequestContextToDelete(t *testing.T) { t.Parallel() storage := newContextRecorderStorage() storage.errs["set|/_GET"] = errors.New("boom") var captured error app := fiber.New(fiber.Config{ ErrorHandler: func(c fiber.Ctx, err error) error { captured = err return c.Status(fiber.StatusInternalServerError).SendString("storage failure") }, }) app.Use(func(c fiber.Ctx) error { c.SetContext(canceledContextWithMarker("cleanup")) return c.Next() }) app.Use(New(Config{Storage: storage, Expiration: time.Minute})) app.Get("/", func(c fiber.Ctx) error { return c.SendString("payload") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) require.Error(t, captured) require.ErrorContains(t, captured, "cache: failed to store key") records := storage.recordedDeletes() require.Len(t, records, 1) require.Equal(t, "/_GET_body", records[0].key) require.Equal(t, "cleanup", records[0].value) require.True(t, records[0].canceled) } func TestCacheStorageOperationsObserveRequestContext(t *testing.T) { t.Parallel() storage := newContextRecorderStorage() app := fiber.New() app.Use(func(c fiber.Ctx) error { ctxLabel := string(c.Request().Header.Peek("X-Context")) if ctxLabel == "" { return c.Next() } canceled := string(c.Request().Header.Peek("X-Cancel")) == "true" if canceled { c.SetContext(canceledContextWithMarker(ctxLabel)) } else { c.SetContext(contextWithMarker(ctxLabel)) } return c.Next() }) app.Use(New(Config{Storage: storage, Expiration: time.Minute})) app.Get("/cache", func(c fiber.Ctx) error { return c.SendString("payload") }) firstReq := httptest.NewRequest(fiber.MethodGet, "/cache", http.NoBody) firstReq.Header.Set("X-Context", "store") firstReq.Header.Set("X-Cancel", "true") resp, err := app.Test(firstReq) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) secondReq := httptest.NewRequest(fiber.MethodGet, "/cache", http.NoBody) secondReq.Header.Set("X-Context", "fetch") secondReq.Header.Set("X-Cancel", "true") resp, err = app.Test(secondReq) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) setRecords := storage.recordedSets() require.Len(t, setRecords, 2) for _, rec := range setRecords { require.Contains(t, []string{"/cache_GET", "/cache_GET_body"}, rec.key) require.Equal(t, "store", rec.value) require.True(t, rec.canceled) } getRecords := storage.recordedGets() require.NotEmpty(t, getRecords) var fetchEntry, fetchBody bool for _, rec := range getRecords { if rec.value != "fetch" { continue } if rec.key == "/cache_GET" { require.True(t, rec.canceled) fetchEntry = true } if rec.key == "/cache_GET_body" { require.True(t, rec.canceled) fetchBody = true } } require.True(t, fetchEntry, "expected cached entry retrieval to observe request context") require.True(t, fetchBody, "expected cached body retrieval to observe request context") } func Test_Cache_CacheControl(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 10 * time.Second})) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) _, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, "public, max-age=10", resp.Header.Get(fiber.HeaderCacheControl)) } func Test_Cache_CacheControl_Disabled(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Expiration: 10 * time.Second, DisableCacheControl: true, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) _, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Empty(t, resp.Header.Get(fiber.HeaderCacheControl)) } func Test_Cache_Expired(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 2 * time.Second})) count := 0 app.Get("/", func(c fiber.Ctx) error { count++ return c.SendString(strconv.Itoa(count)) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) body, err := io.ReadAll(resp.Body) require.NoError(t, err) // Sleep until the cache is expired time.Sleep(3 * time.Second) respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) bodyCached, err := io.ReadAll(respCached.Body) require.NoError(t, err) if bytes.Equal(body, bodyCached) { t.Errorf("Cache should have expired: %s, %s", body, bodyCached) } // Next response should be also cached respCachedNextRound, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) bodyCachedNextRound, err := io.ReadAll(respCachedNextRound.Body) require.NoError(t, err) if !bytes.Equal(bodyCachedNextRound, bodyCached) { t.Errorf("Cache should not have expired: %s, %s", bodyCached, bodyCachedNextRound) } } func Test_Cache(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) count := 0 app.Get("/", func(c fiber.Ctx) error { count++ return c.SendString(strconv.Itoa(count)) }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) resp, err := app.Test(req) require.NoError(t, err) cachedReq := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) cachedResp, err := app.Test(cachedReq) require.NoError(t, err) body, err := io.ReadAll(resp.Body) require.NoError(t, err) cachedBody, err := io.ReadAll(cachedResp.Body) require.NoError(t, err) require.Equal(t, cachedBody, body) } // go test -run Test_Cache_WithNoCacheRequestDirective func Test_Cache_WithNoCacheRequestDirective(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { return c.SendString(fiber.Query(c, "id", "1")) }) // Request id = 1 req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) resp, err := app.Test(req) require.NoError(t, err) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) require.Equal(t, []byte("1"), body) // Response cached, entry id = 1 // Request id = 2 without Cache-Control: no-cache cachedReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", http.NoBody) cachedResp, err := app.Test(cachedReq) require.NoError(t, err) cachedBody, err := io.ReadAll(cachedResp.Body) require.NoError(t, err) require.Equal(t, cacheHit, cachedResp.Header.Get("X-Cache")) require.Equal(t, []byte("1"), cachedBody) // Response not cached, returns cached response, entry id = 1 // Request id = 2 with Cache-Control: no-cache noCacheReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", http.NoBody) noCacheReq.Header.Set(fiber.HeaderCacheControl, noCache) noCacheResp, err := app.Test(noCacheReq) require.NoError(t, err) noCacheBody, err := io.ReadAll(noCacheResp.Body) require.NoError(t, err) require.Equal(t, cacheMiss, noCacheResp.Header.Get("X-Cache")) require.Equal(t, []byte("2"), noCacheBody) // Response cached, returns updated response, entry = 2 /* Check Test_Cache_WithETagAndNoCacheRequestDirective */ // Request id = 2 with Cache-Control: no-cache again noCacheReq1 := httptest.NewRequest(fiber.MethodGet, "/?id=2", http.NoBody) noCacheReq1.Header.Set(fiber.HeaderCacheControl, noCache) noCacheResp1, err := app.Test(noCacheReq1) require.NoError(t, err) noCacheBody1, err := io.ReadAll(noCacheResp1.Body) require.NoError(t, err) require.Equal(t, cacheMiss, noCacheResp1.Header.Get("X-Cache")) require.Equal(t, []byte("2"), noCacheBody1) // Response cached, returns updated response, entry = 2 // Request id = 3 with Cache-Control: NO-CACHE noCacheReqUpper := httptest.NewRequest(fiber.MethodGet, "/?id=3", http.NoBody) noCacheReqUpper.Header.Set(fiber.HeaderCacheControl, "NO-CACHE") noCacheRespUpper, err := app.Test(noCacheReqUpper) require.NoError(t, err) noCacheBodyUpper, err := io.ReadAll(noCacheRespUpper.Body) require.NoError(t, err) require.Equal(t, cacheMiss, noCacheRespUpper.Header.Get("X-Cache")) require.Equal(t, []byte("3"), noCacheBodyUpper) // Response cached, returns updated response, entry = 3 // Request id = 4 with Cache-Control: my-no-cache invalidReq := httptest.NewRequest(fiber.MethodGet, "/?id=4", http.NoBody) invalidReq.Header.Set(fiber.HeaderCacheControl, "my-no-cache") invalidResp, err := app.Test(invalidReq) require.NoError(t, err) invalidBody, err := io.ReadAll(invalidResp.Body) require.NoError(t, err) require.Equal(t, cacheHit, invalidResp.Header.Get("X-Cache")) require.Equal(t, []byte("3"), invalidBody) // Response served from cache, existing entry = 3 // Request id = 4 again without Cache-Control: no-cache cachedInvalidReq := httptest.NewRequest(fiber.MethodGet, "/?id=4", http.NoBody) cachedInvalidResp, err := app.Test(cachedInvalidReq) require.NoError(t, err) cachedInvalidBody, err := io.ReadAll(cachedInvalidResp.Body) require.NoError(t, err) require.Equal(t, cacheHit, cachedInvalidResp.Header.Get("X-Cache")) require.Equal(t, []byte("3"), cachedInvalidBody) // Response cached, returns cached response, entry id = 3 // Request id = 1 without Cache-Control: no-cache cachedReq1 := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) cachedResp1, err := app.Test(cachedReq1) require.NoError(t, err) cachedBody1, err := io.ReadAll(cachedResp1.Body) require.NoError(t, err) require.Equal(t, cacheHit, cachedResp1.Header.Get("X-Cache")) require.Equal(t, []byte("3"), cachedBody1) // Response not cached, returns cached response, entry id = 3 } // go test -run Test_Cache_WithETagAndNoCacheRequestDirective func Test_Cache_WithETagAndNoCacheRequestDirective(t *testing.T) { t.Parallel() app := fiber.New() app.Use( etag.New(), New(), ) app.Get("/", func(c fiber.Ctx) error { return c.SendString(fiber.Query(c, "id", "1")) }) // Request id = 1 req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) require.Equal(t, fiber.StatusOK, resp.StatusCode) // Response cached, entry id = 1 // If response status 200 etagToken := resp.Header.Get("Etag") // Request id = 2 with ETag but without Cache-Control: no-cache cachedReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", http.NoBody) cachedReq.Header.Set(fiber.HeaderIfNoneMatch, etagToken) cachedResp, err := app.Test(cachedReq) require.NoError(t, err) require.Equal(t, cacheHit, cachedResp.Header.Get("X-Cache")) require.Equal(t, fiber.StatusNotModified, cachedResp.StatusCode) // Response not cached, returns cached response, entry id = 1, status not modified // Request id = 2 with ETag and Cache-Control: no-cache noCacheReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", http.NoBody) noCacheReq.Header.Set(fiber.HeaderCacheControl, noCache) noCacheReq.Header.Set(fiber.HeaderIfNoneMatch, etagToken) noCacheResp, err := app.Test(noCacheReq) require.NoError(t, err) require.Equal(t, cacheMiss, noCacheResp.Header.Get("X-Cache")) require.Equal(t, fiber.StatusOK, noCacheResp.StatusCode) // Response cached, returns updated response, entry id = 2 // If response status 200 etagToken = noCacheResp.Header.Get("Etag") // Request id = 3 with ETag and Cache-Control: NO-CACHE noCacheReqUpper := httptest.NewRequest(fiber.MethodGet, "/?id=3", http.NoBody) noCacheReqUpper.Header.Set(fiber.HeaderCacheControl, "NO-CACHE") noCacheReqUpper.Header.Set(fiber.HeaderIfNoneMatch, etagToken) noCacheRespUpper, err := app.Test(noCacheReqUpper) require.NoError(t, err) require.Equal(t, cacheMiss, noCacheRespUpper.Header.Get("X-Cache")) require.Equal(t, fiber.StatusOK, noCacheRespUpper.StatusCode) // Response cached, returns updated response, entry id = 3 // Request id = 2 with ETag and Cache-Control: no-cache again noCacheReq1 := httptest.NewRequest(fiber.MethodGet, "/?id=2", http.NoBody) noCacheReq1.Header.Set(fiber.HeaderCacheControl, noCache) noCacheReq1.Header.Set(fiber.HeaderIfNoneMatch, etagToken) noCacheResp1, err := app.Test(noCacheReq1) require.NoError(t, err) require.Equal(t, cacheMiss, noCacheResp1.Header.Get("X-Cache")) require.Equal(t, fiber.StatusNotModified, noCacheResp1.StatusCode) // Response cached, returns updated response, entry id = 2, status not modified // Request id = 1 without ETag and Cache-Control: no-cache cachedReq1 := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) cachedResp1, err := app.Test(cachedReq1) require.NoError(t, err) require.Equal(t, cacheHit, cachedResp1.Header.Get("X-Cache")) require.Equal(t, fiber.StatusOK, cachedResp1.StatusCode) // Response not cached, returns cached response, entry id = 2 } // go test -run Test_Cache_WithNoStoreRequestDirective func Test_Cache_WithNoStoreRequestDirective(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { return c.SendString(fiber.Query(c, "id", "1")) }) // Request id = 2 noStoreReq := httptest.NewRequest(fiber.MethodGet, "/?id=2", http.NoBody) noStoreReq.Header.Set(fiber.HeaderCacheControl, noStore) noStoreResp, err := app.Test(noStoreReq) require.NoError(t, err) noStoreBody, err := io.ReadAll(noStoreResp.Body) require.NoError(t, err) require.Equal(t, []byte("2"), noStoreBody) // Response not cached, returns updated response // Request id = 3 with Cache-Control: NO-STORE noStoreReqUpper := httptest.NewRequest(fiber.MethodGet, "/?id=3", http.NoBody) noStoreReqUpper.Header.Set(fiber.HeaderCacheControl, "NO-STORE") noStoreRespUpper, err := app.Test(noStoreReqUpper) require.NoError(t, err) noStoreBodyUpper, err := io.ReadAll(noStoreRespUpper.Body) require.NoError(t, err) require.Equal(t, []byte("3"), noStoreBodyUpper) // Response not cached, returns updated response // Request id = 4 with Cache-Control: my-no-store invalidReq := httptest.NewRequest(fiber.MethodGet, "/?id=4", http.NoBody) invalidReq.Header.Set(fiber.HeaderCacheControl, "my-no-store") invalidResp, err := app.Test(invalidReq) require.NoError(t, err) invalidBody, err := io.ReadAll(invalidResp.Body) require.NoError(t, err) require.Equal(t, cacheMiss, invalidResp.Header.Get("X-Cache")) require.Equal(t, []byte("4"), invalidBody) // Response cached, returns updated response, entry = 4 // Request id = 4 again without Cache-Control cachedInvalidReq := httptest.NewRequest(fiber.MethodGet, "/?id=4", http.NoBody) cachedInvalidResp, err := app.Test(cachedInvalidReq) require.NoError(t, err) cachedInvalidBody, err := io.ReadAll(cachedInvalidResp.Body) require.NoError(t, err) require.Equal(t, cacheHit, cachedInvalidResp.Header.Get("X-Cache")) require.Equal(t, []byte("4"), cachedInvalidBody) // Response cached previously, served from cache } func Test_Cache_WithSeveralRequests(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Expiration: 10 * time.Second, })) app.Get("/:id", func(c fiber.Ctx) error { return c.SendString(c.Params("id")) }) for range 10 { for i := range 10 { func(id int) { rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, fmt.Sprintf("/%d", id), http.NoBody)) require.NoError(t, err) defer func(body io.ReadCloser) { closeErr := body.Close() require.NoError(t, closeErr) }(rsp.Body) idFromServ, err := io.ReadAll(rsp.Body) require.NoError(t, err) a, err := strconv.Atoi(string(idFromServ)) require.NoError(t, err) // Sometimes, the id is not equal to a require.Equal(t, id, a) }(i) } } } func Test_Cache_Invalid_Expiration(t *testing.T) { t.Parallel() app := fiber.New() cache := New(Config{Expiration: 0 * time.Second}) app.Use(cache) count := 0 app.Get("/", func(c fiber.Ctx) error { count++ return c.SendString(strconv.Itoa(count)) }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) resp, err := app.Test(req) require.NoError(t, err) cachedReq := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) cachedResp, err := app.Test(cachedReq) require.NoError(t, err) body, err := io.ReadAll(resp.Body) require.NoError(t, err) cachedBody, err := io.ReadAll(cachedResp.Body) require.NoError(t, err) require.Equal(t, cachedBody, body) } func Test_Cache_Get(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Post("/", func(c fiber.Ctx) error { return c.SendString(fiber.Query[string](c, "cache")) }) app.Get("/get", func(c fiber.Ctx) error { return c.SendString(fiber.Query[string](c, "cache")) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=123", http.NoBody)) require.NoError(t, err) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "123", string(body)) resp, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=12345", http.NoBody)) require.NoError(t, err) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "12345", string(body)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/get?cache=123", http.NoBody)) require.NoError(t, err) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "123", string(body)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/get?cache=12345", http.NoBody)) require.NoError(t, err) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "123", string(body)) } func Test_Cache_Post(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Methods: []string{fiber.MethodPost}, })) app.Post("/", func(c fiber.Ctx) error { return c.SendString(fiber.Query[string](c, "cache")) }) app.Get("/get", func(c fiber.Ctx) error { return c.SendString(fiber.Query[string](c, "cache")) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=123", http.NoBody)) require.NoError(t, err) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "123", string(body)) resp, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=12345", http.NoBody)) require.NoError(t, err) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "123", string(body)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/get?cache=123", http.NoBody)) require.NoError(t, err) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "123", string(body)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/get?cache=12345", http.NoBody)) require.NoError(t, err) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "12345", string(body)) } func Test_Cache_NothingToCache(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: -(time.Second * 1)})) count := 0 app.Get("/", func(c fiber.Ctx) error { count++ return c.SendString(strconv.Itoa(count)) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) body, err := io.ReadAll(resp.Body) require.NoError(t, err) time.Sleep(500 * time.Millisecond) respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) bodyCached, err := io.ReadAll(respCached.Body) require.NoError(t, err) if bytes.Equal(body, bodyCached) { t.Errorf("Cache should have expired: %s, %s", body, bodyCached) } } func Test_Cache_CustomNext(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Next: func(c fiber.Ctx) bool { return c.Response().StatusCode() != fiber.StatusOK }, })) count := 0 app.Get("/", func(c fiber.Ctx) error { count++ return c.SendString(strconv.Itoa(count)) }) errorCount := 0 app.Get("/error", func(c fiber.Ctx) error { errorCount++ return c.Status(fiber.StatusInternalServerError).SendString(strconv.Itoa(errorCount)) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) body, err := io.ReadAll(resp.Body) require.NoError(t, err) respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) bodyCached, err := io.ReadAll(respCached.Body) require.NoError(t, err) require.True(t, bytes.Equal(body, bodyCached)) require.NotEmpty(t, respCached.Header.Get(fiber.HeaderCacheControl)) _, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/error", http.NoBody)) require.NoError(t, err) errRespCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/error", http.NoBody)) require.NoError(t, err) require.Empty(t, errRespCached.Header.Get(fiber.HeaderCacheControl)) } func Test_CustomKey(t *testing.T) { t.Parallel() app := fiber.New() var called bool app.Use(New(Config{KeyGenerator: func(c fiber.Ctx) string { called = true return utils.CopyString(c.Path()) }})) app.Get("/", func(c fiber.Ctx) error { return c.SendString("hi") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) _, err := app.Test(req) require.NoError(t, err) require.True(t, called) } func Test_CustomExpiration(t *testing.T) { t.Parallel() app := fiber.New() var called bool var newCacheTime int app.Use(New(Config{ExpirationGenerator: func(c fiber.Ctx, _ *Config) time.Duration { called = true var err error newCacheTime, err = strconv.Atoi(c.GetRespHeader("Cache-Time", "600")) require.NoError(t, err) return time.Second * time.Duration(newCacheTime) }})) count := 0 app.Get("/", func(c fiber.Ctx) error { count++ c.Response().Header.Add("Cache-Time", "1") return c.SendString(strconv.Itoa(count)) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.True(t, called) require.Equal(t, 1, newCacheTime) // Wait until the cache expires (timestamp tick can delay expiry detection slightly). expireDeadline := time.Now().Add(3 * time.Second) var cachedResp *http.Response for { cachedResp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) if cachedResp.Header.Get("X-Cache") != cacheHit { break } require.True(t, time.Now().Before(expireDeadline), "response remained cached beyond expected expiration") time.Sleep(50 * time.Millisecond) } body, err := io.ReadAll(resp.Body) require.NoError(t, err) cachedBody, err := io.ReadAll(cachedResp.Body) require.NoError(t, err) if bytes.Equal(body, cachedBody) { t.Errorf("Cache should have expired: %s, %s", body, cachedBody) } // Next response should be cached cachedRespNextRound, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) cachedBodyNextRound, err := io.ReadAll(cachedRespNextRound.Body) require.NoError(t, err) if !bytes.Equal(cachedBodyNextRound, cachedBody) { t.Errorf("Cache should not have expired: %s, %s", cachedBodyNextRound, cachedBody) } } func Test_AdditionalE2EResponseHeaders(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ StoreResponseHeaders: true, })) app.Get("/", func(c fiber.Ctx) error { c.Response().Header.Add("X-Foobar", "foobar") return c.SendString("hi") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, "foobar", resp.Header.Get("X-Foobar")) req = httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, "foobar", resp.Header.Get("X-Foobar")) } func Test_CacheHeader(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Expiration: 10 * time.Second, Next: func(c fiber.Ctx) bool { return c.Response().StatusCode() != fiber.StatusOK }, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) app.Post("/", func(c fiber.Ctx) error { return c.SendString(fiber.Query[string](c, "cache")) }) count := 0 app.Get("/error", func(c fiber.Ctx) error { count++ c.Response().Header.Add("Cache-Time", "1") return c.Status(fiber.StatusInternalServerError).SendString(strconv.Itoa(count)) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, resp.Header.Get("X-Cache")) resp, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/?cache=12345", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache")) errRespCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/error", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheUnreachable, errRespCached.Header.Get("X-Cache")) } func Test_Cache_WithHead(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) count := 0 handler := func(c fiber.Ctx) error { count++ c.Response().Header.Add("Cache-Time", "1") return c.SendString(strconv.Itoa(count)) } app.RouteChain("/").Get(handler).Head(handler) req := httptest.NewRequest(fiber.MethodHead, "/", http.NoBody) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) cachedReq := httptest.NewRequest(fiber.MethodHead, "/", http.NoBody) cachedResp, err := app.Test(cachedReq) require.NoError(t, err) require.Equal(t, cacheHit, cachedResp.Header.Get("X-Cache")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) cachedBody, err := io.ReadAll(cachedResp.Body) require.NoError(t, err) require.Equal(t, cachedBody, body) } func Test_Cache_WithHeadThenGet(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) handler := func(c fiber.Ctx) error { return c.SendString(fiber.Query[string](c, "cache")) } app.RouteChain("/").Get(handler).Head(handler) headResp, err := app.Test(httptest.NewRequest(fiber.MethodHead, "/?cache=123", http.NoBody)) require.NoError(t, err) headBody, err := io.ReadAll(headResp.Body) require.NoError(t, err) require.Empty(t, string(headBody)) require.Equal(t, cacheMiss, headResp.Header.Get("X-Cache")) headResp, err = app.Test(httptest.NewRequest(fiber.MethodHead, "/?cache=123", http.NoBody)) require.NoError(t, err) headBody, err = io.ReadAll(headResp.Body) require.NoError(t, err) require.Empty(t, string(headBody)) require.Equal(t, cacheHit, headResp.Header.Get("X-Cache")) getResp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?cache=123", http.NoBody)) require.NoError(t, err) getBody, err := io.ReadAll(getResp.Body) require.NoError(t, err) require.Equal(t, "123", string(getBody)) require.Equal(t, cacheMiss, getResp.Header.Get("X-Cache")) getResp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/?cache=123", http.NoBody)) require.NoError(t, err) getBody, err = io.ReadAll(getResp.Body) require.NoError(t, err) require.Equal(t, "123", string(getBody)) require.Equal(t, cacheHit, getResp.Header.Get("X-Cache")) } func Test_CustomCacheHeader(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ CacheHeader: "Cache-Status", })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("Cache-Status")) } func Test_CacheInvalidation(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ CacheInvalidator: func(c fiber.Ctx) bool { return fiber.Query[bool](c, "invalidate") }, })) count := 0 app.Get("/", func(c fiber.Ctx) error { count++ return c.SendString(strconv.Itoa(count)) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) body, err := io.ReadAll(resp.Body) require.NoError(t, err) respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) bodyCached, err := io.ReadAll(respCached.Body) require.NoError(t, err) require.True(t, bytes.Equal(body, bodyCached)) require.NotEmpty(t, respCached.Header.Get(fiber.HeaderCacheControl)) respInvalidate, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?invalidate=true", http.NoBody)) require.NoError(t, err) bodyInvalidate, err := io.ReadAll(respInvalidate.Body) require.NoError(t, err) require.NotEqual(t, body, bodyInvalidate) } func Test_CacheInvalidation_noCacheEntry(t *testing.T) { t.Parallel() t.Run("Cache Invalidator should not be called if no cache entry exist ", func(t *testing.T) { t.Parallel() app := fiber.New() cacheInvalidatorExecuted := false app.Use(New(Config{ CacheInvalidator: func(c fiber.Ctx) bool { cacheInvalidatorExecuted = true return fiber.Query[bool](c, "invalidate") }, MaxBytes: 10 * 1024 * 1024, })) _, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?invalidate=true", http.NoBody)) require.NoError(t, err) require.False(t, cacheInvalidatorExecuted) }) } func Test_CacheInvalidation_removeFromHeap(t *testing.T) { t.Parallel() t.Run("Invalidate and remove from the heap", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ CacheInvalidator: func(c fiber.Ctx) bool { return fiber.Query[bool](c, "invalidate") }, MaxBytes: 10 * 1024 * 1024, })) count := 0 app.Get("/", func(c fiber.Ctx) error { count++ return c.SendString(strconv.Itoa(count)) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) body, err := io.ReadAll(resp.Body) require.NoError(t, err) respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) bodyCached, err := io.ReadAll(respCached.Body) require.NoError(t, err) require.True(t, bytes.Equal(body, bodyCached)) require.NotEmpty(t, respCached.Header.Get(fiber.HeaderCacheControl)) respInvalidate, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?invalidate=true", http.NoBody)) require.NoError(t, err) bodyInvalidate, err := io.ReadAll(respInvalidate.Body) require.NoError(t, err) require.NotEqual(t, body, bodyInvalidate) }) } func Test_CacheStorage_CustomHeaders(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Storage: memory.New(), MaxBytes: 10 * 1024 * 1024, })) app.Get("/", func(c fiber.Ctx) error { c.Response().Header.Set("Content-Type", "text/xml") c.Response().Header.Set("Content-Encoding", "utf8") return c.Send([]byte("Test")) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) body, err := io.ReadAll(resp.Body) require.NoError(t, err) respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) bodyCached, err := io.ReadAll(respCached.Body) require.NoError(t, err) require.True(t, bytes.Equal(body, bodyCached)) require.NotEmpty(t, respCached.Header.Get(fiber.HeaderCacheControl)) } // Because time points are updated once every X milliseconds, entries in tests can often have // equal expiration times and thus be in a random order. This closure hands out increasing // time intervals to maintain strong ascending order of expiration func stableAscendingExpiration() func(c1 fiber.Ctx, c2 *Config) time.Duration { i := 0 return func(_ fiber.Ctx, _ *Config) time.Duration { i++ return time.Hour * time.Duration(i) } } func Test_Cache_MaxBytesOrder(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ MaxBytes: 2, ExpirationGenerator: stableAscendingExpiration(), })) app.Get("/*", func(c fiber.Ctx) error { return c.SendString("1") }) cases := [][]string{ // Insert a, b into cache of size 2 bytes (responses are 1 byte) {"/a", cacheMiss}, {"/b", cacheMiss}, {"/a", cacheHit}, {"/b", cacheHit}, // Add c -> a evicted {"/c", cacheMiss}, {"/b", cacheHit}, // Add a again -> b evicted {"/a", cacheMiss}, {"/c", cacheHit}, // Add b -> c evicted {"/b", cacheMiss}, {"/c", cacheMiss}, } for idx, tcase := range cases { rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, tcase[0], http.NoBody)) require.NoError(t, err) require.Equal(t, tcase[1], rsp.Header.Get("X-Cache"), "Case %v", idx) } } func Test_Cache_MaxBytesSizes(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ MaxBytes: 7, ExpirationGenerator: stableAscendingExpiration(), })) app.Get("/*", func(c fiber.Ctx) error { path := c.RequestCtx().URI().LastPathSegment() size, err := strconv.Atoi(string(path)) require.NoError(t, err) return c.Send(make([]byte, size)) }) cases := [][]string{ {"/1", cacheMiss}, {"/2", cacheMiss}, {"/3", cacheMiss}, {"/4", cacheMiss}, // 1+2+3+4 > 7 => 1,2 are evicted now {"/3", cacheHit}, {"/1", cacheMiss}, {"/2", cacheMiss}, {"/8", cacheUnreachable}, // too big to cache -> unreachable } for idx, tcase := range cases { rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, tcase[0], http.NoBody)) require.NoError(t, err) require.Equal(t, tcase[1], rsp.Header.Get("X-Cache"), "Case %v", idx) } } func Test_Cache_UncacheableStatusCodes(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/:statusCode", func(c fiber.Ctx) error { statusCode, err := strconv.Atoi(c.Params("statusCode")) require.NoError(t, err) return c.Status(statusCode).SendString("foo") }) uncacheableStatusCodes := []int{ // Informational responses fiber.StatusContinue, fiber.StatusSwitchingProtocols, fiber.StatusProcessing, fiber.StatusEarlyHints, // Successful responses fiber.StatusCreated, fiber.StatusAccepted, fiber.StatusResetContent, fiber.StatusMultiStatus, fiber.StatusAlreadyReported, fiber.StatusIMUsed, // Redirection responses fiber.StatusFound, fiber.StatusSeeOther, fiber.StatusNotModified, fiber.StatusUseProxy, fiber.StatusSwitchProxy, fiber.StatusTemporaryRedirect, // Client error responses fiber.StatusBadRequest, fiber.StatusUnauthorized, fiber.StatusPaymentRequired, fiber.StatusForbidden, fiber.StatusNotAcceptable, fiber.StatusProxyAuthRequired, fiber.StatusRequestTimeout, fiber.StatusConflict, fiber.StatusLengthRequired, fiber.StatusPreconditionFailed, fiber.StatusRequestEntityTooLarge, fiber.StatusUnsupportedMediaType, fiber.StatusRequestedRangeNotSatisfiable, fiber.StatusExpectationFailed, fiber.StatusMisdirectedRequest, fiber.StatusUnprocessableEntity, fiber.StatusLocked, fiber.StatusFailedDependency, fiber.StatusTooEarly, fiber.StatusUpgradeRequired, fiber.StatusPreconditionRequired, fiber.StatusTooManyRequests, fiber.StatusRequestHeaderFieldsTooLarge, fiber.StatusTeapot, fiber.StatusUnavailableForLegalReasons, // Server error responses fiber.StatusInternalServerError, fiber.StatusBadGateway, fiber.StatusServiceUnavailable, fiber.StatusGatewayTimeout, fiber.StatusHTTPVersionNotSupported, fiber.StatusVariantAlsoNegotiates, fiber.StatusInsufficientStorage, fiber.StatusLoopDetected, fiber.StatusNotExtended, fiber.StatusNetworkAuthenticationRequired, } for _, v := range uncacheableStatusCodes { resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, fmt.Sprintf("/%d", v), http.NoBody)) require.NoError(t, err) require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache")) require.Equal(t, v, resp.StatusCode) } } func TestCacheAgeHeader(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 10 * time.Second})) app.Get("/", func(c fiber.Ctx) error { return c.SendString("ok") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, "0", resp.Header.Get(fiber.HeaderAge)) time.Sleep(4 * time.Second) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, resp.Header.Get("X-Cache")) age, err := strconv.Atoi(resp.Header.Get(fiber.HeaderAge)) require.NoError(t, err) require.Positive(t, age) } func TestCacheUpstreamAge(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 3 * time.Second})) app.Get("/", func(c fiber.Ctx) error { c.Set(fiber.HeaderAge, "5") return c.SendString("hi") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, "5", resp.Header.Get(fiber.HeaderAge)) time.Sleep(1500 * time.Millisecond) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache")) require.Equal(t, "5", resp.Header.Get(fiber.HeaderAge)) } func Test_CacheRequestMaxAgeRevalidates(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Expiration: 30 * time.Second, KeyGenerator: func(c fiber.Ctx) string { return c.Path() + "|req-max-age-zero" }, })) var count int app.Get("/", func(c fiber.Ctx) error { count++ c.Set(fiber.HeaderCacheControl, "public, max-age=30") return c.SendString(strconv.Itoa(count)) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "1", string(body)) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderCacheControl, "max-age=0") resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "2", string(body)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "2", string(body)) } func Test_CacheExpiresFutureAllowsCaching(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ StoreResponseHeaders: true, })) var count int app.Get("/", func(c fiber.Ctx) error { count++ c.Set(fiber.HeaderExpires, time.Now().Add(30*time.Second).UTC().Format(time.RFC1123)) return c.SendString("expires" + strconv.Itoa(count)) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "expires1", string(body)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "expires1", string(body)) } func Test_CacheExpiresPastPreventsCaching(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) var count int app.Get("/", func(c fiber.Ctx) error { count++ c.Set(fiber.HeaderExpires, time.Now().Add(-1*time.Minute).UTC().Format(time.RFC1123)) return c.SendString("expires" + strconv.Itoa(count)) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "expires1", string(body)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "expires2", string(body)) } func Test_CacheAllowsSharedCacheMustRevalidateWithAuthorization(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Expiration: 30 * time.Second, KeyGenerator: func(c fiber.Ctx) string { return c.Path() + "|must-revalidate-auth" }, })) var count int app.Get("/", func(c fiber.Ctx) error { count++ c.Set(fiber.HeaderCacheControl, "must-revalidate, max-age=60") return c.SendString("auth" + strconv.Itoa(count)) }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderAuthorization, "Bearer token") resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "auth1", string(body)) req = httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderAuthorization, "Bearer token") resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, cacheHit, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "auth1", string(body)) } func Test_CacheAllowsSharedCacheProxyRevalidateWithAuthorization(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Expiration: 30 * time.Second, KeyGenerator: func(c fiber.Ctx) string { return c.Path() + "|proxy-revalidate-auth" }, })) var count int app.Get("/", func(c fiber.Ctx) error { count++ c.Set(fiber.HeaderCacheControl, "proxy-revalidate, max-age=60") return c.SendString("proxy" + strconv.Itoa(count)) }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderAuthorization, "Bearer token") resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "proxy1", string(body)) req = httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderAuthorization, "Bearer token") resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, cacheHit, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "proxy1", string(body)) } func Test_CacheInvalidExpiresStoredAsStale(t *testing.T) { t.Parallel() storage := newFailingCacheStorage() app := fiber.New() app.Use(New(Config{ Expiration: 30 * time.Second, KeyGenerator: func(c fiber.Ctx) string { return c.Path() + "|invalid-expires" }, Storage: storage, })) var count int app.Get("/", func(c fiber.Ctx) error { count++ c.Set(fiber.HeaderCacheControl, "public") c.Set(fiber.HeaderExpires, "invalid-date") return c.SendString("body" + strconv.Itoa(count)) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "body1", string(body)) expectedKey := "/|invalid-expires_GET" require.Contains(t, storage.data, expectedKey) require.Contains(t, storage.data, expectedKey+"_body") resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "body2", string(body)) require.Contains(t, storage.data, expectedKey) require.Contains(t, storage.data, expectedKey+"_body") resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "body3", string(body)) require.Contains(t, storage.data, expectedKey) require.Contains(t, storage.data, expectedKey+"_body") } func Test_CacheSMaxAgeOverridesMaxAgeWhenShorter(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) var count int app.Get("/", func(c fiber.Ctx) error { count++ c.Set(fiber.HeaderCacheControl, "public, max-age=10, s-maxage=1") return c.SendString(strconv.Itoa(count)) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, resp.Header.Get("X-Cache")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "1", string(body)) time.Sleep(1700 * time.Millisecond) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "2", string(body)) } func Test_CacheSMaxAgeOverridesMaxAgeWhenLonger(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) var count int app.Get("/", func(c fiber.Ctx) error { count++ c.Set(fiber.HeaderCacheControl, "public, max-age=1, s-maxage=2") return c.SendString(strconv.Itoa(count)) }) for time.Now().Nanosecond() >= int(100*time.Millisecond) { time.Sleep(10 * time.Millisecond) } resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) time.Sleep(1200 * time.Millisecond) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, resp.Header.Get("X-Cache")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "1", string(body)) time.Sleep(1700 * time.Millisecond) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "2", string(body)) } func Test_CacheOnlyIfCachedMiss(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) var count int app.Get("/", func(c fiber.Ctx) error { count++ return c.SendString("ok") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderCacheControl, "only-if-cached") resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusGatewayTimeout, resp.StatusCode) require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache")) require.Equal(t, 0, count) } func Test_CacheOnlyIfCachedStaleNotServed(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) var count int app.Get("/", func(c fiber.Ctx) error { count++ c.Set(fiber.HeaderCacheControl, "public, max-age=1") return c.SendString(strconv.Itoa(count)) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) time.Sleep(1500 * time.Millisecond) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderCacheControl, "only-if-cached") resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusGatewayTimeout, resp.StatusCode) require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache")) require.Equal(t, 1, count) } func Test_CacheMaxStaleServesStaleResponse(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) var count int app.Get("/", func(c fiber.Ctx) error { count++ c.Set(fiber.HeaderCacheControl, "public, max-age=2") return c.SendString(strconv.Itoa(count)) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, resp.Header.Get("X-Cache")) time.Sleep(2500 * time.Millisecond) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderCacheControl, "max-stale=5") resp, err = app.Test(req) require.NoError(t, err) require.Equalf(t, cacheHit, resp.Header.Get("X-Cache"), "dirs=%+v Age=%s count=%d", parseRequestCacheControlString("max-stale=5"), resp.Header.Get(fiber.HeaderAge), count) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "1", string(body)) require.Equal(t, 1, count) } func Test_CacheMaxStaleRespectsMustRevalidate(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) var count int app.Get("/", func(c fiber.Ctx) error { count++ c.Set(fiber.HeaderCacheControl, "public, max-age=1, must-revalidate") return c.SendString(strconv.Itoa(count)) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) time.Sleep(1500 * time.Millisecond) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderCacheControl, "max-stale=30") resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "2", string(body)) require.Equal(t, 2, count) } func Test_CacheMaxStaleRespectsProxyRevalidateSharedAuth(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) var count int app.Get("/", func(c fiber.Ctx) error { count++ c.Set(fiber.HeaderCacheControl, "s-maxage=1, proxy-revalidate") return c.SendString(strconv.Itoa(count)) }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderAuthorization, "Bearer abc") resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) time.Sleep(1500 * time.Millisecond) req = httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderAuthorization, "Bearer abc") req.Header.Set(fiber.HeaderCacheControl, "max-stale=30") resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "2", string(body)) require.Equal(t, 2, count) } func Test_CachePreservesCacheControlHeaders(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) expires := time.Now().Add(10 * time.Second).UTC().Format(http.TimeFormat) app.Get("/", func(c fiber.Ctx) error { c.Set(fiber.HeaderCacheControl, "public, max-age=5, immutable") c.Set(fiber.HeaderExpires, expires) c.Set(fiber.HeaderETag, `W/"abc"`) return c.SendString("ok") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) require.Equal(t, "public, max-age=5, immutable", resp.Header.Get(fiber.HeaderCacheControl)) require.Equal(t, expires, resp.Header.Get(fiber.HeaderExpires)) require.Equal(t, `W/"abc"`, resp.Header.Get(fiber.HeaderETag)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, resp.Header.Get("X-Cache")) require.Equal(t, "public, max-age=5, immutable", resp.Header.Get(fiber.HeaderCacheControl)) require.Equal(t, expires, resp.Header.Get(fiber.HeaderExpires)) require.Equal(t, `W/"abc"`, resp.Header.Get(fiber.HeaderETag)) } func setResponseDate(date time.Time) fiber.Handler { return func(c fiber.Ctx) error { if err := c.Next(); err != nil { return err } c.Response().Header.Set(fiber.HeaderDate, date.UTC().Format(http.TimeFormat)) return nil } } func Test_CacheDateAndAgeHandling(t *testing.T) { t.Parallel() type testCase struct { name string cacheControl string cacheHeader string dateOffset time.Duration expiration time.Duration expectAgeAtLeast int expectCount int originAge int } cases := []testCase{ { name: "age derived from past date without Age header", dateOffset: -1 * time.Minute, cacheControl: "public, max-age=120", cacheHeader: cacheHit, expiration: 5 * time.Minute, expectAgeAtLeast: 1, expectCount: 1, }, { name: "stale due to past date despite max-age", dateOffset: -90 * time.Second, cacheControl: "public, max-age=30", cacheHeader: cacheUnreachable, expiration: 5 * time.Minute, expectCount: 2, originAge: 90, }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: tc.expiration})) app.Use(setResponseDate(time.Now().Add(tc.dateOffset).UTC())) var count int app.Get("/", func(c fiber.Ctx) error { count++ if tc.originAge > 0 { c.Response().Header.Set(fiber.HeaderAge, strconv.Itoa(tc.originAge)) } c.Set(fiber.HeaderCacheControl, tc.cacheControl) return c.SendString(strconv.Itoa(count)) }) _, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) if tc.cacheHeader == cacheHit { time.Sleep(2 * time.Second) } resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, tc.cacheHeader, resp.Header.Get("X-Cache")) if tc.cacheHeader == cacheHit { ageVal, err := strconv.Atoi(resp.Header.Get(fiber.HeaderAge)) require.NoError(t, err) require.GreaterOrEqual(t, ageVal, tc.expectAgeAtLeast) require.Equal(t, 1, count) } else { body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, strconv.Itoa(tc.expectCount), string(body)) require.Equal(t, tc.expectCount, count) } }) } } func Test_CacheClampsInvalidStoredDate(t *testing.T) { t.Parallel() storage := newMutatingStorage(func(key string, val []byte) []byte { if strings.HasSuffix(key, "_body") { return val } var it item if _, err := it.UnmarshalMsg(val); err != nil { return val } it.date = uint64(math.MaxInt64) + 1024 updated, err := it.MarshalMsg(nil) if err != nil { return val } return updated }) app := fiber.New() app.Use(New(Config{ Expiration: time.Minute, Storage: storage, })) app.Get("/", func(c fiber.Ctx) error { c.Set(fiber.HeaderCacheControl, "public, max-age=60") return c.SendString("ok") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, resp.Header.Get("X-Cache")) parsedDate, err := http.ParseTime(resp.Header.Get(fiber.HeaderDate)) require.NoError(t, err) require.WithinDuration(t, time.Now(), parsedDate, time.Minute) ageVal, err := strconv.Atoi(resp.Header.Get(fiber.HeaderAge)) require.NoError(t, err) require.Less(t, ageVal, 60) require.GreaterOrEqual(t, ageVal, 0) } func Test_CacheClampsFutureStoredDate(t *testing.T) { t.Parallel() storage := newMutatingStorage(func(key string, val []byte) []byte { if strings.HasSuffix(key, "_body") { return val } var it item if _, err := it.UnmarshalMsg(val); err != nil { return val } future := time.Now().Add(2 * time.Second).UTC() sec := future.Unix() if sec < 0 { sec = 0 } it.date = uint64(sec) //nolint:gosec // safe: sec is clamped to non-negative range updated, err := it.MarshalMsg(nil) if err != nil { return val } return updated }) app := fiber.New() app.Use(New(Config{ Expiration: time.Minute, Storage: storage, })) app.Get("/", func(c fiber.Ctx) error { c.Set(fiber.HeaderCacheControl, "public, max-age=60") return c.SendString("ok") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, resp.Header.Get("X-Cache")) parsedDate, err := http.ParseTime(resp.Header.Get(fiber.HeaderDate)) require.NoError(t, err) require.False(t, parsedDate.After(time.Now())) ageVal, err := strconv.Atoi(resp.Header.Get(fiber.HeaderAge)) require.NoError(t, err) require.GreaterOrEqual(t, ageVal, 0) } func Test_RequestPragmaNoCacheTriggersMiss(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Expiration: time.Minute, })) var count int app.Get("/", func(c fiber.Ctx) error { count++ c.Set(fiber.HeaderCacheControl, "public, max-age=60") return c.SendString("body" + strconv.Itoa(count)) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "body1", string(body)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "body1", string(body)) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderPragma, "no-cache") resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "body2", string(body)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "body2", string(body)) } func Test_CacheStaleResponseAddsWarning110(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Expiration: 2 * time.Second, })) var count int app.Get("/", func(c fiber.Ctx) error { count++ c.Set(fiber.HeaderCacheControl, "public, max-age=1") return c.SendString("body" + strconv.Itoa(count)) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderCacheControl, "max-stale=5") // Wait for the cached response to become stale (max-age=1) // Add extra time to ensure the entry has expired time.Sleep(1200 * time.Millisecond) deadline := time.Now().Add(3 * time.Second) for { resp, err = app.Test(req) require.NoError(t, err) if resp.Header.Get("X-Cache") == cacheHit { ageVal, err := strconv.Atoi(resp.Header.Get(fiber.HeaderAge)) require.NoError(t, err) if ageVal >= 1 { // Check that Warning header is present before breaking warnings := resp.Header.Values(fiber.HeaderWarning) if len(warnings) > 0 { break } } } require.True(t, time.Now().Before(deadline), "response did not become stale before deadline") time.Sleep(50 * time.Millisecond) } warnings := resp.Header.Values(fiber.HeaderWarning) require.NotEmpty(t, warnings, "Warning header should be present when serving stale response") found := false for _, w := range warnings { if strings.Contains(w, "110") { found = true break } } require.True(t, found, "warning 110 not found in %v", warnings) } func Test_CacheHeuristicFreshnessAddsWarning113(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Expiration: 2 * time.Second, })) app.Get("/", func(c fiber.Ctx) error { c.Set(fiber.HeaderCacheControl, "public, max-age=60") return c.SendString("body") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, resp.Header.Get("X-Cache")) for _, w := range resp.Header.Values(fiber.HeaderWarning) { require.NotContains(t, w, "113", "warning 113 should not be present for explicitly fresh responses") } } func Test_CacheHeuristicFreshnessAddsWarning113AfterThreshold(t *testing.T) { t.Parallel() storage := newMutatingStorage(func(key string, val []byte) []byte { if strings.HasSuffix(key, "_body") { return val } var it item if _, err := it.UnmarshalMsg(val); err != nil { return val } oldDate := time.Now().Add(-25 * time.Hour).UTC() sec := oldDate.Unix() if sec < 0 { sec = 0 } it.date = uint64(sec) //nolint:gosec // safe: sec is clamped to non-negative range future := time.Now().Add(48 * time.Hour).UTC() expSec := future.Unix() if expSec < 0 { expSec = 0 } it.exp = uint64(expSec) //nolint:gosec // safe: expSec is clamped to non-negative range it.ttl = uint64((48 * time.Hour) / time.Second) updated, err := it.MarshalMsg(nil) if err != nil { return val } return updated }) app := fiber.New() app.Use(New(Config{ Expiration: 2 * time.Second, Storage: storage, })) var count int app.Get("/", func(c fiber.Ctx) error { count++ return c.SendString("body" + strconv.Itoa(count)) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, resp.Header.Get("X-Cache")) warnings := resp.Header.Values(fiber.HeaderWarning) require.NotEmpty(t, warnings) found := false for _, w := range warnings { if strings.Contains(w, "113") { found = true break } } require.True(t, found, "warning 113 not found in %v", warnings) } func Test_CacheAgeHeaderIsCappedAtMaxDeltaSeconds(t *testing.T) { t.Parallel() const veryLargeAge = uint64(math.MaxInt32) + 1000 storage := newMutatingStorage(func(key string, val []byte) []byte { if strings.HasSuffix(key, "_body") { return val } var it item if _, err := it.UnmarshalMsg(val); err != nil { return val } it.age = veryLargeAge updated, err := it.MarshalMsg(nil) if err != nil { return val } return updated }) app := fiber.New() app.Use(New(Config{ Expiration: time.Minute, Storage: storage, })) app.Get("/", func(c fiber.Ctx) error { c.Set(fiber.HeaderCacheControl, "public, max-age=60") return c.SendString("body") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, resp.Header.Get("X-Cache")) ageVal, err := strconv.Atoi(resp.Header.Get(fiber.HeaderAge)) require.NoError(t, err) require.Equal(t, math.MaxInt32, ageVal) } func Test_CacheMinFreshForcesRevalidation(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) var count int app.Get("/", func(c fiber.Ctx) error { count++ c.Set(fiber.HeaderCacheControl, "public, max-age=5") return c.SendString(strconv.Itoa(count)) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "1", string(body)) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderCacheControl, "min-fresh=10") resp, err = app.Test(req) require.NoError(t, err) require.Equalf(t, cacheMiss, resp.Header.Get("X-Cache"), "dirs=%+v Age=%s count=%d", parseRequestCacheControlString("min-fresh=10"), resp.Header.Get(fiber.HeaderAge), count) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "2", string(body)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "2", string(body)) } func Test_CachePermanentRedirectCached(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Expiration: 30 * time.Second, StoreResponseHeaders: true, KeyGenerator: func(c fiber.Ctx) string { return c.Path() + "|status-308" }, })) var count int app.Get("/", func(c fiber.Ctx) error { count++ c.Set(fiber.HeaderCacheControl, "public, max-age=30") c.Set(fiber.HeaderLocation, "/dest") return c.Status(fiber.StatusPermanentRedirect).SendString("redir" + strconv.Itoa(count)) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) require.Equal(t, fiber.StatusPermanentRedirect, resp.StatusCode) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "redir1", string(body)) require.Equal(t, "/dest", resp.Header.Get(fiber.HeaderLocation)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, resp.Header.Get("X-Cache")) require.Equal(t, fiber.StatusPermanentRedirect, resp.StatusCode) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "redir1", string(body)) require.Equal(t, "/dest", resp.Header.Get(fiber.HeaderLocation)) } func Test_CacheNoStoreDirective(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { c.Set(fiber.HeaderCacheControl, "no-store") return c.SendString("ok") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache")) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache")) } func Test_CacheNoCacheDirective(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) var count int app.Get("/", func(c fiber.Ctx) error { count++ c.Set(fiber.HeaderCacheControl, "no-cache") return c.SendString(strconv.Itoa(count)) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "1", string(body)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "2", string(body)) } func Test_CacheNoCacheDirectiveOverridesExistingEntry(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) var noCacheMode atomic.Bool app.Get("/", func(c fiber.Ctx) error { if noCacheMode.Load() { c.Set(fiber.HeaderCacheControl, "no-cache") return c.SendString("no-cache") } c.Set(fiber.HeaderCacheControl, "public, max-age=60") return c.SendString("cacheable") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "cacheable", string(body)) noCacheMode.Store(true) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderCacheControl, "no-cache") resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "no-cache", string(body)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "no-cache", string(body)) } func Test_CacheRespectsUpstreamAgeForFreshness(t *testing.T) { t.Parallel() t.Run("skipsCachingWhenAgeExhaustsFreshness", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ KeyGenerator: func(c fiber.Ctx) string { return c.Path() + "|age-exhausted" }, })) var count int app.Get("/", func(c fiber.Ctx) error { count++ c.Set(fiber.HeaderCacheControl, "public, max-age=2") c.Set(fiber.HeaderAge, "2") return c.SendString(strconv.Itoa(count)) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "1", string(body)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "2", string(body)) }) t.Run("expiresAfterRemainingLifetime", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ KeyGenerator: func(c fiber.Ctx) string { return c.Path() + "|age-remaining" }, })) var count int app.Get("/", func(c fiber.Ctx) error { count++ c.Set(fiber.HeaderCacheControl, "public, max-age=2") c.Set(fiber.HeaderAge, "1") return c.SendString(strconv.Itoa(count)) }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "1", string(body)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "1", string(body)) time.Sleep(1500 * time.Millisecond) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "2", string(body)) }) } func Test_CacheVarySeparatesVariants(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ KeyGenerator: func(c fiber.Ctx) string { return c.Path() + "|vary-separated" }, })) var count int app.Get("/", func(c fiber.Ctx) error { count++ c.Set(fiber.HeaderVary, fiber.HeaderAcceptLanguage) return c.SendString(c.Get(fiber.HeaderAcceptLanguage) + strconv.Itoa(count)) }) reqEN := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) reqEN.Header.Set(fiber.HeaderAcceptLanguage, "en") resp, err := app.Test(reqEN) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "en1", string(body)) reqFR := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) reqFR.Header.Set(fiber.HeaderAcceptLanguage, "fr") resp, err = app.Test(reqFR) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "fr2", string(body)) reqENRepeat := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) reqENRepeat.Header.Set(fiber.HeaderAcceptLanguage, "en") resp, err = app.Test(reqENRepeat) require.NoError(t, err) require.Equal(t, cacheHit, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "en1", string(body)) reqFRRepeat := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) reqFRRepeat.Header.Set(fiber.HeaderAcceptLanguage, "fr") resp, err = app.Test(reqFRRepeat) require.NoError(t, err) require.Equal(t, cacheHit, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "fr2", string(body)) } func Test_CacheVaryStarUncacheable(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ KeyGenerator: func(c fiber.Ctx) string { return c.Path() + "|vary-star" }, })) var count int app.Get("/", func(c fiber.Ctx) error { count++ c.Set(fiber.HeaderVary, "*") return c.SendString(strconv.Itoa(count)) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "1", string(body)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "2", string(body)) } func Test_CachePrivateDirective(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) var count int app.Get("/", func(c fiber.Ctx) error { count++ c.Set(fiber.HeaderCacheControl, "private") return c.SendString(strconv.Itoa(count)) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "1", string(body)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "2", string(body)) } func Test_CachePrivateDirectiveWithAuthorization(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) var count int app.Get("/", func(c fiber.Ctx) error { count++ c.Set(fiber.HeaderCacheControl, "private") return c.SendString(strconv.Itoa(count)) }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderAuthorization, "Bearer token") resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "1", string(body)) req = httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderAuthorization, "Bearer token") resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "2", string(body)) } func Test_CachePrivateDirectiveInvalidatesExistingEntry(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) var privateMode atomic.Bool app.Get("/", func(c fiber.Ctx) error { if privateMode.Load() { c.Set(fiber.HeaderCacheControl, "private") return c.SendString("private") } c.Set(fiber.HeaderCacheControl, "public, max-age=60") return c.SendString("public") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "public", string(body)) privateMode.Store(true) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderCacheControl, "no-cache") resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "private", string(body)) privateMode.Store(false) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "public", string(body)) } func Test_CacheControlNotOverwritten(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 10 * time.Second, StoreResponseHeaders: true})) app.Get("/", func(c fiber.Ctx) error { c.Set(fiber.HeaderCacheControl, "private") return c.SendString("ok") }) _, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, "private", resp.Header.Get(fiber.HeaderCacheControl)) } func Test_CacheMaxAgeDirective(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 10 * time.Second})) app.Get("/", func(c fiber.Ctx) error { c.Set(fiber.HeaderCacheControl, "max-age=1") return c.SendString("1") }) _, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) time.Sleep(1500 * time.Millisecond) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) } func Test_ParseMaxAge(t *testing.T) { t.Parallel() tests := []struct { header string expect time.Duration ok bool }{ {"max-age=60", 60 * time.Second, true}, {"public, max-age=86400", 86400 * time.Second, true}, {"no-store", 0, false}, {"max-age=invalid", 0, false}, {"public, s-maxage=100, max-age=50", 50 * time.Second, true}, {"MAX-AGE=20", 20 * time.Second, true}, {"public , max-age=0", 0, true}, {"public , max-age", 0, false}, } for _, tt := range tests { t.Run(tt.header, func(t *testing.T) { t.Parallel() d, ok := parseMaxAge(tt.header) if tt.ok != ok { t.Fatalf("expected ok=%v got %v", tt.ok, ok) } if ok && d != tt.expect { t.Fatalf("expected %v got %v", tt.expect, d) } }) } } func Test_AllowsSharedCache(t *testing.T) { t.Parallel() tests := []struct { directives string expect bool }{ {"public", true}, {"private", false}, {"s-maxage=60", true}, {"public, max-age=60", true}, {"public, must-revalidate", true}, {"max-age=60", false}, {"no-cache", false}, {"no-cache, s-maxage=60", true}, {"", false}, } for _, tt := range tests { t.Run(tt.directives, func(t *testing.T) { t.Parallel() got := allowsSharedCache(tt.directives) require.Equal(t, tt.expect, got, "directives: %q", tt.directives) }) } t.Run("private overrules public", func(t *testing.T) { t.Parallel() got := allowsSharedCache(strings.ToUpper("private, public")) require.False(t, got) }) } func TestCacheSkipsAuthorizationByDefault(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) var count int app.Get("/", func(c fiber.Ctx) error { count++ return c.SendString(strconv.Itoa(count)) }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderAuthorization, "Bearer token") resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "1", string(body)) req = httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderAuthorization, "Bearer token") resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "2", string(body)) } func TestCacheBypassesExistingEntryForAuthorization(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) var count int app.Get("/", func(c fiber.Ctx) error { count++ return c.SendString(strconv.Itoa(count)) }) nonAuthReq := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) resp, err := app.Test(nonAuthReq) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "1", string(body)) authReq := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) authReq.Header.Set(fiber.HeaderAuthorization, "Bearer token") resp, err = app.Test(authReq) require.NoError(t, err) require.Equal(t, cacheUnreachable, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "2", string(body)) resp, err = app.Test(nonAuthReq) require.NoError(t, err) require.Equal(t, cacheHit, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "1", string(body)) } func TestCacheAllowsSharedCacheWithAuthorization(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 10 * time.Second})) var count int app.Get("/", func(c fiber.Ctx) error { count++ c.Set(fiber.HeaderCacheControl, "public, max-age=60") return c.SendString("ok") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderAuthorization, "Bearer token") resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) req = httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderAuthorization, "Bearer token") resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, cacheHit, resp.Header.Get("X-Cache")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "ok", string(body)) require.Equal(t, 1, count) } func TestCacheAllowsAuthorizationWithRevalidateDirectives(t *testing.T) { t.Parallel() tests := []struct { name string cacheControl string expires string expectedBody string expectedBody2 string expectFirst string expectSecond string }{ { name: "must-revalidate", cacheControl: "must-revalidate, max-age=60", expectedBody: "ok-1", expectedBody2: "ok-1", expectFirst: cacheMiss, expectSecond: cacheHit, }, { name: "proxy-revalidate", cacheControl: "proxy-revalidate, max-age=60", expectedBody: "ok-1", expectedBody2: "ok-1", expectFirst: cacheMiss, expectSecond: cacheHit, }, { name: "expires header", cacheControl: "", expires: time.Now().Add(1 * time.Minute).UTC().Format(http.TimeFormat), expectedBody: "ok-1", expectedBody2: "ok-2", expectFirst: cacheUnreachable, expectSecond: cacheUnreachable, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 10 * time.Second})) var count int app.Get("/", func(c fiber.Ctx) error { count++ c.Set(fiber.HeaderCacheControl, tt.cacheControl) if tt.expires != "" { c.Set(fiber.HeaderExpires, tt.expires) } return c.SendString(fmt.Sprintf("ok-%d", count)) }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderAuthorization, "Bearer token") resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, tt.expectFirst, resp.Header.Get("X-Cache")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, tt.expectedBody, string(body)) resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, tt.expectSecond, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, tt.expectedBody2, string(body)) if tt.expectSecond == cacheHit { require.Equal(t, 1, count) } else { require.Equal(t, 2, count) } }) } } func TestCacheSeparatesAuthorizationValues(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 10 * time.Second})) var count int app.Get("/", func(c fiber.Ctx) error { count++ c.Set(fiber.HeaderCacheControl, "public, max-age=60") return c.SendString(fmt.Sprintf("body-%d-%s", count, c.Get(fiber.HeaderAuthorization))) }) newRequest := func(token string) *http.Request { req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(fiber.HeaderAuthorization, "Bearer "+token) return req } authTokenA := "token-a" authTokenB := "token-b" resp, err := app.Test(newRequest(authTokenA)) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "body-1-Bearer "+authTokenA, string(body)) resp, err = app.Test(newRequest(authTokenA)) require.NoError(t, err) require.Equal(t, cacheHit, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "body-1-Bearer "+authTokenA, string(body)) require.Equal(t, 1, count) resp, err = app.Test(newRequest(authTokenB)) require.NoError(t, err) require.Equal(t, cacheMiss, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "body-2-Bearer "+authTokenB, string(body)) require.Equal(t, 2, count) resp, err = app.Test(newRequest(authTokenB)) require.NoError(t, err) require.Equal(t, cacheHit, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "body-2-Bearer "+authTokenB, string(body)) resp, err = app.Test(newRequest(authTokenA)) require.NoError(t, err) require.Equal(t, cacheHit, resp.Header.Get("X-Cache")) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "body-1-Bearer "+authTokenA, string(body)) require.Equal(t, 2, count) } // go test -v -run=^$ -bench=Benchmark_Cache -benchmem -count=4 func Benchmark_Cache(b *testing.B) { app := fiber.New() app.Use(New()) app.Get("/demo", func(c fiber.Ctx) error { data, _ := os.ReadFile("../../.github/README.md") //nolint:errcheck // We're inside a benchmark return c.Status(fiber.StatusTeapot).Send(data) }) h := app.Handler() fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(fiber.MethodGet) fctx.Request.SetRequestURI("/demo") b.ReportAllocs() for b.Loop() { h(fctx) } require.Equal(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode()) require.Greater(b, len(fctx.Response.Body()), 30000) } func Benchmark_Cache_Miss(b *testing.B) { app := fiber.New() app.Use(New()) app.Get("/*", func(c fiber.Ctx) error { data, _ := os.ReadFile("../../.github/README.md") //nolint:errcheck // We're inside a benchmark return c.Status(fiber.StatusOK).Send(data) }) h := app.Handler() fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(fiber.MethodGet) b.ReportAllocs() b.ResetTimer() var n int for b.Loop() { n++ fctx.Request.SetRequestURI("/demo/" + strconv.Itoa(n)) h(fctx) } require.Equal(b, fiber.StatusOK, fctx.Response.Header.StatusCode()) require.Greater(b, len(fctx.Response.Body()), 30000) } // go test -v -run=^$ -bench=Benchmark_Cache_Storage -benchmem -count=4 func Benchmark_Cache_Storage(b *testing.B) { app := fiber.New() app.Use(New(Config{ Storage: memory.New(), })) app.Get("/demo", func(c fiber.Ctx) error { data, _ := os.ReadFile("../../.github/README.md") //nolint:errcheck // We're inside a benchmark return c.Status(fiber.StatusTeapot).Send(data) }) h := app.Handler() fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(fiber.MethodGet) fctx.Request.SetRequestURI("/demo") b.ReportAllocs() for b.Loop() { h(fctx) } require.Equal(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode()) require.Greater(b, len(fctx.Response.Body()), 30000) } func Benchmark_Cache_AdditionalHeaders(b *testing.B) { app := fiber.New() app.Use(New(Config{ StoreResponseHeaders: true, })) app.Get("/demo", func(c fiber.Ctx) error { c.Response().Header.Add("X-Foobar", "foobar") return c.SendStatus(418) }) h := app.Handler() fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(fiber.MethodGet) fctx.Request.SetRequestURI("/demo") b.ReportAllocs() for b.Loop() { h(fctx) } require.Equal(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode()) require.Equal(b, []byte("foobar"), fctx.Response.Header.Peek("X-Foobar")) } func Benchmark_Cache_MaxSize(b *testing.B) { // The benchmark is run with three different MaxSize parameters // 1) 0: Tracking is disabled = no overhead // 2) MaxInt32: Enough to store all entries = no removals // 3) 100: Small size = constant insertions and removals cases := []uint{0, math.MaxUint32, 100} names := []string{"Disabled", "Unlim", "LowBounded"} for i, size := range cases { b.Run(names[i], func(b *testing.B) { app := fiber.New() app.Use(New(Config{MaxBytes: size})) app.Get("/*", func(c fiber.Ctx) error { return c.Status(fiber.StatusTeapot).SendString("1") }) h := app.Handler() fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(fiber.MethodGet) b.ReportAllocs() n := 0 for b.Loop() { n++ fctx.Request.SetRequestURI(fmt.Sprintf("/%v", n)) h(fctx) } require.Equal(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode()) }) } } func Test_Cache_RevalidationWithMaxBytes(t *testing.T) { t.Parallel() t.Run("max-age=0 revalidation removes old entry on storage success", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ MaxBytes: 100, })) requestCount := 0 app.Get("/test", func(c fiber.Ctx) error { requestCount++ c.Set(fiber.HeaderCacheControl, "max-age=60") return c.SendString(fmt.Sprintf("response-%d", requestCount)) }) // First request - cache the response req1 := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody) resp1, err := app.Test(req1) require.NoError(t, err) require.Equal(t, cacheMiss, resp1.Header.Get("X-Cache")) // Request with max-age=0 to force revalidation req2 := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody) req2.Header.Set(fiber.HeaderCacheControl, "max-age=0") resp2, err := app.Test(req2) require.NoError(t, err) body2, err := io.ReadAll(resp2.Body) require.NoError(t, err) require.Equal(t, "response-2", string(body2)) // Next request should serve the NEW cached entry req3 := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody) resp3, err := app.Test(req3) require.NoError(t, err) require.Equal(t, cacheHit, resp3.Header.Get("X-Cache")) body3, err := io.ReadAll(resp3.Body) require.NoError(t, err) require.Equal(t, "response-2", string(body3), "New entry should be cached") }) t.Run("min-fresh revalidation with MaxBytes", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ MaxBytes: 100, })) requestCount := 0 app.Get("/test", func(c fiber.Ctx) error { requestCount++ c.Set(fiber.HeaderCacheControl, "max-age=2") return c.SendString(fmt.Sprintf("response-%d", requestCount)) }) // First request - cache the response req1 := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody) resp1, err := app.Test(req1) require.NoError(t, err) require.Equal(t, cacheMiss, resp1.Header.Get("X-Cache")) // Wait a bit so the entry has aged time.Sleep(1 * time.Second) // Request with min-fresh that exceeds remaining freshness req2 := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody) req2.Header.Set(fiber.HeaderCacheControl, "min-fresh=5") resp2, err := app.Test(req2) require.NoError(t, err) body2, err := io.ReadAll(resp2.Body) require.NoError(t, err) require.Equal(t, "response-2", string(body2)) // Next request should serve the NEW cached entry req3 := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody) resp3, err := app.Test(req3) require.NoError(t, err) require.Equal(t, cacheHit, resp3.Header.Get("X-Cache")) body3, err := io.ReadAll(resp3.Body) require.NoError(t, err) require.Equal(t, "response-2", string(body3)) }) t.Run("revalidation respects MaxBytes eviction", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ MaxBytes: 20, // Only room for 2 responses of 10 bytes each ExpirationGenerator: stableAscendingExpiration(), })) app.Get("/*", func(c fiber.Ctx) error { c.Set(fiber.HeaderCacheControl, "max-age=60") return c.SendString("1234567890") // 10 bytes }) // Cache /a and /b req1 := httptest.NewRequest(fiber.MethodGet, "/a", http.NoBody) resp1, err := app.Test(req1) require.NoError(t, err) require.Equal(t, cacheMiss, resp1.Header.Get("X-Cache")) req2 := httptest.NewRequest(fiber.MethodGet, "/b", http.NoBody) resp2, err := app.Test(req2) require.NoError(t, err) require.Equal(t, cacheMiss, resp2.Header.Get("X-Cache")) // Both should be cached req3 := httptest.NewRequest(fiber.MethodGet, "/a", http.NoBody) resp3, err := app.Test(req3) require.NoError(t, err) require.Equal(t, cacheHit, resp3.Header.Get("X-Cache")) req4 := httptest.NewRequest(fiber.MethodGet, "/b", http.NoBody) resp4, err := app.Test(req4) require.NoError(t, err) require.Equal(t, cacheHit, resp4.Header.Get("X-Cache")) // Revalidate /a with max-age=0 req5 := httptest.NewRequest(fiber.MethodGet, "/a", http.NoBody) req5.Header.Set(fiber.HeaderCacheControl, "max-age=0") _, err = app.Test(req5) require.NoError(t, err) // /a should be revalidated and cached again req6 := httptest.NewRequest(fiber.MethodGet, "/a", http.NoBody) resp6, err := app.Test(req6) require.NoError(t, err) require.Equal(t, cacheHit, resp6.Header.Get("X-Cache")) // /b should still be cached (heap accounting should be correct) req7 := httptest.NewRequest(fiber.MethodGet, "/b", http.NoBody) resp7, err := app.Test(req7) require.NoError(t, err) require.Equal(t, cacheHit, resp7.Header.Get("X-Cache")) }) t.Run("revalidation with non-cacheable response preserves old entry", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ MaxBytes: 100, })) requestCount := 0 app.Get("/test", func(c fiber.Ctx) error { requestCount++ if requestCount == 1 { c.Set(fiber.HeaderCacheControl, "max-age=60") return c.SendString("cacheable") } // Second request returns no-store c.Set(fiber.HeaderCacheControl, "no-store") return c.SendString("not-cacheable") }) // First request - cache the response req1 := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody) resp1, err := app.Test(req1) require.NoError(t, err) require.Equal(t, cacheMiss, resp1.Header.Get("X-Cache")) body1, err := io.ReadAll(resp1.Body) require.NoError(t, err) require.Equal(t, "cacheable", string(body1)) // Request with max-age=0 to force revalidation // The new response will be no-store (not cacheable) req2 := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody) req2.Header.Set(fiber.HeaderCacheControl, "max-age=0") resp2, err := app.Test(req2) require.NoError(t, err) body2, err := io.ReadAll(resp2.Body) require.NoError(t, err) require.Equal(t, "not-cacheable", string(body2)) // Next request should still serve the OLD cached entry // because the new response was not cacheable and old entry should remain tracked req3 := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody) resp3, err := app.Test(req3) require.NoError(t, err) require.Equal(t, cacheHit, resp3.Header.Get("X-Cache")) body3, err := io.ReadAll(resp3.Body) require.NoError(t, err) require.Equal(t, "cacheable", string(body3), "Old entry should still be cached") }) } // Test_parseCacheControlDirectives_QuotedStrings tests RFC 9111 Section 5.2 compliance // for quoted-string values in Cache-Control directives func Test_parseCacheControlDirectives_QuotedStrings(t *testing.T) { t.Parallel() tests := []struct { name string expected map[string]string input string }{ { name: "simple quoted value", input: `community="UCI"`, expected: map[string]string{ "community": "UCI", }, }, { name: "multiple directives with quoted values", input: `max-age=3600, community="UCI", custom="value"`, expected: map[string]string{ "max-age": "3600", "community": "UCI", "custom": "value", }, }, { name: "quoted value with spaces", input: `custom="value with spaces"`, expected: map[string]string{ "custom": "value with spaces", }, }, { name: "quoted value with escaped quote", input: `custom="value with \"quotes\""`, expected: map[string]string{ "custom": `value with "quotes"`, }, }, { name: "quoted value with escaped backslash", input: `custom="value with \\ backslash"`, expected: map[string]string{ "custom": `value with \ backslash`, }, }, { name: "mixed quoted and unquoted values", input: `max-age=3600, community="UCI", no-cache, custom="test"`, expected: map[string]string{ "max-age": "3600", "community": "UCI", "no-cache": "", "custom": "test", }, }, { name: "quoted empty value", input: `custom=""`, expected: map[string]string{ "custom": "", }, }, { name: "spaces around quoted value", input: `custom = "value" , another="test"`, expected: map[string]string{ "custom": "value", "another": "test", }, }, { name: "unquoted token value", input: `max-age=3600`, expected: map[string]string{ "max-age": "3600", }, }, { name: "complex mixed case", input: `max-age=3600, s-maxage=7200, community="UCI", no-store, custom="value with \"escaped\" quotes"`, expected: map[string]string{ "max-age": "3600", "s-maxage": "7200", "community": "UCI", "no-store": "", "custom": `value with "escaped" quotes`, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() result := make(map[string]string) parseCacheControlDirectives([]byte(tt.input), func(key, value []byte) { result[string(key)] = string(value) }) require.Equal(t, tt.expected, result) }) } } // Test_unquoteCacheDirective tests the unquoting logic for quoted-string values func Test_unquoteCacheDirective(t *testing.T) { t.Parallel() tests := []struct { name string input []byte expected []byte }{ { name: "simple quoted string", input: []byte(`"value"`), expected: []byte("value"), }, { name: "empty quoted string", input: []byte(`""`), expected: []byte(""), }, { name: "quoted string with spaces", input: []byte(`"value with spaces"`), expected: []byte("value with spaces"), }, { name: "quoted string with escaped quote", input: []byte(`"value with \"quote\""`), expected: []byte(`value with "quote"`), }, { name: "quoted string with escaped backslash", input: []byte(`"value with \\ backslash"`), expected: []byte(`value with \ backslash`), }, { name: "quoted string with multiple escapes", input: []byte(`"a\"b\\c\"d"`), expected: []byte(`a"b\c"d`), }, { name: "too short input", input: []byte(`"`), expected: []byte(`"`), }, { name: "empty input", input: []byte(``), expected: []byte(``), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() result := unquoteCacheDirective(tt.input) require.Equal(t, tt.expected, result) }) } } // Test_Cache_MaxBytes_InsufficientSpace tests the "insufficient space" error path // when an entry is larger than MaxBytes, ensuring such entries are treated as unreachable func Test_Cache_MaxBytes_InsufficientSpace(t *testing.T) { t.Parallel() t.Run("entry larger than MaxBytes with empty cache", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ MaxBytes: 10, // Very small cache Expiration: 1 * time.Hour, })) app.Get("/large", func(c fiber.Ctx) error { // Return data larger than MaxBytes return c.Send(make([]byte, 20)) }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/large", http.NoBody)) require.NoError(t, err) // Should be unreachable because entry is too large require.Equal(t, cacheUnreachable, rsp.Header.Get("X-Cache")) }) t.Run("entry larger than MaxBytes after eviction", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ MaxBytes: 15, ExpirationGenerator: stableAscendingExpiration(), })) app.Get("/*", func(c fiber.Ctx) error { path := c.Path() if path == "/small" { return c.Send(make([]byte, 5)) } return c.Send(make([]byte, 20)) }) // Cache a small entry first rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/small", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) // Try to cache a large entry - should return unreachable since it won't fit even after eviction rsp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/large", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheUnreachable, rsp.Header.Get("X-Cache")) }) } func Test_Cache_MaxBytes_DeletionFailureRestoresTracking(t *testing.T) { t.Parallel() storage := newFailingCacheStorage() app := fiber.New() app.Use(New(Config{ MaxBytes: 4, Expiration: 1 * time.Hour, Storage: storage, })) app.Get("/:name", func(c fiber.Ctx) error { return c.SendString("data") }) // Seed the cache with a single entry rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/first", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) var storedKeys []string storage.mu.Lock() for key := range storage.data { storedKeys = append(storedKeys, key) if strings.Contains(key, "/first") { storage.errs["del|"+key] = errors.New("delete failed") } } storage.mu.Unlock() t.Logf("stored keys after first cache: %v", storedKeys) // Next request triggers eviction; deletion failure should surface an error rsp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/second", http.NoBody)) require.NoError(t, err) body, err := io.ReadAll(rsp.Body) require.NoError(t, err) require.Equal(t, fiber.StatusInternalServerError, rsp.StatusCode) require.Contains(t, string(body), "failed to delete key") require.NoError(t, rsp.Body.Close()) var remainingKeys []string storage.mu.RLock() for key := range storage.data { remainingKeys = append(remainingKeys, key) } storage.mu.RUnlock() t.Logf("stored keys after deletion failure: %v", remainingKeys) storage.mu.Lock() storage.errs = make(map[string]error) storage.mu.Unlock() // Another request should succeed and be cacheable after restoring heap tracking rsp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/third", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) require.NoError(t, rsp.Body.Close()) } // Test_Cache_MaxBytes_ConcurrencyAndRaceConditions tests that the race condition fix works correctly // under concurrent load, verifying that storedBytes never exceeds MaxBytes even with multiple // goroutines making simultaneous requests func Test_Cache_MaxBytes_ConcurrencyAndRaceConditions(t *testing.T) { t.Parallel() t.Run("concurrent requests with MaxBytes limit", func(t *testing.T) { t.Parallel() app := fiber.New() const maxBytes = uint(1000) const numGoroutines = 20 const requestsPerGoroutine = 5 app.Use(New(Config{ MaxBytes: maxBytes, Expiration: 10 * time.Second, })) app.Get("/*", func(c fiber.Ctx) error { // Return data that will fill up the cache return c.Send(make([]byte, 50)) }) // Launch multiple goroutines making concurrent requests var wg sync.WaitGroup errChan := make(chan error, numGoroutines*requestsPerGoroutine) for i := 0; i < numGoroutines; i++ { id := i wg.Add(1) //nolint:revive // Standard WaitGroup pattern is appropriate here go func() { defer wg.Done() for j := 0; j < requestsPerGoroutine; j++ { path := fmt.Sprintf("/test-%d-%d", id, j) req := httptest.NewRequest(fiber.MethodGet, path, http.NoBody) _, err := app.Test(req) if err != nil { errChan <- err } } }() } wg.Wait() close(errChan) // Check for errors for err := range errChan { require.NoError(t, err, "concurrent request failed") } // The test passes if no errors occurred and no race conditions were detected by -race flag }) t.Run("concurrent requests near capacity triggers eviction", func(t *testing.T) { t.Parallel() app := fiber.New() const maxBytes = uint(200) const numRequests = 10 app.Use(New(Config{ MaxBytes: maxBytes, Expiration: 10 * time.Second, })) app.Get("/*", func(c fiber.Ctx) error { // Each response is about 50 bytes, so we'll exceed capacity return c.Send(make([]byte, 50)) }) // Make concurrent requests that will trigger evictions var wg sync.WaitGroup for i := 0; i < numRequests; i++ { id := i wg.Add(1) //nolint:revive // Standard WaitGroup pattern is appropriate here go func() { defer wg.Done() path := fmt.Sprintf("/item-%d", id) req := httptest.NewRequest(fiber.MethodGet, path, http.NoBody) _, err := app.Test(req) if err != nil { t.Logf("request error: %v", err) } }() } wg.Wait() // Test passes if no race conditions or panics occurred // The -race flag will detect any remaining race conditions }) } // Test_Cache_HelperFunctions tests various helper functions for better coverage func Test_Cache_HelperFunctions(t *testing.T) { t.Parallel() t.Run("parseHTTPDate empty", func(t *testing.T) { t.Parallel() result, ok := parseHTTPDate([]byte{}) require.False(t, ok) require.Equal(t, uint64(0), result) }) t.Run("parseHTTPDate invalid", func(t *testing.T) { t.Parallel() result, ok := parseHTTPDate([]byte("invalid")) require.False(t, ok) require.Equal(t, uint64(0), result) }) t.Run("parseHTTPDate valid", func(t *testing.T) { t.Parallel() result, ok := parseHTTPDate([]byte("Mon, 02 Jan 2006 15:04:05 GMT")) require.True(t, ok) require.Positive(t, result) }) t.Run("safeUnixSeconds negative", func(t *testing.T) { t.Parallel() result := safeUnixSeconds(time.Unix(-1, 0)) require.Equal(t, uint64(0), result) }) t.Run("safeUnixSeconds positive", func(t *testing.T) { t.Parallel() result := safeUnixSeconds(time.Unix(1234567890, 0)) require.Equal(t, uint64(1234567890), result) }) t.Run("remainingFreshness nil", func(t *testing.T) { t.Parallel() result := remainingFreshness(nil, 100) require.Equal(t, uint64(0), result) }) t.Run("remainingFreshness zero exp", func(t *testing.T) { t.Parallel() e := &item{exp: 0} result := remainingFreshness(e, 100) require.Equal(t, uint64(0), result) }) t.Run("remainingFreshness expired", func(t *testing.T) { t.Parallel() e := &item{exp: 100} result := remainingFreshness(e, 200) require.Equal(t, uint64(0), result) }) t.Run("remainingFreshness valid", func(t *testing.T) { t.Parallel() e := &item{exp: 200} result := remainingFreshness(e, 100) require.Equal(t, uint64(100), result) }) t.Run("lookupCachedHeader not found", func(t *testing.T) { t.Parallel() headers := []cachedHeader{{key: []byte("Content-Type"), value: []byte("text/html")}} value, found := lookupCachedHeader(headers, "Authorization") require.False(t, found) require.Nil(t, value) }) t.Run("lookupCachedHeader case insensitive", func(t *testing.T) { t.Parallel() headers := []cachedHeader{{key: []byte("Authorization"), value: []byte("Bearer token")}} value, found := lookupCachedHeader(headers, "authorization") require.True(t, found) require.Equal(t, []byte("Bearer token"), value) }) t.Run("secondsToDuration zero", func(t *testing.T) { t.Parallel() result := secondsToDuration(0) require.Equal(t, time.Duration(0), result) }) t.Run("secondsToDuration large", func(t *testing.T) { t.Parallel() result := secondsToDuration(9223372036) require.Greater(t, result, time.Duration(0)) }) t.Run("secondsToTime zero", func(t *testing.T) { t.Parallel() result := secondsToTime(0) require.Equal(t, time.Unix(0, 0).UTC(), result) }) t.Run("secondsToTime value", func(t *testing.T) { t.Parallel() result := secondsToTime(1234567890) require.Equal(t, time.Unix(1234567890, 0).UTC(), result) }) t.Run("isHeuristicFreshness short age", func(t *testing.T) { t.Parallel() cfg := &Config{Expiration: 1 * time.Hour} e := &item{cacheControl: []byte("public")} result := isHeuristicFreshness(e, cfg, 3600) require.False(t, result) }) t.Run("isHeuristicFreshness with expires", func(t *testing.T) { t.Parallel() cfg := &Config{Expiration: 1 * time.Hour} e := &item{cacheControl: []byte("public"), expires: []byte("Wed, 21 Oct 2015 07:28:00 GMT")} result := isHeuristicFreshness(e, cfg, uint64(25*time.Hour/time.Second)) require.False(t, result) }) t.Run("isHeuristicFreshness true", func(t *testing.T) { t.Parallel() cfg := &Config{Expiration: 1 * time.Hour} e := &item{cacheControl: []byte("public")} result := isHeuristicFreshness(e, cfg, uint64(25*time.Hour/time.Second)) require.True(t, result) }) t.Run("cacheBodyFetchError miss", func(t *testing.T) { t.Parallel() mask := func(_ string) string { return "***" } err := cacheBodyFetchError(mask, "key", errCacheMiss) require.Error(t, err) require.Contains(t, err.Error(), "no cached body") }) t.Run("cacheBodyFetchError other", func(t *testing.T) { t.Parallel() mask := func(_ string) string { return "***" } originalErr := errors.New("storage error") err := cacheBodyFetchError(mask, "key", originalErr) require.Equal(t, originalErr, err) }) } // Test_Cache_VaryAndAuth tests vary and auth functionality func Test_Cache_VaryAndAuth(t *testing.T) { t.Parallel() t.Run("storeVaryManifest failure", func(t *testing.T) { t.Parallel() storage := newFailingCacheStorage() storage.errs["set|manifest"] = errors.New("storage fail") manager := &manager{storage: storage} err := storeVaryManifest(context.Background(), manager, "manifest", []string{"Accept"}, 3600*time.Second) require.Error(t, err) }) t.Run("loadVaryManifest not found", func(t *testing.T) { t.Parallel() storage := newFailingCacheStorage() manager := &manager{storage: storage} varyNames, found, err := loadVaryManifest(context.Background(), manager, "nonexistent") require.NoError(t, err) require.False(t, found) require.Nil(t, varyNames) }) t.Run("vary with multiple headers", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Vary", "Accept, Accept-Encoding") c.Response().Header.Set("Cache-Control", "max-age=3600") return c.SendString("test") }) req := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody) req.Header.Set("Accept", "application/json") req.Header.Set("Accept-Encoding", "gzip") rsp, err := app.Test(req) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) req2 := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody) req2.Header.Set("Accept", "application/json") req2.Header.Set("Accept-Encoding", "gzip") rsp2, err := app.Test(req2) require.NoError(t, err) require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache")) }) t.Run("auth with must-revalidate", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Cache-Control", "must-revalidate, max-age=3600") return c.SendString("content") }) req := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody) req.Header.Set("Authorization", "Bearer token1") rsp, err := app.Test(req) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) req2 := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody) req2.Header.Set("Authorization", "Bearer token1") rsp2, err := app.Test(req2) require.NoError(t, err) require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache")) }) } // Test_Cache_DateAndCacheControl tests date parsing and cache control func Test_Cache_DateAndCacheControl(t *testing.T) { t.Parallel() t.Run("date header parsing", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Date", "Mon, 02 Jan 2006 15:04:05 GMT") c.Response().Header.Set("Cache-Control", "max-age=3600") return c.SendString("test") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) }) t.Run("invalid date header", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Date", "invalid") c.Response().Header.Set("Cache-Control", "max-age=3600") return c.SendString("test") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) }) t.Run("cache control with quoted values", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Cache-Control", `max-age=3600, ext="value, with, commas"`) return c.SendString("test") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) }) t.Run("cache control with spaces", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Cache-Control", "max-age=3600 , public , must-revalidate") return c.SendString("test") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) }) } // Test_Cache_CacheControlCombinations tests common cache control directive combinations func Test_Cache_CacheControlCombinations(t *testing.T) { t.Parallel() t.Run("max-age with public", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Cache-Control", "public, max-age=3600") return c.SendString("public content") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache")) }) t.Run("max-age with private", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Cache-Control", "private, max-age=3600") return c.SendString("private content") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheUnreachable, rsp.Header.Get("X-Cache")) }) t.Run("s-maxage overrides max-age", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Cache-Control", "public, max-age=60, s-maxage=3600") return c.SendString("content") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache")) }) t.Run("no-store prevents caching", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Cache-Control", "no-store") return c.SendString("no store content") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheUnreachable, rsp.Header.Get("X-Cache")) rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheUnreachable, rsp2.Header.Get("X-Cache")) }) t.Run("no-cache with etag", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Cache-Control", "no-cache") c.Response().Header.Set("ETag", `"123456"`) return c.SendString("no-cache content") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheUnreachable, rsp.Header.Get("X-Cache")) }) t.Run("must-revalidate with max-age", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Cache-Control", "must-revalidate, max-age=3600") return c.SendString("must revalidate content") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache")) }) t.Run("proxy-revalidate with max-age", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Cache-Control", "public, proxy-revalidate, max-age=3600") return c.SendString("proxy revalidate content") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache")) }) t.Run("immutable with max-age", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Cache-Control", "public, max-age=31536000, immutable") return c.SendString("immutable content") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache")) }) t.Run("max-age=0 with must-revalidate", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Cache-Control", "max-age=0, must-revalidate") return c.SendString("always revalidate") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheUnreachable, rsp.Header.Get("X-Cache")) }) t.Run("public with no explicit max-age", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Cache-Control", "public") return c.SendString("public no max-age") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache")) }) t.Run("multiple cache directives with extensions", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Cache-Control", `public, max-age=3600, custom="value"`) return c.SendString("content") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache")) }) t.Run("private overrides public", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Cache-Control", "public, private, max-age=3600") return c.SendString("conflicting directives") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheUnreachable, rsp.Header.Get("X-Cache")) }) t.Run("stale-while-revalidate with max-age", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Cache-Control", "max-age=60, stale-while-revalidate=120") return c.SendString("stale while revalidate") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache")) }) t.Run("stale-if-error with max-age", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Cache-Control", "max-age=60, stale-if-error=3600") return c.SendString("stale if error") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache")) }) } // Test_Cache_RequestResponseDirectives tests caching behavior with various request/response cache-control directives func Test_Cache_RequestResponseDirectives(t *testing.T) { t.Parallel() t.Run("negative expiration skips caching", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: -1 * time.Second})) app.Get("/test", func(c fiber.Ctx) error { return c.SendString("test") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.NotEqual(t, cacheMiss, rsp.Header.Get("X-Cache")) require.NotEqual(t, cacheHit, rsp.Header.Get("X-Cache")) }) t.Run("request with no-store directive", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { return c.SendString("test") }) req := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody) req.Header.Set("Cache-Control", "no-store") rsp, err := app.Test(req) require.NoError(t, err) require.NotEqual(t, cacheMiss, rsp.Header.Get("X-Cache")) }) t.Run("request with pragma no-cache", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Cache-Control", "max-age=3600") return c.SendString("test") }) req := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody) req.Header.Set("Pragma", "no-cache") rsp, err := app.Test(req) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) }) t.Run("method not in allowed methods list", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Expiration: 1 * time.Hour, Methods: []string{fiber.MethodGet}, })) app.Post("/test", func(c fiber.Ctx) error { return c.SendString("test") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodPost, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheUnreachable, rsp.Header.Get("X-Cache")) }) t.Run("request with min-fresh directive", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Cache-Control", "max-age=60") return c.SendString("test") }) // First request to cache rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) // Second request with min-fresh that's too high req := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody) req.Header.Set("Cache-Control", "min-fresh=120") rsp, err = app.Test(req) require.NoError(t, err) // Should be a miss because min-fresh requirement not met cacheStatus := rsp.Header.Get("X-Cache") require.Contains(t, []string{cacheMiss, cacheUnreachable}, cacheStatus, "min-fresh requirement should prevent cache hit") }) t.Run("request with max-age=0 directive", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Cache-Control", "max-age=3600") return c.SendString("test") }) // First request to cache rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) // Second request with max-age=0 req := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody) req.Header.Set("Cache-Control", "max-age=0") rsp, err = app.Test(req) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) }) t.Run("request with max-stale directive", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Second})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Cache-Control", "max-age=1") return c.SendString("test") }) // First request to cache rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) // Wait for it to become stale time.Sleep(2 * time.Second) // Request with max-stale to accept stale content req := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody) req.Header.Set("Cache-Control", "max-stale=60") rsp, err = app.Test(req) require.NoError(t, err) // max-stale should allow serving stale content cacheStatus := rsp.Header.Get("X-Cache") // Should be either a hit (if stale is served) or miss (if revalidated) require.Contains(t, []string{cacheHit, cacheMiss, "stale"}, cacheStatus, "max-stale should allow stale content or revalidate") }) t.Run("response with expires header", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { futureTime := time.Now().Add(1 * time.Hour).Format(time.RFC1123) c.Response().Header.Set("Expires", futureTime) return c.SendString("test") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache")) }) t.Run("response with age header", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Cache-Control", "max-age=3600") c.Response().Header.Set("Age", "30") return c.SendString("test") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache")) }) t.Run("custom key generator", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Expiration: 1 * time.Hour, KeyGenerator: func(c fiber.Ctx) string { return "custom-" + c.Path() }, })) app.Get("/test", func(c fiber.Ctx) error { return c.SendString("test") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache")) }) t.Run("response with warning header", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Second})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Cache-Control", "max-age=1") return c.SendString("test") }) // Cache the response rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) // Wait for it to become stale time.Sleep(2 * time.Second) // Request again - should get stale warning or revalidate rsp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) // Check that either cache miss (revalidation) or warning header is present cacheStatus := rsp.Header.Get("X-Cache") warningHeader := rsp.Header.Get("Warning") require.True(t, cacheStatus == cacheMiss || warningHeader != "", "stale response should either revalidate or have warning header") }) t.Run("external storage with body key", func(t *testing.T) { t.Parallel() storage := newFailingCacheStorage() app := fiber.New() app.Use(New(Config{ Expiration: 1 * time.Hour, Storage: storage, })) app.Get("/test", func(c fiber.Ctx) error { return c.SendString("test content") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) // Verify body key is stored hasBodyKey := false storage.mu.RLock() for k := range storage.data { if strings.Contains(k, "_body") { hasBodyKey = true break } } storage.mu.RUnlock() require.True(t, hasBodyKey) }) t.Run("only-if-cached with cache miss", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { return c.SendString("test") }) req := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody) req.Header.Set("Cache-Control", "only-if-cached") rsp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusGatewayTimeout, rsp.StatusCode) }) t.Run("only-if-cached with cache hit", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Cache-Control", "max-age=3600") return c.SendString("test") }) // First request to cache rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) // Second request with only-if-cached req := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody) req.Header.Set("Cache-Control", "only-if-cached") rsp2, err := app.Test(req) require.NoError(t, err) require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache")) }) t.Run("cache control with uppercase directives", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Cache-Control", "PUBLIC, MAX-AGE=3600") return c.SendString("test") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache")) }) } // Test_Cache_ConfigurationAndResponseHandling tests cache behavior for specific configuration and response edge cases. func Test_Cache_ConfigurationAndResponseHandling(t *testing.T) { t.Parallel() t.Run("response with Vary star prevents caching", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Vary", "*") c.Response().Header.Set("Cache-Control", "max-age=3600") return c.SendString("test") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheUnreachable, rsp.Header.Get("X-Cache")) }) t.Run("next function prevents caching", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Expiration: 1 * time.Hour, Next: func(c fiber.Ctx) bool { return c.Path() == "/skip" }, })) app.Get("/skip", func(c fiber.Ctx) error { return c.SendString("test") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/skip", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheUnreachable, rsp.Header.Get("X-Cache")) }) t.Run("non-cacheable status code", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { return c.Status(fiber.StatusCreated).SendString("created") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheUnreachable, rsp.Header.Get("X-Cache")) }) t.Run("body larger than MaxBytes", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Expiration: 1 * time.Hour, MaxBytes: 10, })) app.Get("/test", func(c fiber.Ctx) error { return c.Send(make([]byte, 100)) }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheUnreachable, rsp.Header.Get("X-Cache")) }) t.Run("authorization without shared cache directives", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Cache-Control", "max-age=3600") return c.SendString("test") }) req := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody) req.Header.Set("Authorization", "******") rsp, err := app.Test(req) require.NoError(t, err) require.Equal(t, cacheUnreachable, rsp.Header.Get("X-Cache")) }) t.Run("disable cache control header generation", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Expiration: 1 * time.Hour, DisableCacheControl: true, })) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Cache-Control", "max-age=3600") return c.SendString("test") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache")) }) t.Run("disable value redaction", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Expiration: 1 * time.Hour, DisableValueRedaction: true, })) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Cache-Control", "max-age=3600") return c.SendString("test") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) }) t.Run("response with ETag header", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Cache-Control", "max-age=3600") c.Response().Header.Set("ETag", `"abc123"`) return c.SendString("test") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache")) require.Equal(t, `"abc123"`, rsp2.Header.Get("ETag")) }) t.Run("response with Content-Encoding header", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Cache-Control", "max-age=3600") c.Response().Header.Set("Content-Encoding", "gzip") return c.SendString("test") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache")) require.Equal(t, "gzip", rsp2.Header.Get("Content-Encoding")) }) t.Run("response with custom headers preserved", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Expiration: 1 * time.Hour, StoreResponseHeaders: true, })) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Cache-Control", "max-age=3600") c.Response().Header.Set("X-Custom-Header", "custom-value") return c.SendString("test") }) rsp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheHit, rsp2.Header.Get("X-Cache")) require.Equal(t, "custom-value", rsp2.Header.Get("X-Custom-Header")) }) t.Run("revalidation scenario with cache miss", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { c.Response().Header.Set("Cache-Control", "max-age=3600") return c.SendString("test") }) // Request with no-cache forces revalidation req := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody) req.Header.Set("Cache-Control", "no-cache") rsp, err := app.Test(req) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) }) t.Run("delete vary manifest on no-cache response", func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { // First request creates vary manifest if c.Query("first") != "" { c.Response().Header.Set("Vary", "Accept") c.Response().Header.Set("Cache-Control", "max-age=3600") } else { // Second request returns no-cache to delete manifest c.Response().Header.Set("Cache-Control", "no-cache") } return c.SendString("test") }) req := httptest.NewRequest(fiber.MethodGet, "/test?first=true", http.NoBody) req.Header.Set("Accept", "application/json") rsp, err := app.Test(req) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) // Second request without Vary should delete manifest rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheUnreachable, rsp2.Header.Get("X-Cache")) }) t.Run("vary manifest deletion on different vary response", func(t *testing.T) { t.Parallel() app := fiber.New() var counter atomic.Int32 app.Use(New(Config{Expiration: 1 * time.Hour})) app.Get("/test", func(c fiber.Ctx) error { if counter.Add(1) == 1 { c.Response().Header.Set("Vary", "Accept") } // Second response has no Vary header - should delete manifest c.Response().Header.Set("Cache-Control", "max-age=3600") return c.SendString("test") }) req := httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody) req.Header.Set("Accept", "application/json") rsp, err := app.Test(req) require.NoError(t, err) require.Equal(t, cacheMiss, rsp.Header.Get("X-Cache")) // Second request - different vary behavior rsp2, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, cacheMiss, rsp2.Header.Get("X-Cache")) }) } ================================================ FILE: middleware/cache/config.go ================================================ package cache import ( "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/utils/v2" ) // Config defines the config for middleware. type Config struct { // Storage is used to store the state of the middleware // // Default: an in-memory store for this process only Storage fiber.Storage // Next defines a function to skip this middleware when returned true. // // Optional. Default: nil Next func(c fiber.Ctx) bool // CacheInvalidator defines a function to invalidate the cache when returned true // // Optional. Default: nil CacheInvalidator func(fiber.Ctx) bool // Key allows you to generate custom keys, by default c.Path() is used // // Default: func(c fiber.Ctx) string { // return utils.CopyString(c.Path()) // } KeyGenerator func(fiber.Ctx) string // ExpirationGenerator allows you to generate a custom expiration per request. // If nil, the Expiration value is used. // // Default: nil ExpirationGenerator func(fiber.Ctx, *Config) time.Duration // CacheHeader header on response header, indicate cache status, with the following possible return value // // hit, miss, unreachable // // Optional. Default: X-Cache CacheHeader string // You can specify HTTP methods to cache. // The middleware just caches the routes of its methods in this slice. // // Default: []string{fiber.MethodGet, fiber.MethodHead} Methods []string // Expiration is the time that a cached response will live // // Optional. Default: 5 * time.Minute Expiration time.Duration // Max number of bytes of response bodies simultaneously stored in cache. When limit is reached, // entries with the nearest expiration are deleted to make room for new. // 0 means no limit // // Optional. Default: 1 * 1024 * 1024 MaxBytes uint // DisableValueRedaction turns off masking cache keys in logs and error messages when set to true. // // Optional. Default: false DisableValueRedaction bool // DisableCacheControl disables client side caching if set to true // // Optional. Default: false DisableCacheControl bool // StoreResponseHeaders allows you to store additional headers generated by // next middlewares and handlers. // // Default: false StoreResponseHeaders bool } // ConfigDefault is the default config var ConfigDefault = Config{ Next: nil, Expiration: 5 * time.Minute, CacheHeader: "X-Cache", DisableCacheControl: false, CacheInvalidator: nil, DisableValueRedaction: false, KeyGenerator: func(c fiber.Ctx) string { return utils.CopyString(c.Path()) }, ExpirationGenerator: nil, StoreResponseHeaders: false, Storage: nil, MaxBytes: 1 * 1024 * 1024, Methods: []string{fiber.MethodGet, fiber.MethodHead}, } // Helper function to set default values func configDefault(config ...Config) Config { // Return default config if nothing provided if len(config) < 1 { return ConfigDefault } // Override default config cfg := config[0] // Set default values if cfg.Next == nil { cfg.Next = ConfigDefault.Next } if int(cfg.Expiration.Seconds()) == 0 { cfg.Expiration = ConfigDefault.Expiration } if cfg.CacheHeader == "" { cfg.CacheHeader = ConfigDefault.CacheHeader } if cfg.KeyGenerator == nil { cfg.KeyGenerator = ConfigDefault.KeyGenerator } if len(cfg.Methods) == 0 { cfg.Methods = ConfigDefault.Methods } return cfg } ================================================ FILE: middleware/cache/heap.go ================================================ package cache import ( "container/heap" ) type heapEntry struct { key string exp uint64 bytes uint idx int } // indexedHeap is a regular min-heap that allows finding // elements in constant time. It does so by handing out special indices // and tracking entry movement. // // indexedHeap is used for quickly finding entries with the lowest // expiration timestamp and deleting arbitrary entries. type indexedHeap struct { // Slice the heap is built on entries []heapEntry // Mapping "index" to position in heap slice indices []int // Max index handed out maxidx int } // Len implements heap.Interface by reporting the number of entries in the heap. func (h indexedHeap) Len() int { return len(h.entries) } // Less implements heap.Interface and orders entries by expiration time. func (h indexedHeap) Less(i, j int) bool { return h.entries[i].exp < h.entries[j].exp } // Swap implements heap.Interface and swaps the entries at the provided indices. func (h indexedHeap) Swap(i, j int) { h.entries[i], h.entries[j] = h.entries[j], h.entries[i] h.indices[h.entries[i].idx] = i h.indices[h.entries[j].idx] = j } // Push implements heap.Interface and inserts a new entry into the heap. func (h *indexedHeap) Push(x any) { h.pushInternal(x.(heapEntry)) //nolint:forcetypeassert,errcheck // Forced type assertion required to implement the heap.Interface interface } // Pop implements heap.Interface and removes the last entry from the heap. func (h *indexedHeap) Pop() any { n := len(h.entries) h.entries = h.entries[0 : n-1] return h.entries[0:n][n-1] } func (h *indexedHeap) pushInternal(entry heapEntry) { h.indices[entry.idx] = len(h.entries) h.entries = append(h.entries, entry) } // Returns index to track entry func (h *indexedHeap) put(key string, exp uint64, bytes uint) int { idx := 0 if len(h.entries) < h.maxidx { // Steal index from previously removed entry // capacity > size is guaranteed n := len(h.entries) idx = h.entries[:n+1][n].idx } else { idx = h.maxidx h.maxidx++ h.indices = append(h.indices, idx) } // Push manually to avoid allocation h.pushInternal(heapEntry{ key: key, exp: exp, idx: idx, bytes: bytes, }) heap.Fix(h, h.Len()-1) return idx } func (h *indexedHeap) removeInternal(realIdx int) (key string, size uint) { //nolint:nonamedreturns // gocritic unnamedResult prefers named key and size when removing heap entries x := heap.Remove(h, realIdx).(heapEntry) //nolint:forcetypeassert,errcheck // Forced type assertion required to implement the heap.Interface interface return x.key, x.bytes } // Remove entry by index func (h *indexedHeap) remove(idx int) (key string, size uint) { //nolint:nonamedreturns // gocritic unnamedResult prefers naming returned key and size pair return h.removeInternal(h.indices[idx]) } // Remove entry with lowest expiration time func (h *indexedHeap) removeFirst() (key string, size uint) { //nolint:nonamedreturns // gocritic unnamedResult prefers naming returned key and size pair return h.removeInternal(0) } ================================================ FILE: middleware/cache/manager.go ================================================ package cache import ( "context" "errors" "fmt" "sync" "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/internal/memory" ) // msgp -file="manager.go" -o="manager_msgp.go" -tests=true -unexported // Default slice limits are sized for cache payloads, with tighter field caps below. // //go:generate msgp -o=manager_msgp.go -tests=true -unexported //nolint:revive // msgp requires tags on unexported fields for limit enforcement. type item struct { headers []cachedHeader `msg:",limit=1024"` // Typical HTTP header count stays well below this. body []byte // Cache bodies are bounded by storage policy, not msgp limits. ctype []byte `msg:",limit=256"` // Content-Type values are short per RFCs. cencoding []byte `msg:",limit=128"` // Content-Encoding is typically a short token. cacheControl []byte `msg:",limit=2048"` // Cache-Control directives are bounded. expires []byte `msg:",limit=128"` // Expires is a short HTTP-date string. etag []byte `msg:",limit=256"` // ETags are small tokens/quoted strings. date uint64 status int age uint64 exp uint64 ttl uint64 forceRevalidate bool revalidate bool shareable bool private bool // used for finding the item in an indexed heap heapidx int } //nolint:revive // msgp requires tags on unexported fields for limit enforcement. type cachedHeader struct { key []byte `msg:",limit=512"` // Header names are small. value []byte `msg:",limit=16384"` // Header values are bounded to reasonable sizes. } //msgp:ignore manager type manager struct { pool sync.Pool memory *memory.Storage storage fiber.Storage redactKeys bool } const redactedKey = "[redacted]" var errCacheMiss = errors.New("cache: miss") func newManager(storage fiber.Storage, redactKeys bool) *manager { // Create new storage handler manager := &manager{ pool: sync.Pool{ New: func() any { return new(item) }, }, redactKeys: redactKeys, } if storage != nil { // Use provided storage if provided manager.storage = storage } else { // Fallback to memory storage manager.memory = memory.New() } return manager } // acquire returns an *entry from the sync.Pool func (m *manager) acquire() *item { return m.pool.Get().(*item) //nolint:forcetypeassert,errcheck // We store nothing else in the pool } // release and reset *entry to sync.Pool func (m *manager) release(e *item) { // don't release item if we using in-memory storage if m.storage == nil { return } e.body = nil e.cacheControl = nil e.expires = nil e.etag = nil e.ctype = nil e.cencoding = nil e.date = 0 e.status = 0 e.age = 0 e.exp = 0 e.ttl = 0 e.forceRevalidate = false e.revalidate = false e.headers = nil e.shareable = false e.private = false e.heapidx = 0 m.pool.Put(e) } // get data from storage or memory func (m *manager) get(ctx context.Context, key string) (*item, error) { if m.storage != nil { raw, err := m.storage.GetWithContext(ctx, key) if err != nil { return nil, fmt.Errorf("cache: failed to get key %q from storage: %w", m.logKey(key), err) } if raw == nil { return nil, errCacheMiss } it := m.acquire() if _, err := it.UnmarshalMsg(raw); err != nil { m.release(it) return nil, fmt.Errorf("cache: failed to unmarshal key %q: %w", m.logKey(key), err) } return it, nil } if value := m.memory.Get(key); value != nil { it, ok := value.(*item) if !ok { return nil, fmt.Errorf("cache: unexpected entry type %T for key %q", value, m.logKey(key)) } return it, nil } return nil, errCacheMiss } // get raw data from storage or memory func (m *manager) getRaw(ctx context.Context, key string) ([]byte, error) { if m.storage != nil { raw, err := m.storage.GetWithContext(ctx, key) if err != nil { return nil, fmt.Errorf("cache: failed to get raw key %q from storage: %w", m.logKey(key), err) } if raw == nil { return nil, errCacheMiss } return raw, nil } if value := m.memory.Get(key); value != nil { raw, ok := value.([]byte) if !ok { return nil, fmt.Errorf("cache: unexpected raw entry type %T for key %q", value, m.logKey(key)) } return raw, nil } return nil, errCacheMiss } // set data to storage or memory func (m *manager) set(ctx context.Context, key string, it *item, exp time.Duration) error { if m.storage != nil { raw, err := it.MarshalMsg(nil) if err != nil { m.release(it) return fmt.Errorf("cache: failed to marshal key %q: %w", m.logKey(key), err) } if err := m.storage.SetWithContext(ctx, key, raw, exp); err != nil { m.release(it) return fmt.Errorf("cache: failed to store key %q: %w", m.logKey(key), err) } m.release(it) return nil } m.memory.Set(key, it, exp) return nil } // set data to storage or memory func (m *manager) setRaw(ctx context.Context, key string, raw []byte, exp time.Duration) error { if m.storage != nil { if err := m.storage.SetWithContext(ctx, key, raw, exp); err != nil { return fmt.Errorf("cache: failed to store raw key %q: %w", m.logKey(key), err) } return nil } m.memory.Set(key, raw, exp) return nil } // delete data from storage or memory func (m *manager) del(ctx context.Context, key string) error { if m.storage != nil { if err := m.storage.DeleteWithContext(ctx, key); err != nil { return fmt.Errorf("cache: failed to delete key %q: %w", m.logKey(key), err) } return nil } m.memory.Delete(key) return nil } func (m *manager) logKey(key string) string { if m.redactKeys { return redactedKey } return key } ================================================ FILE: middleware/cache/manager_msgp.go ================================================ // Code generated by github.com/tinylib/msgp DO NOT EDIT. package cache import ( "github.com/tinylib/msgp/msgp" ) // DecodeMsg implements msgp.Decodable func (z *cachedHeader) DecodeMsg(dc *msgp.Reader) (err error) { var field []byte _ = field var zb0001 uint32 zb0001, err = dc.ReadMapHeader() if err != nil { err = msgp.WrapError(err) return } for zb0001 > 0 { zb0001-- field, err = dc.ReadMapKeyPtr() if err != nil { err = msgp.WrapError(err) return } switch msgp.UnsafeString(field) { case "key": z.key, err = dc.ReadBytesLimit(z.key, 512) if err == nil && z.key == nil { z.key = []byte{} } if err != nil { err = msgp.WrapError(err, "key") return } case "value": z.value, err = dc.ReadBytesLimit(z.value, 16384) if err == nil && z.value == nil { z.value = []byte{} } if err != nil { err = msgp.WrapError(err, "value") return } default: err = dc.Skip() if err != nil { err = msgp.WrapError(err) return } } } return } // EncodeMsg implements msgp.Encodable func (z *cachedHeader) EncodeMsg(en *msgp.Writer) (err error) { // map header, size 2 // write "key" err = en.Append(0x82, 0xa3, 0x6b, 0x65, 0x79) if err != nil { return } err = en.WriteBytes(z.key) if err != nil { err = msgp.WrapError(err, "key") return } // write "value" err = en.Append(0xa5, 0x76, 0x61, 0x6c, 0x75, 0x65) if err != nil { return } err = en.WriteBytes(z.value) if err != nil { err = msgp.WrapError(err, "value") return } return } // MarshalMsg implements msgp.Marshaler func (z *cachedHeader) MarshalMsg(b []byte) (o []byte, err error) { o = msgp.Require(b, z.Msgsize()) // map header, size 2 // string "key" o = append(o, 0x82, 0xa3, 0x6b, 0x65, 0x79) o = msgp.AppendBytes(o, z.key) // string "value" o = append(o, 0xa5, 0x76, 0x61, 0x6c, 0x75, 0x65) o = msgp.AppendBytes(o, z.value) return } // UnmarshalMsg implements msgp.Unmarshaler func (z *cachedHeader) UnmarshalMsg(bts []byte) (o []byte, err error) { var field []byte _ = field var zb0001 uint32 zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) if err != nil { err = msgp.WrapError(err) return } for zb0001 > 0 { zb0001-- field, bts, err = msgp.ReadMapKeyZC(bts) if err != nil { err = msgp.WrapError(err) return } switch msgp.UnsafeString(field) { case "key": var zb0002 uint32 zb0002, bts, err = msgp.ReadBytesHeader(bts) if err != nil { err = msgp.WrapError(err, "key") return } if zb0002 > 512 { err = msgp.ErrLimitExceeded return } if z.key == nil || uint32(cap(z.key)) < zb0002 { z.key = make([]byte, zb0002) } else { z.key = z.key[:zb0002] } if uint32(len(bts)) < zb0002 { err = msgp.ErrShortBytes return } copy(z.key, bts[:zb0002]) bts = bts[zb0002:] case "value": var zb0003 uint32 zb0003, bts, err = msgp.ReadBytesHeader(bts) if err != nil { err = msgp.WrapError(err, "value") return } if zb0003 > 16384 { err = msgp.ErrLimitExceeded return } if z.value == nil || uint32(cap(z.value)) < zb0003 { z.value = make([]byte, zb0003) } else { z.value = z.value[:zb0003] } if uint32(len(bts)) < zb0003 { err = msgp.ErrShortBytes return } copy(z.value, bts[:zb0003]) bts = bts[zb0003:] default: bts, err = msgp.Skip(bts) if err != nil { err = msgp.WrapError(err) return } } } o = bts return } // Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message func (z *cachedHeader) Msgsize() (s int) { s = 1 + 4 + msgp.BytesPrefixSize + len(z.key) + 6 + msgp.BytesPrefixSize + len(z.value) return } // DecodeMsg implements msgp.Decodable func (z *item) DecodeMsg(dc *msgp.Reader) (err error) { var field []byte _ = field var zb0001 uint32 zb0001, err = dc.ReadMapHeader() if err != nil { err = msgp.WrapError(err) return } for zb0001 > 0 { zb0001-- field, err = dc.ReadMapKeyPtr() if err != nil { err = msgp.WrapError(err) return } switch msgp.UnsafeString(field) { case "headers": var zb0002 uint32 zb0002, err = dc.ReadArrayHeader() if err != nil { err = msgp.WrapError(err, "headers") return } if zb0002 > 1024 { err = msgp.ErrLimitExceeded return } if cap(z.headers) >= int(zb0002) { z.headers = (z.headers)[:zb0002] } else { z.headers = make([]cachedHeader, zb0002) } for za0001 := range z.headers { var zb0003 uint32 zb0003, err = dc.ReadMapHeader() if err != nil { err = msgp.WrapError(err, "headers", za0001) return } if zb0003 > 1024 { err = msgp.ErrLimitExceeded return } for zb0003 > 0 { zb0003-- field, err = dc.ReadMapKeyPtr() if err != nil { err = msgp.WrapError(err, "headers", za0001) return } switch msgp.UnsafeString(field) { case "key": z.headers[za0001].key, err = dc.ReadBytesLimit(z.headers[za0001].key, 512) if err == nil && z.headers[za0001].key == nil { z.headers[za0001].key = []byte{} } if err != nil { err = msgp.WrapError(err, "headers", za0001, "key") return } case "value": z.headers[za0001].value, err = dc.ReadBytesLimit(z.headers[za0001].value, 16384) if err == nil && z.headers[za0001].value == nil { z.headers[za0001].value = []byte{} } if err != nil { err = msgp.WrapError(err, "headers", za0001, "value") return } default: err = dc.Skip() if err != nil { err = msgp.WrapError(err, "headers", za0001) return } } } } case "body": z.body, err = dc.ReadBytes(z.body) if err != nil { err = msgp.WrapError(err, "body") return } case "ctype": z.ctype, err = dc.ReadBytesLimit(z.ctype, 256) if err == nil && z.ctype == nil { z.ctype = []byte{} } if err != nil { err = msgp.WrapError(err, "ctype") return } case "cencoding": z.cencoding, err = dc.ReadBytesLimit(z.cencoding, 128) if err == nil && z.cencoding == nil { z.cencoding = []byte{} } if err != nil { err = msgp.WrapError(err, "cencoding") return } case "cacheControl": z.cacheControl, err = dc.ReadBytesLimit(z.cacheControl, 2048) if err == nil && z.cacheControl == nil { z.cacheControl = []byte{} } if err != nil { err = msgp.WrapError(err, "cacheControl") return } case "expires": z.expires, err = dc.ReadBytesLimit(z.expires, 128) if err == nil && z.expires == nil { z.expires = []byte{} } if err != nil { err = msgp.WrapError(err, "expires") return } case "etag": z.etag, err = dc.ReadBytesLimit(z.etag, 256) if err == nil && z.etag == nil { z.etag = []byte{} } if err != nil { err = msgp.WrapError(err, "etag") return } case "date": z.date, err = dc.ReadUint64() if err != nil { err = msgp.WrapError(err, "date") return } case "status": z.status, err = dc.ReadInt() if err != nil { err = msgp.WrapError(err, "status") return } case "age": z.age, err = dc.ReadUint64() if err != nil { err = msgp.WrapError(err, "age") return } case "exp": z.exp, err = dc.ReadUint64() if err != nil { err = msgp.WrapError(err, "exp") return } case "ttl": z.ttl, err = dc.ReadUint64() if err != nil { err = msgp.WrapError(err, "ttl") return } case "forceRevalidate": z.forceRevalidate, err = dc.ReadBool() if err != nil { err = msgp.WrapError(err, "forceRevalidate") return } case "revalidate": z.revalidate, err = dc.ReadBool() if err != nil { err = msgp.WrapError(err, "revalidate") return } case "shareable": z.shareable, err = dc.ReadBool() if err != nil { err = msgp.WrapError(err, "shareable") return } case "private": z.private, err = dc.ReadBool() if err != nil { err = msgp.WrapError(err, "private") return } case "heapidx": z.heapidx, err = dc.ReadInt() if err != nil { err = msgp.WrapError(err, "heapidx") return } default: err = dc.Skip() if err != nil { err = msgp.WrapError(err) return } } } return } // EncodeMsg implements msgp.Encodable func (z *item) EncodeMsg(en *msgp.Writer) (err error) { // map header, size 17 // write "headers" err = en.Append(0xde, 0x0, 0x11, 0xa7, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73) if err != nil { return } err = en.WriteArrayHeader(uint32(len(z.headers))) if err != nil { err = msgp.WrapError(err, "headers") return } for za0001 := range z.headers { // map header, size 2 // write "key" err = en.Append(0x82, 0xa3, 0x6b, 0x65, 0x79) if err != nil { return } err = en.WriteBytes(z.headers[za0001].key) if err != nil { err = msgp.WrapError(err, "headers", za0001, "key") return } // write "value" err = en.Append(0xa5, 0x76, 0x61, 0x6c, 0x75, 0x65) if err != nil { return } err = en.WriteBytes(z.headers[za0001].value) if err != nil { err = msgp.WrapError(err, "headers", za0001, "value") return } } // write "body" err = en.Append(0xa4, 0x62, 0x6f, 0x64, 0x79) if err != nil { return } err = en.WriteBytes(z.body) if err != nil { err = msgp.WrapError(err, "body") return } // write "ctype" err = en.Append(0xa5, 0x63, 0x74, 0x79, 0x70, 0x65) if err != nil { return } err = en.WriteBytes(z.ctype) if err != nil { err = msgp.WrapError(err, "ctype") return } // write "cencoding" err = en.Append(0xa9, 0x63, 0x65, 0x6e, 0x63, 0x6f, 0x64, 0x69, 0x6e, 0x67) if err != nil { return } err = en.WriteBytes(z.cencoding) if err != nil { err = msgp.WrapError(err, "cencoding") return } // write "cacheControl" err = en.Append(0xac, 0x63, 0x61, 0x63, 0x68, 0x65, 0x43, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c) if err != nil { return } err = en.WriteBytes(z.cacheControl) if err != nil { err = msgp.WrapError(err, "cacheControl") return } // write "expires" err = en.Append(0xa7, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73) if err != nil { return } err = en.WriteBytes(z.expires) if err != nil { err = msgp.WrapError(err, "expires") return } // write "etag" err = en.Append(0xa4, 0x65, 0x74, 0x61, 0x67) if err != nil { return } err = en.WriteBytes(z.etag) if err != nil { err = msgp.WrapError(err, "etag") return } // write "date" err = en.Append(0xa4, 0x64, 0x61, 0x74, 0x65) if err != nil { return } err = en.WriteUint64(z.date) if err != nil { err = msgp.WrapError(err, "date") return } // write "status" err = en.Append(0xa6, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73) if err != nil { return } err = en.WriteInt(z.status) if err != nil { err = msgp.WrapError(err, "status") return } // write "age" err = en.Append(0xa3, 0x61, 0x67, 0x65) if err != nil { return } err = en.WriteUint64(z.age) if err != nil { err = msgp.WrapError(err, "age") return } // write "exp" err = en.Append(0xa3, 0x65, 0x78, 0x70) if err != nil { return } err = en.WriteUint64(z.exp) if err != nil { err = msgp.WrapError(err, "exp") return } // write "ttl" err = en.Append(0xa3, 0x74, 0x74, 0x6c) if err != nil { return } err = en.WriteUint64(z.ttl) if err != nil { err = msgp.WrapError(err, "ttl") return } // write "forceRevalidate" err = en.Append(0xaf, 0x66, 0x6f, 0x72, 0x63, 0x65, 0x52, 0x65, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65) if err != nil { return } err = en.WriteBool(z.forceRevalidate) if err != nil { err = msgp.WrapError(err, "forceRevalidate") return } // write "revalidate" err = en.Append(0xaa, 0x72, 0x65, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65) if err != nil { return } err = en.WriteBool(z.revalidate) if err != nil { err = msgp.WrapError(err, "revalidate") return } // write "shareable" err = en.Append(0xa9, 0x73, 0x68, 0x61, 0x72, 0x65, 0x61, 0x62, 0x6c, 0x65) if err != nil { return } err = en.WriteBool(z.shareable) if err != nil { err = msgp.WrapError(err, "shareable") return } // write "private" err = en.Append(0xa7, 0x70, 0x72, 0x69, 0x76, 0x61, 0x74, 0x65) if err != nil { return } err = en.WriteBool(z.private) if err != nil { err = msgp.WrapError(err, "private") return } // write "heapidx" err = en.Append(0xa7, 0x68, 0x65, 0x61, 0x70, 0x69, 0x64, 0x78) if err != nil { return } err = en.WriteInt(z.heapidx) if err != nil { err = msgp.WrapError(err, "heapidx") return } return } // MarshalMsg implements msgp.Marshaler func (z *item) MarshalMsg(b []byte) (o []byte, err error) { o = msgp.Require(b, z.Msgsize()) // map header, size 17 // string "headers" o = append(o, 0xde, 0x0, 0x11, 0xa7, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73) o = msgp.AppendArrayHeader(o, uint32(len(z.headers))) for za0001 := range z.headers { // map header, size 2 // string "key" o = append(o, 0x82, 0xa3, 0x6b, 0x65, 0x79) o = msgp.AppendBytes(o, z.headers[za0001].key) // string "value" o = append(o, 0xa5, 0x76, 0x61, 0x6c, 0x75, 0x65) o = msgp.AppendBytes(o, z.headers[za0001].value) } // string "body" o = append(o, 0xa4, 0x62, 0x6f, 0x64, 0x79) o = msgp.AppendBytes(o, z.body) // string "ctype" o = append(o, 0xa5, 0x63, 0x74, 0x79, 0x70, 0x65) o = msgp.AppendBytes(o, z.ctype) // string "cencoding" o = append(o, 0xa9, 0x63, 0x65, 0x6e, 0x63, 0x6f, 0x64, 0x69, 0x6e, 0x67) o = msgp.AppendBytes(o, z.cencoding) // string "cacheControl" o = append(o, 0xac, 0x63, 0x61, 0x63, 0x68, 0x65, 0x43, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c) o = msgp.AppendBytes(o, z.cacheControl) // string "expires" o = append(o, 0xa7, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73) o = msgp.AppendBytes(o, z.expires) // string "etag" o = append(o, 0xa4, 0x65, 0x74, 0x61, 0x67) o = msgp.AppendBytes(o, z.etag) // string "date" o = append(o, 0xa4, 0x64, 0x61, 0x74, 0x65) o = msgp.AppendUint64(o, z.date) // string "status" o = append(o, 0xa6, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73) o = msgp.AppendInt(o, z.status) // string "age" o = append(o, 0xa3, 0x61, 0x67, 0x65) o = msgp.AppendUint64(o, z.age) // string "exp" o = append(o, 0xa3, 0x65, 0x78, 0x70) o = msgp.AppendUint64(o, z.exp) // string "ttl" o = append(o, 0xa3, 0x74, 0x74, 0x6c) o = msgp.AppendUint64(o, z.ttl) // string "forceRevalidate" o = append(o, 0xaf, 0x66, 0x6f, 0x72, 0x63, 0x65, 0x52, 0x65, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65) o = msgp.AppendBool(o, z.forceRevalidate) // string "revalidate" o = append(o, 0xaa, 0x72, 0x65, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65) o = msgp.AppendBool(o, z.revalidate) // string "shareable" o = append(o, 0xa9, 0x73, 0x68, 0x61, 0x72, 0x65, 0x61, 0x62, 0x6c, 0x65) o = msgp.AppendBool(o, z.shareable) // string "private" o = append(o, 0xa7, 0x70, 0x72, 0x69, 0x76, 0x61, 0x74, 0x65) o = msgp.AppendBool(o, z.private) // string "heapidx" o = append(o, 0xa7, 0x68, 0x65, 0x61, 0x70, 0x69, 0x64, 0x78) o = msgp.AppendInt(o, z.heapidx) return } // UnmarshalMsg implements msgp.Unmarshaler func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) { var field []byte _ = field var zb0001 uint32 zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) if err != nil { err = msgp.WrapError(err) return } for zb0001 > 0 { zb0001-- field, bts, err = msgp.ReadMapKeyZC(bts) if err != nil { err = msgp.WrapError(err) return } switch msgp.UnsafeString(field) { case "headers": var zb0002 uint32 zb0002, bts, err = msgp.ReadArrayHeaderBytes(bts) if err != nil { err = msgp.WrapError(err, "headers") return } if zb0002 > 1024 { err = msgp.ErrLimitExceeded return } if cap(z.headers) >= int(zb0002) { z.headers = (z.headers)[:zb0002] } else { z.headers = make([]cachedHeader, zb0002) } for za0001 := range z.headers { var zb0003 uint32 zb0003, bts, err = msgp.ReadMapHeaderBytes(bts) if err != nil { err = msgp.WrapError(err, "headers", za0001) return } if zb0003 > 1024 { err = msgp.ErrLimitExceeded return } for zb0003 > 0 { zb0003-- field, bts, err = msgp.ReadMapKeyZC(bts) if err != nil { err = msgp.WrapError(err, "headers", za0001) return } switch msgp.UnsafeString(field) { case "key": var zb0004 uint32 zb0004, bts, err = msgp.ReadBytesHeader(bts) if err != nil { err = msgp.WrapError(err, "headers", za0001, "key") return } if zb0004 > 512 { err = msgp.ErrLimitExceeded return } if z.headers[za0001].key == nil || uint32(cap(z.headers[za0001].key)) < zb0004 { z.headers[za0001].key = make([]byte, zb0004) } else { z.headers[za0001].key = z.headers[za0001].key[:zb0004] } if uint32(len(bts)) < zb0004 { err = msgp.ErrShortBytes return } copy(z.headers[za0001].key, bts[:zb0004]) bts = bts[zb0004:] case "value": var zb0005 uint32 zb0005, bts, err = msgp.ReadBytesHeader(bts) if err != nil { err = msgp.WrapError(err, "headers", za0001, "value") return } if zb0005 > 16384 { err = msgp.ErrLimitExceeded return } if z.headers[za0001].value == nil || uint32(cap(z.headers[za0001].value)) < zb0005 { z.headers[za0001].value = make([]byte, zb0005) } else { z.headers[za0001].value = z.headers[za0001].value[:zb0005] } if uint32(len(bts)) < zb0005 { err = msgp.ErrShortBytes return } copy(z.headers[za0001].value, bts[:zb0005]) bts = bts[zb0005:] default: bts, err = msgp.Skip(bts) if err != nil { err = msgp.WrapError(err, "headers", za0001) return } } } } case "body": z.body, bts, err = msgp.ReadBytesBytes(bts, z.body) if err != nil { err = msgp.WrapError(err, "body") return } case "ctype": var zb0006 uint32 zb0006, bts, err = msgp.ReadBytesHeader(bts) if err != nil { err = msgp.WrapError(err, "ctype") return } if zb0006 > 256 { err = msgp.ErrLimitExceeded return } if z.ctype == nil || uint32(cap(z.ctype)) < zb0006 { z.ctype = make([]byte, zb0006) } else { z.ctype = z.ctype[:zb0006] } if uint32(len(bts)) < zb0006 { err = msgp.ErrShortBytes return } copy(z.ctype, bts[:zb0006]) bts = bts[zb0006:] case "cencoding": var zb0007 uint32 zb0007, bts, err = msgp.ReadBytesHeader(bts) if err != nil { err = msgp.WrapError(err, "cencoding") return } if zb0007 > 128 { err = msgp.ErrLimitExceeded return } if z.cencoding == nil || uint32(cap(z.cencoding)) < zb0007 { z.cencoding = make([]byte, zb0007) } else { z.cencoding = z.cencoding[:zb0007] } if uint32(len(bts)) < zb0007 { err = msgp.ErrShortBytes return } copy(z.cencoding, bts[:zb0007]) bts = bts[zb0007:] case "cacheControl": var zb0008 uint32 zb0008, bts, err = msgp.ReadBytesHeader(bts) if err != nil { err = msgp.WrapError(err, "cacheControl") return } if zb0008 > 2048 { err = msgp.ErrLimitExceeded return } if z.cacheControl == nil || uint32(cap(z.cacheControl)) < zb0008 { z.cacheControl = make([]byte, zb0008) } else { z.cacheControl = z.cacheControl[:zb0008] } if uint32(len(bts)) < zb0008 { err = msgp.ErrShortBytes return } copy(z.cacheControl, bts[:zb0008]) bts = bts[zb0008:] case "expires": var zb0009 uint32 zb0009, bts, err = msgp.ReadBytesHeader(bts) if err != nil { err = msgp.WrapError(err, "expires") return } if zb0009 > 128 { err = msgp.ErrLimitExceeded return } if z.expires == nil || uint32(cap(z.expires)) < zb0009 { z.expires = make([]byte, zb0009) } else { z.expires = z.expires[:zb0009] } if uint32(len(bts)) < zb0009 { err = msgp.ErrShortBytes return } copy(z.expires, bts[:zb0009]) bts = bts[zb0009:] case "etag": var zb0010 uint32 zb0010, bts, err = msgp.ReadBytesHeader(bts) if err != nil { err = msgp.WrapError(err, "etag") return } if zb0010 > 256 { err = msgp.ErrLimitExceeded return } if z.etag == nil || uint32(cap(z.etag)) < zb0010 { z.etag = make([]byte, zb0010) } else { z.etag = z.etag[:zb0010] } if uint32(len(bts)) < zb0010 { err = msgp.ErrShortBytes return } copy(z.etag, bts[:zb0010]) bts = bts[zb0010:] case "date": z.date, bts, err = msgp.ReadUint64Bytes(bts) if err != nil { err = msgp.WrapError(err, "date") return } case "status": z.status, bts, err = msgp.ReadIntBytes(bts) if err != nil { err = msgp.WrapError(err, "status") return } case "age": z.age, bts, err = msgp.ReadUint64Bytes(bts) if err != nil { err = msgp.WrapError(err, "age") return } case "exp": z.exp, bts, err = msgp.ReadUint64Bytes(bts) if err != nil { err = msgp.WrapError(err, "exp") return } case "ttl": z.ttl, bts, err = msgp.ReadUint64Bytes(bts) if err != nil { err = msgp.WrapError(err, "ttl") return } case "forceRevalidate": z.forceRevalidate, bts, err = msgp.ReadBoolBytes(bts) if err != nil { err = msgp.WrapError(err, "forceRevalidate") return } case "revalidate": z.revalidate, bts, err = msgp.ReadBoolBytes(bts) if err != nil { err = msgp.WrapError(err, "revalidate") return } case "shareable": z.shareable, bts, err = msgp.ReadBoolBytes(bts) if err != nil { err = msgp.WrapError(err, "shareable") return } case "private": z.private, bts, err = msgp.ReadBoolBytes(bts) if err != nil { err = msgp.WrapError(err, "private") return } case "heapidx": z.heapidx, bts, err = msgp.ReadIntBytes(bts) if err != nil { err = msgp.WrapError(err, "heapidx") return } default: bts, err = msgp.Skip(bts) if err != nil { err = msgp.WrapError(err) return } } } o = bts return } // Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message func (z *item) Msgsize() (s int) { s = 3 + 8 + msgp.ArrayHeaderSize for za0001 := range z.headers { s += 1 + 4 + msgp.BytesPrefixSize + len(z.headers[za0001].key) + 6 + msgp.BytesPrefixSize + len(z.headers[za0001].value) } s += 5 + msgp.BytesPrefixSize + len(z.body) + 6 + msgp.BytesPrefixSize + len(z.ctype) + 10 + msgp.BytesPrefixSize + len(z.cencoding) + 13 + msgp.BytesPrefixSize + len(z.cacheControl) + 8 + msgp.BytesPrefixSize + len(z.expires) + 5 + msgp.BytesPrefixSize + len(z.etag) + 5 + msgp.Uint64Size + 7 + msgp.IntSize + 4 + msgp.Uint64Size + 4 + msgp.Uint64Size + 4 + msgp.Uint64Size + 16 + msgp.BoolSize + 11 + msgp.BoolSize + 10 + msgp.BoolSize + 8 + msgp.BoolSize + 8 + msgp.IntSize return } ================================================ FILE: middleware/cache/manager_msgp_test.go ================================================ // Code generated by github.com/tinylib/msgp DO NOT EDIT. package cache import ( "bytes" "testing" "github.com/tinylib/msgp/msgp" ) func TestMarshalUnmarshalcachedHeader(t *testing.T) { v := cachedHeader{} bts, err := v.MarshalMsg(nil) if err != nil { t.Fatal(err) } left, err := v.UnmarshalMsg(bts) if err != nil { t.Fatal(err) } if len(left) > 0 { t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) } left, err = msgp.Skip(bts) if err != nil { t.Fatal(err) } if len(left) > 0 { t.Errorf("%d bytes left over after Skip(): %q", len(left), left) } } func BenchmarkMarshalMsgcachedHeader(b *testing.B) { v := cachedHeader{} b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { v.MarshalMsg(nil) } } func BenchmarkAppendMsgcachedHeader(b *testing.B) { v := cachedHeader{} bts := make([]byte, 0, v.Msgsize()) bts, _ = v.MarshalMsg(bts[0:0]) b.SetBytes(int64(len(bts))) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { bts, _ = v.MarshalMsg(bts[0:0]) } } func BenchmarkUnmarshalcachedHeader(b *testing.B) { v := cachedHeader{} bts, _ := v.MarshalMsg(nil) b.ReportAllocs() b.SetBytes(int64(len(bts))) b.ResetTimer() for i := 0; i < b.N; i++ { _, err := v.UnmarshalMsg(bts) if err != nil { b.Fatal(err) } } } func TestEncodeDecodecachedHeader(t *testing.T) { v := cachedHeader{} var buf bytes.Buffer msgp.Encode(&buf, &v) m := v.Msgsize() if buf.Len() > m { t.Log("WARNING: TestEncodeDecodecachedHeader Msgsize() is inaccurate") } vn := cachedHeader{} err := msgp.Decode(&buf, &vn) if err != nil { t.Error(err) } buf.Reset() msgp.Encode(&buf, &v) err = msgp.NewReader(&buf).Skip() if err != nil { t.Error(err) } } func BenchmarkEncodecachedHeader(b *testing.B) { v := cachedHeader{} var buf bytes.Buffer msgp.Encode(&buf, &v) b.SetBytes(int64(buf.Len())) en := msgp.NewWriter(msgp.Nowhere) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { v.EncodeMsg(en) } en.Flush() } func BenchmarkDecodecachedHeader(b *testing.B) { v := cachedHeader{} var buf bytes.Buffer msgp.Encode(&buf, &v) b.SetBytes(int64(buf.Len())) rd := msgp.NewEndlessReader(buf.Bytes(), b) dc := msgp.NewReader(rd) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { err := v.DecodeMsg(dc) if err != nil { b.Fatal(err) } } } func TestMarshalUnmarshalitem(t *testing.T) { v := item{} bts, err := v.MarshalMsg(nil) if err != nil { t.Fatal(err) } left, err := v.UnmarshalMsg(bts) if err != nil { t.Fatal(err) } if len(left) > 0 { t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) } left, err = msgp.Skip(bts) if err != nil { t.Fatal(err) } if len(left) > 0 { t.Errorf("%d bytes left over after Skip(): %q", len(left), left) } } func BenchmarkMarshalMsgitem(b *testing.B) { v := item{} b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { v.MarshalMsg(nil) } } func BenchmarkAppendMsgitem(b *testing.B) { v := item{} bts := make([]byte, 0, v.Msgsize()) bts, _ = v.MarshalMsg(bts[0:0]) b.SetBytes(int64(len(bts))) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { bts, _ = v.MarshalMsg(bts[0:0]) } } func BenchmarkUnmarshalitem(b *testing.B) { v := item{} bts, _ := v.MarshalMsg(nil) b.ReportAllocs() b.SetBytes(int64(len(bts))) b.ResetTimer() for i := 0; i < b.N; i++ { _, err := v.UnmarshalMsg(bts) if err != nil { b.Fatal(err) } } } func TestEncodeDecodeitem(t *testing.T) { v := item{} var buf bytes.Buffer msgp.Encode(&buf, &v) m := v.Msgsize() if buf.Len() > m { t.Log("WARNING: TestEncodeDecodeitem Msgsize() is inaccurate") } vn := item{} err := msgp.Decode(&buf, &vn) if err != nil { t.Error(err) } buf.Reset() msgp.Encode(&buf, &v) err = msgp.NewReader(&buf).Skip() if err != nil { t.Error(err) } } func BenchmarkEncodeitem(b *testing.B) { v := item{} var buf bytes.Buffer msgp.Encode(&buf, &v) b.SetBytes(int64(buf.Len())) en := msgp.NewWriter(msgp.Nowhere) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { v.EncodeMsg(en) } en.Flush() } func BenchmarkDecodeitem(b *testing.B) { v := item{} var buf bytes.Buffer msgp.Encode(&buf, &v) b.SetBytes(int64(buf.Len())) rd := msgp.NewEndlessReader(buf.Bytes(), b) dc := msgp.NewReader(rd) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { err := v.DecodeMsg(dc) if err != nil { b.Fatal(err) } } } ================================================ FILE: middleware/cache/manager_test.go ================================================ package cache import ( "context" "testing" "time" "github.com/gofiber/utils/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func Test_manager_get(t *testing.T) { t.Parallel() cacheManager := newManager(nil, true) t.Run("Item not found in cache", func(t *testing.T) { t.Parallel() it, err := cacheManager.get(context.Background(), utils.UUIDv4()) require.ErrorIs(t, err, errCacheMiss) assert.Nil(t, it) }) t.Run("Item found in cache", func(t *testing.T) { t.Parallel() id := utils.UUIDv4() cacheItem := cacheManager.acquire() cacheItem.body = []byte("test-body") require.NoError(t, cacheManager.set(context.Background(), id, cacheItem, 10*time.Second)) it, err := cacheManager.get(context.Background(), id) require.NoError(t, err) assert.NotNil(t, it) }) } func Test_manager_logKey(t *testing.T) { t.Parallel() redactedManager := newManager(nil, true) assert.Equal(t, redactedKey, redactedManager.logKey("secret")) plainManager := newManager(nil, false) assert.Equal(t, "secret", plainManager.logKey("secret")) } ================================================ FILE: middleware/compress/compress.go ================================================ package compress import ( "strings" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/etag" "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) func hasToken(header, token string) bool { for part := range strings.SplitSeq(header, ",") { if utils.EqualFold(utils.TrimSpace(part), token) { return true } } return false } func shouldSkip(c fiber.Ctx) bool { if c.Method() == fiber.MethodHead { return true } status := c.Response().StatusCode() if status < 200 || status == fiber.StatusNoContent || status == fiber.StatusResetContent || status == fiber.StatusNotModified || status == fiber.StatusPartialContent || len(c.Response().Body()) == 0 || c.Get(fiber.HeaderRange) != "" || hasToken(c.Get(fiber.HeaderCacheControl), "no-transform") || hasToken(c.GetRespHeader(fiber.HeaderCacheControl), "no-transform") { return true } return false } func appendVaryAcceptEncoding(c fiber.Ctx) { vary := c.GetRespHeader(fiber.HeaderVary) if vary == "" { c.Set(fiber.HeaderVary, fiber.HeaderAcceptEncoding) return } if hasToken(vary, "*") || hasToken(vary, fiber.HeaderAcceptEncoding) { return } c.Set(fiber.HeaderVary, vary+", "+fiber.HeaderAcceptEncoding) } // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Set default config cfg := configDefault(config...) // Setup request handlers var ( fctx = func(_ *fasthttp.RequestCtx) {} compressor fasthttp.RequestHandler ) // Setup compression algorithm switch cfg.Level { case LevelDefault: // LevelDefault compressor = fasthttp.CompressHandlerBrotliLevel(fctx, fasthttp.CompressBrotliDefaultCompression, fasthttp.CompressDefaultCompression, ) case LevelBestSpeed: // LevelBestSpeed compressor = fasthttp.CompressHandlerBrotliLevel(fctx, fasthttp.CompressBrotliBestSpeed, fasthttp.CompressBestSpeed, ) case LevelBestCompression: // LevelBestCompression compressor = fasthttp.CompressHandlerBrotliLevel(fctx, fasthttp.CompressBrotliBestCompression, fasthttp.CompressBestCompression, ) default: // LevelDisabled return func(c fiber.Ctx) error { return c.Next() } } // Return new handler return func(c fiber.Ctx) error { // Don't execute middleware if Next returns true if cfg.Next != nil && cfg.Next(c) { return c.Next() } // Continue stack if err := c.Next(); err != nil { return err } if shouldSkip(c) { appendVaryAcceptEncoding(c) return nil } if c.GetRespHeader(fiber.HeaderContentEncoding) != "" { appendVaryAcceptEncoding(c) return nil } compressor(c.RequestCtx()) if tag := c.GetRespHeader(fiber.HeaderETag); tag != "" && !strings.HasPrefix(tag, "W/") { if c.GetRespHeader(fiber.HeaderContentEncoding) != "" { c.Set(fiber.HeaderETag, string(etag.Generate(c.Response().Body()))) } } appendVaryAcceptEncoding(c) return nil } } ================================================ FILE: middleware/compress/compress_test.go ================================================ package compress import ( "errors" "fmt" "io" "net/http" "net/http/httptest" "os" "testing" "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/etag" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) var filedata []byte var testConfig = fiber.TestConfig{ Timeout: 10 * time.Second, FailOnTimeout: true, } func init() { dat, err := os.ReadFile("../../.github/README.md") if err != nil { panic(err) } filedata = dat } // go test -run Test_Compress_Gzip func Test_Compress_Gzip(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { c.Set(fiber.HeaderContentType, fiber.MIMETextPlainCharsetUTF8) return c.Send(filedata) }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set("Accept-Encoding", "gzip") resp, err := app.Test(req, testConfig) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.Equal(t, "gzip", resp.Header.Get(fiber.HeaderContentEncoding)) // Validate that the file size has shrunk body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Less(t, len(body), len(filedata)) } // go test -run Test_Compress_Different_Level func Test_Compress_Different_Level(t *testing.T) { t.Parallel() levels := []Level{LevelDefault, LevelBestSpeed, LevelBestCompression} algorithms := []string{"gzip", "deflate", "br", "zstd"} for _, algo := range algorithms { for _, level := range levels { t.Run(fmt.Sprintf("%s_level %d", algo, level), func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Level: level})) app.Get("/", func(c fiber.Ctx) error { c.Set(fiber.HeaderContentType, fiber.MIMETextPlainCharsetUTF8) return c.Send(filedata) }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set("Accept-Encoding", algo) resp, err := app.Test(req, testConfig) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.Equal(t, algo, resp.Header.Get(fiber.HeaderContentEncoding)) // Validate that the file size has shrunk body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Less(t, len(body), len(filedata)) }) } } } func Test_Compress_Deflate(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { return c.Send(filedata) }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set("Accept-Encoding", "deflate") resp, err := app.Test(req, testConfig) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.Equal(t, "deflate", resp.Header.Get(fiber.HeaderContentEncoding)) // Validate that the file size has shrunk body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Less(t, len(body), len(filedata)) } func Test_Compress_Brotli(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { return c.Send(filedata) }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set("Accept-Encoding", "br") resp, err := app.Test(req, testConfig) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.Equal(t, "br", resp.Header.Get(fiber.HeaderContentEncoding)) // Validate that the file size has shrunk body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Less(t, len(body), len(filedata)) } func Test_Compress_Zstd(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { return c.Send(filedata) }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set("Accept-Encoding", "zstd") resp, err := app.Test(req, testConfig) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.Equal(t, "zstd", resp.Header.Get(fiber.HeaderContentEncoding)) // Validate that the file size has shrunk body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Less(t, len(body), len(filedata)) } func Test_Compress_Disabled(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Level: LevelDisabled})) app.Get("/", func(c fiber.Ctx) error { return c.Send(filedata) }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set("Accept-Encoding", "br") resp, err := app.Test(req, testConfig) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.Empty(t, resp.Header.Get(fiber.HeaderContentEncoding)) // Validate the file size is not shrunk body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Len(t, body, len(filedata)) } func Test_Compress_Adds_Vary_Header(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { return c.SendString("hello") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set("Accept-Encoding", "gzip") resp, err := app.Test(req, testConfig) require.NoError(t, err, "app.Test(req)") require.Equal(t, "Accept-Encoding", resp.Header.Get(fiber.HeaderVary)) } func Test_Compress_Vary_Star(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { c.Set(fiber.HeaderVary, "*") return c.SendString("hello") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set("Accept-Encoding", "gzip") resp, err := app.Test(req, testConfig) require.NoError(t, err, "app.Test(req)") require.Equal(t, "*", resp.Header.Get(fiber.HeaderVary)) } func Test_Compress_Vary_List_Star(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { c.Set(fiber.HeaderVary, "User-Agent, *") return c.SendString("hello") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set("Accept-Encoding", "gzip") resp, err := app.Test(req, testConfig) require.NoError(t, err, "app.Test(req)") require.Equal(t, "User-Agent, *", resp.Header.Get(fiber.HeaderVary)) } func Test_Compress_Vary_Similar_Substring(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { c.Set(fiber.HeaderVary, "Accept-Encoding2") return c.SendString("hello") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set("Accept-Encoding", "gzip") resp, err := app.Test(req, testConfig) require.NoError(t, err, "app.Test(req)") require.Equal(t, "Accept-Encoding2, Accept-Encoding", resp.Header.Get(fiber.HeaderVary)) } func Test_Compress_Skip_When_Content_Encoding_Set(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { c.Set(fiber.HeaderContentEncoding, "gzip") c.Set(fiber.HeaderETag, "\"abc\"") return c.SendString("hello") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set("Accept-Encoding", "gzip") resp, err := app.Test(req, testConfig) require.NoError(t, err, "app.Test(req)") body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "hello", string(body)) require.Equal(t, "gzip", resp.Header.Get(fiber.HeaderContentEncoding)) require.Equal(t, "\"abc\"", resp.Header.Get(fiber.HeaderETag)) require.Equal(t, "Accept-Encoding", resp.Header.Get(fiber.HeaderVary)) } func Test_Compress_Skip_When_Content_Encoding_Set_Vary_Star(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { c.Set(fiber.HeaderContentEncoding, "gzip") c.Set(fiber.HeaderVary, "*") return c.SendString("hello") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set("Accept-Encoding", "gzip") resp, err := app.Test(req, testConfig) require.NoError(t, err, "app.Test(req)") require.Equal(t, "*", resp.Header.Get(fiber.HeaderVary)) } func Test_Compress_Skip_When_Content_Encoding_Set_Vary_List_Star(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { c.Set(fiber.HeaderContentEncoding, "gzip") c.Set(fiber.HeaderVary, "User-Agent, *") return c.SendString("hello") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set("Accept-Encoding", "gzip") resp, err := app.Test(req, testConfig) require.NoError(t, err, "app.Test(req)") require.Equal(t, "User-Agent, *", resp.Header.Get(fiber.HeaderVary)) } func Test_Compress_Skip_When_Content_Encoding_Set_Vary_Similar_Substring(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { c.Set(fiber.HeaderContentEncoding, "gzip") c.Set(fiber.HeaderVary, "Accept-Encoding2") return c.SendString("hello") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set("Accept-Encoding", "gzip") resp, err := app.Test(req, testConfig) require.NoError(t, err, "app.Test(req)") require.Equal(t, "Accept-Encoding2, Accept-Encoding", resp.Header.Get(fiber.HeaderVary)) } func Test_Compress_Strong_ETag_Recalculated(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { c.Set(fiber.HeaderContentType, fiber.MIMETextPlainCharsetUTF8) c.Set(fiber.HeaderETag, "\"abc\"") return c.Send(filedata) }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set("Accept-Encoding", "gzip") resp, err := app.Test(req, testConfig) require.NoError(t, err, "app.Test(req)") require.Equal(t, "gzip", resp.Header.Get(fiber.HeaderContentEncoding)) body, err := io.ReadAll(resp.Body) require.NoError(t, err) expected := string(etag.Generate(body)) require.Equal(t, expected, resp.Header.Get(fiber.HeaderETag)) } func Test_Compress_Weak_ETag_Unchanged(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { c.Set(fiber.HeaderContentType, fiber.MIMETextPlainCharsetUTF8) c.Set(fiber.HeaderETag, "W/\"abc\"") return c.Send(filedata) }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set("Accept-Encoding", "gzip") resp, err := app.Test(req, testConfig) require.NoError(t, err, "app.Test(req)") require.Equal(t, "gzip", resp.Header.Get(fiber.HeaderContentEncoding)) require.Equal(t, "W/\"abc\"", resp.Header.Get(fiber.HeaderETag)) } func Test_Compress_Strong_ETag_Unchanged_When_Not_Compressed(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { c.Set(fiber.HeaderETag, "\"abc\"") return c.SendString("tiny") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set("Accept-Encoding", "gzip") resp, err := app.Test(req, testConfig) require.NoError(t, err) require.Empty(t, resp.Header.Get(fiber.HeaderContentEncoding)) require.Equal(t, "\"abc\"", resp.Header.Get(fiber.HeaderETag)) require.Equal(t, "Accept-Encoding", resp.Header.Get(fiber.HeaderVary)) } func Test_Compress_Skip_Head(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) handler := func(c fiber.Ctx) error { c.Set(fiber.HeaderETag, "\"abc\"") return c.Send(filedata) } app.Get("/", handler) app.Head("/", handler) getReq := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) getReq.Header.Set("Accept-Encoding", "gzip") getResp, err := app.Test(getReq, testConfig) require.NoError(t, err, "app.Test(getReq)") getBody, err := io.ReadAll(getResp.Body) require.NoError(t, err) require.NotEmpty(t, getBody) require.Equal(t, "gzip", getResp.Header.Get(fiber.HeaderContentEncoding)) headReq := httptest.NewRequest(fiber.MethodHead, "/", http.NoBody) headReq.Header.Set("Accept-Encoding", "gzip") headResp, err := app.Test(headReq, testConfig) require.NoError(t, err, "app.Test(headReq)") headBody, err := io.ReadAll(headResp.Body) require.NoError(t, err) require.Empty(t, headBody) require.Empty(t, headResp.Header.Get(fiber.HeaderContentEncoding)) require.Equal(t, "Accept-Encoding", headResp.Header.Get(fiber.HeaderVary)) require.Equal(t, "\"abc\"", headResp.Header.Get(fiber.HeaderETag)) } func Test_Compress_Skip_Status_NoContent(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { c.Set(fiber.HeaderETag, "\"abc\"") return c.SendStatus(fiber.StatusNoContent) }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set("Accept-Encoding", "gzip") resp, err := app.Test(req, testConfig) require.NoError(t, err, "app.Test(req)") require.Equal(t, fiber.StatusNoContent, resp.StatusCode) require.Empty(t, resp.Header.Get(fiber.HeaderContentEncoding)) require.Equal(t, "\"abc\"", resp.Header.Get(fiber.HeaderETag)) } func Test_Compress_Skip_Status_NotModified(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { c.Set(fiber.HeaderETag, "\"abc\"") c.Status(fiber.StatusNotModified) return nil }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set("Accept-Encoding", "gzip") resp, err := app.Test(req, testConfig) require.NoError(t, err, "app.Test(req)") require.Equal(t, fiber.StatusNotModified, resp.StatusCode) require.Empty(t, resp.Header.Get(fiber.HeaderContentEncoding)) require.Equal(t, "\"abc\"", resp.Header.Get(fiber.HeaderETag)) } func Test_Compress_Skip_Range(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { return c.SendString("hello") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set("Accept-Encoding", "gzip") req.Header.Set("Range", "bytes=0-1") resp, err := app.Test(req, testConfig) require.NoError(t, err, "app.Test(req)") require.Empty(t, resp.Header.Get(fiber.HeaderContentEncoding)) require.Equal(t, "Accept-Encoding", resp.Header.Get(fiber.HeaderVary)) } func Test_Compress_Skip_Range_NoAcceptEncoding(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { return c.SendString("hello") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set("Range", "bytes=0-1") resp, err := app.Test(req, testConfig) require.NoError(t, err, "app.Test(req)") require.Empty(t, resp.Header.Get(fiber.HeaderContentEncoding)) require.Equal(t, "Accept-Encoding", resp.Header.Get(fiber.HeaderVary)) } func Test_Compress_Skip_Range_Vary_Star(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { c.Set(fiber.HeaderVary, "*") return c.SendString("hello") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set("Accept-Encoding", "gzip") req.Header.Set("Range", "bytes=0-1") resp, err := app.Test(req, testConfig) require.NoError(t, err, "app.Test(req)") require.Empty(t, resp.Header.Get(fiber.HeaderContentEncoding)) require.Equal(t, "*", resp.Header.Get(fiber.HeaderVary)) } func Test_Compress_Skip_Range_Vary_Similar_Substring(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { c.Set(fiber.HeaderVary, "Accept-Encoding2") return c.SendString("hello") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set("Accept-Encoding", "gzip") req.Header.Set("Range", "bytes=0-1") resp, err := app.Test(req, testConfig) require.NoError(t, err, "app.Test(req)") require.Empty(t, resp.Header.Get(fiber.HeaderContentEncoding)) require.Equal(t, "Accept-Encoding2, Accept-Encoding", resp.Header.Get(fiber.HeaderVary)) } func Test_Compress_Skip_Status_PartialContent(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { c.Status(fiber.StatusPartialContent) return c.SendString("hello") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set("Accept-Encoding", "gzip") resp, err := app.Test(req, testConfig) require.NoError(t, err, "app.Test(req)") require.Equal(t, fiber.StatusPartialContent, resp.StatusCode) require.Empty(t, resp.Header.Get(fiber.HeaderContentEncoding)) } func Test_Compress_Skip_NoTransform(t *testing.T) { t.Parallel() tests := []struct { name string setRequest bool }{ {name: "request", setRequest: true}, {name: "response", setRequest: false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { if !tt.setRequest { c.Set(fiber.HeaderCacheControl, "no-transform") } return c.SendString("hello") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set("Accept-Encoding", "gzip") if tt.setRequest { req.Header.Set(fiber.HeaderCacheControl, "no-transform") } resp, err := app.Test(req, testConfig) require.NoError(t, err, "app.Test(req)") require.Empty(t, resp.Header.Get(fiber.HeaderContentEncoding)) }) } } func Test_Compress_Next_Error(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(_ fiber.Ctx) error { return errors.New("next error") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set("Accept-Encoding", "gzip") resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") require.Equal(t, 500, resp.StatusCode, "Status code") require.Empty(t, resp.Header.Get(fiber.HeaderContentEncoding)) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "next error", string(body)) } // go test -run Test_Compress_Next func Test_Compress_Next(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Next: func(_ fiber.Ctx) bool { return true }, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusNotFound, resp.StatusCode) } // go test -bench=Benchmark_Compress func Benchmark_Compress(b *testing.B) { tests := []struct { name string acceptEncoding string }{ {name: "Gzip", acceptEncoding: "gzip"}, {name: "Deflate", acceptEncoding: "deflate"}, {name: "Brotli", acceptEncoding: "br"}, {name: "Zstd", acceptEncoding: "zstd"}, } for _, tt := range tests { b.Run(tt.name, func(b *testing.B) { app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { c.Set(fiber.HeaderContentType, fiber.MIMETextPlainCharsetUTF8) return c.Send(filedata) }) h := app.Handler() fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(fiber.MethodGet) fctx.Request.SetRequestURI("/") if tt.acceptEncoding != "" { fctx.Request.Header.Set("Accept-Encoding", tt.acceptEncoding) } b.ReportAllocs() for b.Loop() { h(fctx) } }) } } // go test -bench=Benchmark_Compress_Levels func Benchmark_Compress_Levels(b *testing.B) { tests := []struct { name string acceptEncoding string }{ {name: "Gzip", acceptEncoding: "gzip"}, {name: "Deflate", acceptEncoding: "deflate"}, {name: "Brotli", acceptEncoding: "br"}, {name: "Zstd", acceptEncoding: "zstd"}, } levels := []struct { name string level Level }{ {name: "LevelDisabled", level: LevelDisabled}, {name: "LevelDefault", level: LevelDefault}, {name: "LevelBestSpeed", level: LevelBestSpeed}, {name: "LevelBestCompression", level: LevelBestCompression}, } for _, tt := range tests { for _, lvl := range levels { b.Run(tt.name+"_"+lvl.name, func(b *testing.B) { app := fiber.New() app.Use(New(Config{Level: lvl.level})) app.Get("/", func(c fiber.Ctx) error { c.Set(fiber.HeaderContentType, fiber.MIMETextPlainCharsetUTF8) return c.Send(filedata) }) h := app.Handler() fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(fiber.MethodGet) fctx.Request.SetRequestURI("/") if tt.acceptEncoding != "" { fctx.Request.Header.Set("Accept-Encoding", tt.acceptEncoding) } b.ReportAllocs() for b.Loop() { h(fctx) } }) } } } // go test -bench=Benchmark_Compress_Parallel func Benchmark_Compress_Parallel(b *testing.B) { tests := []struct { name string acceptEncoding string }{ {name: "Gzip", acceptEncoding: "gzip"}, {name: "Deflate", acceptEncoding: "deflate"}, {name: "Brotli", acceptEncoding: "br"}, {name: "Zstd", acceptEncoding: "zstd"}, } for _, tt := range tests { b.Run(tt.name, func(b *testing.B) { app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { c.Set(fiber.HeaderContentType, fiber.MIMETextPlainCharsetUTF8) return c.Send(filedata) }) h := app.Handler() b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(fiber.MethodGet) fctx.Request.SetRequestURI("/") if tt.acceptEncoding != "" { fctx.Request.Header.Set("Accept-Encoding", tt.acceptEncoding) } for pb.Next() { h(fctx) } }) }) } } // go test -bench=Benchmark_Compress_Levels_Parallel func Benchmark_Compress_Levels_Parallel(b *testing.B) { tests := []struct { name string acceptEncoding string }{ {name: "Gzip", acceptEncoding: "gzip"}, {name: "Deflate", acceptEncoding: "deflate"}, {name: "Brotli", acceptEncoding: "br"}, {name: "Zstd", acceptEncoding: "zstd"}, } levels := []struct { name string level Level }{ {name: "LevelDisabled", level: LevelDisabled}, {name: "LevelDefault", level: LevelDefault}, {name: "LevelBestSpeed", level: LevelBestSpeed}, {name: "LevelBestCompression", level: LevelBestCompression}, } for _, tt := range tests { for _, lvl := range levels { b.Run(tt.name+"_"+lvl.name, func(b *testing.B) { app := fiber.New() app.Use(New(Config{Level: lvl.level})) app.Get("/", func(c fiber.Ctx) error { c.Set(fiber.HeaderContentType, fiber.MIMETextPlainCharsetUTF8) return c.Send(filedata) }) h := app.Handler() b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(fiber.MethodGet) fctx.Request.SetRequestURI("/") if tt.acceptEncoding != "" { fctx.Request.Header.Set("Accept-Encoding", tt.acceptEncoding) } for pb.Next() { h(fctx) } }) }) } } } ================================================ FILE: middleware/compress/config.go ================================================ package compress import ( "github.com/gofiber/fiber/v3" ) // Config defines the config for middleware. type Config struct { // Next defines a function to skip this middleware when returned true. // // Optional. Default: nil Next func(c fiber.Ctx) bool // Level sets the compression level used for the response // // Optional. Default: LevelDefault // LevelDisabled: -1 // LevelDefault: 0 // LevelBestSpeed: 1 // LevelBestCompression: 2 Level Level } // Level is numeric representation of compression level type Level int // Represents compression level that will be used in the middleware const ( LevelDisabled Level = -1 LevelDefault Level = 0 LevelBestSpeed Level = 1 LevelBestCompression Level = 2 ) // ConfigDefault is the default config var ConfigDefault = Config{ Next: nil, Level: LevelDefault, } // Helper function to set default values func configDefault(config ...Config) Config { // Return default config if nothing provided if len(config) < 1 { return ConfigDefault } // Override default config cfg := config[0] // Set default values if cfg.Level < LevelDisabled || cfg.Level > LevelBestCompression { cfg.Level = ConfigDefault.Level } return cfg } ================================================ FILE: middleware/cors/config.go ================================================ package cors import ( "github.com/gofiber/fiber/v3" ) // Config defines the config for middleware. type Config struct { // Next defines a function to skip this middleware when returned true. // // Optional. Default: nil Next func(c fiber.Ctx) bool // AllowOriginsFunc defines a function that will set the 'Access-Control-Allow-Origin' // response header to the 'origin' request header when returned true. This allows for // dynamic evaluation of allowed origins. Note if AllowCredentials is true, wildcard origins // will be not have the 'Access-Control-Allow-Credentials' header set to 'true'. // // The function receives serialized origins (scheme + host) or the literal "null" string. // According to the CORS specification (https://developer.mozilla.org/en-US/docs/Web/HTTP/Guides/CORS#origin), // browsers send "null" for privacy-sensitive contexts like sandboxed iframes or file:// URLs. // // Origins with userinfo, paths, queries, fragments, or wildcards are rejected and will not // be passed to this function. // // Optional. Default: nil AllowOriginsFunc func(origin string) bool // AllowOrigin defines a list of origins that may access the resource. // // This supports wildcard matching for subdomains by prefixing the domain with a `*.` // e.g. "http://.domain.com". This will allow all level of subdomains of domain.com to access the resource. // // If the special wildcard `"*"` is present in the list, all origins will be allowed. // // Optional. Default value []string{} AllowOrigins []string // AllowMethods defines a list methods allowed when accessing the resource. // This is used in response to a preflight request. // // Optional. Default value []string{"GET", "POST", "HEAD", "PUT", "DELETE", "PATCH"} AllowMethods []string // AllowHeaders defines a list of request headers that can be used when // making the actual request. This is in response to a preflight request. // // Optional. Default value []string{} AllowHeaders []string // ExposeHeaders defines an allowlist of headers that clients are allowed to // access. // // Optional. Default value []string{}. ExposeHeaders []string // MaxAge indicates how long (in seconds) the results of a preflight request // can be cached. // If you pass MaxAge 0, Access-Control-Max-Age header will not be added and // browser will use 5 seconds by default. // To disable caching completely, pass MaxAge value negative. It will set the Access-Control-Max-Age header 0. // // Optional. Default value 0. MaxAge int // DisableValueRedaction turns off redaction of configuration values and origins in logs and panics. // // Optional. Default: false DisableValueRedaction bool // AllowCredentials indicates whether or not the response to the request // can be exposed when the credentials flag is true. When used as part of // a response to a preflight request, this indicates whether or not the // actual request can be made using credentials. Note: if true, the // AllowOrigins setting cannot contain the wildcard "*" to prevent // security vulnerabilities. // // Optional. Default value false. AllowCredentials bool // AllowPrivateNetwork indicates whether the Access-Control-Allow-Private-Network // response header should be set to true, allowing requests from private networks. // // Optional. Default value false. AllowPrivateNetwork bool } // ConfigDefault is the default config var ConfigDefault = Config{ Next: nil, AllowOriginsFunc: nil, AllowOrigins: []string{"*"}, DisableValueRedaction: false, AllowMethods: []string{ fiber.MethodGet, fiber.MethodPost, fiber.MethodHead, fiber.MethodPut, fiber.MethodDelete, fiber.MethodPatch, }, AllowHeaders: []string{}, AllowCredentials: false, ExposeHeaders: []string{}, MaxAge: 0, AllowPrivateNetwork: false, } ================================================ FILE: middleware/cors/cors.go ================================================ package cors import ( "slices" "strconv" "strings" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/log" "github.com/gofiber/utils/v2" utilsstrings "github.com/gofiber/utils/v2/strings" ) const redactedValue = "[redacted]" // isOriginSerializedOrNull checks if the origin is a serialized origin or the literal "null". // It returns two booleans: (isSerialized, isNull). func isOriginSerializedOrNull(originHeaderRaw string) (isSerialized, isNull bool) { //nolint:nonamedreturns // gocritic unnamedResult prefers naming serialization and null status results if originHeaderRaw == "null" { return false, true } originIsSerialized, _ := normalizeOrigin(originHeaderRaw) return originIsSerialized, false } // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Set default config cfg := ConfigDefault // Override config if provided if len(config) > 0 { cfg = config[0] // Set default values if len(cfg.AllowMethods) == 0 { cfg.AllowMethods = ConfigDefault.AllowMethods } } redactValues := !cfg.DisableValueRedaction maskValue := func(value string) string { if redactValues { return redactedValue } return value } // Warning logs if both AllowOrigins and AllowOriginsFunc are set if len(cfg.AllowOrigins) > 0 && cfg.AllowOriginsFunc != nil { log.Warn("[CORS] Both 'AllowOrigins' and 'AllowOriginsFunc' have been defined.") } // allowOrigins is a slice of strings that contains the allowed origins // defined in the 'AllowOrigins' configuration. allowOrigins := []string{} allowSubOrigins := []subdomain{} // Validate and normalize static AllowOrigins allowAllOrigins := len(cfg.AllowOrigins) == 0 && cfg.AllowOriginsFunc == nil for _, origin := range cfg.AllowOrigins { if origin == "*" { allowAllOrigins = true break } trimmedOrigin := utils.TrimSpace(origin) if before, after, found := strings.Cut(trimmedOrigin, "://*."); found { withoutWildcard := before + "://" + after isValid, normalizedOrigin := normalizeOrigin(withoutWildcard) if !isValid { panic("[CORS] Invalid origin format in configuration: " + maskValue(trimmedOrigin)) } scheme, host, ok := strings.Cut(normalizedOrigin, "://") if !ok { panic("[CORS] Invalid origin format after normalization:" + maskValue(trimmedOrigin)) } sd := subdomain{prefix: scheme + "://", suffix: host} allowSubOrigins = append(allowSubOrigins, sd) } else { isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin) if !isValid { panic("[CORS] Invalid origin format in configuration: " + maskValue(trimmedOrigin)) } allowOrigins = append(allowOrigins, normalizedOrigin) } } // Validate CORS credentials configuration if cfg.AllowCredentials && allowAllOrigins { panic("[CORS] Configuration error: When 'AllowCredentials' is set to true, 'AllowOrigins' cannot contain a wildcard origin '*'. Please specify allowed origins explicitly or adjust 'AllowCredentials' setting.") } // Warn if allowAllOrigins is set to true and AllowOriginsFunc is defined if allowAllOrigins && cfg.AllowOriginsFunc != nil { log.Warn("[CORS] 'AllowOrigins' is set to allow all origins, 'AllowOriginsFunc' will not be used.") } // Convert int to string maxAge := strconv.Itoa(cfg.MaxAge) // Return new handler return func(c fiber.Ctx) error { // Don't execute middleware if Next returns true if cfg.Next != nil && cfg.Next(c) { return c.Next() } // Get origin header preserving the original case for the response originHeaderRaw := c.Get(fiber.HeaderOrigin) originHeader := utilsstrings.ToLower(originHeaderRaw) // If the request does not have Origin header, the request is outside the scope of CORS if originHeader == "" { // See https://fetch.spec.whatwg.org/#cors-protocol-and-http-caches // Unless all origins are allowed, we include the Vary header to cache the response correctly if !allowAllOrigins { c.Vary(fiber.HeaderOrigin) } return c.Next() } // If it's a preflight request and doesn't have Access-Control-Request-Method header, it's outside the scope of CORS if c.Method() == fiber.MethodOptions && c.Get(fiber.HeaderAccessControlRequestMethod) == "" { // Response to OPTIONS request should not be cached but, // some caching can be configured to cache such responses. // To Avoid poisoning the cache, we include the Vary header // for non-CORS OPTIONS requests: c.Vary(fiber.HeaderOrigin) return c.Next() } // Set default allowOrigin to empty string allowOrigin := "" // Check allowed origins if allowAllOrigins { allowOrigin = "*" } else { // Check if the origin is in the list of allowed origins if slices.Contains(allowOrigins, originHeader) { allowOrigin = originHeaderRaw } // Check if the origin is in the list of allowed subdomains if allowOrigin == "" { for _, sOrigin := range allowSubOrigins { if sOrigin.match(originHeader) { allowOrigin = originHeaderRaw break } } } } // Run AllowOriginsFunc if the logic for // handling the value in 'AllowOrigins' does // not result in allowOrigin being set. if allowOrigin == "" && cfg.AllowOriginsFunc != nil && cfg.AllowOriginsFunc(originHeaderRaw) { originIsSerialized, originIsNull := isOriginSerializedOrNull(originHeaderRaw) if originIsSerialized || originIsNull { allowOrigin = originHeaderRaw } } // Simple request // Omit allowMethods and allowHeaders, only used for pre-flight requests if c.Method() != fiber.MethodOptions { if !allowAllOrigins { // See https://fetch.spec.whatwg.org/#cors-protocol-and-http-caches c.Vary(fiber.HeaderOrigin) } setSimpleHeaders(c, allowOrigin, &cfg) return c.Next() } // Pre-flight request // Response to OPTIONS request should not be cached but, // some caching can be configured to cache such responses. // To Avoid poisoning the cache, we include the Vary header // of preflight responses: c.Vary(fiber.HeaderAccessControlRequestMethod) c.Vary(fiber.HeaderAccessControlRequestHeaders) if cfg.AllowPrivateNetwork && c.Get(fiber.HeaderAccessControlRequestPrivateNetwork) == "true" { c.Vary(fiber.HeaderAccessControlRequestPrivateNetwork) c.Set(fiber.HeaderAccessControlAllowPrivateNetwork, "true") } c.Vary(fiber.HeaderOrigin) setPreflightHeaders(c, allowOrigin, maxAge, &cfg) // Set Preflight headers if len(cfg.AllowMethods) > 0 { c.Set(fiber.HeaderAccessControlAllowMethods, strings.Join(cfg.AllowMethods, ", ")) } if len(cfg.AllowHeaders) > 0 { c.Set(fiber.HeaderAccessControlAllowHeaders, strings.Join(cfg.AllowHeaders, ", ")) } else { h := c.Get(fiber.HeaderAccessControlRequestHeaders) if h != "" { c.Set(fiber.HeaderAccessControlAllowHeaders, h) } } // Send 204 No Content return c.SendStatus(fiber.StatusNoContent) } } // Function to set Simple CORS headers func setSimpleHeaders(c fiber.Ctx, allowOrigin string, cfg *Config) { if cfg == nil { return } if cfg.AllowCredentials { // When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*' if allowOrigin == "*" { c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin) log.Warn("[CORS] 'AllowCredentials' is true, but 'AllowOrigins' cannot be set to '*'.") } else if allowOrigin != "" { c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin) c.Set(fiber.HeaderAccessControlAllowCredentials, "true") } } else if allowOrigin != "" { // For non-credential requests, it's safe to set to '*' or specific origins c.Set(fiber.HeaderAccessControlAllowOrigin, allowOrigin) } // Set Expose-Headers if not empty if len(cfg.ExposeHeaders) > 0 { c.Set(fiber.HeaderAccessControlExposeHeaders, strings.Join(cfg.ExposeHeaders, ", ")) } } // Function to set Preflight CORS headers func setPreflightHeaders(c fiber.Ctx, allowOrigin, maxAge string, cfg *Config) { setSimpleHeaders(c, allowOrigin, cfg) // Set MaxAge if set if cfg != nil && cfg.MaxAge > 0 { c.Set(fiber.HeaderAccessControlMaxAge, maxAge) } else if cfg != nil && cfg.MaxAge < 0 { c.Set(fiber.HeaderAccessControlMaxAge, "0") } } ================================================ FILE: middleware/cors/cors_test.go ================================================ package cors import ( "bytes" "net/http" "net/http/httptest" "os" "strings" "testing" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/log" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) func Test_CORS_Defaults(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) testDefaultOrEmptyConfig(t, app) } func Test_CORS_Empty_Config(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{})) testDefaultOrEmptyConfig(t, app) } func Test_CORS_WildcardHeaders(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ AllowMethods: []string{"*"}, AllowHeaders: []string{"*"}, ExposeHeaders: []string{"*"}, })) h := app.Handler() // Test preflight request ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodOptions) ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") h(ctx) require.Equal(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials))) require.Equal(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowMethods))) require.Equal(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders))) require.Equal(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlExposeHeaders))) } func Test_CORS_Negative_MaxAge(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{MaxAge: -1})) ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodOptions) ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") app.Handler()(ctx) require.Equal(t, "0", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge))) } func Test_CORS_MaxAge_NotSetOnSimpleRequest(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{MaxAge: 100})) ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") app.Handler()(ctx) require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge))) } func Test_CORS_Preserve_Origin_Case(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{AllowOrigins: []string{"http://example.com"}})) origin := "HTTP://EXAMPLE.COM" ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodOptions) ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, origin) app.Handler()(ctx) require.Equal(t, origin, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) } func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) { t.Helper() h := app.Handler() // Test default GET response headers ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") h(ctx) require.Equal(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials))) require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlExposeHeaders))) // Test default OPTIONS (preflight) response headers ctx = &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodOptions) ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") h(ctx) require.Equal(t, "GET, POST, HEAD, PUT, DELETE, PATCH", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowMethods))) require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders))) require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge))) } func Test_CORS_AllowOrigins_Vary(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New( Config{ AllowOrigins: []string{"http://localhost"}, }, )) h := app.Handler() // Test Vary header non-Cors request ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) require.Contains(t, string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin, "Vary header should be set") // Test Vary header Cors request ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodOptions) ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") h(ctx) require.Contains(t, string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin, "Vary header should be set") } // go test -run -v Test_CORS_Wildcard func Test_CORS_Wildcard(t *testing.T) { t.Parallel() // New fiber instance app := fiber.New() // OPTIONS (preflight) response headers when AllowOrigins is * app.Use(New(Config{ MaxAge: 3600, ExposeHeaders: []string{"X-Request-ID"}, AllowHeaders: []string{"Authentication"}, })) // Get handler pointer handler := app.Handler() // Make request ctx := &fasthttp.RequestCtx{} ctx.Request.SetRequestURI("/") ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") ctx.Request.Header.SetMethod(fiber.MethodOptions) ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) // Perform request handler(ctx) // Check result require.Equal(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) // Validates request is not reflecting origin in the response require.Contains(t, string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin, "Vary header should be set") require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials))) require.Equal(t, "3600", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge))) require.Equal(t, "Authentication", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders))) // Test non OPTIONS (preflight) response headers ctx = &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") handler(ctx) require.NotContains(t, string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin, "Vary header should not be set") require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials))) require.Equal(t, "X-Request-ID", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlExposeHeaders))) } // go test -run -v Test_CORS_Origin_AllowCredentials func Test_CORS_Origin_AllowCredentials(t *testing.T) { t.Parallel() // New fiber instance app := fiber.New() // OPTIONS (preflight) response headers when AllowOrigins is * app.Use(New(Config{ AllowOrigins: []string{"http://localhost"}, AllowCredentials: true, MaxAge: 3600, ExposeHeaders: []string{"X-Request-ID"}, AllowHeaders: []string{"Authentication"}, })) // Get handler pointer handler := app.Handler() // Make request ctx := &fasthttp.RequestCtx{} ctx.Request.SetRequestURI("/") ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") ctx.Request.Header.SetMethod(fiber.MethodOptions) ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) // Perform request handler(ctx) // Check result require.Equal(t, "http://localhost", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) require.Equal(t, "true", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials))) require.Equal(t, "3600", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge))) require.Equal(t, "Authentication", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders))) // Test non OPTIONS (preflight) response headers ctx = &fasthttp.RequestCtx{} ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost") ctx.Request.Header.SetMethod(fiber.MethodGet) handler(ctx) require.Equal(t, "true", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials))) require.Equal(t, "X-Request-ID", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlExposeHeaders))) } // go test -run -v Test_CORS_Wildcard_AllowCredentials_Panic // Test for fiber-ghsa-fmg4-x8pw-hjhg func Test_CORS_Wildcard_AllowCredentials_Panic(t *testing.T) { t.Parallel() // New fiber instance app := fiber.New() didPanic := false func() { defer func() { if r := recover(); r != nil { didPanic = true } }() app.Use(New(Config{ AllowOrigins: []string{"*"}, AllowCredentials: true, })) }() if !didPanic { t.Error("Expected a panic when AllowOrigins is '*' and AllowCredentials is true") } } // Test that a warning is logged when AllowOrigins allows all origins and // AllowOriginsFunc is also provided. func Test_CORS_Warn_AllowAllOrigins_WithFunc(t *testing.T) { var buf bytes.Buffer log.SetOutput(&buf) t.Cleanup(func() { log.SetOutput(os.Stderr) }) fiber.New().Use(New(Config{ AllowOrigins: []string{"*"}, AllowOriginsFunc: func(string) bool { return true }, })) require.Contains(t, buf.String(), "AllowOriginsFunc' will not be used") } // go test -run -v Test_CORS_Invalid_Origin_Panic func Test_CORS_Invalid_Origins_Panic(t *testing.T) { t.Parallel() invalidOrigins := []string{ "localhost", "http://foo.[a-z]*.example.com", "http://*", "https://*", "http://*.com*", "invalid url", "*", "http://origin.com,invalid url", // add more invalid origins as needed } for _, origin := range invalidOrigins { // New fiber instance app := fiber.New() didPanic := false func() { defer func() { if r := recover(); r != nil { didPanic = true } }() app.Use(New(Config{ AllowOrigins: []string{origin}, AllowCredentials: true, })) }() if !didPanic { t.Errorf("Expected a panic for invalid origin: %s", origin) } } } func Test_CORS_DisableValueRedaction(t *testing.T) { t.Parallel() require.PanicsWithValue(t, "[CORS] Invalid origin format in configuration: [redacted]", func() { New(Config{ AllowOrigins: []string{"http://"}, DisableValueRedaction: false, }) }) require.PanicsWithValue(t, "[CORS] Invalid origin format in configuration: http://", func() { New(Config{ AllowOrigins: []string{"http://"}, DisableValueRedaction: true, }) }) } // go test -run -v Test_CORS_Subdomain func Test_CORS_Subdomain(t *testing.T) { t.Parallel() // New fiber instance app := fiber.New() // OPTIONS (preflight) response headers when AllowOrigins is set to a subdomain app.Use("/", New(Config{ AllowOrigins: []string{" http://*.example.com "}, })) // Get handler pointer handler := app.Handler() // Make request with disallowed origin ctx := &fasthttp.RequestCtx{} ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com") // Perform request handler(ctx) // Allow-Origin header should be "" because http://google.com does not satisfy http://*.example.com require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) ctx.Request.Reset() ctx.Response.Reset() // Make request with domain only (disallowed) ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com") handler(ctx) require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) ctx.Request.Reset() ctx.Response.Reset() // Make request with malformed subdomain (disallowed) ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://evil.comexample.com") handler(ctx) require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) ctx.Request.Reset() ctx.Response.Reset() // Make request with allowed origin ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://test.example.com") handler(ctx) require.Equal(t, "http://test.example.com", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) } func Test_CORS_AllowOriginScheme(t *testing.T) { t.Parallel() tests := []struct { reqOrigin string pattern []string shouldAllowOrigin bool }{ { pattern: []string{"http://example.com"}, reqOrigin: "http://example.com", shouldAllowOrigin: true, }, { pattern: []string{"HTTP://EXAMPLE.COM"}, reqOrigin: "http://example.com", shouldAllowOrigin: true, }, { pattern: []string{"https://example.com"}, reqOrigin: "https://example.com", shouldAllowOrigin: true, }, { pattern: []string{"http://example.com"}, reqOrigin: "https://example.com", shouldAllowOrigin: false, }, { pattern: []string{"http://*.example.com"}, reqOrigin: "http://aaa.example.com", shouldAllowOrigin: true, }, { pattern: []string{"http://*.example.com"}, reqOrigin: "http://bbb.aaa.example.com", shouldAllowOrigin: true, }, { pattern: []string{"http://*.aaa.example.com"}, reqOrigin: "http://bbb.aaa.example.com", shouldAllowOrigin: true, }, { pattern: []string{"http://*.example.com:8080"}, reqOrigin: "http://aaa.example.com:8080", shouldAllowOrigin: true, }, { pattern: []string{"http://*.example.com"}, reqOrigin: "http://1.2.aaa.example.com", shouldAllowOrigin: true, }, { pattern: []string{"http://example.com"}, reqOrigin: "http://gofiber.com", shouldAllowOrigin: false, }, { pattern: []string{"http://*.aaa.example.com"}, reqOrigin: "http://ccc.bbb.example.com", shouldAllowOrigin: false, }, { pattern: []string{"http://*.example.com"}, reqOrigin: "http://1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.example.com", shouldAllowOrigin: true, }, { pattern: []string{"http://example.com"}, reqOrigin: "http://ccc.bbb.example.com", shouldAllowOrigin: false, }, { pattern: []string{"https://--aaa.bbb.com"}, reqOrigin: "https://prod-preview--aaa.bbb.com", shouldAllowOrigin: false, }, { pattern: []string{"http://*.example.com"}, reqOrigin: "http://ccc.bbb.example.com", shouldAllowOrigin: true, }, { pattern: []string{"http://domain-1.com", "http://example.com"}, reqOrigin: "http://example.com", shouldAllowOrigin: true, }, { pattern: []string{"http://domain-1.com", "http://example.com"}, reqOrigin: "http://domain-2.com", shouldAllowOrigin: false, }, { pattern: []string{"http://domain-1.com", "http://example.com"}, reqOrigin: "http://example.com", shouldAllowOrigin: true, }, { pattern: []string{"http://domain-1.com", "http://example.com"}, reqOrigin: "http://domain-2.com", shouldAllowOrigin: false, }, { pattern: []string{"http://domain-1.com", "http://example.com"}, reqOrigin: "http://domain-1.com", shouldAllowOrigin: true, }, } for _, tt := range tests { app := fiber.New() app.Use("/", New(Config{AllowOrigins: tt.pattern})) handler := app.Handler() ctx := &fasthttp.RequestCtx{} ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, tt.reqOrigin) handler(ctx) if tt.shouldAllowOrigin { require.Equal(t, tt.reqOrigin, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) } else { require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) } } } func Test_CORS_AllowOriginHeader_NoMatch(t *testing.T) { t.Parallel() // New fiber instance app := fiber.New() app.Use("/", New(Config{ AllowOrigins: []string{"http://example-1.com", "https://example-1.com"}, })) // Get handler pointer handler := app.Handler() // Make request with disallowed origin ctx := &fasthttp.RequestCtx{} ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com") // Perform request handler(ctx) var headerExists bool for key := range ctx.Response.Header.All() { if string(key) == fiber.HeaderAccessControlAllowOrigin { headerExists = true } } require.False(t, headerExists, "Access-Control-Allow-Origin header should not be set") } // go test -run Test_CORS_Next func Test_CORS_Next(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Next: func(_ fiber.Ctx) bool { return true }, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusNotFound, resp.StatusCode) } // go test -run Test_CORS_Headers_BasedOnRequestType func Test_CORS_Headers_BasedOnRequestType(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Use(func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) methods := []string{ fiber.MethodGet, fiber.MethodPost, fiber.MethodPut, fiber.MethodDelete, fiber.MethodPatch, fiber.MethodHead, } // Get handler pointer handler := app.Handler() t.Run("Without origin", func(t *testing.T) { t.Parallel() // Make request without origin header, and without Access-Control-Request-Method for _, method := range methods { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(method) ctx.Request.SetRequestURI("https://example.com/") handler(ctx) require.Equal(t, 200, ctx.Response.StatusCode(), "Status code should be 200") require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should not be set") } }) t.Run("Preflight request with origin and Access-Control-Request-Method headers", func(t *testing.T) { t.Parallel() // Make preflight request with origin header and with Access-Control-Request-Method for _, method := range methods { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodOptions) ctx.Request.SetRequestURI("https://example.com/") ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com") ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, method) handler(ctx) require.Equal(t, 204, ctx.Response.StatusCode(), "Status code should be 204") require.Equal(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should be set") require.Equal(t, "GET, POST, HEAD, PUT, DELETE, PATCH", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowMethods)), "Access-Control-Allow-Methods header should be set (preflight request)") require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)), "Access-Control-Allow-Headers header should be set (preflight request)") } }) t.Run("Non-preflight request with origin", func(t *testing.T) { t.Parallel() // Make non-preflight request with origin header and with Access-Control-Request-Method for _, method := range methods { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(method) ctx.Request.SetRequestURI("https://example.com/api/action") ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com") handler(ctx) require.Equal(t, 200, ctx.Response.StatusCode(), "Status code should be 200") require.Equal(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should be set") require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowMethods)), "Access-Control-Allow-Methods header should not be set (non-preflight request)") require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)), "Access-Control-Allow-Headers header should not be set (non-preflight request)") } }) t.Run("Preflight with Access-Control-Request-Headers", func(t *testing.T) { t.Parallel() // Make preflight request with origin header and with Access-Control-Request-Method for _, method := range methods { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodOptions) ctx.Request.SetRequestURI("https://example.com/") ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com") ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, method) ctx.Request.Header.Set(fiber.HeaderAccessControlRequestHeaders, "X-Custom-Header") handler(ctx) require.Equal(t, 204, ctx.Response.StatusCode(), "Status code should be 204") require.Equal(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should be set") require.Equal(t, "GET, POST, HEAD, PUT, DELETE, PATCH", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowMethods)), "Access-Control-Allow-Methods header should be set (preflight request)") require.Equal(t, "X-Custom-Header", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)), "Access-Control-Allow-Headers header should be set (preflight request)") } }) } func Test_CORS_AllowOriginsAndAllowOriginsFunc(t *testing.T) { t.Parallel() // New fiber instance app := fiber.New() app.Use("/", New(Config{ AllowOrigins: []string{"http://example-1.com"}, AllowOriginsFunc: func(origin string) bool { return strings.Contains(origin, "example-2") }, })) // Get handler pointer handler := app.Handler() // Make request with disallowed origin ctx := &fasthttp.RequestCtx{} ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com") // Perform request handler(ctx) // Allow-Origin header should be "" because http://google.com does not satisfy http://example-1.com or 'strings.Contains(origin, "example-2")' require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) ctx.Request.Reset() ctx.Response.Reset() // Make request with allowed origin ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-1.com") handler(ctx) require.Equal(t, "http://example-1.com", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) ctx.Request.Reset() ctx.Response.Reset() // Make request with allowed origin ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-2.com") handler(ctx) require.Equal(t, "http://example-2.com", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) } func Test_CORS_AllowOriginsFunc(t *testing.T) { t.Parallel() // New fiber instance app := fiber.New() app.Use("/", New(Config{ AllowOriginsFunc: func(origin string) bool { return strings.Contains(origin, "example-2") }, })) // Get handler pointer handler := app.Handler() // Make request with disallowed origin ctx := &fasthttp.RequestCtx{} ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com") // Perform request handler(ctx) // Allow-Origin header should be empty because http://google.com does not satisfy 'strings.Contains(origin, "example-2")' // and AllowOrigins has not been set require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) ctx.Request.Reset() ctx.Response.Reset() // Make request with allowed origin ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-2.com") handler(ctx) // Allow-Origin header should be "http://example-2.com" require.Equal(t, "http://example-2.com", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) } func Test_CORS_AllowOriginsFuncRejectsNonSerializedOrigins(t *testing.T) { t.Parallel() testCases := []struct { Name string Origin string Method string SetPreflightMethod bool ExpectAllowed bool }{ { Name: "UserInfoPresent", Origin: "http://user:pass@example.com", Method: fiber.MethodGet, }, { Name: "PathPresent", Origin: "http://example.com/path", Method: fiber.MethodOptions, SetPreflightMethod: true, }, { Name: "QueryPresent", Origin: "http://example.com?query=1", Method: fiber.MethodOptions, SetPreflightMethod: true, }, { Name: "FragmentPresent", Origin: "http://example.com#section", Method: fiber.MethodGet, }, { Name: "WildcardHost", Origin: "http://*.example.com", Method: fiber.MethodGet, }, { Name: "StandaloneWildcard", Origin: "*", Method: fiber.MethodGet, }, { Name: "NullOriginUppercase", Origin: "NULL", Method: fiber.MethodGet, }, { Name: "NullOriginMixedCase", Origin: "Null", Method: fiber.MethodGet, }, { Name: "NullOriginLowercase", Origin: "null", Method: fiber.MethodGet, ExpectAllowed: true, }, } for _, tc := range testCases { t.Run(tc.Name, func(t *testing.T) { t.Parallel() app := fiber.New() app.Use("/", New(Config{ AllowOriginsFunc: func(string) bool { return true }, })) app.All("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) handler := app.Handler() ctx := &fasthttp.RequestCtx{} ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(tc.Method) ctx.Request.Header.Set(fiber.HeaderOrigin, tc.Origin) if tc.SetPreflightMethod { ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) } handler(ctx) if tc.ExpectAllowed { require.Equal(t, "null", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) } else { require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) } if tc.Method == fiber.MethodOptions { require.Equal(t, fiber.StatusNoContent, ctx.Response.StatusCode()) } else { require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) } }) } } func Test_CORS_AllowOriginsAndAllowOriginsFunc_AllUseCases(t *testing.T) { testCases := []struct { Name string RequestOrigin string ResponseOrigin string Config Config }{ { Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/OriginAllowed", Config: Config{ AllowOrigins: []string{"http://aaa.com"}, AllowOriginsFunc: nil, }, RequestOrigin: "http://aaa.com", ResponseOrigin: "http://aaa.com", }, { Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/MultipleOrigins/OriginAllowed", Config: Config{ AllowOrigins: []string{"http://aaa.com", "http://bbb.com"}, AllowOriginsFunc: nil, }, RequestOrigin: "http://bbb.com", ResponseOrigin: "http://bbb.com", }, { Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/MultipleOrigins/OriginNotAllowed", Config: Config{ AllowOrigins: []string{"http://aaa.com", "http://bbb.com"}, AllowOriginsFunc: nil, }, RequestOrigin: "http://ccc.com", ResponseOrigin: "", }, { Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/MultipleOrigins/Whitespace/OriginAllowed", Config: Config{ AllowOrigins: []string{" http://aaa.com ", " http://bbb.com "}, AllowOriginsFunc: nil, }, RequestOrigin: "http://aaa.com", ResponseOrigin: "http://aaa.com", }, { Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/OriginNotAllowed", Config: Config{ AllowOrigins: []string{"http://aaa.com"}, AllowOriginsFunc: nil, }, RequestOrigin: "http://bbb.com", ResponseOrigin: "", }, { Name: "AllowOriginsDefined/AllowOriginsFuncReturnsTrue/OriginAllowed", Config: Config{ AllowOrigins: []string{"http://aaa.com"}, AllowOriginsFunc: func(_ string) bool { return true }, }, RequestOrigin: "http://aaa.com", ResponseOrigin: "http://aaa.com", }, { Name: "AllowOriginsDefined/AllowOriginsFuncReturnsTrue/OriginNotAllowed", Config: Config{ AllowOrigins: []string{"http://aaa.com"}, AllowOriginsFunc: func(_ string) bool { return true }, }, RequestOrigin: "http://bbb.com", ResponseOrigin: "http://bbb.com", }, { Name: "AllowOriginsDefined/AllowOriginsFuncReturnsFalse/OriginAllowed", Config: Config{ AllowOrigins: []string{"http://aaa.com"}, AllowOriginsFunc: func(_ string) bool { return false }, }, RequestOrigin: "http://aaa.com", ResponseOrigin: "http://aaa.com", }, { Name: "AllowOriginsDefined/AllowOriginsFuncReturnsFalse/OriginNotAllowed", Config: Config{ AllowOrigins: []string{"http://aaa.com"}, AllowOriginsFunc: func(_ string) bool { return false }, }, RequestOrigin: "http://bbb.com", ResponseOrigin: "", }, { Name: "AllowOriginsEmpty/AllowOriginsFuncUndefined/OriginAllowed", Config: Config{ AllowOriginsFunc: nil, }, RequestOrigin: "http://aaa.com", ResponseOrigin: "*", }, { Name: "AllowOriginsEmpty/AllowOriginsFuncReturnsTrue/OriginAllowed", Config: Config{ AllowOriginsFunc: func(_ string) bool { return true }, }, RequestOrigin: "http://aaa.com", ResponseOrigin: "http://aaa.com", }, { Name: "AllowOriginsEmpty/AllowOriginsFuncReturnsFalse/OriginNotAllowed", Config: Config{ AllowOriginsFunc: func(_ string) bool { return false }, }, RequestOrigin: "http://aaa.com", ResponseOrigin: "", }, } for _, tc := range testCases { t.Run(tc.Name, func(t *testing.T) { app := fiber.New() app.Use("/", New(tc.Config)) handler := app.Handler() ctx := &fasthttp.RequestCtx{} ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, tc.RequestOrigin) handler(ctx) require.Equal(t, tc.ResponseOrigin, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) }) } } // The fix for issue #2422 func Test_CORS_AllowCredentials(t *testing.T) { testCases := []struct { Name string RequestOrigin string ResponseOrigin string ResponseCredentials string Config Config }{ { Name: "AllowOriginsFuncDefined", Config: Config{ AllowCredentials: true, AllowOriginsFunc: func(_ string) bool { return true }, }, RequestOrigin: "http://aaa.com", // The AllowOriginsFunc config was defined, should use the real origin of the function ResponseOrigin: "http://aaa.com", ResponseCredentials: "true", }, { Name: "fiber-ghsa-fmg4-x8pw-hjhg-wildcard-credentials", Config: Config{ AllowCredentials: true, AllowOriginsFunc: func(_ string) bool { return true }, }, RequestOrigin: "*", ResponseOrigin: "", // Middleware will validate that wildcard won't set credentials to true and reject non-serialized origins ResponseCredentials: "", }, { Name: "AllowOriginsFuncNotDefined", Config: Config{ // Setting this to true will cause the middleware to panic since default AllowOrigins is "*" AllowCredentials: false, }, RequestOrigin: "http://aaa.com", // None of the AllowOrigins or AllowOriginsFunc config was defined, should use the default origin of "*" // which will cause the CORS error in the client: // The value of the 'Access-Control-Allow-Origin' header in the response must not be the wildcard '*' // when the request's credentials mode is 'include'. ResponseOrigin: "*", ResponseCredentials: "", }, { Name: "AllowOriginsDefined", Config: Config{ AllowCredentials: true, AllowOrigins: []string{"http://aaa.com"}, }, RequestOrigin: "http://aaa.com", ResponseOrigin: "http://aaa.com", ResponseCredentials: "true", }, { Name: "AllowOriginsDefined/UnallowedOrigin", Config: Config{ AllowCredentials: true, AllowOrigins: []string{"http://aaa.com"}, }, RequestOrigin: "http://bbb.com", ResponseOrigin: "", ResponseCredentials: "", }, } for _, tc := range testCases { t.Run(tc.Name, func(t *testing.T) { app := fiber.New() app.Use("/", New(tc.Config)) handler := app.Handler() ctx := &fasthttp.RequestCtx{} ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodOptions) ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, tc.RequestOrigin) handler(ctx) require.Equal(t, tc.ResponseCredentials, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials))) require.Equal(t, tc.ResponseOrigin, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) }) } } // The Enhancement for issue #2804 func Test_CORS_AllowPrivateNetwork(t *testing.T) { t.Parallel() // Test scenario where AllowPrivateNetwork is enabled app := fiber.New() app.Use(New(Config{ AllowPrivateNetwork: true, })) handler := app.Handler() ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodOptions) ctx.Request.Header.Set(fiber.HeaderOrigin, "https://example.com") ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set("Access-Control-Request-Private-Network", "true") handler(ctx) // Verify the Access-Control-Allow-Private-Network header is set to "true" require.Equal(t, "true", string(ctx.Response.Header.Peek("Access-Control-Allow-Private-Network")), "The Access-Control-Allow-Private-Network header should be set to 'true' when AllowPrivateNetwork is enabled") // Non-preflight request should not have Access-Control-Allow-Private-Network header ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "https://example.com") ctx.Request.Header.Set("Access-Control-Request-Private-Network", "true") handler(ctx) require.Empty(t, string(ctx.Response.Header.Peek("Access-Control-Allow-Private-Network")), "The Access-Control-Allow-Private-Network header should be set to 'true' when AllowPrivateNetwork is enabled") // Non-preflight GET request should not have Access-Control-Allow-Private-Network header require.Empty(t, string(ctx.Response.Header.Peek("Access-Control-Allow-Private-Network")), "The Access-Control-Allow-Private-Network header should be set to 'true' when AllowPrivateNetwork is enabled") // Non-preflight OPTIONS request should not have Access-Control-Allow-Private-Network header ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodOptions) ctx.Request.Header.Set(fiber.HeaderOrigin, "https://example.com") ctx.Request.Header.Set("Access-Control-Request-Private-Network", "true") handler(ctx) require.Empty(t, string(ctx.Response.Header.Peek("Access-Control-Allow-Private-Network")), "The Access-Control-Allow-Private-Network header should be set to 'true' when AllowPrivateNetwork is enabled") // Reset ctx for next test ctx = &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodOptions) ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "https://example.com") // Test scenario where AllowPrivateNetwork is disabled (default) app = fiber.New() app.Use(New()) handler = app.Handler() handler(ctx) // Verify the Access-Control-Allow-Private-Network header is not present require.Empty(t, string(ctx.Response.Header.Peek("Access-Control-Allow-Private-Network")), "The Access-Control-Allow-Private-Network header should not be present by default") // Test scenario where AllowPrivateNetwork is disabled but client sends header app = fiber.New() app.Use(New()) handler = app.Handler() ctx = &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodOptions) ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "https://example.com") ctx.Request.Header.Set("Access-Control-Request-Private-Network", "true") handler(ctx) // Verify the Access-Control-Allow-Private-Network header is not present require.Empty(t, string(ctx.Response.Header.Peek("Access-Control-Allow-Private-Network")), "The Access-Control-Allow-Private-Network header should not be present by default") // Test scenario where AllowPrivateNetwork is enabled and client does NOT send header app = fiber.New() app.Use(New(Config{ AllowPrivateNetwork: true, })) handler = app.Handler() ctx = &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodOptions) ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "https://example.com") handler(ctx) // Verify the Access-Control-Allow-Private-Network header is not present require.Empty(t, string(ctx.Response.Header.Peek("Access-Control-Allow-Private-Network")), "The Access-Control-Allow-Private-Network header should not be present by default") // Test scenario where AllowPrivateNetwork is enabled and client sends header with false value app = fiber.New() app.Use(New(Config{ AllowPrivateNetwork: true, })) handler = app.Handler() ctx = &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodOptions) ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderOrigin, "https://example.com") ctx.Request.Header.Set("Access-Control-Request-Private-Network", "false") handler(ctx) // Verify the Access-Control-Allow-Private-Network header is not present require.Empty(t, string(ctx.Response.Header.Peek("Access-Control-Allow-Private-Network")), "The Access-Control-Allow-Private-Network header should not be present by default") } // go test -v -run=^$ -bench=Benchmark_CORS_NewHandler -benchmem -count=4 func Benchmark_CORS_NewHandler(b *testing.B) { app := fiber.New() c := New(Config{ AllowOrigins: []string{"http://localhost", "http://example.com"}, AllowMethods: []string{fiber.MethodGet, fiber.MethodPost, fiber.MethodPut, fiber.MethodDelete}, AllowHeaders: []string{fiber.HeaderOrigin, fiber.HeaderContentType, fiber.HeaderAccept}, AllowCredentials: true, MaxAge: 600, }) app.Use(c) app.Use(func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} req := &fasthttp.Request{} req.Header.SetMethod(fiber.MethodGet) req.SetRequestURI("/") req.Header.Set(fiber.HeaderOrigin, "http://localhost") req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") ctx.Init(req, nil, nil) b.ReportAllocs() for b.Loop() { h(ctx) } } // go test -v -run=^$ -bench=Benchmark_CORS_NewHandler_Parallel -benchmem -count=4 func Benchmark_CORS_NewHandler_Parallel(b *testing.B) { app := fiber.New() c := New(Config{ AllowOrigins: []string{"http://localhost", "http://example.com"}, AllowMethods: []string{fiber.MethodGet, fiber.MethodPost, fiber.MethodPut, fiber.MethodDelete}, AllowHeaders: []string{fiber.HeaderOrigin, fiber.HeaderContentType, fiber.HeaderAccept}, AllowCredentials: true, MaxAge: 600, }) app.Use(c) app.Use(func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { ctx := &fasthttp.RequestCtx{} req := &fasthttp.Request{} req.Header.SetMethod(fiber.MethodGet) req.SetRequestURI("/") req.Header.Set(fiber.HeaderOrigin, "http://localhost") req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") ctx.Init(req, nil, nil) for pb.Next() { h(ctx) } }) } // go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerSingleOrigin -benchmem -count=4 func Benchmark_CORS_NewHandlerSingleOrigin(b *testing.B) { app := fiber.New() c := New(Config{ AllowOrigins: []string{"http://example.com"}, AllowMethods: []string{fiber.MethodGet, fiber.MethodPost, fiber.MethodPut, fiber.MethodDelete}, AllowHeaders: []string{fiber.HeaderOrigin, fiber.HeaderContentType, fiber.HeaderAccept}, AllowCredentials: true, MaxAge: 600, }) app.Use(c) app.Use(func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} req := &fasthttp.Request{} req.Header.SetMethod(fiber.MethodGet) req.SetRequestURI("/") req.Header.Set(fiber.HeaderOrigin, "http://example.com") req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") ctx.Init(req, nil, nil) b.ReportAllocs() for b.Loop() { h(ctx) } } // go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerSingleOrigin_Parallel -benchmem -count=4 func Benchmark_CORS_NewHandlerSingleOrigin_Parallel(b *testing.B) { app := fiber.New() c := New(Config{ AllowOrigins: []string{"http://example.com"}, AllowMethods: []string{fiber.MethodGet, fiber.MethodPost, fiber.MethodPut, fiber.MethodDelete}, AllowHeaders: []string{fiber.HeaderOrigin, fiber.HeaderContentType, fiber.HeaderAccept}, AllowCredentials: true, MaxAge: 600, }) app.Use(c) app.Use(func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { ctx := &fasthttp.RequestCtx{} req := &fasthttp.Request{} req.Header.SetMethod(fiber.MethodGet) req.SetRequestURI("/") req.Header.Set(fiber.HeaderOrigin, "http://example.com") req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") ctx.Init(req, nil, nil) for pb.Next() { h(ctx) } }) } // go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerWildcard -benchmem -count=4 func Benchmark_CORS_NewHandlerWildcard(b *testing.B) { app := fiber.New() c := New(Config{ AllowMethods: []string{fiber.MethodGet, fiber.MethodPost, fiber.MethodPut, fiber.MethodDelete}, AllowHeaders: []string{fiber.HeaderOrigin, fiber.HeaderContentType, fiber.HeaderAccept}, AllowCredentials: false, MaxAge: 600, }) app.Use(c) app.Use(func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} req := &fasthttp.Request{} req.Header.SetMethod(fiber.MethodGet) req.SetRequestURI("/") req.Header.Set(fiber.HeaderOrigin, "http://example.com") req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") ctx.Init(req, nil, nil) b.ReportAllocs() for b.Loop() { h(ctx) } } // go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerWildcard_Parallel -benchmem -count=4 func Benchmark_CORS_NewHandlerWildcard_Parallel(b *testing.B) { app := fiber.New() c := New(Config{ AllowMethods: []string{fiber.MethodGet, fiber.MethodPost, fiber.MethodPut, fiber.MethodDelete}, AllowHeaders: []string{fiber.HeaderOrigin, fiber.HeaderContentType, fiber.HeaderAccept}, AllowCredentials: false, MaxAge: 600, }) app.Use(c) app.Use(func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { ctx := &fasthttp.RequestCtx{} req := &fasthttp.Request{} req.Header.SetMethod(fiber.MethodGet) req.SetRequestURI("/") req.Header.Set(fiber.HeaderOrigin, "http://example.com") req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") ctx.Init(req, nil, nil) for pb.Next() { h(ctx) } }) } // go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerPreflight -benchmem -count=4 func Benchmark_CORS_NewHandlerPreflight(b *testing.B) { app := fiber.New() c := New(Config{ AllowOrigins: []string{"http://localhost", "http://example.com"}, AllowMethods: []string{fiber.MethodGet, fiber.MethodPost, fiber.MethodPut, fiber.MethodDelete}, AllowHeaders: []string{fiber.HeaderOrigin, fiber.HeaderContentType, fiber.HeaderAccept}, AllowCredentials: true, MaxAge: 600, }) app.Use(c) app.Use(func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Preflight request req := &fasthttp.Request{} req.Header.SetMethod(fiber.MethodOptions) req.SetRequestURI("/") req.Header.Set(fiber.HeaderOrigin, "http://example.com") req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodPost) req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") ctx.Init(req, nil, nil) b.ReportAllocs() for b.Loop() { h(ctx) } } // go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerPreflight_Parallel -benchmem -count=4 func Benchmark_CORS_NewHandlerPreflight_Parallel(b *testing.B) { app := fiber.New() c := New(Config{ AllowOrigins: []string{"http://localhost", "http://example.com"}, AllowMethods: []string{fiber.MethodGet, fiber.MethodPost, fiber.MethodPut, fiber.MethodDelete}, AllowHeaders: []string{fiber.HeaderOrigin, fiber.HeaderContentType, fiber.HeaderAccept}, AllowCredentials: true, MaxAge: 600, }) app.Use(c) app.Use(func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { ctx := &fasthttp.RequestCtx{} req := &fasthttp.Request{} req.Header.SetMethod(fiber.MethodOptions) req.SetRequestURI("/") req.Header.Set(fiber.HeaderOrigin, "http://example.com") req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodPost) req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") ctx.Init(req, nil, nil) for pb.Next() { h(ctx) } }) } // go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerPreflightSingleOrigin -benchmem -count=4 func Benchmark_CORS_NewHandlerPreflightSingleOrigin(b *testing.B) { app := fiber.New() c := New(Config{ AllowOrigins: []string{"http://example.com"}, AllowMethods: []string{fiber.MethodGet, fiber.MethodPost, fiber.MethodPut, fiber.MethodDelete}, AllowHeaders: []string{fiber.HeaderOrigin, fiber.HeaderContentType, fiber.HeaderAccept}, AllowCredentials: true, MaxAge: 600, }) app.Use(c) app.Use(func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} req := &fasthttp.Request{} req.Header.SetMethod(fiber.MethodOptions) req.SetRequestURI("/") req.Header.Set(fiber.HeaderOrigin, "http://example.com") req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodPost) req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") ctx.Init(req, nil, nil) b.ReportAllocs() for b.Loop() { h(ctx) } } // go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerPreflightSingleOrigin_Parallel -benchmem -count=4 func Benchmark_CORS_NewHandlerPreflightSingleOrigin_Parallel(b *testing.B) { app := fiber.New() c := New(Config{ AllowOrigins: []string{"http://example.com"}, AllowMethods: []string{fiber.MethodGet, fiber.MethodPost, fiber.MethodPut, fiber.MethodDelete}, AllowHeaders: []string{fiber.HeaderOrigin, fiber.HeaderContentType, fiber.HeaderAccept}, AllowCredentials: true, MaxAge: 600, }) app.Use(c) app.Use(func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { ctx := &fasthttp.RequestCtx{} req := &fasthttp.Request{} req.Header.SetMethod(fiber.MethodOptions) req.SetRequestURI("/") req.Header.Set(fiber.HeaderOrigin, "http://example.com") req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodPost) req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") ctx.Init(req, nil, nil) for pb.Next() { h(ctx) } }) } // go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerPreflightWildcard -benchmem -count=4 func Benchmark_CORS_NewHandlerPreflightWildcard(b *testing.B) { app := fiber.New() c := New(Config{ AllowMethods: []string{fiber.MethodGet, fiber.MethodPost, fiber.MethodPut, fiber.MethodDelete}, AllowHeaders: []string{fiber.HeaderOrigin, fiber.HeaderContentType, fiber.HeaderAccept}, AllowCredentials: false, MaxAge: 600, }) app.Use(c) app.Use(func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} req := &fasthttp.Request{} req.Header.SetMethod(fiber.MethodOptions) req.SetRequestURI("/") req.Header.Set(fiber.HeaderOrigin, "http://example.com") req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodPost) req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") ctx.Init(req, nil, nil) b.ReportAllocs() for b.Loop() { h(ctx) } } // go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerPreflightWildcard_Parallel -benchmem -count=4 func Benchmark_CORS_NewHandlerPreflightWildcard_Parallel(b *testing.B) { app := fiber.New() c := New(Config{ AllowMethods: []string{fiber.MethodGet, fiber.MethodPost, fiber.MethodPut, fiber.MethodDelete}, AllowHeaders: []string{fiber.HeaderOrigin, fiber.HeaderContentType, fiber.HeaderAccept}, AllowCredentials: false, MaxAge: 600, }) app.Use(c) app.Use(func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { ctx := &fasthttp.RequestCtx{} req := &fasthttp.Request{} req.Header.SetMethod(fiber.MethodOptions) req.SetRequestURI("/") req.Header.Set(fiber.HeaderOrigin, "http://example.com") req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodPost) req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") ctx.Init(req, nil, nil) for pb.Next() { h(ctx) } }) } ================================================ FILE: middleware/cors/utils.go ================================================ package cors import ( "net/url" "strings" utilsstrings "github.com/gofiber/utils/v2/strings" ) // matchScheme compares the scheme of the domain and pattern func matchScheme(domain, pattern string) bool { dScheme, _, dFound := strings.Cut(domain, ":") pScheme, _, pFound := strings.Cut(pattern, ":") return dFound && pFound && dScheme == pScheme } // normalizeDomain removes the scheme and port from the input domain func normalizeDomain(input string) string { // Remove scheme if after, found := strings.CutPrefix(input, "https://"); found { input = after } else if after, found := strings.CutPrefix(input, "http://"); found { input = after } // Find and remove port, if present if input != "" && input[0] != '[' { if before, _, found := strings.Cut(input, ":"); found { input = before } } return input } // normalizeOrigin checks if the provided origin is in a correct format // and normalizes it by removing any path or trailing slash. // It returns a boolean indicating whether the origin is valid // and the normalized origin. func normalizeOrigin(origin string) (valid bool, normalized string) { //nolint:nonamedreturns // gocritic unnamedResult prefers naming validity and normalized origin results parsedOrigin, err := url.Parse(origin) if err != nil { return false, "" } // Don't allow a wildcard with a protocol // wildcards cannot be used within any other value. For example, the following header is not valid: // Access-Control-Allow-Origin: https://* if strings.IndexByte(parsedOrigin.Host, '*') >= 0 { return false, "" } // Validate there is a host present. The presence of a path, query, or fragment components // is checked, but a trailing "/" (indicative of the root) is allowed for the path and will be normalized if parsedOrigin.User != nil || parsedOrigin.Host == "" || (parsedOrigin.Path != "" && parsedOrigin.Path != "/") || parsedOrigin.RawQuery != "" || parsedOrigin.Fragment != "" { return false, "" } // Normalize the origin by constructing it from the scheme and host. // The path or trailing slash is not included in the normalized origin. return true, utilsstrings.ToLower(parsedOrigin.Scheme) + "://" + utilsstrings.ToLower(parsedOrigin.Host) } type subdomain struct { // The wildcard pattern prefix string suffix string } func (s subdomain) match(o string) bool { // Not a subdomain if not long enough for a dot separator. if len(o) < len(s.prefix)+len(s.suffix)+1 { return false } if !strings.HasPrefix(o, s.prefix) || !strings.HasSuffix(o, s.suffix) { return false } // Check for the dot separator and validate that there is at least one // non-empty label between prefix and suffix. Empty labels like // "https://.example.com" or "https://..example.com" should not match. suffixStartIndex := len(o) - len(s.suffix) if suffixStartIndex <= len(s.prefix) { return false } if o[suffixStartIndex-1] != '.' { return false } // Extract the subdomain part (without the trailing dot) and ensure it // doesn't contain empty labels. sub := o[len(s.prefix) : suffixStartIndex-1] if sub == "" || strings.HasPrefix(sub, ".") || strings.Contains(sub, "..") { return false } return true } ================================================ FILE: middleware/cors/utils_test.go ================================================ package cors import ( "testing" "github.com/stretchr/testify/assert" ) // go test -run -v Test_NormalizeOrigin func Test_NormalizeOrigin(t *testing.T) { testCases := []struct { origin string expectedOrigin string expectedValid bool }{ {origin: "http://example.com", expectedValid: true, expectedOrigin: "http://example.com"}, // Simple case should work. {origin: "http://example.com/", expectedValid: true, expectedOrigin: "http://example.com"}, // Trailing slash should be removed. {origin: "http://example.com:3000", expectedValid: true, expectedOrigin: "http://example.com:3000"}, // Port should be preserved. {origin: "http://example.com:3000/", expectedValid: true, expectedOrigin: "http://example.com:3000"}, // Trailing slash should be removed. {origin: "app://example.com/", expectedValid: true, expectedOrigin: "app://example.com"}, // App scheme should be accepted. {origin: "http://", expectedValid: false, expectedOrigin: ""}, // Invalid origin should not be accepted. {origin: "file:///etc/passwd", expectedValid: false, expectedOrigin: ""}, // File scheme should not be accepted. {origin: "https://*example.com", expectedValid: false, expectedOrigin: ""}, // Wildcard domain should not be accepted. {origin: "http://*.example.com", expectedValid: false, expectedOrigin: ""}, // Wildcard subdomain should not be accepted. {origin: "http://example.com/path", expectedValid: false, expectedOrigin: ""}, // Path should not be accepted. {origin: "http://example.com?query=123", expectedValid: false, expectedOrigin: ""}, // Query should not be accepted. {origin: "http://example.com#fragment", expectedValid: false, expectedOrigin: ""}, // Fragment should not be accepted. {origin: "http://user:pass@example.com", expectedValid: false, expectedOrigin: ""}, // Userinfo should not be accepted. {origin: "http://localhost", expectedValid: true, expectedOrigin: "http://localhost"}, // Localhost should be accepted. {origin: "http://127.0.0.1", expectedValid: true, expectedOrigin: "http://127.0.0.1"}, // IPv4 address should be accepted. {origin: "http://[::1]", expectedValid: true, expectedOrigin: "http://[::1]"}, // IPv6 address should be accepted. {origin: "http://[::1]:8080", expectedValid: true, expectedOrigin: "http://[::1]:8080"}, // IPv6 address with port should be accepted. {origin: "http://[::1]:8080/", expectedValid: true, expectedOrigin: "http://[::1]:8080"}, // IPv6 address with port and trailing slash should be accepted. {origin: "http://[::1]:8080/path", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port and path should not be accepted. {origin: "http://[::1]:8080?query=123", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port and query should not be accepted. {origin: "http://[::1]:8080#fragment", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port and fragment should not be accepted. {origin: "http://[::1]:8080/path?query=123#fragment", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port, path, query, and fragment should not be accepted. {origin: "http://[::1]:8080/path?query=123#fragment/", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port, path, query, fragment, and trailing slash should not be accepted. {origin: "http://[::1]:8080/path?query=123#fragment/invalid", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port, path, query, fragment, trailing slash, and invalid segment should not be accepted. {origin: "http://[::1]:8080/path?query=123#fragment/invalid/", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port, path, query, fragment, trailing slash, and invalid segment with trailing slash should not be accepted. {origin: "http://[::1]:8080/path?query=123#fragment/invalid/segment", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port, path, query, fragment, trailing slash, and invalid segment with additional segment should not be accepted. } for _, tc := range testCases { valid, normalizedOrigin := normalizeOrigin(tc.origin) if valid != tc.expectedValid { t.Errorf("Expected origin '%s' to be valid: %v, but got: %v", tc.origin, tc.expectedValid, valid) } if normalizedOrigin != tc.expectedOrigin { t.Errorf("Expected normalized origin '%s' for origin '%s', but got: '%s'", tc.expectedOrigin, tc.origin, normalizedOrigin) } } } // go test -run -v Test_MatchScheme func Test_MatchScheme(t *testing.T) { testCases := []struct { domain string pattern string expected bool }{ {domain: "http://example.com", pattern: "http://example.com", expected: true}, // Exact match should work. {domain: "https://example.com", pattern: "http://example.com", expected: false}, // Scheme mismatch should matter. {domain: "http://example.com", pattern: "https://example.com", expected: false}, // Scheme mismatch should matter. {domain: "http://example.com", pattern: "http://example.org", expected: true}, // Different domains should not matter. {domain: "http://example.com", pattern: "http://example.com:8080", expected: true}, // Port should not matter. {domain: "http://example.com:8080", pattern: "http://example.com", expected: true}, // Port should not matter. {domain: "http://example.com:8080", pattern: "http://example.com:8081", expected: true}, // Different ports should not matter. {domain: "http://localhost", pattern: "http://localhost", expected: true}, // Localhost should match. {domain: "http://127.0.0.1", pattern: "http://127.0.0.1", expected: true}, // IPv4 address should match. {domain: "http://[::1]", pattern: "http://[::1]", expected: true}, // IPv6 address should match. } for _, tc := range testCases { result := matchScheme(tc.domain, tc.pattern) if result != tc.expected { t.Errorf("Expected matchScheme('%s', '%s') to be %v, but got %v", tc.domain, tc.pattern, tc.expected, result) } } } // go test -run -v Test_NormalizeDomain func Test_NormalizeDomain(t *testing.T) { testCases := []struct { input string expectedOutput string }{ {input: "http://example.com", expectedOutput: "example.com"}, // Simple case with http scheme. {input: "https://example.com", expectedOutput: "example.com"}, // Simple case with https scheme. {input: "http://example.com:3000", expectedOutput: "example.com"}, // Case with port. {input: "https://example.com:3000", expectedOutput: "example.com"}, // Case with port and https scheme. {input: "http://example.com/path", expectedOutput: "example.com/path"}, // Case with path. {input: "http://example.com?query=123", expectedOutput: "example.com?query=123"}, // Case with query. {input: "http://example.com#fragment", expectedOutput: "example.com#fragment"}, // Case with fragment. {input: "example.com", expectedOutput: "example.com"}, // Case without scheme. {input: "example.com:8080", expectedOutput: "example.com"}, // Case without scheme but with port. {input: "sub.example.com", expectedOutput: "sub.example.com"}, // Case with subdomain. {input: "sub.sub.example.com", expectedOutput: "sub.sub.example.com"}, // Case with nested subdomain. {input: "http://localhost", expectedOutput: "localhost"}, // Case with localhost. {input: "http://127.0.0.1", expectedOutput: "127.0.0.1"}, // Case with IPv4 address. {input: "http://[::1]", expectedOutput: "[::1]"}, // Case with IPv6 address. } for _, tc := range testCases { output := normalizeDomain(tc.input) if output != tc.expectedOutput { t.Errorf("Expected normalized domain '%s' for input '%s', but got: '%s'", tc.expectedOutput, tc.input, output) } } } // go test -v -run=^$ -bench=Benchmark_CORS_SubdomainMatch -benchmem -count=4 func Benchmark_CORS_SubdomainMatch(b *testing.B) { s := subdomain{ prefix: "www", suffix: "example.com", } o := "www.example.com" b.ReportAllocs() for b.Loop() { s.match(o) } } func Test_CORS_SubdomainMatch(t *testing.T) { tests := []struct { name string sub subdomain origin string expected bool }{ { name: "match with different scheme", sub: subdomain{prefix: "http://api.", suffix: "example.com"}, origin: "https://api.service.example.com", expected: false, }, { name: "match with different scheme", sub: subdomain{prefix: "https://", suffix: "example.com"}, origin: "http://api.service.example.com", expected: false, }, { name: "match with valid subdomain", sub: subdomain{prefix: "https://", suffix: "example.com"}, origin: "https://api.service.example.com", expected: true, }, { name: "match with valid nested subdomain", sub: subdomain{prefix: "https://", suffix: "example.com"}, origin: "https://1.2.api.service.example.com", expected: true, }, { name: "no match with invalid prefix", sub: subdomain{prefix: "https://abc.", suffix: "example.com"}, origin: "https://service.example.com", expected: false, }, { name: "no match with invalid suffix", sub: subdomain{prefix: "https://", suffix: "example.com"}, origin: "https://api.example.org", expected: false, }, { name: "no match with empty origin", sub: subdomain{prefix: "https://", suffix: "example.com"}, origin: "", expected: false, }, { name: "no match with malformed subdomain", sub: subdomain{prefix: "https://", suffix: "example.com"}, origin: "https://evil.comexample.com", expected: false, }, { name: "partial match not considered a match", sub: subdomain{prefix: "https://service.", suffix: "example.com"}, origin: "https://api.example.com", expected: false, }, { name: "no match with empty host label", sub: subdomain{prefix: "https://", suffix: "example.com"}, origin: "https://.example.com", expected: false, }, { name: "no match with malformed host label", sub: subdomain{prefix: "https://", suffix: "example.com"}, origin: "https://..example.com", expected: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := tt.sub.match(tt.origin) assert.Equal(t, tt.expected, got, "subdomain.match()") }) } } ================================================ FILE: middleware/csrf/config.go ================================================ package csrf import ( "fmt" "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/extractors" "github.com/gofiber/fiber/v3/log" "github.com/gofiber/fiber/v3/middleware/session" "github.com/gofiber/utils/v2" ) // Config defines the config for CSRF middleware. type Config struct { // Storage is used to store the state of the middleware. // // Optional. Default: memory.New() // Ignored if Session is set. Storage fiber.Storage // Next defines a function to skip this middleware when returned true. // // Optional. Default: nil Next func(c fiber.Ctx) bool // Session is used to store the state of the middleware. // // Optional. Default: nil // If set, the middleware will use the session store instead of the storage. Session *session.Store // KeyGenerator creates a new CSRF token. // // Optional. Default: utils.SecureToken KeyGenerator func() string // ErrorHandler is executed when an error is returned from fiber.Handler. // // Optional. Default: defaultErrorHandler ErrorHandler fiber.ErrorHandler // CookieName is the name of the CSRF cookie. // // Optional. Default: "csrf_" CookieName string // CookieDomain is the domain of the CSRF cookie. // // Optional. Default: "" CookieDomain string // CookiePath is the path of the CSRF cookie. // // Optional. Default: "" CookiePath string // CookieSameSite is the SameSite attribute of the CSRF cookie. // // Optional. Default: "Lax" CookieSameSite string // TrustedOrigins is a list of trusted origins for unsafe requests. // For requests that use the Origin header, the origin must match the // Host header or one of the TrustedOrigins. // For secure requests that do not include the Origin header, the Referer // header must match the Host header or one of the TrustedOrigins. // // This supports matching subdomains at any level. This means you can use a value like // "https://*.example.com" to allow any subdomain of example.com to submit requests, // including multiple subdomain levels such as "https://sub.sub.example.com". // // Optional. Default: [] TrustedOrigins []string // Extractor returns the CSRF token from the request. // // Optional. Default: extractors.FromHeader("X-Csrf-Token") // // Available extractors from github.com/gofiber/fiber/v3/extractors: // - extractors.FromHeader("X-Csrf-Token"): Most secure, recommended for APIs // - extractors.FromForm("_csrf"): Secure, recommended for form submissions // - extractors.FromQuery("csrf_token"): Less secure, URLs may be logged // - extractors.FromParam("csrf"): Less secure, URLs may be logged // - extractors.Chain(...): Advanced chaining of multiple extractors // // See the Extractors Guide for complete documentation: // https://docs.gofiber.io/guide/extractors // // WARNING: Never create custom extractors that read from cookies with the same // CookieName as this defeats CSRF protection entirely. Extractor extractors.Extractor // IdleTimeout is the duration of time the CSRF token is valid. // // Optional. Default: 30 * time.Minute IdleTimeout time.Duration // DisableValueRedaction turns off masking CSRF tokens and storage keys in logs and errors. // // Optional. Default: false DisableValueRedaction bool // CookieSecure indicates if CSRF cookie is secure. // // Optional. Default: false CookieSecure bool // CookieHTTPOnly indicates if CSRF cookie is HTTP only. // // Optional. Default: false CookieHTTPOnly bool // CookieSessionOnly decides whether cookie should last for only the browser session. // Ignores Expiration if set to true. // // Optional. Default: false CookieSessionOnly bool // SingleUseToken indicates if the CSRF token should be destroyed // and a new one generated on each use. // // Optional. Default: false SingleUseToken bool } // HeaderName is the default header name for CSRF tokens. const HeaderName = "X-Csrf-Token" // ConfigDefault is the default config for CSRF middleware. var ConfigDefault = Config{ CookieName: "csrf_", CookieSameSite: "Lax", IdleTimeout: 30 * time.Minute, KeyGenerator: utils.SecureToken, ErrorHandler: defaultErrorHandler, Extractor: extractors.FromHeader(HeaderName), DisableValueRedaction: false, } // defaultErrorHandler is the default error handler that processes errors from fiber.Handler. func defaultErrorHandler(_ fiber.Ctx, _ error) error { return fiber.ErrForbidden } // configDefault is a helper function to set default values. func configDefault(config ...Config) Config { // Return default config if nothing provided if len(config) < 1 { return ConfigDefault } // Override default config cfg := config[0] // Set default values if cfg.IdleTimeout <= 0 { cfg.IdleTimeout = ConfigDefault.IdleTimeout } if cfg.CookieName == "" { cfg.CookieName = ConfigDefault.CookieName } if cfg.CookieSameSite == "" { cfg.CookieSameSite = ConfigDefault.CookieSameSite } if cfg.KeyGenerator == nil { cfg.KeyGenerator = ConfigDefault.KeyGenerator } if cfg.ErrorHandler == nil { cfg.ErrorHandler = ConfigDefault.ErrorHandler } // Check if Extractor is zero value (since it's a struct) if cfg.Extractor.Extract == nil { cfg.Extractor = ConfigDefault.Extractor } // Validate extractor security configurations validateExtractorSecurity(&cfg) return cfg } // validateExtractorSecurity checks for insecure extractor configurations func validateExtractorSecurity(cfg *Config) { if cfg == nil { return } // Check primary extractor if isInsecureCookieExtractor(cfg.Extractor, cfg.CookieName) { panic("CSRF: Extractor reads from the same cookie '" + cfg.CookieName + "' used for token storage. This completely defeats CSRF protection.") } // Check chained extractors for i, extractor := range cfg.Extractor.Chain { if isInsecureCookieExtractor(extractor, cfg.CookieName) { panic(fmt.Sprintf("CSRF: Chained extractor #%d reads from the same cookie '%s' "+ "used for token storage. This completely defeats CSRF protection.", i+1, cfg.CookieName)) } } // Additional security warnings (non-fatal) if cfg.Extractor.Source == extractors.SourceQuery || cfg.Extractor.Source == extractors.SourceParam { log.Warnf("[CSRF WARNING] Using %v extractor - URLs may be logged", cfg.Extractor.Source) } } // isInsecureCookieExtractor checks if an extractor unsafely reads from the CSRF cookie func isInsecureCookieExtractor(extractor extractors.Extractor, cookieName string) bool { if extractor.Source == extractors.SourceCookie { // Exact match - definitely insecure if extractor.Key == cookieName { return true } // Case-insensitive match - potentially confusing, warn but don't panic if utils.EqualFold(extractor.Key, cookieName) && extractor.Key != cookieName { log.Warnf("[CSRF WARNING] Extractor cookie name '%s' is similar to CSRF cookie '%s' - this may be confusing", extractor.Key, cookieName) } } return false } ================================================ FILE: middleware/csrf/config_test.go ================================================ package csrf import ( "fmt" "strings" "testing" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/extractors" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) // Test security validation functions func Test_CSRF_ExtractorSecurity_Validation(t *testing.T) { t.Parallel() // Test secure configurations - should not panic t.Run("SecureConfigurations", func(t *testing.T) { t.Parallel() secureConfigs := []Config{ {Extractor: extractors.FromHeader("X-Csrf-Token")}, {Extractor: extractors.FromForm("_csrf")}, {Extractor: extractors.FromQuery("csrf_token")}, {Extractor: extractors.FromParam("csrf")}, {Extractor: extractors.Chain(extractors.FromHeader("X-Csrf-Token"), extractors.FromForm("_csrf"))}, } for i, cfg := range secureConfigs { t.Run(fmt.Sprintf("Config%d", i), func(t *testing.T) { require.NotPanics(t, func() { configDefault(cfg) }) }) } }) // Test insecure configurations - should panic t.Run("InsecureCookieExtractor", func(t *testing.T) { t.Parallel() // Create a custom extractor that reads from cookie (simulating dangerous behavior) insecureCookieExtractor := extractors.Extractor{ Extract: func(c fiber.Ctx) (string, error) { return c.Cookies("csrf_"), nil }, Source: extractors.SourceCookie, Key: "csrf_", } cfg := Config{ CookieName: "csrf_", Extractor: insecureCookieExtractor, } require.Panics(t, func() { configDefault(cfg) }, "Should panic when extractor reads from same cookie") }) // Test insecure chained extractors t.Run("InsecureChainedExtractor", func(t *testing.T) { t.Parallel() insecureCookieExtractor := extractors.Extractor{ Extract: func(c fiber.Ctx) (string, error) { return c.Cookies("csrf_"), nil }, Source: extractors.SourceCookie, Key: "csrf_", } chainedExtractor := extractors.Chain( extractors.FromHeader("X-Csrf-Token"), insecureCookieExtractor, // This should trigger panic ) cfg := Config{ CookieName: "csrf_", Extractor: chainedExtractor, } require.Panics(t, func() { configDefault(cfg) }, "Should panic when chained extractor reads from same cookie") }) // Test different cookie names - should be secure t.Run("DifferentCookieNames", func(t *testing.T) { t.Parallel() cookieExtractor := extractors.Extractor{ Extract: func(c fiber.Ctx) (string, error) { return c.Cookies("different_cookie"), nil }, Source: extractors.SourceCookie, Key: "different_cookie", } cfg := Config{ CookieName: "csrf_", Extractor: cookieExtractor, } require.NotPanics(t, func() { configDefault(cfg) }, "Should not panic when extractor reads from different cookie") }) } // Test extractor metadata func Test_CSRF_Extractor_Metadata(t *testing.T) { t.Parallel() testCases := []struct { name string expectedKey string extractor extractors.Extractor expectedSource extractors.Source }{ { name: "FromHeader", extractor: extractors.FromHeader("X-Custom-Token"), expectedSource: extractors.SourceHeader, expectedKey: "X-Custom-Token", }, { name: "FromForm", extractor: extractors.FromForm("_token"), expectedSource: extractors.SourceForm, expectedKey: "_token", }, { name: "FromQuery", extractor: extractors.FromQuery("token"), expectedSource: extractors.SourceQuery, expectedKey: "token", }, { name: "FromParam", extractor: extractors.FromParam("id"), expectedSource: extractors.SourceParam, expectedKey: "id", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() require.Equal(t, tc.expectedSource, tc.extractor.Source) require.Equal(t, tc.expectedKey, tc.extractor.Key) require.NotNil(t, tc.extractor.Extract) }) } } // Test chain extractor metadata func Test_CSRF_Chain_Extractor_Metadata(t *testing.T) { t.Parallel() t.Run("EmptyChain", func(t *testing.T) { t.Parallel() chained := extractors.Chain() require.Equal(t, extractors.SourceCustom, chained.Source) require.Empty(t, chained.Key) require.Empty(t, chained.Chain) }) t.Run("SingleExtractor", func(t *testing.T) { t.Parallel() header := extractors.FromHeader("X-Token") chained := extractors.Chain(header) require.Equal(t, extractors.SourceHeader, chained.Source) require.Equal(t, "X-Token", chained.Key) require.Len(t, chained.Chain, 1) }) t.Run("MultipleExtractors", func(t *testing.T) { t.Parallel() header := extractors.FromHeader("X-Token") form := extractors.FromForm("_csrf") chained := extractors.Chain(header, form) // Should use first extractor's metadata require.Equal(t, extractors.SourceHeader, chained.Source) require.Equal(t, "X-Token", chained.Key) require.Len(t, chained.Chain, 2) require.Equal(t, header.Source, chained.Chain[0].Source) require.Equal(t, form.Source, chained.Chain[1].Source) }) } // Test custom extractor with new struct pattern func Test_CSRF_Custom_Extractor_Struct(t *testing.T) { t.Parallel() app := fiber.New() // Custom extractor using new struct pattern customExtractor := extractors.Extractor{ Extract: func(c fiber.Ctx) (string, error) { // Extract from custom header token := c.Get("X-Custom-CSRF") if token == "" { return "", extractors.ErrNotFound } return token, nil }, Source: extractors.SourceCustom, Key: "X-Custom-CSRF", } app.Use(New(Config{Extractor: customExtractor})) app.Post("/", func(c fiber.Ctx) error { return c.SendString("OK") }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] // Test with custom header ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set("X-Custom-CSRF", token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test without custom header ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) } // Test error types for different extractors func Test_CSRF_Extractor_Error_Types(t *testing.T) { t.Parallel() testCases := []struct { expectedError error setupRequest func(*fasthttp.RequestCtx) name string extractor extractors.Extractor }{ { name: "MissingHeader", extractor: extractors.FromHeader("X-Missing"), setupRequest: func(_ *fasthttp.RequestCtx) { // Don't set the header }, expectedError: extractors.ErrNotFound, }, { name: "MissingForm", extractor: extractors.FromForm("_missing"), setupRequest: func(ctx *fasthttp.RequestCtx) { ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm) // Don't set form data }, expectedError: extractors.ErrNotFound, }, { name: "MissingQuery", extractor: extractors.FromQuery("missing"), setupRequest: func(ctx *fasthttp.RequestCtx) { ctx.Request.SetRequestURI("/") // Don't set query param }, expectedError: extractors.ErrNotFound, }, { name: "MissingParam", extractor: extractors.FromParam("missing"), setupRequest: func(_ *fasthttp.RequestCtx) { // This would need special route setup to test properly // For now, we'll test the extractor directly }, expectedError: extractors.ErrNotFound, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() app := fiber.New() ctx := &fasthttp.RequestCtx{} tc.setupRequest(ctx) c := app.AcquireCtx(ctx) _, err := tc.extractor.Extract(c) require.Error(t, err) require.Equal(t, tc.expectedError, err) app.ReleaseCtx(c) }) } } // Test security warning logs (would need to capture log output in real implementation) func Test_CSRF_Security_Warnings(t *testing.T) { t.Parallel() // Test that insecure extractors trigger warnings // Note: In a real implementation, you'd want to capture log output // For now, we just test that the configuration doesn't panic insecureConfigs := []Config{ {Extractor: extractors.FromQuery("csrf_token")}, {Extractor: extractors.FromParam("csrf")}, } for i, cfg := range insecureConfigs { t.Run(fmt.Sprintf("InsecureConfig%d", i), func(t *testing.T) { t.Parallel() require.NotPanics(t, func() { configDefault(cfg) }) }) } } // Test isInsecureCookieExtractor function directly func Test_isInsecureCookieExtractor(t *testing.T) { t.Parallel() testCases := []struct { name string cookieName string extractor extractors.Extractor expected bool }{ { name: "SecureHeaderExtractor", extractor: extractors.Extractor{ Source: extractors.SourceHeader, Key: "X-Csrf-Token", }, cookieName: "csrf_", expected: false, }, { name: "InsecureCookieExtractor", extractor: extractors.Extractor{ Source: extractors.SourceCookie, Key: "csrf_", }, cookieName: "csrf_", expected: true, }, { name: "CookieExtractorDifferentName", extractor: extractors.Extractor{ Source: extractors.SourceCookie, Key: "different_cookie", }, cookieName: "csrf_", expected: false, }, { name: "CustomExtractorSafeName", extractor: extractors.Extractor{ Source: extractors.SourceCustom, Key: "safe_key", }, cookieName: "csrf_", expected: false, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() result := isInsecureCookieExtractor(tc.extractor, tc.cookieName) require.Equal(t, tc.expected, result) }) } } func Test_CSRF_CookieName_CaseInsensitive_Warning(t *testing.T) { t.Parallel() // Extractor uses "CSRF_" (uppercase), config uses "csrf_" (lowercase) extractor := extractors.Extractor{ Extract: func(c fiber.Ctx) (string, error) { return c.Cookies("CSRF_"), nil }, Source: extractors.SourceCookie, Key: "CSRF_", } cfg := Config{ CookieName: "csrf_", Extractor: extractor, } // Should not panic, but should log a warning require.NotPanics(t, func() { configDefault(cfg) }, "Should not panic for case-insensitive cookie name match, but should warn") } ================================================ FILE: middleware/csrf/csrf.go ================================================ package csrf import ( "errors" "fmt" "net/url" "slices" "strings" "time" "github.com/gofiber/utils/v2" utilsstrings "github.com/gofiber/utils/v2/strings" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/extractors" ) var ( ErrTokenNotFound = errors.New("csrf: token not found") ErrTokenInvalid = errors.New("csrf: token invalid") ErrFetchSiteInvalid = errors.New("csrf: sec-fetch-site header invalid") ErrRefererNotFound = errors.New("csrf: referer header missing") ErrRefererInvalid = errors.New("csrf: referer header invalid") ErrRefererNoMatch = errors.New("csrf: referer does not match host or trusted origins") ErrOriginInvalid = errors.New("csrf: origin header invalid") ErrOriginNoMatch = errors.New("csrf: origin does not match host or trusted origins") errOriginNotFound = errors.New("origin not supplied or is null") // internal error, will not be returned to the user dummyValue = []byte{'+'} // dummyValue is a placeholder value stored in token storage. The actual token validation relies on the key, not this value. ) // Handler for CSRF middleware type Handler struct { sessionManager *sessionManager storageManager *storageManager config Config } // The contextKey type is unexported to prevent collisions with context keys defined in // other packages. type contextKey int // The keys for the values in context const ( tokenKey contextKey = iota handlerKey ) // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Set default config cfg := configDefault(config...) redactKeys := !cfg.DisableValueRedaction maskValue := func(value string) string { if redactKeys { return redactedKey } return value } // Create manager to simplify storage operations ( see *_manager.go ) var sessionManager *sessionManager var storageManager *storageManager if cfg.Session != nil { sessionManager = newSessionManager(cfg.Session) } else { storageManager = newStorageManager(cfg.Storage, redactKeys) } // Pre-parse trusted origins trustedOrigins := []string{} trustedSubOrigins := []subdomain{} for _, origin := range cfg.TrustedOrigins { trimmedOrigin := utils.TrimSpace(origin) if i := strings.Index(trimmedOrigin, "://*."); i != -1 { withoutWildcard := trimmedOrigin[:i+len("://")] + trimmedOrigin[i+len("://*."):] isValid, normalizedOrigin := normalizeOrigin(withoutWildcard) if !isValid { panic("[CSRF] Invalid origin format in configuration:" + maskValue(origin)) } schemeSep := strings.Index(normalizedOrigin, "://") + len("://") sd := subdomain{prefix: normalizedOrigin[:schemeSep], suffix: normalizedOrigin[schemeSep:]} trustedSubOrigins = append(trustedSubOrigins, sd) } else { isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin) if !isValid { panic("[CSRF] Invalid origin format in configuration:" + maskValue(origin)) } trustedOrigins = append(trustedOrigins, normalizedOrigin) } } // Create the handler outside of the returned function handler := &Handler{ config: cfg, sessionManager: sessionManager, storageManager: storageManager, } // Return new handler return func(c fiber.Ctx) error { // Don't execute middleware if Next returns true if cfg.Next != nil && cfg.Next(c) { return c.Next() } // Store the CSRF handler in the context fiber.StoreInContext(c, handlerKey, handler) var token string // Action depends on the HTTP method switch c.Method() { case fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace: cookieToken := c.Cookies(cfg.CookieName) if cookieToken != "" { raw, err := getRawFromStorage(c, cookieToken, &cfg, sessionManager, storageManager) if err != nil { return cfg.ErrorHandler(c, err) } if raw != nil { token = cookieToken // Token is valid, safe to set it } } default: // Assume that anything not defined as 'safe' by RFC7231 needs protection // Evaluate Sec-Fetch-Site to reject cross-site requests earlier when available. if err := validateSecFetchSite(c); err != nil { return cfg.ErrorHandler(c, err) } // Enforce an origin check for unsafe requests. err := originMatchesHost(c, trustedOrigins, trustedSubOrigins) // If there's no origin, enforce a referer check for HTTPS connections. if errors.Is(err, errOriginNotFound) { if c.Scheme() == schemeHTTPS { err = refererMatchesHost(c, trustedOrigins, trustedSubOrigins) } else { // If it's not HTTPS, clear the error to allow the request to proceed. err = nil } } // If there's an error (either from origin check or referer check), handle it. if err != nil { return cfg.ErrorHandler(c, err) } // Extract token from client request i.e. header, query, param, form extractedToken, err := cfg.Extractor.Extract(c) if err != nil { if errors.Is(err, extractors.ErrNotFound) { return cfg.ErrorHandler(c, ErrTokenNotFound) } // If there's an error during extraction (other than not found), handle it. return cfg.ErrorHandler(c, err) } if extractedToken == "" { return cfg.ErrorHandler(c, ErrTokenNotFound) } // Double Submit Cookie validation: ensure the extracted token matches the cookie value // This prevents CSRF attacks by requiring attackers to know both the cookie AND submit // the same token through a different channel (header, form, etc.) // WARNING: If using a custom extractor that reads from the same cookie, this provides no protection if !compareStrings(extractedToken, c.Cookies(cfg.CookieName)) { return cfg.ErrorHandler(c, ErrTokenInvalid) } raw, err := getRawFromStorage(c, extractedToken, &cfg, sessionManager, storageManager) if err != nil { return cfg.ErrorHandler(c, err) } if raw == nil { // If token is not in storage, expire the cookie expireCSRFCookie(c, &cfg) // and return an error return cfg.ErrorHandler(c, ErrTokenNotFound) } if cfg.SingleUseToken { // If token is single use, delete it from storage if err := deleteTokenFromStorage(c, extractedToken, &cfg, sessionManager, storageManager); err != nil { return cfg.ErrorHandler(c, err) } } else { token = extractedToken // Token is valid, safe to set it } } // Generate CSRF token if not exist if token == "" { // And generate a new token token = cfg.KeyGenerator() } // Create or extend the token in the storage if err := createOrExtendTokenInStorage(c, token, &cfg, sessionManager, storageManager); err != nil { return cfg.ErrorHandler(c, err) } // Update the CSRF cookie updateCSRFCookie(c, &cfg, token) // Tell the browser that a new header value is generated c.Vary(fiber.HeaderCookie) // Store the token in the context fiber.StoreInContext(c, tokenKey, token) // Continue stack return c.Next() } } // TokenFromContext returns the token found in the context. // It accepts fiber.CustomCtx, fiber.Ctx, *fasthttp.RequestCtx, and context.Context. // It returns an empty string if the token does not exist. func TokenFromContext(ctx any) string { if token, ok := fiber.ValueFromContext[string](ctx, tokenKey); ok { return token } return "" } // HandlerFromContext returns the Handler found in the context. // It accepts fiber.CustomCtx, fiber.Ctx, *fasthttp.RequestCtx, and context.Context. // It returns nil if the handler does not exist. func HandlerFromContext(ctx any) *Handler { if handler, ok := fiber.ValueFromContext[*Handler](ctx, handlerKey); ok { return handler } return nil } // getRawFromStorage returns the raw value from the storage for the given token // returns nil if the token does not exist, is expired or is invalid func getRawFromStorage(c fiber.Ctx, token string, cfg *Config, sessionManager *sessionManager, storageManager *storageManager) ([]byte, error) { if cfg.Session != nil { return sessionManager.getRaw(c, token, dummyValue), nil } raw, err := storageManager.getRaw(c, token) if err != nil { return nil, fmt.Errorf("csrf: failed to fetch token from storage: %w", err) } return raw, nil } // createOrExtendTokenInStorage creates or extends the token in the storage func createOrExtendTokenInStorage(c fiber.Ctx, token string, cfg *Config, sessionManager *sessionManager, storageManager *storageManager) error { if cfg.Session != nil { sessionManager.setRaw(c, token, dummyValue, cfg.IdleTimeout) return nil } if err := storageManager.setRaw(c, token, dummyValue, cfg.IdleTimeout); err != nil { return fmt.Errorf("csrf: failed to store token in storage: %w", err) } return nil } func deleteTokenFromStorage(c fiber.Ctx, token string, cfg *Config, sessionManager *sessionManager, storageManager *storageManager) error { if cfg.Session != nil { sessionManager.delRaw(c) return nil } if err := storageManager.delRaw(c, token); err != nil { return fmt.Errorf("csrf: failed to delete token from storage: %w", err) } return nil } // Update CSRF cookie // if expireCookie is true, the cookie will expire immediately func updateCSRFCookie(c fiber.Ctx, cfg *Config, token string) { setCSRFCookie(c, cfg, token, cfg.IdleTimeout) } func expireCSRFCookie(c fiber.Ctx, cfg *Config) { setCSRFCookie(c, cfg, "", -time.Hour) } func setCSRFCookie(c fiber.Ctx, cfg *Config, token string, expiry time.Duration) { cookie := &fiber.Cookie{ Name: cfg.CookieName, Value: token, Domain: cfg.CookieDomain, Path: cfg.CookiePath, Secure: cfg.CookieSecure, HTTPOnly: cfg.CookieHTTPOnly, SameSite: cfg.CookieSameSite, SessionOnly: cfg.CookieSessionOnly, Expires: time.Now().Add(expiry), } // Set the CSRF cookie to the response c.Cookie(cookie) } // DeleteToken removes the token found in the context from the storage // and expires the CSRF cookie func (handler *Handler) DeleteToken(c fiber.Ctx) error { // Extract token from the client request cookie cookieToken := c.Cookies(handler.config.CookieName) if cookieToken == "" { return handler.config.ErrorHandler(c, ErrTokenNotFound) } // Remove the token from storage if err := deleteTokenFromStorage(c, cookieToken, &handler.config, handler.sessionManager, handler.storageManager); err != nil { return handler.config.ErrorHandler(c, err) } // Expire the cookie expireCSRFCookie(c, &handler.config) return nil } func validateSecFetchSite(c fiber.Ctx) error { secFetchSite := utils.Trim(c.Get(fiber.HeaderSecFetchSite), ' ') if secFetchSite == "" { return nil } switch utilsstrings.ToLower(secFetchSite) { case "same-origin", "none", "cross-site", "same-site": return nil default: return ErrFetchSiteInvalid } } // originMatchesHost checks that the origin header matches the host header // returns an error if the origin header is not present or is invalid // returns nil if the origin header is valid func originMatchesHost(c fiber.Ctx, trustedOrigins []string, trustedSubOrigins []subdomain) error { origin := utilsstrings.ToLower(c.Get(fiber.HeaderOrigin)) if origin == "" || origin == "null" { // "null" is set by some browsers when the origin is a secure context https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin#description return errOriginNotFound } originURL, err := url.Parse(origin) if err != nil { return ErrOriginInvalid } if schemeAndHostMatch(originURL.Scheme, originURL.Host, c.Scheme(), c.Host()) { return nil } if slices.Contains(trustedOrigins, origin) { return nil } for _, trustedSubOrigin := range trustedSubOrigins { if trustedSubOrigin.match(origin) { return nil } } return ErrOriginNoMatch } // refererMatchesHost checks that the referer header matches the host header // returns an error if the referer header is not present or is invalid // returns nil if the referer header is valid func refererMatchesHost(c fiber.Ctx, trustedOrigins []string, trustedSubOrigins []subdomain) error { referer := utilsstrings.ToLower(c.Get(fiber.HeaderReferer)) if referer == "" { return ErrRefererNotFound } refererURL, err := url.Parse(referer) if err != nil { return ErrRefererInvalid } if schemeAndHostMatch(refererURL.Scheme, refererURL.Host, c.Scheme(), c.Host()) { return nil } referer = refererURL.String() if slices.Contains(trustedOrigins, referer) { return nil } for _, trustedSubOrigin := range trustedSubOrigins { if trustedSubOrigin.match(referer) { return nil } } return ErrRefererNoMatch } ================================================ FILE: middleware/csrf/csrf_test.go ================================================ package csrf import ( "context" "errors" "net" "net/http" "net/http/httptest" "strings" "testing" "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/extractors" "github.com/gofiber/fiber/v3/middleware/session" "github.com/gofiber/utils/v2" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) type failingCSRFStorage struct { data map[string][]byte errs map[string]error } func newFailingCSRFStorage() *failingCSRFStorage { return &failingCSRFStorage{ data: make(map[string][]byte), errs: make(map[string]error), } } func (s *failingCSRFStorage) GetWithContext(_ context.Context, key string) ([]byte, error) { if err, ok := s.errs["get|"+key]; ok && err != nil { return nil, err } if val, ok := s.data[key]; ok { return append([]byte(nil), val...), nil } return nil, nil } var trustedProxyConfig = fiber.Config{ TrustProxy: true, TrustProxyConfig: fiber.TrustProxyConfig{ Proxies: []string{"0.0.0.0"}, }, } func newTrustedApp() *fiber.App { return fiber.New(trustedProxyConfig) } func newTrustedRequestCtx() *fasthttp.RequestCtx { ctx := &fasthttp.RequestCtx{} ctx.SetRemoteAddr(net.Addr(&net.TCPAddr{IP: net.ParseIP("0.0.0.0")})) return ctx } func (s *failingCSRFStorage) Get(key string) ([]byte, error) { return s.GetWithContext(context.Background(), key) } func (s *failingCSRFStorage) SetWithContext(_ context.Context, key string, val []byte, _ time.Duration) error { if err, ok := s.errs["set|"+key]; ok && err != nil { return err } s.data[key] = append([]byte(nil), val...) return nil } func (s *failingCSRFStorage) Set(key string, val []byte, exp time.Duration) error { return s.SetWithContext(context.Background(), key, val, exp) } func (s *failingCSRFStorage) DeleteWithContext(_ context.Context, key string) error { if err, ok := s.errs["del|"+key]; ok && err != nil { return err } delete(s.data, key) return nil } func (s *failingCSRFStorage) Delete(key string) error { return s.DeleteWithContext(context.Background(), key) } func (s *failingCSRFStorage) ResetWithContext(context.Context) error { s.data = make(map[string][]byte) s.errs = make(map[string]error) return nil } func (s *failingCSRFStorage) Reset() error { return s.ResetWithContext(context.Background()) } func (*failingCSRFStorage) Close() error { return nil } func TestCSRFStorageGetError(t *testing.T) { t.Parallel() storage := newFailingCSRFStorage() storage.errs["get|token"] = errors.New("boom") var captured error app := fiber.New() app.Use(New(Config{ Storage: storage, ErrorHandler: func(_ fiber.Ctx, err error) error { captured = err return fiber.ErrTeapot }, })) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.AddCookie(&http.Cookie{Name: ConfigDefault.CookieName, Value: "token"}) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusTeapot, resp.StatusCode) require.Error(t, captured) require.ErrorContains(t, captured, "csrf: failed to fetch token from storage") } func TestCSRFStorageSetError(t *testing.T) { t.Parallel() storage := newFailingCSRFStorage() storage.errs["set|token"] = errors.New("boom") var captured error app := fiber.New() app.Use(New(Config{ Storage: storage, KeyGenerator: func() string { return "token" }, ErrorHandler: func(_ fiber.Ctx, err error) error { captured = err return fiber.ErrTeapot }, })) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusTeapot, resp.StatusCode) require.Error(t, captured) require.ErrorContains(t, captured, "csrf: failed to store token in storage") } func TestCSRFStorageDeleteError(t *testing.T) { t.Parallel() storage := newFailingCSRFStorage() storage.data["token"] = []byte("value") storage.errs["del|token"] = errors.New("boom") var captured error app := fiber.New() app.Use(New(Config{ Storage: storage, SingleUseToken: true, ErrorHandler: func(_ fiber.Ctx, err error) error { captured = err return fiber.ErrTeapot }, })) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) req := httptest.NewRequest(fiber.MethodPost, "/", http.NoBody) req.Header.Set(HeaderName, "token") req.AddCookie(&http.Cookie{Name: ConfigDefault.CookieName, Value: "token"}) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusTeapot, resp.StatusCode) require.Error(t, captured) require.ErrorContains(t, captured, "csrf: failed to delete token from storage") } func Test_CSRF(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} methods := [4]string{fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace} for _, method := range methods { // Generate CSRF token ctx.Request.Header.SetMethod(method) h(ctx) // Without CSRF cookie ctx.Request.Header.Reset() ctx.Request.ResetBody() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) // Invalid CSRF token ctx.Request.Header.Reset() ctx.Request.ResetBody() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, "johndoe") h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) // Valid CSRF token ctx.Request.Header.Reset() ctx.Request.ResetBody() ctx.Response.Reset() ctx.Request.Header.SetMethod(method) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) } } func Test_CSRF_WithSession(t *testing.T) { t.Parallel() // session store store := session.NewStore(session.Config{ Extractor: extractors.FromCookie("_session"), }) // fiber instance app := fiber.New() // fiber context ctx := &fasthttp.RequestCtx{} defer app.ReleaseCtx(app.AcquireCtx(ctx)) // get session sess, err := store.Get(app.AcquireCtx(ctx)) require.NoError(t, err) require.True(t, sess.Fresh()) // the session string is no longer be 123 newSessionIDString := sess.ID() require.NoError(t, sess.Save()) app.AcquireCtx(ctx).Request().Header.SetCookie("_session", newSessionIDString) // middleware config config := Config{ Session: store, } // middleware app.Use(New(config)) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() methods := [4]string{fiber.MethodGet, fiber.MethodHead, fiber.MethodOptions, fiber.MethodTrace} for _, method := range methods { // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.SetCookie("_session", newSessionIDString) h(ctx) // Without CSRF cookie ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.SetCookie("_session", newSessionIDString) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) // Empty/invalid CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, "johndoe") ctx.Request.Header.SetCookie("_session", newSessionIDString) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) // Valid CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(method) ctx.Request.Header.SetCookie("_session", newSessionIDString) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) for header := range strings.SplitSeq(token, ";") { if strings.Split(utils.TrimSpace(header), "=")[0] == ConfigDefault.CookieName { token = strings.Split(header, "=")[1] break } } ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie("_session", newSessionIDString) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) } } // go test -run Test_CSRF_WithSession_Middleware func Test_CSRF_WithSession_Middleware(t *testing.T) { t.Parallel() app := fiber.New() // session mw smh, sstore := session.NewWithStore() // csrf mw cmh := New(Config{ Session: sstore, }) app.Use(smh) app.Use(cmh) app.Get("/", func(c fiber.Ctx) error { sess := session.FromContext(c) sess.Set("hello", "world") return c.SendStatus(fiber.StatusOK) }) app.Post("/", func(c fiber.Ctx) error { sess := session.FromContext(c) if sess.Get("hello") != "world" { return c.SendStatus(fiber.StatusInternalServerError) } return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Generate CSRF token and session_id ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) csrfCookie := fasthttp.AcquireCookie() csrfCookie.SetKey(ConfigDefault.CookieName) require.True(t, ctx.Response.Header.Cookie(csrfCookie)) csrfToken := string(csrfCookie.Value()) require.NotEmpty(t, csrfToken) fasthttp.ReleaseCookie(csrfCookie) sessionCookie := fasthttp.AcquireCookie() sessionCookie.SetKey("session_id") require.True(t, ctx.Response.Header.Cookie(sessionCookie)) sessionID := string(sessionCookie.Value()) require.NotEmpty(t, sessionID) fasthttp.ReleaseCookie(sessionCookie) // Use the CSRF token and session_id ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, csrfToken) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, csrfToken) ctx.Request.Header.SetCookie("session_id", sessionID) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) } // go test -run Test_CSRF_ExpiredToken func Test_CSRF_ExpiredToken(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ IdleTimeout: 1 * time.Second, })) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] // Use the CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Wait for the token to expire time.Sleep(1250 * time.Millisecond) // Expired CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) } // go test -run Test_CSRF_ExpiredToken_WithSession func Test_CSRF_ExpiredToken_WithSession(t *testing.T) { t.Parallel() // session store store := session.NewStore(session.Config{ Extractor: extractors.FromCookie("_session"), }) // fiber instance app := fiber.New() // fiber context ctx := &fasthttp.RequestCtx{} defer app.ReleaseCtx(app.AcquireCtx(ctx)) // get session sess, err := store.Get(app.AcquireCtx(ctx)) require.NoError(t, err) require.True(t, sess.Fresh()) // get session id newSessionIDString := sess.ID() require.NoError(t, sess.Save()) app.AcquireCtx(ctx).Request().Header.SetCookie("_session", newSessionIDString) // middleware config config := Config{ Session: store, IdleTimeout: 1 * time.Second, } // middleware app.Use(New(config)) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.SetCookie("_session", newSessionIDString) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) for header := range strings.SplitSeq(token, ";") { if strings.Split(utils.TrimSpace(header), "=")[0] == ConfigDefault.CookieName { token = strings.Split(header, "=")[1] break } } // Use the CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie("_session", newSessionIDString) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Wait for the token to expire time.Sleep(1*time.Second + 100*time.Millisecond) // Expired CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie("_session", newSessionIDString) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) } // go test -run Test_CSRF_MultiUseToken func Test_CSRF_MultiUseToken(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Extractor: extractors.FromHeader("X-Csrf-Token"), })) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Invalid CSRF token ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set("X-Csrf-Token", "johndoe") h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) // Generate CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set("X-Csrf-Token", token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) newToken := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) newToken = strings.Split(strings.Split(newToken, ";")[0], "=")[1] require.Equal(t, 200, ctx.Response.StatusCode()) // Check if the token is not a dummy value require.Equal(t, token, newToken) } // go test -run Test_CSRF_SingleUseToken func Test_CSRF_SingleUseToken(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ SingleUseToken: true, })) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] // Use the CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) newToken := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) newToken = strings.Split(strings.Split(newToken, ";")[0], "=")[1] if token == newToken { t.Error("new token should not be the same as the old token") } // Use the CSRF token again ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) } // go test -run Test_CSRF_Next func Test_CSRF_Next(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Next: func(_ fiber.Ctx) bool { return true }, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusNotFound, resp.StatusCode) } func Test_CSRF_From_Form(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Extractor: extractors.FromForm("_csrf")})) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Invalid CSRF token ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) // Generate CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm) ctx.Request.SetBodyString("_csrf=" + token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) } func Test_CSRF_From_Query(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Extractor: extractors.FromQuery("_csrf")})) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Invalid CSRF token ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.SetRequestURI("/?_csrf=" + utils.UUIDv4()) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) // Generate CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.SetRequestURI("/") h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] ctx.Request.Reset() ctx.Response.Reset() ctx.Request.SetRequestURI("/?_csrf=" + token) ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) require.Equal(t, "OK", string(ctx.Response.Body())) } func Test_CSRF_From_Param(t *testing.T) { t.Parallel() app := fiber.New() csrfGroup := app.Group("/:csrf", New(Config{Extractor: extractors.FromParam("csrf")})) csrfGroup.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Invalid CSRF token ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.SetRequestURI("/" + utils.UUIDv4()) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) // Generate CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.SetRequestURI("/" + utils.UUIDv4()) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] ctx.Request.Reset() ctx.Response.Reset() ctx.Request.SetRequestURI("/" + token) ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) require.Equal(t, "OK", string(ctx.Response.Body())) } func Test_CSRF_From_Custom(t *testing.T) { t.Parallel() app := fiber.New() extractor := extractors.Extractor{ Extract: func(c fiber.Ctx) (string, error) { body := string(c.Body()) // Generate the correct extractor to get the token from the correct location selectors := strings.Split(body, "=") if len(selectors) != 2 || selectors[1] == "" { return "", extractors.ErrNotFound } return selectors[1], nil }, Source: extractors.SourceCustom, Key: "_csrf", } app.Use(New(Config{Extractor: extractor})) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Invalid CSRF token ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) // Generate CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain) ctx.Request.SetBodyString("_csrf=" + token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) } func Test_CSRF_Extractor_EmptyString(t *testing.T) { t.Parallel() app := fiber.New() extractor := extractors.Extractor{ Extract: func(_ fiber.Ctx) (string, error) { return "", nil }, Source: extractors.SourceCustom, Key: "_csrf", } errorHandler := func(c fiber.Ctx, err error) error { return c.Status(403).SendString(err.Error()) } app.Use(New(Config{ Extractor: extractor, ErrorHandler: errorHandler, })) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Generate CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMETextPlain) ctx.Request.SetBodyString("_csrf=" + token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) require.Equal(t, ErrTokenNotFound.Error(), string(ctx.Response.Body())) } func Test_CSRF_SecFetchSite(t *testing.T) { t.Parallel() errorHandler := func(c fiber.Ctx, err error) error { return c.Status(fiber.StatusForbidden).SendString(err.Error()) } app := newTrustedApp() app.Use(New(Config{ErrorHandler: errorHandler})) app.All("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := newTrustedRequestCtx() ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.URI().SetScheme("http") ctx.Request.URI().SetHost("example.com") ctx.Request.Header.SetHost("example.com") h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] tests := []struct { name string method string secFetchSite string origin string expectedStatus int16 https bool expectFetchSiteInvalid bool }{ { name: "same-origin allowed", method: fiber.MethodPost, secFetchSite: "same-origin", origin: "http://example.com", expectedStatus: http.StatusOK, }, { name: "none allowed", method: fiber.MethodPost, secFetchSite: "none", origin: "http://example.com", expectedStatus: http.StatusOK, }, { name: "cross-site with origin allowed", method: fiber.MethodPost, secFetchSite: "cross-site", origin: "http://example.com", expectedStatus: http.StatusOK, }, { name: "same-site with origin allowed", method: fiber.MethodPost, secFetchSite: "same-site", origin: "http://example.com", expectedStatus: http.StatusOK, }, { name: "cross-site with mismatched origin blocked", method: fiber.MethodPost, secFetchSite: "cross-site", origin: "https://attacker.example", expectedStatus: http.StatusForbidden, }, { name: "same-site with null origin blocked", method: fiber.MethodPost, secFetchSite: "same-site", origin: "null", expectedStatus: http.StatusForbidden, https: true, }, { name: "invalid header blocked", method: fiber.MethodPost, secFetchSite: "weird", origin: "http://example.com", expectedStatus: http.StatusForbidden, expectFetchSiteInvalid: true, }, { name: "no header with no origin", method: fiber.MethodPost, origin: "", expectedStatus: http.StatusOK, }, { name: "no header with matching origin", method: fiber.MethodPost, origin: "http://example.com", expectedStatus: http.StatusOK, }, { name: "no header with mismatched origin", method: fiber.MethodPost, origin: "https://attacker.example", expectedStatus: http.StatusForbidden, }, { name: "no header with null origin", method: fiber.MethodPost, origin: "null", expectedStatus: http.StatusForbidden, https: true, }, { name: "GET allowed", method: fiber.MethodGet, secFetchSite: "cross-site", expectedStatus: http.StatusOK, }, { name: "HEAD allowed", method: fiber.MethodHead, secFetchSite: "cross-site", expectedStatus: http.StatusOK, }, { name: "OPTIONS allowed", method: fiber.MethodOptions, secFetchSite: "cross-site", expectedStatus: http.StatusOK, }, { name: "PUT with mismatched origin blocked", method: fiber.MethodPut, secFetchSite: "cross-site", origin: "https://attacker.example", expectedStatus: http.StatusForbidden, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() c := &fasthttp.RequestCtx{} scheme := "http" if tt.https { scheme = "https" } c.Request.Header.SetMethod(tt.method) c.Request.URI().SetScheme(scheme) c.Request.URI().SetHost("example.com") c.Request.Header.SetHost("example.com") c.Request.Header.SetProtocol(scheme) if scheme == "https" { c.Request.Header.Set(fiber.HeaderXForwardedProto, "https") } if tt.origin != "" { c.Request.Header.Set(fiber.HeaderOrigin, tt.origin) } if tt.secFetchSite != "" { c.Request.Header.Set(fiber.HeaderSecFetchSite, tt.secFetchSite) } safe := tt.method == fiber.MethodGet || tt.method == fiber.MethodHead || tt.method == fiber.MethodOptions || tt.method == fiber.MethodTrace if !safe { c.Request.Header.Set(HeaderName, token) c.Request.Header.SetCookie(ConfigDefault.CookieName, token) } h(c) require.Equal(t, int(tt.expectedStatus), c.Response.StatusCode()) if tt.expectFetchSiteInvalid { require.Equal(t, ErrFetchSiteInvalid.Error(), string(c.Response.Body())) } }) } } func Test_CSRF_Origin(t *testing.T) { t.Parallel() app := newTrustedApp() app.Use(New(Config{CookieSecure: true})) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := newTrustedRequestCtx() ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "http") h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] // Test Correct Origin with port ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.URI().SetScheme("http") ctx.Request.URI().SetHost("example.com:8080") ctx.Request.Header.SetProtocol("http") ctx.Request.Header.SetHost("example.com:8080") ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com:8080") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test Origin without default HTTP port against host with default port ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.URI().SetScheme("http") ctx.Request.URI().SetHost("example.com:80") ctx.Request.Header.SetProtocol("http") ctx.Request.Header.SetHost("example.com:80") ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test Origin with default HTTP port against host without port ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.URI().SetScheme("http") ctx.Request.URI().SetHost("example.com") ctx.Request.Header.SetProtocol("http") ctx.Request.Header.SetHost("example.com") ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com:80") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test Correct Origin with wrong port ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.URI().SetScheme("http") ctx.Request.URI().SetHost("example.com") ctx.Request.Header.SetProtocol("http") ctx.Request.Header.SetHost("example.com") ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com:3000") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) // Test Correct Origin with null ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.URI().SetScheme("http") ctx.Request.URI().SetHost("example.com") ctx.Request.Header.SetProtocol("http") ctx.Request.Header.SetHost("example.com") ctx.Request.Header.Set(fiber.HeaderOrigin, "null") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test Correct Origin with ReverseProxy ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.URI().SetScheme("http") ctx.Request.URI().SetHost("10.0.1.42.com:8080") ctx.Request.Header.SetProtocol("http") ctx.Request.Header.SetHost("10.0.1.42:8080") ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "http") ctx.Request.Header.Set(fiber.HeaderXForwardedHost, "example.com") ctx.Request.Header.Set(fiber.HeaderXForwardedFor, `192.0.2.43, "[2001:db8:cafe::17]"`) ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test Origin without default HTTPS port against host with default port ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.URI().SetScheme("https") ctx.Request.URI().SetHost("example.com:443") ctx.Request.Header.SetProtocol("https") ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") ctx.Request.Header.SetHost("example.com:443") ctx.Request.Header.Set(fiber.HeaderOrigin, "https://example.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test Origin with default HTTPS port against host without port ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.URI().SetScheme("https") ctx.Request.URI().SetHost("example.com") ctx.Request.Header.SetProtocol("https") ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") ctx.Request.Header.SetHost("example.com") ctx.Request.Header.Set(fiber.HeaderOrigin, "https://example.com:443") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test Correct Origin with ReverseProxy Missing X-Forwarded-* Headers ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.URI().SetScheme("http") ctx.Request.URI().SetHost("10.0.1.42:8080") ctx.Request.Header.SetProtocol("http") ctx.Request.Header.SetHost("10.0.1.42:8080") ctx.Request.Header.Set(fiber.HeaderXUrlScheme, "http") // We need to set this header to make sure c.Protocol() returns http ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) // Test Wrong Origin ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "http") ctx.Request.Header.Set(fiber.HeaderXForwardedHost, "example.com") ctx.Request.Header.Set(fiber.HeaderOrigin, "http://csrf.example.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) } func Test_CSRF_TrustedOrigins(t *testing.T) { t.Parallel() app := newTrustedApp() app.Use(New(Config{ CookieSecure: true, TrustedOrigins: []string{ "http://safe.example.com", "https://safe.example.com", "http://*.domain-1.com", "https://*.domain-1.com", }, })) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := newTrustedRequestCtx() ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] // Test Trusted Origin ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.URI().SetScheme("http") ctx.Request.URI().SetHost("example.com") ctx.Request.Header.SetProtocol("http") ctx.Request.Header.SetHost("example.com") ctx.Request.Header.Set(fiber.HeaderOrigin, "http://safe.example.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test Trusted Origin Subdomain ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.URI().SetScheme("http") ctx.Request.URI().SetHost("domain-1.com") ctx.Request.Header.SetProtocol("http") ctx.Request.Header.SetHost("domain-1.com") ctx.Request.Header.Set(fiber.HeaderOrigin, "http://safe.domain-1.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test Trusted Origin deeply nested subdomain ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.URI().SetScheme("https") ctx.Request.URI().SetHost("a.b.c.domain-1.com") ctx.Request.Header.SetProtocol("https") ctx.Request.Header.SetHost("a.b.c.domain-1.com") ctx.Request.Header.Set(fiber.HeaderOrigin, "https://a.b.c.domain-1.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test Trusted Origin Invalid ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.URI().SetScheme("http") ctx.Request.URI().SetHost("domain-1.com") ctx.Request.Header.SetProtocol("http") ctx.Request.Header.SetHost("domain-1.com") ctx.Request.Header.Set(fiber.HeaderOrigin, "http://evildomain-1.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) // Test Trusted Origin malformed subdomain ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.URI().SetScheme("http") ctx.Request.URI().SetHost("domain-1.com") ctx.Request.Header.SetProtocol("http") ctx.Request.Header.SetHost("domain-1.com") ctx.Request.Header.Set(fiber.HeaderOrigin, "http://evil.comdomain-1.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) // Test Trusted Referer ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") ctx.Request.URI().SetScheme("https") ctx.Request.URI().SetHost("example.com") ctx.Request.Header.SetProtocol("https") ctx.Request.Header.SetHost("example.com") ctx.Request.Header.Set(fiber.HeaderReferer, "https://safe.example.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test Trusted Referer Wildcard ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") ctx.Request.URI().SetScheme("https") ctx.Request.URI().SetHost("domain-1.com") ctx.Request.Header.SetProtocol("https") ctx.Request.Header.SetHost("domain-1.com") ctx.Request.Header.Set(fiber.HeaderReferer, "https://safe.domain-1.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test Trusted Referer deeply nested subdomain ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") ctx.Request.URI().SetScheme("https") ctx.Request.URI().SetHost("a.b.c.domain-1.com") ctx.Request.Header.SetProtocol("https") ctx.Request.Header.SetHost("a.b.c.domain-1.com") ctx.Request.Header.Set(fiber.HeaderReferer, "https://a.b.c.domain-1.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test Trusted Referer Invalid ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") ctx.Request.URI().SetScheme("https") ctx.Request.URI().SetHost("api.domain-1.com") ctx.Request.Header.SetProtocol("https") ctx.Request.Header.SetHost("api.domain-1.com") ctx.Request.Header.Set(fiber.HeaderReferer, "https://evildomain-1.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) } func Test_CSRF_TrustedOrigins_InvalidOrigins(t *testing.T) { t.Parallel() tests := []struct { name string origin string }{ {name: "No Scheme", origin: "localhost"}, {name: "Wildcard", origin: "https://*"}, {name: "Wildcard domain", origin: "https://*example.com"}, {name: "File Scheme", origin: "file://example.com"}, {name: "FTP Scheme", origin: "ftp://example.com"}, {name: "Port Wildcard", origin: "http://example.com:*"}, {name: "Multiple Wildcards", origin: "https://*.*.com"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { origin := tt.origin t.Parallel() require.Panics(t, func() { app := fiber.New() app.Use(New(Config{ CookieSecure: true, TrustedOrigins: []string{origin}, })) }, "Expected panic") }) } } func Test_CSRF_Referer(t *testing.T) { t.Parallel() app := newTrustedApp() app.Use(New(Config{CookieSecure: true})) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := newTrustedRequestCtx() ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] // Test Correct Referer with port ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") ctx.Request.URI().SetScheme("https") ctx.Request.URI().SetHost("example.com:8443") ctx.Request.Header.SetProtocol("https") ctx.Request.Header.SetHost("example.com:8443") ctx.Request.Header.Set(fiber.HeaderReferer, ctx.Request.URI().String()) ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test Referer without default HTTPS port against host with default port ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") ctx.Request.URI().SetScheme("https") ctx.Request.URI().SetHost("example.com:443") ctx.Request.Header.SetProtocol("https") ctx.Request.Header.SetHost("example.com:443") ctx.Request.Header.Set(fiber.HeaderReferer, "https://example.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test Referer with default HTTPS port against host without port ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") ctx.Request.URI().SetScheme("https") ctx.Request.URI().SetHost("example.com") ctx.Request.Header.SetProtocol("https") ctx.Request.Header.SetHost("example.com") ctx.Request.Header.Set(fiber.HeaderReferer, "https://example.com:443") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test Correct Referer with ReverseProxy ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") ctx.Request.URI().SetScheme("https") ctx.Request.URI().SetHost("10.0.1.42.com:8443") ctx.Request.Header.SetProtocol("https") ctx.Request.Header.SetHost("10.0.1.42:8443") ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") ctx.Request.Header.Set(fiber.HeaderXForwardedHost, "example.com") ctx.Request.Header.Set(fiber.HeaderXForwardedFor, `192.0.2.43, "[2001:db8:cafe::17]"`) ctx.Request.Header.Set(fiber.HeaderReferer, "https://example.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test Referer without default HTTP port against host with default port ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "http") ctx.Request.URI().SetScheme("http") ctx.Request.URI().SetHost("example.com:80") ctx.Request.Header.SetProtocol("http") ctx.Request.Header.SetHost("example.com:80") ctx.Request.Header.Set(fiber.HeaderReferer, "http://example.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test Referer with default HTTP port against host without port ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "http") ctx.Request.URI().SetScheme("http") ctx.Request.URI().SetHost("example.com") ctx.Request.Header.SetProtocol("http") ctx.Request.Header.SetHost("example.com") ctx.Request.Header.Set(fiber.HeaderReferer, "http://example.com:80") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test Correct Referer with ReverseProxy Missing X-Forwarded-* Headers ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") ctx.Request.URI().SetScheme("https") ctx.Request.URI().SetHost("10.0.1.42:8443") ctx.Request.Header.SetProtocol("https") ctx.Request.Header.SetHost("10.0.1.42:8443") ctx.Request.Header.Set(fiber.HeaderXUrlScheme, "https") // We need to set this header to make sure c.Protocol() returns https ctx.Request.Header.Set(fiber.HeaderReferer, "https://example.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) // Test Correct Referer with path ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") ctx.Request.Header.Set(fiber.HeaderXForwardedHost, "example.com") ctx.Request.Header.Set(fiber.HeaderReferer, "https://example.com/action/items?gogogo=true") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test Wrong Referer ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") ctx.Request.Header.Set(fiber.HeaderXForwardedHost, "example.com") ctx.Request.Header.Set(fiber.HeaderReferer, "https://csrf.example.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) } func Test_CSRF_DeleteToken(t *testing.T) { t.Parallel() app := fiber.New() config := ConfigDefault app.Use(New(config)) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // DeleteToken after token generation and remove the cookie ctx.Request.Header.Reset() ctx.Request.ResetBody() ctx.Response.Reset() ctx.Request.Header.Set(HeaderName, "") handler := HandlerFromContext(app.AcquireCtx(ctx)) if handler != nil { ctx.Request.Header.DelAllCookies() err := handler.DeleteToken(app.AcquireCtx(ctx)) require.ErrorIs(t, err, ErrTokenNotFound) } h(ctx) // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] // Delete the CSRF token ctx.Request.Header.Reset() ctx.Request.ResetBody() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) handler = HandlerFromContext(app.AcquireCtx(ctx)) if handler != nil { if err := handler.DeleteToken(app.AcquireCtx(ctx)); err != nil { t.Fatal(err) } } h(ctx) ctx.Request.Header.Reset() ctx.Request.ResetBody() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) } func Test_CSRF_DeleteToken_WithSession(t *testing.T) { t.Parallel() // session store store := session.NewStore(session.Config{ Extractor: extractors.FromCookie("_session"), }) // fiber instance app := fiber.New() // fiber context ctx := &fasthttp.RequestCtx{} // get session sess, err := store.Get(app.AcquireCtx(ctx)) require.NoError(t, err) require.True(t, sess.Fresh()) // the session string is no longer be 123 newSessionIDString := sess.ID() require.NoError(t, sess.Save()) app.AcquireCtx(ctx).Request().Header.SetCookie("_session", newSessionIDString) // middleware config config := Config{ Session: store, } // middleware app.Use(New(config)) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.SetCookie("_session", newSessionIDString) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] // Delete the CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) handler := HandlerFromContext(app.AcquireCtx(ctx)) if handler != nil { if err := handler.DeleteToken(app.AcquireCtx(ctx)); err != nil { t.Fatal(err) } } h(ctx) ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) ctx.Request.Header.SetCookie("_session", newSessionIDString) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) } func Test_CSRF_ErrorHandler_InvalidToken(t *testing.T) { t.Parallel() app := fiber.New() errHandler := func(ctx fiber.Ctx, err error) error { require.Equal(t, ErrTokenInvalid, err) return ctx.Status(419).Send([]byte("invalid CSRF token")) } app.Use(New(Config{ErrorHandler: errHandler})) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) // invalid CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, "johndoe") h(ctx) require.Equal(t, 419, ctx.Response.StatusCode()) require.Equal(t, "invalid CSRF token", string(ctx.Response.Body())) } func Test_CSRF_ErrorHandler_EmptyToken(t *testing.T) { t.Parallel() app := fiber.New() errHandler := func(ctx fiber.Ctx, err error) error { require.Equal(t, ErrTokenNotFound, err) return ctx.Status(419).Send([]byte("empty CSRF token")) } app.Use(New(Config{ErrorHandler: errHandler})) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) // empty CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) h(ctx) require.Equal(t, 419, ctx.Response.StatusCode()) require.Equal(t, "empty CSRF token", string(ctx.Response.Body())) } func Test_CSRF_ErrorHandler_MissingReferer(t *testing.T) { t.Parallel() app := newTrustedApp() errHandler := func(ctx fiber.Ctx, err error) error { require.Equal(t, ErrRefererNotFound, err) return ctx.Status(419).Send([]byte("empty CSRF token")) } app.Use(New(Config{ CookieSecure: true, ErrorHandler: errHandler, })) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := newTrustedRequestCtx() ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") ctx.Request.Header.Set(fiber.HeaderXForwardedHost, "example.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 419, ctx.Response.StatusCode()) } func Test_CSRF_Cookie_Injection_Exploit(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Inject CSRF token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderCookie, "csrf_=pwned;") ctx.Request.SetRequestURI("/") h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] // Exploit CSRF token we just injected ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.Set(fiber.HeaderCookie, "csrf_=pwned;") h(ctx) require.Equal(t, 403, ctx.Response.StatusCode(), "CSRF exploit successful") } // Test_CSRF_UnsafeHeaderValue ensures that unsafe header values, such as those described in https://github.com/gofiber/fiber/issues/2045, are rejected and the bug remains fixed. // go test -race -run Test_CSRF_UnsafeHeaderValue func Test_CSRF_UnsafeHeaderValue(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) app.Get("/test", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) var token string for _, c := range resp.Cookies() { if c.Name != ConfigDefault.CookieName { continue } token = c.Value break } t.Log("token", token) getReq := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) getReq.Header.Set(HeaderName, token) resp, err = app.Test(getReq) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) getReq = httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody) getReq.Header.Set("X-Requested-With", "XMLHttpRequest") getReq.Header.Set(fiber.HeaderCacheControl, "no") getReq.Header.Set(HeaderName, token) getReq.AddCookie(&http.Cookie{ Name: ConfigDefault.CookieName, Value: token, }) resp, err = app.Test(getReq) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) getReq.Header.Set(fiber.HeaderAccept, "*/*") getReq.Header.Del(HeaderName) resp, err = app.Test(getReq) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) postReq := httptest.NewRequest(fiber.MethodPost, "/", http.NoBody) postReq.Header.Set("X-Requested-With", "XMLHttpRequest") postReq.Header.Set(HeaderName, token) postReq.AddCookie(&http.Cookie{ Name: ConfigDefault.CookieName, Value: token, }) resp, err = app.Test(postReq) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) } // go test -v -run=^$ -bench=Benchmark_Middleware_CSRF_Check -benchmem -count=4 func Benchmark_Middleware_CSRF_Check(b *testing.B) { app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] // Test Correct Referer POST ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") ctx.Request.URI().SetScheme("https") ctx.Request.URI().SetHost("example.com") ctx.Request.Header.SetProtocol("https") ctx.Request.Header.SetHost("example.com") ctx.Request.Header.Set(fiber.HeaderReferer, "https://example.com") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) b.ReportAllocs() for b.Loop() { h(ctx) } require.Equal(b, fiber.StatusTeapot, ctx.Response.Header.StatusCode()) } // go test -v -run=^$ -bench=Benchmark_Middleware_CSRF_GenerateToken -benchmem -count=4 func Benchmark_Middleware_CSRF_GenerateToken(b *testing.B) { app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) b.ReportAllocs() for b.Loop() { h(ctx) } // Ensure the GET request returns a 418 status code require.Equal(b, fiber.StatusTeapot, ctx.Response.Header.StatusCode()) } func Test_CSRF_InvalidURLHeaders(t *testing.T) { t.Parallel() app := newTrustedApp() errHandler := func(ctx fiber.Ctx, err error) error { return ctx.Status(419).Send([]byte(err.Error())) } app.Use(New(Config{ErrorHandler: errHandler})) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := newTrustedRequestCtx() // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "http") h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] // invalid Origin ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.URI().SetScheme("http") ctx.Request.URI().SetHost("example.com") ctx.Request.Header.SetProtocol("http") ctx.Request.Header.SetHost("example.com") ctx.Request.Header.Set(fiber.HeaderOrigin, "http://[::1]:%38%30/Invalid Origin") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 419, ctx.Response.StatusCode()) require.Equal(t, ErrOriginInvalid.Error(), string(ctx.Response.Body())) // invalid Referer ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https") ctx.Request.URI().SetScheme("https") ctx.Request.URI().SetHost("example.com") ctx.Request.Header.SetProtocol("https") ctx.Request.Header.SetHost("example.com") ctx.Request.Header.Set(fiber.HeaderReferer, "http://[::1]:%38%30/Invalid Referer") ctx.Request.Header.Set(HeaderName, token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 419, ctx.Response.StatusCode()) require.Equal(t, ErrRefererInvalid.Error(), string(ctx.Response.Body())) } func Test_CSRF_TokenFromContext(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { token := TokenFromContext(c) require.NotEmpty(t, token) return c.SendStatus(fiber.StatusOK) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) } func Test_CSRF_FromContextMethods(t *testing.T) { t.Parallel() app := fiber.New(fiber.Config{PassLocalsToContext: true}) app.Use(New()) app.Get("/", func(c fiber.Ctx) error { token := TokenFromContext(c) require.NotEmpty(t, token) handler := HandlerFromContext(c) require.NotNil(t, handler) customCtx, ok := c.(fiber.CustomCtx) require.True(t, ok) require.Equal(t, token, TokenFromContext(customCtx)) require.Equal(t, handler, HandlerFromContext(customCtx)) require.Equal(t, token, TokenFromContext(c.RequestCtx())) require.Equal(t, token, TokenFromContext(c.Context())) require.Equal(t, handler, HandlerFromContext(c.RequestCtx())) require.Equal(t, handler, HandlerFromContext(c.Context())) return c.SendStatus(fiber.StatusOK) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) } func Test_CSRF_FromContextMethods_Invalid(t *testing.T) { t.Parallel() app := fiber.New() app.Get("/", func(c fiber.Ctx) error { token := TokenFromContext(c) require.Empty(t, token) handler := HandlerFromContext(c) require.Nil(t, handler) return c.SendStatus(fiber.StatusOK) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) } func Test_deleteTokenFromStorage(t *testing.T) { t.Parallel() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(func() { app.ReleaseCtx(ctx) }) token := "token123" dummy := []byte("dummy") store := session.NewStore() sm := newSessionManager(store) stm := newStorageManager(nil, true) sm.setRaw(ctx, token, dummy, time.Minute) cfg := Config{Session: store} require.NoError(t, deleteTokenFromStorage(ctx, token, &cfg, sm, stm)) raw := sm.getRaw(ctx, token, dummy) require.Nil(t, raw) sm2 := newSessionManager(nil) stm2 := newStorageManager(nil, true) require.NoError(t, stm2.setRaw(context.Background(), token, dummy, time.Minute)) cfg = Config{} require.NoError(t, deleteTokenFromStorage(ctx, token, &cfg, sm2, stm2)) raw, err := stm2.getRaw(context.Background(), token) require.NoError(t, err) require.Nil(t, raw) } func Test_storageManager_logKey(t *testing.T) { t.Parallel() redacted := newStorageManager(nil, true) require.Equal(t, redactedKey, redacted.logKey("secret")) plain := newStorageManager(nil, false) require.Equal(t, "secret", plain.logKey("secret")) } func Test_CSRF_Chain_Extractor(t *testing.T) { t.Parallel() app := fiber.New() // Chain extractor: try header first, fall back to form chainExtractor := extractors.Chain( extractors.FromHeader("X-Csrf-Token"), extractors.FromForm("_csrf"), ) app.Use(New(Config{Extractor: chainExtractor})) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] // Test 1: Token in header (first extractor should succeed) ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set("X-Csrf-Token", token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test 2: Token in form (fallback should succeed) ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm) ctx.Request.SetBodyString("_csrf=" + token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test 3: Token in both header and form (header should take precedence) ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm) ctx.Request.Header.Set("X-Csrf-Token", token) ctx.Request.SetBodyString("_csrf=wrong_token") ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test 4: No token in either location ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) // Test 5: Wrong token in both locations ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm) ctx.Request.Header.Set("X-Csrf-Token", "wrong_token") ctx.Request.SetBodyString("_csrf=also_wrong") ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) } func Test_CSRF_Chain_Extractor_Empty(t *testing.T) { t.Parallel() app := fiber.New() // Empty chain extractor emptyChain := extractors.Chain() app.Use(New(Config{Extractor: emptyChain})) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] // Test with empty chain - should always fail ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set("X-Csrf-Token", token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) } func Test_CSRF_Chain_Extractor_SingleExtractor(t *testing.T) { t.Parallel() app := fiber.New() // Chain with single extractor (should behave like the single extractor) singleChain := extractors.Chain(extractors.FromHeader("X-Csrf-Token")) app.Use(New(Config{Extractor: singleChain})) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] // Test valid token in header ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set("X-Csrf-Token", token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) // Test no token ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode()) } func Test_CSRF_All_Extractors(t *testing.T) { t.Parallel() testCases := []struct { setupRequest func(ctx *fasthttp.RequestCtx, token string) name string extractor extractors.Extractor expectStatus int }{ { name: "FromHeader", extractor: extractors.FromHeader("X-Csrf-Token"), setupRequest: func(ctx *fasthttp.RequestCtx, token string) { ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set("X-Csrf-Token", token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) }, expectStatus: 200, }, { name: "FromHeader_Missing", extractor: extractors.FromHeader("X-Csrf-Token"), setupRequest: func(ctx *fasthttp.RequestCtx, token string) { ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) }, expectStatus: 403, }, { name: "FromForm", extractor: extractors.FromForm("_csrf"), setupRequest: func(ctx *fasthttp.RequestCtx, token string) { ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm) ctx.Request.SetBodyString("_csrf=" + token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) }, expectStatus: 200, }, { name: "FromForm_Missing", extractor: extractors.FromForm("_csrf"), setupRequest: func(ctx *fasthttp.RequestCtx, token string) { ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.Set(fiber.HeaderContentType, fiber.MIMEApplicationForm) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) }, expectStatus: 403, }, { name: "FromQuery", extractor: extractors.FromQuery("csrf_token"), setupRequest: func(ctx *fasthttp.RequestCtx, token string) { ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.SetRequestURI("/?csrf_token=" + token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) }, expectStatus: 200, }, { name: "FromQuery_Missing", extractor: extractors.FromQuery("csrf_token"), setupRequest: func(ctx *fasthttp.RequestCtx, token string) { ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.SetRequestURI("/") ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) }, expectStatus: 403, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{Extractor: tc.extractor})) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.SetRequestURI("/") h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] // Test the extractor ctx.Request.Reset() ctx.Response.Reset() tc.setupRequest(ctx, token) h(ctx) require.Equal(t, tc.expectStatus, ctx.Response.StatusCode(), "Test case %s failed: expected %d, got %d", tc.name, tc.expectStatus, ctx.Response.StatusCode()) }) } } func Test_CSRF_Param_Extractor(t *testing.T) { t.Parallel() testCases := []struct { setupRequest func(ctx *fasthttp.RequestCtx, token string) name string expectStatus int }{ { name: "FromParam_Valid", setupRequest: func(ctx *fasthttp.RequestCtx, token string) { ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.SetRequestURI("/" + token) ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) }, expectStatus: 200, }, { name: "FromParam_Invalid", setupRequest: func(ctx *fasthttp.RequestCtx, token string) { ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.SetRequestURI("/wrong_token") ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) }, expectStatus: 403, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() app := fiber.New() // Only use param-based routing for param extractor tests csrfGroup := app.Group("/:csrf", New(Config{Extractor: extractors.FromParam("csrf")})) csrfGroup.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.SetRequestURI("/" + utils.UUIDv4()) h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] // Test the extractor ctx.Request.Reset() ctx.Response.Reset() tc.setupRequest(ctx, token) h(ctx) require.Equal(t, tc.expectStatus, ctx.Response.StatusCode(), "Test case %s failed: expected %d, got %d", tc.name, tc.expectStatus, ctx.Response.StatusCode()) }) } } func Test_CSRF_Param_Extractor_Missing(t *testing.T) { t.Parallel() // Test the case where no param is provided (should get 403 from CSRF middleware on the catch-all route) app := fiber.New() // Add a catch-all route with CSRF middleware for missing param case app.Use(New(Config{Extractor: extractors.FromParam("csrf")})) app.Post("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) h := app.Handler() ctx := &fasthttp.RequestCtx{} // Generate CSRF token ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.SetRequestURI("/") h(ctx) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) token = strings.Split(strings.Split(token, ";")[0], "=")[1] // Test missing param (accessing "/" instead of "/:csrf") ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.SetRequestURI("/") ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token) h(ctx) require.Equal(t, 403, ctx.Response.StatusCode(), "Missing param should return 403") } func Test_CSRF_Extractors_ErrorTypes(t *testing.T) { t.Parallel() // Test all extractor error types testCases := []struct { expected error setupCtx func(ctx *fasthttp.RequestCtx) // Add setup function name string extractor extractors.Extractor }{ { name: "Missing header", extractor: extractors.FromHeader("X-Missing-Header"), expected: extractors.ErrNotFound, setupCtx: func(_ *fasthttp.RequestCtx) {}, // No setup needed for headers }, { name: "Missing query", extractor: extractors.FromQuery("missing_param"), expected: extractors.ErrNotFound, setupCtx: func(ctx *fasthttp.RequestCtx) { ctx.Request.SetRequestURI("/") // Set URI for query parsing }, }, { name: "Missing param", extractor: extractors.FromParam("missing_param"), expected: extractors.ErrNotFound, setupCtx: func(_ *fasthttp.RequestCtx) {}, // Params are handled by router }, { name: "Missing form", extractor: extractors.FromForm("missing_field"), expected: extractors.ErrNotFound, setupCtx: func(ctx *fasthttp.RequestCtx) { // Properly initialize request for form parsing ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.SetContentType(fiber.MIMEApplicationForm) ctx.Request.SetBodyString("") // Empty form body }, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() app := fiber.New() requestCtx := &fasthttp.RequestCtx{} tc.setupCtx(requestCtx) // Set up the context properly ctx := app.AcquireCtx(requestCtx) defer app.ReleaseCtx(ctx) token, err := tc.extractor.Extract(ctx) require.Empty(t, token) require.Equal(t, tc.expected, err) }) } } ================================================ FILE: middleware/csrf/helpers.go ================================================ package csrf import ( "crypto/subtle" "net/url" "strings" "github.com/gofiber/utils/v2" utilsstrings "github.com/gofiber/utils/v2/strings" ) const ( schemeHTTP = "http" schemeHTTPS = "https" ) func compareTokens(a, b []byte) bool { return subtle.ConstantTimeCompare(a, b) == 1 } func compareStrings(a, b string) bool { return subtle.ConstantTimeCompare(utils.UnsafeBytes(a), utils.UnsafeBytes(b)) == 1 } func schemeAndHostMatch(schemeA, hostA, schemeB, hostB string) bool { normalizedSchemeA := utilsstrings.ToLower(schemeA) normalizedSchemeB := utilsstrings.ToLower(schemeB) normalizedHostA := normalizeSchemeHost(normalizedSchemeA, hostA) normalizedHostB := normalizeSchemeHost(normalizedSchemeB, hostB) return normalizedSchemeA == normalizedSchemeB && normalizedHostA == normalizedHostB } func normalizeSchemeHost(scheme, host string) string { host = utilsstrings.ToLower(host) defaultPort := "" switch scheme { case schemeHTTP: defaultPort = "80" case schemeHTTPS: defaultPort = "443" default: return host } parsedHost, err := url.Parse(scheme + "://" + host) if err != nil { return host } if port := parsedHost.Port(); port != "" { return host } hostname := parsedHost.Hostname() if hostname == "" { return host } if strings.IndexByte(hostname, ':') >= 0 && !strings.HasPrefix(hostname, "[") { hostname = "[" + hostname + "]" } return hostname + ":" + defaultPort } // normalizeOrigin checks if the provided origin is in a correct format // and normalizes it by removing any path or trailing slash. // It returns a boolean indicating whether the origin is valid // and the normalized origin. func normalizeOrigin(origin string) (valid bool, normalized string) { //nolint:nonamedreturns // gocritic unnamedResult prefers naming validity and normalized origin results parsedOrigin, err := url.Parse(origin) if err != nil { return false, "" } // Validate the scheme is either http or https if parsedOrigin.Scheme != schemeHTTP && parsedOrigin.Scheme != schemeHTTPS { return false, "" } // Don't allow a wildcard with a protocol // wildcards cannot be used within any other value. For example, the following header is not valid: // Access-Control-Allow-Origin: https://* if strings.IndexByte(parsedOrigin.Host, '*') >= 0 { return false, "" } // Validate there is a host present. The presence of a path, query, or fragment components // is checked, but a trailing "/" (indicative of the root) is allowed for the path and will be normalized if parsedOrigin.Host == "" || (parsedOrigin.Path != "" && parsedOrigin.Path != "/") || parsedOrigin.RawQuery != "" || parsedOrigin.Fragment != "" { return false, "" } // Normalize the origin by constructing it from the scheme and host. // The path or trailing slash is not included in the normalized origin. return true, utilsstrings.ToLower(parsedOrigin.Scheme) + "://" + utilsstrings.ToLower(parsedOrigin.Host) } type subdomain struct { prefix string suffix string } func (s subdomain) match(o string) bool { // Not a subdomain if not long enough for a dot separator. if len(o) < len(s.prefix)+len(s.suffix)+1 { return false } if !strings.HasPrefix(o, s.prefix) || !strings.HasSuffix(o, s.suffix) { return false } // Check for the dot separator and validate that there is at least one // non-empty label between prefix and suffix. Empty labels like // "https://.example.com" or "https://..example.com" should not match. suffixStartIndex := len(o) - len(s.suffix) if suffixStartIndex <= len(s.prefix) { return false } if o[suffixStartIndex-1] != '.' { return false } // Extract the subdomain part (without the trailing dot) and ensure it // doesn't contain empty labels. sub := o[len(s.prefix) : suffixStartIndex-1] if sub == "" || sub[0] == '.' || strings.Contains(sub, "..") { return false } return true } ================================================ FILE: middleware/csrf/helpers_test.go ================================================ package csrf import ( "testing" "github.com/stretchr/testify/assert" ) // go test -run -v Test_normalizeOrigin func Test_normalizeOrigin(t *testing.T) { t.Parallel() testCases := []struct { origin string expectedOrigin string expectedValid bool }{ {origin: "http://example.com", expectedValid: true, expectedOrigin: "http://example.com"}, // Simple case should work. {origin: "HTTP://EXAMPLE.COM", expectedValid: true, expectedOrigin: "http://example.com"}, // Case should be normalized. {origin: "http://example.com/", expectedValid: true, expectedOrigin: "http://example.com"}, // Trailing slash should be removed. {origin: "http://example.com:3000", expectedValid: true, expectedOrigin: "http://example.com:3000"}, // Port should be preserved. {origin: "http://example.com:3000/", expectedValid: true, expectedOrigin: "http://example.com:3000"}, // Trailing slash should be removed. {origin: "http://", expectedValid: false, expectedOrigin: ""}, // Invalid origin should not be accepted. {origin: "file:///etc/passwd", expectedValid: false, expectedOrigin: ""}, // File scheme should not be accepted. {origin: "https://*example.com", expectedValid: false, expectedOrigin: ""}, // Wildcard domain should not be accepted. {origin: "http://*.example.com", expectedValid: false, expectedOrigin: ""}, // Wildcard subdomain should not be accepted. {origin: "http://example.com/path", expectedValid: false, expectedOrigin: ""}, // Path should not be accepted. {origin: "http://example.com?query=123", expectedValid: false, expectedOrigin: ""}, // Query should not be accepted. {origin: "http://example.com#fragment", expectedValid: false, expectedOrigin: ""}, // Fragment should not be accepted. {origin: "http://localhost", expectedValid: true, expectedOrigin: "http://localhost"}, // Localhost should be accepted. {origin: "http://127.0.0.1", expectedValid: true, expectedOrigin: "http://127.0.0.1"}, // IPv4 address should be accepted. {origin: "http://[::1]", expectedValid: true, expectedOrigin: "http://[::1]"}, // IPv6 address should be accepted. {origin: "http://[::1]:8080", expectedValid: true, expectedOrigin: "http://[::1]:8080"}, // IPv6 address with port should be accepted. {origin: "http://[::1]:8080/", expectedValid: true, expectedOrigin: "http://[::1]:8080"}, // IPv6 address with port and trailing slash should be accepted. {origin: "http://[::1]:8080/path", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port and path should not be accepted. {origin: "http://[::1]:8080?query=123", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port and query should not be accepted. {origin: "http://[::1]:8080#fragment", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port and fragment should not be accepted. {origin: "http://[::1]:8080/path?query=123#fragment", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port, path, query, and fragment should not be accepted. {origin: "http://[::1]:8080/path?query=123#fragment/", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port, path, query, fragment, and trailing slash should not be accepted. {origin: "http://[::1]:8080/path?query=123#fragment/invalid", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port, path, query, fragment, trailing slash, and invalid segment should not be accepted. {origin: "http://[::1]:8080/path?query=123#fragment/invalid/", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port, path, query, fragment, trailing slash, and invalid segment with trailing slash should not be accepted. {origin: "http://[::1]:8080/path?query=123#fragment/invalid/segment", expectedValid: false, expectedOrigin: ""}, // IPv6 address with port, path, query, fragment, trailing slash, and invalid segment with additional segment should not be accepted. } for _, tc := range testCases { valid, normalizedOrigin := normalizeOrigin(tc.origin) if valid != tc.expectedValid { t.Errorf("Expected origin '%s' to be valid: %v, but got: %v", tc.origin, tc.expectedValid, valid) } if normalizedOrigin != tc.expectedOrigin { t.Errorf("Expected normalized origin '%s' for origin '%s', but got: '%s'", tc.expectedOrigin, tc.origin, normalizedOrigin) } } } func Test_normalizeSchemeHost(t *testing.T) { t.Parallel() testCases := []struct { name string scheme string host string expectedHost string }{ { name: "http default port added", scheme: "http", host: "example.com", expectedHost: "example.com:80", }, { name: "https default port added", scheme: "https", host: "example.com", expectedHost: "example.com:443", }, { name: "http custom port preserved", scheme: "http", host: "example.com:8080", expectedHost: "example.com:8080", }, { name: "https ipv6 default port added", scheme: "https", host: "[::1]", expectedHost: "[::1]:443", }, { name: "unknown scheme preserved", scheme: "ftp", host: "example.com", expectedHost: "example.com", }, { name: "https ipv6 custom port preserved", scheme: "https", host: "[::1]:8080", expectedHost: "[::1]:8080", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() assert.Equal(t, tc.expectedHost, normalizeSchemeHost(tc.scheme, tc.host)) }) } } // go test -run -v TestSubdomainMatch func TestSubdomainMatch(t *testing.T) { t.Parallel() tests := []struct { name string sub subdomain origin string expected bool }{ { name: "match with different scheme", sub: subdomain{prefix: "http://api.", suffix: "example.com"}, origin: "https://api.service.example.com", expected: false, }, { name: "match with different scheme", sub: subdomain{prefix: "https://", suffix: "example.com"}, origin: "http://api.service.example.com", expected: false, }, { name: "match with valid subdomain", sub: subdomain{prefix: "https://", suffix: "example.com"}, origin: "https://api.service.example.com", expected: true, }, { name: "match with valid nested subdomain", sub: subdomain{prefix: "https://", suffix: "example.com"}, origin: "https://1.2.api.service.example.com", expected: true, }, { name: "no match with invalid prefix", sub: subdomain{prefix: "https://abc.", suffix: "example.com"}, origin: "https://service.example.com", expected: false, }, { name: "no match with invalid suffix", sub: subdomain{prefix: "https://", suffix: "example.com"}, origin: "https://api.example.org", expected: false, }, { name: "no match with empty origin", sub: subdomain{prefix: "https://", suffix: "example.com"}, origin: "", expected: false, }, { name: "no match with malformed subdomain", sub: subdomain{prefix: "https://", suffix: "example.com"}, origin: "https://evil.comexample.com", expected: false, }, { name: "partial match not considered a match", sub: subdomain{prefix: "https://service.", suffix: "example.com"}, origin: "https://api.example.com", expected: false, }, { name: "no match with empty host label", sub: subdomain{prefix: "https://", suffix: "example.com"}, origin: "https://.example.com", expected: false, }, { name: "no match with malformed host label", sub: subdomain{prefix: "https://", suffix: "example.com"}, origin: "https://..example.com", expected: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() got := tt.sub.match(tt.origin) assert.Equal(t, tt.expected, got, "subdomain.match()") }) } } // go test -v -run=^$ -bench=Benchmark_CSRF_SubdomainMatch -benchmem -count=4 func Benchmark_CSRF_SubdomainMatch(b *testing.B) { s := subdomain{ prefix: "www", suffix: "example.com", } o := "www.example.com" b.ReportAllocs() for b.Loop() { s.match(o) } } ================================================ FILE: middleware/csrf/session_manager.go ================================================ package csrf import ( "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/log" "github.com/gofiber/fiber/v3/middleware/session" ) type sessionManager struct { session *session.Store } type sessionKeyType int const ( sessionKey sessionKeyType = 0 ) func newSessionManager(s *session.Store) *sessionManager { // Create new storage handler sessionManager := new(sessionManager) if s != nil { // Use provided storage if provided sessionManager.session = s // Register the sessionKeyType and Token type s.RegisterType(sessionKeyType(0)) s.RegisterType(Token{}) } return sessionManager } // get token from session func (m *sessionManager) getRaw(c fiber.Ctx, key string, raw []byte) []byte { sess := session.FromContext(c) var token Token var ok bool if sess != nil { token, ok = sess.Get(sessionKey).(Token) } else { // Try to get the session from the store storeSess, err := m.session.Get(c) if err != nil { // Handle error return nil } token, ok = storeSess.Get(sessionKey).(Token) } if ok { if token.Expiration.Before(time.Now()) || key != token.Key || !compareTokens(raw, token.Raw) { return nil } return token.Raw } return nil } // set token in session func (m *sessionManager) setRaw(c fiber.Ctx, key string, raw []byte, exp time.Duration) { sess := session.FromContext(c) if sess != nil { // the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here sess.Set(sessionKey, Token{Key: key, Raw: raw, Expiration: time.Now().Add(exp)}) } else { // Try to get the session from the store storeSess, err := m.session.Get(c) if err != nil { // Handle error return } storeSess.Set(sessionKey, Token{Key: key, Raw: raw, Expiration: time.Now().Add(exp)}) if err := storeSess.Save(); err != nil { log.Warn("csrf: failed to save session: ", err) } } } // delete token from session func (m *sessionManager) delRaw(c fiber.Ctx) { sess := session.FromContext(c) if sess != nil { sess.Delete(sessionKey) } else { // Try to get the session from the store storeSess, err := m.session.Get(c) if err != nil { // Handle error return } storeSess.Delete(sessionKey) if err := storeSess.Save(); err != nil { log.Warn("csrf: failed to save session: ", err) } } } ================================================ FILE: middleware/csrf/storage_manager.go ================================================ package csrf import ( "context" "fmt" "sync" "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/internal/memory" ) // msgp -file="storage_manager.go" -o="storage_manager_msgp.go" -tests=true -unexported // //go:generate msgp -o=storage_manager_msgp.go -tests=true -unexported type item struct{} const redactedKey = "[redacted]" //msgp:ignore manager //msgp:ignore storageManager type storageManager struct { pool sync.Pool `msg:"-"` //nolint:revive // Ignore unexported type memory *memory.Storage `msg:"-"` //nolint:revive // Ignore unexported type storage fiber.Storage `msg:"-"` //nolint:revive // Ignore unexported type redactKeys bool } func newStorageManager(storage fiber.Storage, redactKeys bool) *storageManager { // Create new storage handler storageManager := &storageManager{ pool: sync.Pool{ New: func() any { return new(item) }, }, redactKeys: redactKeys, } if storage != nil { // Use provided storage if provided storageManager.storage = storage } else { // Fallback to memory storage storageManager.memory = memory.New() } return storageManager } // get raw data from storage or memory func (m *storageManager) getRaw(ctx context.Context, key string) ([]byte, error) { if m.storage != nil { raw, err := m.storage.GetWithContext(ctx, key) if err != nil { return nil, fmt.Errorf("csrf: failed to get value from storage: %w", err) } return raw, nil } if value := m.memory.Get(key); value != nil { raw, ok := value.([]byte) if !ok { return nil, fmt.Errorf("csrf: unexpected value type %T in storage", value) } return raw, nil } return nil, nil } // set data to storage or memory func (m *storageManager) setRaw(ctx context.Context, key string, raw []byte, exp time.Duration) error { if m.storage != nil { if err := m.storage.SetWithContext(ctx, key, raw, exp); err != nil { return fmt.Errorf("csrf: failed to store key %q: %w", m.logKey(key), err) } return nil } m.memory.Set(key, raw, exp) return nil } // delete data from storage or memory func (m *storageManager) delRaw(ctx context.Context, key string) error { if m.storage != nil { if err := m.storage.DeleteWithContext(ctx, key); err != nil { return fmt.Errorf("csrf: failed to delete key %q: %w", m.logKey(key), err) } return nil } m.memory.Delete(key) return nil } func (m *storageManager) logKey(key string) string { if m.redactKeys { return redactedKey } return key } ================================================ FILE: middleware/csrf/storage_manager_msgp.go ================================================ // Code generated by github.com/tinylib/msgp DO NOT EDIT. package csrf import ( "github.com/tinylib/msgp/msgp" ) // DecodeMsg implements msgp.Decodable func (z *item) DecodeMsg(dc *msgp.Reader) (err error) { var field []byte _ = field var zb0001 uint32 zb0001, err = dc.ReadMapHeader() if err != nil { err = msgp.WrapError(err) return } for zb0001 > 0 { zb0001-- field, err = dc.ReadMapKeyPtr() if err != nil { err = msgp.WrapError(err) return } switch msgp.UnsafeString(field) { default: err = dc.Skip() if err != nil { err = msgp.WrapError(err) return } } } return } // EncodeMsg implements msgp.Encodable func (z item) EncodeMsg(en *msgp.Writer) (err error) { // map header, size 0 _ = z err = en.Append(0x80) if err != nil { return } return } // MarshalMsg implements msgp.Marshaler func (z item) MarshalMsg(b []byte) (o []byte, err error) { o = msgp.Require(b, z.Msgsize()) // map header, size 0 _ = z o = append(o, 0x80) return } // UnmarshalMsg implements msgp.Unmarshaler func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) { var field []byte _ = field var zb0001 uint32 zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) if err != nil { err = msgp.WrapError(err) return } for zb0001 > 0 { zb0001-- field, bts, err = msgp.ReadMapKeyZC(bts) if err != nil { err = msgp.WrapError(err) return } switch msgp.UnsafeString(field) { default: bts, err = msgp.Skip(bts) if err != nil { err = msgp.WrapError(err) return } } } o = bts return } // Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message func (z item) Msgsize() (s int) { s = 1 return } ================================================ FILE: middleware/csrf/storage_manager_msgp_test.go ================================================ // Code generated by github.com/tinylib/msgp DO NOT EDIT. package csrf import ( "bytes" "testing" "github.com/tinylib/msgp/msgp" ) func TestMarshalUnmarshalitem(t *testing.T) { v := item{} bts, err := v.MarshalMsg(nil) if err != nil { t.Fatal(err) } left, err := v.UnmarshalMsg(bts) if err != nil { t.Fatal(err) } if len(left) > 0 { t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) } left, err = msgp.Skip(bts) if err != nil { t.Fatal(err) } if len(left) > 0 { t.Errorf("%d bytes left over after Skip(): %q", len(left), left) } } func BenchmarkMarshalMsgitem(b *testing.B) { v := item{} b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { v.MarshalMsg(nil) } } func BenchmarkAppendMsgitem(b *testing.B) { v := item{} bts := make([]byte, 0, v.Msgsize()) bts, _ = v.MarshalMsg(bts[0:0]) b.SetBytes(int64(len(bts))) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { bts, _ = v.MarshalMsg(bts[0:0]) } } func BenchmarkUnmarshalitem(b *testing.B) { v := item{} bts, _ := v.MarshalMsg(nil) b.ReportAllocs() b.SetBytes(int64(len(bts))) b.ResetTimer() for i := 0; i < b.N; i++ { _, err := v.UnmarshalMsg(bts) if err != nil { b.Fatal(err) } } } func TestEncodeDecodeitem(t *testing.T) { v := item{} var buf bytes.Buffer msgp.Encode(&buf, &v) m := v.Msgsize() if buf.Len() > m { t.Log("WARNING: TestEncodeDecodeitem Msgsize() is inaccurate") } vn := item{} err := msgp.Decode(&buf, &vn) if err != nil { t.Error(err) } buf.Reset() msgp.Encode(&buf, &v) err = msgp.NewReader(&buf).Skip() if err != nil { t.Error(err) } } func BenchmarkEncodeitem(b *testing.B) { v := item{} var buf bytes.Buffer msgp.Encode(&buf, &v) b.SetBytes(int64(buf.Len())) en := msgp.NewWriter(msgp.Nowhere) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { v.EncodeMsg(en) } en.Flush() } func BenchmarkDecodeitem(b *testing.B) { v := item{} var buf bytes.Buffer msgp.Encode(&buf, &v) b.SetBytes(int64(buf.Len())) rd := msgp.NewEndlessReader(buf.Bytes(), b) dc := msgp.NewReader(rd) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { err := v.DecodeMsg(dc) if err != nil { b.Fatal(err) } } } ================================================ FILE: middleware/csrf/token.go ================================================ package csrf import ( "time" ) // Token represents a CSRF token with expiration metadata. // This is used internally for token storage and validation. type Token struct { Expiration time.Time `json:"expiration"` Key string `json:"key"` Raw []byte `json:"raw"` } ================================================ FILE: middleware/earlydata/config.go ================================================ package earlydata import ( "github.com/gofiber/fiber/v3" ) const ( DefaultHeaderName = "Early-Data" DefaultHeaderTrueValue = "1" ) // Config defines the config for middleware. type Config struct { // Next defines a function to skip this middleware when returned true. // // Optional. Default: nil Next func(c fiber.Ctx) bool // IsEarlyData returns whether the request is an early-data request. // // Optional. Default: a function which checks if the "Early-Data" request header equals "1". IsEarlyData func(c fiber.Ctx) bool // AllowEarlyData returns whether the early-data request should be allowed or rejected. // // Optional. Default: a function which rejects the request on unsafe and allows the request on safe HTTP request methods. AllowEarlyData func(c fiber.Ctx) bool // Error is returned if an early-data request is rejected. // // Optional. Default: fiber.ErrTooEarly. Error error } // ConfigDefault is the default config var ConfigDefault = Config{ IsEarlyData: func(c fiber.Ctx) bool { return c.Get(DefaultHeaderName) == DefaultHeaderTrueValue }, AllowEarlyData: func(c fiber.Ctx) bool { return fiber.IsMethodSafe(c.Method()) }, Error: fiber.ErrTooEarly, } // Helper function to set default values func configDefault(config ...Config) Config { // Return default config if nothing provided if len(config) < 1 { return ConfigDefault } // Override default config cfg := config[0] // Set default values if cfg.IsEarlyData == nil { cfg.IsEarlyData = ConfigDefault.IsEarlyData } if cfg.AllowEarlyData == nil { cfg.AllowEarlyData = ConfigDefault.AllowEarlyData } if cfg.Error == nil { cfg.Error = ConfigDefault.Error } return cfg } ================================================ FILE: middleware/earlydata/earlydata.go ================================================ package earlydata import ( "github.com/gofiber/fiber/v3" ) // The contextKey type is unexported to prevent collisions with context keys defined in // other packages. type contextKey int const ( localsKeyAllowed contextKey = 0 // earlydata_allowed ) // IsEarly returns true if the request used early data and was accepted by the middleware. func IsEarly(c fiber.Ctx) bool { return c.Locals(localsKeyAllowed) != nil } // New creates a new middleware handler // https://datatracker.ietf.org/doc/html/rfc8470#section-5.1 func New(config ...Config) fiber.Handler { // Set default config cfg := configDefault(config...) // Return new handler return func(c fiber.Ctx) error { // Don't execute middleware if Next returns true if cfg.Next != nil && cfg.Next(c) { return c.Next() } // Continue stack if request is not an early-data request if !cfg.IsEarlyData(c) { return c.Next() } // Abort if we can't trust the early-data header if !c.IsProxyTrusted() { return cfg.Error } // Continue stack if we allow early-data for this request if cfg.AllowEarlyData(c) { _ = c.Locals(localsKeyAllowed, true) return c.Next() } // Else return our error return cfg.Error } } ================================================ FILE: middleware/earlydata/earlydata_test.go ================================================ package earlydata import ( "errors" "fmt" "net/http" "net/http/httptest" "reflect" "testing" "github.com/gofiber/fiber/v3" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) const ( headerName = "Early-Data" headerValOn = "1" headerValOff = "0" ) const ( trustedRemoteAddr = "0.0.0.0:1234" untrustedRemoteAddr = "203.0.113.1:1234" ) func appWithConfig(t *testing.T, c *fiber.Config) *fiber.App { t.Helper() var app *fiber.App if c == nil { app = fiber.New() } else { app = fiber.New(*c) } app.Use(New()) // Middleware to test IsEarly func const localsKeyTestValid = "earlydata_testvalid" app.Use(func(c fiber.Ctx) error { isEarly := IsEarly(c) switch h := c.Get(headerName); h { case "", headerValOff: if isEarly { return errors.New("is early-data even though it's not") } case headerValOn: switch { case fiber.IsMethodSafe(c.Method()): if !isEarly { return errors.New("should be early-data on safe HTTP methods") } default: if isEarly { return errors.New("early-data unsupported on unsafe HTTP methods") } } default: return fmt.Errorf("header has unsupported value: %s", h) } _ = c.Locals(localsKeyTestValid, true) return c.Next() }) app.Add([]string{ fiber.MethodGet, fiber.MethodPost, }, "/", func(c fiber.Ctx) error { valid, ok := c.Locals(localsKeyTestValid).(bool) if !ok { panic(errors.New("failed to type-assert to bool")) } if !valid { return errors.New("handler called even though validation failed") } return nil }) return app } type requestExpectation struct { method string header string status int } func executeExpectations(t *testing.T, app *fiber.App, remoteAddr string, expectations []requestExpectation) { t.Helper() for _, expectation := range expectations { req := httptest.NewRequest(expectation.method, "/", http.NoBody) req.RemoteAddr = remoteAddr if expectation.header != "" { req.Header.Set(headerName, expectation.header) } resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, expectation.status, resp.StatusCode) } } // go test -run Test_EarlyData func Test_EarlyData(t *testing.T) { t.Parallel() untrustedExpectations := []requestExpectation{ {method: fiber.MethodGet, status: fiber.StatusOK}, {method: fiber.MethodGet, header: headerValOff, status: fiber.StatusOK}, {method: fiber.MethodGet, header: headerValOn, status: fiber.StatusTooEarly}, {method: fiber.MethodPost, status: fiber.StatusOK}, {method: fiber.MethodPost, header: headerValOff, status: fiber.StatusOK}, {method: fiber.MethodPost, header: headerValOn, status: fiber.StatusTooEarly}, } trustedExpectations := []requestExpectation{ {method: fiber.MethodGet, status: fiber.StatusOK}, {method: fiber.MethodGet, header: headerValOff, status: fiber.StatusOK}, {method: fiber.MethodGet, header: headerValOn, status: fiber.StatusOK}, {method: fiber.MethodPost, status: fiber.StatusOK}, {method: fiber.MethodPost, header: headerValOff, status: fiber.StatusOK}, {method: fiber.MethodPost, header: headerValOn, status: fiber.StatusTooEarly}, } t.Run("empty config", func(t *testing.T) { t.Parallel() app := appWithConfig(t, nil) executeExpectations(t, app, untrustedRemoteAddr, untrustedExpectations) }) t.Run("default config", func(t *testing.T) { t.Parallel() app := appWithConfig(t, &fiber.Config{}) executeExpectations(t, app, untrustedRemoteAddr, untrustedExpectations) }) t.Run("config with TrustProxy and untrusted remote", func(t *testing.T) { t.Parallel() app := appWithConfig(t, &fiber.Config{ TrustProxy: true, }) executeExpectations(t, app, untrustedRemoteAddr, untrustedExpectations) }) t.Run("config with TrustProxy and trusted TrustProxyConfig.Proxies", func(t *testing.T) { t.Parallel() app := appWithConfig(t, &fiber.Config{ TrustProxy: true, TrustProxyConfig: fiber.TrustProxyConfig{ Proxies: []string{ "0.0.0.0", }, }, }) executeExpectations(t, app, trustedRemoteAddr, trustedExpectations) }) } // Test_EarlyDataNext verifies that the middleware skips its logic when Next returns true. func Test_EarlyDataNext(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Next: func(fiber.Ctx) bool { return true }, })) called := false app.Get("/", func(c fiber.Ctx) error { called = true if IsEarly(c) { return errors.New("IsEarly(c) should be false when Next returns true") } return nil }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set(headerName, headerValOn) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) require.True(t, called) } // Test_configDefault_NoConfig verifies that calling configDefault without // providing a configuration returns ConfigDefault as-is. func Test_configDefault_NoConfig(t *testing.T) { t.Parallel() cfg := configDefault() require.Equal(t, ConfigDefault.Error, cfg.Error) require.Equal(t, reflect.ValueOf(ConfigDefault.IsEarlyData).Pointer(), reflect.ValueOf(cfg.IsEarlyData).Pointer()) require.Equal(t, reflect.ValueOf(ConfigDefault.AllowEarlyData).Pointer(), reflect.ValueOf(cfg.AllowEarlyData).Pointer()) } // Test_configDefault_WithConfig verifies that provided configuration fields are // kept while missing fields are populated with defaults. func Test_configDefault_WithConfig(t *testing.T) { t.Parallel() expectedErr := errors.New("boom") called := false custom := Config{ Next: func(_ fiber.Ctx) bool { called = true; return false }, Error: expectedErr, } cfg := configDefault(custom) // Next should be preserved and not invoked by configDefault. require.False(t, called) require.Equal(t, reflect.ValueOf(custom.Next).Pointer(), reflect.ValueOf(cfg.Next).Pointer()) // Custom error must be preserved. require.Equal(t, expectedErr, cfg.Error) // Missing fields should be set to defaults. require.NotNil(t, cfg.IsEarlyData) require.NotNil(t, cfg.AllowEarlyData) // Verify default functions behave as expected. app := fiber.New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(DefaultHeaderName, DefaultHeaderTrueValue) c.Request().Header.SetMethod(fiber.MethodGet) require.True(t, cfg.IsEarlyData(c)) require.True(t, cfg.AllowEarlyData(c)) app.ReleaseCtx(c) } ================================================ FILE: middleware/encryptcookie/config.go ================================================ package encryptcookie import ( "github.com/gofiber/fiber/v3" ) // Config defines the config for middleware. type Config struct { // Next defines a function to skip this middleware when returned true. // // Optional. Default: nil Next func(c fiber.Ctx) bool // Custom function to encrypt cookies. // // Optional. Default: EncryptCookie (using AES-GCM) Encryptor func(name, decryptedString, key string) (string, error) // Custom function to decrypt cookies. // // Optional. Default: DecryptCookie (using AES-GCM) Decryptor func(name, encryptedString, key string) (string, error) // Base64 encoded unique key to encode & decode cookies. // // Required. Key length should be 16, 24, or 32 bytes when decoded // if using the default EncryptCookie and DecryptCookie functions. // You may use `encryptcookie.GenerateKey(length)` to generate a new key. Key string // Array of cookie keys that should not be encrypted. // // Optional. Default: [] Except []string } // ConfigDefault is the default config var ConfigDefault = Config{ Next: nil, Except: []string{}, Key: "", Encryptor: EncryptCookie, Decryptor: DecryptCookie, } // Helper function to set default values func configDefault(config ...Config) Config { // Set default config cfg := ConfigDefault // Override config if provided if len(config) > 0 { cfg = config[0] // Set default values if cfg.Next == nil { cfg.Next = ConfigDefault.Next } if cfg.Except == nil { cfg.Except = ConfigDefault.Except } if cfg.Encryptor == nil { cfg.Encryptor = ConfigDefault.Encryptor } if cfg.Decryptor == nil { cfg.Decryptor = ConfigDefault.Decryptor } } if cfg.Key == "" { panic("fiber: encrypt cookie middleware requires key") } if err := validateKey(cfg.Key); err != nil { panic(err) } return cfg } ================================================ FILE: middleware/encryptcookie/config_test.go ================================================ package encryptcookie import ( "crypto/rand" "encoding/base64" "fmt" "testing" "github.com/stretchr/testify/require" ) func Test_configDefault_KeyValidation(t *testing.T) { t.Parallel() t.Run("invalid base64", func(t *testing.T) { t.Parallel() _, decErr := base64.StdEncoding.DecodeString("invalid") expectedErr := fmt.Errorf("failed to base64-decode key: %w", decErr).Error() require.PanicsWithError(t, expectedErr, func() { configDefault(Config{Key: "invalid"}) }) }) t.Run("invalid length", func(t *testing.T) { t.Parallel() key := make([]byte, 20) _, err := rand.Read(key) require.NoError(t, err) invalidKey := base64.StdEncoding.EncodeToString(key) require.PanicsWithValue(t, ErrInvalidKeyLength, func() { configDefault(Config{Key: invalidKey}) }) }) } ================================================ FILE: middleware/encryptcookie/encryptcookie.go ================================================ package encryptcookie import ( "errors" "github.com/valyala/fasthttp" "github.com/gofiber/fiber/v3" ) // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Set default config cfg := configDefault(config...) // Return new handler return func(c fiber.Ctx) error { // Don't execute middleware if Next returns true if cfg.Next != nil && cfg.Next(c) { return c.Next() } // Decrypt request cookies cookiesToDelete := make([][]byte, 0, 4) for key, value := range c.Request().Header.Cookies() { keyString := string(key) if !isDisabled(keyString, cfg.Except) { decryptedValue, err := cfg.Decryptor(keyString, string(value), cfg.Key) if err != nil { cookiesToDelete = append(cookiesToDelete, key) } else { c.Request().Header.SetCookie(keyString, decryptedValue) } } } // Delete cookies that failed to decrypt - outside the loop to avoid mutation during iteration for _, key := range cookiesToDelete { c.Request().Header.DelCookieBytes(key) } // Continue stack err := c.Next() // Encrypt response cookies for key := range c.Response().Header.Cookies() { keyString := string(key) if !isDisabled(keyString, cfg.Except) { cookieValue := fasthttp.Cookie{} cookieValue.SetKeyBytes(key) if c.Response().Header.Cookie(&cookieValue) { encryptedValue, encErr := cfg.Encryptor(keyString, string(cookieValue.Value()), cfg.Key) if encErr != nil { return errors.Join(err, encErr) } cookieValue.SetValue(encryptedValue) c.Response().Header.SetCookie(&cookieValue) } } } return err } } ================================================ FILE: middleware/encryptcookie/encryptcookie_test.go ================================================ package encryptcookie import ( "crypto/rand" "encoding/base64" "errors" "net/http" "net/http/httptest" "strconv" "testing" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" "github.com/gofiber/fiber/v3" ) func Test_Middleware_Panics(t *testing.T) { t.Parallel() t.Run("Empty Key", func(t *testing.T) { t.Parallel() app := fiber.New() require.Panics(t, func() { app.Use(New(Config{ Key: "", })) }) }) t.Run("Invalid Key", func(t *testing.T) { t.Parallel() require.Panics(t, func() { GenerateKey(11) }) }) } func Test_Middleware_InvalidKeys(t *testing.T) { t.Parallel() tests := []struct { length int }{ {length: 11}, {length: 25}, {length: 60}, } for _, tt := range tests { t.Run(strconv.Itoa(tt.length)+"_length_encrypt", func(t *testing.T) { t.Parallel() key := make([]byte, tt.length) _, err := rand.Read(key) require.NoError(t, err) keyString := base64.StdEncoding.EncodeToString(key) _, err = EncryptCookie("test", "SomeThing", keyString) require.Error(t, err) }) t.Run(strconv.Itoa(tt.length)+"_length_decrypt", func(t *testing.T) { t.Parallel() key := make([]byte, tt.length) _, err := rand.Read(key) require.NoError(t, err) keyString := base64.StdEncoding.EncodeToString(key) _, err = DecryptCookie("test", "SomeThing", keyString) require.Error(t, err) }) } } func Test_Middleware_InvalidBase64(t *testing.T) { t.Parallel() invalidBase64 := "invalid-base64-string-!@#" t.Run("encryptor", func(t *testing.T) { t.Parallel() _, err := EncryptCookie("test", "SomeText", invalidBase64) require.Error(t, err) require.ErrorContains(t, err, "failed to base64-decode key") }) t.Run("decryptor_key", func(t *testing.T) { t.Parallel() _, err := DecryptCookie("test", "SomeText", invalidBase64) require.Error(t, err) require.ErrorContains(t, err, "failed to base64-decode key") }) t.Run("decryptor_value", func(t *testing.T) { t.Parallel() _, err := DecryptCookie("test", invalidBase64, GenerateKey(32)) require.Error(t, err) require.ErrorContains(t, err, "failed to base64-decode value") }) } func Test_DecryptCookie_InvalidEncryptedValue(t *testing.T) { t.Parallel() key := GenerateKey(32) // the decoded value is shorter than the GCM nonce size, so decryption should fail immediately shortValue := base64.StdEncoding.EncodeToString([]byte("short")) _, err := DecryptCookie("session", shortValue, key) require.ErrorIs(t, err, ErrInvalidEncryptedValue) } func Test_Middleware_EncryptionErrorPropagates(t *testing.T) { t.Parallel() testKey := GenerateKey(32) expected := errors.New("encrypt failed") var captured error app := fiber.New(fiber.Config{ ErrorHandler: func(c fiber.Ctx, err error) error { captured = err return c.Status(fiber.StatusTeapot).SendString("encryption error") }, }) app.Use(New(Config{ Key: testKey, Encryptor: func(name, value, _ string) (string, error) { if name == "test" { return "", expected } return value, nil }, })) app.Get("/", func(c fiber.Ctx) error { c.Cookie(&fiber.Cookie{ Name: "test", Value: "value", }) return nil }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusTeapot, resp.StatusCode) require.ErrorIs(t, captured, expected) } func Test_Middleware_EncryptionErrorDoesNotMaskNextError(t *testing.T) { t.Parallel() testKey := GenerateKey(32) encryptErr := errors.New("encrypt failed") downstreamErr := errors.New("downstream failed") var captured error app := fiber.New(fiber.Config{ ErrorHandler: func(c fiber.Ctx, err error) error { captured = err return c.Status(fiber.StatusTeapot).SendString("combined error") }, }) app.Use(New(Config{ Key: testKey, Encryptor: func(name, value, _ string) (string, error) { if name == "test" { return "", encryptErr } return value, nil }, })) app.Get("/", func(c fiber.Ctx) error { c.Cookie(&fiber.Cookie{ Name: "test", Value: "value", }) return downstreamErr }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusTeapot, resp.StatusCode) require.ErrorIs(t, captured, downstreamErr) require.ErrorIs(t, captured, encryptErr) } func Test_Middleware_Encrypt_Cookie(t *testing.T) { t.Parallel() testKey := GenerateKey(32) app := fiber.New() app.Use(New(Config{ Key: testKey, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("value=" + c.Cookies("test")) }) app.Post("/", func(c fiber.Ctx) error { c.Cookie(&fiber.Cookie{ Name: "test", Value: "SomeThing", }) return nil }) h := app.Handler() // Test empty cookie ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) require.Equal(t, "value=", string(ctx.Response.Body())) // Test invalid cookie ctx = &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.SetCookie("test", "Invalid") h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) require.Equal(t, "value=", string(ctx.Response.Body())) ctx.Request.Header.SetCookie("test", "ixQURE2XOyZUs0WAOh2ehjWcP7oZb07JvnhWOsmeNUhPsj4+RyI=") h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) require.Equal(t, "value=", string(ctx.Response.Body())) // Test valid cookie ctx = &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodPost) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) encryptedCookie := fasthttp.Cookie{} encryptedCookie.SetKey("test") require.True(t, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value") decryptedCookieValue, err := DecryptCookie("test", string(encryptedCookie.Value()), testKey) require.NoError(t, err) require.Equal(t, "SomeThing", decryptedCookieValue) ctx = &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.SetCookie("test", string(encryptedCookie.Value())) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) require.Equal(t, "value=SomeThing", string(ctx.Response.Body())) } func Test_EncryptCookie_Rejects_Swapped_Names(t *testing.T) { t.Parallel() testKey := GenerateKey(32) encryptedValue, err := EncryptCookie("cookieA", "ValueA", testKey) require.NoError(t, err) decryptedValue, err := DecryptCookie("cookieA", encryptedValue, testKey) require.NoError(t, err) require.Equal(t, "ValueA", decryptedValue) _, err = DecryptCookie("cookieB", encryptedValue, testKey) require.Error(t, err) require.ErrorContains(t, err, "failed to decrypt ciphertext") } func Test_Encrypt_Cookie_Next(t *testing.T) { t.Parallel() testKey := GenerateKey(32) app := fiber.New() app.Use(New(Config{ Key: testKey, Next: func(_ fiber.Ctx) bool { return true }, })) app.Get("/", func(c fiber.Ctx) error { c.Cookie(&fiber.Cookie{ Name: "test", Value: "SomeThing", }) return nil }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, "SomeThing", resp.Cookies()[0].Value) } func Test_Encrypt_Cookie_Except(t *testing.T) { t.Parallel() testKey := GenerateKey(32) app := fiber.New() app.Use(New(Config{ Key: testKey, Except: []string{ "test1", }, })) app.Get("/", func(c fiber.Ctx) error { c.Cookie(&fiber.Cookie{ Name: "test1", Value: "SomeThing", }) c.Cookie(&fiber.Cookie{ Name: "test2", Value: "SomeThing", }) return nil }) h := app.Handler() ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) rawCookie := fasthttp.Cookie{} rawCookie.SetKey("test1") require.True(t, ctx.Response.Header.Cookie(&rawCookie), "Get cookie value") require.Equal(t, "SomeThing", string(rawCookie.Value())) encryptedCookie := fasthttp.Cookie{} encryptedCookie.SetKey("test2") require.True(t, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value") decryptedCookieValue, err := DecryptCookie("test2", string(encryptedCookie.Value()), testKey) require.NoError(t, err) require.Equal(t, "SomeThing", decryptedCookieValue) } func Test_Encrypt_Cookie_Custom_Encryptor(t *testing.T) { t.Parallel() testKey := GenerateKey(32) app := fiber.New() app.Use(New(Config{ Key: testKey, Encryptor: func(_, decryptedString, _ string) (string, error) { return base64.StdEncoding.EncodeToString([]byte(decryptedString)), nil }, Decryptor: func(_, encryptedString, _ string) (string, error) { decodedBytes, err := base64.StdEncoding.DecodeString(encryptedString) return string(decodedBytes), err }, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("value=" + c.Cookies("test")) }) app.Post("/", func(c fiber.Ctx) error { c.Cookie(&fiber.Cookie{ Name: "test", Value: "SomeThing", }) return nil }) h := app.Handler() ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodPost) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) encryptedCookie := fasthttp.Cookie{} encryptedCookie.SetKey("test") require.True(t, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value") decodedBytes, err := base64.StdEncoding.DecodeString(string(encryptedCookie.Value())) require.NoError(t, err) require.Equal(t, "SomeThing", string(decodedBytes)) ctx = &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.SetCookie("test", string(encryptedCookie.Value())) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) require.Equal(t, "value=SomeThing", string(ctx.Response.Body())) } func Test_GenerateKey(t *testing.T) { t.Parallel() tests := []struct { length int }{ {length: 16}, {length: 24}, {length: 32}, } decodeBase64 := func(t *testing.T, s string) []byte { t.Helper() data, err := base64.StdEncoding.DecodeString(s) require.NoError(t, err) return data } for _, tt := range tests { t.Run(strconv.Itoa(tt.length)+"_length", func(t *testing.T) { t.Parallel() key := GenerateKey(tt.length) decodedKey := decodeBase64(t, key) require.Len(t, decodedKey, tt.length) }) } t.Run("Invalid Length", func(t *testing.T) { require.Panics(t, func() { GenerateKey(10) }) require.Panics(t, func() { GenerateKey(20) }) }) } func Benchmark_Middleware_Encrypt_Cookie(b *testing.B) { testKey := GenerateKey(32) app := fiber.New() app.Use(New(Config{ Key: testKey, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("value=" + c.Cookies("test")) }) app.Post("/", func(c fiber.Ctx) error { c.Cookie(&fiber.Cookie{ Name: "test", Value: "SomeThing", }) return nil }) h := app.Handler() b.Run("Empty Cookie", func(b *testing.B) { for b.Loop() { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) } }) b.Run("Invalid Cookie", func(b *testing.B) { for b.Loop() { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.SetCookie("test", "Invalid") h(ctx) } }) b.Run("Valid Cookie", func(b *testing.B) { for b.Loop() { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodPost) h(ctx) } }) } func Benchmark_Encrypt_Cookie_Next(b *testing.B) { testKey := GenerateKey(32) app := fiber.New() app.Use(New(Config{ Key: testKey, Next: func(_ fiber.Ctx) bool { return true }, })) app.Get("/", func(c fiber.Ctx) error { c.Cookie(&fiber.Cookie{ Name: "test", Value: "SomeThing", }) return nil }) h := app.Handler() b.Run("Encrypt Cookie Next", func(b *testing.B) { for b.Loop() { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.SetRequestURI("/") h(ctx) } }) } func Benchmark_Encrypt_Cookie_Except(b *testing.B) { testKey := GenerateKey(32) app := fiber.New() app.Use(New(Config{ Key: testKey, Except: []string{ "test1", }, })) app.Get("/", func(c fiber.Ctx) error { c.Cookie(&fiber.Cookie{ Name: "test1", Value: "SomeThing", }) c.Cookie(&fiber.Cookie{ Name: "test2", Value: "SomeThing", }) return nil }) h := app.Handler() b.Run("Encrypt Cookie Except", func(b *testing.B) { for b.Loop() { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) } }) } func Benchmark_Encrypt_Cookie_Custom_Encryptor(b *testing.B) { testKey := GenerateKey(32) app := fiber.New() app.Use(New(Config{ Key: testKey, Encryptor: func(_, decryptedString, _ string) (string, error) { return base64.StdEncoding.EncodeToString([]byte(decryptedString)), nil }, Decryptor: func(_, encryptedString, _ string) (string, error) { decodedBytes, err := base64.StdEncoding.DecodeString(encryptedString) return string(decodedBytes), err }, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("value=" + c.Cookies("test")) }) app.Post("/", func(c fiber.Ctx) error { c.Cookie(&fiber.Cookie{ Name: "test", Value: "SomeThing", }) return nil }) h := app.Handler() b.Run("Custom Encryptor Post", func(b *testing.B) { for b.Loop() { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodPost) h(ctx) } }) b.Run("Custom Encryptor Get", func(b *testing.B) { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodPost) h(ctx) encryptedCookie := fasthttp.Cookie{} encryptedCookie.SetKey("test") require.True(b, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value") for b.Loop() { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.SetCookie("test", string(encryptedCookie.Value())) h(ctx) } }) } func Benchmark_Middleware_Encrypt_Cookie_Parallel(b *testing.B) { testKey := GenerateKey(32) app := fiber.New() app.Use(New(Config{ Key: testKey, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("value=" + c.Cookies("test")) }) app.Post("/", func(c fiber.Ctx) error { c.Cookie(&fiber.Cookie{ Name: "test", Value: "SomeThing", }) return nil }) h := app.Handler() b.Run("Empty Cookie Parallel", func(b *testing.B) { b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) } }) }) b.Run("Invalid Cookie Parallel", func(b *testing.B) { b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.SetCookie("test", "Invalid") h(ctx) } }) }) b.Run("Valid Cookie Parallel", func(b *testing.B) { b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodPost) h(ctx) } }) }) } func Benchmark_Encrypt_Cookie_Next_Parallel(b *testing.B) { testKey := GenerateKey(32) app := fiber.New() app.Use(New(Config{ Key: testKey, Next: func(_ fiber.Ctx) bool { return true }, })) app.Get("/", func(c fiber.Ctx) error { c.Cookie(&fiber.Cookie{ Name: "test", Value: "SomeThing", }) return nil }) h := app.Handler() b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.SetRequestURI("/") h(ctx) } }) } func Benchmark_Encrypt_Cookie_Except_Parallel(b *testing.B) { testKey := GenerateKey(32) app := fiber.New() app.Use(New(Config{ Key: testKey, Except: []string{ "test1", }, })) app.Get("/", func(c fiber.Ctx) error { c.Cookie(&fiber.Cookie{ Name: "test1", Value: "SomeThing", }) c.Cookie(&fiber.Cookie{ Name: "test2", Value: "SomeThing", }) return nil }) h := app.Handler() b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) } }) } func Benchmark_Encrypt_Cookie_Custom_Encryptor_Parallel(b *testing.B) { testKey := GenerateKey(32) app := fiber.New() app.Use(New(Config{ Key: testKey, Encryptor: func(_, decryptedString, _ string) (string, error) { return base64.StdEncoding.EncodeToString([]byte(decryptedString)), nil }, Decryptor: func(_, encryptedString, _ string) (string, error) { decodedBytes, err := base64.StdEncoding.DecodeString(encryptedString) return string(decodedBytes), err }, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("value=" + c.Cookies("test")) }) app.Post("/", func(c fiber.Ctx) error { c.Cookie(&fiber.Cookie{ Name: "test", Value: "SomeThing", }) return nil }) h := app.Handler() b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodPost) h(ctx) encryptedCookie := fasthttp.Cookie{} encryptedCookie.SetKey("test") require.True(b, ctx.Response.Header.Cookie(&encryptedCookie), "Get cookie value") b.ResetTimer() for pb.Next() { ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.SetCookie("test", string(encryptedCookie.Value())) h(ctx) } }) } func Benchmark_GenerateKey(b *testing.B) { tests := []struct { length int }{ {length: 16}, {length: 24}, {length: 32}, } for _, tt := range tests { b.Run(strconv.Itoa(tt.length), func(b *testing.B) { for b.Loop() { GenerateKey(tt.length) } }) } } func Benchmark_GenerateKey_Parallel(b *testing.B) { tests := []struct { length int }{ {length: 16}, {length: 24}, {length: 32}, } for _, tt := range tests { b.Run(strconv.Itoa(tt.length), func(b *testing.B) { b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { GenerateKey(tt.length) } }) }) } } // Test_Middleware_Mixed_Valid_Invalid_Cookies tests that the middleware correctly handles // a mix of valid and invalid cookies during iteration func Test_Middleware_Mixed_Valid_Invalid_Cookies(t *testing.T) { t.Parallel() testKey := GenerateKey(32) app := fiber.New() app.Use(New(Config{ Key: testKey, })) app.Get("/", func(c fiber.Ctx) error { valid1 := c.Cookies("valid1") valid2 := c.Cookies("valid2") invalid := c.Cookies("invalid") return c.SendString("valid1=" + valid1 + ",valid2=" + valid2 + ",invalid=" + invalid) }) h := app.Handler() // First, create some valid encrypted cookies encryptedValue1, err := EncryptCookie("valid1", "value1", testKey) require.NoError(t, err) encryptedValue2, err := EncryptCookie("valid2", "value2", testKey) require.NoError(t, err) // Test with a mix of valid and invalid cookies ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.SetCookie("valid1", encryptedValue1) ctx.Request.Header.SetCookie("invalid", "thisisnotvalid") ctx.Request.Header.SetCookie("valid2", encryptedValue2) h(ctx) require.Equal(t, 200, ctx.Response.StatusCode()) require.Equal(t, "valid1=value1,valid2=value2,invalid=", string(ctx.Response.Body())) // Verify the invalid cookie was deleted but valid ones remain require.NotEmpty(t, ctx.Request.Header.Cookie("valid1")) require.Empty(t, ctx.Request.Header.Cookie("invalid")) require.NotEmpty(t, ctx.Request.Header.Cookie("valid2")) } ================================================ FILE: middleware/encryptcookie/utils.go ================================================ package encryptcookie import ( "crypto/aes" "crypto/cipher" "crypto/rand" "encoding/base64" "errors" "fmt" "slices" ) var ( ErrInvalidKeyLength = errors.New("encryption key must be 16, 24, or 32 bytes") ErrInvalidEncryptedValue = errors.New("encrypted value is not valid") ) // decodeKey decodes the provided base64-encoded key and validates its length. // It returns the decoded key bytes or an error when invalid. func decodeKey(key string) ([]byte, error) { keyDecoded, err := base64.StdEncoding.DecodeString(key) if err != nil { return nil, fmt.Errorf("failed to base64-decode key: %w", err) } keyLen := len(keyDecoded) if keyLen != 16 && keyLen != 24 && keyLen != 32 { return nil, ErrInvalidKeyLength } return keyDecoded, nil } // validateKey checks if the provided base64-encoded key is of valid length. func validateKey(key string) error { _, err := decodeKey(key) return err } // EncryptCookie Encrypts a cookie value with specific encryption key func EncryptCookie(name, value, key string) (string, error) { keyDecoded, err := decodeKey(key) if err != nil { return "", err } block, err := aes.NewCipher(keyDecoded) if err != nil { return "", fmt.Errorf("failed to create AES cipher: %w", err) } gcm, err := cipher.NewGCMWithRandomNonce(block) if err != nil { return "", fmt.Errorf("failed to create GCM mode: %w", err) } ciphertext := gcm.Seal(nil, nil, []byte(value), []byte(name)) return base64.StdEncoding.EncodeToString(ciphertext), nil } // DecryptCookie Decrypts a cookie value with specific encryption key func DecryptCookie(name, value, key string) (string, error) { keyDecoded, err := decodeKey(key) if err != nil { return "", err } enc, err := base64.StdEncoding.DecodeString(value) if err != nil { return "", fmt.Errorf("failed to base64-decode value: %w", err) } block, err := aes.NewCipher(keyDecoded) if err != nil { return "", fmt.Errorf("failed to create AES cipher: %w", err) } gcm, err := cipher.NewGCMWithRandomNonce(block) if err != nil { return "", fmt.Errorf("failed to create GCM mode: %w", err) } if len(enc) < gcm.NonceSize()+gcm.Overhead() { return "", ErrInvalidEncryptedValue } plaintext, err := gcm.Open(nil, nil, enc, []byte(name)) if err != nil { return "", fmt.Errorf("failed to decrypt ciphertext: %w", err) } return string(plaintext), nil } // GenerateKey returns a random string of 16, 24, or 32 bytes. // The length of the key determines the AES encryption algorithm used: // 16 bytes for AES-128, 24 bytes for AES-192, and 32 bytes for AES-256-GCM. func GenerateKey(length int) string { if length != 16 && length != 24 && length != 32 { panic(ErrInvalidKeyLength) } key := make([]byte, length) if _, err := rand.Read(key); err != nil { panic(err) } return base64.StdEncoding.EncodeToString(key) } // Check given cookie key is disabled for encryption or not func isDisabled(key string, except []string) bool { return slices.Contains(except, key) } ================================================ FILE: middleware/envvar/config.go ================================================ package envvar // Config defines the config for middleware. type Config struct { // ExportVars specifies the environment variables that should export ExportVars map[string]string } // ConfigDefault is the default config. var ConfigDefault = Config{ ExportVars: map[string]string{}, } func configDefault(config ...Config) Config { if len(config) == 0 { return ConfigDefault } cfg := config[0] if cfg.ExportVars == nil { cfg.ExportVars = ConfigDefault.ExportVars } return cfg } ================================================ FILE: middleware/envvar/envvar.go ================================================ package envvar import ( "os" "github.com/gofiber/fiber/v3" ) const hAllow = fiber.MethodGet + ", " + fiber.MethodHead // EnvVar captures environment variables that are exposed through the // middleware response. type EnvVar struct { Vars map[string]string `json:"vars"` } func (envVar *EnvVar) set(key, val string) { envVar.Vars[key] = val } // New creates a handler that returns configured environment variables as a // JSON response. func New(config ...Config) fiber.Handler { cfg := configDefault(config...) return func(c fiber.Ctx) error { method := c.Method() if method != fiber.MethodGet && method != fiber.MethodHead { c.Set(fiber.HeaderAllow, hAllow) return fiber.ErrMethodNotAllowed } envVar := newEnvVar(cfg) varsByte, err := c.App().Config().JSONEncoder(envVar) if err != nil { return c.Status(fiber.StatusInternalServerError).SendString(err.Error()) } c.Set(fiber.HeaderContentType, fiber.MIMEApplicationJSONCharsetUTF8) return c.Send(varsByte) } } func newEnvVar(cfg Config) *EnvVar { vars := &EnvVar{Vars: make(map[string]string)} if len(cfg.ExportVars) == 0 { // do not expose environment variables when no configuration // is supplied to prevent accidental information disclosure return vars } for key, defaultVal := range cfg.ExportVars { vars.set(key, defaultVal) if envVal, exists := os.LookupEnv(key); exists { vars.set(key, envVal) } } return vars } ================================================ FILE: middleware/envvar/envvar_test.go ================================================ package envvar import ( "context" "encoding/json" "io" "net/http" "net/http/httptest" "testing" "github.com/gofiber/fiber/v3" "github.com/stretchr/testify/require" ) func Test_EnvVarStructWithExportVars(t *testing.T) { t.Setenv("testKey", "testEnvValue") t.Setenv("anotherEnvKey", "anotherEnvVal") vars := newEnvVar(Config{ ExportVars: map[string]string{"testKey": "", "testDefaultKey": "testDefaultVal"}, }) require.Equal(t, "testEnvValue", vars.Vars["testKey"]) require.Equal(t, "testDefaultVal", vars.Vars["testDefaultKey"]) require.Empty(t, vars.Vars["anotherEnvKey"]) } func Test_EnvVarHandler(t *testing.T) { t.Setenv("testKey", "testVal") expectedEnvVarResponse, err := json.Marshal( struct { Vars map[string]string `json:"vars"` }{ Vars: map[string]string{"testKey": "testVal"}, }) require.NoError(t, err) app := fiber.New() app.Use("/envvars", New(Config{ ExportVars: map[string]string{"testKey": ""}, })) req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars", http.NoBody) require.NoError(t, err) resp, err := app.Test(req) require.NoError(t, err) respBody, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, expectedEnvVarResponse, respBody) } func Test_EnvVarHandlerNotMatched(t *testing.T) { app := fiber.New() app.Use("/envvars", New(Config{ ExportVars: map[string]string{"testKey": ""}, })) app.Get("/another-path", func(ctx fiber.Ctx) error { require.NoError(t, ctx.SendString("OK")) return nil }) req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/another-path", http.NoBody) require.NoError(t, err) resp, err := app.Test(req) require.NoError(t, err) respBody, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, []byte("OK"), respBody) } func Test_EnvVarHandlerDefaultConfig(t *testing.T) { t.Setenv("testEnvKey", "testEnvVal") app := fiber.New() app.Use("/envvars", New()) req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars", http.NoBody) require.NoError(t, err) resp, err := app.Test(req) require.NoError(t, err) respBody, err := io.ReadAll(resp.Body) require.NoError(t, err) var envVars EnvVar require.NoError(t, json.Unmarshal(respBody, &envVars)) _, exists := envVars.Vars["testEnvKey"] require.False(t, exists) } func Test_EnvVarHandlerMethod(t *testing.T) { app := fiber.New() app.Use("/envvars", New()) req, err := http.NewRequestWithContext(context.Background(), fiber.MethodPost, "http://localhost/envvars", http.NoBody) require.NoError(t, err) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusMethodNotAllowed, resp.StatusCode) require.Equal(t, hAllow, resp.Header.Get(fiber.HeaderAllow)) } func Test_EnvVarHandlerHead(t *testing.T) { app := fiber.New() app.Use("/envvars", New()) req := httptest.NewRequest(fiber.MethodHead, "http://localhost/envvars", http.NoBody) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Empty(t, string(body)) } func Test_EnvVarHandlerSpecialValue(t *testing.T) { testEnvKey := "testEnvKey" fakeBase64 := "testBase64:TQ==" t.Setenv(testEnvKey, fakeBase64) app := fiber.New() app.Use("/envvars/export", New(Config{ExportVars: map[string]string{testEnvKey: ""}})) app.Use("/envvars", New()) req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars", http.NoBody) require.NoError(t, err) resp, err := app.Test(req) require.NoError(t, err) respBody, err := io.ReadAll(resp.Body) require.NoError(t, err) var envVars EnvVar require.NoError(t, json.Unmarshal(respBody, &envVars)) _, exists := envVars.Vars[testEnvKey] require.False(t, exists) req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "http://localhost/envvars/export", http.NoBody) require.NoError(t, err) resp, err = app.Test(req) require.NoError(t, err) respBody, err = io.ReadAll(resp.Body) require.NoError(t, err) var envVarsExport EnvVar require.NoError(t, json.Unmarshal(respBody, &envVarsExport)) val := envVarsExport.Vars[testEnvKey] require.Equal(t, fakeBase64, val) } ================================================ FILE: middleware/etag/config.go ================================================ package etag import ( "github.com/gofiber/fiber/v3" ) // Config defines the config for middleware. type Config struct { // Next defines a function to skip this middleware when returned true. // // Optional. Default: nil Next func(c fiber.Ctx) bool // Weak indicates that a weak validator is used. Weak etags are easy // to generate, but are far less useful for comparisons. Strong // validators are ideal for comparisons but can be very difficult // to generate efficiently. Weak ETag values of two representations // of the same resources might be semantically equivalent, but not // byte-for-byte identical. This means weak etags prevent caching // when byte range requests are used, but strong etags mean range // requests can still be cached. Weak bool } // ConfigDefault is the default config var ConfigDefault = Config{ Weak: false, Next: nil, } // Helper function to set default values func configDefault(config ...Config) Config { // Return default config if nothing provided if len(config) < 1 { return ConfigDefault } // Override default config cfg := config[0] // Set default values return cfg } ================================================ FILE: middleware/etag/etag.go ================================================ package etag import ( "bytes" "hash/crc32" "math" "slices" "github.com/gofiber/fiber/v3" "github.com/valyala/bytebufferpool" ) var ( weakPrefix = []byte("W/") crc32q = crc32.MakeTable(0xD5828281) ) // Generate returns a strong ETag for body. func Generate(body []byte) []byte { if uint64(len(body)) > math.MaxUint32 { return nil } bb := bytebufferpool.Get() defer bytebufferpool.Put(bb) b := bb.B[:0] b = append(b, '"') b = appendUint(b, uint32(len(body))) // #nosec G115 -- length checked above b = append(b, '-') b = appendUint(b, crc32.Checksum(body, crc32q)) b = append(b, '"') return slices.Clone(b) } // GenerateWeak returns a weak ETag for body. func GenerateWeak(body []byte) []byte { tag := Generate(body) if tag == nil { return nil } return append(weakPrefix, tag...) } // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Set default config cfg := configDefault(config...) normalizedHeaderETag := []byte("Etag") // Return new handler return func(c fiber.Ctx) error { // Don't execute middleware if Next returns true if cfg.Next != nil && cfg.Next(c) { return c.Next() } // Return err if next handler returns one if err := c.Next(); err != nil { return err } // Don't generate ETags for invalid responses if c.Response().StatusCode() != fiber.StatusOK { return nil } body := c.Response().Body() // Skips ETag if no response body is present if len(body) == 0 { return nil } // Skip ETag if header is already present if c.Response().Header.PeekBytes(normalizedHeaderETag) != nil { return nil } bodyLength := len(body) if uint64(bodyLength) > math.MaxUint32 { return c.SendStatus(fiber.StatusRequestEntityTooLarge) } var etag []byte if cfg.Weak { etag = GenerateWeak(body) } else { etag = Generate(body) } // Get ETag header from request clientEtag := c.Request().Header.Peek(fiber.HeaderIfNoneMatch) // Check if client's ETag is weak if bytes.HasPrefix(clientEtag, weakPrefix) { // Check if server's ETag is weak if bytes.Equal(clientEtag[2:], etag) || bytes.Equal(clientEtag[2:], etag[2:]) { // W/1 == 1 || W/1 == W/1 c.RequestCtx().ResetBody() return c.SendStatus(fiber.StatusNotModified) } // W/1 != W/2 || W/1 != 2 c.Response().Header.SetCanonical(normalizedHeaderETag, etag) return nil } if bytes.Contains(clientEtag, etag) { // 1 == 1 c.RequestCtx().ResetBody() return c.SendStatus(fiber.StatusNotModified) } // 1 != 2 c.Response().Header.SetCanonical(normalizedHeaderETag, etag) return nil } } // appendUint appends n to dst and returns the extended dst. func appendUint(dst []byte, n uint32) []byte { var b [20]byte buf := b[:] i := len(buf) var q uint32 for n >= 10 { i-- q = n / 10 buf[i] = '0' + byte(n-q*10) n = q } i-- buf[i] = '0' + byte(n) dst = append(dst, buf[i:]...) return dst } ================================================ FILE: middleware/etag/etag_test.go ================================================ package etag import ( "bytes" "io" "net/http" "net/http/httptest" "testing" "github.com/gofiber/fiber/v3" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) // go test -run Test_ETag_Next func Test_ETag_Next(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Next: func(_ fiber.Ctx) bool { return true }, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusNotFound, resp.StatusCode) } // go test -run Test_ETag_SkipError func Test_ETag_SkipError(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(_ fiber.Ctx) error { return fiber.ErrForbidden }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusForbidden, resp.StatusCode) } // go test -run Test_ETag_NotStatusOK func Test_ETag_NotStatusOK(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusCreated) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusCreated, resp.StatusCode) } // go test -run Test_ETag_NoBody func Test_ETag_NoBody(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(_ fiber.Ctx) error { return nil }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) } // go test -run Test_ETag_NewEtag func Test_ETag_NewEtag(t *testing.T) { t.Parallel() t.Run("without HeaderIfNoneMatch", func(t *testing.T) { t.Parallel() testETagNewEtag(t, false, false) }) t.Run("with HeaderIfNoneMatch and not matched", func(t *testing.T) { t.Parallel() testETagNewEtag(t, true, false) }) t.Run("with HeaderIfNoneMatch and matched", func(t *testing.T) { t.Parallel() testETagNewEtag(t, true, true) }) } func testETagNewEtag(t *testing.T, headerIfNoneMatch, matched bool) { //nolint:revive // We're in a test, so using bools as a flow-control is fine t.Helper() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) if headerIfNoneMatch { etag := `"non-match"` if matched { etag = `"13-1831710635"` } req.Header.Set(fiber.HeaderIfNoneMatch, etag) } resp, err := app.Test(req) require.NoError(t, err) if !headerIfNoneMatch || !matched { require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Equal(t, `"13-1831710635"`, resp.Header.Get(fiber.HeaderETag)) return } if matched { require.Equal(t, fiber.StatusNotModified, resp.StatusCode) b, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Empty(t, b) } } // go test -run Test_ETag_WeakEtag func Test_ETag_WeakEtag(t *testing.T) { t.Parallel() t.Run("without HeaderIfNoneMatch", func(t *testing.T) { t.Parallel() testETagWeakEtag(t, false, false) }) t.Run("with HeaderIfNoneMatch and not matched", func(t *testing.T) { t.Parallel() testETagWeakEtag(t, true, false) }) t.Run("with HeaderIfNoneMatch and matched", func(t *testing.T) { t.Parallel() testETagWeakEtag(t, true, true) }) } func testETagWeakEtag(t *testing.T, headerIfNoneMatch, matched bool) { //nolint:revive // We're in a test, so using bools as a flow-control is fine t.Helper() app := fiber.New() app.Use(New(Config{Weak: true})) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) if headerIfNoneMatch { etag := `W/"non-match"` if matched { etag = `W/"13-1831710635"` } req.Header.Set(fiber.HeaderIfNoneMatch, etag) } resp, err := app.Test(req) require.NoError(t, err) if !headerIfNoneMatch || !matched { require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Equal(t, `W/"13-1831710635"`, resp.Header.Get(fiber.HeaderETag)) return } if matched { require.Equal(t, fiber.StatusNotModified, resp.StatusCode) b, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Empty(t, b) } } // go test -run Test_ETag_CustomEtag func Test_ETag_CustomEtag(t *testing.T) { t.Parallel() t.Run("without HeaderIfNoneMatch", func(t *testing.T) { t.Parallel() testETagCustomEtag(t, false, false) }) t.Run("with HeaderIfNoneMatch and not matched", func(t *testing.T) { t.Parallel() testETagCustomEtag(t, true, false) }) t.Run("with HeaderIfNoneMatch and matched", func(t *testing.T) { t.Parallel() testETagCustomEtag(t, true, true) }) } func testETagCustomEtag(t *testing.T, headerIfNoneMatch, matched bool) { //nolint:revive // We're in a test, so using bools as a flow-control is fine t.Helper() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { c.Set(fiber.HeaderETag, `"custom"`) if bytes.Equal(c.Request().Header.Peek(fiber.HeaderIfNoneMatch), []byte(`"custom"`)) { return c.SendStatus(fiber.StatusNotModified) } return c.SendString("Hello, World!") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) if headerIfNoneMatch { etag := `"non-match"` if matched { etag = `"custom"` } req.Header.Set(fiber.HeaderIfNoneMatch, etag) } resp, err := app.Test(req) require.NoError(t, err) if !headerIfNoneMatch || !matched { require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Equal(t, `"custom"`, resp.Header.Get(fiber.HeaderETag)) return } if matched { require.Equal(t, fiber.StatusNotModified, resp.StatusCode) b, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Empty(t, b) } } // go test -run Test_ETag_CustomEtagPut func Test_ETag_CustomEtagPut(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Put("/", func(c fiber.Ctx) error { c.Set(fiber.HeaderETag, `"custom"`) if !bytes.Equal(c.Request().Header.Peek(fiber.HeaderIfMatch), []byte(`"custom"`)) { return c.SendStatus(fiber.StatusPreconditionFailed) } return c.SendString("Hello, World!") }) req := httptest.NewRequest(fiber.MethodPut, "/", http.NoBody) req.Header.Set(fiber.HeaderIfMatch, `"non-match"`) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusPreconditionFailed, resp.StatusCode) } // go test -v -run=^$ -bench=Benchmark_Etag -benchmem -count=4 func Benchmark_Etag(b *testing.B) { app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) h := app.Handler() fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(fiber.MethodGet) fctx.Request.SetRequestURI("/") b.ReportAllocs() for b.Loop() { h(fctx) } require.Equal(b, 200, fctx.Response.Header.StatusCode()) require.Equal(b, `"13-1831710635"`, string(fctx.Response.Header.Peek(fiber.HeaderETag))) } ================================================ FILE: middleware/expvar/config.go ================================================ package expvar import ( "github.com/gofiber/fiber/v3" ) // Config defines the config for middleware. type Config struct { // Next defines a function to skip this middleware when returned true. // // Optional. Default: nil Next func(c fiber.Ctx) bool } var ConfigDefault = Config{ Next: nil, } func configDefault(config ...Config) Config { // Return default config if nothing provided if len(config) < 1 { return ConfigDefault } // Override default config cfg := config[0] // Set default values if cfg.Next == nil { cfg.Next = ConfigDefault.Next } return cfg } ================================================ FILE: middleware/expvar/expvar.go ================================================ package expvar import ( "strings" "github.com/gofiber/fiber/v3" "github.com/valyala/fasthttp/expvarhandler" ) // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Set default config cfg := configDefault(config...) // Return new handler return func(c fiber.Ctx) error { // Don't execute middleware if Next returns true if cfg.Next != nil && cfg.Next(c) { return c.Next() } path := c.Path() // We are only interested in /debug/vars routes if len(path) < 11 || !strings.HasPrefix(path, "/debug/vars") { return c.Next() } if path == "/debug/vars" { expvarhandler.ExpvarHandler(c.RequestCtx()) return nil } return c.Redirect().To("/debug/vars") } } ================================================ FILE: middleware/expvar/expvar_test.go ================================================ package expvar import ( "bytes" "io" "net/http" "net/http/httptest" "testing" "github.com/gofiber/fiber/v3" "github.com/stretchr/testify/require" ) func Test_Non_Expvar_Path(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { return c.SendString("escaped") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) b, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "escaped", string(b)) } func Test_Expvar_Index(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { return c.SendString("escaped") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/debug/vars", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) require.Equal(t, fiber.MIMEApplicationJSONCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) b, err := io.ReadAll(resp.Body) require.NoError(t, err) require.True(t, bytes.Contains(b, []byte("cmdline"))) require.True(t, bytes.Contains(b, []byte("memstat"))) } func Test_Expvar_Filter(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { return c.SendString("escaped") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/debug/vars?r=cmd", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) require.Equal(t, fiber.MIMEApplicationJSONCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) b, err := io.ReadAll(resp.Body) require.NoError(t, err) require.True(t, bytes.Contains(b, []byte("cmdline"))) require.False(t, bytes.Contains(b, []byte("memstat"))) } func Test_Expvar_Other_Path(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { return c.SendString("escaped") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/debug/vars/303", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusSeeOther, resp.StatusCode) } // go test -run Test_Expvar_Next func Test_Expvar_Next(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Next: func(_ fiber.Ctx) bool { return true }, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/debug/vars", http.NoBody)) require.NoError(t, err) require.Equal(t, 404, resp.StatusCode) } ================================================ FILE: middleware/favicon/config.go ================================================ package favicon import ( "io/fs" "github.com/gofiber/fiber/v3" ) // Config defines the config for middleware. type Config struct { // FileSystem is an optional alternate filesystem to search for the favicon in. // An example of this could be an embedded or network filesystem // // Optional. Default: nil FileSystem fs.FS `json:"-"` // Next defines a function to skip this middleware when returned true. // // Optional. Default: nil Next func(c fiber.Ctx) bool // File holds the path to an actual favicon that will be cached // // Optional. Default: "" File string `json:"file"` // URL for favicon handler // // Optional. Default: "/favicon.ico" URL string `json:"url"` // CacheControl defines how the Cache-Control header in the response should be set // // Optional. Default: "public, max-age=31536000" CacheControl string `json:"cache_control"` // Raw data of the favicon file // // Optional. Default: nil Data []byte `json:"-"` // MaxBytes limits the maximum size of the cached favicon asset. // // Optional. Default: 1048576 MaxBytes int64 `json:"max_bytes"` } // ConfigDefault is the default config var ConfigDefault = Config{ Next: nil, File: "", URL: fPath, CacheControl: "public, max-age=31536000", MaxBytes: 1024 * 1024, } func configDefault(config ...Config) Config { if len(config) == 0 { return ConfigDefault } cfg := config[0] if cfg.Next == nil { cfg.Next = ConfigDefault.Next } if cfg.URL == "" { cfg.URL = ConfigDefault.URL } if cfg.File == "" { cfg.File = ConfigDefault.File } if cfg.CacheControl == "" { cfg.CacheControl = ConfigDefault.CacheControl } if cfg.MaxBytes <= 0 { cfg.MaxBytes = ConfigDefault.MaxBytes } return cfg } ================================================ FILE: middleware/favicon/favicon.go ================================================ package favicon import ( "fmt" "io" "io/fs" "os" "strconv" "github.com/gofiber/fiber/v3" ) const ( fPath = "/favicon.ico" hType = "image/x-icon" hAllow = "GET, HEAD, OPTIONS" hZero = "0" ) // New creates a new middleware handler func New(config ...Config) fiber.Handler { cfg := configDefault(config...) // Load iconData if provided var ( err error iconData []byte iconLenHeader string iconLen int f fs.File ) if cfg.Data != nil { // use the provided favicon data iconData = cfg.Data iconLenHeader = strconv.Itoa(len(cfg.Data)) iconLen = len(cfg.Data) } else if cfg.File != "" { // read from configured filesystem if present if cfg.FileSystem != nil { f, err = cfg.FileSystem.Open(cfg.File) if err != nil { panic(err) } defer func() { _ = f.Close() //nolint:errcheck // not needed }() if iconData, err = readLimited(f, cfg.MaxBytes); err != nil { panic(err) } } else { f, err = os.Open(cfg.File) if err != nil { panic(err) } defer func() { _ = f.Close() //nolint:errcheck // not needed }() if iconData, err = readLimited(f, cfg.MaxBytes); err != nil { panic(err) } } iconLenHeader = strconv.Itoa(len(iconData)) iconLen = len(iconData) } // Return new handler return func(c fiber.Ctx) error { // Don't execute middleware if Next returns true if cfg.Next != nil && cfg.Next(c) { return c.Next() } // Only respond to favicon requests if c.Path() != cfg.URL { return c.Next() } // Only allow GET, HEAD and OPTIONS requests if c.Method() != fiber.MethodGet && c.Method() != fiber.MethodHead { if c.Method() != fiber.MethodOptions { c.Status(fiber.StatusMethodNotAllowed) } else { c.Status(fiber.StatusOK) } c.Set(fiber.HeaderAllow, hAllow) c.Set(fiber.HeaderContentLength, hZero) return nil } // Serve cached favicon if iconLen > 0 { c.Set(fiber.HeaderContentLength, iconLenHeader) c.Set(fiber.HeaderContentType, hType) c.Set(fiber.HeaderCacheControl, cfg.CacheControl) return c.Status(fiber.StatusOK).Send(iconData) } return c.SendStatus(fiber.StatusNoContent) } } func readLimited(reader io.Reader, maxBytes int64) ([]byte, error) { limit := maxBytes + 1 data, err := io.ReadAll(io.LimitReader(reader, limit)) if err != nil { return nil, fmt.Errorf("favicon: read limited: %w", err) } if int64(len(data)) > maxBytes { return nil, fmt.Errorf("favicon: file size exceeds max bytes %d", maxBytes) } return data, nil } ================================================ FILE: middleware/favicon/favicon_test.go ================================================ package favicon import ( "bytes" "net/http" "net/http/httptest" "os" "testing" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" "github.com/gofiber/fiber/v3" ) // go test -run Test_Middleware_Favicon func Test_Middleware_Favicon(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(_ fiber.Ctx) error { return nil }) // Skip Favicon middleware resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, fiber.StatusOK, resp.StatusCode, "Status code") resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, fiber.StatusNoContent, resp.StatusCode, "Status code") resp, err = app.Test(httptest.NewRequest(fiber.MethodOptions, "/favicon.ico", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, fiber.StatusOK, resp.StatusCode, "Status code") resp, err = app.Test(httptest.NewRequest(fiber.MethodPut, "/favicon.ico", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, fiber.StatusMethodNotAllowed, resp.StatusCode, "Status code") require.Equal(t, "GET, HEAD, OPTIONS", resp.Header.Get(fiber.HeaderAllow)) } // go test -run Test_Middleware_Favicon_Not_Found func Test_Middleware_Favicon_Not_Found(t *testing.T) { t.Parallel() defer func() { if err := recover(); err == nil { t.Error("should catch panic") return } }() fiber.New().Use(New(Config{ File: "non-exist.ico", })) } // go test -run Test_Middleware_Favicon_MaxBytes func Test_Middleware_Favicon_MaxBytes(t *testing.T) { t.Parallel() defer func() { if err := recover(); err == nil { t.Error("should catch panic") } }() dir := t.TempDir() path := dir + "/favicon.ico" err := os.WriteFile(path, bytes.Repeat([]byte("a"), 11), 0o600) require.NoError(t, err) fiber.New().Use(New(Config{ File: path, MaxBytes: 10, })) } // go test -run Test_Middleware_Favicon_MaxBytes_FileSystem func Test_Middleware_Favicon_MaxBytes_FileSystem(t *testing.T) { t.Parallel() defer func() { if err := recover(); err == nil { t.Error("should catch panic") } }() dir := t.TempDir() path := dir + "/favicon.ico" err := os.WriteFile(path, bytes.Repeat([]byte("a"), 11), 0o600) require.NoError(t, err) fiber.New().Use(New(Config{ File: "favicon.ico", FileSystem: os.DirFS(dir), MaxBytes: 10, })) } // go test -run Test_Middleware_Favicon_Found func Test_Middleware_Favicon_Found(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ File: "../../.github/testdata/favicon.ico", })) app.Get("/", func(_ fiber.Ctx) error { return nil }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, fiber.StatusOK, resp.StatusCode, "Status code") require.Equal(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType)) require.Equal(t, "public, max-age=31536000", resp.Header.Get(fiber.HeaderCacheControl), "CacheControl Control") } // go test -run Test_Custom_Favicon_Url func Test_Custom_Favicon_URL(t *testing.T) { app := fiber.New() const customURL = "/favicon.svg" app.Use(New(Config{ File: "../../.github/testdata/favicon.ico", URL: customURL, })) app.Get("/", func(_ fiber.Ctx) error { return nil }) resp, err := app.Test(httptest.NewRequest(http.MethodGet, customURL, http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, fiber.StatusOK, resp.StatusCode, "Status code") require.Equal(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType)) } // go test -run Test_Custom_Favicon_Data func Test_Custom_Favicon_Data(t *testing.T) { data, err := os.ReadFile("../../.github/testdata/favicon.ico") require.NoError(t, err) app := fiber.New() app.Use(New(Config{ Data: data, })) app.Get("/", func(_ fiber.Ctx) error { return nil }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, fiber.StatusOK, resp.StatusCode, "Status code") require.Equal(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType)) require.Equal(t, "public, max-age=31536000", resp.Header.Get(fiber.HeaderCacheControl), "CacheControl Control") } // go test -run Test_Middleware_Favicon_FileSystem func Test_Middleware_Favicon_FileSystem(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ File: "favicon.ico", FileSystem: os.DirFS("../../.github/testdata"), })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, fiber.StatusOK, resp.StatusCode, "Status code") require.Equal(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType)) require.Equal(t, "public, max-age=31536000", resp.Header.Get(fiber.HeaderCacheControl), "CacheControl Control") } // go test -run Test_Middleware_Favicon_CacheControl func Test_Middleware_Favicon_CacheControl(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ CacheControl: "public, max-age=100", File: "../../.github/testdata/favicon.ico", })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/favicon.ico", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, fiber.StatusOK, resp.StatusCode, "Status code") require.Equal(t, "image/x-icon", resp.Header.Get(fiber.HeaderContentType)) require.Equal(t, "public, max-age=100", resp.Header.Get(fiber.HeaderCacheControl), "CacheControl Control") } // go test -v -run=^$ -bench=Benchmark_Middleware_Favicon -benchmem -count=4 func Benchmark_Middleware_Favicon(b *testing.B) { app := fiber.New() app.Use(New()) app.Get("/", func(_ fiber.Ctx) error { return nil }) handler := app.Handler() c := &fasthttp.RequestCtx{} c.Request.SetRequestURI("/") b.ReportAllocs() for b.Loop() { handler(c) } } // go test -run Test_Favicon_Next func Test_Favicon_Next(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Next: func(_ fiber.Ctx) bool { return true }, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusNotFound, resp.StatusCode) } ================================================ FILE: middleware/healthcheck/config.go ================================================ package healthcheck import ( "github.com/gofiber/fiber/v3" ) // Config defines the configuration options for the healthcheck middleware. type Config struct { // Next defines a function to skip this middleware when returned true. If this function returns true // and no other handlers are defined for the route, Fiber will return a status 404 Not Found, since // no other handlers were defined to return a different status. // // Optional. Default: nil Next func(fiber.Ctx) bool // Probe is executed to determine the current health state. It can be used for liveness, // readiness or startup checks. Returning true indicates the application is healthy. // // Optional. Default: func(c fiber.Ctx) bool { return true } Probe func(fiber.Ctx) bool } const ( // LivenessEndpoint is the conventional path for a liveness check. // Register the middleware on this path to expose it. LivenessEndpoint = "/livez" // ReadinessEndpoint is the conventional path for a readiness check. // Register the middleware on this path to expose it. ReadinessEndpoint = "/readyz" // StartupEndpoint is the conventional path for a startup check. // Register the middleware on this path to expose it. StartupEndpoint = "/startupz" ) func defaultProbe(_ fiber.Ctx) bool { return true } // ConfigDefault is the default configuration. var ConfigDefault = Config{ Next: nil, Probe: defaultProbe, } func configDefault(config ...Config) Config { if len(config) < 1 { return ConfigDefault } cfg := config[0] if cfg.Probe == nil { cfg.Probe = ConfigDefault.Probe } return cfg } ================================================ FILE: middleware/healthcheck/healthcheck.go ================================================ package healthcheck import ( "github.com/gofiber/fiber/v3" ) // New returns a health-check handler that responds based on the provided // configuration. func New(config ...Config) fiber.Handler { cfg := configDefault(config...) return func(c fiber.Ctx) error { // Don't execute middleware if Next returns true if cfg.Next != nil && cfg.Next(c) { return c.Next() } if c.Method() != fiber.MethodGet { return c.Next() } if cfg.Probe(c) { return c.SendStatus(fiber.StatusOK) } return c.SendStatus(fiber.StatusServiceUnavailable) } } ================================================ FILE: middleware/healthcheck/healthcheck_test.go ================================================ package healthcheck import ( "net/http" "net/http/httptest" "strconv" "testing" "github.com/gofiber/fiber/v3" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) func shouldGiveStatus(t *testing.T, app *fiber.App, path string, expectedStatus int) { t.Helper() req, err := app.Test(httptest.NewRequest(fiber.MethodGet, path, http.NoBody)) require.NoError(t, err) require.Equal(t, expectedStatus, req.StatusCode, "path: "+path+" should match "+strconv.Itoa(expectedStatus)) } func shouldGiveOK(t *testing.T, app *fiber.App, path string) { t.Helper() shouldGiveStatus(t, app, path, fiber.StatusOK) } func shouldGiveNotFound(t *testing.T, app *fiber.App, path string) { t.Helper() shouldGiveStatus(t, app, path, fiber.StatusNotFound) } func Test_HealthCheck_Strict_Routing_Default(t *testing.T) { t.Parallel() app := fiber.New(fiber.Config{ StrictRouting: true, }) app.Get(LivenessEndpoint, New()) app.Get(ReadinessEndpoint, New()) app.Get(StartupEndpoint, New()) shouldGiveOK(t, app, "/readyz") shouldGiveOK(t, app, "/livez") shouldGiveOK(t, app, "/startupz") shouldGiveNotFound(t, app, "/readyz/") shouldGiveNotFound(t, app, "/livez/") shouldGiveNotFound(t, app, "/startupz/") shouldGiveNotFound(t, app, "/notDefined/readyz") shouldGiveNotFound(t, app, "/notDefined/livez") shouldGiveNotFound(t, app, "/notDefined/startupz") } func Test_HealthCheck_Default(t *testing.T) { t.Parallel() app := fiber.New() app.Get(LivenessEndpoint, New()) app.Get(ReadinessEndpoint, New()) app.Get(StartupEndpoint, New()) shouldGiveOK(t, app, "/readyz") shouldGiveOK(t, app, "/livez") shouldGiveOK(t, app, "/startupz") shouldGiveOK(t, app, "/readyz/") shouldGiveOK(t, app, "/livez/") shouldGiveOK(t, app, "/startupz/") shouldGiveNotFound(t, app, "/notDefined/readyz") shouldGiveNotFound(t, app, "/notDefined/livez") shouldGiveNotFound(t, app, "/notDefined/startupz") } func Test_HealthCheck_Custom(t *testing.T) { t.Parallel() app := fiber.New() c1 := make(chan struct{}, 1) app.Get("/live", New(Config{ Probe: func(_ fiber.Ctx) bool { return true }, })) app.Get("/ready", New(Config{ Probe: func(_ fiber.Ctx) bool { select { case <-c1: return true default: return false } }, })) app.Get(StartupEndpoint, New(Config{ Probe: func(_ fiber.Ctx) bool { return false }, })) // Setup custom liveness and readiness probes to simulate application health status // Live should return 200 with GET request shouldGiveOK(t, app, "/live") // Live should return 404 with POST request req, err := app.Test(httptest.NewRequest(fiber.MethodPost, "/live", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusMethodNotAllowed, req.StatusCode) // Ready should return 404 with POST request req, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/ready", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusMethodNotAllowed, req.StatusCode) // Ready should return 503 with GET request before the channel is closed shouldGiveStatus(t, app, "/ready", fiber.StatusServiceUnavailable) shouldGiveStatus(t, app, "/startupz", fiber.StatusServiceUnavailable) // Ready should return 200 with GET request after the channel is closed c1 <- struct{}{} shouldGiveOK(t, app, "/ready") } func Test_HealthCheck_Custom_Nested(t *testing.T) { t.Parallel() app := fiber.New() c1 := make(chan struct{}, 1) app.Get("/probe/live", New(Config{ Probe: func(_ fiber.Ctx) bool { return true }, })) app.Get("/probe/ready", New(Config{ Probe: func(_ fiber.Ctx) bool { select { case <-c1: return true default: return false } }, })) // Testing custom health check endpoints with nested paths shouldGiveOK(t, app, "/probe/live") shouldGiveStatus(t, app, "/probe/ready", fiber.StatusServiceUnavailable) shouldGiveOK(t, app, "/probe/live/") shouldGiveStatus(t, app, "/probe/ready/", fiber.StatusServiceUnavailable) shouldGiveNotFound(t, app, "/probe/livez") shouldGiveNotFound(t, app, "/probe/readyz") shouldGiveNotFound(t, app, "/probe/livez/") shouldGiveNotFound(t, app, "/probe/readyz/") shouldGiveNotFound(t, app, "/livez") shouldGiveNotFound(t, app, "/readyz") shouldGiveNotFound(t, app, "/readyz/") shouldGiveNotFound(t, app, "/livez/") c1 <- struct{}{} shouldGiveOK(t, app, "/probe/ready") c1 <- struct{}{} shouldGiveOK(t, app, "/probe/ready/") } func Test_HealthCheck_Next(t *testing.T) { t.Parallel() app := fiber.New() checker := New(Config{ Next: func(_ fiber.Ctx) bool { return true }, }) app.Get(LivenessEndpoint, checker) app.Get(ReadinessEndpoint, checker) app.Get(StartupEndpoint, checker) // This should give not found since there are no other handlers to execute // so it's like the route isn't defined at all shouldGiveNotFound(t, app, "/readyz") shouldGiveNotFound(t, app, "/livez") shouldGiveNotFound(t, app, "/startupz") } func Benchmark_HealthCheck(b *testing.B) { app := fiber.New() app.Get(LivenessEndpoint, New()) app.Get(ReadinessEndpoint, New()) app.Get(StartupEndpoint, New()) h := app.Handler() fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(fiber.MethodGet) fctx.Request.SetRequestURI("/livez") b.ReportAllocs() for b.Loop() { h(fctx) } require.Equal(b, fiber.StatusOK, fctx.Response.Header.StatusCode()) } func Benchmark_HealthCheck_Parallel(b *testing.B) { app := fiber.New() app.Get(LivenessEndpoint, New()) app.Get(ReadinessEndpoint, New()) app.Get(StartupEndpoint, New()) h := app.Handler() b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(fiber.MethodGet) fctx.Request.SetRequestURI("/livez") for pb.Next() { h(fctx) } }) } ================================================ FILE: middleware/helmet/config.go ================================================ package helmet import ( "github.com/gofiber/fiber/v3" ) // Config defines the config for middleware. type Config struct { // Next defines a function to skip middleware. // Optional. Default: nil Next func(fiber.Ctx) bool // XSSProtection // Optional. Default value "0". XSSProtection string // ContentTypeNosniff // Optional. Default value "nosniff". ContentTypeNosniff string // XFrameOptions // Optional. Default value "SAMEORIGIN". // Possible values: "SAMEORIGIN", "DENY", "ALLOW-FROM uri" XFrameOptions string // ContentSecurityPolicy // Optional. Default value "". ContentSecurityPolicy string // ReferrerPolicy // Optional. Default value "no-referrer". ReferrerPolicy string // Permissions-Policy // Optional. Default value "". PermissionPolicy string // Cross-Origin-Embedder-Policy // Optional. Default value "require-corp". CrossOriginEmbedderPolicy string // Cross-Origin-Opener-Policy // Optional. Default value "same-origin". CrossOriginOpenerPolicy string // Cross-Origin-Resource-Policy // Optional. Default value "same-origin". CrossOriginResourcePolicy string // Origin-Agent-Cluster // Optional. Default value "?1". OriginAgentCluster string // X-DNS-Prefetch-Control // Optional. Default value "off". XDNSPrefetchControl string // X-Download-Options // Optional. Default value "noopen". XDownloadOptions string // X-Permitted-Cross-Domain-Policies // Optional. Default value "none". XPermittedCrossDomain string // HSTSMaxAge // Optional. Default value 0. HSTSMaxAge int // HSTSExcludeSubdomains // Optional. Default value false. HSTSExcludeSubdomains bool // CSPReportOnly // Optional. Default value false. CSPReportOnly bool // HSTSPreloadEnabled // Optional. Default value false. HSTSPreloadEnabled bool } // ConfigDefault is the default config var ConfigDefault = Config{ XSSProtection: "0", ContentTypeNosniff: "nosniff", XFrameOptions: "SAMEORIGIN", ReferrerPolicy: "no-referrer", CrossOriginEmbedderPolicy: "require-corp", CrossOriginOpenerPolicy: "same-origin", CrossOriginResourcePolicy: "same-origin", OriginAgentCluster: "?1", XDNSPrefetchControl: "off", XDownloadOptions: "noopen", XPermittedCrossDomain: "none", } // Helper function to set default values func configDefault(config ...Config) Config { // Return default config if nothing provided if len(config) < 1 { return ConfigDefault } // Override default config cfg := config[0] // Set default values if cfg.XSSProtection == "" { cfg.XSSProtection = ConfigDefault.XSSProtection } if cfg.ContentTypeNosniff == "" { cfg.ContentTypeNosniff = ConfigDefault.ContentTypeNosniff } if cfg.XFrameOptions == "" { cfg.XFrameOptions = ConfigDefault.XFrameOptions } if cfg.ReferrerPolicy == "" { cfg.ReferrerPolicy = ConfigDefault.ReferrerPolicy } if cfg.CrossOriginEmbedderPolicy == "" { cfg.CrossOriginEmbedderPolicy = ConfigDefault.CrossOriginEmbedderPolicy } if cfg.CrossOriginOpenerPolicy == "" { cfg.CrossOriginOpenerPolicy = ConfigDefault.CrossOriginOpenerPolicy } if cfg.CrossOriginResourcePolicy == "" { cfg.CrossOriginResourcePolicy = ConfigDefault.CrossOriginResourcePolicy } if cfg.OriginAgentCluster == "" { cfg.OriginAgentCluster = ConfigDefault.OriginAgentCluster } if cfg.XDNSPrefetchControl == "" { cfg.XDNSPrefetchControl = ConfigDefault.XDNSPrefetchControl } if cfg.XDownloadOptions == "" { cfg.XDownloadOptions = ConfigDefault.XDownloadOptions } if cfg.XPermittedCrossDomain == "" { cfg.XPermittedCrossDomain = ConfigDefault.XPermittedCrossDomain } return cfg } ================================================ FILE: middleware/helmet/helmet.go ================================================ package helmet import ( "fmt" "github.com/gofiber/fiber/v3" ) // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Init config cfg := configDefault(config...) // Return middleware handler return func(c fiber.Ctx) error { // Next request to skip middleware if cfg.Next != nil && cfg.Next(c) { return c.Next() } // Set headers if cfg.XSSProtection != "" { c.Set(fiber.HeaderXXSSProtection, cfg.XSSProtection) } if cfg.ContentTypeNosniff != "" { c.Set(fiber.HeaderXContentTypeOptions, cfg.ContentTypeNosniff) } if cfg.XFrameOptions != "" { c.Set(fiber.HeaderXFrameOptions, cfg.XFrameOptions) } if cfg.CrossOriginEmbedderPolicy != "" { c.Set("Cross-Origin-Embedder-Policy", cfg.CrossOriginEmbedderPolicy) } if cfg.CrossOriginOpenerPolicy != "" { c.Set("Cross-Origin-Opener-Policy", cfg.CrossOriginOpenerPolicy) } if cfg.CrossOriginResourcePolicy != "" { c.Set("Cross-Origin-Resource-Policy", cfg.CrossOriginResourcePolicy) } if cfg.OriginAgentCluster != "" { c.Set("Origin-Agent-Cluster", cfg.OriginAgentCluster) } if cfg.ReferrerPolicy != "" { c.Set("Referrer-Policy", cfg.ReferrerPolicy) } if cfg.XDNSPrefetchControl != "" { c.Set("X-DNS-Prefetch-Control", cfg.XDNSPrefetchControl) } if cfg.XDownloadOptions != "" { c.Set("X-Download-Options", cfg.XDownloadOptions) } if cfg.XPermittedCrossDomain != "" { c.Set("X-Permitted-Cross-Domain-Policies", cfg.XPermittedCrossDomain) } // Handle HSTS headers if c.Protocol() == "https" && cfg.HSTSMaxAge != 0 { subdomains := "" if !cfg.HSTSExcludeSubdomains { subdomains = "; includeSubDomains" } if cfg.HSTSPreloadEnabled { subdomains += "; preload" } c.Set(fiber.HeaderStrictTransportSecurity, fmt.Sprintf("max-age=%d%s", cfg.HSTSMaxAge, subdomains)) } // Handle Content-Security-Policy headers if cfg.ContentSecurityPolicy != "" { if cfg.CSPReportOnly { c.Set(fiber.HeaderContentSecurityPolicyReportOnly, cfg.ContentSecurityPolicy) } else { c.Set(fiber.HeaderContentSecurityPolicy, cfg.ContentSecurityPolicy) } } // Handle Permissions-Policy headers if cfg.PermissionPolicy != "" { c.Set(fiber.HeaderPermissionsPolicy, cfg.PermissionPolicy) } return c.Next() } } ================================================ FILE: middleware/helmet/helmet_test.go ================================================ package helmet import ( "net/http" "net/http/httptest" "testing" "github.com/gofiber/fiber/v3" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) func Test_Default(t *testing.T) { app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, "0", resp.Header.Get(fiber.HeaderXXSSProtection)) require.Equal(t, "nosniff", resp.Header.Get(fiber.HeaderXContentTypeOptions)) require.Equal(t, "SAMEORIGIN", resp.Header.Get(fiber.HeaderXFrameOptions)) require.Empty(t, resp.Header.Get(fiber.HeaderContentSecurityPolicy)) require.Equal(t, "no-referrer", resp.Header.Get(fiber.HeaderReferrerPolicy)) require.Empty(t, resp.Header.Get(fiber.HeaderPermissionsPolicy)) require.Equal(t, "require-corp", resp.Header.Get("Cross-Origin-Embedder-Policy")) require.Equal(t, "same-origin", resp.Header.Get("Cross-Origin-Opener-Policy")) require.Equal(t, "same-origin", resp.Header.Get("Cross-Origin-Resource-Policy")) require.Equal(t, "?1", resp.Header.Get("Origin-Agent-Cluster")) require.Equal(t, "off", resp.Header.Get("X-DNS-Prefetch-Control")) require.Equal(t, "noopen", resp.Header.Get("X-Download-Options")) require.Equal(t, "none", resp.Header.Get("X-Permitted-Cross-Domain-Policies")) } func Test_CustomValues_AllHeaders(t *testing.T) { app := fiber.New() app.Use(New(Config{ // Custom values for all headers XSSProtection: "0", ContentTypeNosniff: "custom-nosniff", XFrameOptions: "DENY", HSTSExcludeSubdomains: true, ContentSecurityPolicy: "default-src 'none'", CSPReportOnly: true, HSTSPreloadEnabled: true, ReferrerPolicy: "origin", PermissionPolicy: "geolocation=(self)", CrossOriginEmbedderPolicy: "custom-value", CrossOriginOpenerPolicy: "custom-value", CrossOriginResourcePolicy: "custom-value", OriginAgentCluster: "custom-value", XDNSPrefetchControl: "custom-control", XDownloadOptions: "custom-options", XPermittedCrossDomain: "custom-policies", })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) // Assertions for custom header values require.Equal(t, "0", resp.Header.Get(fiber.HeaderXXSSProtection)) require.Equal(t, "custom-nosniff", resp.Header.Get(fiber.HeaderXContentTypeOptions)) require.Equal(t, "DENY", resp.Header.Get(fiber.HeaderXFrameOptions)) require.Equal(t, "default-src 'none'", resp.Header.Get(fiber.HeaderContentSecurityPolicyReportOnly)) require.Equal(t, "origin", resp.Header.Get(fiber.HeaderReferrerPolicy)) require.Equal(t, "geolocation=(self)", resp.Header.Get(fiber.HeaderPermissionsPolicy)) require.Equal(t, "custom-value", resp.Header.Get("Cross-Origin-Embedder-Policy")) require.Equal(t, "custom-value", resp.Header.Get("Cross-Origin-Opener-Policy")) require.Equal(t, "custom-value", resp.Header.Get("Cross-Origin-Resource-Policy")) require.Equal(t, "custom-value", resp.Header.Get("Origin-Agent-Cluster")) require.Equal(t, "custom-control", resp.Header.Get("X-DNS-Prefetch-Control")) require.Equal(t, "custom-options", resp.Header.Get("X-Download-Options")) require.Equal(t, "custom-policies", resp.Header.Get("X-Permitted-Cross-Domain-Policies")) } func Test_RealWorldValues_AllHeaders(t *testing.T) { app := fiber.New() app.Use(New(Config{ // Real-world values for all headers XSSProtection: "0", ContentTypeNosniff: "nosniff", XFrameOptions: "SAMEORIGIN", HSTSExcludeSubdomains: false, ContentSecurityPolicy: "default-src 'self';base-uri 'self';font-src 'self' https: data:;form-action 'self';frame-ancestors 'self';img-src 'self' data:;object-src 'none';script-src 'self';script-src-attr 'none';style-src 'self' https: 'unsafe-inline';upgrade-insecure-requests", CSPReportOnly: false, HSTSPreloadEnabled: true, ReferrerPolicy: "no-referrer", PermissionPolicy: "geolocation=(self)", CrossOriginEmbedderPolicy: "require-corp", CrossOriginOpenerPolicy: "same-origin", CrossOriginResourcePolicy: "same-origin", OriginAgentCluster: "?1", XDNSPrefetchControl: "off", XDownloadOptions: "noopen", XPermittedCrossDomain: "none", })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) // Assertions for real-world header values require.Equal(t, "0", resp.Header.Get(fiber.HeaderXXSSProtection)) require.Equal(t, "nosniff", resp.Header.Get(fiber.HeaderXContentTypeOptions)) require.Equal(t, "SAMEORIGIN", resp.Header.Get(fiber.HeaderXFrameOptions)) require.Equal(t, "default-src 'self';base-uri 'self';font-src 'self' https: data:;form-action 'self';frame-ancestors 'self';img-src 'self' data:;object-src 'none';script-src 'self';script-src-attr 'none';style-src 'self' https: 'unsafe-inline';upgrade-insecure-requests", resp.Header.Get(fiber.HeaderContentSecurityPolicy)) require.Equal(t, "no-referrer", resp.Header.Get(fiber.HeaderReferrerPolicy)) require.Equal(t, "geolocation=(self)", resp.Header.Get(fiber.HeaderPermissionsPolicy)) require.Equal(t, "require-corp", resp.Header.Get("Cross-Origin-Embedder-Policy")) require.Equal(t, "same-origin", resp.Header.Get("Cross-Origin-Opener-Policy")) require.Equal(t, "same-origin", resp.Header.Get("Cross-Origin-Resource-Policy")) require.Equal(t, "?1", resp.Header.Get("Origin-Agent-Cluster")) require.Equal(t, "off", resp.Header.Get("X-DNS-Prefetch-Control")) require.Equal(t, "noopen", resp.Header.Get("X-Download-Options")) require.Equal(t, "none", resp.Header.Get("X-Permitted-Cross-Domain-Policies")) } func Test_Next(t *testing.T) { app := fiber.New() app.Use(New(Config{ Next: func(ctx fiber.Ctx) bool { return ctx.Path() == "/next" }, ReferrerPolicy: "no-referrer", })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) app.Get("/next", func(c fiber.Ctx) error { return c.SendString("Skipped!") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, "no-referrer", resp.Header.Get(fiber.HeaderReferrerPolicy)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/next", http.NoBody)) require.NoError(t, err) require.Empty(t, resp.Header.Get(fiber.HeaderReferrerPolicy)) } func Test_ContentSecurityPolicy(t *testing.T) { app := fiber.New() app.Use(New(Config{ ContentSecurityPolicy: "default-src 'none'", })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, "default-src 'none'", resp.Header.Get(fiber.HeaderContentSecurityPolicy)) } func Test_ContentSecurityPolicyReportOnly(t *testing.T) { app := fiber.New() app.Use(New(Config{ ContentSecurityPolicy: "default-src 'none'", CSPReportOnly: true, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, "default-src 'none'", resp.Header.Get(fiber.HeaderContentSecurityPolicyReportOnly)) require.Empty(t, resp.Header.Get(fiber.HeaderContentSecurityPolicy)) } func Test_PermissionsPolicy(t *testing.T) { app := fiber.New() app.Use(New(Config{ PermissionPolicy: "microphone=()", })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, "microphone=()", resp.Header.Get(fiber.HeaderPermissionsPolicy)) } func Test_HSTSHeaders(t *testing.T) { hstsAge := 60 app := fiber.New() app.Use(New(Config{HSTSMaxAge: hstsAge})) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) handler := app.Handler() ctx := &fasthttp.RequestCtx{} ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.SetProtocol("https") handler(ctx) require.Equal(t, "max-age=60; includeSubDomains", string(ctx.Response.Header.Peek(fiber.HeaderStrictTransportSecurity))) ctx.Request.Reset() ctx.Response.Reset() ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.SetProtocol("http") handler(ctx) require.Empty(t, string(ctx.Response.Header.Peek(fiber.HeaderStrictTransportSecurity))) } func Test_HSTSExcludeSubdomainsAndPreload(t *testing.T) { hstsAge := 31536000 app := fiber.New() app.Use(New(Config{ HSTSMaxAge: hstsAge, HSTSExcludeSubdomains: true, HSTSPreloadEnabled: true, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) handler := app.Handler() ctx := &fasthttp.RequestCtx{} ctx.Request.SetRequestURI("/") ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.SetProtocol("https") handler(ctx) require.Equal(t, "max-age=31536000; preload", string(ctx.Response.Header.Peek(fiber.HeaderStrictTransportSecurity))) } ================================================ FILE: middleware/idempotency/config.go ================================================ package idempotency import ( "errors" "fmt" "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/internal/storage/memory" ) var ErrInvalidIdempotencyKey = errors.New("invalid idempotency key") // Config defines the config for middleware. type Config struct { // Lock locks an idempotency key. // // Optional. Default: an in-memory locker for this process only. Lock Locker // Storage stores response data by idempotency key. // // Optional. Default: an in-memory storage for this process only. Storage fiber.Storage // Next defines a function to skip this middleware when returned true. // // Optional. Default: a function which skips the middleware on safe HTTP request method. Next func(c fiber.Ctx) bool // KeyHeaderValidate defines a function to validate the syntax of the idempotency header. // // Optional. Default: a function which ensures the header is 36 characters long (the size of an UUID). KeyHeaderValidate func(string) error // KeyHeader is the name of the header that contains the idempotency key. // // Optional. Default: X-Idempotency-Key KeyHeader string // KeepResponseHeaders is a list of headers that should be kept from the original response. // // Optional. Default: nil (to keep all headers) KeepResponseHeaders []string // Lifetime is the maximum lifetime of an idempotency key. // // Optional. Default: 30 * time.Minute Lifetime time.Duration // DisableValueRedaction turns off masking idempotency keys in logs and errors when set to true. // // Optional. Default: false DisableValueRedaction bool } // ConfigDefault is the default config var ConfigDefault = Config{ Next: func(c fiber.Ctx) bool { // Skip middleware if the request was done using a safe HTTP method return fiber.IsMethodSafe(c.Method()) }, Lifetime: 30 * time.Minute, KeyHeader: "X-Idempotency-Key", KeyHeaderValidate: func(k string) error { if l, wl := len(k), 36; l != wl { // UUID length is 36 chars return fmt.Errorf("%w: invalid length: %d != %d", ErrInvalidIdempotencyKey, l, wl) } return nil }, KeepResponseHeaders: nil, Lock: nil, // Set in configDefault so we don't allocate data here. Storage: nil, // Set in configDefault so we don't allocate data here. DisableValueRedaction: false, } // Helper function to set default values func configDefault(config ...Config) Config { // Return default config if nothing provided if len(config) < 1 { cfg := ConfigDefault cfg.Lock = NewMemoryLock() cfg.Storage = memory.New(memory.Config{ GCInterval: cfg.Lifetime / 2, // Half the lifetime interval }) return cfg } // Override default config cfg := config[0] // Set default values if cfg.Next == nil { cfg.Next = ConfigDefault.Next } if cfg.Lifetime.Nanoseconds() == 0 { cfg.Lifetime = ConfigDefault.Lifetime } if cfg.KeyHeader == "" { cfg.KeyHeader = ConfigDefault.KeyHeader } if cfg.KeyHeaderValidate == nil { cfg.KeyHeaderValidate = ConfigDefault.KeyHeaderValidate } if cfg.KeepResponseHeaders != nil && len(cfg.KeepResponseHeaders) == 0 { cfg.KeepResponseHeaders = ConfigDefault.KeepResponseHeaders } if cfg.Lock == nil { cfg.Lock = NewMemoryLock() } if cfg.Storage == nil { cfg.Storage = memory.New(memory.Config{ GCInterval: cfg.Lifetime / 2, }) } return cfg } ================================================ FILE: middleware/idempotency/idempotency.go ================================================ package idempotency import ( "fmt" "github.com/gofiber/utils/v2" utilsstrings "github.com/gofiber/utils/v2/strings" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/log" ) // Inspired by https://datatracker.ietf.org/doc/html/draft-ietf-httpapi-idempotency-key-header-02 // and https://github.com/penguin-statistics/backend-next/blob/f2f7d5ba54fc8a58f168d153baa17b2ad4a14e45/internal/pkg/middlewares/idempotency.go // The contextKey type is unexported to prevent collisions with context keys defined in // other packages. type contextKey int const ( localsKeyIsFromCache contextKey = iota // localsKeyWasPutToCache ) const redactedKey = "[redacted]" // IsFromCache reports whether the middleware served the response from the // cache for the current request. func IsFromCache(c fiber.Ctx) bool { return c.Locals(localsKeyIsFromCache) != nil } // WasPutToCache reports whether the middleware stored the response produced by // the current request in the cache. func WasPutToCache(c fiber.Ctx) bool { val := c.Locals(localsKeyWasPutToCache) if wasPut, ok := val.(bool); ok { return wasPut } return val != nil } // New creates idempotency middleware that caches responses keyed by the // configured idempotency header. func New(config ...Config) fiber.Handler { // Set default config cfg := configDefault(config...) redactKeys := !cfg.DisableValueRedaction maskKey := func(key string) string { if redactKeys { return redactedKey } return key } keepResponseHeadersMap := make(map[string]struct{}, len(cfg.KeepResponseHeaders)) for _, h := range cfg.KeepResponseHeaders { // CopyString is needed because utils.ToLower uses UnsafeString // and map keys must be immutable keepResponseHeadersMap[utilsstrings.ToLower(h)] = struct{}{} } maybeWriteCachedResponse := func(c fiber.Ctx, key string) (bool, error) { if val, err := cfg.Storage.GetWithContext(c, key); err != nil { return false, fmt.Errorf("failed to read response: %w", err) } else if val != nil { var res response if _, err := res.UnmarshalMsg(val); err != nil { return false, fmt.Errorf("failed to unmarshal response: %w", err) } _ = c.Status(res.StatusCode) for header, vals := range res.Headers { for _, val := range vals { c.RequestCtx().Response.Header.Add(header, val) } } if len(res.Body) != 0 { if err := c.Send(res.Body); err != nil { return true, err } } _ = c.Locals(localsKeyIsFromCache, true) return true, nil } return false, nil } return func(c fiber.Ctx) error { // Don't execute middleware if Next returns true if cfg.Next != nil && cfg.Next(c) { return c.Next() } // Don't execute middleware if the idempotency key is empty if c.Get(cfg.KeyHeader) == "" { return c.Next() } // Validate key key := utils.CopyString(c.Get(cfg.KeyHeader)) if err := cfg.KeyHeaderValidate(key); err != nil { return err } // First-pass: if the idempotency key is in the storage, get and return the response if ok, err := maybeWriteCachedResponse(c, key); err != nil { return fmt.Errorf("failed to write cached response at fastpath: %w", err) } else if ok { return nil } if err := cfg.Lock.Lock(key); err != nil { return fmt.Errorf("failed to lock: %w", err) } defer func() { if err := cfg.Lock.Unlock(key); err != nil { log.Errorf("[IDEMPOTENCY] failed to unlock key %q: %v", maskKey(key), err) } }() // Lock acquired. If the idempotency key now is in the storage, get and return the response if ok, err := maybeWriteCachedResponse(c, key); err != nil { return fmt.Errorf("failed to write cached response while locked: %w", err) } else if ok { return nil } // Execute the request handler if err := c.Next(); err != nil { // If the request handler returned an error, return it and skip idempotency return err } // Construct response res := &response{ StatusCode: c.Response().StatusCode(), Body: c.Response().Body(), } { headers := make(map[string][]string) if err := c.Bind().RespHeader(headers); err != nil { return fmt.Errorf("failed to bind to response headers: %w", err) } if cfg.KeepResponseHeaders == nil { // Keep all res.Headers = headers } else { // Filter res.Headers = make(map[string][]string) for h := range headers { if _, ok := keepResponseHeadersMap[utilsstrings.ToLower(h)]; ok { res.Headers[h] = headers[h] } } } } bodyLimit := c.App().Config().BodyLimit if bodyLimit > 0 && len(res.Body) > bodyLimit { _ = c.Locals(localsKeyWasPutToCache, false) return nil } // Marshal response bs, err := res.MarshalMsg(nil) if err != nil { return fmt.Errorf("failed to marshal response: %w", err) } // Store response if err := cfg.Storage.SetWithContext(c, key, bs, cfg.Lifetime); err != nil { return fmt.Errorf("failed to save response: %w", err) } _ = c.Locals(localsKeyWasPutToCache, true) return nil } } ================================================ FILE: middleware/idempotency/idempotency_test.go ================================================ package idempotency import ( "errors" "fmt" "io" "net/http" "net/http/httptest" "strconv" "strings" "sync" "sync/atomic" "testing" "time" "github.com/gofiber/fiber/v3" "github.com/valyala/fasthttp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) const validKey = "00000000-0000-0000-0000-000000000000" // go test -run Test_Idempotency func Test_Idempotency(t *testing.T) { t.Parallel() app := fiber.New() app.Use(func(c fiber.Ctx) error { if err := c.Next(); err != nil { return err } isMethodSafe := fiber.IsMethodSafe(c.Method()) isIdempotent := IsFromCache(c) || WasPutToCache(c) hasReqHeader := c.Get("X-Idempotency-Key") != "" if isMethodSafe { if isIdempotent { return errors.New("request with safe HTTP method should not be idempotent") } } else { // Unsafe if hasReqHeader { if !isIdempotent { return errors.New("request with unsafe HTTP method should be idempotent if X-Idempotency-Key request header is set") } } else if isIdempotent { return errors.New("request with unsafe HTTP method should not be idempotent if X-Idempotency-Key request header is not set") } } return nil }) // Needs to be at least a second as the memory storage doesn't support shorter durations. const lifetime = 2 * time.Second app.Use(New(Config{ Lifetime: lifetime, })) nextCount := func() func() int { var count int32 return func() int { return int(atomic.AddInt32(&count, 1)) } }() app.Add([]string{ fiber.MethodGet, fiber.MethodPost, }, "/", func(c fiber.Ctx) error { return c.SendString(strconv.Itoa(nextCount())) }) app.Post("/slow", func(c fiber.Ctx) error { time.Sleep(3 * lifetime) return c.SendString(strconv.Itoa(nextCount())) }) doReq := func(method, route, idempotencyKey string) string { req := httptest.NewRequest(method, route, http.NoBody) if idempotencyKey != "" { req.Header.Set("X-Idempotency-Key", idempotencyKey) } resp, err := app.Test(req, fiber.TestConfig{ Timeout: 15 * time.Second, FailOnTimeout: true, }) require.NoError(t, err) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode, string(body)) return string(body) } require.Equal(t, "1", doReq(fiber.MethodGet, "/", "")) require.Equal(t, "2", doReq(fiber.MethodGet, "/", "")) require.Equal(t, "3", doReq(fiber.MethodPost, "/", "")) require.Equal(t, "4", doReq(fiber.MethodPost, "/", "")) require.Equal(t, "5", doReq(fiber.MethodGet, "/", "00000000-0000-0000-0000-000000000000")) require.Equal(t, "6", doReq(fiber.MethodGet, "/", "00000000-0000-0000-0000-000000000000")) require.Equal(t, "7", doReq(fiber.MethodPost, "/", "00000000-0000-0000-0000-000000000000")) require.Equal(t, "7", doReq(fiber.MethodPost, "/", "00000000-0000-0000-0000-000000000000")) require.Equal(t, "8", doReq(fiber.MethodPost, "/", "")) require.Equal(t, "9", doReq(fiber.MethodPost, "/", "11111111-1111-1111-1111-111111111111")) require.Equal(t, "7", doReq(fiber.MethodPost, "/", "00000000-0000-0000-0000-000000000000")) time.Sleep(4 * lifetime) require.Equal(t, "10", doReq(fiber.MethodPost, "/", "00000000-0000-0000-0000-000000000000")) require.Equal(t, "10", doReq(fiber.MethodPost, "/", "00000000-0000-0000-0000-000000000000")) // Test raciness { var wg sync.WaitGroup for range 100 { wg.Go(func() { assert.Equal(t, "11", doReq(fiber.MethodPost, "/slow", "22222222-2222-2222-2222-222222222222")) }) } wg.Wait() require.Equal(t, "11", doReq(fiber.MethodPost, "/slow", "22222222-2222-2222-2222-222222222222")) } time.Sleep(3 * lifetime) require.Equal(t, "12", doReq(fiber.MethodPost, "/slow", "22222222-2222-2222-2222-222222222222")) } // go test -v -run=^$ -bench=Benchmark_Idempotency -benchmem -count=4 func Benchmark_Idempotency(b *testing.B) { app := fiber.New() // Needs to be at least a second as the memory storage doesn't support shorter durations. const lifetime = 1 * time.Second app.Use(New(Config{ Lifetime: lifetime, })) app.Post("/", func(_ fiber.Ctx) error { return nil }) h := app.Handler() b.Run("hit", func(b *testing.B) { c := &fasthttp.RequestCtx{} c.Request.Header.SetMethod(fiber.MethodPost) c.Request.SetRequestURI("/") c.Request.Header.Set("X-Idempotency-Key", "00000000-0000-0000-0000-000000000000") b.ReportAllocs() for b.Loop() { h(c) } }) b.Run("skip", func(b *testing.B) { c := &fasthttp.RequestCtx{} c.Request.Header.SetMethod(fiber.MethodPost) c.Request.SetRequestURI("/") b.ReportAllocs() for b.Loop() { h(c) } }) } func Test_configDefault_defaults(t *testing.T) { t.Parallel() cfg := configDefault() require.NotNil(t, cfg.Lock) require.NotNil(t, cfg.Storage) require.Equal(t, ConfigDefault.Lifetime, cfg.Lifetime) require.Equal(t, ConfigDefault.KeyHeader, cfg.KeyHeader) require.Nil(t, cfg.KeepResponseHeaders) app := fiber.New() fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(fiber.MethodGet) ctx := app.AcquireCtx(fctx) require.True(t, cfg.Next(ctx)) app.ReleaseCtx(ctx) fctx = &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(fiber.MethodPost) ctx = app.AcquireCtx(fctx) require.False(t, cfg.Next(ctx)) app.ReleaseCtx(ctx) require.NoError(t, cfg.KeyHeaderValidate(validKey)) require.Error(t, cfg.KeyHeaderValidate("short")) } func Test_configDefault_override(t *testing.T) { t.Parallel() l := &stubLock{} s := &stubStorage{} cfg := configDefault(Config{ Lifetime: 42 * time.Second, KeyHeader: "Foo", KeepResponseHeaders: []string{}, Lock: l, Storage: s, }) require.Equal(t, 42*time.Second, cfg.Lifetime) require.Equal(t, "Foo", cfg.KeyHeader) require.Nil(t, cfg.KeepResponseHeaders) require.Equal(t, l, cfg.Lock) require.Equal(t, s, cfg.Storage) require.NotNil(t, cfg.Next) require.NotNil(t, cfg.KeyHeaderValidate) } // helper to perform request func do(app *fiber.App, req *http.Request) (resp *http.Response, body string) { //nolint:nonamedreturns // gocritic unnamedResult prefers naming returned response and body payload for clarity resp, err := app.Test(req, fiber.TestConfig{Timeout: 5 * time.Second}) if err != nil { panic(err) } payload, err := io.ReadAll(resp.Body) if err != nil { panic(err) } return resp, string(payload) } func Test_New_NextSkip(t *testing.T) { t.Parallel() app := fiber.New() var count int app.Use(New(Config{Next: func(_ fiber.Ctx) bool { return true }})) app.Post("/", func(c fiber.Ctx) error { count++ return c.SendString(strconv.Itoa(count)) }) req := httptest.NewRequest(http.MethodPost, "/", http.NoBody) req.Header.Set(ConfigDefault.KeyHeader, validKey) _, body1 := do(app, req) req2 := httptest.NewRequest(http.MethodPost, "/", http.NoBody) req2.Header.Set(ConfigDefault.KeyHeader, validKey) _, body2 := do(app, req2) require.Equal(t, "1", body1) require.Equal(t, "2", body2) } func Test_New_InvalidKey(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Post("/", func(_ fiber.Ctx) error { return nil }) req := httptest.NewRequest(http.MethodPost, "/", http.NoBody) req.Header.Set(ConfigDefault.KeyHeader, "bad") resp, body := do(app, req) require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) require.Contains(t, body, "invalid length") } func Test_New_StorageGetError(t *testing.T) { t.Parallel() app := fiber.New() s := &stubStorage{getErr: errors.New("boom")} app.Use(New(Config{Storage: s, Lock: &stubLock{}})) app.Post("/", func(c fiber.Ctx) error { return c.SendString("ok") }) req := httptest.NewRequest(http.MethodPost, "/", http.NoBody) req.Header.Set(ConfigDefault.KeyHeader, validKey) resp, body := do(app, req) require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) require.Contains(t, body, "failed to write cached response at fastpath") } func Test_New_UnmarshalError(t *testing.T) { t.Parallel() app := fiber.New() s := &stubStorage{data: map[string][]byte{validKey: []byte("bad")}} app.Use(New(Config{Storage: s, Lock: &stubLock{}})) app.Post("/", func(c fiber.Ctx) error { return c.SendString("ok") }) req := httptest.NewRequest(http.MethodPost, "/", http.NoBody) req.Header.Set(ConfigDefault.KeyHeader, validKey) resp, body := do(app, req) require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) require.Contains(t, body, "failed to write cached response at fastpath") } func Test_New_StoreRetrieve_FilterHeaders(t *testing.T) { t.Parallel() app := fiber.New() s := &stubStorage{} app.Use(New(Config{ Storage: s, Lock: &stubLock{}, KeepResponseHeaders: []string{"Foo"}, })) var count int app.Post("/", func(c fiber.Ctx) error { count++ c.Set("Foo", "foo") c.Set("Bar", "bar") return c.SendString(fmt.Sprintf("resp%d", count)) }) req := httptest.NewRequest(http.MethodPost, "/", http.NoBody) req.Header.Set(ConfigDefault.KeyHeader, validKey) resp, body := do(app, req) require.Equal(t, "resp1", body) require.Equal(t, "foo", resp.Header.Get("Foo")) require.Equal(t, "bar", resp.Header.Get("Bar")) req2 := httptest.NewRequest(http.MethodPost, "/", http.NoBody) req2.Header.Set(ConfigDefault.KeyHeader, validKey) resp2, body2 := do(app, req2) require.Equal(t, "resp1", body2) require.Equal(t, "foo", resp2.Header.Get("Foo")) require.Empty(t, resp2.Header.Get("Bar")) require.Equal(t, 1, count) require.Equal(t, 1, s.setCount) } func Test_New_SkipCache_WhenBodyTooLarge(t *testing.T) { t.Parallel() bodyLimit := 8 app := fiber.New(fiber.Config{BodyLimit: bodyLimit}) s := &stubStorage{} var wasPut []bool app.Use(func(c fiber.Ctx) error { if err := c.Next(); err != nil { return err } wasPut = append(wasPut, WasPutToCache(c)) return nil }) app.Use(New(Config{Storage: s, Lock: &stubLock{}})) var count int oversized := strings.Repeat("a", bodyLimit+1) app.Post("/", func(c fiber.Ctx) error { count++ return c.SendString(oversized) }) req := httptest.NewRequest(http.MethodPost, "/", http.NoBody) req.Header.Set(ConfigDefault.KeyHeader, validKey) resp1, body1 := do(app, req) require.Equal(t, fiber.StatusOK, resp1.StatusCode) require.Equal(t, oversized, body1) req2 := httptest.NewRequest(http.MethodPost, "/", http.NoBody) req2.Header.Set(ConfigDefault.KeyHeader, validKey) resp2, body2 := do(app, req2) require.Equal(t, fiber.StatusOK, resp2.StatusCode) require.Equal(t, oversized, body2) require.Equal(t, 2, count) require.Equal(t, 0, s.setCount) require.Len(t, wasPut, 2) require.False(t, wasPut[0]) require.False(t, wasPut[1]) } func Test_New_HandlerError(t *testing.T) { t.Parallel() app := fiber.New() s := &stubStorage{} app.Use(New(Config{Storage: s, Lock: &stubLock{}})) app.Post("/", func(_ fiber.Ctx) error { return errors.New("boom") }) req := httptest.NewRequest(http.MethodPost, "/", http.NoBody) req.Header.Set(ConfigDefault.KeyHeader, validKey) resp, body := do(app, req) require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) require.Equal(t, "boom", body) require.Equal(t, 0, s.setCount) resp2, body2 := do(app, req) require.Equal(t, fiber.StatusInternalServerError, resp2.StatusCode) require.Equal(t, "boom", body2) require.Equal(t, 0, s.setCount) } func Test_New_LockError(t *testing.T) { t.Parallel() app := fiber.New() l := &stubLock{lockErr: errors.New("fail")} app.Use(New(Config{Lock: l, Storage: &stubStorage{}})) app.Post("/", func(c fiber.Ctx) error { return c.SendString("ok") }) req := httptest.NewRequest(http.MethodPost, "/", http.NoBody) req.Header.Set(ConfigDefault.KeyHeader, validKey) resp, body := do(app, req) require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) require.Contains(t, body, "failed to lock") } func Test_New_StorageSetError(t *testing.T) { t.Parallel() app := fiber.New() s := &stubStorage{setErr: errors.New("nope")} app.Use(New(Config{Storage: s, Lock: &stubLock{}})) app.Post("/", func(c fiber.Ctx) error { return c.SendString("ok") }) req := httptest.NewRequest(http.MethodPost, "/", http.NoBody) req.Header.Set(ConfigDefault.KeyHeader, validKey) resp, body := do(app, req) require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) require.Contains(t, body, "failed to save response") } func Test_New_UnlockError(t *testing.T) { t.Parallel() app := fiber.New() l := &stubLock{unlockErr: errors.New("u")} app.Use(New(Config{Lock: l, Storage: &stubStorage{}})) app.Post("/", func(c fiber.Ctx) error { return c.SendString("ok") }) req := httptest.NewRequest(http.MethodPost, "/", http.NoBody) req.Header.Set(ConfigDefault.KeyHeader, validKey) resp, body := do(app, req) require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Equal(t, "ok", body) } func Test_New_SecondPassReadError(t *testing.T) { t.Parallel() app := fiber.New() s := &stubStorage{} l := &stubLock{afterLock: func() { s.getErr = errors.New("g") }} app.Use(New(Config{Lock: l, Storage: s})) app.Post("/", func(c fiber.Ctx) error { return c.SendString("ok") }) req := httptest.NewRequest(http.MethodPost, "/", http.NoBody) req.Header.Set(ConfigDefault.KeyHeader, validKey) resp, body := do(app, req) require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) require.Contains(t, body, "failed to write cached response while locked") } ================================================ FILE: middleware/idempotency/locker.go ================================================ package idempotency import ( "sync" ) // Locker implements a spinlock for a string key. type Locker interface { Lock(key string) error Unlock(key string) error } type countedLock struct { mu sync.Mutex locked int } // MemoryLock coordinates access to idempotency keys using in-memory locks. type MemoryLock struct { keys map[string]*countedLock mu sync.Mutex } // Lock acquires the lock for the provided key, creating it when necessary. func (l *MemoryLock) Lock(key string) error { l.mu.Lock() lock, ok := l.keys[key] if !ok { lock = new(countedLock) l.keys[key] = lock } lock.locked++ l.mu.Unlock() lock.mu.Lock() return nil } // Unlock releases the lock associated with the provided key. func (l *MemoryLock) Unlock(key string) error { l.mu.Lock() lock, ok := l.keys[key] if !ok { // This happens if we try to unlock an unknown key l.mu.Unlock() return nil } l.mu.Unlock() lock.mu.Unlock() l.mu.Lock() lock.locked-- if lock.locked <= 0 { // This happens if countedLock is used to Lock and Unlock the same number of times // So, we can delete the key to prevent memory leak delete(l.keys, key) } l.mu.Unlock() return nil } // NewMemoryLock creates a MemoryLock ready for use. func NewMemoryLock() *MemoryLock { return &MemoryLock{ keys: make(map[string]*countedLock), } } var _ Locker = (*MemoryLock)(nil) ================================================ FILE: middleware/idempotency/locker_test.go ================================================ package idempotency_test import ( "strconv" "sync/atomic" "testing" "time" "github.com/gofiber/fiber/v3/middleware/idempotency" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // go test -run Test_MemoryLock func Test_MemoryLock(t *testing.T) { t.Parallel() l := idempotency.NewMemoryLock() // Test that a lock can be acquired { err := l.Lock("a") require.NoError(t, err) } // Test that the same lock cannot be acquired again while held { done := make(chan struct{}) go func() { defer close(done) err := l.Lock("a") assert.NoError(t, err) }() select { case <-done: t.Fatal("lock acquired again") case <-time.After(time.Second): // Expected: goroutine should still be blocked } } // Release lock "a" to prevent goroutine leak { err := l.Unlock("a") require.NoError(t, err) } // Test lock and unlock sequence { err := l.Lock("b") require.NoError(t, err) } { err := l.Unlock("b") require.NoError(t, err) } { err := l.Lock("b") require.NoError(t, err) } { err := l.Unlock("b") require.NoError(t, err) } // Test unlocking non-existent lock (should succeed) { err := l.Unlock("c") require.NoError(t, err) } // Test another lock { err := l.Lock("d") require.NoError(t, err) } { err := l.Unlock("d") require.NoError(t, err) } } func Benchmark_MemoryLock(b *testing.B) { keys := make([]string, 50_000_000) for i := range keys { keys[i] = strconv.Itoa(i) } lock := idempotency.NewMemoryLock() for i := 0; b.Loop(); i++ { key := keys[i] if err := lock.Lock(key); err != nil { b.Fatal(err) } if err := lock.Unlock(key); err != nil { b.Fatal(err) } } } func Benchmark_MemoryLock_Parallel(b *testing.B) { // In order to prevent using repeated keys I pre-allocate keys keys := make([]string, 1_000_000) for i := range keys { keys[i] = strconv.Itoa(i) } b.Run("UniqueKeys", func(b *testing.B) { lock := idempotency.NewMemoryLock() var keyI atomic.Int32 b.ReportAllocs() b.ResetTimer() b.RunParallel(func(p *testing.PB) { for p.Next() { i := int(keyI.Add(1)) % len(keys) key := keys[i] if err := lock.Lock(key); err != nil { b.Fatal(err) } if err := lock.Unlock(key); err != nil { b.Fatal(err) } } }) }) b.Run("RepeatedKeys", func(b *testing.B) { lock := idempotency.NewMemoryLock() var keyI atomic.Int32 b.ReportAllocs() b.ResetTimer() b.RunParallel(func(p *testing.PB) { for p.Next() { // Division by 3 ensures that index will be repeated exactly 3 times i := int(keyI.Add(1)) / 3 % len(keys) key := keys[i] if err := lock.Lock(key); err != nil { b.Fatal(err) } if err := lock.Unlock(key); err != nil { b.Fatal(err) } } }) }) } ================================================ FILE: middleware/idempotency/response.go ================================================ package idempotency // response is a struct that represents the response of a request. // generation tool `go install github.com/tinylib/msgp@latest` // // Idempotency payloads are stored in backing storage, so keep headers/bodies bounded. // //go:generate msgp -o=response_msgp.go -tests=true -unexported type response struct { Headers map[string][]string `msg:"hs,limit=1024"` // HTTP header count norms are well below this. Body []byte `msg:"b"` // Idempotency bodies are bounded by storage policy, not msgp limits. StatusCode int `msg:"sc"` } ================================================ FILE: middleware/idempotency/response_msgp.go ================================================ // Code generated by github.com/tinylib/msgp DO NOT EDIT. package idempotency import ( "github.com/tinylib/msgp/msgp" ) // DecodeMsg implements msgp.Decodable func (z *response) DecodeMsg(dc *msgp.Reader) (err error) { var field []byte _ = field var zb0001 uint32 zb0001, err = dc.ReadMapHeader() if err != nil { err = msgp.WrapError(err) return } for zb0001 > 0 { zb0001-- field, err = dc.ReadMapKeyPtr() if err != nil { err = msgp.WrapError(err) return } switch msgp.UnsafeString(field) { case "hs": var zb0002 uint32 zb0002, err = dc.ReadMapHeader() if err != nil { err = msgp.WrapError(err, "Headers") return } if zb0002 > 1024 { err = msgp.ErrLimitExceeded return } if z.Headers == nil { z.Headers = make(map[string][]string, zb0002) } else if len(z.Headers) > 0 { clear(z.Headers) } for zb0002 > 0 { zb0002-- var za0001 string za0001, err = dc.ReadString() if err != nil { err = msgp.WrapError(err, "Headers") return } var za0002 []string var zb0003 uint32 zb0003, err = dc.ReadArrayHeader() if err != nil { err = msgp.WrapError(err, "Headers", za0001) return } if zb0003 > 1024 { err = msgp.ErrLimitExceeded return } if cap(za0002) >= int(zb0003) { za0002 = (za0002)[:zb0003] } else { za0002 = make([]string, zb0003) } for za0003 := range za0002 { za0002[za0003], err = dc.ReadString() if err != nil { err = msgp.WrapError(err, "Headers", za0001, za0003) return } } z.Headers[za0001] = za0002 } case "b": z.Body, err = dc.ReadBytes(z.Body) if err != nil { err = msgp.WrapError(err, "Body") return } case "sc": z.StatusCode, err = dc.ReadInt() if err != nil { err = msgp.WrapError(err, "StatusCode") return } default: err = dc.Skip() if err != nil { err = msgp.WrapError(err) return } } } return } // EncodeMsg implements msgp.Encodable func (z *response) EncodeMsg(en *msgp.Writer) (err error) { // map header, size 3 // write "hs" err = en.Append(0x83, 0xa2, 0x68, 0x73) if err != nil { return } err = en.WriteMapHeader(uint32(len(z.Headers))) if err != nil { err = msgp.WrapError(err, "Headers") return } for za0001, za0002 := range z.Headers { err = en.WriteString(za0001) if err != nil { err = msgp.WrapError(err, "Headers") return } err = en.WriteArrayHeader(uint32(len(za0002))) if err != nil { err = msgp.WrapError(err, "Headers", za0001) return } for za0003 := range za0002 { err = en.WriteString(za0002[za0003]) if err != nil { err = msgp.WrapError(err, "Headers", za0001, za0003) return } } } // write "b" err = en.Append(0xa1, 0x62) if err != nil { return } err = en.WriteBytes(z.Body) if err != nil { err = msgp.WrapError(err, "Body") return } // write "sc" err = en.Append(0xa2, 0x73, 0x63) if err != nil { return } err = en.WriteInt(z.StatusCode) if err != nil { err = msgp.WrapError(err, "StatusCode") return } return } // MarshalMsg implements msgp.Marshaler func (z *response) MarshalMsg(b []byte) (o []byte, err error) { o = msgp.Require(b, z.Msgsize()) // map header, size 3 // string "hs" o = append(o, 0x83, 0xa2, 0x68, 0x73) o = msgp.AppendMapHeader(o, uint32(len(z.Headers))) for za0001, za0002 := range z.Headers { o = msgp.AppendString(o, za0001) o = msgp.AppendArrayHeader(o, uint32(len(za0002))) for za0003 := range za0002 { o = msgp.AppendString(o, za0002[za0003]) } } // string "b" o = append(o, 0xa1, 0x62) o = msgp.AppendBytes(o, z.Body) // string "sc" o = append(o, 0xa2, 0x73, 0x63) o = msgp.AppendInt(o, z.StatusCode) return } // UnmarshalMsg implements msgp.Unmarshaler func (z *response) UnmarshalMsg(bts []byte) (o []byte, err error) { var field []byte _ = field var zb0001 uint32 zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) if err != nil { err = msgp.WrapError(err) return } for zb0001 > 0 { zb0001-- field, bts, err = msgp.ReadMapKeyZC(bts) if err != nil { err = msgp.WrapError(err) return } switch msgp.UnsafeString(field) { case "hs": var zb0002 uint32 zb0002, bts, err = msgp.ReadMapHeaderBytes(bts) if err != nil { err = msgp.WrapError(err, "Headers") return } if zb0002 > 1024 { err = msgp.ErrLimitExceeded return } if z.Headers == nil { z.Headers = make(map[string][]string, zb0002) } else if len(z.Headers) > 0 { clear(z.Headers) } for zb0002 > 0 { var za0002 []string zb0002-- var za0001 string za0001, bts, err = msgp.ReadStringBytes(bts) if err != nil { err = msgp.WrapError(err, "Headers") return } var zb0003 uint32 zb0003, bts, err = msgp.ReadArrayHeaderBytes(bts) if err != nil { err = msgp.WrapError(err, "Headers", za0001) return } if zb0003 > 1024 { err = msgp.ErrLimitExceeded return } if cap(za0002) >= int(zb0003) { za0002 = (za0002)[:zb0003] } else { za0002 = make([]string, zb0003) } for za0003 := range za0002 { za0002[za0003], bts, err = msgp.ReadStringBytes(bts) if err != nil { err = msgp.WrapError(err, "Headers", za0001, za0003) return } } z.Headers[za0001] = za0002 } case "b": z.Body, bts, err = msgp.ReadBytesBytes(bts, z.Body) if err != nil { err = msgp.WrapError(err, "Body") return } case "sc": z.StatusCode, bts, err = msgp.ReadIntBytes(bts) if err != nil { err = msgp.WrapError(err, "StatusCode") return } default: bts, err = msgp.Skip(bts) if err != nil { err = msgp.WrapError(err) return } } } o = bts return } // Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message func (z *response) Msgsize() (s int) { s = 1 + 3 + msgp.MapHeaderSize if z.Headers != nil { for za0001, za0002 := range z.Headers { _ = za0002 s += msgp.StringPrefixSize + len(za0001) + msgp.ArrayHeaderSize for za0003 := range za0002 { s += msgp.StringPrefixSize + len(za0002[za0003]) } } } s += 2 + msgp.BytesPrefixSize + len(z.Body) + 3 + msgp.IntSize return } ================================================ FILE: middleware/idempotency/response_msgp_test.go ================================================ // Code generated by github.com/tinylib/msgp DO NOT EDIT. package idempotency import ( "bytes" "testing" "github.com/tinylib/msgp/msgp" ) func TestMarshalUnmarshalresponse(t *testing.T) { v := response{} bts, err := v.MarshalMsg(nil) if err != nil { t.Fatal(err) } left, err := v.UnmarshalMsg(bts) if err != nil { t.Fatal(err) } if len(left) > 0 { t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) } left, err = msgp.Skip(bts) if err != nil { t.Fatal(err) } if len(left) > 0 { t.Errorf("%d bytes left over after Skip(): %q", len(left), left) } } func BenchmarkMarshalMsgresponse(b *testing.B) { v := response{} b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { v.MarshalMsg(nil) } } func BenchmarkAppendMsgresponse(b *testing.B) { v := response{} bts := make([]byte, 0, v.Msgsize()) bts, _ = v.MarshalMsg(bts[0:0]) b.SetBytes(int64(len(bts))) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { bts, _ = v.MarshalMsg(bts[0:0]) } } func BenchmarkUnmarshalresponse(b *testing.B) { v := response{} bts, _ := v.MarshalMsg(nil) b.ReportAllocs() b.SetBytes(int64(len(bts))) b.ResetTimer() for i := 0; i < b.N; i++ { _, err := v.UnmarshalMsg(bts) if err != nil { b.Fatal(err) } } } func TestEncodeDecoderesponse(t *testing.T) { v := response{} var buf bytes.Buffer msgp.Encode(&buf, &v) m := v.Msgsize() if buf.Len() > m { t.Log("WARNING: TestEncodeDecoderesponse Msgsize() is inaccurate") } vn := response{} err := msgp.Decode(&buf, &vn) if err != nil { t.Error(err) } buf.Reset() msgp.Encode(&buf, &v) err = msgp.NewReader(&buf).Skip() if err != nil { t.Error(err) } } func BenchmarkEncoderesponse(b *testing.B) { v := response{} var buf bytes.Buffer msgp.Encode(&buf, &v) b.SetBytes(int64(buf.Len())) en := msgp.NewWriter(msgp.Nowhere) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { v.EncodeMsg(en) } en.Flush() } func BenchmarkDecoderesponse(b *testing.B) { v := response{} var buf bytes.Buffer msgp.Encode(&buf, &v) b.SetBytes(int64(buf.Len())) rd := msgp.NewEndlessReader(buf.Bytes(), b) dc := msgp.NewReader(rd) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { err := v.DecodeMsg(dc) if err != nil { b.Fatal(err) } } } ================================================ FILE: middleware/idempotency/stub_test.go ================================================ package idempotency import ( "context" "time" ) // stubLock implements Locker for testing purposes. type stubLock struct { lockErr error unlockErr error afterLock func() } func (s *stubLock) Lock(string) error { if s.afterLock != nil { s.afterLock() } return s.lockErr } func (s *stubLock) Unlock(string) error { return s.unlockErr } // stubStorage implements fiber.Storage for testing. type stubStorage struct { data map[string][]byte getErr error setErr error setCount int } func (s *stubStorage) Get(key string) ([]byte, error) { if s.getErr != nil { return nil, s.getErr } if s.data == nil { return nil, nil } return s.data[key], nil } func (s *stubStorage) GetWithContext(_ context.Context, key string) ([]byte, error) { // Call Get method to avoid code duplication return s.Get(key) } func (s *stubStorage) Set(key string, val []byte, _ time.Duration) error { if s.setErr != nil { return s.setErr } if s.data == nil { s.data = make(map[string][]byte) } s.data[key] = val s.setCount++ return nil } func (s *stubStorage) SetWithContext(_ context.Context, key string, val []byte, _ time.Duration) error { // Call Set method to avoid code duplication return s.Set(key, val, 0) } func (s *stubStorage) Delete(key string) error { if s.data != nil { delete(s.data, key) } return nil } func (s *stubStorage) DeleteWithContext(_ context.Context, key string) error { // Call Delete method to avoid code duplication return s.Delete(key) } func (s *stubStorage) Reset() error { s.data = make(map[string][]byte) return nil } func (s *stubStorage) ResetWithContext(_ context.Context) error { // Call Reset method to avoid code duplication return s.Reset() } func (*stubStorage) Close() error { return nil } ================================================ FILE: middleware/keyauth/config.go ================================================ package keyauth import ( "fmt" "net/url" "strings" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/extractors" ) const ( ErrorInvalidRequest = "invalid_request" ErrorInvalidToken = "invalid_token" ErrorInsufficientScope = "insufficient_scope" ) // Config defines the config for middleware. type Config struct { // Next defines a function to skip this middleware when returned true. // // Optional. Default: nil Next func(c fiber.Ctx) bool // SuccessHandler defines a function which is executed for a valid key. // // Optional. Default: c.Next() SuccessHandler fiber.Handler // ErrorHandler defines a function which is executed for an invalid key. // It may be used to define a custom error. // // Optional. Default: 401 Missing or invalid API Key ErrorHandler fiber.ErrorHandler // Validator is a function to validate the key. // // Required. Validator func(c fiber.Ctx, key string) (bool, error) // Realm defines the protected area for WWW-Authenticate responses. // This is used to set the `WWW-Authenticate` header when authentication fails. // // Optional. Default value "Restricted". Realm string // Challenge defines the full `WWW-Authenticate` header value used when // the middleware responds with 401 and no Authorization scheme is // present. // // Optional. Default: `ApiKey realm=""` when no Authorization scheme // is configured. Challenge string // Error is the RFC 6750 `error` parameter appended to Bearer // `WWW-Authenticate` challenges when validation fails. Allowed values // are `invalid_request`, `invalid_token`, or `insufficient_scope`. // // Optional. Default: "". Error string // ErrorDescription is the RFC 6750 `error_description` parameter // appended to Bearer `WWW-Authenticate` challenges when validation // fails. This field requires that `Error` is also set. // // Optional. Default: "". ErrorDescription string // ErrorURI is the RFC 6750 `error_uri` parameter appended to Bearer // `WWW-Authenticate` challenges when validation fails. This field // requires that `Error` is also set. // // Optional. Default: "". ErrorURI string // Scope is the RFC 6750 `scope` parameter appended to Bearer // challenges when the `error` is `insufficient_scope`. This field // requires that `Error` is set to `insufficient_scope`. // // Optional. Default: "". Scope string // Extractor is a function to extract the key from the request. // // Optional. Default: extractors.FromAuthHeader("Bearer") Extractor extractors.Extractor } // ConfigDefault is the default config var ConfigDefault = Config{ SuccessHandler: func(c fiber.Ctx) error { return c.Next() }, ErrorHandler: func(c fiber.Ctx, _ error) error { return c.Status(fiber.StatusUnauthorized).SendString(ErrMissingOrMalformedAPIKey.Error()) }, Realm: "Restricted", Extractor: extractors.FromAuthHeader("Bearer"), } // configDefault is a helper function to set default values func configDefault(config ...Config) Config { // Return default config if nothing provided if len(config) < 1 { panic("fiber: keyauth middleware requires a validator function") } cfg := config[0] // Require a validator function if cfg.Validator == nil { panic("fiber: keyauth middleware requires a validator function") } // Set default values if cfg.Extractor.Extract == nil { cfg.Extractor = ConfigDefault.Extractor } if cfg.Realm == "" { cfg.Realm = ConfigDefault.Realm } if cfg.SuccessHandler == nil { cfg.SuccessHandler = ConfigDefault.SuccessHandler } if cfg.ErrorHandler == nil { cfg.ErrorHandler = ConfigDefault.ErrorHandler } if len(getAuthSchemes(cfg.Extractor)) == 0 && cfg.Challenge == "" { cfg.Challenge = fmt.Sprintf("ApiKey realm=%q", cfg.Realm) } if cfg.Error != "" { switch cfg.Error { case ErrorInvalidRequest, ErrorInvalidToken, ErrorInsufficientScope: default: panic("fiber: keyauth unsupported error token") } } if cfg.ErrorDescription != "" && cfg.Error == "" { panic("fiber: keyauth error_description requires error") } if cfg.ErrorURI != "" { if cfg.Error == "" { panic("fiber: keyauth error_uri requires error") } if u, err := url.Parse(cfg.ErrorURI); err != nil || !u.IsAbs() { panic("fiber: keyauth error_uri must be absolute") } } if cfg.Error == ErrorInsufficientScope { if cfg.Scope == "" { panic("fiber: keyauth insufficient_scope requires scope") } for scope := range strings.SplitSeq(cfg.Scope, " ") { if scope == "" || !isScopeToken(scope) { panic("fiber: keyauth scope contains invalid token") } } } else if cfg.Scope != "" { panic("fiber: keyauth scope requires insufficient_scope error") } return cfg } func isScopeToken(s string) bool { for i := 0; i < len(s); i++ { c := s[i] if c < 0x21 || c > 0x7e || c == '"' || c == '\\' { return false } } return s != "" } ================================================ FILE: middleware/keyauth/config_test.go ================================================ package keyauth import ( "reflect" "testing" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/extractors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // Test_KeyAuth_ConfigDefault_NoConfig tests the case where no config is provided. func Test_KeyAuth_ConfigDefault_NoConfig(t *testing.T) { t.Parallel() // The New function will call configDefault with no arguments // which will panic because ConfigDefault.Validator is nil. assert.PanicsWithValue(t, "fiber: keyauth middleware requires a validator function", func() { New() }, "Calling New() without a validator should panic") } // Test_KeyAuth_ConfigDefault_PanicWithoutValidator tests that configDefault panics when Validator is nil. func Test_KeyAuth_ConfigDefault_PanicWithoutValidator(t *testing.T) { t.Parallel() assert.PanicsWithValue(t, "fiber: keyauth middleware requires a validator function", func() { configDefault(Config{}) }, "configDefault should panic if validator is not provided") } // Test_KeyAuth_ConfigDefault_WithValidator tests that default values are set when only a validator is provided. func Test_KeyAuth_ConfigDefault_WithValidator(t *testing.T) { t.Parallel() validator := func(fiber.Ctx, string) (bool, error) { return true, nil } cfg := configDefault(Config{ Validator: validator, }) require.NotNil(t, cfg.Validator) assert.Equal(t, ConfigDefault.Realm, cfg.Realm) require.NotNil(t, cfg.SuccessHandler) require.NotNil(t, cfg.ErrorHandler) require.NotNil(t, cfg.Extractor.Extract) } // Test_KeyAuth_ConfigDefault_CustomConfig tests that custom values are preserved. func Test_KeyAuth_ConfigDefault_CustomConfig(t *testing.T) { t.Parallel() nextFunc := func(_ fiber.Ctx) bool { return true } successHandler := func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) } errorHandler := func(c fiber.Ctx, _ error) error { return c.SendStatus(fiber.StatusForbidden) } validator := func(_ fiber.Ctx, _ string) (bool, error) { return true, nil } extractor := extractors.FromHeader("X-API-Key") cfg := configDefault(Config{ Next: nextFunc, SuccessHandler: successHandler, ErrorHandler: errorHandler, Validator: validator, Realm: "API", Extractor: extractor, }) // Using reflect.ValueOf to compare function pointers assert.Equal(t, reflect.ValueOf(nextFunc).Pointer(), reflect.ValueOf(cfg.Next).Pointer()) assert.Equal(t, reflect.ValueOf(successHandler).Pointer(), reflect.ValueOf(cfg.SuccessHandler).Pointer()) assert.Equal(t, reflect.ValueOf(errorHandler).Pointer(), reflect.ValueOf(cfg.ErrorHandler).Pointer()) assert.Equal(t, reflect.ValueOf(validator).Pointer(), reflect.ValueOf(cfg.Validator).Pointer()) assert.Equal(t, reflect.ValueOf(extractor.Extract).Pointer(), reflect.ValueOf(cfg.Extractor.Extract).Pointer()) assert.Equal(t, "API", cfg.Realm) } ================================================ FILE: middleware/keyauth/keyauth.go ================================================ package keyauth import ( "errors" "fmt" "strings" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/extractors" "github.com/gofiber/utils/v2" ) // The contextKey type is unexported to prevent collisions with context keys defined in // other packages. type contextKey int // The keys for the values in context const ( tokenKey contextKey = iota ) // ErrMissingOrMalformedAPIKey is returned when the API key is missing or invalid. var ErrMissingOrMalformedAPIKey = errors.New("missing or invalid API Key") // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Init config cfg := configDefault(config...) // Determine the auth schemes from the extractor chain. authSchemes := getAuthSchemes(cfg.Extractor) // Return middleware handler return func(c fiber.Ctx) error { // Filter request to skip middleware if cfg.Next != nil && cfg.Next(c) { return c.Next() } // Extract and verify key key, err := cfg.Extractor.Extract(c) if errors.Is(err, extractors.ErrNotFound) { // Replace shared extractor not found error with a keyauth specific error err = ErrMissingOrMalformedAPIKey } // If there was no error extracting the key, validate it if err == nil { var valid bool valid, err = cfg.Validator(c, key) if err == nil && valid { fiber.StoreInContext(c, tokenKey, key) return cfg.SuccessHandler(c) } } // Execute the error handler first handlerErr := cfg.ErrorHandler(c, err) status := c.Response().StatusCode() if status == fiber.StatusUnauthorized || status == fiber.StatusProxyAuthRequired { header := fiber.HeaderWWWAuthenticate if status == fiber.StatusProxyAuthRequired { header = fiber.HeaderProxyAuthenticate } if len(authSchemes) > 0 { challenges := make([]string, 0, len(authSchemes)) for _, scheme := range authSchemes { var b strings.Builder fmt.Fprintf(&b, "%s realm=%q", scheme, cfg.Realm) if utils.EqualFold(scheme, "Bearer") { if cfg.Error != "" { fmt.Fprintf(&b, ", error=%q", cfg.Error) if cfg.ErrorDescription != "" { fmt.Fprintf(&b, ", error_description=%q", cfg.ErrorDescription) } if cfg.ErrorURI != "" { fmt.Fprintf(&b, ", error_uri=%q", cfg.ErrorURI) } if cfg.Error == ErrorInsufficientScope { fmt.Fprintf(&b, ", scope=%q", cfg.Scope) } } } challenges = append(challenges, b.String()) } c.Set(header, strings.Join(challenges, ", ")) } else if cfg.Challenge != "" { c.Set(header, cfg.Challenge) } } return handlerErr } } // TokenFromContext returns the bearer token from the request context. // It accepts fiber.CustomCtx, fiber.Ctx, *fasthttp.RequestCtx, and context.Context. // It returns an empty string if the token does not exist. func TokenFromContext(ctx any) string { if token, ok := fiber.ValueFromContext[string](ctx, tokenKey); ok { return token } return "" } // getAuthSchemes inspects an extractor and its chain to find all auth schemes // used by FromAuthHeader. It returns a slice of schemes, or an empty slice if // none are found. func getAuthSchemes(e extractors.Extractor) []string { var schemes []string if e.Source == extractors.SourceAuthHeader && e.AuthScheme != "" { schemes = append(schemes, e.AuthScheme) } for _, ex := range e.Chain { schemes = append(schemes, getAuthSchemes(ex)...) } return schemes } ================================================ FILE: middleware/keyauth/keyauth_test.go ================================================ package keyauth import ( "context" "errors" "io" "net/http" "net/http/httptest" "net/url" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/extractors" ) const CorrectKey = "correct-token_123./~+" var testConfig = fiber.TestConfig{ Timeout: 0, } const ( paramExtractorName = "param" formExtractorName = "form" queryExtractorName = "query" headerExtractorName = "header" authHeaderExtractorName = "authHeader" cookieExtractorName = "cookie" ) func Test_AuthSources(t *testing.T) { // define test cases testSources := []string{headerExtractorName, authHeaderExtractorName, cookieExtractorName, queryExtractorName, paramExtractorName, formExtractorName} tests := []struct { route string authTokenName string description string APIKey string expectedBody string expectedCode int }{ { route: "/", authTokenName: "access_token", description: "auth with correct key", APIKey: CorrectKey, expectedCode: 200, expectedBody: "Success!", }, { route: "/", authTokenName: "access_token", description: "auth with no key", APIKey: "", expectedCode: 401, // 404 in case of param authentication expectedBody: ErrMissingOrMalformedAPIKey.Error(), }, { route: "/", authTokenName: "access_token", description: "auth with wrong key", APIKey: "WRONGKEY", expectedCode: 401, expectedBody: ErrMissingOrMalformedAPIKey.Error(), }, } for _, authSource := range testSources { t.Run(authSource, func(t *testing.T) { for _, test := range tests { app := fiber.New(fiber.Config{UnescapePath: true}) testKey := test.APIKey correctKey := CorrectKey // Use a simple key for param and cookie to avoid encoding issues in the test setup if authSource == paramExtractorName || authSource == cookieExtractorName { if test.APIKey != "" && test.APIKey != "WRONGKEY" { testKey = "simple-key" correctKey = "simple-key" } } authMiddleware := New(Config{ Extractor: func() extractors.Extractor { switch authSource { case headerExtractorName: return extractors.FromHeader(test.authTokenName) case authHeaderExtractorName: return extractors.FromAuthHeader("Bearer") case cookieExtractorName: return extractors.FromCookie(test.authTokenName) case queryExtractorName: return extractors.FromQuery(test.authTokenName) case paramExtractorName: return extractors.FromParam(test.authTokenName) case formExtractorName: return extractors.FromForm(test.authTokenName) default: panic("unknown source") } }(), Validator: func(_ fiber.Ctx, key string) (bool, error) { if key == correctKey { return true, nil } return false, errors.New("invalid key") }, }) handler := func(c fiber.Ctx) error { return c.SendString("Success!") } method := fiber.MethodGet switch authSource { case paramExtractorName: app.Get("/:"+test.authTokenName, authMiddleware, handler) case formExtractorName: method = fiber.MethodPost app.Post("/", authMiddleware, handler) default: app.Get("/", authMiddleware, handler) } targetURL := "/" if authSource == paramExtractorName { targetURL = "/" + url.PathEscape(testKey) } var reqBody io.Reader if authSource == formExtractorName { form := url.Values{} form.Add(test.authTokenName, testKey) bodyStr := form.Encode() reqBody = strings.NewReader(bodyStr) } req, err := http.NewRequestWithContext(context.Background(), method, targetURL, reqBody) require.NoError(t, err) switch authSource { case headerExtractorName: req.Header.Set(test.authTokenName, testKey) case authHeaderExtractorName: if testKey != "" { req.Header.Set("Authorization", "Bearer "+testKey) } case cookieExtractorName: req.Header.Set("Cookie", test.authTokenName+"="+testKey) case queryExtractorName: q := req.URL.Query() q.Add(test.authTokenName, testKey) req.URL.RawQuery = q.Encode() case formExtractorName: req.Header.Add("Content-Type", "application/x-www-form-urlencoded") default: // nothing to do for paramExtractorName } res, err := app.Test(req, testConfig) require.NoError(t, err, test.description) body, err := io.ReadAll(res.Body) require.NoError(t, err) errClose := res.Body.Close() require.NoError(t, errClose) expectedCode := test.expectedCode expectedBody := test.expectedBody if test.APIKey == "" || test.APIKey == "WRONGKEY" { expectedBody = ErrMissingOrMalformedAPIKey.Error() } if authSource == paramExtractorName && testKey == "" { expectedCode = 404 expectedBody = "Not Found" } require.Equal(t, expectedCode, res.StatusCode, test.description) require.Equal(t, expectedBody, string(body), test.description) } }) } } func TestMultipleKeyLookup(t *testing.T) { const ( desc = "auth with correct key" success = "Success!" scheme = "Bearer" ) // set up the fiber endpoint app := fiber.New() customExtractor := extractors.Chain( extractors.FromAuthHeader("Bearer"), extractors.FromHeader("key"), extractors.FromCookie("key"), extractors.FromQuery("key"), ) authMiddleware := New(Config{ Extractor: customExtractor, Validator: func(_ fiber.Ctx, key string) (bool, error) { if key == CorrectKey { return true, nil } return false, errors.New("invalid key") }, }) app.Use(authMiddleware) app.Get("/foo", func(c fiber.Ctx) error { return c.SendString(success) }) // construct the test HTTP request var ( req *http.Request err error ) req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/foo", http.NoBody) require.NoError(t, err) q := req.URL.Query() q.Add("key", CorrectKey) req.URL.RawQuery = q.Encode() res, err := app.Test(req, testConfig) require.NoError(t, err) // test the body of the request body, err := io.ReadAll(res.Body) require.Equal(t, 200, res.StatusCode, desc) // body require.NoError(t, err) require.Equal(t, success, string(body), desc) err = res.Body.Close() require.NoError(t, err) // construct a second request without proper key req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/foo", http.NoBody) require.NoError(t, err) res, err = app.Test(req, testConfig) require.NoError(t, err) errBody, err := io.ReadAll(res.Body) require.NoError(t, err) require.Equal(t, ErrMissingOrMalformedAPIKey.Error(), string(errBody)) } func Test_MultipleKeyAuth(t *testing.T) { // set up the fiber endpoint app := fiber.New() // set up keyauth for /auth1 app.Use(New(Config{ Next: func(c fiber.Ctx) bool { return c.Path() != "/auth1" }, Extractor: extractors.FromAuthHeader("Bearer"), Validator: func(_ fiber.Ctx, key string) (bool, error) { if key == "password1" { return true, nil } return false, errors.New("invalid key") }, })) // setup keyauth for /auth2 app.Use(New(Config{ Next: func(c fiber.Ctx) bool { return c.Path() != "/auth2" }, Extractor: extractors.FromAuthHeader("Bearer"), Validator: func(_ fiber.Ctx, key string) (bool, error) { if key == "password2" { return true, nil } return false, errors.New("invalid key") }, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("No auth needed!") }) app.Get("/auth1", func(c fiber.Ctx) error { return c.SendString("Successfully authenticated for auth1!") }) app.Get("/auth2", func(c fiber.Ctx) error { return c.SendString("Successfully authenticated for auth2!") }) // define test cases tests := []struct { route string description string APIKey string expectedBody string expectedCode int }{ // No auth needed for / { route: "/", description: "No password needed", APIKey: "", expectedCode: 200, expectedBody: "No auth needed!", }, // auth needed for auth1 { route: "/auth1", description: "Normal Authentication Case", APIKey: "password1", expectedCode: 200, expectedBody: "Successfully authenticated for auth1!", }, { route: "/auth1", description: "Wrong API Key", APIKey: "WRONG KEY", expectedCode: 401, expectedBody: ErrMissingOrMalformedAPIKey.Error(), }, { route: "/auth1", description: "Wrong API Key", APIKey: "", // NO KEY expectedCode: 401, expectedBody: ErrMissingOrMalformedAPIKey.Error(), }, // Auth 2 has a different password { route: "/auth2", description: "Normal Authentication Case for auth2", APIKey: "password2", expectedCode: 200, expectedBody: "Successfully authenticated for auth2!", }, { route: "/auth2", description: "Wrong API Key", APIKey: "WRONG KEY", expectedCode: 401, expectedBody: ErrMissingOrMalformedAPIKey.Error(), }, { route: "/auth2", description: "Wrong API Key", APIKey: "", // NO KEY expectedCode: 401, expectedBody: ErrMissingOrMalformedAPIKey.Error(), }, } // run the tests for _, test := range tests { var req *http.Request req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, test.route, http.NoBody) require.NoError(t, err) if test.APIKey != "" { req.Header.Set("Authorization", "Bearer "+test.APIKey) } res, err := app.Test(req, testConfig) require.NoError(t, err, test.description) // test the body of the request body, err := io.ReadAll(res.Body) require.Equal(t, test.expectedCode, res.StatusCode, test.description) // body require.NoError(t, err, test.description) require.Equal(t, test.expectedBody, string(body), test.description) } } func Test_CustomSuccessAndFailureHandlers(t *testing.T) { app := fiber.New() app.Use(New(Config{ SuccessHandler: func(c fiber.Ctx) error { return c.Status(fiber.StatusOK).SendString("API key is valid and request was handled by custom success handler") }, ErrorHandler: func(c fiber.Ctx, _ error) error { return c.Status(fiber.StatusUnauthorized).SendString("API key is invalid and request was handled by custom error handler") }, Validator: func(_ fiber.Ctx, key string) (bool, error) { if key == CorrectKey { return true, nil } return false, ErrMissingOrMalformedAPIKey }, })) // Define a test handler that should not be called app.Get("/", func(_ fiber.Ctx) error { t.Error("Test handler should not be called") return nil }) // Create a request without an API key and send it to the app res, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) // Read the response body into a string body, err := io.ReadAll(res.Body) require.NoError(t, err) // Check that the response has the expected status code and body require.Equal(t, http.StatusUnauthorized, res.StatusCode) require.Equal(t, "API key is invalid and request was handled by custom error handler", string(body)) // Create a request with a valid API key in the Authorization header req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Add("Authorization", "Bearer "+CorrectKey) // Send the request to the app res, err = app.Test(req) require.NoError(t, err) // Read the response body into a string body, err = io.ReadAll(res.Body) require.NoError(t, err) // Check that the response has the expected status code and body require.Equal(t, http.StatusOK, res.StatusCode) require.Equal(t, "API key is valid and request was handled by custom success handler", string(body)) } func Test_CustomNextFunc(t *testing.T) { app := fiber.New() app.Use(New(Config{ Next: func(c fiber.Ctx) bool { return c.Path() == "/allowed" }, Validator: func(_ fiber.Ctx, key string) (bool, error) { if key == CorrectKey { return true, nil } return false, ErrMissingOrMalformedAPIKey }, })) // Define a test handler app.Get("/allowed", func(c fiber.Ctx) error { return c.SendString("API key is valid and request was allowed by custom filter") }) app.Get("/not-allowed", func(c fiber.Ctx) error { return c.SendString("Should be protected") }) // Create a request with the "/allowed" path and send it to the app req := httptest.NewRequest(fiber.MethodGet, "/allowed", http.NoBody) res, err := app.Test(req) require.NoError(t, err) // Read the response body into a string body, err := io.ReadAll(res.Body) require.NoError(t, err) // Check that the response has the expected status code and body require.Equal(t, http.StatusOK, res.StatusCode) require.Equal(t, "API key is valid and request was allowed by custom filter", string(body)) // Create a request with a different path and send it to the app without correct key req = httptest.NewRequest(fiber.MethodGet, "/not-allowed", http.NoBody) res, err = app.Test(req) require.NoError(t, err) // Read the response body into a string body, err = io.ReadAll(res.Body) require.NoError(t, err) // Check that the response has the expected status code and body require.Equal(t, http.StatusUnauthorized, res.StatusCode) require.Equal(t, ErrMissingOrMalformedAPIKey.Error(), string(body)) // Create a request with a different path and send it to the app with correct key req = httptest.NewRequest(fiber.MethodGet, "/not-allowed", http.NoBody) req.Header.Add("Authorization", "Bearer "+CorrectKey) res, err = app.Test(req) require.NoError(t, err) // Read the response body into a string body, err = io.ReadAll(res.Body) require.NoError(t, err) // Check that the response has the expected status code and body require.Equal(t, http.StatusOK, res.StatusCode) require.Equal(t, "Should be protected", string(body)) } func Test_TokenFromContext_None(t *testing.T) { app := fiber.New() // Define a test handler that checks TokenFromContext app.Get("/", func(c fiber.Ctx) error { return c.SendString(TokenFromContext(c)) }) // Verify a "" is sent back if nothing sets the token on the context. req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) // Send res, err := app.Test(req) require.NoError(t, err) // Read the response body into a string body, err := io.ReadAll(res.Body) require.NoError(t, err) require.Empty(t, body) } func Test_TokenFromContext(t *testing.T) { app := fiber.New() // Wire up keyauth middleware to set TokenFromContext now app.Use(New(Config{ Extractor: extractors.FromAuthHeader("Basic"), Validator: func(_ fiber.Ctx, key string) (bool, error) { if key == CorrectKey { return true, nil } return false, ErrMissingOrMalformedAPIKey }, })) // Define a test handler that checks TokenFromContext app.Get("/", func(c fiber.Ctx) error { return c.SendString(TokenFromContext(c)) }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Add("Authorization", "Basic "+CorrectKey) // Send res, err := app.Test(req) require.NoError(t, err) // Read the response body into a string body, err := io.ReadAll(res.Body) require.NoError(t, err) require.Equal(t, CorrectKey, string(body)) } func Test_TokenFromContext_Types(t *testing.T) { t.Parallel() app := fiber.New(fiber.Config{PassLocalsToContext: true}) app.Use(New(Config{ Extractor: extractors.FromAuthHeader("Basic"), Validator: func(_ fiber.Ctx, key string) (bool, error) { if key == CorrectKey { return true, nil } return false, ErrMissingOrMalformedAPIKey }, })) app.Get("/", func(c fiber.Ctx) error { require.Equal(t, CorrectKey, TokenFromContext(c)) customCtx, ok := c.(fiber.CustomCtx) require.True(t, ok) require.Equal(t, CorrectKey, TokenFromContext(customCtx)) require.Equal(t, CorrectKey, TokenFromContext(c.RequestCtx())) require.Equal(t, CorrectKey, TokenFromContext(c.Context())) return c.SendStatus(fiber.StatusOK) }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Add("Authorization", "Basic "+CorrectKey) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) } func Test_AuthSchemeToken(t *testing.T) { app := fiber.New() app.Use(New(Config{ Extractor: extractors.FromAuthHeader("Token"), Validator: func(_ fiber.Ctx, key string) (bool, error) { if key == CorrectKey { return true, nil } return false, ErrMissingOrMalformedAPIKey }, })) // Define a test handler app.Get("/", func(c fiber.Ctx) error { return c.SendString("API key is valid") }) // Create a request with a valid API key in the "Token" Authorization header req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Add("Authorization", "Token "+CorrectKey) // Send the request to the app res, err := app.Test(req) require.NoError(t, err) // Read the response body into a string body, err := io.ReadAll(res.Body) require.NoError(t, err) // Check that the response has the expected status code and body require.Equal(t, http.StatusOK, res.StatusCode) require.Equal(t, "API key is valid", string(body)) } func Test_AuthSchemeBasic(t *testing.T) { app := fiber.New() app.Use(New(Config{ Extractor: extractors.FromAuthHeader("Basic"), Validator: func(_ fiber.Ctx, key string) (bool, error) { if key == CorrectKey { return true, nil } return false, ErrMissingOrMalformedAPIKey }, })) // Define a test handler app.Get("/", func(c fiber.Ctx) error { return c.SendString("API key is valid") }) // Create a request without an API key and Send the request to the app res, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) // Read the response body into a string body, err := io.ReadAll(res.Body) require.NoError(t, err) // Check that the response has the expected status code and body require.Equal(t, http.StatusUnauthorized, res.StatusCode) require.Equal(t, ErrMissingOrMalformedAPIKey.Error(), string(body)) // Create a request with a valid API key in the "Authorization" header using the "Basic" scheme req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Add("Authorization", "Basic "+CorrectKey) // Send the request to the app res, err = app.Test(req) require.NoError(t, err) // Read the response body into a string body, err = io.ReadAll(res.Body) require.NoError(t, err) // Check that the response has the expected status code and body require.Equal(t, http.StatusOK, res.StatusCode) require.Equal(t, "API key is valid", string(body)) } func Test_HeaderSchemeCaseInsensitive(t *testing.T) { app := fiber.New() app.Use(New(Config{ Validator: func(_ fiber.Ctx, key string) (bool, error) { if key == CorrectKey { return true, nil } return false, ErrMissingOrMalformedAPIKey }, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Add("Authorization", "bearer "+CorrectKey) res, err := app.Test(req) require.NoError(t, err) body, err := io.ReadAll(res.Body) require.NoError(t, err) require.Equal(t, http.StatusOK, res.StatusCode) require.Equal(t, "OK", string(body)) } func Test_DefaultErrorHandlerChallenge(t *testing.T) { app := fiber.New() app.Use(New(Config{ Validator: func(_ fiber.Ctx, _ string) (bool, error) { return false, ErrMissingOrMalformedAPIKey }, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") }) res, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, http.StatusUnauthorized, res.StatusCode) require.Equal(t, "Bearer realm=\"Restricted\"", res.Header.Get("WWW-Authenticate")) } func Test_DefaultErrorHandlerInvalid(t *testing.T) { app := fiber.New() app.Use(New(Config{ Validator: func(_ fiber.Ctx, _ string) (bool, error) { return false, errors.New("invalid") }, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Add("Authorization", "Bearer "+CorrectKey) res, err := app.Test(req) require.NoError(t, err) body, err := io.ReadAll(res.Body) require.NoError(t, err) require.Equal(t, http.StatusUnauthorized, res.StatusCode) require.Equal(t, ErrMissingOrMalformedAPIKey.Error(), string(body)) require.Equal(t, "Bearer realm=\"Restricted\"", res.Header.Get("WWW-Authenticate")) } func Test_HeaderSchemeMultipleSpaces(t *testing.T) { app := fiber.New() app.Use(New(Config{ Validator: func(_ fiber.Ctx, key string) (bool, error) { if key == CorrectKey { return true, nil } return false, ErrMissingOrMalformedAPIKey }, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Add("Authorization", "Bearer "+CorrectKey) res, err := app.Test(req) require.NoError(t, err) body, err := io.ReadAll(res.Body) require.NoError(t, err) require.Equal(t, http.StatusUnauthorized, res.StatusCode) require.Equal(t, ErrMissingOrMalformedAPIKey.Error(), string(body)) } func Test_HeaderSchemeMissingSpace(t *testing.T) { app := fiber.New() app.Use(New(Config{Validator: func(_ fiber.Ctx, _ string) (bool, error) { return false, ErrMissingOrMalformedAPIKey }})) app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Add("Authorization", "Bearer"+CorrectKey) res, err := app.Test(req) require.NoError(t, err) body, err := io.ReadAll(res.Body) require.NoError(t, err) require.Equal(t, http.StatusUnauthorized, res.StatusCode) require.Equal(t, ErrMissingOrMalformedAPIKey.Error(), string(body)) } func Test_HeaderSchemeNoToken(t *testing.T) { app := fiber.New() app.Use(New(Config{Validator: func(_ fiber.Ctx, _ string) (bool, error) { return false, ErrMissingOrMalformedAPIKey }})) app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Add("Authorization", "Bearer ") res, err := app.Test(req) require.NoError(t, err) body, err := io.ReadAll(res.Body) require.NoError(t, err) require.Equal(t, http.StatusUnauthorized, res.StatusCode) require.Equal(t, ErrMissingOrMalformedAPIKey.Error(), string(body)) } func Test_HeaderSchemeNoSeparator(t *testing.T) { app := fiber.New() app.Use(New(Config{Validator: func(_ fiber.Ctx, _ string) (bool, error) { return false, ErrMissingOrMalformedAPIKey }})) app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) // No space between "Bearer" and token req.Header.Add("Authorization", "BearerTokenWithoutSpace") res, err := app.Test(req) require.NoError(t, err) body, err := io.ReadAll(res.Body) require.NoError(t, err) require.Equal(t, http.StatusUnauthorized, res.StatusCode) require.Equal(t, ErrMissingOrMalformedAPIKey.Error(), string(body)) } func Test_HeaderSchemeEmptyTokenAfterTrim(t *testing.T) { app := fiber.New() app.Use(New(Config{ Validator: func(_ fiber.Ctx, _ string) (bool, error) { return false, ErrMissingOrMalformedAPIKey }, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) // Authorization header with scheme followed by only spaces/tabs (no actual token) req.Header.Add("Authorization", "Bearer \t \t ") res, err := app.Test(req) require.NoError(t, err) body, err := io.ReadAll(res.Body) require.NoError(t, err) require.Equal(t, http.StatusUnauthorized, res.StatusCode) require.Equal(t, ErrMissingOrMalformedAPIKey.Error(), string(body)) } func Test_WWWAuthenticateHeader(t *testing.T) { t.Parallel() tests := []struct { name string expectedHeader string config Config expectedStatusCode int }{ { name: "default config on failure", config: Config{ Validator: func(_ fiber.Ctx, _ string) (bool, error) { return false, errors.New("validation failed") }, }, expectedHeader: `Bearer realm="Restricted"`, expectedStatusCode: fiber.StatusUnauthorized, }, { name: "custom realm on failure", config: Config{ Validator: func(_ fiber.Ctx, _ string) (bool, error) { return false, errors.New("validation failed") }, Realm: "My Custom Realm", }, expectedHeader: `Bearer realm="My Custom Realm"`, expectedStatusCode: fiber.StatusUnauthorized, }, { name: "default header for non-auth-header extractor", config: Config{ Validator: func(_ fiber.Ctx, _ string) (bool, error) { return false, errors.New("validation failed") }, Extractor: extractors.FromQuery("api_key"), }, expectedHeader: `ApiKey realm="Restricted"`, expectedStatusCode: fiber.StatusUnauthorized, }, { name: "no header on success", config: Config{ Validator: func(_ fiber.Ctx, _ string) (bool, error) { return true, nil }, }, expectedHeader: "", expectedStatusCode: fiber.StatusOK, }, { name: "chained extractor with auth header", config: Config{ Validator: func(_ fiber.Ctx, _ string) (bool, error) { return false, errors.New("validation failed") }, Extractor: extractors.Chain(extractors.FromQuery("q"), extractors.FromAuthHeader("MyScheme")), }, expectedHeader: `MyScheme realm="Restricted"`, expectedStatusCode: fiber.StatusUnauthorized, }, { name: "chained extractor without auth header", config: Config{ Validator: func(_ fiber.Ctx, _ string) (bool, error) { return false, errors.New("validation failed") }, Extractor: extractors.Chain(extractors.FromQuery("q"), extractors.FromCookie("c")), }, expectedHeader: `ApiKey realm="Restricted"`, expectedStatusCode: fiber.StatusUnauthorized, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(tt.config)) app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") }) req := httptest.NewRequest(http.MethodGet, "/", http.NoBody) // Provide a key for the default extractor to find if tt.config.Extractor.Extract == nil { req.Header.Set(fiber.HeaderAuthorization, "Bearer somekey") } resp, err := app.Test(req) require.NoError(t, err) assert.Equal(t, tt.expectedStatusCode, resp.StatusCode) assert.Equal(t, tt.expectedHeader, resp.Header.Get(fiber.HeaderWWWAuthenticate)) }) } } func Test_CustomChallenge(t *testing.T) { app := fiber.New() app.Use(New(Config{ Extractor: extractors.FromQuery("api_key"), Validator: func(_ fiber.Ctx, _ string) (bool, error) { return false, errors.New("invalid") }, Challenge: `ApiKey realm="Restricted"`, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) res, err := app.Test(req) require.NoError(t, err) require.Equal(t, http.StatusUnauthorized, res.StatusCode) require.Equal(t, `ApiKey realm="Restricted"`, res.Header.Get("WWW-Authenticate")) } func Test_BearerErrorFields(t *testing.T) { app := fiber.New() app.Use(New(Config{ Validator: func(_ fiber.Ctx, _ string) (bool, error) { return false, errors.New("invalid") }, Error: "invalid_token", ErrorDescription: "token expired", ErrorURI: "https://example.com", })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Add("Authorization", "Bearer something") res, err := app.Test(req) require.NoError(t, err) require.Equal(t, http.StatusUnauthorized, res.StatusCode) require.Equal(t, `Bearer realm="Restricted", error="invalid_token", error_description="token expired", error_uri="https://example.com"`, res.Header.Get("WWW-Authenticate")) } func Test_BearerErrorURIOnly(t *testing.T) { app := fiber.New() app.Use(New(Config{ Validator: func(_ fiber.Ctx, _ string) (bool, error) { return false, errors.New("invalid") }, Error: "invalid_token", ErrorURI: "https://example.com/docs", })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Add("Authorization", "Bearer something") res, err := app.Test(req) require.NoError(t, err) require.Equal(t, http.StatusUnauthorized, res.StatusCode) require.Equal(t, `Bearer realm="Restricted", error="invalid_token", error_uri="https://example.com/docs"`, res.Header.Get("WWW-Authenticate")) } func Test_BearerInsufficientScope(t *testing.T) { app := fiber.New() app.Use(New(Config{ Validator: func(_ fiber.Ctx, _ string) (bool, error) { return false, errors.New("invalid") }, Error: ErrorInsufficientScope, Scope: "read", })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Add("Authorization", "Bearer something") res, err := app.Test(req) require.NoError(t, err) require.Equal(t, http.StatusUnauthorized, res.StatusCode) require.Equal(t, `Bearer realm="Restricted", error="insufficient_scope", scope="read"`, res.Header.Get("WWW-Authenticate")) } func Test_ScopeValidation(t *testing.T) { require.PanicsWithValue(t, "fiber: keyauth scope requires insufficient_scope error", func() { New(Config{ Validator: func(_ fiber.Ctx, _ string) (bool, error) { return true, nil }, Scope: "foo", }) }) require.PanicsWithValue(t, "fiber: keyauth insufficient_scope requires scope", func() { New(Config{ Validator: func(_ fiber.Ctx, _ string) (bool, error) { return true, nil }, Error: ErrorInsufficientScope, }) }) require.PanicsWithValue(t, "fiber: keyauth scope contains invalid token", func() { New(Config{ Validator: func(_ fiber.Ctx, _ string) (bool, error) { return true, nil }, Error: ErrorInsufficientScope, Scope: "read \"write\"", }) }) require.NotPanics(t, func() { New(Config{ Validator: func(_ fiber.Ctx, _ string) (bool, error) { return true, nil }, Error: ErrorInsufficientScope, Scope: "read write:all", }) }) } func Test_WWWAuthenticateOnlyOn401(t *testing.T) { app := fiber.New() app.Use(New(Config{ Validator: func(_ fiber.Ctx, _ string) (bool, error) { return false, errors.New("invalid") }, ErrorHandler: func(c fiber.Ctx, _ error) error { return c.Status(fiber.StatusForbidden).SendString("forbidden") }, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Add("Authorization", "Bearer bad") res, err := app.Test(req) require.NoError(t, err) require.Equal(t, http.StatusForbidden, res.StatusCode) require.Empty(t, res.Header.Get("WWW-Authenticate")) } func Test_DefaultChallengeForNonAuthExtractor(t *testing.T) { app := fiber.New() app.Use(New(Config{ Extractor: extractors.FromQuery("api_key"), Validator: func(_ fiber.Ctx, _ string) (bool, error) { return false, ErrMissingOrMalformedAPIKey }, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") }) res, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, http.StatusUnauthorized, res.StatusCode) require.Equal(t, `ApiKey realm="Restricted"`, res.Header.Get(fiber.HeaderWWWAuthenticate)) } func Test_MultipleWWWAuthenticateChallenges(t *testing.T) { app := fiber.New() app.Use(New(Config{ Extractor: extractors.Chain( extractors.FromAuthHeader("Bearer"), extractors.FromAuthHeader("ApiKey"), ), Validator: func(_ fiber.Ctx, _ string) (bool, error) { return false, errors.New("invalid") }, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") }) res, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, http.StatusUnauthorized, res.StatusCode) require.Equal(t, `Bearer realm="Restricted", ApiKey realm="Restricted"`, res.Header.Get(fiber.HeaderWWWAuthenticate)) } func Test_ProxyAuthenticateHeader(t *testing.T) { app := fiber.New() app.Use(New(Config{ Validator: func(_ fiber.Ctx, _ string) (bool, error) { return false, errors.New("invalid") }, ErrorHandler: func(c fiber.Ctx, _ error) error { return c.Status(fiber.StatusProxyAuthRequired).SendString("proxy auth") }, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("OK") }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Add("Authorization", "Bearer bad") res, err := app.Test(req) require.NoError(t, err) require.Equal(t, http.StatusProxyAuthRequired, res.StatusCode) require.Equal(t, `Bearer realm="Restricted"`, res.Header.Get(fiber.HeaderProxyAuthenticate)) require.Empty(t, res.Header.Get(fiber.HeaderWWWAuthenticate)) } func Test_New_InvalidErrorToken(t *testing.T) { assert.Panics(t, func() { New(Config{ Validator: func(_ fiber.Ctx, _ string) (bool, error) { return true, nil }, Error: "unsupported", }) }) } func Test_New_ErrorDescriptionRequiresError(t *testing.T) { assert.Panics(t, func() { New(Config{ Validator: func(_ fiber.Ctx, _ string) (bool, error) { return true, nil }, ErrorDescription: "desc", }) }) } func Test_New_ErrorURIRequiresError(t *testing.T) { assert.PanicsWithValue(t, "fiber: keyauth error_uri requires error", func() { New(Config{ Validator: func(_ fiber.Ctx, _ string) (bool, error) { return true, nil }, ErrorURI: "https://example.com/docs", }) }) } func Test_New_ErrorURIAbsolute(t *testing.T) { assert.PanicsWithValue(t, "fiber: keyauth error_uri must be absolute", func() { New(Config{ Validator: func(_ fiber.Ctx, _ string) (bool, error) { return true, nil }, Error: "invalid_token", ErrorURI: "/docs", }) }) } ================================================ FILE: middleware/limiter/config.go ================================================ package limiter import ( "time" "github.com/gofiber/fiber/v3" ) const defaultLimiterMax = 5 // Config defines the config for middleware. type Config struct { // Store is used to store the state of the middleware // // Default: an in memory store for this process only Storage fiber.Storage // LimiterMiddleware is the struct that implements a limiter middleware. // // Default: a new Fixed Window Rate Limiter LimiterMiddleware Handler // Next defines a function to skip this middleware when returned true. // // Optional. Default: nil Next func(c fiber.Ctx) bool // A function to dynamically calculate the max requests supported by the rate limiter middleware // // Default: func(c fiber.Ctx) int { // return c.Max // } MaxFunc func(c fiber.Ctx) int // A function to dynamically calculate the expiration time for rate limiter entries // // Default: A function that returns the static `Expiration` value from the config. ExpirationFunc func(c fiber.Ctx) time.Duration // KeyGenerator allows you to generate custom keys, by default c.IP() is used // // Default: func(c fiber.Ctx) string { // return c.IP() // } KeyGenerator func(fiber.Ctx) string // LimitReached is called when a request hits the limit // // Default: func(c fiber.Ctx) error { // return c.SendStatus(fiber.StatusTooManyRequests) // } LimitReached fiber.Handler // Max number of recent connections during `Expiration` seconds before sending a 429 response // // Default: 5 Max int // Expiration is the time on how long to keep records of requests in memory // // Default: 1 * time.Minute Expiration time.Duration // When set to true, requests with StatusCode >= 400 won't be counted. // // Default: false SkipFailedRequests bool // When set to true, requests with StatusCode < 400 won't be counted. // // Default: false SkipSuccessfulRequests bool // When set to true, the middleware will not include the rate limit headers (X-RateLimit-* and Retry-After) in the response. // // Default: false DisableHeaders bool // DisableValueRedaction turns off masking limiter keys in logs and error messages when set to true. // // Default: false DisableValueRedaction bool } // ConfigDefault is the default config var ConfigDefault = Config{ Max: defaultLimiterMax, Expiration: 1 * time.Minute, MaxFunc: func(_ fiber.Ctx) int { return defaultLimiterMax }, // Note: ExpirationFunc is intentionally nil here so that configDefault() // can create a proper closure that references the configured Expiration value. KeyGenerator: func(c fiber.Ctx) string { return c.IP() }, LimitReached: func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusTooManyRequests) }, SkipFailedRequests: false, SkipSuccessfulRequests: false, DisableHeaders: false, DisableValueRedaction: false, LimiterMiddleware: FixedWindow{}, } // Helper function to set default values func configDefault(config ...Config) Config { // Use default config if nothing provided var cfg Config if len(config) < 1 { cfg = ConfigDefault } else { cfg = config[0] } // Set default values if cfg.Next == nil { cfg.Next = ConfigDefault.Next } if cfg.Max <= 0 { cfg.Max = ConfigDefault.Max } if int(cfg.Expiration.Seconds()) <= 0 { cfg.Expiration = ConfigDefault.Expiration } if cfg.KeyGenerator == nil { cfg.KeyGenerator = ConfigDefault.KeyGenerator } if cfg.LimitReached == nil { cfg.LimitReached = ConfigDefault.LimitReached } if cfg.LimiterMiddleware == nil { cfg.LimiterMiddleware = ConfigDefault.LimiterMiddleware } if cfg.MaxFunc == nil { cfg.MaxFunc = func(_ fiber.Ctx) int { return cfg.Max } } if cfg.ExpirationFunc == nil { cfg.ExpirationFunc = func(_ fiber.Ctx) time.Duration { return cfg.Expiration } } return cfg } ================================================ FILE: middleware/limiter/limiter.go ================================================ package limiter import ( "errors" "github.com/gofiber/fiber/v3" ) const ( // X-RateLimit-* headers xRateLimitLimit = "X-RateLimit-Limit" xRateLimitRemaining = "X-RateLimit-Remaining" xRateLimitReset = "X-RateLimit-Reset" ) // Handler defines a rate-limiting strategy that can produce a middleware // handler using the provided configuration. type Handler interface { New(config *Config) fiber.Handler } // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Set default config cfg := configDefault(config...) // Return the specified middleware handler. return cfg.LimiterMiddleware.New(&cfg) } // getEffectiveStatusCode returns the actual status code, considering both the error and response status func getEffectiveStatusCode(c fiber.Ctx, err error) int { // If there's an error and it's a *fiber.Error, use its status code if err != nil { var fiberErr *fiber.Error if errors.As(err, &fiberErr) { return fiberErr.Code } } // Otherwise, use the response status code return c.Response().StatusCode() } ================================================ FILE: middleware/limiter/limiter_fixed.go ================================================ package limiter import ( "fmt" "strconv" "sync" "github.com/gofiber/fiber/v3" "github.com/gofiber/utils/v2" ) // FixedWindow implements a fixed-window rate limiting strategy. type FixedWindow struct{} // New creates a new fixed window middleware handler func (FixedWindow) New(cfg *Config) fiber.Handler { if cfg == nil { defaultCfg := configDefault() cfg = &defaultCfg } // Limiter variables mux := &sync.RWMutex{} // Create manager to simplify storage operations ( see manager.go ) manager := newManager(cfg.Storage, !cfg.DisableValueRedaction) // Update timestamp every second utils.StartTimeStampUpdater() // Return new handler return func(c fiber.Ctx) error { // Generate maxRequests from generator, if no generator was provided the default value returned is 5 maxRequests := cfg.MaxFunc(c) // Don't execute middleware if Next returns true or if the max is 0 if (cfg.Next != nil && cfg.Next(c)) || maxRequests == 0 { return c.Next() } // Generate expiration from generator expirationDuration := cfg.ExpirationFunc(c) if expirationDuration <= 0 { expirationDuration = ConfigDefault.Expiration } expiration := uint64(expirationDuration.Seconds()) // Get key from request key := cfg.KeyGenerator(c) // Lock entry mux.Lock() reqCtx := c.Context() // Get entry from pool and release when finished e, err := manager.get(reqCtx, key) if err != nil { mux.Unlock() return err } // Get timestamp ts := uint64(utils.Timestamp()) // Set expiration if entry does not exist if e.exp == 0 { e.exp = ts + expiration } else if ts >= e.exp { // Check if entry is expired e.currHits = 0 e.exp = ts + expiration } // Increment hits e.currHits++ // Calculate when it resets in seconds resetInSec := e.exp - ts // Set how many hits we have left remaining := maxRequests - e.currHits // Update storage if setErr := manager.set(reqCtx, key, e, expirationDuration); setErr != nil { mux.Unlock() return fmt.Errorf("limiter: failed to persist state: %w", setErr) } // Unlock entry mux.Unlock() // Check if hits exceed the max if remaining < 0 { // Return response with Retry-After header // https://tools.ietf.org/html/rfc6584 if !cfg.DisableHeaders { c.Set(fiber.HeaderRetryAfter, strconv.FormatUint(resetInSec, 10)) } // Call LimitReached handler return cfg.LimitReached(c) } // Continue stack for reaching c.Response().StatusCode() // Store err for returning err = c.Next() // Get the effective status code from either the error or response statusCode := getEffectiveStatusCode(c, err) // Check for SkipFailedRequests and SkipSuccessfulRequests if (cfg.SkipSuccessfulRequests && statusCode < fiber.StatusBadRequest) || (cfg.SkipFailedRequests && statusCode >= fiber.StatusBadRequest) { // Lock entry mux.Lock() entry, getErr := manager.get(reqCtx, key) if getErr != nil { mux.Unlock() return getErr } e = entry e.currHits-- remaining++ if setErr := manager.set(reqCtx, key, e, expirationDuration); setErr != nil { mux.Unlock() return fmt.Errorf("limiter: failed to persist state: %w", setErr) } // Unlock entry mux.Unlock() } // We can continue, update RateLimit headers if !cfg.DisableHeaders { c.Set(xRateLimitLimit, strconv.Itoa(maxRequests)) c.Set(xRateLimitRemaining, strconv.Itoa(remaining)) c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10)) } return err } } ================================================ FILE: middleware/limiter/limiter_sliding.go ================================================ package limiter import ( "fmt" "math" "strconv" "sync" "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/utils/v2" ) // SlidingWindow implements the sliding-window rate limiting strategy. type SlidingWindow struct{} // New creates a new sliding window middleware handler func (SlidingWindow) New(cfg *Config) fiber.Handler { if cfg == nil { defaultCfg := configDefault() cfg = &defaultCfg } // Limiter variables mux := &sync.RWMutex{} // Create manager to simplify storage operations ( see manager.go ) manager := newManager(cfg.Storage, !cfg.DisableValueRedaction) // Update timestamp every second utils.StartTimeStampUpdater() // Return new handler return func(c fiber.Ctx) error { // Generate maxRequests from generator, if no generator was provided the default value returned is 5 maxRequests := cfg.MaxFunc(c) // Don't execute middleware if Next returns true or if the max is 0 if (cfg.Next != nil && cfg.Next(c)) || maxRequests == 0 { return c.Next() } // Generate expiration from generator expirationDuration := cfg.ExpirationFunc(c) if expirationDuration <= 0 { expirationDuration = ConfigDefault.Expiration } expiration := uint64(expirationDuration.Seconds()) // Get key from request key := cfg.KeyGenerator(c) // Lock entry mux.Lock() reqCtx := c.Context() // Get entry from pool and release when finished e, err := manager.get(reqCtx, key) if err != nil { mux.Unlock() return err } // Get timestamp ts := uint64(utils.Timestamp()) // Rotate window resetInSec := rotateWindow(e, ts, expiration) windowExpiresAt := e.exp // Increment hits e.currHits++ // weight = time until current window reset / total window length weight := float64(resetInSec) / float64(expiration) // rate = request count in previous window - weight + request count in current window rate := int(math.Ceil(float64(e.prevHits)*weight)) + e.currHits // Calculate how many hits can be made based on the current rate remaining := maxRequests - rate // Update storage. Garbage collect when the next window ends. // |--------------------------|--------------------------| // ^ ^ ^ ^ // ts e.exp End sample window End next window // <------------> // Reset In Sec // resetInSec = e.exp - ts - time until end of current window. // duration + expiration = end of next window. // Because we don't want to garbage collect in the middle of a window // we add the expiration to the duration. // Otherwise, after the end of "sample window", attackers could launch // a new request with the full window length. if setErr := manager.set(reqCtx, key, e, ttlDuration(resetInSec, expiration)); setErr != nil { mux.Unlock() return fmt.Errorf("limiter: failed to persist state: %w", setErr) } // Unlock entry mux.Unlock() // Check if hits exceed the allowed maximum for this request if remaining < 0 { // Return response with Retry-After header // https://tools.ietf.org/html/rfc6584 if !cfg.DisableHeaders { c.Set(fiber.HeaderRetryAfter, strconv.FormatUint(resetInSec, 10)) } // Call LimitReached handler return cfg.LimitReached(c) } // Continue stack for reaching c.Response().StatusCode() // Store err for returning err = c.Next() // Get the effective status code from either the error or response statusCode := getEffectiveStatusCode(c, err) skipHit := (cfg.SkipSuccessfulRequests && statusCode < fiber.StatusBadRequest) || (cfg.SkipFailedRequests && statusCode >= fiber.StatusBadRequest) if skipHit || !cfg.DisableHeaders { // Lock entry mux.Lock() entry, getErr := manager.get(reqCtx, key) if getErr != nil { mux.Unlock() return getErr } e = entry ts = uint64(utils.Timestamp()) resetInSec = rotateWindow(e, ts, expiration) weight = float64(resetInSec) / float64(expiration) if skipHit { if counter := bucketForOriginalHit(e, windowExpiresAt, ts, expiration); counter != nil && *counter > 0 { *counter-- } } rate = int(math.Ceil(float64(e.prevHits)*weight)) + e.currHits remaining = maxRequests - rate if setErr := manager.set(reqCtx, key, e, ttlDuration(resetInSec, expiration)); setErr != nil { mux.Unlock() return fmt.Errorf("limiter: failed to persist state: %w", setErr) } // Unlock entry mux.Unlock() // We can continue, update RateLimit headers if !cfg.DisableHeaders { c.Set(xRateLimitLimit, strconv.Itoa(maxRequests)) c.Set(xRateLimitRemaining, strconv.Itoa(remaining)) c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10)) } } return err } } func rotateWindow(e *item, ts, expiration uint64) uint64 { // Set expiration if entry does not exist if e.exp == 0 { e.exp = ts + expiration } else if ts >= e.exp { // The entry has expired, handle the expiration. // Reset the current hits to 0. elapsed := ts - e.exp if elapsed >= expiration { e.prevHits = 0 e.currHits = 0 e.exp = ts + expiration } else { e.prevHits = e.currHits e.currHits = 0 e.exp = ts + expiration - elapsed } } // Calculate when it resets in seconds return e.exp - ts } func bucketForOriginalHit(e *item, requestExpiration, ts, expiration uint64) *int { if ts < requestExpiration { return &e.currHits } if ts-requestExpiration < expiration { return &e.prevHits } return nil } func ttlDuration(resetInSec, expiration uint64) time.Duration { resetDuration, ok := secondsToDuration(resetInSec) if !ok { return time.Duration(math.MaxInt64) } expirationDuration, ok := secondsToDuration(expiration) if !ok { return time.Duration(math.MaxInt64) } if resetDuration > time.Duration(math.MaxInt64)-expirationDuration { return time.Duration(math.MaxInt64) } return resetDuration + expirationDuration } func secondsToDuration(seconds uint64) (time.Duration, bool) { const maxSeconds = math.MaxInt64 / int64(time.Second) if seconds > uint64(maxSeconds) { return time.Duration(math.MaxInt64), false } return time.Duration(seconds) * time.Second, true } ================================================ FILE: middleware/limiter/limiter_test.go ================================================ package limiter import ( "context" "errors" "io" "net/http" "net/http/httptest" "strconv" "sync" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/internal/storage/memory" ) type failingLimiterStorage struct { data map[string][]byte errs map[string]error } const testLimiterClientKey = "client-key" func newFailingLimiterStorage() *failingLimiterStorage { return &failingLimiterStorage{ data: make(map[string][]byte), errs: make(map[string]error), } } // countingFailStorage fails set operations after a specified number of successful calls type countingFailStorage struct { *failingLimiterStorage setFailErr error setCount int failAfterN int } func newCountingFailStorage(failAfterN int, err error) *countingFailStorage { return &countingFailStorage{ failingLimiterStorage: newFailingLimiterStorage(), failAfterN: failAfterN, setFailErr: err, } } func (s *countingFailStorage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { s.setCount++ if s.setCount > s.failAfterN { return s.setFailErr } return s.failingLimiterStorage.SetWithContext(ctx, key, val, exp) } type contextRecord struct { key string value string canceled bool } type contextRecorderLimiterStorage struct { *failingLimiterStorage gets []contextRecord sets []contextRecord } func sleepForRetryAfter(t *testing.T, resp *http.Response) { t.Helper() retryAfter := resp.Header.Get(fiber.HeaderRetryAfter) if retryAfter == "" { time.Sleep(500 * time.Millisecond) return } seconds, err := strconv.Atoi(retryAfter) require.NoError(t, err) delay := time.Duration(seconds) * time.Second // Sliding window needs roughly 2x the reported delay for the previous window to expire. if doubled := 2 * delay; doubled > delay { delay = doubled } if minDelay := 4 * time.Second; delay < minDelay { delay = minDelay } time.Sleep(delay + 500*time.Millisecond) } func newContextRecorderLimiterStorage() *contextRecorderLimiterStorage { return &contextRecorderLimiterStorage{failingLimiterStorage: newFailingLimiterStorage()} } func contextRecordFrom(ctx context.Context, key string) contextRecord { record := contextRecord{ key: key, canceled: errors.Is(ctx.Err(), context.Canceled), } if value, ok := ctx.Value(markerKey).(string); ok { record.value = value } return record } func (s *contextRecorderLimiterStorage) GetWithContext(ctx context.Context, key string) ([]byte, error) { s.gets = append(s.gets, contextRecordFrom(ctx, key)) return s.failingLimiterStorage.GetWithContext(ctx, key) } func (s *contextRecorderLimiterStorage) SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error { s.sets = append(s.sets, contextRecordFrom(ctx, key)) return s.failingLimiterStorage.SetWithContext(ctx, key, val, exp) } func (s *contextRecorderLimiterStorage) recordedGets() []contextRecord { out := make([]contextRecord, len(s.gets)) copy(out, s.gets) return out } func (s *contextRecorderLimiterStorage) recordedSets() []contextRecord { out := make([]contextRecord, len(s.sets)) copy(out, s.sets) return out } func (s *failingLimiterStorage) GetWithContext(_ context.Context, key string) ([]byte, error) { if err, ok := s.errs["get|"+key]; ok && err != nil { return nil, err } if val, ok := s.data[key]; ok { return append([]byte(nil), val...), nil } return nil, nil } func (s *failingLimiterStorage) Get(key string) ([]byte, error) { return s.GetWithContext(context.Background(), key) } func (s *failingLimiterStorage) SetWithContext(_ context.Context, key string, val []byte, _ time.Duration) error { if err, ok := s.errs["set|"+key]; ok && err != nil { return err } s.data[key] = append([]byte(nil), val...) return nil } func (s *failingLimiterStorage) Set(key string, val []byte, exp time.Duration) error { return s.SetWithContext(context.Background(), key, val, exp) } func (*failingLimiterStorage) DeleteWithContext(context.Context, string) error { return nil } func (*failingLimiterStorage) Delete(string) error { return nil } func (*failingLimiterStorage) ResetWithContext(context.Context) error { return nil } func (*failingLimiterStorage) Reset() error { return nil } func (*failingLimiterStorage) Close() error { return nil } type contextKey string const markerKey contextKey = "marker" func contextWithMarker(label string) context.Context { return context.WithValue(context.Background(), markerKey, label) } func canceledContextWithMarker(label string) context.Context { ctx, cancel := context.WithCancel(contextWithMarker(label)) cancel() return ctx } func TestLimiterDefaultConfigNoPanic(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { return c.SendString("ok") }) require.NotPanics(t, func() { resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) }) } func TestLimiterFixedStorageGetError(t *testing.T) { t.Parallel() storage := newFailingLimiterStorage() storage.errs["get|"+testLimiterClientKey] = errors.New("boom") var captured error app := fiber.New(fiber.Config{ ErrorHandler: func(c fiber.Ctx, err error) error { captured = err return c.Status(fiber.StatusInternalServerError).SendString("storage failure") }, }) app.Use(New(Config{Storage: storage, Max: 1, Expiration: time.Second, KeyGenerator: func(fiber.Ctx) string { return testLimiterClientKey }})) app.Get("/", func(c fiber.Ctx) error { return c.SendString("ok") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) require.Error(t, captured) require.ErrorContains(t, captured, "limiter: failed to get key") require.ErrorContains(t, captured, "[redacted]") } func TestLimiterFixedStorageSetError(t *testing.T) { t.Parallel() storage := newFailingLimiterStorage() storage.errs["set|"+testLimiterClientKey] = errors.New("boom") var captured error app := fiber.New(fiber.Config{ ErrorHandler: func(c fiber.Ctx, err error) error { captured = err return c.Status(fiber.StatusInternalServerError).SendString("storage failure") }, }) app.Use(New(Config{Storage: storage, Max: 1, Expiration: time.Second, KeyGenerator: func(fiber.Ctx) string { return testLimiterClientKey }})) app.Get("/", func(c fiber.Ctx) error { return c.SendString("ok") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) require.Error(t, captured) require.ErrorContains(t, captured, "limiter: failed to persist state") require.ErrorContains(t, captured, "limiter: failed to store key") require.ErrorContains(t, captured, "[redacted]") } func TestLimiterFixedPropagatesRequestContextToStorage(t *testing.T) { t.Parallel() storage := newContextRecorderLimiterStorage() app := fiber.New() app.Use(func(c fiber.Ctx) error { path := c.Path() if path == "/normal" { c.SetContext(contextWithMarker("fixed-normal")) } if path == "/rollback" { c.SetContext(canceledContextWithMarker("fixed-rollback")) } return c.Next() }) app.Use(New(Config{ Storage: storage, Max: 1, Expiration: time.Minute, SkipSuccessfulRequests: true, KeyGenerator: func(c fiber.Ctx) string { return c.Path() }, LimiterMiddleware: FixedWindow{}, })) app.Get("/:mode", func(c fiber.Ctx) error { return c.SendString("ok") }) for _, path := range []string{"/normal", "/rollback"} { resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, path, http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) } gets := storage.recordedGets() require.Len(t, gets, 4) sets := storage.recordedSets() require.Len(t, sets, 4) verifyRecords := func(t *testing.T, records []contextRecord, key, wantValue string, wantCanceled bool) { t.Helper() var matched []contextRecord for _, rec := range records { if rec.key == key { matched = append(matched, rec) } } require.Len(t, matched, 2) for _, rec := range matched { require.Equal(t, wantValue, rec.value) require.Equal(t, wantCanceled, rec.canceled) } } verifyRecords(t, gets, "/normal", "fixed-normal", false) verifyRecords(t, gets, "/rollback", "fixed-rollback", true) verifyRecords(t, sets, "/normal", "fixed-normal", false) verifyRecords(t, sets, "/rollback", "fixed-rollback", true) } func TestLimiterFixedStorageGetErrorDisableRedaction(t *testing.T) { t.Parallel() storage := newFailingLimiterStorage() storage.errs["get|"+testLimiterClientKey] = errors.New("boom") var captured error app := fiber.New(fiber.Config{ ErrorHandler: func(c fiber.Ctx, err error) error { captured = err return c.Status(fiber.StatusInternalServerError).SendString("storage failure") }, }) app.Use(New(Config{DisableValueRedaction: true, Storage: storage, Max: 1, Expiration: time.Second, KeyGenerator: func(fiber.Ctx) string { return testLimiterClientKey }})) app.Get("/", func(c fiber.Ctx) error { return c.SendString("ok") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) require.Error(t, captured) require.ErrorContains(t, captured, testLimiterClientKey) require.NotContains(t, captured.Error(), "[redacted]") } func TestLimiterFixedStorageSetErrorDisableRedaction(t *testing.T) { t.Parallel() storage := newFailingLimiterStorage() storage.errs["set|"+testLimiterClientKey] = errors.New("boom") var captured error app := fiber.New(fiber.Config{ ErrorHandler: func(c fiber.Ctx, err error) error { captured = err return c.Status(fiber.StatusInternalServerError).SendString("storage failure") }, }) app.Use(New(Config{DisableValueRedaction: true, Storage: storage, Max: 1, Expiration: time.Second, KeyGenerator: func(fiber.Ctx) string { return testLimiterClientKey }})) app.Get("/", func(c fiber.Ctx) error { return c.SendString("ok") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) require.Error(t, captured) require.ErrorContains(t, captured, testLimiterClientKey) require.NotContains(t, captured.Error(), "[redacted]") } func TestLimiterFixedStorageSetErrorOnSkipSuccessfulRequests(t *testing.T) { t.Parallel() storage := newCountingFailStorage(1, errors.New("second set failed")) var captured error app := fiber.New(fiber.Config{ ErrorHandler: func(c fiber.Ctx, err error) error { captured = err return c.Status(fiber.StatusInternalServerError).SendString("storage failure") }, }) app.Use(New(Config{ Storage: storage, Max: 10, Expiration: time.Second, SkipSuccessfulRequests: true, KeyGenerator: func(fiber.Ctx) string { return testLimiterClientKey }, })) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) require.Error(t, captured) require.ErrorContains(t, captured, "limiter: failed to persist state") } func TestLimiterSlidingPropagatesRequestContextToStorage(t *testing.T) { t.Parallel() storage := newContextRecorderLimiterStorage() app := fiber.New() app.Use(func(c fiber.Ctx) error { path := c.Path() if path == "/normal" { c.SetContext(contextWithMarker("sliding-normal")) } if path == "/rollback" { c.SetContext(canceledContextWithMarker("sliding-rollback")) } return c.Next() }) app.Use(New(Config{ Storage: storage, Max: 1, Expiration: time.Minute, SkipSuccessfulRequests: true, KeyGenerator: func(c fiber.Ctx) string { return c.Path() }, LimiterMiddleware: SlidingWindow{}, })) app.Get("/:mode", func(c fiber.Ctx) error { return c.SendString("ok") }) for _, path := range []string{"/normal", "/rollback"} { resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, path, http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) } gets := storage.recordedGets() require.Len(t, gets, 4) sets := storage.recordedSets() require.Len(t, sets, 4) verifyRecords := func(t *testing.T, records []contextRecord, key, wantValue string, wantCanceled bool) { t.Helper() var matched []contextRecord for _, rec := range records { if rec.key == key { matched = append(matched, rec) } } require.Len(t, matched, 2) for _, rec := range matched { require.Equal(t, wantValue, rec.value) require.Equal(t, wantCanceled, rec.canceled) } } verifyRecords(t, gets, "/normal", "sliding-normal", false) verifyRecords(t, gets, "/rollback", "sliding-rollback", true) verifyRecords(t, sets, "/normal", "sliding-normal", false) verifyRecords(t, sets, "/rollback", "sliding-rollback", true) } func TestLimiterSlidingSkipsPostUpdateWhenHeadersDisabled(t *testing.T) { t.Parallel() storage := newContextRecorderLimiterStorage() app := fiber.New() app.Use(New(Config{ Max: 1, Expiration: time.Second, Storage: storage, DisableHeaders: true, LimiterMiddleware: SlidingWindow{}, })) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Len(t, storage.recordedGets(), 1) require.Len(t, storage.recordedSets(), 1) } // go test -run Test_Limiter_With_Max_Func_With_Zero -race -v func Test_Limiter_With_Max_Func_With_Zero_And_Limiter_Sliding(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ MaxFunc: func(_ fiber.Ctx) int { return 0 }, Expiration: 2 * time.Second, SkipFailedRequests: false, SkipSuccessfulRequests: false, LimiterMiddleware: SlidingWindow{}, })) app.Get("/:status", func(c fiber.Ctx) error { if c.Params("status") == "fail" { return c.SendStatus(400) } return c.SendStatus(200) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody)) require.NoError(t, err) require.Equal(t, 400, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) time.Sleep(4*time.Second + 500*time.Millisecond) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) } func Test_Limiter_Sliding_MaxFuncOverridesStaticMax(t *testing.T) { t.Parallel() app := fiber.New() staticMax := 5 dynamicMax := 2 app.Use(New(Config{ Max: staticMax, MaxFunc: func(fiber.Ctx) int { return dynamicMax }, Expiration: 2 * time.Second, LimiterMiddleware: SlidingWindow{}, })) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Equal(t, strconv.Itoa(dynamicMax), resp.Header.Get("X-RateLimit-Limit")) require.Equal(t, strconv.Itoa(dynamicMax-1), resp.Header.Get("X-RateLimit-Remaining")) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Equal(t, strconv.Itoa(dynamicMax), resp.Header.Get("X-RateLimit-Limit")) require.Equal(t, strconv.Itoa(dynamicMax-2), resp.Header.Get("X-RateLimit-Remaining")) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusTooManyRequests, resp.StatusCode) } // go test -run Test_Limiter_With_Max_Func_With_Zero -race -v func Test_Limiter_With_Max_Func_With_Zero(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ MaxFunc: func(_ fiber.Ctx) int { return 0 }, Expiration: 2 * time.Second, Storage: memory.New(), })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello tester!") }) var wg sync.WaitGroup for i := 0; i <= 4; i++ { wg.Go(func() { resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) assert.Equal(t, fiber.StatusOK, resp.StatusCode) body, err := io.ReadAll(resp.Body) require.NoError(t, err) assert.Equal(t, "Hello tester!", string(body)) }) } wg.Wait() resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) } // go test -run Test_Limiter_With_Max_Func -race -v func Test_Limiter_With_Max_Func(t *testing.T) { t.Parallel() app := fiber.New() maxRequests := 10 app.Use(New(Config{ MaxFunc: func(_ fiber.Ctx) int { return maxRequests }, Expiration: 2 * time.Second, Storage: memory.New(), })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello tester!") }) var wg sync.WaitGroup for i := 0; i <= maxRequests-1; i++ { wg.Go(func() { resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) assert.Equal(t, fiber.StatusOK, resp.StatusCode) body, err := io.ReadAll(resp.Body) require.NoError(t, err) assert.Equal(t, "Hello tester!", string(body)) }) } wg.Wait() resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, 429, resp.StatusCode) time.Sleep(3 * time.Second) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) } // go test -run Test_Limiter_Fixed_ExpirationFuncOverridesStaticExpiration -race -v func Test_Limiter_Fixed_ExpirationFuncOverridesStaticExpiration(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Max: 2, Expiration: 10 * time.Second, ExpirationFunc: func(_ fiber.Ctx) time.Duration { return 2 * time.Second }, LimiterMiddleware: FixedWindow{}, })) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusTooManyRequests, resp.StatusCode) time.Sleep(3 * time.Second) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) } // go test -run Test_Limiter_Sliding_ExpirationFuncOverridesStaticExpiration -race -v func Test_Limiter_Sliding_ExpirationFuncOverridesStaticExpiration(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Max: 2, Expiration: 10 * time.Second, ExpirationFunc: func(_ fiber.Ctx) time.Duration { return 2 * time.Second }, LimiterMiddleware: SlidingWindow{}, })) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusTooManyRequests, resp.StatusCode) // Sliding window needs ~2x expiration to fully reset (considers previous window) sleepForRetryAfter(t, resp) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) } // go test -run Test_Limiter_Fixed_ExpirationFunc_FallbackOnZeroDuration -race -v func Test_Limiter_Fixed_ExpirationFunc_FallbackOnZeroDuration(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Max: 1, ExpirationFunc: func(_ fiber.Ctx) time.Duration { return 0 }, LimiterMiddleware: FixedWindow{}, })) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusTooManyRequests, resp.StatusCode) } // go test -run Test_Limiter_Fixed_ExpirationFunc_FallbackOnNegativeDuration -race -v func Test_Limiter_Fixed_ExpirationFunc_FallbackOnNegativeDuration(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Max: 1, ExpirationFunc: func(_ fiber.Ctx) time.Duration { return -1 * time.Second }, LimiterMiddleware: FixedWindow{}, })) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusTooManyRequests, resp.StatusCode) } // go test -run Test_Limiter_Sliding_ExpirationFunc_FallbackOnZeroDuration -race -v func Test_Limiter_Sliding_ExpirationFunc_FallbackOnZeroDuration(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Max: 1, ExpirationFunc: func(_ fiber.Ctx) time.Duration { return 0 }, LimiterMiddleware: SlidingWindow{}, })) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusTooManyRequests, resp.StatusCode) } // go test -run Test_Limiter_Sliding_ExpirationFunc_FallbackOnNegativeDuration -race -v func Test_Limiter_Sliding_ExpirationFunc_FallbackOnNegativeDuration(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Max: 1, ExpirationFunc: func(_ fiber.Ctx) time.Duration { return -1 * time.Second }, LimiterMiddleware: SlidingWindow{}, })) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusTooManyRequests, resp.StatusCode) } // go test -run Test_Limiter_Concurrency_Store -race -v func Test_Limiter_Concurrency_Store(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Max: 50, Expiration: 2 * time.Second, Storage: memory.New(), })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello tester!") }) var wg sync.WaitGroup for i := 0; i <= 49; i++ { wg.Go(func() { resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) assert.Equal(t, fiber.StatusOK, resp.StatusCode) body, err := io.ReadAll(resp.Body) require.NoError(t, err) assert.Equal(t, "Hello tester!", string(body)) }) } wg.Wait() resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, 429, resp.StatusCode) time.Sleep(3 * time.Second) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) } // go test -run Test_Limiter_Concurrency -race -v func Test_Limiter_Concurrency(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Max: 50, Expiration: 2 * time.Second, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello tester!") }) var wg sync.WaitGroup for i := 0; i <= 49; i++ { wg.Go(func() { resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) assert.Equal(t, fiber.StatusOK, resp.StatusCode) body, err := io.ReadAll(resp.Body) require.NoError(t, err) assert.Equal(t, "Hello tester!", string(body)) }) } wg.Wait() resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, 429, resp.StatusCode) time.Sleep(3 * time.Second) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) } // go test -run Test_Limiter_Fixed_Window_No_Skip_Choices -v func Test_Limiter_Fixed_Window_No_Skip_Choices(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Max: 2, Expiration: 2 * time.Second, SkipFailedRequests: false, SkipSuccessfulRequests: false, LimiterMiddleware: FixedWindow{}, })) app.Get("/:status", func(c fiber.Ctx) error { if c.Params("status") == "fail" { return c.SendStatus(400) } return c.SendStatus(200) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody)) require.NoError(t, err) require.Equal(t, 400, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody)) require.NoError(t, err) require.Equal(t, 429, resp.StatusCode) time.Sleep(3 * time.Second) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) } // go test -run Test_Limiter_Fixed_Window_Custom_Storage_No_Skip_Choices -v func Test_Limiter_Fixed_Window_Custom_Storage_No_Skip_Choices(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Max: 2, Expiration: 2 * time.Second, SkipFailedRequests: false, SkipSuccessfulRequests: false, Storage: memory.New(), LimiterMiddleware: FixedWindow{}, })) app.Get("/:status", func(c fiber.Ctx) error { if c.Params("status") == "fail" { return c.SendStatus(400) } return c.SendStatus(200) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody)) require.NoError(t, err) require.Equal(t, 400, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody)) require.NoError(t, err) require.Equal(t, 429, resp.StatusCode) time.Sleep(3 * time.Second) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) } // go test -run Test_Limiter_Sliding_Window_No_Skip_Choices -v func Test_Limiter_Sliding_Window_No_Skip_Choices(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Max: 2, Expiration: 2 * time.Second, SkipFailedRequests: false, SkipSuccessfulRequests: false, LimiterMiddleware: SlidingWindow{}, })) app.Get("/:status", func(c fiber.Ctx) error { if c.Params("status") == "fail" { return c.SendStatus(400) } return c.SendStatus(200) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody)) require.NoError(t, err) require.Equal(t, 400, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody)) require.NoError(t, err) require.Equal(t, 429, resp.StatusCode) sleepForRetryAfter(t, resp) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) } // go test -run Test_Limiter_Sliding_Window_Custom_Storage_No_Skip_Choices -v func Test_Limiter_Sliding_Window_Custom_Storage_No_Skip_Choices(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Max: 2, Expiration: 2 * time.Second, SkipFailedRequests: false, SkipSuccessfulRequests: false, Storage: memory.New(), LimiterMiddleware: SlidingWindow{}, })) app.Get("/:status", func(c fiber.Ctx) error { if c.Params("status") == "fail" { return c.SendStatus(400) } return c.SendStatus(200) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody)) require.NoError(t, err) require.Equal(t, 400, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody)) require.NoError(t, err) require.Equal(t, 429, resp.StatusCode) sleepForRetryAfter(t, resp) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) } func Test_Limiter_Sliding_Window_RecalculatesAfterHandlerDelay(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Max: 2, Expiration: time.Second, LimiterMiddleware: SlidingWindow{}, })) app.Get("/", func(c fiber.Ctx) error { time.Sleep(600 * time.Millisecond) return c.SendStatus(fiber.StatusOK) }) for i := 0; i < 2; i++ { resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) } time.Sleep(time.Second + 100*time.Millisecond) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Equal(t, "2", resp.Header.Get(xRateLimitLimit)) require.Equal(t, "1", resp.Header.Get(xRateLimitRemaining)) require.NotEmpty(t, resp.Header.Get(xRateLimitReset)) } func Test_Limiter_Sliding_Window_ExpiresStalePrevHits(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Max: 1, Expiration: time.Second, LimiterMiddleware: SlidingWindow{}, })) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) time.Sleep(2500 * time.Millisecond) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Equal(t, "0", resp.Header.Get(xRateLimitRemaining)) } func Test_Limiter_Sliding_Window_SkipFailedRequests_DecrementsPreviousWindow(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Max: 2, Expiration: 200 * time.Millisecond, SkipFailedRequests: true, LimiterMiddleware: SlidingWindow{}, })) app.Get("/:mode", func(c fiber.Ctx) error { if c.Params("mode") == "fail" { time.Sleep(300 * time.Millisecond) return c.SendStatus(fiber.StatusInternalServerError) } return c.SendStatus(fiber.StatusOK) }) type respErr struct { resp *http.Response err error } failCh := make(chan respErr, 1) go func() { resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody)) failCh <- respErr{resp: resp, err: err} }() time.Sleep(220 * time.Millisecond) successResp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/ok", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, successResp.StatusCode) result := <-failCh require.NoError(t, result.err) require.Equal(t, fiber.StatusInternalServerError, result.resp.StatusCode) require.Equal(t, "2", result.resp.Header.Get(xRateLimitLimit)) require.Equal(t, "1", result.resp.Header.Get(xRateLimitRemaining)) assert.NotEmpty(t, result.resp.Header.Get(xRateLimitReset)) } // go test -run Test_Limiter_Fixed_Window_Skip_Failed_Requests -v func Test_Limiter_Fixed_Window_Skip_Failed_Requests(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Max: 1, Expiration: 2 * time.Second, SkipFailedRequests: true, LimiterMiddleware: FixedWindow{}, })) app.Get("/:status", func(c fiber.Ctx) error { if c.Params("status") == "fail" { return c.SendStatus(400) } return c.SendStatus(200) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody)) require.NoError(t, err) require.Equal(t, 400, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody)) require.NoError(t, err) require.Equal(t, 429, resp.StatusCode) time.Sleep(3 * time.Second) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) } // go test -run Test_Limiter_Fixed_Window_Custom_Storage_Skip_Failed_Requests -v func Test_Limiter_Fixed_Window_Custom_Storage_Skip_Failed_Requests(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Max: 1, Expiration: 2 * time.Second, Storage: memory.New(), SkipFailedRequests: true, LimiterMiddleware: FixedWindow{}, })) app.Get("/:status", func(c fiber.Ctx) error { if c.Params("status") == "fail" { return c.SendStatus(400) } return c.SendStatus(200) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody)) require.NoError(t, err) require.Equal(t, 400, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody)) require.NoError(t, err) require.Equal(t, 429, resp.StatusCode) time.Sleep(3 * time.Second) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) } // go test -run Test_Limiter_Sliding_Window_Skip_Failed_Requests -v func Test_Limiter_Sliding_Window_Skip_Failed_Requests(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Max: 1, Expiration: 2 * time.Second, SkipFailedRequests: true, LimiterMiddleware: SlidingWindow{}, })) app.Get("/:status", func(c fiber.Ctx) error { if c.Params("status") == "fail" { return c.SendStatus(400) } return c.SendStatus(200) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody)) require.NoError(t, err) require.Equal(t, 400, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody)) require.NoError(t, err) require.Equal(t, 429, resp.StatusCode) sleepForRetryAfter(t, resp) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) } // go test -run Test_Limiter_Sliding_Window_Custom_Storage_Skip_Failed_Requests -v func Test_Limiter_Sliding_Window_Custom_Storage_Skip_Failed_Requests(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Max: 1, Expiration: 2 * time.Second, Storage: memory.New(), SkipFailedRequests: true, LimiterMiddleware: SlidingWindow{}, })) app.Get("/:status", func(c fiber.Ctx) error { if c.Params("status") == "fail" { return c.SendStatus(400) } return c.SendStatus(200) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody)) require.NoError(t, err) require.Equal(t, 400, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody)) require.NoError(t, err) require.Equal(t, 429, resp.StatusCode) sleepForRetryAfter(t, resp) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) } // go test -run Test_Limiter_Fixed_Window_Skip_Successful_Requests -v func Test_Limiter_Fixed_Window_Skip_Successful_Requests(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Max: 1, Expiration: 2 * time.Second, SkipSuccessfulRequests: true, LimiterMiddleware: FixedWindow{}, })) app.Get("/:status", func(c fiber.Ctx) error { if c.Params("status") == "fail" { return c.SendStatus(400) } return c.SendStatus(200) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody)) require.NoError(t, err) require.Equal(t, 400, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody)) require.NoError(t, err) require.Equal(t, 429, resp.StatusCode) time.Sleep(3 * time.Second) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody)) require.NoError(t, err) require.Equal(t, 400, resp.StatusCode) } // go test -run Test_Limiter_Fixed_Window_Custom_Storage_Skip_Successful_Requests -v func Test_Limiter_Fixed_Window_Custom_Storage_Skip_Successful_Requests(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Max: 1, Expiration: 2 * time.Second, Storage: memory.New(), SkipSuccessfulRequests: true, LimiterMiddleware: FixedWindow{}, })) app.Get("/:status", func(c fiber.Ctx) error { if c.Params("status") == "fail" { return c.SendStatus(400) } return c.SendStatus(200) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody)) require.NoError(t, err) require.Equal(t, 400, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody)) require.NoError(t, err) require.Equal(t, 429, resp.StatusCode) time.Sleep(3 * time.Second) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody)) require.NoError(t, err) require.Equal(t, 400, resp.StatusCode) } // go test -run Test_Limiter_Sliding_Window_Skip_Successful_Requests -v func Test_Limiter_Sliding_Window_Skip_Successful_Requests(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Max: 1, Expiration: 2 * time.Second, SkipSuccessfulRequests: true, LimiterMiddleware: SlidingWindow{}, })) app.Get("/:status", func(c fiber.Ctx) error { if c.Params("status") == "fail" { return c.SendStatus(400) } return c.SendStatus(200) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody)) require.NoError(t, err) require.Equal(t, 400, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody)) require.NoError(t, err) require.Equal(t, 429, resp.StatusCode) sleepForRetryAfter(t, resp) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody)) require.NoError(t, err) require.Equal(t, 400, resp.StatusCode) } // go test -run Test_Limiter_Sliding_Window_Custom_Storage_Skip_Successful_Requests -v func Test_Limiter_Sliding_Window_Custom_Storage_Skip_Successful_Requests(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Max: 1, Expiration: 2 * time.Second, Storage: memory.New(), SkipSuccessfulRequests: true, LimiterMiddleware: SlidingWindow{}, })) app.Get("/:status", func(c fiber.Ctx) error { if c.Params("status") == "fail" { return c.SendStatus(400) } return c.SendStatus(200) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/success", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody)) require.NoError(t, err) require.Equal(t, 400, resp.StatusCode) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody)) require.NoError(t, err) require.Equal(t, 429, resp.StatusCode) sleepForRetryAfter(t, resp) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/fail", http.NoBody)) require.NoError(t, err) require.Equal(t, 400, resp.StatusCode) } // go test -v -run=^$ -bench=Benchmark_Limiter_Custom_Store -benchmem -count=4 func Benchmark_Limiter_Custom_Store(b *testing.B) { app := fiber.New() app.Use(New(Config{ Max: 100, Expiration: 60 * time.Second, Storage: memory.New(), })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) h := app.Handler() fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(fiber.MethodGet) fctx.Request.SetRequestURI("/") for b.Loop() { h(fctx) } } // Test to reproduce the bug where fiber.NewErrorf responses are not counted as failed requests func Test_Limiter_Bug_NewErrorf_SkipSuccessfulRequests_SlidingWindow(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Max: 1, Expiration: 60 * time.Second, LimiterMiddleware: SlidingWindow{}, SkipSuccessfulRequests: true, SkipFailedRequests: false, DisableHeaders: true, })) app.Get("/", func(_ fiber.Ctx) error { return fiber.NewErrorf(fiber.StatusInternalServerError, "Error") }) // First request should succeed (and be counted because it's a failed request) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) // Second request should be rate limited because the first failed request was counted // But currently this is not happening due to the bug resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) // This should be 429 (rate limited) but currently returns 500 due to the bug require.Equal(t, fiber.StatusTooManyRequests, resp.StatusCode, "Second request should be rate limited") } // Test to reproduce the bug where fiber.NewErrorf responses are not counted as failed requests (FixedWindow) func Test_Limiter_Bug_NewErrorf_SkipSuccessfulRequests_FixedWindow(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Max: 1, Expiration: 60 * time.Second, LimiterMiddleware: FixedWindow{}, SkipSuccessfulRequests: true, SkipFailedRequests: false, DisableHeaders: true, })) app.Get("/", func(_ fiber.Ctx) error { return fiber.NewErrorf(fiber.StatusInternalServerError, "Error") }) // First request should succeed (and be counted because it's a failed request) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) // Second request should be rate limited because the first failed request was counted // But currently this is not happening due to the bug resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) // This should be 429 (rate limited) but currently returns 500 due to the bug require.Equal(t, fiber.StatusTooManyRequests, resp.StatusCode, "Second request should be rate limited") } // go test -run Test_Limiter_Next func Test_Limiter_Next(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Next: func(_ fiber.Ctx) bool { return true }, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusNotFound, resp.StatusCode) } func Test_Limiter_Headers(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Max: 50, Expiration: 2 * time.Second, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello tester!") }) fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(fiber.MethodGet) fctx.Request.SetRequestURI("/") app.Handler()(fctx) require.Equal(t, "50", string(fctx.Response.Header.Peek("X-RateLimit-Limit"))) if v := string(fctx.Response.Header.Peek("X-RateLimit-Remaining")); v == "" { t.Error("The X-RateLimit-Remaining header is not set correctly - value is an empty string.") } if v := string(fctx.Response.Header.Peek("X-RateLimit-Reset")); (v != "1") && (v != "2") { t.Error("The X-RateLimit-Reset header is not set correctly - value is out of bounds.") } } func Test_Limiter_Disable_Headers(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Max: 1, Expiration: 2 * time.Second, DisableHeaders: true, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello tester!") }) // first request should pass fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(fiber.MethodGet) fctx.Request.SetRequestURI("/") app.Handler()(fctx) require.Equal(t, fiber.StatusOK, fctx.Response.StatusCode()) require.Equal(t, "Hello tester!", string(fctx.Response.Body())) require.Empty(t, string(fctx.Response.Header.Peek("X-RateLimit-Limit"))) require.Empty(t, string(fctx.Response.Header.Peek("X-RateLimit-Remaining"))) require.Empty(t, string(fctx.Response.Header.Peek("X-RateLimit-Reset"))) // second request should hit the limit and return 429 without headers fctx2 := &fasthttp.RequestCtx{} fctx2.Request.Header.SetMethod(fiber.MethodGet) fctx2.Request.SetRequestURI("/") app.Handler()(fctx2) require.Equal(t, fiber.StatusTooManyRequests, fctx2.Response.StatusCode()) require.Empty(t, string(fctx2.Response.Header.Peek(fiber.HeaderRetryAfter))) require.Empty(t, string(fctx2.Response.Header.Peek("X-RateLimit-Limit"))) require.Empty(t, string(fctx2.Response.Header.Peek("X-RateLimit-Remaining"))) require.Empty(t, string(fctx2.Response.Header.Peek("X-RateLimit-Reset"))) } // go test -v -run=^$ -bench=Benchmark_Limiter -benchmem -count=4 func Benchmark_Limiter(b *testing.B) { app := fiber.New() app.Use(New(Config{ Max: 100, Expiration: 60 * time.Second, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) h := app.Handler() fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(fiber.MethodGet) fctx.Request.SetRequestURI("/") for b.Loop() { h(fctx) } } // go test -run Test_Sliding_Window -race -v func Test_Sliding_Window(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Max: 10, Expiration: 1 * time.Second, Storage: memory.New(), LimiterMiddleware: SlidingWindow{}, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello tester!") }) singleRequest := func(shouldFail bool) { resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) if shouldFail { require.NoError(t, err) require.Equal(t, 429, resp.StatusCode) } else { require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) } } for range 5 { singleRequest(false) } time.Sleep(3 * time.Second) for range 5 { singleRequest(false) } time.Sleep(3 * time.Second) for range 5 { singleRequest(false) } time.Sleep(3 * time.Second) for range 10 { singleRequest(false) } // requests should fail now for range 5 { singleRequest(true) } } ================================================ FILE: middleware/limiter/manager.go ================================================ package limiter import ( "context" "fmt" "sync" "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/internal/memory" ) // msgp -file="manager.go" -o="manager_msgp.go" -tests=false -unexported // //go:generate msgp -o=manager_msgp.go -tests=false -unexported type item struct { currHits int prevHits int exp uint64 } //msgp:ignore manager type manager struct { pool sync.Pool memory *memory.Storage storage fiber.Storage redactKeys bool } const redactedKey = "[redacted]" func newManager(storage fiber.Storage, redactKeys bool) *manager { // Create new storage handler manager := &manager{ pool: sync.Pool{ New: func() any { return new(item) }, }, redactKeys: redactKeys, } if storage != nil { // Use provided storage if provided manager.storage = storage } else { // Fallback too memory storage manager.memory = memory.New() } return manager } // acquire returns an *entry from the sync.Pool func (m *manager) acquire() *item { return m.pool.Get().(*item) //nolint:forcetypeassert,errcheck // We store nothing else in the pool } // release and reset *entry to sync.Pool func (m *manager) release(e *item) { e.prevHits = 0 e.currHits = 0 e.exp = 0 m.pool.Put(e) } // get data from storage or memory func (m *manager) get(ctx context.Context, key string) (*item, error) { if m.storage != nil { raw, err := m.storage.GetWithContext(ctx, key) if err != nil { return nil, fmt.Errorf("limiter: failed to get key %q from storage: %w", m.logKey(key), err) } if raw != nil { it := m.acquire() if _, err := it.UnmarshalMsg(raw); err != nil { m.release(it) return nil, fmt.Errorf("limiter: failed to unmarshal key %q: %w", m.logKey(key), err) } return it, nil } return m.acquire(), nil } value := m.memory.Get(key) if value == nil { return m.acquire(), nil } it, ok := value.(*item) if !ok { return nil, fmt.Errorf("limiter: unexpected entry type %T for key %q", value, m.logKey(key)) } return it, nil } // set data to storage or memory func (m *manager) set(ctx context.Context, key string, it *item, exp time.Duration) error { if m.storage != nil { raw, err := it.MarshalMsg(nil) if err != nil { m.release(it) return fmt.Errorf("limiter: failed to marshal key %q: %w", m.logKey(key), err) } if err := m.storage.SetWithContext(ctx, key, raw, exp); err != nil { m.release(it) return fmt.Errorf("limiter: failed to store key %q: %w", m.logKey(key), err) } m.release(it) return nil } m.memory.Set(key, it, exp) return nil } func (m *manager) logKey(key string) string { if m.redactKeys { return redactedKey } return key } ================================================ FILE: middleware/limiter/manager_msgp.go ================================================ // Code generated by github.com/tinylib/msgp DO NOT EDIT. package limiter import ( "github.com/tinylib/msgp/msgp" ) // DecodeMsg implements msgp.Decodable func (z *item) DecodeMsg(dc *msgp.Reader) (err error) { var field []byte _ = field var zb0001 uint32 zb0001, err = dc.ReadMapHeader() if err != nil { err = msgp.WrapError(err) return } for zb0001 > 0 { zb0001-- field, err = dc.ReadMapKeyPtr() if err != nil { err = msgp.WrapError(err) return } switch msgp.UnsafeString(field) { case "currHits": z.currHits, err = dc.ReadInt() if err != nil { err = msgp.WrapError(err, "currHits") return } case "prevHits": z.prevHits, err = dc.ReadInt() if err != nil { err = msgp.WrapError(err, "prevHits") return } case "exp": z.exp, err = dc.ReadUint64() if err != nil { err = msgp.WrapError(err, "exp") return } default: err = dc.Skip() if err != nil { err = msgp.WrapError(err) return } } } return } // EncodeMsg implements msgp.Encodable func (z item) EncodeMsg(en *msgp.Writer) (err error) { // map header, size 3 // write "currHits" err = en.Append(0x83, 0xa8, 0x63, 0x75, 0x72, 0x72, 0x48, 0x69, 0x74, 0x73) if err != nil { return } err = en.WriteInt(z.currHits) if err != nil { err = msgp.WrapError(err, "currHits") return } // write "prevHits" err = en.Append(0xa8, 0x70, 0x72, 0x65, 0x76, 0x48, 0x69, 0x74, 0x73) if err != nil { return } err = en.WriteInt(z.prevHits) if err != nil { err = msgp.WrapError(err, "prevHits") return } // write "exp" err = en.Append(0xa3, 0x65, 0x78, 0x70) if err != nil { return } err = en.WriteUint64(z.exp) if err != nil { err = msgp.WrapError(err, "exp") return } return } // MarshalMsg implements msgp.Marshaler func (z item) MarshalMsg(b []byte) (o []byte, err error) { o = msgp.Require(b, z.Msgsize()) // map header, size 3 // string "currHits" o = append(o, 0x83, 0xa8, 0x63, 0x75, 0x72, 0x72, 0x48, 0x69, 0x74, 0x73) o = msgp.AppendInt(o, z.currHits) // string "prevHits" o = append(o, 0xa8, 0x70, 0x72, 0x65, 0x76, 0x48, 0x69, 0x74, 0x73) o = msgp.AppendInt(o, z.prevHits) // string "exp" o = append(o, 0xa3, 0x65, 0x78, 0x70) o = msgp.AppendUint64(o, z.exp) return } // UnmarshalMsg implements msgp.Unmarshaler func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) { var field []byte _ = field var zb0001 uint32 zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) if err != nil { err = msgp.WrapError(err) return } for zb0001 > 0 { zb0001-- field, bts, err = msgp.ReadMapKeyZC(bts) if err != nil { err = msgp.WrapError(err) return } switch msgp.UnsafeString(field) { case "currHits": z.currHits, bts, err = msgp.ReadIntBytes(bts) if err != nil { err = msgp.WrapError(err, "currHits") return } case "prevHits": z.prevHits, bts, err = msgp.ReadIntBytes(bts) if err != nil { err = msgp.WrapError(err, "prevHits") return } case "exp": z.exp, bts, err = msgp.ReadUint64Bytes(bts) if err != nil { err = msgp.WrapError(err, "exp") return } default: bts, err = msgp.Skip(bts) if err != nil { err = msgp.WrapError(err) return } } } o = bts return } // Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message func (z item) Msgsize() (s int) { s = 1 + 9 + msgp.IntSize + 9 + msgp.IntSize + 4 + msgp.Uint64Size return } ================================================ FILE: middleware/limiter/manager_msgp_test.go ================================================ package limiter // Code generated by github.com/tinylib/msgp DO NOT EDIT. import ( "bytes" "testing" "github.com/tinylib/msgp/msgp" ) func TestMarshalUnmarshalitem(t *testing.T) { v := item{} bts, err := v.MarshalMsg(nil) if err != nil { t.Fatal(err) } left, err := v.UnmarshalMsg(bts) if err != nil { t.Fatal(err) } if len(left) > 0 { t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) } left, err = msgp.Skip(bts) if err != nil { t.Fatal(err) } if len(left) > 0 { t.Errorf("%d bytes left over after Skip(): %q", len(left), left) } } func BenchmarkMarshalMsgitem(b *testing.B) { v := item{} b.ReportAllocs() for b.Loop() { v.MarshalMsg(nil) } } func BenchmarkAppendMsgitem(b *testing.B) { v := item{} bts := make([]byte, 0, v.Msgsize()) bts, _ = v.MarshalMsg(bts[0:0]) b.SetBytes(int64(len(bts))) b.ReportAllocs() for b.Loop() { bts, _ = v.MarshalMsg(bts[0:0]) } } func BenchmarkUnmarshalitem(b *testing.B) { v := item{} bts, _ := v.MarshalMsg(nil) b.ReportAllocs() b.SetBytes(int64(len(bts))) for b.Loop() { _, err := v.UnmarshalMsg(bts) if err != nil { b.Fatal(err) } } } func TestEncodeDecodeitem(t *testing.T) { v := item{} var buf bytes.Buffer msgp.Encode(&buf, &v) m := v.Msgsize() if buf.Len() > m { t.Log("WARNING: TestEncodeDecodeitem Msgsize() is inaccurate") } vn := item{} err := msgp.Decode(&buf, &vn) if err != nil { t.Error(err) } buf.Reset() msgp.Encode(&buf, &v) err = msgp.NewReader(&buf).Skip() if err != nil { t.Error(err) } } func BenchmarkEncodeitem(b *testing.B) { v := item{} var buf bytes.Buffer msgp.Encode(&buf, &v) b.SetBytes(int64(buf.Len())) en := msgp.NewWriter(msgp.Nowhere) b.ReportAllocs() for b.Loop() { v.EncodeMsg(en) } en.Flush() } func BenchmarkDecodeitem(b *testing.B) { v := item{} var buf bytes.Buffer msgp.Encode(&buf, &v) b.SetBytes(int64(buf.Len())) rd := msgp.NewEndlessReader(buf.Bytes(), b) dc := msgp.NewReader(rd) b.ReportAllocs() for b.Loop() { err := v.DecodeMsg(dc) if err != nil { b.Fatal(err) } } } ================================================ FILE: middleware/logger/config.go ================================================ package logger import ( "io" "os" "time" "github.com/gofiber/fiber/v3" ) // Config defines the config for middleware. type Config struct { // Stream is a writer where logs are written // // Default: os.Stdout Stream io.Writer // Next defines a function to skip this middleware when returned true. // // Optional. Default: nil Next func(c fiber.Ctx) bool // Skip is a function to determine if logging is skipped or written to Stream. // // Optional. Default: nil Skip func(c fiber.Ctx) bool // Done is a function that is called after the log string for a request is written to Output, // and pass the log string as parameter. // // Optional. Default: nil Done func(c fiber.Ctx, logString []byte) // tagFunctions defines the custom tag action // // Optional. Default: map[string]LogFunc CustomTags map[string]LogFunc // You can define specific things before returning the handler: colors, template, etc. // // Optional. Default: beforeHandlerFunc BeforeHandlerFunc func(*Config) // You can use custom loggers with Fiber by using this field. // This field is really useful if you're using Zerolog, Zap, Logrus, apex/log etc. // If you don't define anything for this field, it'll use default logger of Fiber. // // Optional. Default: defaultLogger LoggerFunc func(c fiber.Ctx, data *Data, cfg *Config) error timeZoneLocation *time.Location // Format defines the logging format for the middleware. // // You can customize the log output by defining a format string with placeholders // such as: ${time}, ${ip}, ${status}, ${method}, ${path}, ${latency}, ${error}, etc. // The full list of available placeholders can be found in 'tags.go' or at // 'https://docs.gofiber.io/api/middleware/logger/#constants'. // // Fiber provides predefined logging formats that can be used directly: // // - DefaultFormat → Uses the default log format: "[${time}] ${ip} ${status} - ${latency} ${method} ${path} ${error}" // - CommonFormat → Uses the Apache Common Log Format (CLF): "${ip} - - [${time}] \"${method} ${url} ${protocol}\" ${status} ${bytesSent}\n" // - CombinedFormat → Uses the Apache Combined Log Format: "${ip} - - [${time}] \"${method} ${url} ${protocol}\" ${status} ${bytesSent} \"${referer}\" \"${ua}\"\n" // - JSONFormat → Uses the JSON log format: "{\"time\":\"${time}\",\"ip\":\"${ip}\",\"method\":\"${method}\",\"url\":\"${url}\",\"status\":${status},\"bytesSent\":${bytesSent}}\n" // - ECSFormat → Uses the Elastic Common Schema (ECS) log format: {\"@timestamp\":\"${time}\",\"ecs\":{\"version\":\"1.6.0\"},\"client\":{\"ip\":\"${ip}\"},\"http\":{\"request\":{\"method\":\"${method}\",\"url\":\"${url}\",\"protocol\":\"${protocol}\"},\"response\":{\"status_code\":${status},\"body\":{\"bytes\":${bytesSent}}}},\"log\":{\"level\":\"INFO\",\"logger\":\"fiber\"},\"message\":\"${method} ${url} responded with ${status}\"}" // If both `Format` and `CustomFormat` are provided, the `CustomFormat` will be used, and the `Format` field will be ignored. // If no format is specified, the default format is used: // "[${time}] ${ip} ${status} - ${latency} ${method} ${path} ${error}" Format string // TimeFormat https://programming.guide/go/format-parse-string-time-date-example.html // // Optional. Default: 15:04:05 TimeFormat string // TimeZone can be specified, such as "UTC" and "America/New_York" and "Asia/Chongqing", etc // // Optional. Default: "Local" TimeZone string // TimeInterval is the delay before the timestamp is updated // // Optional. Default: 500 * time.Millisecond TimeInterval time.Duration // DisableColors defines if the logs output should be colorized // // Default: false DisableColors bool // ForceColors forces the colors to be enabled even if the output is not a terminal // // Default: false ForceColors bool enableColors bool enableLatency bool } const ( startTag = "${" endTag = "}" paramSeparator = ":" ) // Buffer abstracts the buffer operations used when rendering log entries. type Buffer interface { Len() int ReadFrom(r io.Reader) (int64, error) WriteTo(w io.Writer) (int64, error) Bytes() []byte Write(p []byte) (int, error) WriteByte(c byte) error WriteString(s string) (int, error) Set(p []byte) SetString(s string) String() string } // LogFunc formats logging output using the provided buffer and request data. type LogFunc func(output Buffer, c fiber.Ctx, data *Data, extraParam string) (int, error) // ConfigDefault is the default config var ConfigDefault = Config{ Next: nil, Skip: nil, Done: nil, Format: DefaultFormat, TimeFormat: "15:04:05", TimeZone: "Local", TimeInterval: 500 * time.Millisecond, Stream: os.Stdout, BeforeHandlerFunc: beforeHandlerFunc, LoggerFunc: defaultLoggerInstance, enableColors: true, } // Helper function to set default values func configDefault(config ...Config) Config { // Return default config if nothing provided if len(config) < 1 { return ConfigDefault } // Override default config cfg := config[0] // Set default values if cfg.Next == nil { cfg.Next = ConfigDefault.Next } if cfg.Skip == nil { cfg.Skip = ConfigDefault.Skip } if cfg.Done == nil { cfg.Done = ConfigDefault.Done } if cfg.Format == "" { cfg.Format = ConfigDefault.Format } if cfg.TimeZone == "" { cfg.TimeZone = ConfigDefault.TimeZone } if cfg.TimeFormat == "" { cfg.TimeFormat = ConfigDefault.TimeFormat } if int(cfg.TimeInterval) <= 0 { cfg.TimeInterval = ConfigDefault.TimeInterval } if cfg.Stream == nil { cfg.Stream = ConfigDefault.Stream } if cfg.BeforeHandlerFunc == nil { cfg.BeforeHandlerFunc = ConfigDefault.BeforeHandlerFunc } if cfg.LoggerFunc == nil { cfg.LoggerFunc = ConfigDefault.LoggerFunc } // Enable colors if no custom format or output is given if (!cfg.DisableColors && cfg.Stream == ConfigDefault.Stream) || cfg.ForceColors { cfg.enableColors = true } return cfg } ================================================ FILE: middleware/logger/data.go ================================================ package logger import ( "sync/atomic" "time" ) // Data is a struct to define some variables to use in custom logger function. type Data struct { Start time.Time Stop time.Time ChainErr error Timestamp atomic.Value Pid string ErrPaddingStr string TemplateChain [][]byte LogFuncChain []LogFunc } ================================================ FILE: middleware/logger/default_logger.go ================================================ package logger import ( "fmt" "io" "os" "strconv" "github.com/gofiber/fiber/v3" "github.com/gofiber/utils/v2" "github.com/mattn/go-colorable" "github.com/mattn/go-isatty" "github.com/valyala/bytebufferpool" ) // default logger for fiber func defaultLoggerInstance(c fiber.Ctx, data *Data, cfg *Config) error { if cfg == nil { cfg = &Config{ Stream: os.Stdout, Format: DefaultFormat, enableColors: true, } } // Check if Skip is defined and call it. // Now, if Skip(c) == true, we SKIP logging: if cfg.Skip != nil && cfg.Skip(c) { return nil // Skip logging if Skip returns true } // Alias colors colors := c.App().Config().ColorScheme // Get new buffer buf := bytebufferpool.Get() // Default output when no custom Format or io.Writer is given if cfg.Format == DefaultFormat { // Format error if exist formatErr := "" if cfg.enableColors { if data.ChainErr != nil { formatErr = colors.Red + " | " + data.ChainErr.Error() + colors.Reset } fmt.Fprintf(buf, "%s |%s %3d %s| %13v | %15s |%s %-7s %s| %-"+data.ErrPaddingStr+"s %s\n", data.Timestamp.Load().(string), //nolint:forcetypeassert,errcheck // Timestamp is always a string statusColor(c.Response().StatusCode(), &colors), c.Response().StatusCode(), colors.Reset, data.Stop.Sub(data.Start), c.IP(), methodColor(c.Method(), &colors), c.Method(), colors.Reset, c.Path(), formatErr, ) } else { if data.ChainErr != nil { formatErr = " | " + data.ChainErr.Error() } // Helper function to append fixed-width string with padding fixedWidth := func(s string, width int, rightAlign bool) { if rightAlign { for i := len(s); i < width; i++ { buf.WriteByte(' ') } buf.WriteString(s) } else { buf.WriteString(s) for i := len(s); i < width; i++ { buf.WriteByte(' ') } } } // Timestamp buf.WriteString(data.Timestamp.Load().(string)) //nolint:forcetypeassert,errcheck // Timestamp is always a string buf.WriteString(" | ") // Status Code with 3 fixed width, right aligned fixedWidth(strconv.Itoa(c.Response().StatusCode()), 3, true) buf.WriteString(" | ") // Duration with 13 fixed width, right aligned fixedWidth(data.Stop.Sub(data.Start).String(), 13, true) buf.WriteString(" | ") // Client IP with 15 fixed width, right aligned fixedWidth(c.IP(), 15, true) buf.WriteString(" | ") // HTTP Method with 7 fixed width, left aligned fixedWidth(c.Method(), 7, false) buf.WriteString(" | ") // Path with dynamic padding for error message, left aligned errPadding, _ := strconv.Atoi(data.ErrPaddingStr) //nolint:errcheck // It is fine to ignore the error fixedWidth(c.Path(), errPadding, false) // Error message buf.WriteString(" ") buf.WriteString(formatErr) buf.WriteString("\n") } // Write buffer to output writeLog(cfg.Stream, buf.Bytes()) if cfg.Done != nil { cfg.Done(c, buf.Bytes()) } // Put buffer back to pool bytebufferpool.Put(buf) // End chain return nil } var err error // Loop over template parts execute dynamic parts and add fixed parts to the buffer for i, logFunc := range data.LogFuncChain { switch { case logFunc == nil: buf.Write(data.TemplateChain[i]) case data.TemplateChain[i] == nil: _, err = logFunc(buf, c, data, "") default: _, err = logFunc(buf, c, data, utils.UnsafeString(data.TemplateChain[i])) } if err != nil { break } } // Also write errors to the buffer if err != nil { buf.WriteString(err.Error()) } writeLog(cfg.Stream, buf.Bytes()) if cfg.Done != nil { cfg.Done(c, buf.Bytes()) } // Put buffer back to pool bytebufferpool.Put(buf) return nil } // run something before returning the handler func beforeHandlerFunc(cfg *Config) { if cfg == nil { return } // If colors are enabled, check terminal compatibility if cfg.enableColors && cfg.Stream == os.Stdout { cfg.Stream = colorable.NewColorableStdout() if !cfg.ForceColors && (os.Getenv("TERM") == "dumb" || os.Getenv("NO_COLOR") == "1" || (!isatty.IsTerminal(os.Stdout.Fd()) && !isatty.IsCygwinTerminal(os.Stdout.Fd()))) { cfg.Stream = colorable.NewNonColorable(os.Stdout) } } } func appendInt(output Buffer, v int) (int, error) { old := output.Len() output.Set(strconv.AppendInt(output.Bytes(), int64(v), 10)) return output.Len() - old, nil } // writeLog writes a msg to w, printing a warning to stderr if the log fails. func writeLog(w io.Writer, msg []byte) { if _, err := w.Write(msg); err != nil { // Write error to output if _, writeErr := w.Write([]byte(err.Error())); writeErr != nil { // There is something wrong with the given io.Writer _, _ = fmt.Fprintf(os.Stderr, "Failed to write to log, %v\n", writeErr) } } } ================================================ FILE: middleware/logger/errors.go ================================================ package logger import ( "errors" ) // ErrTemplateParameterMissing indicates that a template parameter was referenced but not provided. var ErrTemplateParameterMissing = errors.New("logger: template parameter missing") ================================================ FILE: middleware/logger/format.go ================================================ package logger const ( // Fiber's default logger DefaultFormat = "[${time}] ${ip} ${status} - ${latency} ${method} ${path} ${error}\n" // Apache Common Log Format (CLF) CommonFormat = "${ip} - - [${time}] \"${method} ${url} ${protocol}\" ${status} ${bytesSent}\n" // Apache Combined Log Format CombinedFormat = "${ip} - - [${time}] \"${method} ${url} ${protocol}\" ${status} ${bytesSent} \"${referer}\" \"${ua}\"\n" // JSON log formats JSONFormat = "{\"time\":\"${time}\",\"ip\":\"${ip}\",\"method\":\"${method}\",\"url\":\"${url}\",\"status\":${status},\"bytesSent\":${bytesSent}}\n" // Elastic Common Schema (ECS) Log Format ECSFormat = "{\"@timestamp\":\"${time}\",\"ecs\":{\"version\":\"1.6.0\"},\"client\":{\"ip\":\"${ip}\"},\"http\":{\"request\":{\"method\":\"${method}\",\"url\":\"${url}\",\"protocol\":\"${protocol}\"},\"response\":{\"status_code\":${status},\"body\":{\"bytes\":${bytesSent}}}},\"log\":{\"level\":\"INFO\",\"logger\":\"fiber\"},\"message\":\"${method} ${url} responded with ${status}\"}\n" ) ================================================ FILE: middleware/logger/logger.go ================================================ package logger import ( "os" "strconv" "strings" "sync" "sync/atomic" "time" "github.com/gofiber/fiber/v3" ) // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Set default config cfg := configDefault(config...) // Get timezone location tz, err := time.LoadLocation(cfg.TimeZone) if err != nil || tz == nil { cfg.timeZoneLocation = time.Local } else { cfg.timeZoneLocation = tz } // Check if format contains latency cfg.enableLatency = strings.Contains(cfg.Format, "${"+TagLatency+"}") var timestamp atomic.Value // Create correct timeformat timestamp.Store(time.Now().In(cfg.timeZoneLocation).Format(cfg.TimeFormat)) // Update date/time every 500 milliseconds in a separate go routine if strings.Contains(cfg.Format, "${"+TagTime+"}") { go func() { for { time.Sleep(cfg.TimeInterval) timestamp.Store(time.Now().In(cfg.timeZoneLocation).Format(cfg.TimeFormat)) } }() } // Set PID once pid := strconv.Itoa(os.Getpid()) // Set variables var ( once sync.Once errHandler fiber.ErrorHandler dataPool = sync.Pool{New: func() any { return new(Data) }} ) // Err padding errPadding := 15 errPaddingStr := strconv.Itoa(errPadding) // Before handling func cfg.BeforeHandlerFunc(&cfg) // Logger data // instead of analyzing the template inside(handler) each time, this is done once before // and we create several slices of the same length with the functions to be executed and fixed parts. templateChain, logFunChain, err := buildLogFuncChain(&cfg, createTagMap(&cfg)) if err != nil { panic(err) } // Return new handler return func(c fiber.Ctx) error { // Don't execute middleware if Next returns true if cfg.Next != nil && cfg.Next(c) { return c.Next() } // Set error handler once once.Do(func() { // get longest possible path stack := c.App().Stack() for m := range stack { for r := range stack[m] { if len(stack[m][r].Path) > errPadding { errPadding = len(stack[m][r].Path) errPaddingStr = strconv.Itoa(errPadding) } } } // override error handler errHandler = c.App().ErrorHandler }) // Logger data data := dataPool.Get().(*Data) //nolint:forcetypeassert,errcheck // We store nothing else in the pool // no need for a reset, as long as we always override everything data.Pid = pid data.ErrPaddingStr = errPaddingStr data.Timestamp = timestamp data.TemplateChain = templateChain data.LogFuncChain = logFunChain // put data back in the pool defer dataPool.Put(data) // Set latency start time if cfg.enableLatency { data.Start = time.Now() } // Handle request, store err for logging chainErr := c.Next() data.ChainErr = chainErr // Manually call error handler if chainErr != nil { if err := errHandler(c, chainErr); err != nil { _ = c.SendStatus(fiber.StatusInternalServerError) //nolint:errcheck // TODO: Explain why we ignore the error here } } // Set latency stop time if cfg.enableLatency { data.Stop = time.Now() } // Logger instance & update some logger data fields return cfg.LoggerFunc(c, data, &cfg) } } ================================================ FILE: middleware/logger/logger_test.go ================================================ //nolint:depguard // Because we test logging :D package logger import ( "bufio" "bytes" "errors" "fmt" "io" "log" "net/http" "net/http/httptest" "os" "regexp" "runtime" "strconv" "sync" "testing" "time" "github.com/stretchr/testify/require" "github.com/valyala/bytebufferpool" "github.com/valyala/fasthttp" "github.com/gofiber/fiber/v3" fiberlog "github.com/gofiber/fiber/v3/log" "github.com/gofiber/fiber/v3/middleware/requestid" ) const ( pathFooBar = "/?foo=bar" httpProto = "HTTP/1.1" ) func benchmarkSetup(b *testing.B, app *fiber.App, uri string) { b.Helper() h := app.Handler() fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(fiber.MethodGet) fctx.Request.SetRequestURI(uri) b.ReportAllocs() for b.Loop() { h(fctx) } } func benchmarkSetupParallel(b *testing.B, app *fiber.App, path string) { b.Helper() handler := app.Handler() b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(fiber.MethodGet) fctx.Request.SetRequestURI(path) for pb.Next() { handler(fctx) } }) } // go test -run Test_Logger func Test_Logger(t *testing.T) { t.Parallel() app := fiber.New() buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) app.Use(New(Config{ Format: "${error}", Stream: buf, })) app.Get("/", func(_ fiber.Ctx) error { return errors.New("some random error") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) require.Equal(t, "some random error", buf.String()) } // go test -run Test_Logger_locals func Test_Logger_locals(t *testing.T) { t.Parallel() app := fiber.New() buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) app.Use(New(Config{ Format: "${locals:demo}", Stream: buf, })) app.Get("/", func(c fiber.Ctx) error { c.Locals("demo", "johndoe") return c.SendStatus(fiber.StatusOK) }) app.Get("/int", func(c fiber.Ctx) error { c.Locals("demo", 55) return c.SendStatus(fiber.StatusOK) }) app.Get("/empty", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Equal(t, "johndoe", buf.String()) buf.Reset() resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/int", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Equal(t, "55", buf.String()) buf.Reset() resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/empty", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Empty(t, buf.String()) } // go test -run Test_Logger_Next func Test_Logger_Next(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Next: func(_ fiber.Ctx) bool { return true }, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusNotFound, resp.StatusCode) } // go test -run Test_Logger_Done func Test_Logger_Done(t *testing.T) { t.Parallel() buf := bytes.NewBuffer(nil) app := fiber.New() app.Use(New(Config{ Done: func(c fiber.Ctx, logString []byte) { if c.Response().StatusCode() == fiber.StatusOK { _, err := buf.Write(logString) require.NoError(t, err) } }, })).Get("/logging", func(ctx fiber.Ctx) error { return ctx.SendStatus(fiber.StatusOK) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/logging", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Positive(t, buf.Len(), 0) } // Test_Logger_Filter tests the Filter functionality of the logger middleware. // It verifies that logs are written or skipped based on the filter condition. func Test_Logger_Filter(t *testing.T) { t.Parallel() t.Run("Test Not Found", func(t *testing.T) { t.Parallel() app := fiber.New() logOutput := bytes.Buffer{} // Return true to skip logging for all requests != 404 app.Use(New(Config{ Skip: func(c fiber.Ctx) bool { return c.Response().StatusCode() != fiber.StatusNotFound }, Stream: &logOutput, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/nonexistent", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusNotFound, resp.StatusCode) // Expect logs for the 404 request require.Contains(t, logOutput.String(), "404") }) t.Run("Test OK", func(t *testing.T) { t.Parallel() app := fiber.New() logOutput := bytes.Buffer{} // Return true to skip logging for all requests == 200 app.Use(New(Config{ Skip: func(c fiber.Ctx) bool { return c.Response().StatusCode() == fiber.StatusOK }, Stream: &logOutput, })) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) // We skip logging for status == 200, so "200" should not appear require.NotContains(t, logOutput.String(), "200") }) t.Run("Always Skip", func(t *testing.T) { t.Parallel() app := fiber.New() logOutput := bytes.Buffer{} // Filter always returns true => skip all logs app.Use(New(Config{ Skip: func(_ fiber.Ctx) bool { return true // always skip }, Stream: &logOutput, })) app.Get("/something", func(c fiber.Ctx) error { return c.Status(fiber.StatusTeapot).SendString("I'm a teapot") }) _, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/something", http.NoBody)) require.NoError(t, err) // Expect NO logs require.Empty(t, logOutput.String()) }) t.Run("Never Skip", func(t *testing.T) { t.Parallel() app := fiber.New() logOutput := bytes.Buffer{} // Filter always returns false => never skip logs app.Use(New(Config{ Skip: func(_ fiber.Ctx) bool { return false // never skip }, Stream: &logOutput, })) app.Get("/always", func(c fiber.Ctx) error { return c.Status(fiber.StatusTeapot).SendString("Teapot again") }) _, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/always", http.NoBody)) require.NoError(t, err) // Expect some logging - check any substring require.Contains(t, logOutput.String(), strconv.Itoa(fiber.StatusTeapot)) }) t.Run("Skip /healthz", func(t *testing.T) { t.Parallel() app := fiber.New() logOutput := bytes.Buffer{} // Filter returns true (skip logs) if the request path is /healthz app.Use(New(Config{ Skip: func(c fiber.Ctx) bool { return c.Path() == "/healthz" }, Stream: &logOutput, })) // Normal route app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello World!") }) // Health route app.Get("/healthz", func(c fiber.Ctx) error { return c.SendString("OK") }) // Request to "/" -> should be logged _, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Contains(t, logOutput.String(), "200") // Reset output buffer logOutput.Reset() // Request to "/healthz" -> should be skipped _, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/healthz", http.NoBody)) require.NoError(t, err) require.Empty(t, logOutput.String()) }) } // go test -run Test_Logger_ErrorTimeZone func Test_Logger_ErrorTimeZone(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ TimeZone: "invalid", })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusNotFound, resp.StatusCode) } // go test -run Test_Logger_Fiber_Logger func Test_Logger_LoggerToWriter(t *testing.T) { app := fiber.New() buf := bytebufferpool.Get() t.Cleanup(func() { bytebufferpool.Put(buf) }) logger := fiberlog.DefaultLogger[*log.Logger]() stdlogger := logger.Logger() stdlogger.SetFlags(0) logger.SetOutput(buf) testCases := []struct { levelStr string level fiberlog.Level }{ { level: fiberlog.LevelTrace, levelStr: "Trace", }, { level: fiberlog.LevelDebug, levelStr: "Debug", }, { level: fiberlog.LevelInfo, levelStr: "Info", }, { level: fiberlog.LevelWarn, levelStr: "Warn", }, { level: fiberlog.LevelError, levelStr: "Error", }, } for _, tc := range testCases { level := strconv.Itoa(int(tc.level)) t.Run(level, func(t *testing.T) { buf.Reset() app.Use("/"+level, New(Config{ Format: "${error}", Stream: LoggerToWriter(logger, tc. level), })) app.Get("/"+level, func(_ fiber.Ctx) error { return errors.New("some random error") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/"+level, http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) require.Equal(t, "["+tc.levelStr+"] some random error\n", buf.String()) }) require.Panics(t, func() { LoggerToWriter(logger, fiberlog.LevelPanic) }) require.Panics(t, func() { LoggerToWriter(logger, fiberlog.LevelFatal) }) require.Panics(t, func() { LoggerToWriter[any](nil, fiberlog.LevelFatal) }) } } type fakeErrorOutput int func (o *fakeErrorOutput) Write([]byte) (int, error) { *o++ return 0, errors.New("fake output") } // go test -run Test_Logger_ErrorOutput_WithoutColor func Test_Logger_ErrorOutput_WithoutColor(t *testing.T) { t.Parallel() o := new(fakeErrorOutput) app := fiber.New() app.Use(New(Config{ Stream: o, DisableColors: true, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusNotFound, resp.StatusCode) require.EqualValues(t, 2, *o) } // go test -run Test_Logger_ErrorOutput func Test_Logger_ErrorOutput(t *testing.T) { t.Parallel() o := new(fakeErrorOutput) app := fiber.New() app.Use(New(Config{ Stream: o, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusNotFound, resp.StatusCode) require.EqualValues(t, 2, *o) } // go test -run Test_Logger_All func Test_Logger_All(t *testing.T) { t.Parallel() buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) app := fiber.New() app.Use(New(Config{ Format: "${pid}${reqHeaders}${referer}${scheme}${protocol}${ip}${ips}${host}${url}${ua}${body}${route}${black}${red}${green}${yellow}${blue}${magenta}${cyan}${white}${reset}${error}${reqHeader:test}${query:test}${form:test}${cookie:test}${non}", Stream: buf, })) // Alias colors colors := app.Config().ColorScheme resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, pathFooBar, http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusNotFound, resp.StatusCode) expected := fmt.Sprintf("%dHost=example.comhttpHTTP/1.10.0.0.0example.com/?foo=bar/%s%s%s%s%s%s%s%s%sNot Found", os.Getpid(), colors.Black, colors.Red, colors.Green, colors.Yellow, colors.Blue, colors.Magenta, colors.Cyan, colors.White, colors.Reset) require.Equal(t, expected, buf.String()) } func Test_Logger_CLF_Format(t *testing.T) { t.Parallel() buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) app := fiber.New() app.Use(New(Config{ Format: CommonFormat, Stream: buf, })) method := fiber.MethodGet status := fiber.StatusNotFound bytesSent := 0 resp, err := app.Test(httptest.NewRequest(method, pathFooBar, http.NoBody)) require.NoError(t, err) require.Equal(t, status, resp.StatusCode) pattern := fmt.Sprintf(`0\.0\.0\.0 - - \[\d{2}:\d{2}:\d{2}\] "%s %s %s" %d %d`, method, regexp.QuoteMeta(pathFooBar), httpProto, status, bytesSent) require.Regexp(t, pattern, buf.String()) } func Test_Logger_Combined_CLF_Format(t *testing.T) { t.Parallel() buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) app := fiber.New() app.Use(New(Config{ Format: CombinedFormat, Stream: buf, })) method := fiber.MethodGet status := fiber.StatusNotFound bytesSent := 0 referer := "http://example.com" ua := "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/74.0.3729.169 Safari/537.36" req := httptest.NewRequest(method, pathFooBar, http.NoBody) req.Header.Set("Referer", referer) req.Header.Set("User-Agent", ua) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, status, resp.StatusCode) pattern := fmt.Sprintf(`0\.0\.0\.0 - - \[\d{2}:\d{2}:\d{2}\] "%s %s %s" %d %d "%s" "%s"`, method, regexp.QuoteMeta(pathFooBar), httpProto, status, bytesSent, regexp.QuoteMeta(referer), regexp.QuoteMeta(ua)) //nolint:gocritic // double quoting for regex and string is not needed require.Regexp(t, pattern, buf.String()) } func Test_Logger_Json_Format(t *testing.T) { t.Parallel() buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) app := fiber.New() app.Use(New(Config{ Format: JSONFormat, Stream: buf, })) method := fiber.MethodGet status := fiber.StatusNotFound ip := "0.0.0.0" bytesSent := 0 req := httptest.NewRequest(method, pathFooBar, http.NoBody) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, status, resp.StatusCode) pattern := fmt.Sprintf(`\{"time":"\d{2}:\d{2}:\d{2}","ip":"%s","method":%q,"url":"%s","status":%d,"bytesSent":%d\}`, regexp.QuoteMeta(ip), method, regexp.QuoteMeta(pathFooBar), status, bytesSent) //nolint:gocritic // double quoting for regex and string is not needed require.Regexp(t, pattern, buf.String()) } func Test_Logger_ECS_Format(t *testing.T) { t.Parallel() buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) app := fiber.New() app.Use(New(Config{ Format: ECSFormat, Stream: buf, })) method := fiber.MethodGet status := fiber.StatusNotFound ip := "0.0.0.0" bytesSent := 0 msg := fmt.Sprintf("%s %s responded with %d", method, pathFooBar, status) req := httptest.NewRequest(method, pathFooBar, http.NoBody) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, status, resp.StatusCode) pattern := fmt.Sprintf(`\{"@timestamp":"\d{2}:\d{2}:\d{2}","ecs":\{"version":"1.6.0"\},"client":\{"ip":"%s"\},"http":\{"request":\{"method":%q,"url":"%s","protocol":%q\},"response":\{"status_code":%d,"body":\{"bytes":%d\}\}\},"log":\{"level":"INFO","logger":"fiber"\},"message":"%s"\}`, regexp.QuoteMeta(ip), method, regexp.QuoteMeta(pathFooBar), httpProto, status, bytesSent, regexp.QuoteMeta(msg)) //nolint:gocritic // double quoting for regex and string is not needed require.Regexp(t, pattern, buf.String()) } func getLatencyTimeUnits() []struct { unit string div time.Duration } { // windows does not support µs sleep precision // https://github.com/golang/go/issues/29485 if runtime.GOOS == "windows" { return []struct { unit string div time.Duration }{ {unit: "ms", div: time.Millisecond}, {unit: "s", div: time.Second}, } } return []struct { unit string div time.Duration }{ {unit: "µs", div: time.Microsecond}, {unit: "ms", div: time.Millisecond}, {unit: "s", div: time.Second}, } } // go test -run Test_Logger_WithLatency func Test_Logger_WithLatency(t *testing.T) { buff := bytebufferpool.Get() defer bytebufferpool.Put(buff) app := fiber.New() logger := New(Config{ Stream: buff, Format: "${latency}", }) app.Use(logger) // Define a list of time units to test timeUnits := getLatencyTimeUnits() // Initialize a new time unit sleepDuration := 1 * time.Nanosecond // Define a test route that sleeps app.Get("/test", func(c fiber.Ctx) error { time.Sleep(sleepDuration) return c.SendStatus(fiber.StatusOK) }) // Loop through each time unit and assert that the log output contains the expected latency value for _, tu := range timeUnits { // Update the sleep duration for the next iteration sleepDuration = 1 * tu.div // Create a new HTTP request to the test route resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody), fiber.TestConfig{ Timeout: 3 * time.Second, FailOnTimeout: true, }) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) // Assert that the log output contains the expected latency value in the current time unit require.True(t, bytes.HasSuffix(buff.Bytes(), []byte(tu.unit)), "Expected latency to be in %s, got %s", tu.unit, buff.String()) // Reset the buffer buff.Reset() } } // go test -run Test_Logger_WithLatency_DefaultFormat func Test_Logger_WithLatency_DefaultFormat(t *testing.T) { buff := bytebufferpool.Get() defer bytebufferpool.Put(buff) app := fiber.New() logger := New(Config{ Stream: buff, }) app.Use(logger) // Define a list of time units to test timeUnits := getLatencyTimeUnits() // Initialize a new time unit sleepDuration := 1 * time.Nanosecond // Define a test route that sleeps app.Get("/test", func(c fiber.Ctx) error { time.Sleep(sleepDuration) return c.SendStatus(fiber.StatusOK) }) // Loop through each time unit and assert that the log output contains the expected latency value for _, tu := range timeUnits { // Update the sleep duration for the next iteration sleepDuration = 1 * tu.div // Create a new HTTP request to the test route resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody), fiber.TestConfig{ Timeout: 2 * time.Second, FailOnTimeout: true, }) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) // Assert that the log output contains the expected latency value in the current time unit // parse out the latency value from the log output latency := bytes.Split(buff.Bytes(), []byte(" | "))[2] // Assert that the latency value is in the current time unit require.True(t, bytes.HasSuffix(latency, []byte(tu.unit)), "Expected latency to be in %s, got %s", tu.unit, latency) // Reset the buffer buff.Reset() } } // go test -run Test_Query_Params func Test_Query_Params(t *testing.T) { t.Parallel() buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) app := fiber.New() app.Use(New(Config{ Format: "${queryParams}", Stream: buf, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?foo=bar&baz=moz", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusNotFound, resp.StatusCode) expected := "foo=bar&baz=moz" require.Equal(t, expected, buf.String()) } // go test -run Test_Response_Body func Test_Response_Body(t *testing.T) { t.Parallel() buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) app := fiber.New() app.Use(New(Config{ Format: "${resBody}", Stream: buf, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Sample response body") }) app.Post("/test", func(c fiber.Ctx) error { return c.Send([]byte("Post in test")) }) _, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) expectedGetResponse := "Sample response body" require.Equal(t, expectedGetResponse, buf.String()) buf.Reset() // Reset buffer to test POST _, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/test", http.NoBody)) expectedPostResponse := "Post in test" require.NoError(t, err) require.Equal(t, expectedPostResponse, buf.String()) } // go test -run Test_Request_Body func Test_Request_Body(t *testing.T) { t.Parallel() buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) app := fiber.New() app.Use(New(Config{ Format: "${bytesReceived} ${bytesSent} ${status}", Stream: buf, })) app.Post("/", func(c fiber.Ctx) error { c.Response().Header.SetContentLength(5) return c.SendString("World") }) // Create a POST request with a body body := []byte("Hello") req := httptest.NewRequest(fiber.MethodPost, "/", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/octet-stream") _, err := app.Test(req) require.NoError(t, err) require.Equal(t, "5 5 200", buf.String()) } // go test -run Test_Logger_AppendUint func Test_Logger_AppendUint(t *testing.T) { t.Parallel() app := fiber.New() buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) app.Use(New(Config{ Format: "${bytesReceived} ${bytesSent} ${status}", Stream: buf, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("hello") }) app.Get("/content", func(c fiber.Ctx) error { c.Response().Header.SetContentLength(5) return c.SendString("hello") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Equal(t, "-2 0 200", buf.String()) buf.Reset() resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/content", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Equal(t, "-2 5 200", buf.String()) } // go test -run Test_Logger_Data_Race -race func Test_Logger_Data_Race(t *testing.T) { t.Parallel() app := fiber.New() buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) app.Use(New(ConfigDefault)) app.Use(New(Config{ Format: "${time} | ${pid} | ${locals:requestid} | ${status} | ${latency} | ${method} | ${path}\n", })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("hello") }) var ( resp1, resp2 *http.Response err1, err2 error ) wg := &sync.WaitGroup{} wg.Go(func() { resp1, err1 = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) }) resp2, err2 = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) wg.Wait() require.NoError(t, err1) require.Equal(t, fiber.StatusOK, resp1.StatusCode) require.NoError(t, err2) require.Equal(t, fiber.StatusOK, resp2.StatusCode) } // go test -run Test_Response_Header func Test_Response_Header(t *testing.T) { t.Parallel() buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) app := fiber.New() app.Use(requestid.New(requestid.Config{ Next: nil, Header: fiber.HeaderXRequestID, Generator: func() string { return "Hello fiber!" }, })) app.Use(New(Config{ Format: "${respHeader:X-Request-ID}", Stream: buf, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello fiber!") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Equal(t, "Hello fiber!", buf.String()) } // go test -run Test_Req_Header func Test_Req_Header(t *testing.T) { t.Parallel() buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) app := fiber.New() app.Use(New(Config{ Format: "${reqHeader:test}", Stream: buf, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello fiber!") }) headerReq := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) headerReq.Header.Add("test", "Hello fiber!") resp, err := app.Test(headerReq) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Equal(t, "Hello fiber!", buf.String()) } // go test -run Test_ReqHeader_Header func Test_ReqHeader_Header(t *testing.T) { t.Parallel() buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) app := fiber.New() app.Use(New(Config{ Format: "${reqHeader:test}", Stream: buf, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello fiber!") }) reqHeaderReq := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) reqHeaderReq.Header.Add("test", "Hello fiber!") resp, err := app.Test(reqHeaderReq) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Equal(t, "Hello fiber!", buf.String()) } // go test -run Test_CustomTags func Test_CustomTags(t *testing.T) { t.Parallel() customTag := "it is a custom tag" buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) app := fiber.New() app.Use(New(Config{ Format: "${custom_tag}", CustomTags: map[string]LogFunc{ "custom_tag": func(output Buffer, _ fiber.Ctx, _ *Data, _ string) (int, error) { return output.WriteString(customTag) }, }, Stream: buf, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello fiber!") }) reqHeaderReq := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) reqHeaderReq.Header.Add("test", "Hello fiber!") resp, err := app.Test(reqHeaderReq) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Equal(t, customTag, buf.String()) } // go test -run Test_Logger_ByteSent_Streaming func Test_Logger_ByteSent_Streaming(t *testing.T) { t.Parallel() app := fiber.New() buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) app.Use(New(Config{ Format: "${bytesReceived} ${bytesSent} ${status}", Stream: buf, })) app.Get("/", func(c fiber.Ctx) error { c.Set("Connection", "keep-alive") c.Set("Transfer-Encoding", "chunked") c.RequestCtx().SetBodyStreamWriter(func(w *bufio.Writer) { var i int for { i++ msg := fmt.Sprintf("%d - the time is %v", i, time.Now()) fmt.Fprintf(w, "data: Message: %s\n\n", msg) err := w.Flush() if err != nil { break } if i == 10 { break } } }) return nil }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) // -2 means identity, -1 means chunked, 200 status require.Equal(t, "-2 -1 200", buf.String()) } type fakeOutput int func (o *fakeOutput) Write(b []byte) (int, error) { *o++ return len(b), nil } // go test -run Test_Logger_EnableColors func Test_Logger_EnableColors(t *testing.T) { t.Parallel() o := new(fakeOutput) app := fiber.New() app.Use(New(Config{ Stream: o, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusNotFound, resp.StatusCode) require.EqualValues(t, 1, *o) } // go test -run Test_Logger_ForceColors func Test_Logger_ForceColors(t *testing.T) { t.Parallel() buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) app := fiber.New() app.Use(New(Config{ Format: "${ip}${status}${method}${path}${error}\n", Stream: buf, DisableColors: true, ForceColors: true, })) // Alias colors colors := app.Config().ColorScheme resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusNotFound, resp.StatusCode) expected := fmt.Sprintf("0.0.0.0%s404%s%sGET%s/%sNot Found%s\n", colors.Yellow, colors.Reset, colors.Cyan, colors.Reset, colors.Red, colors.Reset) require.Equal(t, expected, buf.String()) } // go test -v -run=^$ -bench=Benchmark_Logger$ -benchmem -count=4 func Benchmark_Logger(b *testing.B) { b.Run("NoMiddleware", func(bb *testing.B) { app := fiber.New() app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) benchmarkSetup(bb, app, "/") }) b.Run("WithBytesAndStatus", func(bb *testing.B) { app := fiber.New() app.Use(New(Config{ Format: "${bytesReceived} ${bytesSent} ${status}", Stream: io.Discard, })) app.Get("/", func(c fiber.Ctx) error { c.Set("test", "test") return c.SendString("Hello, World!") }) benchmarkSetup(bb, app, "/") }) b.Run("DefaultFormat", func(bb *testing.B) { app := fiber.New() app.Use(New(Config{ Stream: io.Discard, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) benchmarkSetup(bb, app, "/") }) b.Run("DefaultFormatDisableColors", func(bb *testing.B) { app := fiber.New() app.Use(New(Config{ Stream: io.Discard, DisableColors: true, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) benchmarkSetup(bb, app, "/") }) b.Run("DefaultFormatForceColors", func(bb *testing.B) { app := fiber.New() app.Use(New(Config{ Stream: io.Discard, ForceColors: true, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) benchmarkSetup(bb, app, "/") }) b.Run("DefaultFormatWithFiberLog", func(bb *testing.B) { app := fiber.New() logger := fiberlog.DefaultLogger[*log.Logger]() logger.SetOutput(io.Discard) app.Use(New(Config{ Stream: LoggerToWriter(logger, fiberlog.LevelDebug), })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) benchmarkSetup(bb, app, "/") }) b.Run("WithTagParameter", func(bb *testing.B) { app := fiber.New() app.Use(New(Config{ Format: "${bytesReceived} ${bytesSent} ${status} ${reqHeader:test}", Stream: io.Discard, })) app.Get("/", func(c fiber.Ctx) error { c.Set("test", "test") return c.SendString("Hello, World!") }) benchmarkSetup(bb, app, "/") }) b.Run("WithLocals", func(bb *testing.B) { app := fiber.New() app.Use(New(Config{ Format: "${locals:demo}", Stream: io.Discard, })) app.Get("/", func(c fiber.Ctx) error { c.Locals("demo", "johndoe") return c.SendStatus(fiber.StatusOK) }) benchmarkSetup(bb, app, "/") }) b.Run("WithLocalsInt", func(bb *testing.B) { app := fiber.New() app.Use(New(Config{ Format: "${locals:demo}", Stream: io.Discard, })) app.Get("/int", func(c fiber.Ctx) error { c.Locals("demo", 55) return c.SendStatus(fiber.StatusOK) }) benchmarkSetup(bb, app, "/int") }) b.Run("WithCustomDone", func(bb *testing.B) { app := fiber.New() app.Use(New(Config{ Done: func(c fiber.Ctx, logString []byte) { if c.Response().StatusCode() == fiber.StatusOK { io.Discard.Write(logString) //nolint:errcheck // ignore error } }, Stream: io.Discard, })) app.Get("/logging", func(ctx fiber.Ctx) error { return ctx.SendStatus(fiber.StatusOK) }) benchmarkSetup(bb, app, "/logging") }) b.Run("WithAllTags", func(bb *testing.B) { app := fiber.New() app.Use(New(Config{ Format: "${pid}${reqHeaders}${referer}${scheme}${protocol}${ip}${ips}${host}${url}${ua}${body}${route}${black}${red}${green}${yellow}${blue}${magenta}${cyan}${white}${reset}${error}${reqHeader:test}${query:test}${form:test}${cookie:test}${non}", Stream: io.Discard, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) benchmarkSetup(bb, app, "/") }) b.Run("Streaming", func(bb *testing.B) { app := fiber.New() app.Use(New(Config{ Format: "${bytesReceived} ${bytesSent} ${status}", Stream: io.Discard, })) app.Get("/", func(c fiber.Ctx) error { c.Set("Connection", "keep-alive") c.Set("Transfer-Encoding", "chunked") c.RequestCtx().SetBodyStreamWriter(func(w *bufio.Writer) { var i int for { i++ msg := fmt.Sprintf("%d - the time is %v", i, time.Now()) fmt.Fprintf(w, "data: Message: %s\n\n", msg) err := w.Flush() if err != nil { break } if i == 10 { break } } }) return nil }) benchmarkSetup(bb, app, "/") }) b.Run("WithBody", func(bb *testing.B) { app := fiber.New() app.Use(New(Config{ Format: "${resBody}", Stream: io.Discard, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Sample response body") }) benchmarkSetup(bb, app, "/") }) } // go test -v -run=^$ -bench=Benchmark_Logger_Parallel$ -benchmem -count=4 func Benchmark_Logger_Parallel(b *testing.B) { b.Run("NoMiddleware", func(bb *testing.B) { app := fiber.New() app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) benchmarkSetupParallel(bb, app, "/") }) b.Run("WithBytesAndStatus", func(bb *testing.B) { app := fiber.New() app.Use(New(Config{ Format: "${bytesReceived} ${bytesSent} ${status}", Stream: io.Discard, })) app.Get("/", func(c fiber.Ctx) error { c.Set("test", "test") return c.SendString("Hello, World!") }) benchmarkSetupParallel(bb, app, "/") }) b.Run("DefaultFormat", func(bb *testing.B) { app := fiber.New() app.Use(New(Config{ Stream: io.Discard, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) benchmarkSetupParallel(bb, app, "/") }) b.Run("DefaultFormatWithFiberLog", func(bb *testing.B) { app := fiber.New() logger := fiberlog.DefaultLogger[*log.Logger]() logger.SetOutput(io.Discard) app.Use(New(Config{ Stream: LoggerToWriter(logger, fiberlog.LevelDebug), })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) benchmarkSetupParallel(bb, app, "/") }) b.Run("DefaultFormatDisableColors", func(bb *testing.B) { app := fiber.New() app.Use(New(Config{ Stream: io.Discard, DisableColors: true, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) benchmarkSetupParallel(bb, app, "/") }) b.Run("DefaultFormatForceColors", func(bb *testing.B) { app := fiber.New() app.Use(New(Config{ Stream: io.Discard, ForceColors: true, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) benchmarkSetupParallel(bb, app, "/") }) b.Run("WithTagParameter", func(bb *testing.B) { app := fiber.New() app.Use(New(Config{ Format: "${bytesReceived} ${bytesSent} ${status} ${reqHeader:test}", Stream: io.Discard, })) app.Get("/", func(c fiber.Ctx) error { c.Set("test", "test") return c.SendString("Hello, World!") }) benchmarkSetupParallel(bb, app, "/") }) b.Run("WithLocals", func(bb *testing.B) { app := fiber.New() app.Use(New(Config{ Format: "${locals:demo}", Stream: io.Discard, })) app.Get("/", func(c fiber.Ctx) error { c.Locals("demo", "johndoe") return c.SendStatus(fiber.StatusOK) }) benchmarkSetupParallel(bb, app, "/") }) b.Run("WithLocalsInt", func(bb *testing.B) { app := fiber.New() app.Use(New(Config{ Format: "${locals:demo}", Stream: io.Discard, })) app.Get("/int", func(c fiber.Ctx) error { c.Locals("demo", 55) return c.SendStatus(fiber.StatusOK) }) benchmarkSetupParallel(bb, app, "/int") }) b.Run("WithCustomDone", func(bb *testing.B) { app := fiber.New() app.Use(New(Config{ Done: func(c fiber.Ctx, logString []byte) { if c.Response().StatusCode() == fiber.StatusOK { io.Discard.Write(logString) //nolint:errcheck // ignore error } }, Stream: io.Discard, })) app.Get("/logging", func(ctx fiber.Ctx) error { return ctx.SendStatus(fiber.StatusOK) }) benchmarkSetupParallel(bb, app, "/logging") }) b.Run("WithAllTags", func(bb *testing.B) { app := fiber.New() app.Use(New(Config{ Format: "${pid}${reqHeaders}${referer}${scheme}${protocol}${ip}${ips}${host}${url}${ua}${body}${route}${black}${red}${green}${yellow}${blue}${magenta}${cyan}${white}${reset}${error}${reqHeader:test}${query:test}${form:test}${cookie:test}${non}", Stream: io.Discard, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) benchmarkSetupParallel(bb, app, "/") }) b.Run("Streaming", func(bb *testing.B) { app := fiber.New() app.Use(New(Config{ Format: "${bytesReceived} ${bytesSent} ${status}", Stream: io.Discard, })) app.Get("/", func(c fiber.Ctx) error { c.Set("Connection", "keep-alive") c.Set("Transfer-Encoding", "chunked") c.RequestCtx().SetBodyStreamWriter(func(w *bufio.Writer) { var i int for { i++ msg := fmt.Sprintf("%d - the time is %v", i, time.Now()) fmt.Fprintf(w, "data: Message: %s\n\n", msg) err := w.Flush() if err != nil { break } if i == 10 { break } } }) return nil }) benchmarkSetupParallel(bb, app, "/") }) b.Run("WithBody", func(bb *testing.B) { app := fiber.New() app.Use(New(Config{ Format: "${resBody}", Stream: io.Discard, })) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Sample response body") }) benchmarkSetupParallel(bb, app, "/") }) } ================================================ FILE: middleware/logger/tags.go ================================================ package logger import ( "fmt" "maps" "strings" "github.com/gofiber/fiber/v3" ) // Logger variables const ( TagPid = "pid" TagTime = "time" TagReferer = "referer" TagProtocol = "protocol" TagScheme = "scheme" TagPort = "port" TagIP = "ip" TagIPs = "ips" TagHost = "host" TagMethod = "method" TagPath = "path" TagURL = "url" TagUA = "ua" TagLatency = "latency" TagStatus = "status" TagResBody = "resBody" TagReqHeaders = "reqHeaders" TagQueryStringParams = "queryParams" TagBody = "body" TagBytesSent = "bytesSent" TagBytesReceived = "bytesReceived" TagRoute = "route" TagError = "error" TagReqHeader = "reqHeader:" TagRespHeader = "respHeader:" TagLocals = "locals:" TagQuery = "query:" TagForm = "form:" TagCookie = "cookie:" TagBlack = "black" TagRed = "red" TagGreen = "green" TagYellow = "yellow" TagBlue = "blue" TagMagenta = "magenta" TagCyan = "cyan" TagWhite = "white" TagReset = "reset" ) // createTagMap function merged the default with the custom tags func createTagMap(cfg *Config) map[string]LogFunc { // Set default tags tagFunctions := map[string]LogFunc{ TagReferer: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) { return output.WriteString(c.Get(fiber.HeaderReferer)) }, TagProtocol: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) { return output.WriteString(c.Protocol()) }, TagScheme: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) { return output.WriteString(c.Scheme()) }, TagPort: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) { return output.WriteString(c.Port()) }, TagIP: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) { return output.WriteString(c.IP()) }, TagIPs: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) { return output.WriteString(c.Get(fiber.HeaderXForwardedFor)) }, TagHost: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) { return output.WriteString(c.Hostname()) }, TagPath: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) { return output.WriteString(c.Path()) }, TagURL: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) { return output.WriteString(c.OriginalURL()) }, TagUA: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) { return output.WriteString(c.Get(fiber.HeaderUserAgent)) }, TagBody: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) { return output.Write(c.Body()) }, TagBytesReceived: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) { return appendInt(output, c.Request().Header.ContentLength()) }, TagBytesSent: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) { return appendInt(output, c.Response().Header.ContentLength()) }, TagRoute: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) { return output.WriteString(c.Route().Path) }, TagResBody: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) { return output.Write(c.Response().Body()) }, TagReqHeaders: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) { out := make(map[string][]string) if err := c.Bind().Header(&out); err != nil { return 0, err } reqHeaders := make([]string, 0, len(out)) for k, v := range out { reqHeaders = append(reqHeaders, k+"="+strings.Join(v, ",")) } return output.WriteString(strings.Join(reqHeaders, "&")) }, TagQueryStringParams: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) { return output.WriteString(c.Request().URI().QueryArgs().String()) }, TagBlack: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) { return output.WriteString(c.App().Config().ColorScheme.Black) }, TagRed: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) { return output.WriteString(c.App().Config().ColorScheme.Red) }, TagGreen: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) { return output.WriteString(c.App().Config().ColorScheme.Green) }, TagYellow: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) { return output.WriteString(c.App().Config().ColorScheme.Yellow) }, TagBlue: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) { return output.WriteString(c.App().Config().ColorScheme.Blue) }, TagMagenta: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) { return output.WriteString(c.App().Config().ColorScheme.Magenta) }, TagCyan: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) { return output.WriteString(c.App().Config().ColorScheme.Cyan) }, TagWhite: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) { return output.WriteString(c.App().Config().ColorScheme.White) }, TagReset: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) { return output.WriteString(c.App().Config().ColorScheme.Reset) }, TagError: func(output Buffer, c fiber.Ctx, data *Data, _ string) (int, error) { if data.ChainErr != nil { if cfg.enableColors { colors := c.App().Config().ColorScheme return fmt.Fprintf(output, "%s%s%s", colors.Red, data.ChainErr.Error(), colors.Reset) } return output.WriteString(data.ChainErr.Error()) } return output.WriteString("-") }, TagReqHeader: func(output Buffer, c fiber.Ctx, _ *Data, extraParam string) (int, error) { return output.WriteString(c.Get(extraParam)) }, TagRespHeader: func(output Buffer, c fiber.Ctx, _ *Data, extraParam string) (int, error) { return output.WriteString(c.GetRespHeader(extraParam)) }, TagQuery: func(output Buffer, c fiber.Ctx, _ *Data, extraParam string) (int, error) { return output.WriteString(fiber.Query[string](c, extraParam)) }, TagForm: func(output Buffer, c fiber.Ctx, _ *Data, extraParam string) (int, error) { return output.WriteString(c.FormValue(extraParam)) }, TagCookie: func(output Buffer, c fiber.Ctx, _ *Data, extraParam string) (int, error) { return output.WriteString(c.Cookies(extraParam)) }, TagLocals: func(output Buffer, c fiber.Ctx, _ *Data, extraParam string) (int, error) { switch v := c.Locals(extraParam).(type) { case []byte: return output.Write(v) case string: return output.WriteString(v) case nil: return 0, nil default: return fmt.Fprintf(output, "%v", v) } }, TagStatus: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) { if cfg.enableColors { colors := c.App().Config().ColorScheme return fmt.Fprintf(output, "%s%3d%s", statusColor(c.Response().StatusCode(), &colors), c.Response().StatusCode(), colors.Reset) } return appendInt(output, c.Response().StatusCode()) }, TagMethod: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) { if cfg.enableColors { colors := c.App().Config().ColorScheme return fmt.Fprintf(output, "%s%s%s", methodColor(c.Method(), &colors), c.Method(), colors.Reset) } return output.WriteString(c.Method()) }, TagPid: func(output Buffer, _ fiber.Ctx, data *Data, _ string) (int, error) { return output.WriteString(data.Pid) }, TagLatency: func(output Buffer, _ fiber.Ctx, data *Data, _ string) (int, error) { latency := data.Stop.Sub(data.Start) return fmt.Fprintf(output, "%13v", latency) }, TagTime: func(output Buffer, _ fiber.Ctx, data *Data, _ string) (int, error) { return output.WriteString(data.Timestamp.Load().(string)) //nolint:forcetypeassert,errcheck // We always store a string in here }, } // merge with custom tags from user maps.Copy(tagFunctions, cfg.CustomTags) return tagFunctions } ================================================ FILE: middleware/logger/template_chain.go ================================================ package logger import ( "bytes" "fmt" "github.com/gofiber/utils/v2" ) // buildLogFuncChain analyzes the template and creates slices with the functions for execution and // slices with the fixed parts of the template and the parameters // // fixParts contains the fixed parts of the template or parameters if a function is stored in the funcChain at this position // funcChain contains for the parts which exist the functions for the dynamic parts // funcChain and fixParts always have the same length and contain nil for the parts where no data is required in the chain, // if a function exists for the part, a parameter for it can also exist in the fixParts slice func buildLogFuncChain(cfg *Config, tagFunctions map[string]LogFunc) ([][]byte, []LogFunc, error) { // process flow is copied from the fasttemplate flow https://github.com/valyala/fasttemplate/blob/2a2d1afadadf9715bfa19683cdaeac8347e5d9f9/template.go#L23-L62 templateB := utils.UnsafeBytes(cfg.Format) startTagB := utils.UnsafeBytes(startTag) endTagB := utils.UnsafeBytes(endTag) paramSeparatorB := utils.UnsafeBytes(paramSeparator) var fixParts [][]byte var funcChain []LogFunc for { before, after, found := bytes.Cut(templateB, startTagB) if !found { // no starting tag found in the existing template part break } // add fixed part funcChain = append(funcChain, nil) fixParts = append(fixParts, before) templateB = after before, after, found = bytes.Cut(templateB, endTagB) if !found { // cannot find end tag - just write it to the output. funcChain = append(funcChain, nil) fixParts = append(fixParts, startTagB) break } // ## function block ## // first check for tags with parameters tag, param, foundParam := bytes.Cut(before, paramSeparatorB) if foundParam { logFunc, ok := tagFunctions[utils.UnsafeString(tag)+paramSeparator] if !ok { return nil, nil, fmt.Errorf("%w: %q", ErrTemplateParameterMissing, utils.UnsafeString(before)) } funcChain = append(funcChain, logFunc) // add param to the fixParts fixParts = append(fixParts, param) } else if logFunc, ok := tagFunctions[utils.UnsafeString(before)]; ok { // add functions without parameter funcChain = append(funcChain, logFunc) fixParts = append(fixParts, nil) } // ## function block end ## // reduce the template string templateB = after } // set the rest funcChain = append(funcChain, nil) fixParts = append(fixParts, templateB) return fixParts, funcChain, nil } ================================================ FILE: middleware/logger/utils.go ================================================ package logger import ( "io" "github.com/gofiber/fiber/v3" fiberlog "github.com/gofiber/fiber/v3/log" "github.com/gofiber/utils/v2" ) func methodColor(method string, colors *fiber.Colors) string { if colors == nil { return "" } switch method { case fiber.MethodGet: return colors.Cyan case fiber.MethodPost: return colors.Green case fiber.MethodPut: return colors.Yellow case fiber.MethodDelete: return colors.Red case fiber.MethodPatch: return colors.White case fiber.MethodHead: return colors.Magenta case fiber.MethodOptions: return colors.Blue default: return colors.Reset } } func statusColor(code int, colors *fiber.Colors) string { if colors == nil { return "" } switch { case code >= fiber.StatusOK && code < fiber.StatusMultipleChoices: return colors.Green case code >= fiber.StatusMultipleChoices && code < fiber.StatusBadRequest: return colors.Blue case code >= fiber.StatusBadRequest && code < fiber.StatusInternalServerError: return colors.Yellow default: return colors.Red } } type customLoggerWriter[T any] struct { loggerInstance fiberlog.AllLogger[T] level fiberlog.Level } // Write implements io.Writer and forwards the payload to the configured logger. func (cl *customLoggerWriter[T]) Write(p []byte) (int, error) { switch cl.level { case fiberlog.LevelTrace: cl.loggerInstance.Trace(utils.UnsafeString(p)) case fiberlog.LevelDebug: cl.loggerInstance.Debug(utils.UnsafeString(p)) case fiberlog.LevelInfo: cl.loggerInstance.Info(utils.UnsafeString(p)) case fiberlog.LevelWarn: cl.loggerInstance.Warn(utils.UnsafeString(p)) case fiberlog.LevelError: cl.loggerInstance.Error(utils.UnsafeString(p)) default: return 0, nil } return len(p), nil } // LoggerToWriter is a helper function that returns an io.Writer that writes to a custom logger. // You can integrate 3rd party loggers such as zerolog, logrus, etc. to logger middleware using this function. // // Valid levels: fiberlog.LevelInfo, fiberlog.LevelTrace, fiberlog.LevelWarn, fiberlog.LevelDebug, fiberlog.LevelError func LoggerToWriter[T any](logger fiberlog.AllLogger[T], level fiberlog.Level) io.Writer { // Check if customLogger is nil if logger == nil { fiberlog.Panic("LoggerToWriter: customLogger must not be nil") } // Check if level is valid if level == fiberlog.LevelFatal || level == fiberlog.LevelPanic { fiberlog.Panic("LoggerToWriter: invalid level") } return &customLoggerWriter[T]{ level: level, loggerInstance: logger, } } ================================================ FILE: middleware/paginate/config.go ================================================ package paginate import ( "slices" "github.com/gofiber/fiber/v3" ) // Config defines the config for the pagination middleware. type Config struct { // Next defines a function to skip this middleware when returned true. Next func(c fiber.Ctx) bool // PageKey is the query string key for page number. // // Optional. Default: "page" PageKey string // LimitKey is the query string key for limit. // // Optional. Default: "limit" LimitKey string // SortKey is the query string key for sort. // // Optional. Default: "" SortKey string // DefaultSort is the default sort field. // // Optional. Default: "id" DefaultSort string // CursorKey is the query string key for cursor-based pagination. // // Optional. Default: "cursor" CursorKey string // OffsetKey is the query string key for offset. // // Optional. Default: "offset" OffsetKey string // CursorParam is an optional alias for the cursor query key. // // Optional. Default: "" CursorParam string // AllowedSorts is the list of allowed sort fields. // // Optional. Default: nil AllowedSorts []string // DefaultPage is the default page number. // // Optional. Default: 1 DefaultPage int // DefaultLimit is the default items per page. // // Optional. Default: 10 DefaultLimit int // MaxLimit is the maximum items per page. // // Optional. Default: 100 MaxLimit int } // ConfigDefault is the default config. var ConfigDefault = Config{ Next: nil, PageKey: "page", DefaultPage: 1, LimitKey: "limit", DefaultLimit: 10, MaxLimit: DefaultMaxLimit, DefaultSort: "id", OffsetKey: "offset", CursorKey: "cursor", } func configDefault(config ...Config) Config { if len(config) < 1 { return ConfigDefault } cfg := config[0] if cfg.Next == nil { cfg.Next = ConfigDefault.Next } if cfg.PageKey == "" { cfg.PageKey = ConfigDefault.PageKey } if cfg.DefaultLimit < 1 { cfg.DefaultLimit = ConfigDefault.DefaultLimit } if cfg.LimitKey == "" { cfg.LimitKey = ConfigDefault.LimitKey } if cfg.DefaultPage < 1 { cfg.DefaultPage = ConfigDefault.DefaultPage } if cfg.CursorKey == "" { cfg.CursorKey = ConfigDefault.CursorKey } if cfg.DefaultSort == "" { cfg.DefaultSort = ConfigDefault.DefaultSort } if cfg.OffsetKey == "" { cfg.OffsetKey = ConfigDefault.OffsetKey } if cfg.MaxLimit < 1 { cfg.MaxLimit = ConfigDefault.MaxLimit } if cfg.DefaultLimit > cfg.MaxLimit { cfg.DefaultLimit = cfg.MaxLimit } if len(cfg.AllowedSorts) > 0 && !slices.Contains(cfg.AllowedSorts, cfg.DefaultSort) { cfg.DefaultSort = cfg.AllowedSorts[0] } return cfg } ================================================ FILE: middleware/paginate/page_info.go ================================================ package paginate import ( "encoding/base64" "encoding/json" "errors" "fmt" "net/url" "github.com/gofiber/utils/v2" ) // ErrCursorEncode is returned when cursor values cannot be encoded. var ErrCursorEncode = errors.New("paginate: failed to encode cursor values") // SortOrder represents sort order. type SortOrder string const ( ASC SortOrder = "asc" DESC SortOrder = "desc" ) // SortField represents a sort field with direction. type SortField struct { Field string `json:"field"` Order SortOrder `json:"order"` } // SortOrderFromString returns a SortOrder from a string (case-insensitive). func SortOrderFromString(s string) SortOrder { if utils.EqualFold(s, "desc") { return DESC } return ASC } // PageInfo contains pagination information. type PageInfo struct { cursorData map[string]any jsonMarshal utils.JSONMarshal jsonUnmarshal utils.JSONUnmarshal Cursor string `json:"cursor,omitempty"` NextCursor string `json:"next_cursor,omitempty"` Sort []SortField `json:"sort"` Page int `json:"page"` Limit int `json:"limit"` Offset int `json:"offset"` HasMore bool `json:"has_more,omitempty"` } // NewPageInfo creates a new PageInfo. func NewPageInfo(page, limit, offset int, sort []SortField) *PageInfo { return &PageInfo{ Page: page, Limit: limit, Offset: offset, Sort: sort, } } // Start returns the start index based on page/limit or offset. func (p *PageInfo) Start() int { if p.Offset > 0 { return p.Offset } if p.Page < 1 { return 0 } return (p.Page - 1) * p.Limit } // SortBy adds a sort field. Chainable. func (p *PageInfo) SortBy(field string, order SortOrder) *PageInfo { p.Sort = append(p.Sort, SortField{Field: field, Order: order}) return p } // NextPageURLWithKeys returns the URL for the next page using custom query keys. func (p *PageInfo) NextPageURLWithKeys(baseURL, pageKey, limitKey string) string { return buildPaginationURL(baseURL, pageKey, utils.FormatInt(int64(p.Page+1)), limitKey, utils.FormatInt(int64(p.Limit))) } // NextPageURL returns the URL for the next page. func (p *PageInfo) NextPageURL(baseURL string) string { return p.NextPageURLWithKeys(baseURL, "page", "limit") } // PreviousPageURLWithKeys returns the URL for the previous page using custom query keys. // Returns empty string if on page 1. func (p *PageInfo) PreviousPageURLWithKeys(baseURL, pageKey, limitKey string) string { if p.Page > 1 { return buildPaginationURL(baseURL, pageKey, utils.FormatInt(int64(p.Page-1)), limitKey, utils.FormatInt(int64(p.Limit))) } return "" } // PreviousPageURL returns the URL for the previous page. // Returns empty string if on page 1. func (p *PageInfo) PreviousPageURL(baseURL string) string { return p.PreviousPageURLWithKeys(baseURL, "page", "limit") } // NextCursorURLWithKeys returns the URL for the next cursor page using custom query keys. // Returns empty string if HasMore is false. func (p *PageInfo) NextCursorURLWithKeys(baseURL, cursorKey, limitKey string) string { if !p.HasMore { return "" } return buildPaginationURL(baseURL, cursorKey, p.NextCursor, limitKey, utils.FormatInt(int64(p.Limit))) } // NextCursorURL returns the URL for the next cursor page. // Returns empty string if HasMore is false. func (p *PageInfo) NextCursorURL(baseURL string) string { return p.NextCursorURLWithKeys(baseURL, "cursor", "limit") } // buildPaginationURL parses baseURL and sets/replaces two query parameters, // preserving any existing query string values. func buildPaginationURL(baseURL, pageParam, pageValue, limitParam, limitValue string) string { u, err := url.Parse(baseURL) if err != nil { return baseURL } q := u.Query() q.Set(pageParam, pageValue) q.Set(limitParam, limitValue) u.RawQuery = q.Encode() return u.String() } // CursorValues returns the decoded cursor key-value map. // If the cursor was parsed by the middleware, the pre-parsed data is returned. // Otherwise it decodes the opaque cursor string. // Returns nil if cursor is empty or invalid. func (p *PageInfo) CursorValues() map[string]any { if p.cursorData != nil { return p.cursorData } if p.Cursor == "" { return nil } if len(p.Cursor) > maxCursorLen { return nil } data, err := base64.RawURLEncoding.DecodeString(p.Cursor) if err != nil { return nil } var values map[string]any unmarshal := p.jsonUnmarshal if unmarshal == nil { unmarshal = json.Unmarshal } if err := unmarshal(data, &values); err != nil { return nil } p.cursorData = values return values } // SetNextCursor encodes a key-value map into an opaque cursor token // and sets both NextCursor and HasMore on the PageInfo. func (p *PageInfo) SetNextCursor(values map[string]any) error { marshal := p.jsonMarshal if marshal == nil { marshal = json.Marshal } data, err := marshal(values) if err != nil { return fmt.Errorf("%w: %w", ErrCursorEncode, err) } encoded := base64.RawURLEncoding.EncodeToString(data) if len(encoded) > maxCursorLen { return fmt.Errorf("%w: cursor token exceeds maximum length (%d)", ErrCursorEncode, maxCursorLen) } p.NextCursor = encoded p.HasMore = true return nil } ================================================ FILE: middleware/paginate/paginate.go ================================================ package paginate import ( "encoding/base64" "slices" "strings" "github.com/gofiber/fiber/v3" "github.com/gofiber/utils/v2" ) // The contextKey type is unexported to prevent collisions with context keys defined in // other packages. type contextKey int const ( pageInfoKey contextKey = iota ) // DefaultMaxLimit is the default maximum limit allowed. const DefaultMaxLimit = 100 // maxCursorLen is the maximum allowed cursor string length. const maxCursorLen = 2048 // New creates a new pagination middleware handler. func New(config ...Config) fiber.Handler { cfg := configDefault(config...) return func(c fiber.Ctx) error { if cfg.Next != nil && cfg.Next(c) { return c.Next() } appCfg := c.App().Config() limit := fiber.Query(c, cfg.LimitKey, cfg.DefaultLimit) if limit < 1 { limit = cfg.DefaultLimit } if limit > cfg.MaxLimit { limit = cfg.MaxLimit } sorts := parseSortQuery(c.Query(cfg.SortKey), cfg.AllowedSorts, cfg.DefaultSort) cursorRaw := c.Query(cfg.CursorKey) if cursorRaw == "" && cfg.CursorParam != "" { cursorRaw = c.Query(cfg.CursorParam) } if cursorRaw != "" { if len(cursorRaw) > maxCursorLen { return fiber.NewError(fiber.StatusBadRequest, "cursor too long") } data, err := base64.RawURLEncoding.DecodeString(cursorRaw) if err != nil { return fiber.NewError(fiber.StatusBadRequest, "invalid cursor") } var obj map[string]any if err := appCfg.JSONDecoder(data, &obj); err != nil { return fiber.NewError(fiber.StatusBadRequest, "invalid cursor") } pageInfo := &PageInfo{ Limit: limit, Sort: sorts, Cursor: cursorRaw, cursorData: obj, jsonMarshal: appCfg.JSONEncoder, jsonUnmarshal: appCfg.JSONDecoder, } fiber.StoreInContext(c, pageInfoKey, pageInfo) return c.Next() } page := max(fiber.Query(c, cfg.PageKey, cfg.DefaultPage), 1) offset := max(fiber.Query(c, cfg.OffsetKey, 0), 0) pageInfo := NewPageInfo(page, limit, offset, sorts) pageInfo.jsonMarshal = appCfg.JSONEncoder pageInfo.jsonUnmarshal = appCfg.JSONDecoder fiber.StoreInContext(c, pageInfoKey, pageInfo) return c.Next() } } // FromContext returns the PageInfo from the request context. // It accepts fiber.CustomCtx, fiber.Ctx, *fasthttp.RequestCtx, and context.Context. // Returns nil and false if no PageInfo is stored. func FromContext(ctx any) (*PageInfo, bool) { return fiber.ValueFromContext[*PageInfo](ctx, pageInfoKey) } func parseSortQuery(query string, allowedSorts []string, defaultSort string) []SortField { if query == "" { return []SortField{{Field: defaultSort, Order: ASC}} } fields := strings.Split(query, ",") sortFields := make([]SortField, 0, len(fields)) for _, field := range fields { field = utils.TrimSpace(field) if field == "" { continue } order := ASC if strings.HasPrefix(field, "-") { order = DESC field = utils.TrimSpace(field[1:]) } if field == "" { continue } if len(allowedSorts) == 0 || slices.Contains(allowedSorts, field) { sortFields = append(sortFields, SortField{Field: field, Order: order}) } } if len(sortFields) == 0 { return []SortField{{Field: defaultSort, Order: ASC}} } return sortFields } ================================================ FILE: middleware/paginate/paginate_test.go ================================================ package paginate import ( "encoding/base64" "encoding/json" "fmt" "net/http" "net/http/httptest" "testing" "github.com/gofiber/fiber/v3" "github.com/stretchr/testify/require" ) type paginateResponse struct { NextPageURL string `json:"next_page_url"` PreviousPageURL string `json:"prev_page_url"` Sort []SortField `json:"sort"` Page int `json:"page"` Limit int `json:"limit"` Offset int `json:"offset"` Start int `json:"start"` } type cursorResponse struct { Cursor string `json:"cursor"` NextCursor string `json:"next_cursor"` Sort []SortField `json:"sort"` Limit int `json:"limit"` HasMore bool `json:"has_more"` } // --- Config tests --- func Test_ConfigDefault(t *testing.T) { t.Parallel() cfg := configDefault() require.Equal(t, "page", cfg.PageKey) require.Equal(t, 1, cfg.DefaultPage) require.Equal(t, "limit", cfg.LimitKey) require.Equal(t, 10, cfg.DefaultLimit) require.Equal(t, DefaultMaxLimit, cfg.MaxLimit) } func Test_ConfigOverride(t *testing.T) { t.Parallel() cfg := configDefault(Config{ PageKey: "p", LimitKey: "l", DefaultPage: 5, DefaultLimit: 50, }) require.Equal(t, "p", cfg.PageKey) require.Equal(t, "l", cfg.LimitKey) require.Equal(t, 5, cfg.DefaultPage) require.Equal(t, 50, cfg.DefaultLimit) } func Test_ConfigDefaultCursorKey(t *testing.T) { t.Parallel() cfg := configDefault() require.Equal(t, "cursor", cfg.CursorKey) } func Test_ConfigOverrideCursorKey(t *testing.T) { t.Parallel() cfg := configDefault(Config{ CursorKey: "after", CursorParam: "starting_after", }) require.Equal(t, "after", cfg.CursorKey) require.Equal(t, "starting_after", cfg.CursorParam) } func Test_ConfigNegativeDefaults(t *testing.T) { t.Parallel() cfg := configDefault(Config{ DefaultPage: -1, DefaultLimit: -1, }) require.Equal(t, 1, cfg.DefaultPage) require.Equal(t, 10, cfg.DefaultLimit) } // --- PageInfo tests --- func Test_SortOrderFromString(t *testing.T) { t.Parallel() tests := []struct { input string expected SortOrder }{ {"asc", ASC}, {"desc", DESC}, {"DESC", DESC}, {"Desc", DESC}, {"invalid", ASC}, {"", ASC}, } for _, tt := range tests { t.Run(tt.input, func(t *testing.T) { t.Parallel() require.Equal(t, tt.expected, SortOrderFromString(tt.input)) }) } } func Test_PageInfoStart(t *testing.T) { t.Parallel() tests := []struct { name string pageInfo PageInfo expected int }{ {"Page 1, limit 10", PageInfo{Page: 1, Limit: 10}, 0}, {"Page 2, limit 10", PageInfo{Page: 2, Limit: 10}, 10}, {"Page 3, limit 20", PageInfo{Page: 3, Limit: 20}, 40}, {"With offset", PageInfo{Page: 2, Limit: 10, Offset: 25}, 25}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() require.Equal(t, tt.expected, tt.pageInfo.Start()) }) } } func Test_PageInfoSortBy(t *testing.T) { t.Parallel() p := NewPageInfo(1, 10, 0, nil) p.SortBy("name", ASC).SortBy("date", DESC) require.Len(t, p.Sort, 2) require.Equal(t, "name", p.Sort[0].Field) require.Equal(t, ASC, p.Sort[0].Order) require.Equal(t, "date", p.Sort[1].Field) require.Equal(t, DESC, p.Sort[1].Order) } func Test_PageInfoNextPageURL(t *testing.T) { t.Parallel() tests := []struct { name string baseURL string expected string pageInfo PageInfo }{ { name: "Middle page", baseURL: "https://example.com/users", expected: "https://example.com/users?limit=10&page=3", pageInfo: PageInfo{Page: 2, Limit: 10}, }, { name: "First page", baseURL: "https://example.com/users", expected: "https://example.com/users?limit=20&page=2", pageInfo: PageInfo{Page: 1, Limit: 20}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() require.Equal(t, tt.expected, tt.pageInfo.NextPageURL(tt.baseURL)) }) } } func Test_PageInfoPreviousPageURL(t *testing.T) { t.Parallel() tests := []struct { name string baseURL string expected string pageInfo PageInfo }{ { name: "Middle page", baseURL: "https://example.com/users", expected: "https://example.com/users?limit=10&page=1", pageInfo: PageInfo{Page: 2, Limit: 10}, }, { name: "First page returns empty", baseURL: "https://example.com/users", expected: "", pageInfo: PageInfo{Page: 1, Limit: 20}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() require.Equal(t, tt.expected, tt.pageInfo.PreviousPageURL(tt.baseURL)) }) } } func Test_PageInfoStartCursorMode(t *testing.T) { t.Parallel() // In cursor mode, Page is 0 (not set). Start() should return 0, not negative. p := &PageInfo{Page: 0, Limit: 20} require.Equal(t, 0, p.Start()) } func Test_PageInfoNextPageURLWithExistingQueryParams(t *testing.T) { t.Parallel() p := PageInfo{Page: 2, Limit: 10} result := p.NextPageURL("https://example.com/users?filter=active") require.Contains(t, result, "filter=active") require.Contains(t, result, "page=3") require.Contains(t, result, "limit=10") } func Test_PageInfoPreviousPageURLWithExistingQueryParams(t *testing.T) { t.Parallel() p := PageInfo{Page: 3, Limit: 10} result := p.PreviousPageURL("https://example.com/users?filter=active") require.Contains(t, result, "filter=active") require.Contains(t, result, "page=2") require.Contains(t, result, "limit=10") } func Test_PageInfoCursorFields(t *testing.T) { t.Parallel() p := &PageInfo{ Cursor: "abc123", HasMore: true, NextCursor: "def456", } require.Equal(t, "abc123", p.Cursor) require.True(t, p.HasMore) require.Equal(t, "def456", p.NextCursor) } func Test_CursorValuesRoundTrip(t *testing.T) { t.Parallel() original := map[string]any{ "id": float64(42), "created_at": "2026-01-01T00:00:00Z", } p := &PageInfo{} require.NoError(t, p.SetNextCursor(original)) require.True(t, p.HasMore) require.NotEmpty(t, p.NextCursor) p2 := &PageInfo{Cursor: p.NextCursor} decoded := p2.CursorValues() require.NotNil(t, decoded) require.InEpsilon(t, float64(42), decoded["id"], 0) require.Equal(t, "2026-01-01T00:00:00Z", decoded["created_at"]) } func Test_CursorValuesEmptyCursor(t *testing.T) { t.Parallel() p := &PageInfo{Cursor: ""} require.Nil(t, p.CursorValues()) } func Test_CursorValuesInvalidBase64(t *testing.T) { t.Parallel() p := &PageInfo{Cursor: "not-valid-base64!!!"} require.Nil(t, p.CursorValues()) } func Test_CursorValuesInvalidJSON(t *testing.T) { t.Parallel() p := &PageInfo{Cursor: "bm90LWpzb24"} require.Nil(t, p.CursorValues()) } func Test_NextCursorURL(t *testing.T) { t.Parallel() t.Run("with HasMore", func(t *testing.T) { t.Parallel() p := &PageInfo{Limit: 20} require.NoError(t, p.SetNextCursor(map[string]any{"id": float64(42)})) url := p.NextCursorURL("https://example.com/users") expected := fmt.Sprintf("https://example.com/users?cursor=%s&limit=20", p.NextCursor) require.Equal(t, expected, url) }) t.Run("without HasMore", func(t *testing.T) { t.Parallel() p := &PageInfo{Limit: 20} require.Empty(t, p.NextCursorURL("https://example.com/users")) }) } func Test_SetNextCursorSetsFields(t *testing.T) { t.Parallel() p := &PageInfo{Limit: 10} require.NoError(t, p.SetNextCursor(map[string]any{"id": float64(1)})) require.True(t, p.HasMore) require.NotEmpty(t, p.NextCursor) } // --- Middleware handler tests --- func Test_PaginateWithQueries(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ DefaultSort: "id", })) app.Get("/", func(c fiber.Ctx) error { pageInfo, ok := FromContext(c) if !ok { return fiber.ErrBadRequest } return c.JSON(paginateResponse{ Page: pageInfo.Page, Limit: pageInfo.Limit, Offset: pageInfo.Offset, Start: pageInfo.Start(), Sort: pageInfo.Sort, NextPageURL: pageInfo.NextPageURL(c.BaseURL()), PreviousPageURL: pageInfo.PreviousPageURL(c.BaseURL()), }) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?page=2&limit=20", http.NoBody)) require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests require.Equal(t, fiber.StatusOK, resp.StatusCode) var body paginateResponse require.NoError(t, json.NewDecoder(resp.Body).Decode(&body)) require.Equal(t, 2, body.Page) require.Equal(t, 20, body.Limit) require.Equal(t, 0, body.Offset) require.Equal(t, 20, body.Start) require.Equal(t, "http://example.com?limit=20&page=3", body.NextPageURL) require.Equal(t, "http://example.com?limit=20&page=1", body.PreviousPageURL) require.Equal(t, []SortField{{Field: "id", Order: ASC}}, body.Sort) } func Test_PaginateWithOffset(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { pageInfo, ok := FromContext(c) if !ok { return fiber.ErrBadRequest } return c.JSON(paginateResponse{ Page: pageInfo.Page, Limit: pageInfo.Limit, Offset: pageInfo.Offset, Start: pageInfo.Start(), Sort: pageInfo.Sort, }) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?offset=20&limit=20", http.NoBody)) require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests require.Equal(t, fiber.StatusOK, resp.StatusCode) var body paginateResponse require.NoError(t, json.NewDecoder(resp.Body).Decode(&body)) require.Equal(t, 1, body.Page) require.Equal(t, 20, body.Limit) require.Equal(t, 20, body.Offset) require.Equal(t, 20, body.Start) } func Test_PaginateCheckDefaultsWhenNoQueries(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { pageInfo, ok := FromContext(c) if !ok { return fiber.ErrBadRequest } return c.JSON(paginateResponse{ Page: pageInfo.Page, Limit: pageInfo.Limit, Offset: pageInfo.Offset, Start: pageInfo.Start(), Sort: pageInfo.Sort, }) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests require.Equal(t, fiber.StatusOK, resp.StatusCode) var body paginateResponse require.NoError(t, json.NewDecoder(resp.Body).Decode(&body)) require.Equal(t, 1, body.Page) require.Equal(t, 10, body.Limit) require.Equal(t, 0, body.Offset) require.Equal(t, 0, body.Start) require.Equal(t, []SortField{{Field: "id", Order: ASC}}, body.Sort) } func Test_PaginateCheckDefaultsWhenNoPage(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { pageInfo, ok := FromContext(c) if !ok { return fiber.ErrBadRequest } return c.JSON(paginateResponse{ Page: pageInfo.Page, Limit: pageInfo.Limit, Start: pageInfo.Start(), }) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?limit=20", http.NoBody)) require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests var body paginateResponse require.NoError(t, json.NewDecoder(resp.Body).Decode(&body)) require.Equal(t, 1, body.Page) require.Equal(t, 20, body.Limit) require.Equal(t, 0, body.Start) } func Test_PaginateCheckDefaultsWhenNoLimit(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { pageInfo, ok := FromContext(c) if !ok { return fiber.ErrBadRequest } return c.JSON(paginateResponse{ Page: pageInfo.Page, Limit: pageInfo.Limit, Start: pageInfo.Start(), }) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?page=2", http.NoBody)) require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests var body paginateResponse require.NoError(t, json.NewDecoder(resp.Body).Decode(&body)) require.Equal(t, 2, body.Page) require.Equal(t, 10, body.Limit) require.Equal(t, 10, body.Start) } func Test_PaginateConfigDefaultPageDefaultLimit(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ DefaultPage: 100, DefaultLimit: DefaultMaxLimit, DefaultSort: "name", })) app.Get("/", func(c fiber.Ctx) error { pageInfo, ok := FromContext(c) if !ok { return fiber.ErrBadRequest } return c.JSON(paginateResponse{ Page: pageInfo.Page, Limit: pageInfo.Limit, Start: pageInfo.Start(), Sort: pageInfo.Sort, }) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests var body paginateResponse require.NoError(t, json.NewDecoder(resp.Body).Decode(&body)) require.Equal(t, 100, body.Page) require.Equal(t, DefaultMaxLimit, body.Limit) require.Equal(t, 9900, body.Start) require.Equal(t, []SortField{{Field: "name", Order: ASC}}, body.Sort) } func Test_PaginateConfigPageKeyLimitKey(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ PageKey: "site", LimitKey: "size", DefaultSort: "id", })) app.Get("/", func(c fiber.Ctx) error { pageInfo, ok := FromContext(c) if !ok { return fiber.ErrBadRequest } return c.JSON(paginateResponse{ Page: pageInfo.Page, Limit: pageInfo.Limit, Start: pageInfo.Start(), Sort: pageInfo.Sort, }) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?site=2&size=5", http.NoBody)) require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests var body paginateResponse require.NoError(t, json.NewDecoder(resp.Body).Decode(&body)) require.Equal(t, 2, body.Page) require.Equal(t, 5, body.Limit) require.Equal(t, 5, body.Start) } func Test_PaginateNegativeDefaultPageDefaultLimitValues(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ DefaultPage: -1, DefaultLimit: -1, DefaultSort: "id", })) app.Get("/", func(c fiber.Ctx) error { pageInfo, ok := FromContext(c) if !ok { return fiber.ErrBadRequest } return c.JSON(paginateResponse{ Page: pageInfo.Page, Limit: pageInfo.Limit, Start: pageInfo.Start(), }) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests var body paginateResponse require.NoError(t, json.NewDecoder(resp.Body).Decode(&body)) require.Equal(t, 1, body.Page) require.Equal(t, 10, body.Limit) require.Equal(t, 0, body.Start) } func Test_PaginateFromContextWithoutNew(t *testing.T) { t.Parallel() app := fiber.New() app.Get("/", func(c fiber.Ctx) error { _, ok := FromContext(c) if !ok { return fiber.ErrBadRequest } return c.JSON(nil) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests require.Equal(t, fiber.StatusBadRequest, resp.StatusCode) } func Test_PaginateNextSkip(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Next: func(_ fiber.Ctx) bool { return true }, })) app.Get("/", func(c fiber.Ctx) error { _, ok := FromContext(c) if !ok { return fiber.ErrBadRequest } return c.JSON(nil) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests require.Equal(t, fiber.StatusBadRequest, resp.StatusCode) } func Test_PaginateEdgeCases(t *testing.T) { t.Parallel() testCases := []struct { name string url string expectedPage int expectedLimit int }{ {"Negative page", "/?page=-1", 1, 10}, {"Page zero", "/?page=0", 1, 10}, {"Negative limit", "/?limit=-10", 1, 10}, {"Limit zero", "/?limit=0", 1, 10}, {"Limit exceeds max", "/?limit=200", 1, DefaultMaxLimit}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ DefaultSort: "id", DefaultLimit: 10, })) app.Get("/", func(c fiber.Ctx) error { pageInfo, ok := FromContext(c) if !ok { return fiber.ErrBadRequest } return c.JSON(pageInfo) }) resp, err := app.Test(httptest.NewRequest(http.MethodGet, tc.url, http.NoBody)) require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests require.Equal(t, 200, resp.StatusCode) var result PageInfo require.NoError(t, json.NewDecoder(resp.Body).Decode(&result)) require.Equal(t, tc.expectedPage, result.Page) require.Equal(t, tc.expectedLimit, result.Limit) }) } } func Test_PaginateWithMultipleSorting(t *testing.T) { t.Parallel() testCases := []struct { name string url string expectedSort []SortField }{ {"Default Sort", "/", []SortField{{Field: "id", Order: ASC}}}, {"Single Sort", "/?sort=name", []SortField{{Field: "name", Order: ASC}}}, {"Multiple Sort", "/?sort=name,-date", []SortField{{Field: "name", Order: ASC}, {Field: "date", Order: DESC}}}, {"Invalid Sort", "/?sort=invalid", []SortField{{Field: "id", Order: ASC}}}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ SortKey: "sort", DefaultSort: "id", AllowedSorts: []string{"id", "name", "date"}, })) app.Get("/", func(c fiber.Ctx) error { pageInfo, ok := FromContext(c) if !ok { return fiber.ErrBadRequest } return c.JSON(paginateResponse{ Sort: pageInfo.Sort, }) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, tc.url, http.NoBody)) require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests var result paginateResponse require.NoError(t, json.NewDecoder(resp.Body).Decode(&result)) require.Equal(t, tc.expectedSort, result.Sort) }) } } func Test_ParseSortQuery(t *testing.T) { t.Parallel() tests := []struct { name string query string allowedSorts []string defaultSort string expected []SortField }{ { "Empty query", "", []string{"id", "name", "date"}, "id", []SortField{{Field: "id", Order: ASC}}, }, { "Single allowed field", "name", []string{"id", "name", "date"}, "id", []SortField{{Field: "name", Order: ASC}}, }, { "Multiple fields with mixed order", "name,-date,id", []string{"id", "name", "date"}, "id", []SortField{ {Field: "name", Order: ASC}, {Field: "date", Order: DESC}, {Field: "id", Order: ASC}, }, }, { "Disallowed field", "email,name", []string{"id", "name", "date"}, "id", []SortField{{Field: "name", Order: ASC}}, }, { "All disallowed fields", "email,phone", []string{"id", "name", "date"}, "id", []SortField{{Field: "id", Order: ASC}}, }, { "Nil AllowedSorts allows all fields", "email,-phone", nil, "id", []SortField{ {Field: "email", Order: ASC}, {Field: "phone", Order: DESC}, }, }, { "Bare dash is skipped", "-", nil, "id", []SortField{{Field: "id", Order: ASC}}, }, { "Dash in comma list is skipped", "name,-,email", nil, "id", []SortField{ {Field: "name", Order: ASC}, {Field: "email", Order: ASC}, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() result := parseSortQuery(tt.query, tt.allowedSorts, tt.defaultSort) require.Equal(t, tt.expected, result) }) } } // --- Cursor tests --- func Test_PaginateWithCursor(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ DefaultSort: "id", })) app.Get("/", func(c fiber.Ctx) error { pageInfo, ok := FromContext(c) if !ok { return fiber.ErrBadRequest } return c.JSON(cursorResponse{ Cursor: pageInfo.Cursor, Limit: pageInfo.Limit, Sort: pageInfo.Sort, }) }) cursorJSON := `{"id":42}` cursor := base64.RawURLEncoding.EncodeToString([]byte(cursorJSON)) resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/?cursor="+cursor+"&limit=20", http.NoBody)) require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests require.Equal(t, 200, resp.StatusCode) var result cursorResponse require.NoError(t, json.NewDecoder(resp.Body).Decode(&result)) require.Equal(t, cursor, result.Cursor) require.Equal(t, 20, result.Limit) } func Test_PaginateCursorPriorityOverPage(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { pageInfo, ok := FromContext(c) if !ok { return fiber.ErrBadRequest } return c.JSON(pageInfo) }) cursorJSON := `{"id":42}` cursor := base64.RawURLEncoding.EncodeToString([]byte(cursorJSON)) resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/?cursor="+cursor+"&page=5&limit=10", http.NoBody)) require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests var result PageInfo require.NoError(t, json.NewDecoder(resp.Body).Decode(&result)) require.Equal(t, cursor, result.Cursor) require.Equal(t, 0, result.Page) } func Test_PaginateEmptyCursorIsFirstPage(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { pageInfo, ok := FromContext(c) if !ok { return fiber.ErrBadRequest } return c.JSON(pageInfo) }) resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/?cursor=&limit=10", http.NoBody)) require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests require.Equal(t, 200, resp.StatusCode) var result PageInfo require.NoError(t, json.NewDecoder(resp.Body).Decode(&result)) require.Empty(t, result.Cursor) } func Test_PaginateInvalidCursorReturns400(t *testing.T) { t.Parallel() testCases := []struct { name string cursor string }{ {"Invalid base64", "not-valid!!!"}, {"Valid base64 but invalid JSON", base64.RawURLEncoding.EncodeToString([]byte("not-json"))}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { pageInfo, _ := FromContext(c) return c.JSON(pageInfo) }) resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/?cursor="+tc.cursor, http.NoBody)) require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests require.Equal(t, 400, resp.StatusCode) }) } } func Test_PaginateCursorWithSort(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ SortKey: "sort", DefaultSort: "id", AllowedSorts: []string{"id", "name"}, })) app.Get("/", func(c fiber.Ctx) error { pageInfo, ok := FromContext(c) if !ok { return fiber.ErrBadRequest } return c.JSON(cursorResponse{ Cursor: pageInfo.Cursor, Sort: pageInfo.Sort, }) }) cursorJSON := `{"id":42}` cursor := base64.RawURLEncoding.EncodeToString([]byte(cursorJSON)) resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/?cursor="+cursor+"&sort=name,-id", http.NoBody)) require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests var result cursorResponse require.NoError(t, json.NewDecoder(resp.Body).Decode(&result)) require.Equal(t, []SortField{{Field: "name", Order: ASC}, {Field: "id", Order: DESC}}, result.Sort) } func Test_PaginateCursorWithCustomKey(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ CursorKey: "after", })) app.Get("/", func(c fiber.Ctx) error { pageInfo, ok := FromContext(c) if !ok { return fiber.ErrBadRequest } return c.JSON(cursorResponse{ Cursor: pageInfo.Cursor, Limit: pageInfo.Limit, }) }) cursorJSON := `{"id":1}` cursor := base64.RawURLEncoding.EncodeToString([]byte(cursorJSON)) resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/?after="+cursor, http.NoBody)) require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests require.Equal(t, 200, resp.StatusCode) var result cursorResponse require.NoError(t, json.NewDecoder(resp.Body).Decode(&result)) require.Equal(t, cursor, result.Cursor) } func Test_PaginateCursorWithParamAlias(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ CursorParam: "starting_after", })) app.Get("/", func(c fiber.Ctx) error { pageInfo, ok := FromContext(c) if !ok { return fiber.ErrBadRequest } return c.JSON(cursorResponse{ Cursor: pageInfo.Cursor, }) }) cursorJSON := `{"id":1}` cursor := base64.RawURLEncoding.EncodeToString([]byte(cursorJSON)) resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/?starting_after="+cursor, http.NoBody)) require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests require.Equal(t, 200, resp.StatusCode) var result cursorResponse require.NoError(t, json.NewDecoder(resp.Body).Decode(&result)) require.Equal(t, cursor, result.Cursor) } func Test_PaginateNoCursorFallsBackToPageMode(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ DefaultSort: "id", })) app.Get("/", func(c fiber.Ctx) error { pageInfo, ok := FromContext(c) if !ok { return fiber.ErrBadRequest } return c.JSON(paginateResponse{ Page: pageInfo.Page, Limit: pageInfo.Limit, Start: pageInfo.Start(), }) }) resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/?page=3&limit=15", http.NoBody)) require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests var result paginateResponse require.NoError(t, json.NewDecoder(resp.Body).Decode(&result)) require.Equal(t, 3, result.Page) require.Equal(t, 15, result.Limit) require.Equal(t, 30, result.Start) } func Test_NextPageURLWithKeys(t *testing.T) { t.Parallel() p := PageInfo{Page: 2, Limit: 10} result := p.NextPageURLWithKeys("https://example.com/users", "p", "per_page") require.Equal(t, "https://example.com/users?p=3&per_page=10", result) } func Test_PreviousPageURLWithKeys(t *testing.T) { t.Parallel() t.Run("has previous", func(t *testing.T) { t.Parallel() p := PageInfo{Page: 3, Limit: 15} result := p.PreviousPageURLWithKeys("https://example.com/items", "p", "size") require.Equal(t, "https://example.com/items?p=2&size=15", result) }) t.Run("first page returns empty", func(t *testing.T) { t.Parallel() p := PageInfo{Page: 1, Limit: 15} result := p.PreviousPageURLWithKeys("https://example.com/items", "p", "size") require.Empty(t, result) }) } func Test_NextCursorURLWithKeys(t *testing.T) { t.Parallel() t.Run("has more", func(t *testing.T) { t.Parallel() p := &PageInfo{Limit: 20} require.NoError(t, p.SetNextCursor(map[string]any{"id": float64(42)})) result := p.NextCursorURLWithKeys("https://example.com/users", "after", "per_page") expected := fmt.Sprintf("https://example.com/users?after=%s&per_page=20", p.NextCursor) require.Equal(t, expected, result) }) t.Run("no more", func(t *testing.T) { t.Parallel() p := &PageInfo{Limit: 20} result := p.NextCursorURLWithKeys("https://example.com/users", "after", "per_page") require.Empty(t, result) }) } func Test_PaginateCustomMaxLimit(t *testing.T) { t.Parallel() testCases := []struct { name string url string expectedLimit int }{ {"Limit within custom max", "/?limit=40", 40}, {"Limit exceeds custom max", "/?limit=200", 50}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ DefaultSort: "id", DefaultLimit: 10, MaxLimit: 50, })) app.Get("/", func(c fiber.Ctx) error { pageInfo, ok := FromContext(c) if !ok { return fiber.ErrBadRequest } return c.JSON(pageInfo) }) resp, err := app.Test(httptest.NewRequest(http.MethodGet, tc.url, http.NoBody)) require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck // close error not relevant in tests require.Equal(t, http.StatusOK, resp.StatusCode) var result PageInfo require.NoError(t, json.NewDecoder(resp.Body).Decode(&result)) require.Equal(t, tc.expectedLimit, result.Limit) }) } } func Test_ConfigDefaultMaxLimitNormalization(t *testing.T) { t.Parallel() cfg := configDefault(Config{MaxLimit: 0}) require.Equal(t, DefaultMaxLimit, cfg.MaxLimit) cfg2 := configDefault(Config{MaxLimit: 50}) require.Equal(t, 50, cfg2.MaxLimit) } // --- Benchmarks --- func Benchmark_PaginateMiddleware(b *testing.B) { app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { pageInfo, _ := FromContext(c) return c.JSON(pageInfo) }) b.ResetTimer() for i := 0; i < b.N; i++ { req := httptest.NewRequest(http.MethodGet, "/?page=2&limit=20&sort=name,-date", http.NoBody) resp, err := app.Test(req, fiber.TestConfig{Timeout: 0}) if err != nil { b.Fatal(err) } resp.Body.Close() //nolint:errcheck // close error not relevant in tests } } func Benchmark_PaginateMiddlewareWithCustomConfig(b *testing.B) { app := fiber.New() app.Use(New(Config{ PageKey: "p", LimitKey: "l", SortKey: "s", DefaultPage: 1, DefaultLimit: 30, DefaultSort: "id", AllowedSorts: []string{"id", "name", "date"}, })) app.Get("/", func(c fiber.Ctx) error { pageInfo, _ := FromContext(c) return c.JSON(pageInfo) }) b.ResetTimer() for i := 0; i < b.N; i++ { req := httptest.NewRequest(http.MethodGet, "/?p=3&l=25&s=name,-id", http.NoBody) resp, err := app.Test(req, fiber.TestConfig{Timeout: 0}) if err != nil { b.Fatal(err) } resp.Body.Close() //nolint:errcheck // close error not relevant in tests } } func Benchmark_PaginateCursorMiddleware(b *testing.B) { app := fiber.New() app.Use(New(Config{ SortKey: "sort", DefaultSort: "id", AllowedSorts: []string{"id", "name", "date"}, })) app.Get("/", func(c fiber.Ctx) error { pageInfo, _ := FromContext(c) return c.JSON(pageInfo) }) cursorJSON := `{"id":42,"created_at":"2026-01-01T00:00:00Z"}` cursor := base64.RawURLEncoding.EncodeToString([]byte(cursorJSON)) b.ResetTimer() for i := 0; i < b.N; i++ { req := httptest.NewRequest(http.MethodGet, "/?cursor="+cursor+"&limit=20&sort=name,-id", http.NoBody) resp, err := app.Test(req, fiber.TestConfig{Timeout: 0}) if err != nil { b.Fatal(err) } resp.Body.Close() //nolint:errcheck // close error not relevant in tests } } ================================================ FILE: middleware/pprof/config.go ================================================ package pprof import ( "github.com/gofiber/fiber/v3" ) // Config defines the config for middleware. type Config struct { // Next defines a function to skip this middleware when returned true. // // Optional. Default: nil Next func(c fiber.Ctx) bool // Prefix defines a URL prefix added before "/debug/pprof". // Note that it should start with (but not end with) a slash. // Example: "/federated-fiber" // // Optional. Default: "" Prefix string } var ConfigDefault = Config{ Next: nil, Prefix: "", } func configDefault(config ...Config) Config { // Return default config if nothing provided if len(config) < 1 { return ConfigDefault } // Override default config cfg := config[0] // Set default values if cfg.Next == nil { cfg.Next = ConfigDefault.Next } if cfg.Prefix == "" { cfg.Prefix = ConfigDefault.Prefix } return cfg } ================================================ FILE: middleware/pprof/pprof.go ================================================ package pprof import ( "net/http/pprof" "strings" "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp/fasthttpadaptor" "github.com/gofiber/fiber/v3" ) // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Set default config cfg := configDefault(config...) // Set pprof adaptors var ( pprofIndex = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Index) pprofCmdline = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Cmdline) pprofProfile = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Profile) pprofSymbol = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Symbol) pprofTrace = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Trace) pprofAllocs = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("allocs").ServeHTTP) pprofBlock = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("block").ServeHTTP) pprofGoroutine = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("goroutine").ServeHTTP) pprofHeap = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("heap").ServeHTTP) pprofMutex = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("mutex").ServeHTTP) pprofThreadcreate = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler("threadcreate").ServeHTTP) ) // Construct actual prefix prefix := cfg.Prefix + "/debug/pprof" // Return new handler return func(c fiber.Ctx) error { // Don't execute middleware if Next returns true if cfg.Next != nil && cfg.Next(c) { return c.Next() } path := c.Path() // We are only interested in /debug/pprof routes path, found := strings.CutPrefix(path, prefix) if !found { return c.Next() } // Switch on trimmed path against constant strings switch path { case "/": pprofIndex(c.RequestCtx()) case "/cmdline": pprofCmdline(c.RequestCtx()) case "/profile": pprofProfile(c.RequestCtx()) case "/symbol": pprofSymbol(c.RequestCtx()) case "/trace": pprofTrace(c.RequestCtx()) case "/allocs": pprofAllocs(c.RequestCtx()) case "/block": pprofBlock(c.RequestCtx()) case "/goroutine": pprofGoroutine(c.RequestCtx()) case "/heap": pprofHeap(c.RequestCtx()) case "/mutex": pprofMutex(c.RequestCtx()) case "/threadcreate": pprofThreadcreate(c.RequestCtx()) default: // pprof index only works with trailing slash if strings.HasSuffix(path, "/") { path = utils.TrimRight(path, '/') } else { path = prefix + "/" } return c.Redirect().To(path) } return nil } } ================================================ FILE: middleware/pprof/pprof_test.go ================================================ package pprof import ( "bytes" "io" "net/http" "net/http/httptest" "testing" "time" "github.com/gofiber/fiber/v3" "github.com/stretchr/testify/require" ) var testConfig = fiber.TestConfig{ Timeout: 5 * time.Second, FailOnTimeout: true, } func Test_Non_Pprof_Path(t *testing.T) { app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { return c.SendString("escaped") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) b, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "escaped", string(b)) } func Test_Non_Pprof_Path_WithPrefix(t *testing.T) { app := fiber.New() app.Use(New(Config{Prefix: "/federated-fiber"})) app.Get("/", func(c fiber.Ctx) error { return c.SendString("escaped") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) b, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "escaped", string(b)) } func Test_Pprof_Index(t *testing.T) { app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { return c.SendString("escaped") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/debug/pprof/", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) b, err := io.ReadAll(resp.Body) require.NoError(t, err) require.True(t, bytes.Contains(b, []byte("/debug/pprof/"))) } func Test_Pprof_Index_WithPrefix(t *testing.T) { app := fiber.New() app.Use(New(Config{Prefix: "/federated-fiber"})) app.Get("/", func(c fiber.Ctx) error { return c.SendString("escaped") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/federated-fiber/debug/pprof/", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) b, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Contains(t, string(b), "/debug/pprof/") } func Test_Pprof_Subs(t *testing.T) { app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { return c.SendString("escaped") }) subs := []string{ "cmdline", "profile", "symbol", "trace", "allocs", "block", "goroutine", "heap", "mutex", "threadcreate", } for _, sub := range subs { t.Run(sub, func(t *testing.T) { target := "/debug/pprof/" + sub if sub == "profile" { target += "?seconds=1" } resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, target, http.NoBody), testConfig) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) }) } } func Test_Pprof_Subs_WithPrefix(t *testing.T) { app := fiber.New() app.Use(New(Config{Prefix: "/federated-fiber"})) app.Get("/", func(c fiber.Ctx) error { return c.SendString("escaped") }) subs := []string{ "cmdline", "profile", "symbol", "trace", "allocs", "block", "goroutine", "heap", "mutex", "threadcreate", } for _, sub := range subs { t.Run(sub, func(t *testing.T) { target := "/federated-fiber/debug/pprof/" + sub if sub == "profile" { target += "?seconds=1" } resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, target, http.NoBody), testConfig) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) }) } } func Test_Pprof_Other(t *testing.T) { app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { return c.SendString("escaped") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/debug/pprof/303", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusSeeOther, resp.StatusCode) } func Test_Pprof_Other_WithPrefix(t *testing.T) { app := fiber.New() app.Use(New(Config{Prefix: "/federated-fiber"})) app.Get("/", func(c fiber.Ctx) error { return c.SendString("escaped") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/federated-fiber/debug/pprof/303", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusSeeOther, resp.StatusCode) } // go test -run Test_Pprof_Next func Test_Pprof_Next(t *testing.T) { app := fiber.New() app.Use(New(Config{ Next: func(_ fiber.Ctx) bool { return true }, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/debug/pprof/", http.NoBody)) require.NoError(t, err) require.Equal(t, 404, resp.StatusCode) } // go test -run Test_Pprof_Next_WithPrefix func Test_Pprof_Next_WithPrefix(t *testing.T) { app := fiber.New() app.Use(New(Config{ Next: func(_ fiber.Ctx) bool { return true }, Prefix: "/federated-fiber", })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/federated-fiber/debug/pprof/", http.NoBody)) require.NoError(t, err) require.Equal(t, 404, resp.StatusCode) } ================================================ FILE: middleware/proxy/config.go ================================================ package proxy import ( "crypto/tls" "time" "github.com/gofiber/fiber/v3" "github.com/valyala/fasthttp" ) // Config defines the config for middleware. type Config struct { // Next defines a function to skip this middleware when returned true. // // Optional. Default: nil Next func(c fiber.Ctx) bool // ModifyRequest allows you to alter the request // // Optional. Default: nil ModifyRequest fiber.Handler // ModifyResponse allows you to alter the response // // Optional. Default: nil ModifyResponse fiber.Handler // tls config for the http client. TLSConfig *tls.Config // Client is custom client when client config is complex. // Note that Servers, Timeout, WriteBufferSize, ReadBufferSize, TLSConfig // and DialDualStack will not be used if the client are set. Client *fasthttp.LBClient // Servers defines a list of :// HTTP servers, // // which are used in a round-robin manner. // i.e.: "https://foobar.com, http://www.foobar.com" // // Required Servers []string // Timeout is the request timeout used when calling the proxy client // // Optional. Default: 1 second Timeout time.Duration // Per-connection buffer size for requests' reading. // This also limits the maximum header size. // Increase this buffer if your clients send multi-KB RequestURIs // and/or multi-KB headers (for example, BIG cookies). ReadBufferSize int // Per-connection buffer size for responses' writing. WriteBufferSize int // KeepConnectionHeader keeps the "Connection" header when set to true. // // Optional. Default: false KeepConnectionHeader bool // Attempt to connect to both ipv4 and ipv6 host addresses if set to true. // // By default client connects only to ipv4 addresses, since unfortunately ipv6 // remains broken in many networks worldwide :) // // Optional. Default: false DialDualStack bool } // ConfigDefault is the default config var ConfigDefault = Config{ Next: nil, ModifyRequest: nil, ModifyResponse: nil, Timeout: fasthttp.DefaultLBClientTimeout, KeepConnectionHeader: false, } // configDefault function to set default values func configDefault(config ...Config) Config { // Return default config if nothing provided if len(config) < 1 { return ConfigDefault } // Override default config cfg := config[0] // Set default values if cfg.Timeout <= 0 { cfg.Timeout = ConfigDefault.Timeout } // Set default values if len(cfg.Servers) == 0 && cfg.Client == nil { panic("Servers cannot be empty") } return cfg } ================================================ FILE: middleware/proxy/proxy.go ================================================ package proxy import ( "bytes" "errors" "net/url" "strings" "sync" "time" "github.com/gofiber/utils/v2" "github.com/gofiber/fiber/v3" "github.com/valyala/fasthttp" ) // Balancer creates a load balancer among multiple upstream servers func Balancer(config ...Config) fiber.Handler { // Set default config cfg := configDefault(config...) // Load balanced client lbc := &fasthttp.LBClient{} // Note that Servers, Timeout, WriteBufferSize, ReadBufferSize and TLSConfig // will not be used if the client are set. if cfg.Client == nil { // Set timeout lbc.Timeout = cfg.Timeout // Scheme must be provided, falls back to http for _, server := range cfg.Servers { if !strings.HasPrefix(server, "http") { server = "http://" + server } u, err := url.Parse(server) if err != nil { panic(err) } client := &fasthttp.HostClient{ NoDefaultUserAgentHeader: true, DisablePathNormalizing: true, Addr: u.Host, ReadBufferSize: cfg.ReadBufferSize, WriteBufferSize: cfg.WriteBufferSize, TLSConfig: cfg.TLSConfig, DialDualStack: cfg.DialDualStack, } lbc.Clients = append(lbc.Clients, client) } } else { // Set custom client lbc = cfg.Client } // Return new handler return func(c fiber.Ctx) error { // Don't execute middleware if Next returns true if cfg.Next != nil && cfg.Next(c) { return c.Next() } // Set request and response req := c.Request() res := c.Response() if !cfg.KeepConnectionHeader { // Don't proxy "Connection" header req.Header.Del(fiber.HeaderConnection) } // Modify request if cfg.ModifyRequest != nil { if err := cfg.ModifyRequest(c); err != nil { return err } } if c.App().Config().Immutable { req.SetRequestURIBytes(req.RequestURI()) } else { req.SetRequestURI(utils.UnsafeString(req.RequestURI())) } // Forward request if err := lbc.Do(req, res); err != nil { return err } if !cfg.KeepConnectionHeader { // Don't proxy "Connection" header res.Header.Del(fiber.HeaderConnection) } // Modify response if cfg.ModifyResponse != nil { if err := cfg.ModifyResponse(c); err != nil { return err } } // Return nil to end proxying if no error return nil } } var client = &fasthttp.Client{ NoDefaultUserAgentHeader: true, DisablePathNormalizing: true, } var ( errNilProxyClientOverride = errors.New("proxy: nil client override passed to Do/Forward") errNilGlobalProxyClient = errors.New("proxy: global client is nil, set a non-nil client with proxy.WithClient") ) var lock sync.RWMutex // WithClient sets the global proxy client. // This function should be called before Do and Forward. func WithClient(cli *fasthttp.Client) { if cli == nil { panic("proxy: WithClient requires a non-nil *fasthttp.Client") } lock.Lock() defer lock.Unlock() client = cli } // Forward performs the given http request and fills the given http response. // This method will return a fiber.Handler func Forward(addr string, clients ...*fasthttp.Client) fiber.Handler { return func(c fiber.Ctx) error { return Do(c, addr, clients...) } } // Do performs the given http request and fills the given http response. // This method can be used within a fiber.Handler func Do(c fiber.Ctx, addr string, clients ...*fasthttp.Client) error { return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error { return cli.Do(req, resp) }, clients...) } // DoRedirects performs the given http request and fills the given http response, following up to maxRedirectsCount redirects. // When the redirect count exceeds maxRedirectsCount, ErrTooManyRedirects is returned. // This method can be used within a fiber.Handler func DoRedirects(c fiber.Ctx, addr string, maxRedirectsCount int, clients ...*fasthttp.Client) error { return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error { return cli.DoRedirects(req, resp, maxRedirectsCount) }, clients...) } // DoDeadline performs the given request and waits for response until the given deadline. // This method can be used within a fiber.Handler func DoDeadline(c fiber.Ctx, addr string, deadline time.Time, clients ...*fasthttp.Client) error { return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error { return cli.DoDeadline(req, resp, deadline) }, clients...) } // DoTimeout performs the given request and waits for response during the given timeout duration. // This method can be used within a fiber.Handler func DoTimeout(c fiber.Ctx, addr string, timeout time.Duration, clients ...*fasthttp.Client) error { return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error { return cli.DoTimeout(req, resp, timeout) }, clients...) } func doAction( c fiber.Ctx, addr string, action func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error, clients ...*fasthttp.Client, ) error { lock.RLock() globalClient := client lock.RUnlock() cli, err := selectClient(globalClient, clients...) if err != nil { return err } req := c.Request() res := c.Response() originalURL := utils.CopyString(c.OriginalURL()) defer req.SetRequestURI(originalURL) copiedURL := utils.CopyString(addr) req.SetRequestURI(copiedURL) // NOTE: if req.isTLS is true, SetRequestURI keeps the scheme as https. // Reference: https://github.com/gofiber/fiber/issues/1762 if scheme := getScheme(utils.UnsafeBytes(copiedURL)); len(scheme) > 0 { req.URI().SetSchemeBytes(scheme) } req.Header.Del(fiber.HeaderConnection) if err := action(cli, req, res); err != nil { return err } res.Header.Del(fiber.HeaderConnection) return nil } func selectClient(globalClient *fasthttp.Client, clients ...*fasthttp.Client) (*fasthttp.Client, error) { if len(clients) != 0 { if clients[0] == nil { return nil, errNilProxyClientOverride } return clients[0], nil } if globalClient == nil { return nil, errNilGlobalProxyClient } return globalClient, nil } func getScheme(uri []byte) []byte { i := bytes.IndexByte(uri, '/') if i < 1 || uri[i-1] != ':' || i == len(uri)-1 || uri[i+1] != '/' { return nil } return uri[:i-1] } // DomainForward performs an http request based on the given domain and populates the given http response. // This method will return a fiber.Handler func DomainForward(hostname, addr string, clients ...*fasthttp.Client) fiber.Handler { return func(c fiber.Ctx) error { host := utils.UnsafeString(c.Request().Host()) if host == hostname { return Do(c, addr+c.OriginalURL(), clients...) } return nil } } type roundrobin struct { pool []string current int sync.Mutex } // this method will return a string of addr server from list server. func (r *roundrobin) get() string { r.Lock() defer r.Unlock() if r.current >= len(r.pool) { r.current %= len(r.pool) } result := r.pool[r.current] r.current++ return result } // BalancerForward Forward performs the given http request with round robin algorithm to server and fills the given http response. // This method will return a fiber.Handler func BalancerForward(servers []string, clients ...*fasthttp.Client) fiber.Handler { r := &roundrobin{ current: 0, pool: servers, } return func(c fiber.Ctx) error { server := r.get() if !strings.HasPrefix(server, "http") { server = "http://" + server } c.Request().Header.Add("X-Real-IP", c.IP()) return Do(c, server+c.OriginalURL(), clients...) } } ================================================ FILE: middleware/proxy/proxy_test.go ================================================ package proxy import ( "crypto/tls" "io" "net" "net/http" "net/http/httptest" "os" "strings" "testing" "time" "github.com/gofiber/fiber/v3" clientpkg "github.com/gofiber/fiber/v3/client" "github.com/stretchr/testify/require" "github.com/gofiber/fiber/v3/internal/tlstest" "github.com/valyala/fasthttp" ) func startServer(app *fiber.App, ln net.Listener) { go func() { err := app.Listener(ln, fiber.ListenConfig{ DisableStartupMessage: true, }) if err != nil { panic(err) } }() } func createProxyTestServer(t *testing.T, handler fiber.Handler, network, address string) (target *fiber.App, addr string) { //nolint:nonamedreturns // gocritic unnamedResult prefers naming returned target app and address for readability t.Helper() target = fiber.New() target.Get("/", handler) ln, err := net.Listen(network, address) require.NoError(t, err) addr = ln.Addr().String() startServer(target, ln) return target, addr } func createProxyTestServerIPv4(t *testing.T, handler fiber.Handler) (target *fiber.App, addr string) { //nolint:nonamedreturns // gocritic unnamedResult prefers naming returned target app and address for readability t.Helper() return createProxyTestServer(t, handler, fiber.NetworkTCP4, "127.0.0.1:0") } func createProxyTestServerIPv6(t *testing.T, handler fiber.Handler) (target *fiber.App, addr string) { //nolint:nonamedreturns // gocritic unnamedResult prefers naming returned target app and address for readability t.Helper() return createProxyTestServer(t, handler, fiber.NetworkTCP6, "[::1]:0") } func createRedirectServer(t *testing.T) string { t.Helper() app := fiber.New() var addr string app.Get("/", func(c fiber.Ctx) error { c.Location("http://" + addr + "/final") return c.Status(fiber.StatusMovedPermanently).SendString("redirect") }) app.Get("/final", func(c fiber.Ctx) error { return c.SendString("final") }) ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") require.NoError(t, err) t.Cleanup(func() { ln.Close() //nolint:errcheck // It is fine to ignore the error here }) addr = ln.Addr().String() startServer(app, ln) return addr } // go test -run Test_Proxy_Empty_Host func Test_Proxy_Empty_Upstream_Servers(t *testing.T) { t.Parallel() defer func() { if r := recover(); r != nil { if r != "Servers cannot be empty" { panic(r) } } }() app := fiber.New() app.Use(Balancer(Config{Servers: []string{}})) } // go test -run Test_Proxy_Empty_Config func Test_Proxy_Empty_Config(t *testing.T) { t.Parallel() defer func() { if r := recover(); r != nil { if r != "Servers cannot be empty" { panic(r) } } }() app := fiber.New() app.Use(Balancer(Config{})) } // go test -run Test_Proxy_Next func Test_Proxy_Next(t *testing.T) { t.Parallel() app := fiber.New() app.Use(Balancer(Config{ Servers: []string{"127.0.0.1"}, Next: func(_ fiber.Ctx) bool { return true }, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusNotFound, resp.StatusCode) } // go test -run Test_Proxy func Test_Proxy(t *testing.T) { t.Parallel() target, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }) resp, err := target.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody), fiber.TestConfig{ Timeout: 2 * time.Second, FailOnTimeout: true, }) require.NoError(t, err) require.Equal(t, fiber.StatusTeapot, resp.StatusCode) app := fiber.New() app.Use(Balancer(Config{Servers: []string{addr}})) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Host = addr resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusTeapot, resp.StatusCode) } // go test -run Test_Proxy_Balancer_WithTlsConfig func Test_Proxy_Balancer_WithTlsConfig(t *testing.T) { t.Parallel() serverTLSConf, _, err := tlstest.GetTLSConfigs() require.NoError(t, err) ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") require.NoError(t, err) ln = tls.NewListener(ln, serverTLSConf) app := fiber.New() app.Get("/tlsbalancer", func(c fiber.Ctx) error { return c.SendString("tls balancer") }) addr := ln.Addr().String() clientTLSConf := &tls.Config{InsecureSkipVerify: true} //nolint:gosec // We're in a test func, so this is fine // disable certificate verification in Balancer app.Use(Balancer(Config{ Servers: []string{addr}, TLSConfig: clientTLSConf, })) startServer(app, ln) client := clientpkg.New() client.SetTLSConfig(clientTLSConf) resp, err := client.Get("https://" + addr + "/tlsbalancer") require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "tls balancer", string(resp.Body())) resp.Close() } // go test -run Test_Proxy_Balancer_IPv6_Upstream func Test_Proxy_Balancer_IPv6_Upstream(t *testing.T) { t.Parallel() target, addr := createProxyTestServerIPv6(t, func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }) resp, err := target.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody), fiber.TestConfig{ Timeout: 2 * time.Second, FailOnTimeout: true, }) require.NoError(t, err) require.Equal(t, fiber.StatusTeapot, resp.StatusCode) app := fiber.New() app.Use(Balancer(Config{Servers: []string{addr}})) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Host = addr resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) } // go test -run Test_Proxy_Balancer_IPv6_Upstream func Test_Proxy_Balancer_IPv6_Upstream_With_DialDualStack(t *testing.T) { t.Parallel() target, addr := createProxyTestServerIPv6(t, func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }) resp, err := target.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody), fiber.TestConfig{ Timeout: 2 * time.Second, FailOnTimeout: true, }) require.NoError(t, err) require.Equal(t, fiber.StatusTeapot, resp.StatusCode) app := fiber.New() app.Use(Balancer(Config{ Servers: []string{addr}, DialDualStack: true, })) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Host = addr resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusTeapot, resp.StatusCode) } // go test -run Test_Proxy_Balancer_IPv6_Upstream func Test_Proxy_Balancer_IPv4_Upstream_With_DialDualStack(t *testing.T) { t.Parallel() target, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }) resp, err := target.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody), fiber.TestConfig{ Timeout: 2 * time.Second, FailOnTimeout: true, }) require.NoError(t, err) require.Equal(t, fiber.StatusTeapot, resp.StatusCode) app := fiber.New() app.Use(Balancer(Config{ Servers: []string{addr}, DialDualStack: true, })) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Host = addr resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusTeapot, resp.StatusCode) } // go test -run Test_Proxy_Forward_WithTlsConfig_To_Http func Test_Proxy_Forward_WithTlsConfig_To_Http(t *testing.T) { t.Parallel() _, targetAddr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error { return c.SendString("hello from target") }) proxyServerTLSConf, _, err := tlstest.GetTLSConfigs() require.NoError(t, err) proxyServerLn, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") require.NoError(t, err) proxyServerLn = tls.NewListener(proxyServerLn, proxyServerTLSConf) proxyAddr := proxyServerLn.Addr().String() app := fiber.New() app.Use(Forward("http://" + targetAddr)) startServer(app, proxyServerLn) client := clientpkg.New() client.SetTimeout(5 * time.Second) client.TLSConfig().InsecureSkipVerify = true resp, err := client.Get("https://" + proxyAddr) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "hello from target", string(resp.Body())) resp.Close() } // go test -run Test_Proxy_Forward func Test_Proxy_Forward(t *testing.T) { t.Parallel() app := fiber.New() _, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error { return c.SendString("forwarded") }) app.Use(Forward("http://" + addr)) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) b, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "forwarded", string(b)) } // go test -run Test_Proxy_Forward_WithClient_TLSConfig func Test_Proxy_Forward_WithClient_TLSConfig(t *testing.T) { t.Parallel() serverTLSConf, _, err := tlstest.GetTLSConfigs() require.NoError(t, err) ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") require.NoError(t, err) ln = tls.NewListener(ln, serverTLSConf) app := fiber.New() app.Get("/tlsfwd", func(c fiber.Ctx) error { return c.SendString("tls forward") }) addr := ln.Addr().String() clientTLSConf := &tls.Config{InsecureSkipVerify: true} //nolint:gosec // We're in a test func, so this is fine // disable certificate verification WithClient(&fasthttp.Client{ TLSConfig: clientTLSConf, }) app.Use(Forward("https://" + addr + "/tlsfwd")) startServer(app, ln) client := clientpkg.New() client.SetTLSConfig(clientTLSConf) resp, err := client.Get("https://" + addr) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "tls forward", string(resp.Body())) resp.Close() } // go test -run Test_Proxy_Modify_Response func Test_Proxy_Modify_Response(t *testing.T) { t.Parallel() _, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error { return c.Status(500).SendString("not modified") }) app := fiber.New() app.Use(Balancer(Config{ Servers: []string{addr}, ModifyResponse: func(c fiber.Ctx) error { c.Response().SetStatusCode(fiber.StatusOK) return c.SendString("modified response") }, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) b, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "modified response", string(b)) } // go test -run Test_Proxy_Modify_Request func Test_Proxy_Modify_Request(t *testing.T) { t.Parallel() _, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error { b := c.Request().Body() return c.SendString(string(b)) }) app := fiber.New() app.Use(Balancer(Config{ Servers: []string{addr}, ModifyRequest: func(c fiber.Ctx) error { c.Request().SetBody([]byte("modified request")) return nil }, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) b, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "modified request", string(b)) } // go test -run Test_Proxy_Timeout_Slow_Server func Test_Proxy_Timeout_Slow_Server(t *testing.T) { t.Parallel() _, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error { time.Sleep(300 * time.Millisecond) return c.SendString("fiber is awesome") }) app := fiber.New() app.Use(Balancer(Config{ Servers: []string{addr}, Timeout: 600 * time.Millisecond, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody), fiber.TestConfig{ Timeout: 2 * time.Second, FailOnTimeout: true, }) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) b, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "fiber is awesome", string(b)) } // go test -run Test_Proxy_With_Timeout func Test_Proxy_With_Timeout(t *testing.T) { t.Parallel() _, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error { time.Sleep(1 * time.Second) return c.SendString("fiber is awesome") }) app := fiber.New() app.Use(Balancer(Config{ Servers: []string{addr}, Timeout: 100 * time.Millisecond, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody), fiber.TestConfig{ Timeout: 2 * time.Second, FailOnTimeout: true, }) require.NoError(t, err) require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) b, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "timeout", string(b)) } // go test -run Test_Proxy_Buffer_Size_Response func Test_Proxy_Buffer_Size_Response(t *testing.T) { t.Parallel() _, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error { long := strings.Join(make([]string, 5000), "-") c.Set("Very-Long-Header", long) return c.SendString("ok") }) app := fiber.New() app.Use(Balancer(Config{Servers: []string{addr}, KeepConnectionHeader: true})) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) app = fiber.New() app.Use(Balancer(Config{ Servers: []string{addr}, ReadBufferSize: 1024 * 8, })) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) } // go test -race -run Test_Proxy_Do_RestoreOriginalURL func Test_Proxy_Do_RestoreOriginalURL(t *testing.T) { t.Parallel() _, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error { return c.SendString("proxied") }) app := fiber.New() app.Get("/test", func(c fiber.Ctx) error { return Do(c, "http://"+addr) }) resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err1) require.Equal(t, "/test", resp.Request.URL.String()) require.Equal(t, fiber.StatusOK, resp.StatusCode) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "proxied", string(body)) } // go test -race -run Test_Proxy_Do_WithRealURL func Test_Proxy_Do_WithRealURL(t *testing.T) { t.Parallel() _, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error { return c.SendString("real url") }) app := fiber.New() app.Get("/test", func(c fiber.Ctx) error { return Do(c, "http://"+addr) }) resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody), fiber.TestConfig{ Timeout: 2 * time.Second, FailOnTimeout: true, }) require.NoError(t, err1) require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Equal(t, "/test", resp.Request.URL.String()) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "real url", string(body)) } // go test -race -run Test_Proxy_Do_WithRedirect func Test_Proxy_Do_WithRedirect(t *testing.T) { t.Parallel() addr := createRedirectServer(t) app := fiber.New() app.Get("/test", func(c fiber.Ctx) error { return Do(c, "http://"+addr) }) resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody), fiber.TestConfig{ Timeout: 2 * time.Second, FailOnTimeout: true, }) require.NoError(t, err1) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "redirect", string(body)) require.Equal(t, fiber.StatusMovedPermanently, resp.StatusCode) } // go test -race -run Test_Proxy_DoRedirects_RestoreOriginalURL func Test_Proxy_DoRedirects_RestoreOriginalURL(t *testing.T) { t.Parallel() addr := createRedirectServer(t) app := fiber.New() app.Get("/test", func(c fiber.Ctx) error { return DoRedirects(c, "http://"+addr, 1) }) resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody), fiber.TestConfig{ Timeout: 2 * time.Second, FailOnTimeout: true, }) require.NoError(t, err1) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "final", string(body)) require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Equal(t, "/test", resp.Request.URL.String()) } // go test -race -run Test_Proxy_DoRedirects_TooManyRedirects func Test_Proxy_DoRedirects_TooManyRedirects(t *testing.T) { t.Parallel() addr := createRedirectServer(t) app := fiber.New() app.Get("/test", func(c fiber.Ctx) error { return DoRedirects(c, "http://"+addr, 0) }) resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody), fiber.TestConfig{ Timeout: 2 * time.Second, FailOnTimeout: true, }) require.NoError(t, err1) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "too many redirects detected when doing the request", string(body)) require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) require.Equal(t, "/test", resp.Request.URL.String()) } // go test -race -run Test_Proxy_DoTimeout_RestoreOriginalURL func Test_Proxy_DoTimeout_RestoreOriginalURL(t *testing.T) { t.Parallel() _, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error { return c.SendString("proxied") }) app := fiber.New() app.Get("/test", func(c fiber.Ctx) error { return DoTimeout(c, "http://"+addr, time.Second) }) resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody), fiber.TestConfig{ Timeout: 2 * time.Second, FailOnTimeout: true, }) require.NoError(t, err1) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "proxied", string(body)) require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Equal(t, "/test", resp.Request.URL.String()) } // go test -race -run Test_Proxy_DoTimeout_Timeout func Test_Proxy_DoTimeout_Timeout(t *testing.T) { _, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error { time.Sleep(time.Second * 5) return c.SendString("proxied") }) app := fiber.New() app.Get("/test", func(c fiber.Ctx) error { return DoTimeout(c, "http://"+addr, time.Second) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody), fiber.TestConfig{ Timeout: 2 * time.Second, FailOnTimeout: true, }) require.NoError(t, err) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "timeout", string(body)) require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) require.Equal(t, "/test", resp.Request.URL.String()) } // go test -race -run Test_Proxy_DoDeadline_RestoreOriginalURL func Test_Proxy_DoDeadline_RestoreOriginalURL(t *testing.T) { t.Parallel() _, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error { return c.SendString("proxied") }) app := fiber.New() app.Get("/test", func(c fiber.Ctx) error { return DoDeadline(c, "http://"+addr, time.Now().Add(time.Second)) }) resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err1) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "proxied", string(body)) require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Equal(t, "/test", resp.Request.URL.String()) } // go test -race -run Test_Proxy_DoDeadline_PastDeadline func Test_Proxy_DoDeadline_PastDeadline(t *testing.T) { _, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error { time.Sleep(time.Second * 5) return c.SendString("proxied") }) app := fiber.New() app.Get("/test", func(c fiber.Ctx) error { return DoDeadline(c, "http://"+addr, time.Now().Add(2*time.Second)) }) _, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody), fiber.TestConfig{ Timeout: 1 * time.Second, FailOnTimeout: true, }) require.Equal(t, os.ErrDeadlineExceeded, err1) } // go test -race -run Test_Proxy_Do_HTTP_Prefix_URL func Test_Proxy_Do_HTTP_Prefix_URL(t *testing.T) { t.Parallel() _, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error { return c.SendString("hello world") }) app := fiber.New() app.Get("/*", func(c fiber.Ctx) error { path := c.OriginalURL() url := strings.TrimPrefix(path, "/") require.Equal(t, "http://"+addr, url) if err := Do(c, url); err != nil { return err } c.Response().Header.Del(fiber.HeaderServer) return nil }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/http://"+addr, http.NoBody)) require.NoError(t, err) s, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "hello world", string(s)) } // go test -race -run Test_Proxy_Forward_Global_Client func Test_Proxy_Forward_Global_Client(t *testing.T) { t.Parallel() ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") require.NoError(t, err) WithClient(&fasthttp.Client{ NoDefaultUserAgentHeader: true, DisablePathNormalizing: true, }) app := fiber.New() app.Get("/test_global_client", func(c fiber.Ctx) error { return c.SendString("test_global_client") }) addr := ln.Addr().String() app.Use(Forward("http://" + addr + "/test_global_client")) startServer(app, ln) client := clientpkg.New() resp, err := client.Get("http://" + addr) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "test_global_client", string(resp.Body())) resp.Close() } // go test -race -run Test_Proxy_Forward_Local_Client func Test_Proxy_Forward_Local_Client(t *testing.T) { t.Parallel() ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") require.NoError(t, err) app := fiber.New() app.Get("/test_local_client", func(c fiber.Ctx) error { return c.SendString("test_local_client") }) addr := ln.Addr().String() app.Use(Forward("http://"+addr+"/test_local_client", &fasthttp.Client{ NoDefaultUserAgentHeader: true, DisablePathNormalizing: true, Dial: fasthttp.Dial, })) startServer(app, ln) client := clientpkg.New() resp, err := client.Get("http://" + addr) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "test_local_client", string(resp.Body())) resp.Close() } // go test -run Test_Proxy_WithClient_Nil_Panics func Test_Proxy_WithClient_Nil_Panics(t *testing.T) { t.Parallel() require.PanicsWithValue(t, "proxy: WithClient requires a non-nil *fasthttp.Client", func() { WithClient(nil) }) } // go test -run Test_Proxy_Do_NilClientOverride func Test_Proxy_Do_NilClientOverride(t *testing.T) { t.Parallel() _, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error { return c.SendString("proxied") }) app := fiber.New() app.Get("/test", func(c fiber.Ctx) error { return Do(c, "http://"+addr, nil) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, errNilProxyClientOverride.Error(), string(body)) } // go test -run Test_Proxy_Do_NonNilClientOverride func Test_Proxy_Do_NonNilClientOverride(t *testing.T) { t.Parallel() _, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error { return c.SendString("proxied") }) app := fiber.New() app.Get("/test", func(c fiber.Ctx) error { return Do(c, "http://"+addr, &fasthttp.Client{ NoDefaultUserAgentHeader: true, DisablePathNormalizing: true, }) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody), fiber.TestConfig{ Timeout: 2 * time.Second, FailOnTimeout: true, }) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "proxied", string(body)) } // go test -run Test_Proxy_SelectClient_NilGlobal func Test_Proxy_SelectClient_NilGlobal(t *testing.T) { t.Parallel() selectedClient, err := selectClient(nil) require.ErrorIs(t, err, errNilGlobalProxyClient) require.Nil(t, selectedClient) } // go test -run Test_Proxy_NilClientOverride_AcrossHelpers func Test_Proxy_NilClientOverride_AcrossHelpers(t *testing.T) { t.Parallel() _, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error { return c.SendString("proxied") }) tests := map[string]func(c fiber.Ctx) error{ "DoRedirects": func(c fiber.Ctx) error { return DoRedirects(c, "http://"+addr, 1, nil) }, "DoDeadline": func(c fiber.Ctx) error { return DoDeadline(c, "http://"+addr, time.Now().Add(time.Second), nil) }, "DoTimeout": func(c fiber.Ctx) error { return DoTimeout(c, "http://"+addr, time.Second, nil) }, } for name, run := range tests { t.Run(name, func(t *testing.T) { t.Parallel() app := fiber.New() app.Get("/test", run) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, errNilProxyClientOverride.Error(), string(body)) }) } } // go test -run Test_ProxyBalancer_Custom_Client func Test_ProxyBalancer_Custom_Client(t *testing.T) { t.Parallel() target, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }) resp, err := target.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody), fiber.TestConfig{ Timeout: 2 * time.Second, FailOnTimeout: true, }) require.NoError(t, err) require.Equal(t, fiber.StatusTeapot, resp.StatusCode) app := fiber.New() app.Use(Balancer(Config{Client: &fasthttp.LBClient{ Clients: []fasthttp.BalancingClient{ &fasthttp.HostClient{ NoDefaultUserAgentHeader: true, DisablePathNormalizing: true, Addr: addr, }, }, Timeout: time.Second, }})) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Host = addr resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusTeapot, resp.StatusCode) } // go test -run Test_Proxy_Domain_Forward_Local func Test_Proxy_Domain_Forward_Local(t *testing.T) { t.Parallel() ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") require.NoError(t, err) app := fiber.New() // target server ln1, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") require.NoError(t, err) app1 := fiber.New() app1.Get("/test", func(c fiber.Ctx) error { return c.SendString("test_local_client:" + c.Query("query_test")) }) proxyAddr := ln.Addr().String() targetAddr := ln1.Addr().String() localDomain := strings.Replace(proxyAddr, "127.0.0.1", "localhost", 1) app.Use(DomainForward(localDomain, "http://"+targetAddr, &fasthttp.Client{ NoDefaultUserAgentHeader: true, DisablePathNormalizing: true, Dial: fasthttp.Dial, })) startServer(app, ln) startServer(app1, ln1) client := clientpkg.New() resp, err := client.Get("http://" + localDomain + "/test?query_test=true") require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode()) require.Equal(t, "test_local_client:true", string(resp.Body())) resp.Close() } // go test -run Test_Proxy_Balancer_Forward_Local func Test_Proxy_Balancer_Forward_Local(t *testing.T) { t.Parallel() app := fiber.New() _, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error { return c.SendString("forwarded") }) app.Use(BalancerForward([]string{addr})) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) b, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "forwarded", string(b)) } func Test_Proxy_Immutable(t *testing.T) { t.Parallel() target, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }) resp, err := target.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody), fiber.TestConfig{ Timeout: 2 * time.Second, FailOnTimeout: true, }) require.NoError(t, err) require.Equal(t, fiber.StatusTeapot, resp.StatusCode) app := fiber.New(fiber.Config{Immutable: true}) app.Use(Balancer(Config{Servers: []string{addr}})) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Host = addr resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusTeapot, resp.StatusCode) } func Test_Proxy_KeepConnectionHeader(t *testing.T) { t.Parallel() _, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error { c.Set(fiber.HeaderConnection, "backend") return c.SendString("ok") }) app := fiber.New() app.Use(Balancer(Config{Servers: []string{addr}, KeepConnectionHeader: true})) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Host = addr req.Header.Set(fiber.HeaderConnection, "keep-alive") resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, "backend", resp.Header.Get(fiber.HeaderConnection)) } func Test_Proxy_DropConnectionHeader(t *testing.T) { t.Parallel() _, addr := createProxyTestServerIPv4(t, func(c fiber.Ctx) error { c.Set(fiber.HeaderConnection, "backend") return c.SendString("ok") }) app := fiber.New() app.Use(Balancer(Config{Servers: []string{addr}})) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Host = addr req.Header.Set(fiber.HeaderConnection, "keep-alive") resp, err := app.Test(req) require.NoError(t, err) require.Empty(t, resp.Header.Get(fiber.HeaderConnection)) } ================================================ FILE: middleware/recover/config.go ================================================ package recover //nolint:predeclared // TODO: Rename to some non-builtin import ( "github.com/gofiber/fiber/v3" ) // Config defines the config for middleware. type Config struct { // Next defines a function to skip this middleware when returned true. // // Optional. Default: nil Next func(c fiber.Ctx) bool // PanicHandler defines a function to customize the error produced from a recovered panic/result. // // Optional. Default: DefaultPanicHandler PanicHandler func(c fiber.Ctx, r any) error // StackTraceHandler defines a function to handle stack trace // // Optional. Default: defaultStackTraceHandler StackTraceHandler func(c fiber.Ctx, e any) // EnableStackTrace enables handling stack trace // // Optional. Default: false EnableStackTrace bool } // ConfigDefault is the default config var ConfigDefault = Config{ Next: nil, EnableStackTrace: false, StackTraceHandler: defaultStackTraceHandler, PanicHandler: DefaultPanicHandler, } // Helper function to set default values func configDefault(config ...Config) Config { // Return default config if nothing provided if len(config) < 1 { return ConfigDefault } // Override default config cfg := config[0] if cfg.EnableStackTrace && cfg.StackTraceHandler == nil { cfg.StackTraceHandler = defaultStackTraceHandler } if cfg.PanicHandler == nil { cfg.PanicHandler = DefaultPanicHandler } return cfg } ================================================ FILE: middleware/recover/recover.go ================================================ package recover //nolint:predeclared // TODO: Rename to some non-builtin import ( "fmt" "os" "runtime/debug" "github.com/gofiber/fiber/v3" ) func defaultStackTraceHandler(_ fiber.Ctx, e any) { fmt.Fprintf(os.Stderr, "panic: %v\n\n%s\n", e, debug.Stack()) } // DefaultPanicHandler returns r directly if it's an error, and creates a new one with the %v verb otherwise. func DefaultPanicHandler(_ fiber.Ctx, r any) error { if err, ok := r.(error); ok { return err } return fmt.Errorf("%v", r) } // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Set default config cfg := configDefault(config...) // Return new handler return func(c fiber.Ctx) (err error) { //nolint:nonamedreturns // Uses recover() to overwrite the error // Don't execute middleware if Next returns true if cfg.Next != nil && cfg.Next(c) { return c.Next() } // Catch panics defer func() { if r := recover(); r != nil { if cfg.EnableStackTrace { cfg.StackTraceHandler(c, r) } // Set error that will call the global error handler err = cfg.PanicHandler(c, r) } }() // Return err if exist, else move to next handler return c.Next() } } ================================================ FILE: middleware/recover/recover_test.go ================================================ package recover //nolint:predeclared // TODO: Rename to some non-builtin import ( "errors" "fmt" "net/http" "net/http/httptest" "testing" "github.com/gofiber/fiber/v3" "github.com/stretchr/testify/require" ) // go test -run Test_Recover func Test_Recover(t *testing.T) { t.Parallel() tests := []struct { name string panicVal any panicHandler func(c fiber.Ctx, r any) error errorMsg string }{ { name: "non-error panic will be handled by default", panicVal: "Hi, I'm an error!", panicHandler: nil, errorMsg: "Hi, I'm an error!", }, { name: "error panic will be handled by default", panicVal: errors.New("hi, I'm an error object"), panicHandler: nil, errorMsg: "hi, I'm an error object", }, { name: "non-error panic will be handled", panicVal: "Hi, I'm an error!", panicHandler: func(c fiber.Ctx, r any) error { return fmt.Errorf("[RECOVERED]: %w", DefaultPanicHandler(c, r)) }, errorMsg: "[RECOVERED]: Hi, I'm an error!", }, { name: "error panic will be handled", panicVal: errors.New("hi, I'm an error object"), panicHandler: func(c fiber.Ctx, r any) error { return fmt.Errorf("[RECOVERED]: %w", DefaultPanicHandler(c, r)) }, errorMsg: "[RECOVERED]: hi, I'm an error object", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { t.Parallel() app := fiber.New(fiber.Config{ ErrorHandler: func(c fiber.Ctx, err error) error { require.Equal(t, tc.errorMsg, err.Error()) return c.SendStatus(fiber.StatusTeapot) }, }) app.Use(New(Config{PanicHandler: tc.panicHandler})) app.Get("/panic", func(_ fiber.Ctx) error { panic(tc.panicVal) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/panic", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusTeapot, resp.StatusCode) }) } } // go test -run Test_Recover_Next func Test_Recover_Next(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Next: func(_ fiber.Ctx) bool { return true }, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusNotFound, resp.StatusCode) } func Test_Recover_EnableStackTrace(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ EnableStackTrace: true, })) app.Get("/panic", func(_ fiber.Ctx) error { panic("Hi, I'm an error!") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/panic", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) } ================================================ FILE: middleware/redirect/config.go ================================================ package redirect import ( "regexp" "github.com/gofiber/fiber/v3" ) // Config defines the config for middleware. type Config struct { // Filter defines a function to skip middleware. // Optional. Default: nil Next func(fiber.Ctx) bool // Rules defines the URL path rewrite rules. The values captured in asterisk can be // retrieved by index e.g. $1, $2 and so on. // Required. Example: // "/old": "/new", // "/api/*": "/$1", // "/js/*": "/public/javascript/$1", // "/users/*/orders/*": "/user/$1/order/$2", Rules map[string]string rulesRegex map[*regexp.Regexp]string // The status code when redirecting // This is ignored if Redirect is disabled // Optional. Default: 302 Temporary Redirect StatusCode int } // ConfigDefault is the default config var ConfigDefault = Config{ StatusCode: fiber.StatusFound, } // Helper function to set default values func configDefault(config ...Config) Config { // Return default config if nothing provided if len(config) < 1 { return ConfigDefault } // Override default config cfg := config[0] // Set default values if cfg.StatusCode == 0 { cfg.StatusCode = ConfigDefault.StatusCode } return cfg } ================================================ FILE: middleware/redirect/redirect.go ================================================ package redirect import ( "regexp" "strconv" "strings" "github.com/gofiber/fiber/v3" "github.com/gofiber/utils/v2" ) // New creates a new middleware handler func New(config ...Config) fiber.Handler { cfg := configDefault(config...) // Initialize cfg.rulesRegex = map[*regexp.Regexp]string{} for k, v := range cfg.Rules { k = strings.ReplaceAll(k, "*", "(.*)") k += "$" cfg.rulesRegex[regexp.MustCompile(k)] = v } // Middleware function return func(c fiber.Ctx) error { // Next request to skip middleware if cfg.Next != nil && cfg.Next(c) { return c.Next() } // Rewrite for k, v := range cfg.rulesRegex { replacer := captureTokens(k, c.Path()) if replacer != nil { queryString := utils.UnsafeString(c.RequestCtx().QueryArgs().QueryString()) if queryString != "" { queryString = "?" + queryString } return c.Redirect().Status(cfg.StatusCode).To(replacer.Replace(v) + queryString) } } return c.Next() } } // https://github.com/labstack/echo/blob/master/middleware/rewrite.go func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer { if len(input) > 1 { input = utils.TrimRight(input, '/') } groups := pattern.FindAllStringSubmatch(input, -1) if groups == nil { return nil } values := groups[0][1:] replace := make([]string, 2*len(values)) for i, v := range values { j := 2 * i replace[j] = "$" + strconv.Itoa(i+1) replace[j+1] = v } return strings.NewReplacer(replace...) } ================================================ FILE: middleware/redirect/redirect_test.go ================================================ package redirect import ( "context" "net/http" "testing" "github.com/gofiber/fiber/v3" "github.com/stretchr/testify/require" ) func Test_Redirect(t *testing.T) { app := *fiber.New() app.Use(New(Config{ Rules: map[string]string{ "/default": "google.com", }, StatusCode: fiber.StatusMovedPermanently, })) app.Use(New(Config{ Rules: map[string]string{ "/default/*": "fiber.wiki", }, StatusCode: fiber.StatusTemporaryRedirect, })) app.Use(New(Config{ Rules: map[string]string{ "/redirect/*": "$1", }, StatusCode: fiber.StatusSeeOther, })) app.Use(New(Config{ Rules: map[string]string{ "/pattern/*": "golang.org", }, StatusCode: fiber.StatusFound, })) app.Use(New(Config{ Rules: map[string]string{ "/": "/swagger", }, StatusCode: fiber.StatusMovedPermanently, })) app.Use(New(Config{ Rules: map[string]string{ "/params": "/with_params", }, StatusCode: fiber.StatusMovedPermanently, })) app.Get("/api/*", func(c fiber.Ctx) error { return c.SendString("API") }) app.Get("/new", func(c fiber.Ctx) error { return c.SendString("Hello, World!") }) tests := []struct { name string url string redirectTo string statusCode int }{ { name: "should be returns status StatusFound without a wildcard", url: "/default", redirectTo: "google.com", statusCode: fiber.StatusMovedPermanently, }, { name: "should be returns status StatusTemporaryRedirect using wildcard", url: "/default/xyz", redirectTo: "fiber.wiki", statusCode: fiber.StatusTemporaryRedirect, }, { name: "should be returns status StatusSeeOther without set redirectTo to use the default", url: "/redirect/github.com/gofiber/redirect", redirectTo: "github.com/gofiber/redirect", statusCode: fiber.StatusSeeOther, }, { name: "should return the status code default", url: "/pattern/xyz", redirectTo: "golang.org", statusCode: fiber.StatusFound, }, { name: "access URL without rule", url: "/new", statusCode: fiber.StatusOK, }, { name: "redirect to swagger route", url: "/", redirectTo: "/swagger", statusCode: fiber.StatusMovedPermanently, }, { name: "no redirect to swagger route", url: "/api/", statusCode: fiber.StatusOK, }, { name: "no redirect to swagger route #2", url: "/api/test", statusCode: fiber.StatusOK, }, { name: "redirect with query params", url: "/params?query=abc", redirectTo: "/with_params?query=abc", statusCode: fiber.StatusMovedPermanently, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, tt.url, http.NoBody) require.NoError(t, err) req.Header.Set("Location", "github.com/gofiber/redirect") resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, tt.statusCode, resp.StatusCode) require.Equal(t, tt.redirectTo, resp.Header.Get("Location")) }) } } func Test_Next(t *testing.T) { // Case 1 : Next function always returns true app := *fiber.New() app.Use(New(Config{ Next: func(fiber.Ctx) bool { return true }, Rules: map[string]string{ "/default": "google.com", }, StatusCode: fiber.StatusMovedPermanently, })) app.Use(func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/default", http.NoBody) require.NoError(t, err) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) // Case 2 : Next function always returns false app = *fiber.New() app.Use(New(Config{ Next: func(fiber.Ctx) bool { return false }, Rules: map[string]string{ "/default": "google.com", }, StatusCode: fiber.StatusMovedPermanently, })) req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/default", http.NoBody) require.NoError(t, err) resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusMovedPermanently, resp.StatusCode) require.Equal(t, "google.com", resp.Header.Get("Location")) } func Test_NoRules(t *testing.T) { // Case 1: No rules with default route defined app := *fiber.New() app.Use(New(Config{ StatusCode: fiber.StatusMovedPermanently, })) app.Use(func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/default", http.NoBody) require.NoError(t, err) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) // Case 2: No rules and no default route defined app = *fiber.New() app.Use(New(Config{ StatusCode: fiber.StatusMovedPermanently, })) req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/default", http.NoBody) require.NoError(t, err) resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusNotFound, resp.StatusCode) } func Test_DefaultConfig(t *testing.T) { // Case 1: Default config and no default route app := *fiber.New() app.Use(New()) req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/default", http.NoBody) require.NoError(t, err) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusNotFound, resp.StatusCode) // Case 2: Default config and default route app = *fiber.New() app.Use(New()) app.Use(func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/default", http.NoBody) require.NoError(t, err) resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) } func Test_RegexRules(t *testing.T) { // Case 1: Rules regex is empty app := *fiber.New() app.Use(New(Config{ Rules: map[string]string{}, StatusCode: fiber.StatusMovedPermanently, })) app.Use(func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/default", http.NoBody) require.NoError(t, err) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) // Case 2: Rules regex map contains valid regex and well-formed replacement URLs app = *fiber.New() app.Use(New(Config{ Rules: map[string]string{ "/default": "google.com", }, StatusCode: fiber.StatusMovedPermanently, })) app.Use(func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/default", http.NoBody) require.NoError(t, err) resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusMovedPermanently, resp.StatusCode) require.Equal(t, "google.com", resp.Header.Get("Location")) // Case 3: Test invalid regex throws panic app = *fiber.New() require.Panics(t, func() { app.Use(New(Config{ Rules: map[string]string{ "(": "google.com", }, StatusCode: fiber.StatusMovedPermanently, })) }) } ================================================ FILE: middleware/requestid/config.go ================================================ package requestid import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/utils/v2" ) // Config defines the config for middleware. type Config struct { // Next defines a function to skip this middleware when returned true. // // Optional. Default: nil Next func(c fiber.Ctx) bool // Generator defines a function to generate the unique identifier. // // Optional. Default: utils.SecureToken Generator func() string // Header is the header key where to get/set the unique request ID // // Optional. Default: "X-Request-ID" Header string } // ConfigDefault is the default config // It uses a secure token generator for better privacy and security. var ConfigDefault = Config{ Next: nil, Header: fiber.HeaderXRequestID, Generator: utils.SecureToken, } // Helper function to set default values func configDefault(config ...Config) Config { // Return default config if nothing provided if len(config) < 1 { return ConfigDefault } // Override default config cfg := config[0] // Set default values if cfg.Header == "" { cfg.Header = ConfigDefault.Header } if cfg.Generator == nil { cfg.Generator = ConfigDefault.Generator } return cfg } ================================================ FILE: middleware/requestid/requestid.go ================================================ package requestid import ( "github.com/gofiber/fiber/v3" "github.com/gofiber/utils/v2" ) // The contextKey type is unexported to prevent collisions with context keys defined in // other packages. type contextKey int // The keys for the values in context const ( requestIDKey contextKey = iota ) // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Set default config cfg := configDefault(config...) // Return new handler return func(c fiber.Ctx) error { // Don't execute middleware if Next returns true if cfg.Next != nil && cfg.Next(c) { return c.Next() } rid := sanitizeRequestID(c.Get(cfg.Header), cfg.Generator) // Set new id to response header c.Set(cfg.Header, rid) // Add the request ID to locals fiber.StoreInContext(c, requestIDKey, rid) // Continue stack return c.Next() } } // sanitizeRequestID returns the provided request ID when it is valid, otherwise // it tries up to three values from the configured generator, then falls back to SecureToken. func sanitizeRequestID(rid string, generator func() string) string { if isValidRequestID(rid) { return rid } for range 3 { rid = generator() if isValidRequestID(rid) { return rid } } // Final fallback: SecureToken always produces a valid ID return utils.SecureToken() } // isValidRequestID reports whether the request ID contains only visible ASCII // characters (0x20–0x7E) and is non-empty. func isValidRequestID(rid string) bool { if rid == "" { return false } for i := 0; i < len(rid); i++ { c := rid[i] if c < 0x20 || c > 0x7e { return false } } return true } // FromContext returns the request ID from context. // It accepts fiber.CustomCtx, fiber.Ctx, *fasthttp.RequestCtx, and context.Context. // If there is no request ID, an empty string is returned. func FromContext(ctx any) string { if rid, ok := fiber.ValueFromContext[string](ctx, requestIDKey); ok { return rid } return "" } ================================================ FILE: middleware/requestid/requestid_test.go ================================================ package requestid import ( "net/http" "net/http/httptest" "testing" "github.com/gofiber/fiber/v3" "github.com/stretchr/testify/require" ) // go test -run Test_RequestID func Test_RequestID(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Hello, World 👋!") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) reqid := resp.Header.Get(fiber.HeaderXRequestID) require.Len(t, reqid, 43) req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Add(fiber.HeaderXRequestID, reqid) resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Equal(t, reqid, resp.Header.Get(fiber.HeaderXRequestID)) } func Test_RequestID_InvalidHeaderValue(t *testing.T) { t.Parallel() rid := sanitizeRequestID("bad\r\nid", func() string { return "clean-generated-id" }) require.Equal(t, "clean-generated-id", rid) } func Test_RequestID_InvalidGeneratedValue(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Generator: func() string { return "bad\r\nid" }, })) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) rid := resp.Header.Get(fiber.HeaderXRequestID) require.NotEmpty(t, rid) require.NotContains(t, rid, "\r") require.NotContains(t, rid, "\n") require.Len(t, rid, 43, "Fallback should produce a SecureToken") } func Test_RequestID_GeneratorAlwaysInvalid(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Generator: func() string { return "invalid\x00id" // Always invalid due to null byte }, })) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) rid := resp.Header.Get(fiber.HeaderXRequestID) require.NotEmpty(t, rid) require.Len(t, rid, 43, "Should fall back to SecureToken after 3 invalid attempts") } func Test_RequestID_CustomGenerator(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Generator: func() string { return "custom-valid-id" }, })) app.Get("/", func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) rid := resp.Header.Get(fiber.HeaderXRequestID) require.Equal(t, "custom-valid-id", rid) } func Test_isValidRequestID_VisibleASCII(t *testing.T) { t.Parallel() require.True(t, isValidRequestID("request-id-09AZaz ~")) } func Test_isValidRequestID_Boundaries(t *testing.T) { t.Parallel() t.Run("allows space and tilde", func(t *testing.T) { t.Parallel() require.True(t, isValidRequestID(" ~")) }) t.Run("rejects out of range", func(t *testing.T) { t.Parallel() require.False(t, isValidRequestID(string([]byte{0x1f}))) require.False(t, isValidRequestID(string([]byte{0x7f}))) }) t.Run("rejects empty", func(t *testing.T) { t.Parallel() require.False(t, isValidRequestID("")) }) } func Test_isValidRequestID_RejectsObsText(t *testing.T) { t.Parallel() require.False(t, isValidRequestID("valid\xff")) } // go test -run Test_RequestID_Next func Test_RequestID_Next(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Next: func(_ fiber.Ctx) bool { return true }, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Empty(t, resp.Header.Get(fiber.HeaderXRequestID)) require.Equal(t, fiber.StatusNotFound, resp.StatusCode) } // go test -run Test_RequestID_Locals func Test_RequestID_FromContext(t *testing.T) { t.Parallel() reqID := "ThisIsARequestId" app := fiber.New() app.Use(New(Config{ Generator: func() string { return reqID }, })) var ctxVal string app.Use(func(c fiber.Ctx) error { ctxVal = FromContext(c) return c.Next() }) _, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, reqID, ctxVal) } func Test_RequestID_FromContext_Empty(t *testing.T) { t.Parallel() app := fiber.New() // No middleware app.Use(func(c fiber.Ctx) error { ctxVal := FromContext(c) require.Empty(t, ctxVal) return c.SendStatus(fiber.StatusOK) }) _, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) } func Test_RequestID_FromContext_Types(t *testing.T) { t.Parallel() reqID := "request-id-123" app := fiber.New(fiber.Config{PassLocalsToContext: true}) app.Use(New(Config{ Generator: func() string { return reqID }, })) app.Get("/", func(c fiber.Ctx) error { require.Equal(t, reqID, FromContext(c)) customCtx, ok := c.(fiber.CustomCtx) require.True(t, ok) require.Equal(t, reqID, FromContext(customCtx)) require.Equal(t, reqID, FromContext(c.RequestCtx())) require.Equal(t, reqID, FromContext(c.Context())) return c.SendStatus(fiber.StatusOK) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) } ================================================ FILE: middleware/responsetime/config.go ================================================ package responsetime import ( "github.com/gofiber/fiber/v3" ) // Config defines the config for middleware. type Config struct { // Next defines a function to skip this middleware when returned true. // // Optional. Default: nil Next func(c fiber.Ctx) bool // Header is the header key used to set the response time. // // Optional. Default: "X-Response-Time" Header string } // ConfigDefault is the default config. var ConfigDefault = Config{ Next: nil, Header: fiber.HeaderXResponseTime, } // Helper function to set default values. func configDefault(config ...Config) Config { // Return default config if nothing provided if len(config) < 1 { return ConfigDefault } // Override default config cfg := config[0] // Set default values if cfg.Header == "" { cfg.Header = ConfigDefault.Header } return cfg } ================================================ FILE: middleware/responsetime/responsetime.go ================================================ package responsetime import ( "time" "github.com/gofiber/fiber/v3" ) // New creates a new middleware handler. func New(config ...Config) fiber.Handler { // Set default config cfg := configDefault(config...) // Return new handler return func(c fiber.Ctx) error { // Don't execute middleware if Next returns true if cfg.Next != nil && cfg.Next(c) { return c.Next() } start := time.Now() err := c.Next() c.Set(cfg.Header, time.Since(start).String()) return err } } ================================================ FILE: middleware/responsetime/responsetime_test.go ================================================ package responsetime import ( "errors" "net/http" "net/http/httptest" "testing" "time" "github.com/gofiber/fiber/v3" "github.com/stretchr/testify/require" ) func TestResponseTimeMiddleware(t *testing.T) { t.Parallel() boom := errors.New("boom") tests := []struct { name string expectedStatus int useCustomErrorHandler bool returnError bool expectHeader bool skipWithNext bool }{ { name: "sets duration header", expectedStatus: fiber.StatusOK, expectHeader: true, }, { name: "skips when Next returns true", expectedStatus: fiber.StatusOK, expectHeader: false, skipWithNext: true, }, { name: "propagates errors", expectedStatus: fiber.StatusTeapot, useCustomErrorHandler: true, returnError: true, expectHeader: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() configs := []Config(nil) if tt.skipWithNext { configs = []Config{{ Next: func(fiber.Ctx) bool { return true }, }} } appConfig := fiber.Config{} if tt.useCustomErrorHandler { appConfig.ErrorHandler = func(c fiber.Ctx, err error) error { t.Helper() require.ErrorIs(t, err, boom) return c.Status(fiber.StatusTeapot).SendString(err.Error()) } } app := fiber.New(appConfig) app.Use(New(configs...)) app.Get("/", func(c fiber.Ctx) error { if tt.returnError { return boom } return c.SendStatus(fiber.StatusOK) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, tt.expectedStatus, resp.StatusCode) header := resp.Header.Get(fiber.HeaderXResponseTime) if tt.expectHeader { require.NotEmpty(t, header) _, parseErr := time.ParseDuration(header) require.NoError(t, parseErr) return } require.Empty(t, header) }) } } ================================================ FILE: middleware/rewrite/config.go ================================================ package rewrite import ( "regexp" "github.com/gofiber/fiber/v3" ) // Config defines the config for middleware. type Config struct { // Next defines a function to skip middleware. // Optional. Default: nil Next func(fiber.Ctx) bool // Rules defines the URL path rewrite rules. The values captured in asterisk can be // retrieved by index e.g. $1, $2 and so on. // Required. Example: // "/old": "/new", // "/api/*": "/$1", // "/js/*": "/public/javascript/$1", // "/users/*/orders/*": "/user/$1/order/$2", Rules map[string]string rulesRegex map[*regexp.Regexp]string } // Helper function to set default values func configDefault(config ...Config) Config { // Return default config if nothing provided if len(config) < 1 { return Config{} } // Override default config cfg := config[0] return cfg } ================================================ FILE: middleware/rewrite/rewrite.go ================================================ package rewrite import ( "regexp" "strconv" "strings" "github.com/gofiber/fiber/v3" ) // New creates a new middleware handler func New(config ...Config) fiber.Handler { cfg := configDefault(config...) // Initialize cfg.rulesRegex = map[*regexp.Regexp]string{} for k, v := range cfg.Rules { k = strings.ReplaceAll(k, "*", "(.*)") k += "$" cfg.rulesRegex[regexp.MustCompile(k)] = v } // Middleware function return func(c fiber.Ctx) error { // Next request to skip middleware if cfg.Next != nil && cfg.Next(c) { return c.Next() } // Rewrite for k, v := range cfg.rulesRegex { replacer := captureTokens(k, c.Path()) if replacer != nil { c.Path(replacer.Replace(v)) break } } return c.Next() } } // https://github.com/labstack/echo/blob/master/middleware/rewrite.go func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer { groups := pattern.FindAllStringSubmatch(input, -1) if groups == nil { return nil } values := groups[0][1:] replace := make([]string, 2*len(values)) for i, v := range values { j := 2 * i replace[j] = "$" + strconv.Itoa(i+1) replace[j+1] = v } return strings.NewReplacer(replace...) } ================================================ FILE: middleware/rewrite/rewrite_test.go ================================================ package rewrite import ( "context" "fmt" "io" "net/http" "testing" "github.com/gofiber/fiber/v3" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) func Test_New(t *testing.T) { // Test with no config m := New() if m == nil { t.Error("Expected middleware to be returned, got nil") } // Test with config m = New(Config{ Rules: map[string]string{ "/old": "/new", }, }) if m == nil { t.Error("Expected middleware to be returned, got nil") } // Test with full config m = New(Config{ Next: func(fiber.Ctx) bool { return true }, Rules: map[string]string{ "/old": "/new", }, }) if m == nil { t.Error("Expected middleware to be returned, got nil") } } func Test_Rewrite(t *testing.T) { // Case 1: Next function always returns true app := fiber.New() app.Use(New(Config{ Next: func(fiber.Ctx) bool { return true }, Rules: map[string]string{ "/old": "/new", }, })) app.Get("/old", func(c fiber.Ctx) error { return c.SendString("Rewrite Successful") }) req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/old", http.NoBody) require.NoError(t, err) resp, err := app.Test(req) require.NoError(t, err) body, err := io.ReadAll(resp.Body) require.NoError(t, err) bodyString := string(body) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Equal(t, "Rewrite Successful", bodyString) // Case 2: Next function always returns false app = fiber.New() app.Use(New(Config{ Next: func(fiber.Ctx) bool { return false }, Rules: map[string]string{ "/old": "/new", }, })) app.Get("/new", func(c fiber.Ctx) error { return c.SendString("Rewrite Successful") }) req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/old", http.NoBody) require.NoError(t, err) resp, err = app.Test(req) require.NoError(t, err) body, err = io.ReadAll(resp.Body) require.NoError(t, err) bodyString = string(body) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Equal(t, "Rewrite Successful", bodyString) // Case 3: check for captured tokens in rewrite rule app = fiber.New() app.Use(New(Config{ Rules: map[string]string{ "/users/*/orders/*": "/user/$1/order/$2", }, })) app.Get("/user/:userID/order/:orderID", func(c fiber.Ctx) error { return c.SendString(fmt.Sprintf("User ID: %s, Order ID: %s", c.Params("userID"), c.Params("orderID"))) }) req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/users/123/orders/456", http.NoBody) require.NoError(t, err) resp, err = app.Test(req) require.NoError(t, err) body, err = io.ReadAll(resp.Body) require.NoError(t, err) bodyString = string(body) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Equal(t, "User ID: 123, Order ID: 456", bodyString) // Case 4: Send non-matching request, handled by default route app = fiber.New() app.Use(New(Config{ Rules: map[string]string{ "/users/*/orders/*": "/user/$1/order/$2", }, })) app.Get("/user/:userID/order/:orderID", func(c fiber.Ctx) error { return c.SendString(fmt.Sprintf("User ID: %s, Order ID: %s", c.Params("userID"), c.Params("orderID"))) }) app.Use(func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/not-matching-any-rule", http.NoBody) require.NoError(t, err) resp, err = app.Test(req) require.NoError(t, err) body, err = io.ReadAll(resp.Body) require.NoError(t, err) bodyString = string(body) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Equal(t, "OK", bodyString) // Case 4: Send non-matching request, with no default route app = fiber.New() app.Use(New(Config{ Rules: map[string]string{ "/users/*/orders/*": "/user/$1/order/$2", }, })) app.Get("/user/:userID/order/:orderID", func(c fiber.Ctx) error { return c.SendString(fmt.Sprintf("User ID: %s, Order ID: %s", c.Params("userID"), c.Params("orderID"))) }) req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/not-matching-any-rule", http.NoBody) require.NoError(t, err) resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusNotFound, resp.StatusCode) } func Benchmark_Rewrite(b *testing.B) { // Helper function to create a new Fiber app with rewrite middleware createApp := func(config Config) *fiber.App { app := fiber.New() app.Use(New(config)) return app } // Benchmark: Rewrite with Next function always returns true b.Run("Next always true", func(b *testing.B) { app := createApp(Config{ Next: func(fiber.Ctx) bool { return true }, Rules: map[string]string{ "/old": "/new", }, }) reqCtx := &fasthttp.RequestCtx{} reqCtx.Request.SetRequestURI("/old") b.ReportAllocs() for b.Loop() { app.Handler()(reqCtx) } }) // Benchmark: Rewrite with Next function always returns false b.Run("Next always false", func(b *testing.B) { app := createApp(Config{ Next: func(fiber.Ctx) bool { return false }, Rules: map[string]string{ "/old": "/new", }, }) reqCtx := &fasthttp.RequestCtx{} reqCtx.Request.SetRequestURI("/old") b.ReportAllocs() for b.Loop() { app.Handler()(reqCtx) } }) // Benchmark: Rewrite with tokens b.Run("Rewrite with tokens", func(b *testing.B) { app := createApp(Config{ Rules: map[string]string{ "/users/*/orders/*": "/user/$1/order/$2", }, }) reqCtx := &fasthttp.RequestCtx{} reqCtx.Request.SetRequestURI("/users/123/orders/456") b.ReportAllocs() for b.Loop() { app.Handler()(reqCtx) } }) // Benchmark: Non-matching request, handled by default route b.Run("NonMatch with default", func(b *testing.B) { app := createApp(Config{ Rules: map[string]string{ "/users/*/orders/*": "/user/$1/order/$2", }, }) app.Use(func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) reqCtx := &fasthttp.RequestCtx{} reqCtx.Request.SetRequestURI("/not-matching-any-rule") b.ReportAllocs() for b.Loop() { app.Handler()(reqCtx) } }) // Benchmark: Non-matching request, with no default route b.Run("NonMatch without default", func(b *testing.B) { app := createApp(Config{ Rules: map[string]string{ "/users/*/orders/*": "/user/$1/order/$2", }, }) reqCtx := &fasthttp.RequestCtx{} reqCtx.Request.SetRequestURI("/not-matching-any-rule") b.ReportAllocs() for b.Loop() { app.Handler()(reqCtx) } }) } func Benchmark_Rewrite_Parallel(b *testing.B) { // Helper function to create a new Fiber app with rewrite middleware createApp := func(config Config) *fiber.App { app := fiber.New() app.Use(New(config)) return app } // Parallel Benchmark: Rewrite with Next function always returns true b.Run("Next always true", func(b *testing.B) { app := createApp(Config{ Next: func(fiber.Ctx) bool { return true }, Rules: map[string]string{ "/old": "/new", }, }) b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { reqCtx := &fasthttp.RequestCtx{} reqCtx.Request.SetRequestURI("/old") for pb.Next() { app.Handler()(reqCtx) } }) }) // Parallel Benchmark: Rewrite with Next function always returns false b.Run("Next always false", func(b *testing.B) { app := createApp(Config{ Next: func(fiber.Ctx) bool { return false }, Rules: map[string]string{ "/old": "/new", }, }) b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { reqCtx := &fasthttp.RequestCtx{} reqCtx.Request.SetRequestURI("/old") for pb.Next() { app.Handler()(reqCtx) } }) }) // Parallel Benchmark: Rewrite with tokens b.Run("Rewrite with tokens", func(b *testing.B) { app := createApp(Config{ Rules: map[string]string{ "/users/*/orders/*": "/user/$1/order/$2", }, }) b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { reqCtx := &fasthttp.RequestCtx{} reqCtx.Request.SetRequestURI("/users/123/orders/456") for pb.Next() { app.Handler()(reqCtx) } }) }) // Parallel Benchmark: Non-matching request, handled by default route b.Run("NonMatch with default", func(b *testing.B) { app := createApp(Config{ Rules: map[string]string{ "/users/*/orders/*": "/user/$1/order/$2", }, }) app.Use(func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { reqCtx := &fasthttp.RequestCtx{} reqCtx.Request.SetRequestURI("/not-matching-any-rule") for pb.Next() { app.Handler()(reqCtx) } }) }) // Parallel Benchmark: Non-matching request, with no default route b.Run("NonMatch without default", func(b *testing.B) { app := createApp(Config{ Rules: map[string]string{ "/users/*/orders/*": "/user/$1/order/$2", }, }) b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { reqCtx := &fasthttp.RequestCtx{} reqCtx.Request.SetRequestURI("/not-matching-any-rule") for pb.Next() { app.Handler()(reqCtx) } }) }) } ================================================ FILE: middleware/session/config.go ================================================ package session import ( "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/extractors" "github.com/gofiber/fiber/v3/log" "github.com/gofiber/utils/v2" ) // Config defines the configuration for the session middleware. type Config struct { // Storage interface for storing session data. // // Optional. Default: memory.New() Storage fiber.Storage // Store defines the session store. // // Required. Store *Store // Next defines a function to skip this middleware when it returns true. // Optional. Default: nil Next func(c fiber.Ctx) bool // ErrorHandler defines a function to handle errors. // // Optional. Default: nil ErrorHandler func(fiber.Ctx, error) // KeyGenerator generates the session key. // // Optional. Default: utils.SecureToken KeyGenerator func() string // CookieDomain defines the domain of the session cookie. // // Optional. Default: "" CookieDomain string // CookiePath defines the path of the session cookie. // // Optional. Default: "" CookiePath string // CookieSameSite specifies the SameSite attribute of the cookie. // // Optional. Default: "Lax" CookieSameSite string // Extractor is used to extract the session ID from the request. // See: https://docs.gofiber.io/guide/extractors // // Optional. Default: extractors.FromCookie("session_id") Extractor extractors.Extractor // IdleTimeout defines the maximum duration of inactivity before the session expires. // // Note: The idle timeout is updated on each `Save()` call. If a middleware handler is used, `Save()` is called automatically. // // Optional. Default: 30 * time.Minute IdleTimeout time.Duration // AbsoluteTimeout defines the maximum duration of the session before it expires. // // If set to 0, the session will not have an absolute timeout, and will expire after the idle timeout. // // Optional. Default: 0 AbsoluteTimeout time.Duration // CookieSecure specifies if the session cookie should be secure. // // Optional. Default: false CookieSecure bool // CookieHTTPOnly specifies if the session cookie should be HTTP-only. // // Optional. Default: false CookieHTTPOnly bool // CookieSessionOnly determines if the cookie should expire when the browser session ends. // // If true, the cookie will be deleted when the browser is closed. // Note: This will not delete the session data from the store. // // Optional. Default: false CookieSessionOnly bool } // ConfigDefault provides the default configuration. var ConfigDefault = Config{ IdleTimeout: 30 * time.Minute, KeyGenerator: utils.SecureToken, Extractor: extractors.FromCookie("session_id"), CookieSameSite: "Lax", } // DefaultErrorHandler logs the error and sends a 500 status code. // // Parameters: // - c: The Fiber context. // - err: The error to handle. // // Usage: // // DefaultErrorHandler(c, err) func DefaultErrorHandler(c fiber.Ctx, err error) { log.Errorf("session error: %v", err) if sendErr := c.Status(fiber.StatusInternalServerError).SendString("Internal Server Error"); sendErr != nil { log.Errorf("failed to send error response: %v", sendErr) } } // configDefault sets default values for the Config struct. // // This function ensures that all necessary fields have sensible defaults // if they are not explicitly set by the user. // // Parameters: // - config: Variadic parameter to override default config. // // Returns: // - Config: The configuration with defaults applied. // // Usage: // // cfg := configDefault() // cfg := configDefault(customConfig) func configDefault(config ...Config) Config { // Return default config if nothing provided if len(config) < 1 { return ConfigDefault } // Override default config cfg := config[0] // Set default values if cfg.IdleTimeout <= 0 { cfg.IdleTimeout = ConfigDefault.IdleTimeout } // Ensure AbsoluteTimeout is greater than or equal to IdleTimeout. if cfg.AbsoluteTimeout > 0 && cfg.AbsoluteTimeout < cfg.IdleTimeout { panic("[session] AbsoluteTimeout must be greater than or equal to IdleTimeout") } // Check if we have a zero-value Extractor if cfg.Extractor.Extract == nil { cfg.Extractor = ConfigDefault.Extractor } if cfg.KeyGenerator == nil { cfg.KeyGenerator = ConfigDefault.KeyGenerator } if cfg.CookieSameSite == "" { cfg.CookieSameSite = ConfigDefault.CookieSameSite } return cfg } ================================================ FILE: middleware/session/config_test.go ================================================ package session import ( "testing" "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/extractors" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) func TestConfigDefault(t *testing.T) { // Test default config cfg := configDefault() require.Equal(t, 30*time.Minute, cfg.IdleTimeout) require.NotNil(t, cfg.KeyGenerator) require.NotNil(t, cfg.Extractor) require.Equal(t, "session_id", cfg.Extractor.Key) require.Equal(t, "Lax", cfg.CookieSameSite) } func TestConfigDefaultWithCustomConfig(t *testing.T) { // Test custom config customConfig := Config{ IdleTimeout: 48 * time.Hour, Extractor: extractors.FromHeader("X-Custom-Session"), KeyGenerator: func() string { return "custom_key" }, } cfg := configDefault(customConfig) require.Equal(t, 48*time.Hour, cfg.IdleTimeout) require.NotNil(t, cfg.KeyGenerator) require.NotNil(t, cfg.Extractor) require.Equal(t, "X-Custom-Session", cfg.Extractor.Key) } func TestDefaultErrorHandler(t *testing.T) { // Create a new Fiber app app := fiber.New() // Create a new context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) // Test DefaultErrorHandler DefaultErrorHandler(ctx, fiber.ErrInternalServerError) require.Equal(t, fiber.StatusInternalServerError, ctx.Response().StatusCode()) } func TestAbsoluteTimeoutValidation(t *testing.T) { require.PanicsWithValue(t, "[session] AbsoluteTimeout must be greater than or equal to IdleTimeout", func() { configDefault(Config{ IdleTimeout: 30 * time.Minute, AbsoluteTimeout: 15 * time.Minute, // Less than IdleTimeout }) }) } ================================================ FILE: middleware/session/data.go ================================================ package session import ( "sync" ) // msgp -file="data.go" -o="data_msgp.go" -tests=true -unexported // Session state should remain small to fit common storage payload limits. // //go:generate msgp -o=data_msgp.go -tests=true -unexported //msgp:ignore data type data struct { Data map[any]any // Session key counts are expected to be bounded. sync.RWMutex `msg:"-"` } var dataPool = sync.Pool{ New: func() any { d := new(data) d.Data = make(map[any]any) return d }, } // acquireData returns a new data object from the pool. // // Returns: // - *data: The data object. // // Usage: // // d := acquireData() func acquireData() *data { obj := dataPool.Get() if d, ok := obj.(*data); ok { return d } // Handle unexpected type in the pool panic("unexpected type in data pool") } // Reset clears the data map and resets the data object. // // Usage: // // d.Reset() func (d *data) Reset() { d.Lock() defer d.Unlock() clear(d.Data) } // Get retrieves a value from the data map by key. // // Parameters: // - key: The key to retrieve. // // Returns: // - any: The value associated with the key. // // Usage: // // value := d.Get("key") func (d *data) Get(key any) any { d.RLock() defer d.RUnlock() return d.Data[key] } // Set updates or creates a new key-value pair in the data map. // // Parameters: // - key: The key to set. // - value: The value to set. // // Usage: // // d.Set("key", "value") func (d *data) Set(key, value any) { d.Lock() defer d.Unlock() d.Data[key] = value } // Delete removes a key-value pair from the data map. // // Parameters: // - key: The key to delete. // // Usage: // // d.Delete("key") func (d *data) Delete(key any) { d.Lock() defer d.Unlock() delete(d.Data, key) } // Keys retrieves all keys in the data map. // // Returns: // - []any: A slice of all keys in the data map. // // Usage: // // keys := d.Keys() func (d *data) Keys() []any { d.RLock() defer d.RUnlock() keys := make([]any, 0, len(d.Data)) for k := range d.Data { keys = append(keys, k) } return keys } // Len returns the number of key-value pairs in the data map. // // Returns: // - int: The number of key-value pairs. // // Usage: // // length := d.Len() func (d *data) Len() int { d.RLock() defer d.RUnlock() return len(d.Data) } ================================================ FILE: middleware/session/data_msgp.go ================================================ // Code generated by github.com/tinylib/msgp DO NOT EDIT. package session ================================================ FILE: middleware/session/data_msgp_test.go ================================================ // Code generated by github.com/tinylib/msgp DO NOT EDIT. package session ================================================ FILE: middleware/session/data_test.go ================================================ package session import ( "reflect" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestKeys(t *testing.T) { t.Parallel() // Test case: Empty data t.Run("Empty data", func(t *testing.T) { t.Parallel() d := acquireData() d.Reset() // Ensure clean state from pool defer dataPool.Put(d) defer d.Reset() keys := d.Keys() require.Empty(t, keys, "Expected no keys in empty data") }) // Test case: Single key t.Run("Single key", func(t *testing.T) { t.Parallel() d := acquireData() d.Reset() // Ensure clean state from pool defer dataPool.Put(d) defer d.Reset() d.Set("key1", "value1") keys := d.Keys() require.Len(t, keys, 1, "Expected one key") require.Contains(t, keys, "key1", "Expected key1 to be present") }) // Test case: Multiple keys t.Run("Multiple keys", func(t *testing.T) { t.Parallel() d := acquireData() d.Reset() // Ensure clean state from pool defer dataPool.Put(d) defer d.Reset() d.Set("key1", "value1") d.Set("key2", "value2") d.Set("key3", "value3") keys := d.Keys() require.Len(t, keys, 3, "Expected three keys") require.Contains(t, keys, "key1", "Expected key1 to be present") require.Contains(t, keys, "key2", "Expected key2 to be present") require.Contains(t, keys, "key3", "Expected key3 to be present") }) // Test case: Concurrent access t.Run("Concurrent access", func(t *testing.T) { t.Parallel() d := acquireData() d.Reset() // Ensure clean state from pool defer dataPool.Put(d) defer d.Reset() d.Set("key1", "value1") d.Set("key2", "value2") d.Set("key3", "value3") done := make(chan bool) go func() { keys := d.Keys() assert.Len(t, keys, 3, "Expected three keys") done <- true }() go func() { keys := d.Keys() assert.Len(t, keys, 3, "Expected three keys") done <- true }() <-done <-done }) } func TestData_Len(t *testing.T) { t.Parallel() // Test case: Empty data t.Run("Empty data", func(t *testing.T) { t.Parallel() d := acquireData() d.Reset() // Ensure clean state from pool defer dataPool.Put(d) defer d.Reset() length := d.Len() require.Equal(t, 0, length, "Expected length to be 0 for empty data") }) // Test case: Single key t.Run("Single key", func(t *testing.T) { t.Parallel() d := acquireData() d.Reset() // Ensure clean state from pool defer dataPool.Put(d) defer d.Reset() d.Set("key1", "value1") length := d.Len() require.Equal(t, 1, length, "Expected length to be 1 when one key is set") }) // Test case: Multiple keys t.Run("Multiple keys", func(t *testing.T) { t.Parallel() d := acquireData() d.Reset() // Ensure clean state from pool defer dataPool.Put(d) defer d.Reset() d.Set("key1", "value1") d.Set("key2", "value2") d.Set("key3", "value3") length := d.Len() require.Equal(t, 3, length, "Expected length to be 3 when three keys are set") }) // Test case: Concurrent access t.Run("Concurrent access", func(t *testing.T) { t.Parallel() d := acquireData() d.Reset() // Ensure clean state from pool defer dataPool.Put(d) defer d.Reset() d.Set("key1", "value1") d.Set("key2", "value2") d.Set("key3", "value3") done := make(chan bool, 2) // Buffered channel with size 2 go func() { length := d.Len() assert.Equal(t, 3, length, "Expected length to be 3 during concurrent access") done <- true }() go func() { length := d.Len() assert.Equal(t, 3, length, "Expected length to be 3 during concurrent access") done <- true }() <-done <-done }) } func TestData_Get(t *testing.T) { t.Parallel() // Test case: Nonexistent key t.Run("Nonexistent key", func(t *testing.T) { t.Parallel() d := acquireData() d.Reset() // Ensure clean state from pool defer dataPool.Put(d) defer d.Reset() value := d.Get("nonexistent-key") require.Nil(t, value, "Expected nil for nonexistent key") }) // Test case: Existing key t.Run("Existing key", func(t *testing.T) { t.Parallel() d := acquireData() d.Reset() // Ensure clean state from pool defer dataPool.Put(d) defer d.Reset() d.Set("key1", "value1") value := d.Get("key1") require.Equal(t, "value1", value, "Expected value1 for key1") }) } func TestData_Reset(t *testing.T) { t.Parallel() // Test case: Reset data t.Run("Reset data", func(t *testing.T) { t.Parallel() d := acquireData() d.Reset() // Ensure clean state from pool defer dataPool.Put(d) d.Set("key1", "value1") d.Set("key2", "value2") d.Reset() requireDataEmpty(t, d, "Expected data map to be empty after reset") }) } func mapPointer(m map[any]any) uintptr { return reflect.ValueOf(m).Pointer() } func lockedMapPointer(d *data) uintptr { d.RLock() defer d.RUnlock() return mapPointer(d.Data) } func requireDataEmpty(t *testing.T, d *data, msg string) { t.Helper() d.RLock() defer d.RUnlock() require.Empty(t, d.Data, msg) } func TestData_ResetPreservesAllocation(t *testing.T) { t.Parallel() d := acquireData() d.Reset() // Ensure clean state from pool t.Cleanup(func() { d.Reset() dataPool.Put(d) }) originalPtr := lockedMapPointer(d) d.Set("key1", "value1") d.Set("key2", "value2") require.Equal(t, originalPtr, lockedMapPointer(d), "Expected map pointer to stay constant after writes") d.Reset() requireDataEmpty(t, d, "Expected data map to be empty after reset") require.Equal(t, originalPtr, lockedMapPointer(d), "Expected reset to preserve underlying map") d.Set("key3", "value3") require.Nil(t, d.Get("key1"), "Expected cleared key not to leak after reset") require.Equal(t, originalPtr, lockedMapPointer(d), "Expected map pointer to remain stable after further writes") } func TestData_PoolReuseDoesNotLeakEntries(t *testing.T) { t.Parallel() acquired := make([]*data, 0, 6) t.Cleanup(func() { for _, item := range acquired { item.Reset() dataPool.Put(item) } }) acquireWithCleanup := func() *data { d := acquireData() acquired = append(acquired, d) return d } first := acquireWithCleanup() first.Set("key1", "value1") first.Set("key2", "value2") first.Reset() originalPtr := lockedMapPointer(first) dataPool.Put(first) var reused *data for i := 0; i < 5; i++ { candidate := acquireWithCleanup() if lockedMapPointer(candidate) == originalPtr { reused = candidate break } requireDataEmpty(t, candidate, "Expected pooled data to be empty when new instance is returned") require.Nil(t, candidate.Get("key2"), "Expected no leakage of prior entries on alternate pooled instance") } if reused == nil { t.Skip("sync.Pool returned a different instance; reuse cannot be asserted") return } require.Equal(t, originalPtr, lockedMapPointer(reused), "Expected pooled data to reuse cleared map") requireDataEmpty(t, reused, "Expected pooled data to be empty after reuse") require.Nil(t, reused.Get("key2"), "Expected no leakage of prior entries on reuse") reused.Set("key4", "value4") require.Equal(t, "value4", reused.Get("key4"), "Expected pooled map to accept new values") } func TestData_Delete(t *testing.T) { t.Parallel() // Test case: Delete existing key t.Run("Delete existing key", func(t *testing.T) { t.Parallel() d := acquireData() d.Reset() // Ensure clean state from pool defer dataPool.Put(d) defer d.Reset() d.Set("key1", "value1") d.Delete("key1") value := d.Get("key1") require.Nil(t, value, "Expected nil for deleted key") }) // Test case: Delete nonexistent key t.Run("Delete nonexistent key", func(t *testing.T) { t.Parallel() d := acquireData() d.Reset() // Ensure clean state from pool defer dataPool.Put(d) defer d.Reset() d.Delete("nonexistent-key") // No assertion needed, just ensure no panic or error }) } ================================================ FILE: middleware/session/middleware.go ================================================ // Package session provides session management middleware for Fiber. // This middleware handles user sessions, including storing session data in the store. package session import ( "errors" "sync" "github.com/gofiber/fiber/v3" ) // Middleware holds session data and configuration. type Middleware struct { Session *Session ctx fiber.Ctx config Config mu sync.RWMutex destroyed bool } // Context key for session middleware lookup. type middlewareKey int const ( // middlewareContextKey is the key used to store the *Middleware in the context locals. middlewareContextKey middlewareKey = iota ) var ( // ErrTypeAssertionFailed occurs when a type assertion fails. ErrTypeAssertionFailed = errors.New("failed to type-assert to *Middleware") // Pool for reusing middleware instances. middlewarePool = &sync.Pool{ New: func() any { return &Middleware{} }, } ) // New initializes session middleware with optional configuration. // // Parameters: // - config: Variadic parameter to override default config. // // Returns: // - fiber.Handler: The Fiber handler for the session middleware. // // Usage: // // app.Use(session.New()) // // Usage: // // app.Use(session.New()) func New(config ...Config) fiber.Handler { if len(config) > 0 { handler, _ := NewWithStore(config[0]) return handler } handler, _ := NewWithStore() return handler } // NewWithStore creates session middleware with an optional custom store. // // Parameters: // - config: Variadic parameter to override default config. // // Returns: // - fiber.Handler: The Fiber handler for the session middleware. // - *Store: The session store. // // Usage: // // handler, store := session.NewWithStore() func NewWithStore(config ...Config) (fiber.Handler, *Store) { cfg := configDefault(config...) if cfg.Store == nil { cfg.Store = NewStore(cfg) } handler := func(c fiber.Ctx) error { if cfg.Next != nil && cfg.Next(c) { return c.Next() } // Acquire session middleware m := acquireMiddleware() m.initialize(c, &cfg) stackErr := c.Next() m.mu.RLock() destroyed := m.destroyed m.mu.RUnlock() if !destroyed { m.saveSession() } releaseMiddleware(m) return stackErr } return handler, cfg.Store } // initialize sets up middleware for the request. func (m *Middleware) initialize(c fiber.Ctx, cfg *Config) { m.mu.Lock() defer m.mu.Unlock() session, err := cfg.Store.getSession(c) if err != nil { panic(err) // handle or log this error appropriately in production } m.config = *cfg m.Session = session m.ctx = c fiber.StoreInContext(c, middlewareContextKey, m) } // saveSession handles session saving and error management after the response. func (m *Middleware) saveSession() { if err := m.Session.saveSession(); err != nil { if m.config.ErrorHandler != nil { m.config.ErrorHandler(m.ctx, err) } else { DefaultErrorHandler(m.ctx, err) } } releaseSession(m.Session) } // acquireMiddleware retrieves a middleware instance from the pool. func acquireMiddleware() *Middleware { m, ok := middlewarePool.Get().(*Middleware) if !ok { panic(ErrTypeAssertionFailed.Error()) } return m } // releaseMiddleware resets and returns middleware to the pool. // // Parameters: // - m: The middleware object to release. // // Usage: // // releaseMiddleware(m) func releaseMiddleware(m *Middleware) { m.mu.Lock() m.config = Config{} m.Session = nil m.ctx = nil m.destroyed = false m.mu.Unlock() middlewarePool.Put(m) } // FromContext returns the Middleware from the Fiber context. // It accepts fiber.CustomCtx, fiber.Ctx, *fasthttp.RequestCtx, and context.Context. // // Parameters: // - c: The Fiber context. // // Returns: // - *Middleware: The middleware object if found; otherwise, nil. // // Usage: // // m := session.FromContext(c) func FromContext(ctx any) *Middleware { if m, ok := fiber.ValueFromContext[*Middleware](ctx, middlewareContextKey); ok { return m } return nil } // Set sets a key-value pair in the session. // // Parameters: // - key: The key to set. // - value: The value to set. // // Usage: // // m.Set("key", "value") func (m *Middleware) Set(key, value any) { m.mu.Lock() defer m.mu.Unlock() m.Session.Set(key, value) } // Get retrieves a value from the session by key. // // Parameters: // - key: The key to retrieve. // // Returns: // - any: The value associated with the key. // // Usage: // // value := m.Get("key") func (m *Middleware) Get(key any) any { m.mu.RLock() defer m.mu.RUnlock() return m.Session.Get(key) } // Delete removes a key-value pair from the session. // // Parameters: // - key: The key to delete. // // Usage: // // m.Delete("key") func (m *Middleware) Delete(key any) { m.mu.Lock() defer m.mu.Unlock() m.Session.Delete(key) } // Keys returns all keys in the current session. // // Returns: // - []any: A slice of all keys in the session. // // Usage: // // keys := m.Keys() func (m *Middleware) Keys() []any { m.mu.RLock() defer m.mu.RUnlock() return m.Session.Keys() } // Destroy destroys the session. // // Returns: // - error: An error if the destruction fails. // // Usage: // // err := m.Destroy() func (m *Middleware) Destroy() error { m.mu.Lock() defer m.mu.Unlock() err := m.Session.Destroy() m.destroyed = true return err } // Fresh checks if the session is fresh. // // Returns: // - bool: True if the session is fresh; otherwise, false. // // Usage: // // isFresh := m.Fresh() func (m *Middleware) Fresh() bool { return m.Session.Fresh() } // ID returns the session ID. // // Returns: // - string: The session ID. // // Usage: // // id := m.ID() func (m *Middleware) ID() string { return m.Session.ID() } // Reset resets the session. // // Returns: // - error: An error if the reset fails. // // Usage: // // err := m.Reset() func (m *Middleware) Reset() error { m.mu.Lock() defer m.mu.Unlock() return m.Session.Reset() } // Regenerate generates a new session ID while preserving session data. // // This method is commonly used after authentication to prevent session fixation attacks. // Unlike Reset(), this method preserves all existing session data. // // Returns: // - error: An error if the regeneration fails. // // Usage: // // err := m.Regenerate() func (m *Middleware) Regenerate() error { m.mu.Lock() defer m.mu.Unlock() return m.Session.Regenerate() } // Store returns the session store. // // Returns: // - *Store: The session store. // // Usage: // // store := m.Store() func (m *Middleware) Store() *Store { return m.config.Store } ================================================ FILE: middleware/session/middleware_test.go ================================================ package session import ( "fmt" "net/http" "net/http/httptest" "sort" "strings" "sync" "testing" "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/extractors" "github.com/gofiber/utils/v2" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) func Test_Session_Middleware(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/get", func(c fiber.Ctx) error { sess := FromContext(c) if sess == nil { return c.SendStatus(fiber.StatusInternalServerError) } value, ok := sess.Get("key").(string) if !ok { return c.Status(fiber.StatusNotFound).SendString("key not found") } return c.SendString("value=" + value) }) app.Post("/set", func(c fiber.Ctx) error { sess := FromContext(c) if sess == nil { return c.SendStatus(fiber.StatusInternalServerError) } // get a value from the body value := c.FormValue("value") sess.Set("key", value) return c.SendStatus(fiber.StatusOK) }) app.Post("/delete", func(c fiber.Ctx) error { sess := FromContext(c) if sess == nil { return c.SendStatus(fiber.StatusInternalServerError) } sess.Delete("key") return c.SendStatus(fiber.StatusOK) }) app.Post("/reset", func(c fiber.Ctx) error { sess := FromContext(c) if sess == nil { return c.SendStatus(fiber.StatusInternalServerError) } // Set a value to ensure it is cleared after reset sess.Set("key", "value") if err := sess.Reset(); err != nil { return c.SendStatus(fiber.StatusInternalServerError) } // Ensure value is cleared value, ok := sess.Get("key").(string) if ok || value != "" { return c.SendStatus(fiber.StatusInternalServerError) } return c.SendStatus(fiber.StatusOK) }) app.Post("/regenerate", func(c fiber.Ctx) error { sess := FromContext(c) if sess == nil { return c.SendStatus(fiber.StatusInternalServerError) } // Set a value to ensure it is preserved after regeneration sess.Set("key", "value") // Regenerate the session ID if err := sess.Regenerate(); err != nil { return c.SendStatus(fiber.StatusInternalServerError) } // Ensure the session ID has changed but session data is preserved newID := sess.ID() if newID == "" { return c.SendStatus(fiber.StatusInternalServerError) } // Check if the session data is still accessible value, ok := sess.Get("key").(string) if !ok || value != "value" { return c.SendStatus(fiber.StatusInternalServerError) } return c.SendStatus(fiber.StatusOK) }) app.Post("/destroy", func(c fiber.Ctx) error { sess := FromContext(c) if sess == nil { return c.SendStatus(fiber.StatusInternalServerError) } if err := sess.Destroy(); err != nil { return c.SendStatus(fiber.StatusInternalServerError) } return c.SendStatus(fiber.StatusOK) }) app.Post("/fresh", func(c fiber.Ctx) error { sess := FromContext(c) if sess == nil { return c.SendStatus(fiber.StatusInternalServerError) } // Reset the session to make it fresh if err := sess.Reset(); err != nil { return c.SendStatus(fiber.StatusInternalServerError) } if sess.Fresh() { return c.SendStatus(fiber.StatusOK) } return c.SendStatus(fiber.StatusInternalServerError) }) app.Post("/keys", func(c fiber.Ctx) error { sess := FromContext(c) if sess == nil { return c.SendStatus(fiber.StatusInternalServerError) } // get a value from the body value := c.FormValue("keys") for rawKey := range strings.SplitSeq(value, ",") { key := utils.TrimSpace(rawKey) if key == "" { continue } // Set each key in the session sess.Set(key, "value_"+key) } return c.SendStatus(fiber.StatusOK) }) app.Get("/keys", func(c fiber.Ctx) error { sess := FromContext(c) if sess == nil { return c.SendStatus(fiber.StatusInternalServerError) } keys := sess.Keys() if len(keys) == 0 { return c.SendStatus(fiber.StatusNotFound) } // Keys may be of any type, so convert to string for display strKeys := []string{} for _, key := range keys { strKeys = append(strKeys, fmt.Sprintf("%v", key)) } return c.SendString("keys=" + strings.Join(strKeys, ",")) }) // Test GET, SET, DELETE, RESET, REGENERATE, DESTROY by sending requests to the respective routes ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.SetRequestURI("/get") h := app.Handler() h(ctx) require.Equal(t, fiber.StatusNotFound, ctx.Response.StatusCode()) token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) require.NotEmpty(t, token, "Expected Set-Cookie header to be present") tokenParts := strings.SplitN(strings.SplitN(token, ";", 2)[0], "=", 2) require.Len(t, tokenParts, 2, "Expected Set-Cookie header to contain a token") token = tokenParts[1] require.Equal(t, "key not found", string(ctx.Response.Body())) // Test POST /set ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.SetRequestURI("/set") ctx.Request.Header.Set("Content-Type", "application/x-www-form-urlencoded") // Set the Content-Type ctx.Request.SetBodyString("value=hello") ctx.Request.Header.SetCookie("session_id", token) h(ctx) require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) // Test GET /get to check if the value was set ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.SetRequestURI("/get") ctx.Request.Header.SetCookie("session_id", token) h(ctx) require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) require.Equal(t, "value=hello", string(ctx.Response.Body())) // Test POST /delete to delete the value ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.SetRequestURI("/delete") ctx.Request.Header.SetCookie("session_id", token) h(ctx) require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) // Test GET /get to check if the value was deleted ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.SetRequestURI("/get") ctx.Request.Header.SetCookie("session_id", token) h(ctx) require.Equal(t, fiber.StatusNotFound, ctx.Response.StatusCode()) require.Equal(t, "key not found", string(ctx.Response.Body())) // Test POST /reset to reset the session ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.SetRequestURI("/reset") ctx.Request.Header.SetCookie("session_id", token) h(ctx) require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) // verify we have a new session token newToken := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) require.NotEmpty(t, newToken, "Expected Set-Cookie header to be present") newTokenParts := strings.SplitN(strings.SplitN(newToken, ";", 2)[0], "=", 2) require.Len(t, newTokenParts, 2, "Expected Set-Cookie header to contain a token") newToken = newTokenParts[1] require.NotEqual(t, token, newToken) token = newToken // Test POST /regenerate to regenerate the session ID ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.SetRequestURI("/regenerate") ctx.Request.Header.SetCookie("session_id", token) h(ctx) require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) // verify we have a new session token newToken = string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) require.NotEmpty(t, newToken, "Expected Set-Cookie header to be present") newTokenParts = strings.SplitN(strings.SplitN(newToken, ";", 2)[0], "=", 2) require.Len(t, newTokenParts, 2, "Expected Set-Cookie header to contain a token") newToken = newTokenParts[1] require.NotEqual(t, token, newToken) token = newToken // Test POST /destroy to destroy the session ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.SetRequestURI("/destroy") ctx.Request.Header.SetCookie("session_id", token) h(ctx) require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) // Verify the session cookie has expired setCookieHeader := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) require.Contains(t, setCookieHeader, "max-age=0") // Sleep so that the session expires time.Sleep(1 * time.Second) // Test GET /get to check if the session was destroyed ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.SetRequestURI("/get") ctx.Request.Header.SetCookie("session_id", token) h(ctx) require.Equal(t, fiber.StatusNotFound, ctx.Response.StatusCode()) // check that we have a new session token newToken = string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) require.NotEmpty(t, newToken, "Expected Set-Cookie header to be present") parts := strings.Split(newToken, ";") require.Greater(t, len(parts), 1) valueParts := strings.Split(parts[0], "=") require.Greater(t, len(valueParts), 1) newToken = valueParts[1] require.NotEqual(t, token, newToken) token = newToken // Test POST /fresh to check if the session is fresh ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.SetRequestURI("/fresh") ctx.Request.Header.SetCookie("session_id", token) h(ctx) require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) // check that we have a new session token newToken = string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) require.NotEmpty(t, newToken, "Expected Set-Cookie header to be present") newTokenParts = strings.SplitN(strings.SplitN(newToken, ";", 2)[0], "=", 2) require.Len(t, newTokenParts, 2, "Expected Set-Cookie header to contain a token") newToken = newTokenParts[1] require.NotEqual(t, token, newToken) token = newToken // Test POST /keys to set multiple keys ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.SetRequestURI("/keys") ctx.Request.Header.Set("Content-Type", "application/x-www-form-urlencoded") // Set the Content-Type ctx.Request.SetBodyString("keys=key1,key2") ctx.Request.Header.SetCookie("session_id", token) h(ctx) require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) // Test GET /keys to check if the session has the keys ctx.Request.Reset() ctx.Response.Reset() ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.SetRequestURI("/keys") ctx.Request.Header.SetCookie("session_id", token) h(ctx) require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) body := string(ctx.Response.Body()) require.True(t, strings.HasPrefix(body, "keys=")) parts = strings.Split(strings.TrimPrefix(body, "keys="), ",") require.Len(t, parts, 2, "Expected two keys in the session") sort.Strings(parts) require.Equal(t, []string{"key1", "key2"}, parts) } func Test_Session_NewWithStore(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New()) app.Get("/", func(c fiber.Ctx) error { sess := FromContext(c) id := sess.ID() return c.SendString("value=" + id) }) app.Post("/", func(c fiber.Ctx) error { sess := FromContext(c) id := sess.ID() c.Cookie(&fiber.Cookie{ Name: "session_id", Value: id, }) return nil }) h := app.Handler() // Test GET request without cookie ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) // Get session cookie token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) require.NotEmpty(t, token, "Expected Set-Cookie header to be present") tokenParts := strings.SplitN(strings.SplitN(token, ";", 2)[0], "=", 2) require.Len(t, tokenParts, 2, "Expected Set-Cookie header to contain a token") token = tokenParts[1] require.Equal(t, "value="+token, string(ctx.Response.Body())) // Test GET request with cookie ctx = &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.SetCookie("session_id", token) h(ctx) require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) require.Equal(t, "value="+token, string(ctx.Response.Body())) } func Test_Session_FromSession(t *testing.T) { t.Parallel() app := fiber.New() sess := FromContext(app.AcquireCtx(&fasthttp.RequestCtx{})) require.Nil(t, sess) app.Use(New()) } func Test_Session_FromContext_Types(t *testing.T) { t.Parallel() app := fiber.New(fiber.Config{PassLocalsToContext: true}) app.Use(New()) app.Get("/", func(c fiber.Ctx) error { require.NotNil(t, FromContext(c)) customCtx, ok := c.(fiber.CustomCtx) require.True(t, ok) require.NotNil(t, FromContext(customCtx)) require.NotNil(t, FromContext(c.RequestCtx())) require.NotNil(t, FromContext(c.Context())) return c.SendStatus(fiber.StatusOK) }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) } func Test_Session_WithConfig(t *testing.T) { t.Parallel() app := fiber.New() app.Use(New(Config{ Next: func(c fiber.Ctx) bool { return c.Get("key") == "value" }, IdleTimeout: 1 * time.Second, Extractor: extractors.FromCookie("session_id_test"), KeyGenerator: func() string { return "test" }, })) app.Get("/", func(c fiber.Ctx) error { sess := FromContext(c) id := sess.ID() return c.SendString("value=" + id) }) app.Get("/isFresh", func(c fiber.Ctx) error { sess := FromContext(c) if sess.Fresh() { return c.SendStatus(fiber.StatusOK) } return c.SendStatus(fiber.StatusInternalServerError) }) app.Post("/", func(c fiber.Ctx) error { sess := FromContext(c) id := sess.ID() c.Cookie(&fiber.Cookie{ Name: "session_id_test", Value: id, }) return nil }) h := app.Handler() // Test GET request without cookie ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) // Get session cookie token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) require.NotEmpty(t, token, "Expected Set-Cookie header to be present") tokenParts := strings.SplitN(strings.SplitN(token, ";", 2)[0], "=", 2) require.Len(t, tokenParts, 2, "Expected Set-Cookie header to contain a token") token = tokenParts[1] require.Equal(t, "value="+token, string(ctx.Response.Body())) // Test GET request with cookie ctx = &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.SetCookie("session_id_test", token) h(ctx) require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) require.Equal(t, "value="+token, string(ctx.Response.Body())) // Test POST request with cookie ctx = &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.SetCookie("session_id_test", token) h(ctx) require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) // Test POST request without cookie ctx = &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodPost) h(ctx) require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) // Test POST request with wrong key ctx = &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.SetCookie("session_id", token) h(ctx) require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) // Test POST request with wrong value ctx = &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodPost) ctx.Request.Header.SetCookie("session_id_test", "wrong") h(ctx) require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) // Check idle timeout not expired ctx = &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.SetCookie("session_id_test", token) ctx.Request.SetRequestURI("/isFresh") h(ctx) require.Equal(t, fiber.StatusInternalServerError, ctx.Response.StatusCode()) // Test idle timeout time.Sleep(1200 * time.Millisecond) ctx = &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) ctx.Request.Header.SetCookie("session_id_test", token) ctx.Request.SetRequestURI("/isFresh") h(ctx) require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) } func Test_Session_Next(t *testing.T) { t.Parallel() var ( doNext bool muNext sync.RWMutex ) app := fiber.New() app.Use(New(Config{ Next: func(_ fiber.Ctx) bool { muNext.RLock() defer muNext.RUnlock() return doNext }, })) app.Get("/", func(c fiber.Ctx) error { sess := FromContext(c) if sess == nil { return c.SendStatus(fiber.StatusInternalServerError) } id := sess.ID() return c.SendString("value=" + id) }) h := app.Handler() // Test with Next returning false ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) // Get session cookie token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie)) require.NotEmpty(t, token, "Expected Set-Cookie header to be present") tokenParts := strings.SplitN(strings.SplitN(token, ";", 2)[0], "=", 2) require.Len(t, tokenParts, 2, "Expected Set-Cookie header to contain a token") token = tokenParts[1] require.Equal(t, "value="+token, string(ctx.Response.Body())) // Test with Next returning true muNext.Lock() doNext = true muNext.Unlock() ctx = &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) require.Equal(t, fiber.StatusInternalServerError, ctx.Response.StatusCode()) } func Test_Session_Middleware_Store(t *testing.T) { t.Parallel() app := fiber.New() handler, sessionStore := NewWithStore() app.Use(handler) app.Get("/", func(c fiber.Ctx) error { sess := FromContext(c) st := sess.Store() if st != sessionStore { return c.SendStatus(fiber.StatusInternalServerError) } return c.SendStatus(fiber.StatusOK) }) h := app.Handler() // Test GET request ctx := &fasthttp.RequestCtx{} ctx.Request.Header.SetMethod(fiber.MethodGet) h(ctx) require.Equal(t, fiber.StatusOK, ctx.Response.StatusCode()) } ================================================ FILE: middleware/session/session.go ================================================ package session import ( "bytes" "context" "encoding/gob" "fmt" "sync" "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/extractors" "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) // Session represents a user session. type Session struct { ctx fiber.Ctx // fiber context config *Store // store configuration data *data // key value data id string // session id idleTimeout time.Duration // idleTimeout of this session mu sync.RWMutex // Mutex to protect non-data fields fresh bool // if new session } type absExpirationKeyType int const ( // sessionIDContextKey is the key used to store the session ID in the context locals. absExpirationKey absExpirationKeyType = iota ) // Session pool for reusing byte buffers. var byteBufferPool = sync.Pool{ New: func() any { return new(bytes.Buffer) }, } var sessionPool = sync.Pool{ New: func() any { return &Session{} }, } // acquireSession returns a new Session from the pool. // // Returns: // - *Session: The session object. // // Usage: // // s := acquireSession() func acquireSession() *Session { s := sessionPool.Get().(*Session) //nolint:forcetypeassert,errcheck // We store nothing else in the pool if s.data == nil { s.data = acquireData() } s.fresh = true return s } // Release releases the session back to the pool. // // This function should be called after the session is no longer needed. // This function is used to reduce the number of allocations and // to improve the performance of the session store. // // The session should not be used after calling this function. // // Important: The Release function should only be used when accessing the session directly, // for example, when you have called func (s *Session) Get(ctx) to get the session. // It should not be used when using the session with a *Middleware handler in the request // call stack, as the middleware will still need to access the session. // // Usage: // // sess := session.Get(ctx) // defer sess.Release() func (s *Session) Release() { if s == nil { return } releaseSession(s) } func releaseSession(s *Session) { s.mu.Lock() s.id = "" s.idleTimeout = 0 s.ctx = nil s.config = nil if s.data != nil { s.data.Reset() } s.mu.Unlock() sessionPool.Put(s) } // Fresh returns whether the session is new // // Returns: // - bool: True if the session is fresh; otherwise, false. // // Usage: // // isFresh := s.Fresh() func (s *Session) Fresh() bool { s.mu.RLock() defer s.mu.RUnlock() return s.fresh } // ID returns the session ID // // Returns: // - string: The session ID. // // Usage: // // id := s.ID() func (s *Session) ID() string { s.mu.RLock() defer s.mu.RUnlock() return s.id } // Get returns the value associated with the given key. // // Parameters: // - key: The key to retrieve. // // Returns: // - any: The value associated with the key. // // Usage: // // value := s.Get("key") func (s *Session) Get(key any) any { if s.data == nil { return nil } return s.data.Get(key) } // Set updates or creates a new key-value pair in the session. // // Parameters: // - key: The key to set. // - val: The value to set. // // Usage: // // s.Set("key", "value") func (s *Session) Set(key, val any) { if s.data == nil { return } s.data.Set(key, val) } // Delete removes the key-value pair from the session. // // Parameters: // - key: The key to delete. // // Usage: // // s.Delete("key") func (s *Session) Delete(key any) { if s.data == nil { return } s.data.Delete(key) } // Destroy deletes the session from storage and expires the session cookie. // // Returns: // - error: An error if the destruction fails. // // Usage: // // err := s.Destroy() func (s *Session) Destroy() error { if s.data == nil { return nil } // Reset local data s.data.Reset() s.mu.RLock() defer s.mu.RUnlock() // Use external Storage if exist var ctx context.Context = s.ctx if ctx == nil { ctx = context.Background() } if err := s.config.Storage.DeleteWithContext(ctx, s.id); err != nil { return err } // Expire session s.delSession() return nil } // Regenerate generates a new session id and deletes the old one from storage. // // Returns: // - error: An error if the regeneration fails. // // Usage: // // err := s.Regenerate() func (s *Session) Regenerate() error { s.mu.Lock() defer s.mu.Unlock() // Delete old id from storage var ctx context.Context = s.ctx if ctx == nil { ctx = context.Background() } if err := s.config.Storage.DeleteWithContext(ctx, s.id); err != nil { return err } // Generate a new session, and set session.fresh to true s.refresh() return nil } // Reset generates a new session id, deletes the old one from storage, and resets the associated data. // // Returns: // - error: An error if the reset fails. // // Usage: // // err := s.Reset() func (s *Session) Reset() error { // Reset local data if s.data != nil { s.data.Reset() } s.mu.Lock() defer s.mu.Unlock() // Reset expiration s.idleTimeout = 0 // Delete old id from storage var ctx context.Context = s.ctx if ctx == nil { ctx = context.Background() } if err := s.config.Storage.DeleteWithContext(ctx, s.id); err != nil { return err } // Expire session s.delSession() // Generate a new session, and set session.fresh to true s.refresh() return nil } // refresh generates a new session, and sets session.fresh to be true. func (s *Session) refresh() { s.id = s.config.KeyGenerator() s.fresh = true } // Save saves the session data and updates the cookie // // Note: If the session is being used in the handler, calling Save will have // no effect and the session will automatically be saved when the handler returns. // // Returns: // - error: An error if the save operation fails. // // Usage: // // err := s.Save() func (s *Session) Save() error { if s.ctx == nil { return s.saveSession() } // If the session is being used in the handler, it should not be saved if m, ok := s.ctx.Locals(middlewareContextKey).(*Middleware); ok { if m.Session == s { // Session is in use, so we do nothing and return return nil } } return s.saveSession() } // saveSession encodes session data to saves it to storage. func (s *Session) saveSession() error { if s.data == nil { return nil } s.mu.Lock() defer s.mu.Unlock() // Set idleTimeout if not already set if s.idleTimeout <= 0 { s.idleTimeout = s.config.IdleTimeout } // Update client cookie s.setSession() // Encode session data s.data.RLock() encodedBytes, err := s.encodeSessionData() s.data.RUnlock() if err != nil { return fmt.Errorf("failed to encode data: %w", err) } // Pass copied bytes with session id to provider var ctx context.Context = s.ctx if ctx == nil { ctx = context.Background() } return s.config.Storage.SetWithContext(ctx, s.id, encodedBytes, s.idleTimeout) } // Keys retrieves all keys in the current session. // // Returns: // - []any: A slice of all keys in the session. // // Usage: // // keys := s.Keys() func (s *Session) Keys() []any { if s.data == nil { return []any{} } return s.data.Keys() } // SetIdleTimeout used when saving the session on the next call to `Save()`. // // Parameters: // - idleTimeout: The duration for the idle timeout. // // Usage: // // s.SetIdleTimeout(time.Hour) func (s *Session) SetIdleTimeout(idleTimeout time.Duration) { s.mu.Lock() defer s.mu.Unlock() s.idleTimeout = idleTimeout } // getExtractorInfo returns all cookie and header extractors from the chain func (s *Session) getExtractorInfo() []extractors.Extractor { if s.config == nil { return []extractors.Extractor{{Source: extractors.SourceCookie, Key: "session_id"}} // Safe default } extractor := s.config.Extractor var relevantExtractors []extractors.Extractor // If it's a chained extractor, collect all cookie/header extractors if len(extractor.Chain) > 0 { for _, chainExtractor := range extractor.Chain { if chainExtractor.Source == extractors.SourceCookie || chainExtractor.Source == extractors.SourceHeader { relevantExtractors = append(relevantExtractors, chainExtractor) } } } else if extractor.Source == extractors.SourceCookie || extractor.Source == extractors.SourceHeader { // Single extractor - only include if it's cookie or header relevantExtractors = append(relevantExtractors, extractor) } // If no cookie/header extractors found and the config has a store but no explicit cookie/header extractors, // we should not default to cookie. This allows for read-only configurations (e.g., query/param/form/custom). // Only add default cookie extractor if we have no extractors at all (nil config case is handled above) return relevantExtractors } func (s *Session) setSession() { if s.ctx == nil { return } // Get all relevant extractors relevantExtractors := s.getExtractorInfo() // Set session ID for each extractor type for _, ext := range relevantExtractors { switch ext.Source { case extractors.SourceHeader: s.ctx.Response().Header.SetBytesV(ext.Key, utils.UnsafeBytes(s.id)) case extractors.SourceCookie: fcookie := fasthttp.AcquireCookie() fcookie.SetKey(ext.Key) fcookie.SetValue(s.id) fcookie.SetPath(s.config.CookiePath) fcookie.SetDomain(s.config.CookieDomain) // Cookies are also session cookies if they do not specify the Expires or Max-Age attribute. // refer: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie if !s.config.CookieSessionOnly { fcookie.SetMaxAge(int(s.idleTimeout.Seconds())) fcookie.SetExpire(time.Now().Add(s.idleTimeout)) } s.setCookieAttributes(fcookie) s.ctx.Response().Header.SetCookie(fcookie) fasthttp.ReleaseCookie(fcookie) default: // For non-cookie/header sources, do nothing (read-only) } } } func (s *Session) delSession() { if s.ctx == nil { return } // Get all relevant extractors relevantExtractors := s.getExtractorInfo() // Delete session ID for each extractor type for _, ext := range relevantExtractors { switch ext.Source { case extractors.SourceHeader: s.ctx.Request().Header.Del(ext.Key) s.ctx.Response().Header.Del(ext.Key) case extractors.SourceCookie: s.ctx.Request().Header.DelCookie(ext.Key) s.ctx.Response().Header.DelCookie(ext.Key) fcookie := fasthttp.AcquireCookie() fcookie.SetKey(ext.Key) fcookie.SetPath(s.config.CookiePath) fcookie.SetDomain(s.config.CookieDomain) fcookie.SetMaxAge(-1) fcookie.SetExpire(time.Now().Add(-1 * time.Minute)) s.setCookieAttributes(fcookie) s.ctx.Response().Header.SetCookie(fcookie) fasthttp.ReleaseCookie(fcookie) default: // For non-cookie/header sources, do nothing (read-only) } } } // setCookieAttributes sets the cookie attributes based on the session config. func (s *Session) setCookieAttributes(fcookie *fasthttp.Cookie) { // Set SameSite attribute switch { case utils.EqualFold(s.config.CookieSameSite, fiber.CookieSameSiteStrictMode): fcookie.SetSameSite(fasthttp.CookieSameSiteStrictMode) case utils.EqualFold(s.config.CookieSameSite, fiber.CookieSameSiteNoneMode): fcookie.SetSameSite(fasthttp.CookieSameSiteNoneMode) default: fcookie.SetSameSite(fasthttp.CookieSameSiteLaxMode) } // The Secure attribute is required for SameSite=None if fcookie.SameSite() == fasthttp.CookieSameSiteNoneMode { fcookie.SetSecure(true) } else { fcookie.SetSecure(s.config.CookieSecure) } fcookie.SetHTTPOnly(s.config.CookieHTTPOnly) } // decodeSessionData decodes session data from raw bytes // // Parameters: // - rawData: The raw byte data to decode. // // Returns: // - error: An error if the decoding fails. // // Usage: // // err := s.decodeSessionData(rawData) func (s *Session) decodeSessionData(rawData []byte) error { byteBuffer := byteBufferPool.Get().(*bytes.Buffer) //nolint:forcetypeassert,errcheck // We store nothing else in the pool defer byteBufferPool.Put(byteBuffer) defer byteBuffer.Reset() _, _ = byteBuffer.Write(rawData) decCache := gob.NewDecoder(byteBuffer) if err := decCache.Decode(&s.data.Data); err != nil { return fmt.Errorf("failed to decode session data: %w", err) } return nil } // encodeSessionData encodes session data to raw bytes // // Parameters: // - rawData: The raw byte data to encode. // // Returns: // - error: An error if the encoding fails. // // Usage: // // err := s.encodeSessionData(rawData) func (s *Session) encodeSessionData() ([]byte, error) { byteBuffer := byteBufferPool.Get().(*bytes.Buffer) //nolint:forcetypeassert,errcheck // We store nothing else in the pool defer byteBufferPool.Put(byteBuffer) defer byteBuffer.Reset() encCache := gob.NewEncoder(byteBuffer) if err := encCache.Encode(&s.data.Data); err != nil { return nil, fmt.Errorf("failed to encode session data: %w", err) } // Copy the bytes // Copy the data in buffer encodedBytes := make([]byte, byteBuffer.Len()) copy(encodedBytes, byteBuffer.Bytes()) return encodedBytes, nil } // absExpiration returns the session absolute expiration time or a zero time if not set. // // Returns: // - time.Time: The session absolute expiration time. Zero time if not set. // // Usage: // // expiration := s.absExpiration() func (s *Session) absExpiration() time.Time { absExpiration, ok := s.Get(absExpirationKey).(time.Time) if ok { return absExpiration } return time.Time{} } // isAbsExpired returns true if the session is expired. // // If the session has an absolute expiration time set, this function will return true if the // current time is after the absolute expiration time. // // Returns: // - bool: True if the session is expired; otherwise, false. func (s *Session) isAbsExpired() bool { absExpiration := s.absExpiration() return !absExpiration.IsZero() && time.Now().After(absExpiration) } // setAbsExpiration sets the absolute session expiration time. // // Parameters: // - expiration: The session expiration time. // // Usage: // // s.setAbsExpiration(time.Now().Add(time.Hour)) func (s *Session) setAbsExpiration(absExpiration time.Time) { s.Set(absExpirationKey, absExpiration) } ================================================ FILE: middleware/session/session_test.go ================================================ package session import ( "errors" "strings" "sync" "testing" "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/extractors" "github.com/gofiber/fiber/v3/internal/storage/memory" "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) // go test -run Test_Session func Test_Session(t *testing.T) { t.Parallel() // session store store := NewStore() // fiber instance app := fiber.New() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) // Get a new session sess, err := store.Get(ctx) require.NoError(t, err) require.True(t, sess.Fresh()) token := sess.ID() require.NoError(t, sess.Save()) sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) // set session using default cookie extractor ctx.Request().Header.SetCookie("session_id", token) // get session sess, err = store.Get(ctx) require.NoError(t, err) require.False(t, sess.Fresh()) // get keys keys := sess.Keys() require.Equal(t, []any{}, keys) // get value name := sess.Get("name") require.Nil(t, name) // set value sess.Set("name", "john") // get value name = sess.Get("name") require.Equal(t, "john", name) keys = sess.Keys() require.Equal(t, []any{"name"}, keys) // delete key sess.Delete("name") // get value name = sess.Get("name") require.Nil(t, name) // get keys keys = sess.Keys() require.Equal(t, []any{}, keys) // get id id := sess.ID() require.Equal(t, token, id) // save the old session first err = sess.Save() require.NoError(t, err) // release the session sess.Release() // release the context app.ReleaseCtx(ctx) // requesting entirely new context to prevent falsy tests ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) sess, err = store.Get(ctx) require.NoError(t, err) require.True(t, sess.Fresh()) // this id should be randomly generated as session key was deleted require.Len(t, sess.ID(), 43) sess.Release() // when we use the original session for the second time // the session be should be same if the session is not expired app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) // request the server with the old session ctx.Request().Header.SetCookie("session_id", id) sess, err = store.Get(ctx) defer sess.Release() require.NoError(t, err) require.False(t, sess.Fresh()) require.Equal(t, sess.id, id) } // go test -run Test_Session_Types func Test_Session_Types(t *testing.T) { t.Parallel() // session store store := NewStore() // fiber instance app := fiber.New() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) // set cookie ctx.Request().Header.SetCookie("session_id", "123") // get session sess, err := store.Get(ctx) require.NoError(t, err) require.True(t, sess.Fresh()) // the session string is no longer be 123 newSessionIDString := sess.ID() type User struct { Name string } store.RegisterType(User{}) vuser := User{ Name: "John", } // set value var ( vbool = true vstring = "str" vint = 13 vint8 int8 = 13 vint16 int16 = 13 vint32 int32 = 13 vint64 int64 = 13 vuint uint = 13 vuint8 uint8 = 13 vuint16 uint16 = 13 vuint32 uint32 = 13 vuint64 uint64 = 13 vuintptr uintptr = 13 vbyte byte = 'k' vrune = 'k' vfloat32 float32 = 13 vfloat64 float64 = 13 vcomplex64 complex64 = 13 vcomplex128 complex128 = 13 ) sess.Set("vuser", vuser) sess.Set("vbool", vbool) sess.Set("vstring", vstring) sess.Set("vint", vint) sess.Set("vint8", vint8) sess.Set("vint16", vint16) sess.Set("vint32", vint32) sess.Set("vint64", vint64) sess.Set("vuint", vuint) sess.Set("vuint8", vuint8) sess.Set("vuint16", vuint16) sess.Set("vuint32", vuint32) sess.Set("vuint32", vuint32) sess.Set("vuint64", vuint64) sess.Set("vuintptr", vuintptr) sess.Set("vbyte", vbyte) sess.Set("vrune", vrune) sess.Set("vfloat32", vfloat32) sess.Set("vfloat64", vfloat64) sess.Set("vcomplex64", vcomplex64) sess.Set("vcomplex128", vcomplex128) // save session err = sess.Save() require.NoError(t, err) sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) ctx.Request().Header.SetCookie("session_id", newSessionIDString) // get session sess, err = store.Get(ctx) require.NoError(t, err) require.False(t, sess.Fresh()) // get value vuserResult, ok := sess.Get("vuser").(User) require.True(t, ok) require.Equal(t, vuser, vuserResult) vboolResult, ok := sess.Get("vbool").(bool) require.True(t, ok) require.Equal(t, vbool, vboolResult) vstringResult, ok := sess.Get("vstring").(string) require.True(t, ok) require.Equal(t, vstring, vstringResult) vintResult, ok := sess.Get("vint").(int) require.True(t, ok) require.Equal(t, vint, vintResult) vint8Result, ok := sess.Get("vint8").(int8) require.True(t, ok) require.Equal(t, vint8, vint8Result) vint16Result, ok := sess.Get("vint16").(int16) require.True(t, ok) require.Equal(t, vint16, vint16Result) vint32Result, ok := sess.Get("vint32").(int32) require.True(t, ok) require.Equal(t, vint32, vint32Result) vint64Result, ok := sess.Get("vint64").(int64) require.True(t, ok) require.Equal(t, vint64, vint64Result) vuintResult, ok := sess.Get("vuint").(uint) require.True(t, ok) require.Equal(t, vuint, vuintResult) vuint8Result, ok := sess.Get("vuint8").(uint8) require.True(t, ok) require.Equal(t, vuint8, vuint8Result) vuint16Result, ok := sess.Get("vuint16").(uint16) require.True(t, ok) require.Equal(t, vuint16, vuint16Result) vuint32Result, ok := sess.Get("vuint32").(uint32) require.True(t, ok) require.Equal(t, vuint32, vuint32Result) vuint64Result, ok := sess.Get("vuint64").(uint64) require.True(t, ok) require.Equal(t, vuint64, vuint64Result) vuintptrResult, ok := sess.Get("vuintptr").(uintptr) require.True(t, ok) require.Equal(t, vuintptr, vuintptrResult) vbyteResult, ok := sess.Get("vbyte").(byte) require.True(t, ok) require.Equal(t, vbyte, vbyteResult) vruneResult, ok := sess.Get("vrune").(rune) require.True(t, ok) require.Equal(t, vrune, vruneResult) vfloat32Result, ok := sess.Get("vfloat32").(float32) require.True(t, ok) require.InEpsilon(t, vfloat32, vfloat32Result, 0.001) vfloat64Result, ok := sess.Get("vfloat64").(float64) require.True(t, ok) require.InEpsilon(t, vfloat64, vfloat64Result, 0.001) vcomplex64Result, ok := sess.Get("vcomplex64").(complex64) require.True(t, ok) require.Equal(t, vcomplex64, vcomplex64Result) vcomplex128Result, ok := sess.Get("vcomplex128").(complex128) require.True(t, ok) require.Equal(t, vcomplex128, vcomplex128Result) sess.Release() app.ReleaseCtx(ctx) } // go test -run Test_Session_Store_Reset func Test_Session_Store_Reset(t *testing.T) { t.Parallel() // session store store := NewStore() // fiber instance app := fiber.New() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) // get session sess, err := store.Get(ctx) require.NoError(t, err) // make sure its new require.True(t, sess.Fresh()) // set value & save sess.Set("hello", "world") ctx.Request().Header.SetCookie("session_id", sess.ID()) require.NoError(t, sess.Save()) // reset store require.NoError(t, store.Reset(ctx)) id := sess.ID() sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) ctx.Request().Header.SetCookie("session_id", id) // make sure the session is recreated sess, err = store.Get(ctx) defer sess.Release() require.NoError(t, err) require.True(t, sess.Fresh()) require.Nil(t, sess.Get("hello")) } func Test_Session_KeyTypes(t *testing.T) { // Note: This test cannot run in parallel because it registers types // in the global gob registry via store.RegisterType(), which would // cause race conditions with other parallel tests. // session store store := NewStore() // fiber instance app := fiber.New() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) // get session sess, err := store.Get(ctx) require.NoError(t, err) require.True(t, sess.Fresh()) type Person struct { Name string } type unexportedKey int // register non-default types store.RegisterType(Person{}) store.RegisterType(unexportedKey(0)) type unregisteredKeyType int type unregisteredValueType int // verify unregistered keys types are not allowed var ( unregisteredKey unregisteredKeyType unregisteredValue unregisteredValueType ) sess.Set(unregisteredKey, "test") err = sess.Save() require.Error(t, err) sess.Delete(unregisteredKey) err = sess.Save() require.NoError(t, err) sess.Set("abc", unregisteredValue) err = sess.Save() require.Error(t, err) sess.Delete("abc") err = sess.Save() require.NoError(t, err) require.NoError(t, sess.Reset()) // Release session before continuing sess.Release() // Get a new session after reset sess, err = store.Get(ctx) require.NoError(t, err) require.True(t, sess.Fresh()) var ( kbool = true kstring = "str" kint = 13 kint8 int8 = 13 kint16 int16 = 13 kint32 int32 = 13 kint64 int64 = 13 kuint uint = 13 kuint8 uint8 = 13 kuint16 uint16 = 13 kuint32 uint32 = 13 kuint64 uint64 = 13 kuintptr uintptr = 13 kbyte byte = 'k' krune = 'k' kfloat32 float32 = 13 kfloat64 float64 = 13 kcomplex64 complex64 = 13 kcomplex128 complex128 = 13 kuser = Person{Name: "John"} kunexportedKey = unexportedKey(13) ) var ( vbool = true vstring = "str" vint = 13 vint8 int8 = 13 vint16 int16 = 13 vint32 int32 = 13 vint64 int64 = 13 vuint uint = 13 vuint8 uint8 = 13 vuint16 uint16 = 13 vuint32 uint32 = 13 vuint64 uint64 = 13 vuintptr uintptr = 13 vbyte byte = 'k' vrune = 'k' vfloat32 float32 = 13 vfloat64 float64 = 13 vcomplex64 complex64 = 13 vcomplex128 complex128 = 13 vuser = Person{Name: "John"} vunexportedKey = unexportedKey(13) ) keys := []any{ kbool, kstring, kint, kint8, kint16, kint32, kint64, kuint, kuint8, kuint16, kuint32, kuint64, kuintptr, kbyte, krune, kfloat32, kfloat64, kcomplex64, kcomplex128, kuser, kunexportedKey, } values := []any{ vbool, vstring, vint, vint8, vint16, vint32, vint64, vuint, vuint8, vuint16, vuint32, vuint64, vuintptr, vbyte, vrune, vfloat32, vfloat64, vcomplex64, vcomplex128, vuser, vunexportedKey, } // loop test all key value pairs for i, key := range keys { sess.Set(key, values[i]) } id := sess.ID() ctx.Request().Header.SetCookie("session_id", id) // save session err = sess.Save() require.NoError(t, err) sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) ctx.Request().Header.SetCookie("session_id", id) // get session sess, err = store.Get(ctx) require.NoError(t, err) defer sess.Release() require.False(t, sess.Fresh()) // loop test all key value pairs for i, key := range keys { // get value result := sess.Get(key) require.Equal(t, values[i], result) } } // go test -run Test_Session_Save func Test_Session_Save(t *testing.T) { t.Parallel() t.Run("save to cookie", func(t *testing.T) { t.Parallel() // session store store := NewStore() // fiber instance app := fiber.New() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) // get session sess, err := store.Get(ctx) require.NoError(t, err) // set value sess.Set("name", "john") // save session err = sess.Save() require.NoError(t, err) sess.Release() }) t.Run("save to header", func(t *testing.T) { t.Parallel() // session store store := NewStore(Config{ Extractor: extractors.FromHeader("session_id"), }) // fiber instance app := fiber.New() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) // get session sess, err := store.Get(ctx) require.NoError(t, err) // set value sess.Set("name", "john") // save session err = sess.Save() require.NoError(t, err) require.Equal(t, sess.ID(), string(ctx.Response().Header.Peek("session_id"))) sess.Release() }) } // Test chained extractors to ensure both cookie and header are set when both are present func Test_Session_ChainedExtractors(t *testing.T) { t.Parallel() t.Run("cookie and header chain", func(t *testing.T) { t.Parallel() // session store with chained extractors store := NewStore(Config{ Extractor: extractors.Chain(extractors.FromCookie("session_id"), extractors.FromHeader("x-session-id")), }) // fiber instance app := fiber.New() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) // get session sess, err := store.Get(ctx) require.NoError(t, err) // set value sess.Set("name", "john") // save session err = sess.Save() require.NoError(t, err) // verify both cookie and header are set cookie := ctx.Response().Header.PeekCookie("session_id") require.NotNil(t, cookie) require.Contains(t, string(cookie), sess.ID()) header := string(ctx.Response().Header.Peek("x-session-id")) require.Equal(t, sess.ID(), header) sess.Release() }) t.Run("header and cookie chain", func(t *testing.T) { t.Parallel() // session store with chained extractors (different order) store := NewStore(Config{ Extractor: extractors.Chain(extractors.FromHeader("x-session-id"), extractors.FromCookie("session_id")), }) // fiber instance app := fiber.New() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) // get session sess, err := store.Get(ctx) require.NoError(t, err) // set value sess.Set("name", "john") // save session err = sess.Save() require.NoError(t, err) // verify both header and cookie are set header := string(ctx.Response().Header.Peek("x-session-id")) require.Equal(t, sess.ID(), header) cookie := ctx.Response().Header.PeekCookie("session_id") require.NotNil(t, cookie) require.Contains(t, string(cookie), sess.ID()) sess.Release() }) t.Run("only SourceOther extractors - no response setting", func(t *testing.T) { t.Parallel() // session store with only query/form extractors store := NewStore(Config{ Extractor: extractors.Chain(extractors.FromQuery("session_id"), extractors.FromForm("session_id")), }) // fiber instance app := fiber.New() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) // get session sess, err := store.Get(ctx) require.NoError(t, err) // set value sess.Set("name", "john") // save session err = sess.Save() require.NoError(t, err) // verify no cookie or header is set cookie := ctx.Response().Header.PeekCookie("session_id") require.Nil(t, cookie) header := string(ctx.Response().Header.Peek("session_id")) require.Empty(t, header) sess.Release() }) t.Run("mixed chain with SourceOther", func(t *testing.T) { t.Parallel() // session store with mixed extractors including SourceOther store := NewStore(Config{ Extractor: extractors.Chain(extractors.FromCookie("session_id"), extractors.FromQuery("session_id"), extractors.FromHeader("x-session-id")), }) // fiber instance app := fiber.New() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) // get session sess, err := store.Get(ctx) require.NoError(t, err) // set value sess.Set("name", "john") // save session err = sess.Save() require.NoError(t, err) // verify both cookie and header are set (query is ignored for response) cookie := ctx.Response().Header.PeekCookie("session_id") require.NotNil(t, cookie) require.Contains(t, string(cookie), sess.ID()) header := string(ctx.Response().Header.Peek("x-session-id")) require.Equal(t, sess.ID(), header) sess.Release() }) } func Test_Session_Save_IdleTimeout(t *testing.T) { t.Parallel() t.Run("save to cookie", func(t *testing.T) { t.Parallel() const sessionDuration = 5 * time.Second // session store store := NewStore() // fiber instance app := fiber.New() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) // get session sess, err := store.Get(ctx) require.NoError(t, err) // set value sess.Set("name", "john") token := sess.ID() // expire this session in 5 seconds sess.SetIdleTimeout(sessionDuration) // save session err = sess.Save() require.NoError(t, err) sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) // here you need to get the old session yet ctx.Request().Header.SetCookie("session_id", token) sess, err = store.Get(ctx) require.NoError(t, err) require.Equal(t, token, sess.ID(), "session ID should match before expiration") name := sess.Get("name") require.Equal(t, "john", name, "session should contain the saved value before expiration") // just to make sure the session has been expired // Add extra buffer time to ensure expiration is processed time.Sleep(sessionDuration + (100 * time.Millisecond)) sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) // here you should get a new session ctx.Request().Header.SetCookie("session_id", token) sess, err = store.Get(ctx) defer sess.Release() require.NoError(t, err) require.Nil(t, sess.Get("name")) require.NotEqual(t, sess.ID(), token) }) } func Test_Session_Save_AbsoluteTimeout(t *testing.T) { t.Parallel() t.Run("save to cookie", func(t *testing.T) { t.Parallel() const absoluteTimeout = 2 * time.Second // extra headroom to avoid flakiness under -race // session store store := NewStore(Config{ IdleTimeout: absoluteTimeout, AbsoluteTimeout: absoluteTimeout, }) // force change to IdleTimeout store.IdleTimeout = 10 * time.Second // fiber instance app := fiber.New() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) // get session sess, err := store.Get(ctx) require.NoError(t, err) // set value sess.Set("name", "john") token := sess.ID() // save session err = sess.Save() require.NoError(t, err) sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) // here you need to get the old session yet ctx.Request().Header.SetCookie("session_id", token) sess, err = store.Get(ctx) require.NoError(t, err) require.Equal(t, "john", sess.Get("name")) // just to make sure the session has been expired time.Sleep(absoluteTimeout + (200 * time.Millisecond)) sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) // here you should get a new session ctx.Request().Header.SetCookie("session_id", token) sess, err = store.Get(ctx) require.NoError(t, err) require.Nil(t, sess.Get("name")) require.NotEqual(t, sess.ID(), token) require.True(t, sess.Fresh()) require.IsType(t, time.Time{}, sess.Get(absExpirationKey)) token = sess.ID() sess.Set("name", "john") // save session err = sess.Save() require.NoError(t, err) sess.Release() app.ReleaseCtx(ctx) // just to make sure the session has been expired time.Sleep(absoluteTimeout + (200 * time.Millisecond)) // try to get expired session by id sess, err = store.GetByID(ctx, token) require.Error(t, err) require.ErrorIs(t, err, ErrSessionIDNotFoundInStore) require.Nil(t, sess) }) } // go test -run Test_Session_Destroy func Test_Session_Destroy(t *testing.T) { t.Parallel() t.Run("destroy from cookie", func(t *testing.T) { t.Parallel() // session store store := NewStore() // fiber instance app := fiber.New() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) // get session sess, err := store.Get(ctx) defer sess.Release() require.NoError(t, err) sess.Set("name", "fenny") require.NoError(t, sess.Destroy()) name := sess.Get("name") require.Nil(t, name) }) t.Run("destroy from header", func(t *testing.T) { t.Parallel() // session store store := NewStore(Config{ Extractor: extractors.FromHeader("session_id"), }) // fiber instance app := fiber.New() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) // get session sess, err := store.Get(ctx) require.NoError(t, err) // set value & save sess.Set("name", "fenny") id := sess.ID() require.NoError(t, sess.Save()) sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) // get session ctx.Request().Header.Set("session_id", id) sess, err = store.Get(ctx) require.NoError(t, err) defer sess.Release() err = sess.Destroy() require.NoError(t, err) require.Empty(t, string(ctx.Response().Header.Peek("session_id"))) }) } // go test -run Test_Session_Custom_Config func Test_Session_Custom_Config(t *testing.T) { t.Parallel() store := NewStore(Config{IdleTimeout: time.Hour, KeyGenerator: func() string { return "very random" }}) require.Equal(t, time.Hour, store.IdleTimeout) require.Equal(t, "very random", store.KeyGenerator()) store = NewStore(Config{IdleTimeout: 0}) require.Equal(t, ConfigDefault.IdleTimeout, store.IdleTimeout) } // go test -run Test_Session_Cookie func Test_Session_Cookie(t *testing.T) { t.Parallel() // session store store := NewStore() // fiber instance app := fiber.New() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) // get session sess, err := store.Get(ctx) require.NoError(t, err) require.NoError(t, sess.Save()) sess.Release() // cookie should be set on Save ( even if empty data ) cookie := ctx.Response().Header.PeekCookie("session_id") require.NotNil(t, cookie) require.Regexp(t, `^session_id=[A-Za-z0-9\-_]{43}; max-age=\d+; path=/; SameSite=Lax$`, string(cookie)) } // go test -run Test_Session_Cookie_SameSite func Test_Session_Cookie_SameSite(t *testing.T) { t.Parallel() tests := []struct { expectedInHeader string name string sameSite string initialSecure bool }{ { name: "Lax should not force secure", sameSite: "Lax", initialSecure: false, expectedInHeader: "SameSite=Lax", }, { name: "Lax with secure should stay secure", sameSite: "Lax", initialSecure: true, expectedInHeader: "SameSite=Lax; secure", }, { name: "Strict should not force secure", sameSite: "Strict", initialSecure: false, expectedInHeader: "SameSite=Strict", }, { name: "Strict with secure should stay secure", sameSite: "Strict", initialSecure: true, expectedInHeader: "SameSite=Strict; secure", }, { name: "None should force secure", sameSite: "None", initialSecure: false, expectedInHeader: "SameSite=None; secure", }, { name: "None with secure should stay secure", sameSite: "None", initialSecure: true, expectedInHeader: "SameSite=None; secure", }, { name: "Case-insensitive none should force secure", sameSite: "none", initialSecure: false, expectedInHeader: "SameSite=None; secure", }, { name: "Case-insensitive strict should not force secure", sameSite: "strict", initialSecure: false, expectedInHeader: "SameSite=Strict", }, { name: "Default should be Lax", sameSite: "invalid", initialSecure: false, expectedInHeader: "SameSite=Lax", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { t.Parallel() // session store store := NewStore(Config{ CookieSameSite: tc.sameSite, CookieSecure: tc.initialSecure, }) // fiber instance app := fiber.New() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) // get session sess, err := store.Get(ctx) require.NoError(t, err) defer sess.Release() // save session to trigger cookie setting err = sess.Save() require.NoError(t, err) // check cookie cookie := string(ctx.Response().Header.PeekCookie("session_id")) // The order of attributes in the cookie string is not guaranteed. // Instead of checking for a single substring, we check for the presence of each part. parts := strings.SplitSeq(tc.expectedInHeader, "; ") for part := range parts { require.Contains(t, cookie, part) } // Also check that secure is NOT present when it shouldn't be if !tc.initialSecure && tc.sameSite != "None" && tc.sameSite != "none" { require.NotContains(t, cookie, "secure") } }) } } // go test -run Test_Session_Cookie_In_Response // Regression: https://github.com/gofiber/fiber/pull/1191 func Test_Session_Cookie_In_Middleware_Chain(t *testing.T) { store := NewStore() app := fiber.New() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) // get session sess, err := store.Get(ctx) require.NoError(t, err) sess.Set("id", "1") require.True(t, sess.Fresh()) id := sess.ID() require.NoError(t, sess.Save()) sess.Release() sess, err = store.Get(ctx) require.NoError(t, err) defer sess.Release() sess.Set("name", "john") require.False(t, sess.Fresh()) // Session should not be fresh - it reuses the same ID from context locals require.Equal(t, id, sess.ID()) // session id should be the same require.Equal(t, "1", sess.Get("id")) require.Equal(t, "john", sess.Get("name")) } // go test -run Test_Session_Deletes_Single_Key // Regression: https://github.com/gofiber/fiber/issues/1365 func Test_Session_Deletes_Single_Key(t *testing.T) { t.Parallel() store := NewStore() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) sess, err := store.Get(ctx) require.NoError(t, err) id := sess.ID() sess.Set("id", "1") require.NoError(t, sess.Save()) sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) ctx.Request().Header.SetCookie("session_id", id) sess, err = store.Get(ctx) require.NoError(t, err) sess.Delete("id") require.NoError(t, sess.Save()) sess.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) ctx.Request().Header.SetCookie("session_id", id) sess, err = store.Get(ctx) defer sess.Release() require.NoError(t, err) require.False(t, sess.Fresh()) require.Nil(t, sess.Get("id")) app.ReleaseCtx(ctx) } // go test -run Test_Session_Reset func Test_Session_Reset(t *testing.T) { t.Parallel() // fiber instance app := fiber.New() // session store store := NewStore() t.Run("reset session data and id, and set fresh to be true", func(t *testing.T) { t.Parallel() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) // a random session uuid originalSessionUUIDString := "" // now the session is in the storage freshSession, err := store.Get(ctx) require.NoError(t, err) originalSessionUUIDString = freshSession.ID() // set a value freshSession.Set("name", "fenny") freshSession.Set("email", "fenny@example.com") err = freshSession.Save() require.NoError(t, err) freshSession.Release() app.ReleaseCtx(ctx) ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) // set cookie ctx.Request().Header.SetCookie("session_id", originalSessionUUIDString) // as the session is in the storage, session.fresh should be false acquiredSession, err := store.Get(ctx) require.NoError(t, err) require.False(t, acquiredSession.Fresh()) err = acquiredSession.Reset() require.NoError(t, err) require.NotEqual(t, originalSessionUUIDString, acquiredSession.ID()) // acquiredSession.fresh should be true after resetting require.True(t, acquiredSession.Fresh()) // Check that the session data has been reset keys := acquiredSession.Keys() require.Equal(t, []any{}, keys) // Set a new value for 'name' and check that it's updated acquiredSession.Set("name", "john") require.Equal(t, "john", acquiredSession.Get("name")) require.Nil(t, acquiredSession.Get("email")) // Save after resetting err = acquiredSession.Save() require.NoError(t, err) acquiredSession.Release() // Check that the session id is not in the header or cookie anymore require.Empty(t, string(ctx.Response().Header.Peek("session_id"))) require.Empty(t, string(ctx.Request().Header.Peek("session_id"))) app.ReleaseCtx(ctx) }) } // go test -run Test_Session_Regenerate // Regression: https://github.com/gofiber/fiber/issues/1395 func Test_Session_Regenerate(t *testing.T) { t.Parallel() // fiber instance app := fiber.New() t.Run("set fresh to be true when regenerating a session", func(t *testing.T) { t.Parallel() // session store store := NewStore() // a random session uuid originalSessionUUIDString := "" // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) // now the session is in the storage freshSession, err := store.Get(ctx) require.NoError(t, err) originalSessionUUIDString = freshSession.ID() err = freshSession.Save() require.NoError(t, err) freshSession.Release() // release the context app.ReleaseCtx(ctx) // acquire a new context ctx = app.AcquireCtx(&fasthttp.RequestCtx{}) // set cookie ctx.Request().Header.SetCookie("session_id", originalSessionUUIDString) // as the session is in the storage, session.fresh should be false acquiredSession, err := store.Get(ctx) require.NoError(t, err) defer acquiredSession.Release() require.False(t, acquiredSession.Fresh()) err = acquiredSession.Regenerate() require.NoError(t, err) require.NotEqual(t, originalSessionUUIDString, acquiredSession.ID()) // acquiredSession.fresh should be true after regenerating require.True(t, acquiredSession.Fresh()) // release the context app.ReleaseCtx(ctx) }) } // go test -v -run=^$ -bench=Benchmark_Session -benchmem -count=4 func Benchmark_Session(b *testing.B) { b.Run("default", func(b *testing.B) { app, store := fiber.New(), NewStore() c := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(c) c.Request().Header.SetCookie("session_id", "12356789") b.ReportAllocs() for b.Loop() { sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark sess.Set("john", "doe") _ = sess.Save() //nolint:errcheck // We're inside a benchmark sess.Release() } }) b.Run("storage", func(b *testing.B) { app := fiber.New() store := NewStore(Config{ Storage: memory.New(), }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(c) c.Request().Header.SetCookie("session_id", "12356789") b.ReportAllocs() for b.Loop() { sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark sess.Set("john", "doe") _ = sess.Save() //nolint:errcheck // We're inside a benchmark sess.Release() } }) } // go test -v -run=^$ -bench=Benchmark_Session_Parallel -benchmem -count=4 func Benchmark_Session_Parallel(b *testing.B) { b.Run("default", func(b *testing.B) { app, store := fiber.New(), NewStore() b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.SetCookie("session_id", "12356789") sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark sess.Set("john", "doe") _ = sess.Save() //nolint:errcheck // We're inside a benchmark sess.Release() app.ReleaseCtx(c) } }) }) b.Run("storage", func(b *testing.B) { app := fiber.New() store := NewStore(Config{ Storage: memory.New(), }) b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.SetCookie("session_id", "12356789") sess, _ := store.Get(c) //nolint:errcheck // We're inside a benchmark sess.Set("john", "doe") _ = sess.Save() //nolint:errcheck // We're inside a benchmark sess.Release() app.ReleaseCtx(c) } }) }) } // go test -v -run=^$ -bench=Benchmark_Session_Asserted -benchmem -count=4 func Benchmark_Session_Asserted(b *testing.B) { b.Run("default", func(b *testing.B) { app, store := fiber.New(), NewStore() c := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(c) c.Request().Header.SetCookie("session_id", "12356789") b.ReportAllocs() for b.Loop() { sess, err := store.Get(c) require.NoError(b, err) sess.Set("john", "doe") err = sess.Save() require.NoError(b, err) sess.Release() } }) b.Run("storage", func(b *testing.B) { app := fiber.New() store := NewStore(Config{ Storage: memory.New(), }) c := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(c) c.Request().Header.SetCookie("session_id", "12356789") b.ReportAllocs() for b.Loop() { sess, err := store.Get(c) require.NoError(b, err) sess.Set("john", "doe") err = sess.Save() require.NoError(b, err) sess.Release() } }) } // go test -v -run=^$ -bench=Benchmark_Session_Asserted_Parallel -benchmem -count=4 func Benchmark_Session_Asserted_Parallel(b *testing.B) { b.Run("default", func(b *testing.B) { app, store := fiber.New(), NewStore() b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.SetCookie("session_id", "12356789") sess, err := store.Get(c) require.NoError(b, err) sess.Set("john", "doe") require.NoError(b, sess.Save()) sess.Release() app.ReleaseCtx(c) } }) }) b.Run("storage", func(b *testing.B) { app := fiber.New() store := NewStore(Config{ Storage: memory.New(), }) b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.SetCookie("session_id", "12356789") sess, err := store.Get(c) require.NoError(b, err) sess.Set("john", "doe") require.NoError(b, sess.Save()) sess.Release() app.ReleaseCtx(c) } }) }) } // go test -v -race -run Test_Session_Concurrency ./... func Test_Session_Concurrency(t *testing.T) { app := fiber.New() store := NewStore() var wg sync.WaitGroup errChan := make(chan error, 10) // Buffered channel to collect errors const numGoroutines = 10 // Number of concurrent goroutines to test // Start numGoroutines goroutines for range numGoroutines { wg.Go(func() { localCtx := app.AcquireCtx(&fasthttp.RequestCtx{}) sess, err := store.getSession(localCtx) if err != nil { errChan <- err return } // Set a value sess.Set("name", "john") // get the session id id := sess.ID() // Check if the session is fresh if !sess.Fresh() { errChan <- errors.New("session should be fresh") return } // Save the session if saveErr := sess.Save(); saveErr != nil { errChan <- saveErr return } // release the session sess.Release() // Release the context app.ReleaseCtx(localCtx) // Acquire a new context localCtx = app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(localCtx) // Set the session id in the header localCtx.Request().Header.SetCookie("session_id", id) // Get the session sess, err = store.Get(localCtx) if err != nil { errChan <- err return } defer sess.Release() // Get the value name := sess.Get("name") if name != "john" { errChan <- errors.New("name should be john") return } // Get ID from the session if sess.ID() != id { errChan <- errors.New("id should be the same") return } // Check if the session is fresh if sess.Fresh() { errChan <- errors.New("session should not be fresh") return } // Delete the key sess.Delete("name") // Get the value name = sess.Get("name") if name != nil { errChan <- errors.New("name should be nil") return } // Destroy the session if err := sess.Destroy(); err != nil { errChan <- err return } }) } wg.Wait() // Wait for all goroutines to finish close(errChan) // Close the channel to signal no more errors will be sent // Check for errors sent to errChan for err := range errChan { require.NoError(t, err) } } func Test_Session_StoreGetDecodeSessionDataError(t *testing.T) { // Initialize a new store with default config store := NewStore() // Create a new Fiber app app := fiber.New() // Generate a fake session ID sessionID := uuid.New().String() // Store invalid session data to simulate decode error err := store.Storage.Set(sessionID, []byte("invalid data"), 0) require.NoError(t, err, "Failed to set invalid session data") // Create a new request context c := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(c) // Set the session ID in cookies c.Request().Header.SetCookie("session_id", sessionID) // Attempt to get the session _, err = store.Get(c) require.Error(t, err, "Expected error due to invalid session data, but got nil") // Check that the error message is as expected require.Contains(t, err.Error(), "failed to decode session data", "Unexpected error message") // Check that the error is as expected require.ErrorContains(t, err, "failed to decode session data", "Unexpected error") // Attempt to get the session by ID _, err = store.GetByID(c, sessionID) require.Error(t, err, "Expected error due to invalid session data, but got nil") // Check that the error message is as expected require.ErrorContains(t, err, "failed to decode session data", "Unexpected error") } // go test -run Test_Session_Fresh_Flag_Bug // This test verifies the fix for the fresh flag bug where calling getSession() // multiple times in the same request would incorrectly mark the session as fresh // when the ID was found in context locals. func Test_Session_Fresh_Flag_Bug(t *testing.T) { t.Parallel() store := NewStore() app := fiber.New() // Test Case 1: First call with no session cookie - should be fresh ctx1 := app.AcquireCtx(&fasthttp.RequestCtx{}) sess1, err := store.Get(ctx1) require.NoError(t, err) require.True(t, sess1.Fresh(), "First session should be fresh (no cookie provided)") sessionID := sess1.ID() require.NoError(t, sess1.Save()) sess1.Release() app.ReleaseCtx(ctx1) // Test Case 2: Second call with session cookie - should NOT be fresh ctx2 := app.AcquireCtx(&fasthttp.RequestCtx{}) ctx2.Request().Header.SetCookie("session_id", sessionID) sess2, err := store.Get(ctx2) require.NoError(t, err) require.False(t, sess2.Fresh(), "Existing session should not be fresh") require.Equal(t, sessionID, sess2.ID()) // Test Case 3: Call getSession() again in the same request // This simulates what happens when CSRF middleware calls store operations // The session ID is now in context locals from the first getSession() call sess3, err := store.getSession(ctx2) require.NoError(t, err) require.False(t, sess3.Fresh(), "Session should still not be fresh on second getSession() call in same request") require.Equal(t, sessionID, sess3.ID()) sess2.Release() sess3.Release() app.ReleaseCtx(ctx2) // Test Case 4: Expired session - should generate new ID and be fresh ctx3 := app.AcquireCtx(&fasthttp.RequestCtx{}) ctx3.Request().Header.SetCookie("session_id", "expired-or-nonexistent-id") sess4, err := store.Get(ctx3) require.NoError(t, err) require.True(t, sess4.Fresh(), "New session (after expired/missing data) should be fresh") require.NotEqual(t, "expired-or-nonexistent-id", sess4.ID(), "Should have generated a new session ID") sess4.Release() app.ReleaseCtx(ctx3) } // go test -run Test_Session_CSRF_Scenario // This test simulates the user-reported issue with CSRF + session middleware // where a POST without CSRF token would result in a new session_id cookie func Test_Session_CSRF_Scenario(t *testing.T) { t.Parallel() store := NewStore(Config{ IdleTimeout: 2 * time.Second, // Longer timeout to ensure session persists }) app := fiber.New() // Simulate: First GET request creates session ctx1 := app.AcquireCtx(&fasthttp.RequestCtx{}) sess1, err := store.Get(ctx1) require.NoError(t, err) require.True(t, sess1.Fresh()) firstSessionID := sess1.ID() // Store some data (simulating CSRF token storage) sess1.Set("csrf_token", "token-123") require.NoError(t, sess1.Save()) sess1.Release() app.ReleaseCtx(ctx1) // Small delay to ensure save completes time.Sleep(10 * time.Millisecond) // Simulate: POST request with valid session (before expiration) ctx2 := app.AcquireCtx(&fasthttp.RequestCtx{}) ctx2.Request().Header.SetCookie("session_id", firstSessionID) sess2, err := store.Get(ctx2) require.NoError(t, err) require.False(t, sess2.Fresh(), "Session should not be fresh - it exists") require.Equal(t, firstSessionID, sess2.ID(), "Session ID should remain the same") require.Equal(t, "token-123", sess2.Get("csrf_token")) // Simulate CSRF validation failure (session is accessed but request fails) // Session should still maintain the same ID require.Equal(t, firstSessionID, sess2.ID()) sess2.Release() app.ReleaseCtx(ctx2) // Wait for session to expire time.Sleep(2200 * time.Millisecond) // Simulate: POST request with expired session // This is the scenario the user reported - session data is gone ctx3 := app.AcquireCtx(&fasthttp.RequestCtx{}) ctx3.Request().Header.SetCookie("session_id", firstSessionID) sess3, err := store.Get(ctx3) require.NoError(t, err) require.True(t, sess3.Fresh(), "Session should be fresh - old data expired") require.NotEqual(t, firstSessionID, sess3.ID(), "Should have generated new session ID (expected behavior)") require.Nil(t, sess3.Get("csrf_token"), "Old session data should be gone") sess3.Release() app.ReleaseCtx(ctx3) } // go test -run Test_Session_Multiple_GetSession_Calls // This test ensures that calling getSession() multiple times within the same // request context doesn't incorrectly mark the session as fresh due to the // session ID being stored in context locals func Test_Session_Multiple_GetSession_Calls(t *testing.T) { t.Parallel() store := NewStore() app := fiber.New() // Create initial session ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) sess1, err := store.Get(ctx) require.NoError(t, err) require.True(t, sess1.Fresh()) sessionID := sess1.ID() sess1.Set("test_key", "test_value") require.NoError(t, sess1.Save()) sess1.Release() app.ReleaseCtx(ctx) // New request with existing session ctx2 := app.AcquireCtx(&fasthttp.RequestCtx{}) ctx2.Request().Header.SetCookie("session_id", sessionID) // First getSession() call - loads from storage sess2, err := store.getSession(ctx2) require.NoError(t, err) require.False(t, sess2.Fresh(), "First call: existing session should not be fresh") require.Equal(t, sessionID, sess2.ID()) require.Equal(t, "test_value", sess2.Get("test_key")) // Second getSession() call - ID now in context locals // This is where the bug would manifest before the fix sess3, err := store.getSession(ctx2) require.NoError(t, err) require.False(t, sess3.Fresh(), "Second call: session should STILL not be fresh (bug fix verification)") require.Equal(t, sessionID, sess3.ID()) require.Equal(t, "test_value", sess3.Get("test_key")) // Third call to ensure consistency sess4, err := store.getSession(ctx2) require.NoError(t, err) require.False(t, sess4.Fresh(), "Third call: session should remain not fresh") require.Equal(t, sessionID, sess4.ID()) sess2.Release() sess3.Release() sess4.Release() app.ReleaseCtx(ctx2) } ================================================ FILE: middleware/session/store.go ================================================ package session import ( "context" "encoding/gob" "errors" "fmt" "time" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/internal/storage/memory" "github.com/gofiber/fiber/v3/log" ) // ErrEmptySessionID is an error that occurs when the session ID is empty. var ( ErrEmptySessionID = errors.New("session ID cannot be empty") ErrSessionAlreadyLoadedByMiddleware = errors.New("session already loaded by middleware") ErrSessionIDNotFoundInStore = errors.New("session ID not found in session store") ) // sessionIDKey is the local key type used to store and retrieve the session ID in context. type sessionIDKey int const ( // sessionIDContextKey is the key used to store the session ID in the context locals. sessionIDContextKey sessionIDKey = iota ) // Store manages session data using the configured storage backend. type Store struct { Config } // NewStore creates a new session store with the provided configuration. // // Parameters: // - config: Variadic parameter to override default config. // // Returns: // - *Store: The session store. // // Usage: // // store := session.NewStore() func NewStore(config ...Config) *Store { // Set default config cfg := configDefault(config...) if cfg.Storage == nil { cfg.Storage = memory.New() } store := &Store{ Config: cfg, } if cfg.AbsoluteTimeout > 0 { store.RegisterType(absExpirationKey) store.RegisterType(time.Time{}) } return store } // RegisterType registers a custom type for encoding/decoding into any storage provider. // // Parameters: // - i: The custom type to register. // // Usage: // // store.RegisterType(MyCustomType{}) func (*Store) RegisterType(i any) { gob.Register(i) } // Get will get/create a session. // // This function will return an ErrSessionAlreadyLoadedByMiddleware if // the session is already loaded by the middleware. // // Parameters: // - c: The Fiber context. // // Returns: // - *Session: The session object. // - error: An error if the session retrieval fails or if the session is already loaded by the middleware. // // Usage: // // sess, err := store.Get(c) // if err != nil { // // handle error // } func (s *Store) Get(c fiber.Ctx) (*Session, error) { // If session is already loaded in the context, // it should not be loaded again _, ok := c.Locals(middlewareContextKey).(*Middleware) if ok { return nil, ErrSessionAlreadyLoadedByMiddleware } return s.getSession(c) } // getSession retrieves a session based on the context. // // Parameters: // - c: The Fiber context. // // Returns: // - *Session: The session object. // - error: An error if the session retrieval fails. // // Usage: // // sess, err := store.getSession(c) // if err != nil { // // handle error // } func (s *Store) getSession(c fiber.Ctx) (*Session, error) { var rawData []byte var err error id, ok := c.Locals(sessionIDContextKey).(string) if !ok { id = s.getSessionID(c) } fresh := false // Session is not fresh initially; only set to true if we generate a new ID // Attempt to fetch session data if an ID is provided if id != "" { rawData, err = s.Storage.GetWithContext(c, id) if err != nil { return nil, err } if rawData == nil { // Data not found, prepare to generate a new session id = "" } } // Generate a new ID if needed if id == "" { fresh = true // The session is fresh if a new ID is generated id = s.KeyGenerator() c.Locals(sessionIDContextKey, id) } // Create session object sess := acquireSession() sess.mu.Lock() sess.ctx = c sess.config = s sess.id = id sess.fresh = fresh // Decode session data if found if rawData != nil { sess.data.Lock() err := sess.decodeSessionData(rawData) sess.data.Unlock() if err != nil { sess.mu.Unlock() sess.Release() return nil, fmt.Errorf("failed to decode session data: %w", err) } } sess.mu.Unlock() if fresh && s.AbsoluteTimeout > 0 { sess.setAbsExpiration(time.Now().Add(s.AbsoluteTimeout)) } else if sess.isAbsExpired() { if err := sess.Reset(); err != nil { return nil, fmt.Errorf("failed to reset session: %w", err) } sess.setAbsExpiration(time.Now().Add(s.AbsoluteTimeout)) } return sess, nil } // getSessionID returns the session ID using the configured extractor. // The extractor is provided by the shared extractors package. // // Parameters: // - c: The Fiber context. // // Returns: // - string: The session ID. // // Usage: // // id := store.getSessionID(c) func (s *Store) getSessionID(c fiber.Ctx) string { sessionID, err := s.Extractor.Extract(c) if err != nil { // If extraction fails, return empty string to generate a new session return "" } return sessionID } // Reset deletes all sessions from the storage. // // Returns: // - error: An error if the reset operation fails. // // Usage: // // err := store.Reset() // if err != nil { // // handle error // } func (s *Store) Reset(ctx context.Context) error { return s.Storage.ResetWithContext(ctx) } // Delete deletes a session by its ID. // // Parameters: // - id: The unique identifier of the session. // // Returns: // - error: An error if the deletion fails or if the session ID is empty. // // Usage: // // err := store.Delete(id) // if err != nil { // // handle error // } func (s *Store) Delete(ctx context.Context, id string) error { if id == "" { return ErrEmptySessionID } return s.Storage.DeleteWithContext(ctx, id) } // GetByID retrieves a session by its ID from the storage. // If the session is not found, it returns nil and an error. // // Unlike session middleware methods, this function does not automatically: // // - Load the session into the request context. // // - Save the session data to the storage or update the client cookie. // // Important Notes: // // - The session object returned by GetByID does not have a context associated with it. // // - When using this method alongside session middleware, there is a potential for collisions, // so be mindful of interactions between manually retrieved sessions and middleware-managed sessions. // // - If you modify a session returned by GetByID, you must call session.Save() to persist the changes. // // - When you are done with the session, you should call session.Release() to release the session back to the pool. // // Parameters: // - id: The unique identifier of the session. // // Returns: // - *Session: The session object if found; otherwise, nil. // - error: An error if the session retrieval fails or if the session ID is empty. // // Usage: // // sess, err := store.GetByID(id) // if err != nil { // // handle error // } func (s *Store) GetByID(ctx context.Context, id string) (*Session, error) { if id == "" { return nil, ErrEmptySessionID } rawData, err := s.Storage.GetWithContext(ctx, id) if err != nil { return nil, err } if rawData == nil { return nil, ErrSessionIDNotFoundInStore } sess := acquireSession() sess.mu.Lock() sess.config = s sess.id = id sess.fresh = false sess.data.Lock() decodeErr := sess.decodeSessionData(rawData) sess.data.Unlock() sess.mu.Unlock() if decodeErr != nil { sess.Release() return nil, fmt.Errorf("failed to decode session data: %w", decodeErr) } if s.AbsoluteTimeout > 0 { if sess.isAbsExpired() { if err := sess.Destroy(); err != nil { //nolint:contextcheck // it is not right sess.Release() log.Errorf("failed to destroy session: %v", err) } return nil, ErrSessionIDNotFoundInStore } } return sess, nil } ================================================ FILE: middleware/session/store_test.go ================================================ package session import ( "context" "fmt" "testing" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/extractors" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) // go test -run Test_Store_getSessionID func Test_Store_getSessionID(t *testing.T) { t.Parallel() expectedID := "test-session-id" // fiber instance app := fiber.New() t.Run("from cookie", func(t *testing.T) { t.Parallel() // session store store := NewStore() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) // set cookie ctx.Request().Header.SetCookie(store.Extractor.Key, expectedID) require.Equal(t, expectedID, store.getSessionID(ctx)) }) t.Run("from header", func(t *testing.T) { t.Parallel() // session store store := NewStore(Config{ Extractor: extractors.FromHeader("session_id"), }) // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) // set header ctx.Request().Header.Set(store.Extractor.Key, expectedID) require.Equal(t, expectedID, store.getSessionID(ctx)) }) t.Run("from url query", func(t *testing.T) { t.Parallel() // session store store := NewStore(Config{ Extractor: extractors.FromQuery("session_id"), }) // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) // set url parameter ctx.Request().SetRequestURI(fmt.Sprintf("/path?%s=%s", store.Extractor.Key, expectedID)) require.Equal(t, expectedID, store.getSessionID(ctx)) }) } // go test -run Test_Store_Get // Regression: https://github.com/gofiber/fiber/issues/1408 // Regression: https://github.com/gofiber/fiber/security/advisories/GHSA-98j2-3j3p-fw2v func Test_Store_Get(t *testing.T) { // Regression: https://github.com/gofiber/fiber/security/advisories/GHSA-98j2-3j3p-fw2v t.Parallel() unexpectedID := "test-session-id" // fiber instance app := fiber.New() t.Run("session should be re-generated if it is invalid", func(t *testing.T) { t.Parallel() // session store store := NewStore() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) // set cookie ctx.Request().Header.SetCookie(store.Extractor.Key, unexpectedID) acquiredSession, err := store.Get(ctx) require.NoError(t, err) require.NotEqual(t, unexpectedID, acquiredSession.ID()) }) } // go test -run Test_Store_DeleteSession func Test_Store_DeleteSession(t *testing.T) { t.Parallel() // fiber instance app := fiber.New() // session store store := NewStore() // fiber context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) // Create a new session session, err := store.Get(ctx) require.NoError(t, err) // Save the session ID sessionID := session.ID() // Delete the session err = store.Delete(ctx, sessionID) require.NoError(t, err) // Try to get the session again session, err = store.Get(ctx) require.NoError(t, err) // The session ID should be different now, because the old session was deleted require.NotEqual(t, sessionID, session.ID()) } func TestStore_Get_SessionAlreadyLoaded(t *testing.T) { // Create a new Fiber app app := fiber.New() // Create a new context ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) // Mock middleware and set it in the context middleware := &Middleware{} ctx.Locals(middlewareContextKey, middleware) // Create a new store store := &Store{} // Call the Get method sess, err := store.Get(ctx) // Assert that the error is ErrSessionAlreadyLoadedByMiddleware require.Nil(t, sess) require.Equal(t, ErrSessionAlreadyLoadedByMiddleware, err) } func TestStore_Delete(t *testing.T) { // Create a new store store := NewStore() t.Run("delete with empty session ID", func(t *testing.T) { err := store.Delete(context.Background(), "") require.Error(t, err) require.Equal(t, ErrEmptySessionID, err) }) t.Run("delete non-existing session", func(t *testing.T) { err := store.Delete(context.Background(), "non-existing-session-id") require.NoError(t, err) }) } func Test_Store_GetByID(t *testing.T) { t.Parallel() // Create a new store store := NewStore() t.Run("empty session ID", func(t *testing.T) { t.Parallel() sess, err := store.GetByID(context.Background(), "") require.Error(t, err) require.Nil(t, sess) require.Equal(t, ErrEmptySessionID, err) }) t.Run("nonexistent session ID", func(t *testing.T) { t.Parallel() sess, err := store.GetByID(context.Background(), "nonexistent-session-id") require.Error(t, err) require.Nil(t, sess) require.Equal(t, ErrSessionIDNotFoundInStore, err) }) t.Run("valid session ID", func(t *testing.T) { t.Parallel() app := fiber.New() // Create a new session ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) session, err := store.Get(ctx) defer session.Release() defer app.ReleaseCtx(ctx) require.NoError(t, err) // Save the session ID sessionID := session.ID() // Save the session err = session.Save() require.NoError(t, err) // Retrieve the session by ID retrievedSession, err := store.GetByID(context.Background(), sessionID) require.NoError(t, err) require.NotNil(t, retrievedSession) require.Equal(t, sessionID, retrievedSession.ID()) // Call Save on the retrieved session retrievedSession.Set("key", "value") err = retrievedSession.Save() require.NoError(t, err) // Call Other Session methods require.Equal(t, "value", retrievedSession.Get("key")) require.False(t, retrievedSession.Fresh()) require.NoError(t, retrievedSession.Reset()) require.NoError(t, retrievedSession.Destroy()) require.IsType(t, []any{}, retrievedSession.Keys()) require.NoError(t, retrievedSession.Regenerate()) require.NotPanics(t, func() { retrievedSession.Release() }) }) } ================================================ FILE: middleware/skip/skip.go ================================================ package skip import ( "github.com/gofiber/fiber/v3" ) // New returns a middleware that calls the provided predicate for each request. // If the predicate evaluates to true the wrapped handler is skipped and the next // handler in the chain is executed. func New(handler fiber.Handler, exclude func(c fiber.Ctx) bool) fiber.Handler { if exclude == nil { return handler } return func(c fiber.Ctx) error { if exclude(c) { return c.Next() } return handler(c) } } ================================================ FILE: middleware/skip/skip_test.go ================================================ package skip_test import ( "net/http" "net/http/httptest" "testing" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/middleware/skip" "github.com/stretchr/testify/require" ) // go test -run Test_Skip func Test_Skip(t *testing.T) { t.Parallel() app := fiber.New() app.Use(skip.New(errTeapotHandler, func(fiber.Ctx) bool { return true })) app.Get("/", helloWorldHandler) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) } // go test -run Test_SkipFalse func Test_SkipFalse(t *testing.T) { t.Parallel() app := fiber.New() app.Use(skip.New(errTeapotHandler, func(fiber.Ctx) bool { return false })) app.Get("/", helloWorldHandler) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusTeapot, resp.StatusCode) } // go test -run Test_SkipNilFunc func Test_SkipNilFunc(t *testing.T) { t.Parallel() app := fiber.New() app.Use(skip.New(errTeapotHandler, nil)) app.Get("/", helloWorldHandler) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err) require.Equal(t, fiber.StatusTeapot, resp.StatusCode) } func helloWorldHandler(c fiber.Ctx) error { return c.SendString("Hello, World 👋!") } func errTeapotHandler(fiber.Ctx) error { return fiber.ErrTeapot } ================================================ FILE: middleware/static/config.go ================================================ package static import ( "io/fs" "time" "github.com/gofiber/fiber/v3" ) // Config defines the config for middleware. type Config struct { // FS is the file system to serve the static files from. // You can use interfaces compatible with fs.FS like embed.FS, os.DirFS etc. // // Optional. Default: nil FS fs.FS // Next defines a function to skip this middleware when returned true. // // Optional. Default: nil Next func(c fiber.Ctx) bool // ModifyResponse defines a function that allows you to alter the response. // // Optional. Default: nil ModifyResponse fiber.Handler // NotFoundHandler defines a function to handle when the path is not found. // // Optional. Default: nil NotFoundHandler fiber.Handler // The names of the index files for serving a directory. // // Optional. Default: []string{"index.html"}. IndexNames []string `json:"index"` // Expiration duration for inactive file handlers. // Use a negative time.Duration to disable it. // // Optional. Default: 10 * time.Second. CacheDuration time.Duration `json:"cache_duration"` // The value for the Cache-Control HTTP-header // that is set on the file response. MaxAge is defined in seconds. // // Optional. Default: 0. MaxAge int `json:"max_age"` // When set to true, the server tries minimizing CPU usage by caching compressed files. // This works differently than the github.com/gofiber/compression middleware. // // Optional. Default: false Compress bool `json:"compress"` // When set to true, enables byte range requests. // // Optional. Default: false ByteRange bool `json:"byte_range"` // When set to true, enables directory browsing. // // Optional. Default: false. Browse bool `json:"browse"` // When set to true, enables direct download. // // Optional. Default: false. Download bool `json:"download"` } // ConfigDefault is the default config var ConfigDefault = Config{ IndexNames: []string{"index.html"}, CacheDuration: 10 * time.Second, } // Helper function to set default values func configDefault(config ...Config) Config { // Return default config if nothing provided if len(config) < 1 { return ConfigDefault } // Override default config cfg := config[0] // Set default values if len(cfg.IndexNames) == 0 { cfg.IndexNames = ConfigDefault.IndexNames } if cfg.CacheDuration == 0 { cfg.CacheDuration = ConfigDefault.CacheDuration } return cfg } ================================================ FILE: middleware/static/static.go ================================================ package static import ( "bytes" "errors" "fmt" "io/fs" "net/url" "os" pathpkg "path" "path/filepath" "slices" "strconv" "strings" "sync" "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" "github.com/gofiber/fiber/v3" ) var ErrInvalidPath = errors.New("invalid path") // sanitizePath validates and cleans the requested path. // It returns an error if the path attempts to traverse directories. func sanitizePath(p []byte, filesystem fs.FS) ([]byte, error) { var s string hasTrailingSlash := len(p) > 0 && p[len(p)-1] == '/' if bytes.IndexByte(p, '\\') >= 0 { b := make([]byte, len(p)) copy(b, p) for i := range b { if b[i] == '\\' { b[i] = '/' } } s = utils.UnsafeString(b) } else { s = utils.UnsafeString(p) } // repeatedly unescape until it no longer changes, catching errors for strings.IndexByte(s, '%') >= 0 { us, err := url.PathUnescape(s) if err != nil { return nil, ErrInvalidPath } if us == s { break } s = us } if strings.IndexByte(s, '\\') >= 0 { return nil, ErrInvalidPath } // reject any null bytes if strings.IndexByte(s, '\x00') >= 0 { return nil, ErrInvalidPath } normalized := filepath.ToSlash(s) if filesystem == nil && strings.HasPrefix(normalized, "//") { return nil, ErrInvalidPath } s = pathpkg.Clean("/" + normalized) trimmed := utils.TrimLeft(s, '/') if trimmed != "" { if slices.Contains(strings.Split(trimmed, "/"), "..") { return nil, ErrInvalidPath } } if filesystem == nil { normalizedClean := filepath.ToSlash(trimmed) if strings.HasPrefix(normalizedClean, "//") { return nil, ErrInvalidPath } if volume := filepath.VolumeName(normalizedClean); volume != "" { return nil, ErrInvalidPath } if len(normalizedClean) >= 2 && normalizedClean[1] == ':' { drive := normalizedClean[0] if (drive >= 'a' && drive <= 'z') || (drive >= 'A' && drive <= 'Z') { return nil, ErrInvalidPath } } if strings.HasPrefix(filepath.ToSlash(s), "//") { return nil, ErrInvalidPath } } if filesystem != nil { s = trimmed if s == "" { return []byte("/"), nil } if !fs.ValidPath(s) { return nil, ErrInvalidPath } s = "/" + s } if hasTrailingSlash && len(s) > 1 && s[len(s)-1] != '/' { s += "/" } return utils.UnsafeBytes(s), nil } // New creates a new middleware handler. // The root argument specifies the root directory from which to serve static assets. // // Note: Root has to be string or fs.FS; otherwise, it will panic. func New(root string, cfg ...Config) fiber.Handler { config := configDefault(cfg...) var createFS sync.Once var fileHandler fasthttp.RequestHandler var cacheControlValue string var rootIsFile bool // adjustments for io/fs compatibility if config.FS != nil && root == "" { root = "." } return func(c fiber.Ctx) error { // Don't execute middleware if Next returns true if config.Next != nil && config.Next(c) { return c.Next() } // We only serve static assets on GET or HEAD methods method := c.Method() if method != fiber.MethodGet && method != fiber.MethodHead { return c.Next() } // Initialize FS createFS.Do(func() { prefix := c.Route().Path if check, err := isFile(root, config.FS); err == nil { rootIsFile = check } // Is prefix a partial wildcard? if before, _, found := strings.Cut(prefix, "*"); found { // /john* -> /john prefix = before } prefixLen := len(prefix) if prefixLen > 1 && prefix[prefixLen-1:] == "/" { // /john/ -> /john prefixLen-- } fileServer := &fasthttp.FS{ Root: root, FS: config.FS, AllowEmptyRoot: true, GenerateIndexPages: config.Browse, AcceptByteRange: config.ByteRange, Compress: config.Compress, CompressBrotli: config.Compress, // Brotli compression won't work without this CompressZstd: config.Compress, // Zstd compression won't work without this CompressedFileSuffixes: c.App().Config().CompressedFileSuffixes, CacheDuration: config.CacheDuration, SkipCache: config.CacheDuration < 0, IndexNames: config.IndexNames, PathNotFound: func(fctx *fasthttp.RequestCtx) { fctx.Response.SetStatusCode(fiber.StatusNotFound) }, } fileServer.PathRewrite = func(fctx *fasthttp.RequestCtx) []byte { path := fctx.Path() if len(path) >= prefixLen { checkFile, err := isFile(root, fileServer.FS) if err != nil { return path } // If the root is a file, we need to reset the path to "/" always. switch { case checkFile && fileServer.FS == nil: path = []byte("/") case checkFile && fileServer.FS != nil: path = utils.UnsafeBytes(root) default: path = path[prefixLen:] if len(path) == 0 || path[len(path)-1] != '/' { path = append(path, '/') } } } if len(path) > 0 && path[0] != '/' { path = append([]byte("/"), path...) } sanitized, err := sanitizePath(path, fileServer.FS) if err != nil { // return a guaranteed-missing path so fs responds with 404 return []byte("/__fiber_invalid__") } return sanitized } maxAge := config.MaxAge if maxAge > 0 { cacheControlValue = "public, max-age=" + strconv.Itoa(maxAge) } fileHandler = fileServer.NewRequestHandler() }) // Serve file fileHandler(c.RequestCtx()) // Sets the response Content-Disposition header to attachment if the Download option is true if config.Download { name := filepath.Base(c.Path()) if rootIsFile { name = filepath.Base(root) } c.Attachment(name) } // Return request if found and not forbidden status := c.RequestCtx().Response.StatusCode() if status != fiber.StatusNotFound && status != fiber.StatusForbidden { if cacheControlValue != "" { c.RequestCtx().Response.Header.Set(fiber.HeaderCacheControl, cacheControlValue) } if config.ModifyResponse != nil { return config.ModifyResponse(c) } return nil } // Return custom 404 handler if provided. if config.NotFoundHandler != nil { return config.NotFoundHandler(c) } // Reset response to default c.RequestCtx().SetContentType("") // Issue #420 c.RequestCtx().Response.SetStatusCode(fiber.StatusOK) c.RequestCtx().Response.SetBodyString("") // Next middleware return c.Next() } } // isFile checks if the root is a file. func isFile(root string, filesystem fs.FS) (bool, error) { var file fs.File var err error if filesystem != nil { file, err = filesystem.Open(root) if err != nil { return false, fmt.Errorf("static: %w", err) } defer func() { _ = file.Close() //nolint:errcheck // not needed }() } else { file, err = os.Open(filepath.Clean(root)) if err != nil { return false, fmt.Errorf("static: %w", err) } defer func() { _ = file.Close() //nolint:errcheck // not needed }() } stat, err := file.Stat() if err != nil { return false, fmt.Errorf("static: %w", err) } return stat.Mode().IsRegular(), nil } ================================================ FILE: middleware/static/static_test.go ================================================ package static import ( "embed" "io" "io/fs" "net/http" "net/http/httptest" "net/url" "os" "path/filepath" "runtime" "strings" "testing" "time" "github.com/stretchr/testify/require" "github.com/gofiber/fiber/v3" ) const ( winOS = "windows" testCSSDir = "../../.github/testdata/fs/css" ) var testConfig = fiber.TestConfig{ Timeout: 10 * time.Second, FailOnTimeout: true, } // go test -run Test_Static_Index_Default func Test_Static_Index_Default(t *testing.T) { t.Parallel() app := fiber.New() app.Get("/prefix", New("../../.github/workflows")) app.Get("", New("../../.github/")) app.Get("test", New("", Config{ IndexNames: []string{"index.html"}, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength)) require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Contains(t, string(body), "Hello, World!") resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/not-found", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 404, resp.StatusCode, "Status code") require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength)) require.Equal(t, fiber.MIMETextPlainCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "Not Found", string(body)) } // go test -run Test_Static_Index func Test_Static_Direct(t *testing.T) { t.Parallel() app := fiber.New() app.Get("/*", New("../../.github")) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/index.html", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength)) require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Contains(t, string(body), "Hello, World!") resp, err = app.Test(httptest.NewRequest(fiber.MethodPost, "/index.html", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 405, resp.StatusCode, "Status code") require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength)) require.Equal(t, fiber.MIMETextPlainCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/testdata/testRoutes.json", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength)) require.Equal(t, fiber.MIMEApplicationJSON, resp.Header.Get("Content-Type")) require.Empty(t, resp.Header.Get(fiber.HeaderCacheControl), "CacheControl Control") body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Contains(t, string(body), "test_routes") } // go test -run Test_Static_MaxAge func Test_Static_MaxAge(t *testing.T) { t.Parallel() app := fiber.New() app.Get("/*", New("../../.github", Config{ MaxAge: 100, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/index.html", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength)) require.Equal(t, "text/html; charset=utf-8", resp.Header.Get(fiber.HeaderContentType)) require.Equal(t, "public, max-age=100", resp.Header.Get(fiber.HeaderCacheControl), "CacheControl Control") } // go test -run Test_Static_Custom_CacheControl func Test_Static_Custom_CacheControl(t *testing.T) { t.Parallel() app := fiber.New() app.Get("/*", New("../../.github", Config{ ModifyResponse: func(c fiber.Ctx) error { if strings.Contains(c.GetRespHeader("Content-Type"), "text/html") { c.Response().Header.Set("Cache-Control", "no-cache, no-store, must-revalidate") } return nil }, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/index.html", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, "no-cache, no-store, must-revalidate", resp.Header.Get(fiber.HeaderCacheControl), "CacheControl Control") normalResp, normalErr := app.Test(httptest.NewRequest(fiber.MethodGet, "/config.yml", http.NoBody)) require.NoError(t, normalErr, "app.Test(req)") require.Empty(t, normalResp.Header.Get(fiber.HeaderCacheControl), "CacheControl Control") } func Test_Static_Disable_Cache(t *testing.T) { // Skip on Windows. It's not possible to delete a file that is in use. if runtime.GOOS == winOS { t.SkipNow() } t.Parallel() app := fiber.New() file, err := os.Create("../../.github/test.txt") require.NoError(t, err) _, err = file.WriteString("Hello, World!") require.NoError(t, err) require.NoError(t, file.Close()) // Remove the file even if the test fails defer func() { _ = os.Remove("../../.github/test.txt") //nolint:errcheck // not needed }() app.Get("/*", New("../../.github/", Config{ CacheDuration: -1, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test.txt", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Empty(t, resp.Header.Get(fiber.HeaderCacheControl), "CacheControl Control") body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Contains(t, string(body), "Hello, World!") require.NoError(t, os.Remove("../../.github/test.txt")) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/test.txt", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Empty(t, resp.Header.Get(fiber.HeaderCacheControl), "CacheControl Control") body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "Not Found", string(body)) } func Test_Static_NotFoundHandler(t *testing.T) { t.Parallel() app := fiber.New() app.Get("/*", New("../../.github", Config{ NotFoundHandler: func(c fiber.Ctx) error { return c.SendString("Custom 404") }, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/not-found", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 404, resp.StatusCode, "Status code") require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength)) require.Equal(t, fiber.MIMETextPlainCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "Custom 404", string(body)) } // go test -run Test_Static_Download func Test_Static_Download(t *testing.T) { t.Parallel() app := fiber.New() app.Get("/fiber.png", New("../../.github/testdata/fs/img/fiber.png", Config{ Download: true, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/fiber.png", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength)) require.Equal(t, "image/png", resp.Header.Get(fiber.HeaderContentType)) require.Equal(t, `attachment; filename="fiber.png"`, resp.Header.Get(fiber.HeaderContentDisposition)) } func Test_Static_Download_NonASCII(t *testing.T) { // Skip on Windows. It's not possible to delete a file that is in use. if runtime.GOOS == "windows" { t.SkipNow() } t.Parallel() dir := t.TempDir() fname := "файл.txt" path := filepath.Join(dir, fname) require.NoError(t, os.WriteFile(path, []byte("x"), 0o644)) //nolint:gosec // Not a concern app := fiber.New() app.Get("/file", New(path, Config{Download: true})) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/file", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") expect := "attachment; filename=\"" + fname + "\"; filename*=UTF-8''" + url.PathEscape(fname) require.Equal(t, expect, resp.Header.Get(fiber.HeaderContentDisposition)) } // go test -run Test_Static_Group func Test_Static_Group(t *testing.T) { t.Parallel() app := fiber.New() grp := app.Group("/v1", func(c fiber.Ctx) error { c.Set("Test-Header", "123") return c.Next() }) grp.Get("/v2*", New("../../.github/index.html")) req := httptest.NewRequest(fiber.MethodGet, "/v1/v2", http.NoBody) resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength)) require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) require.Equal(t, "123", resp.Header.Get("Test-Header")) grp = app.Group("/v2") grp.Get("/v3*", New("../../.github/index.html")) req = httptest.NewRequest(fiber.MethodGet, "/v2/v3/john/doe", http.NoBody) resp, err = app.Test(req) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength)) require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) } func Test_Static_Wildcard(t *testing.T) { t.Parallel() app := fiber.New() app.Get("*", New("../../.github/index.html")) req := httptest.NewRequest(fiber.MethodGet, "/yesyes/john/doe", http.NoBody) resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength)) require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Contains(t, string(body), "Test file") } func Test_Static_Prefix_Wildcard(t *testing.T) { t.Parallel() app := fiber.New() app.Get("/test*", New("../../.github/index.html")) req := httptest.NewRequest(fiber.MethodGet, "/test/john/doe", http.NoBody) resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength)) require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) app.Get("/my/nameisjohn*", New("../../.github/index.html")) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/my/nameisjohn/no/its/not", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength)) require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Contains(t, string(body), "Test file") } func Test_Static_Prefix(t *testing.T) { t.Parallel() app := fiber.New() app.Get("/john*", New("../../.github")) req := httptest.NewRequest(fiber.MethodGet, "/john/index.html", http.NoBody) resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength)) require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) app.Get("/prefix*", New("../../.github/testdata")) req = httptest.NewRequest(fiber.MethodGet, "/prefix/index.html", http.NoBody) resp, err = app.Test(req) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength)) require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) app.Get("/single*", New("../../.github/testdata/testRoutes.json")) req = httptest.NewRequest(fiber.MethodGet, "/single", http.NoBody) resp, err = app.Test(req) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength)) require.Equal(t, fiber.MIMEApplicationJSON, resp.Header.Get(fiber.HeaderContentType)) } func Test_Static_Trailing_Slash(t *testing.T) { t.Parallel() app := fiber.New() app.Get("/john*", New("../../.github")) req := httptest.NewRequest(fiber.MethodGet, "/john/", http.NoBody) resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength)) require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) app.Get("/john_without_index*", New(testCSSDir)) req = httptest.NewRequest(fiber.MethodGet, "/john_without_index/", http.NoBody) resp, err = app.Test(req) require.NoError(t, err, "app.Test(req)") require.Equal(t, 404, resp.StatusCode, "Status code") require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength)) require.Equal(t, fiber.MIMETextPlainCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) app.Use("/john", New("../../.github")) req = httptest.NewRequest(fiber.MethodGet, "/john/", http.NoBody) resp, err = app.Test(req) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength)) require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) req = httptest.NewRequest(fiber.MethodGet, "/john", http.NoBody) resp, err = app.Test(req) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength)) require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) app.Use("/john_without_index/", New(testCSSDir)) req = httptest.NewRequest(fiber.MethodGet, "/john_without_index/", http.NoBody) resp, err = app.Test(req) require.NoError(t, err, "app.Test(req)") require.Equal(t, 404, resp.StatusCode, "Status code") require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength)) require.Equal(t, fiber.MIMETextPlainCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) } func Test_Static_Next(t *testing.T) { t.Parallel() app := fiber.New() app.Get("/*", New("../../.github", Config{ Next: func(c fiber.Ctx) bool { return c.Get("X-Custom-Header") == "skip" }, })) app.Get("/*", func(c fiber.Ctx) error { return c.SendString("You've skipped app.Static") }) t.Run("app.Static is skipped: invoking Get handler", func(t *testing.T) { t.Parallel() req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set("X-Custom-Header", "skip") resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength)) require.Equal(t, fiber.MIMETextPlainCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Contains(t, string(body), "You've skipped app.Static") }) t.Run("app.Static is not skipped: serving index.html", func(t *testing.T) { t.Parallel() req := httptest.NewRequest(fiber.MethodGet, "/", http.NoBody) req.Header.Set("X-Custom-Header", "don't skip") resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength)) require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Contains(t, string(body), "Hello, World!") }) } func Test_Route_Static_Root(t *testing.T) { t.Parallel() dir := testCSSDir app := fiber.New() app.Get("/*", New(dir, Config{ Browse: true, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/style.css", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") body, err := io.ReadAll(resp.Body) require.NoError(t, err, "app.Test(req)") require.Contains(t, string(body), "color") app = fiber.New() app.Get("/*", New(dir)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 404, resp.StatusCode, "Status code") resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/style.css", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") body, err = io.ReadAll(resp.Body) require.NoError(t, err, "app.Test(req)") require.Contains(t, string(body), "color") } func Test_Route_Static_HasPrefix(t *testing.T) { t.Parallel() dir := testCSSDir app := fiber.New() app.Get("/static*", New(dir, Config{ Browse: true, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/static", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/static/", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/static/style.css", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") body, err := io.ReadAll(resp.Body) require.NoError(t, err, "app.Test(req)") require.Contains(t, string(body), "color") app = fiber.New() app.Get("/static/*", New(dir, Config{ Browse: true, })) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/static", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/static/", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/static/style.css", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") body, err = io.ReadAll(resp.Body) require.NoError(t, err, "app.Test(req)") require.Contains(t, string(body), "color") app = fiber.New() app.Get("/static*", New(dir)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/static", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 404, resp.StatusCode, "Status code") resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/static/", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 404, resp.StatusCode, "Status code") resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/static/style.css", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") body, err = io.ReadAll(resp.Body) require.NoError(t, err, "app.Test(req)") require.Contains(t, string(body), "color") app = fiber.New() app.Get("/static*", New(dir)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/static", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 404, resp.StatusCode, "Status code") resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/static/", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 404, resp.StatusCode, "Status code") resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/static/style.css", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") body, err = io.ReadAll(resp.Body) require.NoError(t, err, "app.Test(req)") require.Contains(t, string(body), "color") } func Test_Static_FS(t *testing.T) { t.Parallel() app := fiber.New() app.Get("/*", New("", Config{ FS: os.DirFS("../../.github/testdata/fs"), Browse: true, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/css/style.css", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.Equal(t, fiber.MIMETextCSSCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) body, err := io.ReadAll(resp.Body) require.NoError(t, err, "app.Test(req)") require.Contains(t, string(body), "color") } /*func Test_Static_FS_DifferentRoot(t *testing.T) { t.Parallel() app := fiber.New() app.Get("/*", New("fs", Config{ FS: os.DirFS("../../.github/testdata"), IndexNames: []string{"index2.html"}, Browse: true, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) body, err := io.ReadAll(resp.Body) require.NoError(t, err, "app.Test(req)") require.Contains(t, string(body), "

Hello, World!

") resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/css/style.css", nil)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.Equal(t, fiber.MIMETextCSSCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) body, err = io.ReadAll(resp.Body) require.NoError(t, err, "app.Test(req)") require.Contains(t, string(body), "color") }*/ //go:embed static.go config.go var fsTestFilesystem embed.FS func Test_Static_FS_Browse(t *testing.T) { t.Parallel() app := fiber.New() app.Get("/embed*", New("", Config{ FS: fsTestFilesystem, Browse: true, })) app.Get("/dirfs*", New("", Config{ FS: os.DirFS(testCSSDir), Browse: true, })) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/dirfs", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) body, err := io.ReadAll(resp.Body) require.NoError(t, err, "app.Test(req)") require.Contains(t, string(body), "style.css") resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/dirfs/style.css", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.Equal(t, fiber.MIMETextCSSCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) body, err = io.ReadAll(resp.Body) require.NoError(t, err, "app.Test(req)") require.Contains(t, string(body), "color") resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/dirfs/test", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/dirfs/test/style2.css", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.Equal(t, fiber.MIMETextCSSCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) body, err = io.ReadAll(resp.Body) require.NoError(t, err, "app.Test(req)") require.Contains(t, string(body), "color") resp, err = app.Test(httptest.NewRequest(fiber.MethodGet, "/embed", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) body, err = io.ReadAll(resp.Body) require.NoError(t, err, "app.Test(req)") require.Contains(t, string(body), "static.go") } func Test_Static_FS_Prefix_Wildcard(t *testing.T) { t.Parallel() app := fiber.New() app.Get("/test*", New("index.html", Config{ FS: os.DirFS("../../.github"), IndexNames: []string{"not_index.html"}, })) req := httptest.NewRequest(fiber.MethodGet, "/test/john/doe", http.NoBody) resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.NotEmpty(t, resp.Header.Get(fiber.HeaderContentLength)) require.Equal(t, fiber.MIMETextHTMLCharsetUTF8, resp.Header.Get(fiber.HeaderContentType)) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Contains(t, string(body), "Test file") } func Test_isFile(t *testing.T) { t.Parallel() cases := []struct { filesystem fs.FS gotError error name string path string expected bool }{ { name: "file", path: "index.html", filesystem: os.DirFS("../../.github"), expected: true, }, { name: "file", path: "index2.html", filesystem: os.DirFS("../../.github"), expected: false, gotError: fs.ErrNotExist, }, { name: "directory", path: ".", filesystem: os.DirFS("../../.github"), expected: false, }, { name: "directory", path: "not_exists", filesystem: os.DirFS("../../.github"), expected: false, gotError: fs.ErrNotExist, }, { name: "directory", path: ".", filesystem: os.DirFS(testCSSDir), expected: false, }, { name: "file", path: testCSSDir + "/style.css", filesystem: nil, expected: true, }, { name: "file", path: testCSSDir + "/style2.css", filesystem: nil, expected: false, gotError: fs.ErrNotExist, }, { name: "directory", path: testCSSDir, filesystem: nil, expected: false, }, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { c := c t.Parallel() actual, err := isFile(c.path, c.filesystem) require.ErrorIs(t, err, c.gotError) require.Equal(t, c.expected, actual) }) } } func Test_Static_Compress(t *testing.T) { t.Parallel() dir := "../../.github/testdata/fs" app := fiber.New() app.Get("/*", New(dir, Config{ Compress: true, })) // Note: deflate is not supported by fasthttp.FS algorithms := []string{"zstd", "gzip", "br"} for _, algo := range algorithms { t.Run(algo+"_compression", func(t *testing.T) { t.Parallel() // request non-compressible file (less than 200 bytes), Content Length will remain the same req := httptest.NewRequest(fiber.MethodGet, "/css/style.css", http.NoBody) req.Header.Set("Accept-Encoding", algo) resp, err := app.Test(req, testConfig) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.Empty(t, resp.Header.Get(fiber.HeaderContentEncoding)) require.Equal(t, "46", resp.Header.Get(fiber.HeaderContentLength)) // request compressible file, ContentLength will change req = httptest.NewRequest(fiber.MethodGet, "/index.html", http.NoBody) req.Header.Set("Accept-Encoding", algo) resp, err = app.Test(req, testConfig) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.Equal(t, algo, resp.Header.Get(fiber.HeaderContentEncoding)) require.Greater(t, "299", resp.Header.Get(fiber.HeaderContentLength)) }) } } func Test_Static_Compress_WithoutEncoding(t *testing.T) { t.Parallel() dir := "../../.github/testdata/fs" app := fiber.New() app.Get("/*", New(dir, Config{ Compress: true, CacheDuration: 1 * time.Second, })) // request compressible file without encoding req := httptest.NewRequest(fiber.MethodGet, "/index.html", http.NoBody) resp, err := app.Test(req, testConfig) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.Empty(t, resp.Header.Get(fiber.HeaderContentEncoding)) require.Equal(t, "299", resp.Header.Get(fiber.HeaderContentLength)) // request compressible file with different encodings algorithms := []string{"zstd", "gzip", "br"} fileSuffixes := map[string]string{ "gzip": ".fiber.gz", "br": ".fiber.br", "zstd": ".fiber.zst", } for _, algo := range algorithms { // Wait for cache to expire time.Sleep(2 * time.Second) fileName := "index.html" compressedFileName := dir + "/index.html" + fileSuffixes[algo] req = httptest.NewRequest(fiber.MethodGet, "/"+fileName, http.NoBody) req.Header.Set("Accept-Encoding", algo) resp, err = app.Test(req, testConfig) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.Equal(t, algo, resp.Header.Get(fiber.HeaderContentEncoding)) require.Greater(t, "299", resp.Header.Get(fiber.HeaderContentLength)) // verify suffixed file was created _, err := os.Stat(compressedFileName) require.NoError(t, err, "File should exist") } } func Test_Static_Compress_WithFileSuffixes(t *testing.T) { t.Parallel() dir := "../../.github/testdata/fs" fileSuffixes := map[string]string{ "gzip": ".test.gz", "br": ".test.br", "zstd": ".test.zst", } app := fiber.New(fiber.Config{ CompressedFileSuffixes: fileSuffixes, }) app.Get("/*", New(dir, Config{ Compress: true, CacheDuration: 1 * time.Second, })) // request compressible file with different encodings algorithms := []string{"zstd", "gzip", "br"} for _, algo := range algorithms { // Wait for cache to expire time.Sleep(2 * time.Second) fileName := "index.html" compressedFileName := dir + "/index.html" + fileSuffixes[algo] req := httptest.NewRequest(fiber.MethodGet, "/"+fileName, http.NoBody) req.Header.Set("Accept-Encoding", algo) resp, err := app.Test(req, testConfig) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.Equal(t, algo, resp.Header.Get(fiber.HeaderContentEncoding)) require.Greater(t, "299", resp.Header.Get(fiber.HeaderContentLength)) // verify suffixed file was created _, err = os.Stat(compressedFileName) require.NoError(t, err, "File should exist") } } func Test_Router_Mount_n_Static(t *testing.T) { t.Parallel() app := fiber.New() app.Use("/static", New(testCSSDir, Config{Browse: true})) app.Get("/", func(c fiber.Ctx) error { return c.SendString("Home") }) subApp := fiber.New() app.Use("/mount", subApp) subApp.Get("/test", func(c fiber.Ctx) error { return c.SendString("Hello from /test") }) app.Use(func(c fiber.Ctx) error { return c.Status(fiber.StatusNotFound).SendString("Not Found") }) resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/static/style.css", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") } func Test_Static_PathTraversal(t *testing.T) { // Skip this test if running on Windows if runtime.GOOS == winOS { t.Skip("Skipping Windows-specific tests") } t.Parallel() app := fiber.New() // Serve only from testCSSDir // This directory should contain `style.css` but not `index.html` or anything above it. rootDir := testCSSDir app.Get("/*", New(rootDir)) // A valid request: should succeed validReq := httptest.NewRequest(fiber.MethodGet, "/style.css", http.NoBody) validResp, err := app.Test(validReq) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, validResp.StatusCode, "Status code") require.Equal(t, fiber.MIMETextCSSCharsetUTF8, validResp.Header.Get(fiber.HeaderContentType)) validBody, err := io.ReadAll(validResp.Body) require.NoError(t, err, "app.Test(req)") require.Contains(t, string(validBody), "color") // Helper function to assert that a given path is blocked. // Blocked can mean different status codes depending on what triggered the block. // We'll accept 400 or 404 as "blocked" statuses: // - 404 is the expected blocked response in most cases. // - 400 might occur if fasthttp rejects the request before it's even processed (e.g., null bytes). assertTraversalBlocked := func(path string) { req := httptest.NewRequest(fiber.MethodGet, path, http.NoBody) resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") status := resp.StatusCode require.Truef(t, status == 400 || status == 404, "Status code for path traversal %s should be 400 or 404, got %d", path, status) body, err := io.ReadAll(resp.Body) require.NoError(t, err) // If we got a 404, we expect the "Not Found" message because that's how fiber handles NotFound by default. if status == 404 { require.Contains(t, string(body), "Not Found", "Blocked traversal should have a \"Not Found\" message for %s", path) } else { require.Contains(t, string(body), "Are you a hacker?", "Blocked traversal should have a \"Not Found\" message for %s", path) } } // Basic attempts to escape the directory assertTraversalBlocked("/index.html..") assertTraversalBlocked("/style.css..") assertTraversalBlocked("/../index.html") assertTraversalBlocked("/../../index.html") assertTraversalBlocked("/../../../index.html") // Attempts with double slashes assertTraversalBlocked("//../index.html") assertTraversalBlocked("/..//index.html") // Encoded attempts: `%2e` is '.' and `%2f` is '/' assertTraversalBlocked("/..%2findex.html") // ../index.html assertTraversalBlocked("/%2e%2e/index.html") // ../index.html assertTraversalBlocked("/%2e%2e%2f%2e%2e/secret") // ../../../secret // Mixed encoded and normal attempts assertTraversalBlocked("/%2e%2e/../index.html") // ../../index.html assertTraversalBlocked("/..%2f..%2fsecret.json") // ../../../secret.json // Attempts with current directory references assertTraversalBlocked("/./../index.html") assertTraversalBlocked("/././../index.html") // Trailing slashes assertTraversalBlocked("/../") assertTraversalBlocked("/../../") // Attempts to load files from an absolute path outside the root assertTraversalBlocked("/" + rootDir + "/../../index.html") // Additional edge cases: // Double-encoded `..` assertTraversalBlocked("/%252e%252e/index.html") // double-encoded .. -> ../index.html after double decoding // Multiple levels of encoding and traversal assertTraversalBlocked("/%2e%2e%2F..%2f%2e%2e%2fWINDOWS") // multiple ups and unusual pattern assertTraversalBlocked("/%2e%2e%2F..%2f%2e%2e%2f%2e%2e/secret") // more complex chain of ../ // Null byte attempts assertTraversalBlocked("/index.html%00.jpg") assertTraversalBlocked("/%00index.html") assertTraversalBlocked("/somefolder%00/something") assertTraversalBlocked("/%00/index.html") // Attempts to access known system files assertTraversalBlocked("/etc/passwd") assertTraversalBlocked("/etc/") // Complex mixed attempts with encoded slashes and dots assertTraversalBlocked("/..%2F..%2F..%2F..%2Fetc%2Fpasswd") // Attempts inside subdirectories with encoded traversal assertTraversalBlocked("/somefolder/%2e%2e%2findex.html") assertTraversalBlocked("/somefolder/%2e%2e%2f%2e%2e%2findex.html") // Backslash encoded attempts assertTraversalBlocked("/%5C..%5Cindex.html") assertTraversalBlocked("/%5c..%5c..%5cetc%5cpasswd") assertTraversalBlocked("/%255c..%255c..%255cetc%255cpasswd") assertTraversalBlocked("/..%5c..%5cetc%5cpasswd") assertTraversalBlocked("/%2e%2e%5c%2e%2e%5cetc%5cpasswd") assertTraversalBlocked("/%2e%2e%2f%2e%2e%5cetc%5cpasswd") assertTraversalBlocked("/%2f%2e%2e%2f%2e%2e%2fetc%2fpasswd") assertTraversalBlocked("/.%2e/.%2e/etc/passwd") assertTraversalBlocked("/..%2f..%2f..%2f..%2fetc%2fpasswd") assertTraversalBlocked("/%2e%2e%2f%2e%2e%2f%2e%2e%2fetc%2fpasswd") assertTraversalBlocked("/%2e%2e%2f%2e%2e%2fetc%2fshadow") assertTraversalBlocked("/%2e%2e/%2e%2e/var/log/auth.log") assertTraversalBlocked("/..%2f..%2fvar%2flog%2fauth.log") assertTraversalBlocked("/%2e%2e//%2e%2e//etc/passwd") assertTraversalBlocked("/..//..//etc/passwd") assertTraversalBlocked("/%2e%2e%2f%2e%2e%2f%2e%2e%2fproc%2fself%2fenviron") assertTraversalBlocked("/%2e%2e%2f%2e%2e%2f%2e%2e%2froot%2f.ssh%2fauthorized_keys") } func Test_Static_PathTraversal_WindowsOnly(t *testing.T) { // Skip this test if not running on Windows if runtime.GOOS != winOS { t.Skip("Skipping Windows-specific tests") } t.Parallel() app := fiber.New() // Serve only from testCSSDir rootDir := testCSSDir app.Get("/*", New(rootDir)) // A valid request (relative path without backslash): validReq := httptest.NewRequest(fiber.MethodGet, "/style.css", http.NoBody) validResp, err := app.Test(validReq) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, validResp.StatusCode, "Status code for valid file on Windows") body, err := io.ReadAll(validResp.Body) require.NoError(t, err, "app.Test(req)") require.Contains(t, string(body), "color") // Helper to test blocked responses assertTraversalBlocked := func(path string) { req := httptest.NewRequest(fiber.MethodGet, path, http.NoBody) resp, err := app.Test(req) require.NoError(t, err, "app.Test(req)") // We expect a blocked request to return either 400 or 404 status := resp.StatusCode require.Containsf(t, []int{400, 404}, status, "Status code for path traversal %s should be 400 or 404, got %d", path, status) // If it's a 404, we expect a "Not Found" message if status == 404 { respBody, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Contains(t, string(respBody), "Not Found", "Blocked traversal should have a \"Not Found\" message for %s", path) } else { require.Contains(t, string(body), "Are you a hacker?", "Blocked traversal should have a \"Not Found\" message for %s", path) } } // Windows-specific traversal attempts // Backslashes are treated as directory separators on Windows. assertTraversalBlocked("/..\\index.html") assertTraversalBlocked("/..\\..\\index.html") assertTraversalBlocked("/..\\..\\..\\Windows\\win.ini") assertTraversalBlocked("/..\\..\\..\\Windows\\System32\\drivers\\etc\\hosts") assertTraversalBlocked("/%5C..%5C..%5CWindows%5Cwin.ini") assertTraversalBlocked("/%255C..%255C..%255CWindows%255Cwin.ini") assertTraversalBlocked("/%5c..%5c..%5cWindows%5cSystem32%5cdrivers%5cetc%5chosts") assertTraversalBlocked("/C:\\Windows\\System32\\cmd.exe") assertTraversalBlocked("/C:%5CWindows%5CSystem32%5Ccmd.exe") assertTraversalBlocked("/%43:%5CWindows%5CSystem32%5Ccmd.exe") assertTraversalBlocked("/%5c%5cserver%5cshare%5csecret.txt") assertTraversalBlocked("//server\\share\\secret.txt") assertTraversalBlocked("//server/share/secret.txt") assertTraversalBlocked("/%2F%2Fserver%2Fshare%2Fsecret.txt") // Attempt with a path that might try to reference Windows drives or absolute paths // Note: These are artificial tests to ensure no drive-letter escapes are allowed. assertTraversalBlocked("/C:\\Windows\\System32\\cmd.exe") assertTraversalBlocked("/C:/Windows/System32/cmd.exe") // Attempt with UNC-like paths (though unlikely in a web context, good to test) assertTraversalBlocked("//server\\share\\secret.txt") // Attempt using a mixture of forward and backward slashes assertTraversalBlocked("/..\\..\\/index.html") // Attempt that includes a null-byte on Windows assertTraversalBlocked("/index.html%00.txt") // Check behavior on an obviously nonexistent and suspicious file assertTraversalBlocked("/\\this\\path\\does\\not\\exist\\..") // Attempts involving relative traversal and current directory reference assertTraversalBlocked("/.\\../index.html") assertTraversalBlocked("/./..\\index.html") } func Benchmark_SanitizePath(b *testing.B) { bench := func(name string, filesystem fs.FS, path []byte) { b.Run(name, func(b *testing.B) { b.ReportAllocs() for b.Loop() { if _, err := sanitizePath(path, filesystem); err != nil { b.Fatal(err) } } }) } bench("nilFS - urlencoded chars", nil, []byte("/foo%2Fbar/../baz%20qux/index.html")) bench("dirFS - urlencoded chars", os.DirFS("."), []byte("/foo%2Fbar/../baz%20qux/index.html")) bench("nilFS - slashes", nil, []byte("\\foo%2Fbar\\baz%20qux\\index.html")) } func Test_SanitizePath(t *testing.T) { t.Parallel() type testCase struct { filesystem fs.FS name string expectPath string input []byte } testCases := []testCase{ {name: "simple path", input: []byte("/foo/bar.txt"), expectPath: "/foo/bar.txt"}, {name: "traversal attempt", input: []byte("/foo/../../bar.txt"), expectPath: "/bar.txt"}, {name: "encoded traversal", input: []byte("/foo/%2e%2e/bar.txt"), expectPath: "/bar.txt"}, {name: "double encoded traversal", input: []byte("/%252e%252e/bar.txt"), expectPath: "/bar.txt"}, {name: "current dir reference", input: []byte("/foo/./bar.txt"), expectPath: "/foo/bar.txt"}, {name: "encoded slash", input: []byte("/foo%2Fbar.txt"), expectPath: "/foo/bar.txt"}, {name: "empty path", input: []byte(""), expectPath: "/"}, {name: "dot segments", input: []byte("/foo/./bar/../baz.txt"), expectPath: "/foo/baz.txt"}, {name: "leading dot segment", input: []byte("/./foo/bar.txt"), expectPath: "/foo/bar.txt"}, {name: "encoded space", input: []byte("/foo%20bar/baz.txt"), expectPath: "/foo bar/baz.txt"}, {name: "encoded plus literal", input: []byte("/foo+bar/baz.txt"), expectPath: "/foo+bar/baz.txt"}, // windows-specific paths {name: "backslash path", input: []byte("\\foo\\bar.txt"), expectPath: "/foo/bar.txt"}, {name: "backslash traversal", input: []byte("\\foo\\..\\..\\bar.txt"), expectPath: "/bar.txt"}, {name: "mixed slashes", input: []byte("/foo\\bar.txt"), expectPath: "/foo/bar.txt"}, {name: "trailing slash preserved", input: []byte("/foo/bar/"), expectPath: "/foo/bar/"}, {name: "encoded trailing slash", input: []byte("/foo/bar%2F"), expectPath: "/foo/bar"}, {filesystem: os.DirFS("."), name: "filesystem empty path", input: []byte(""), expectPath: "/"}, {filesystem: os.DirFS("."), name: "filesystem trailing slash", input: []byte("/foo/"), expectPath: "/foo/"}, {filesystem: os.DirFS("."), name: "filesystem traversal clean", input: []byte("/foo/../bar.txt"), expectPath: "/bar.txt"}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() got, err := sanitizePath(tc.input, tc.filesystem) require.NoError(t, err) require.Equal(t, tc.expectPath, string(got)) }) } } func Test_SanitizePath_Error(t *testing.T) { t.Parallel() type testCase struct { filesystem fs.FS name string input []byte } testCases := []testCase{ {name: "null byte", input: []byte("/foo/bar.txt%00")}, {name: "encoded backslash traversal", input: []byte("/foo%5C..%5Cbar.txt")}, {name: "double encoded backslash traversal", input: []byte("/%255C..%255C..%255CWindows%255Cwin.ini")}, {name: "encoded backslash absolute", input: []byte("/%5CWindows%5CSystem32%5Cdrivers%5Cetc%5Chosts")}, {name: "double encoded backslash absolute", input: []byte("/%255CWindows%255CSystem32%255Cdrivers%255Cetc%255Chosts")}, {name: "encoded backslash mixed slashes", input: []byte("/..%5C..%5Cetc%5Cpasswd")}, {name: "encoded backslash mixed encoding", input: []byte("/%2e%2e%5c%2e%2e%5cetc%5cpasswd")}, {name: "encoded backslash with encoded slash", input: []byte("/%2e%2e%2f%2e%2e%5cetc%5cpasswd")}, {name: "encoded backslash unc path", input: []byte("//server%5Cshare%5Csecret.txt")}, {name: "encoded backslash drive letter", input: []byte("/C:%5CWindows%5CSystem32%5Ccmd.exe")}, {name: "double slash path", input: []byte("//foo//bar.txt")}, {name: "drive letter", input: []byte("C:/Windows/System32/cmd.exe")}, {name: "drive letter with leading slash", input: []byte("/C:/Windows/System32/cmd.exe")}, {name: "encoded drive letter", input: []byte("/%43:%5CWindows%5CSystem32%5Ccmd.exe")}, {name: "unc path", input: []byte("//server/share/secret.txt")}, {name: "encoded unc path", input: []byte("/%2F%2Fserver%2Fshare%2Fsecret.txt")}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() _, err := sanitizePath(tc.input, tc.filesystem) require.ErrorIs(t, err, ErrInvalidPath, "Expected ErrInvalidPath for input: %s", tc.input) }) } } ================================================ FILE: middleware/timeout/config.go ================================================ package timeout import ( "time" "github.com/gofiber/fiber/v3" ) // Config holds the configuration for the timeout middleware. type Config struct { // Next defines a function to skip this middleware. // Optional. Default: nil Next func(c fiber.Ctx) bool // OnTimeout is executed when a timeout occurs. // Optional. Default: nil (return fiber.ErrRequestTimeout) OnTimeout fiber.Handler // Errors defines custom errors that are treated as timeouts. // Optional. Default: nil Errors []error // Timeout defines the timeout duration for all routes. // Optional. Default: 0 (no timeout) Timeout time.Duration } // ConfigDefault is the default configuration. var ConfigDefault = Config{ Next: nil, Timeout: 0, OnTimeout: nil, Errors: nil, } // configDefault returns the first Config value or ConfigDefault. func configDefault(config ...Config) Config { if len(config) < 1 { return ConfigDefault } cfg := config[0] if cfg.Timeout < 0 { cfg.Timeout = ConfigDefault.Timeout } if cfg.Errors == nil { cfg.Errors = ConfigDefault.Errors } if cfg.OnTimeout == nil { cfg.OnTimeout = ConfigDefault.OnTimeout } if cfg.Next == nil { cfg.Next = ConfigDefault.Next } return cfg } ================================================ FILE: middleware/timeout/timeout.go ================================================ package timeout import ( "context" "errors" "runtime/debug" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/log" ) // New enforces a timeout for each incoming request. It replaces the request's // context with one that has the configured deadline, which is exposed through // c.Context(). Handlers can detect the timeout by listening on c.Context().Done() // and return early. // // When a timeout occurs, the middleware returns immediately with fiber.ErrRequestTimeout // (or the result of OnTimeout if configured). The handler goroutine can continue // safely, and resources are recycled when it finishes via the Abandon/ForceRelease // mechanism. func New(h fiber.Handler, config ...Config) fiber.Handler { cfg := configDefault(config...) return func(ctx fiber.Ctx) error { if cfg.Next != nil && cfg.Next(ctx) { return h(ctx) } timeout := cfg.Timeout if timeout <= 0 { return h(ctx) } // Create timeout context - handler can check c.Context().Done() parent := ctx.Context() tCtx, cancel := context.WithTimeout(parent, timeout) ctx.SetContext(tCtx) // Channels for handler result and panics done := make(chan error, 1) panicChan := make(chan any, 1) // Run handler in goroutine so we can race against the timeout go func() { defer func() { if p := recover(); p != nil { log.Errorw("panic recovered in timeout handler", "panic", p, "stack", string(debug.Stack())) select { case panicChan <- p: default: // Middleware already returned, panic value discarded } } }() err := h(ctx) select { case done <- err: default: // Middleware already returned, error discarded } }() // Wait for handler completion, panic, or timeout select { case err := <-done: // Handler finished normally - cleanup and return cancel() ctx.SetContext(parent) return handleResult(err, ctx, cfg) case <-panicChan: // Handler panicked - cleanup and return error cancel() ctx.SetContext(parent) return fiber.ErrInternalServerError case <-tCtx.Done(): // Timeout occurred - abandon context and return immediately // The cleanup goroutine will cancel the timeout context once the handler finishes; // the abandoned fiber.Ctx stays out of the pool. return handleTimeout(parent, ctx, cancel, done, panicChan, cfg) } } } // handleResult processes the handler's return value func handleResult(err error, ctx fiber.Ctx, cfg Config) error { if err != nil && isTimeoutError(err, cfg.Errors) { return invokeOnTimeout(ctx, cfg) } return err } // handleTimeout handles the timeout case using the Abandon mechanism func handleTimeout( parent context.Context, ctx fiber.Ctx, cancel context.CancelFunc, done <-chan error, panicChan <-chan any, cfg Config, ) error { // Mark fiber context as abandoned - ReleaseCtx will skip pooling. // The context will NOT be returned to the pool. This is an intentional // trade-off: we accept the small memory cost of not recycling timed-out // contexts in exchange for complete race-freedom. // // This is the same approach fasthttp uses - timed-out RequestCtx objects // are never returned to the pool (see fasthttp's releaseCtx which panics // if timeoutResponse is set). ctx.Abandon() // Prepare the timeout response before marking the RequestCtx as timed out so // custom OnTimeout handlers can shape the response body. timeoutErr := invokeOnTimeout(ctx, cfg) // If no OnTimeout handler is configured or the response is still the default // 200/empty, ensure a sensible timeout response is captured for fasthttp to send. if cfg.OnTimeout == nil || (ctx.Response().StatusCode() == fiber.StatusOK && len(ctx.Response().Body()) == 0) { ctx.Response().SetStatusCode(fiber.StatusRequestTimeout) if len(ctx.Response().Body()) == 0 { ctx.Response().SetBodyString(fiber.ErrRequestTimeout.Message) } } // Tell fasthttp to not recycle the RequestCtx - it will acquire a new one // for the response and send the captured payload (either default or from // OnTimeout). All ctx mutations after this call are ignored by fasthttp. ctx.RequestCtx().TimeoutErrorWithResponse(&ctx.RequestCtx().Response) // Spawn cleanup goroutine that waits for handler to finish. // This only does context cleanup (cancel + restore parent), NOT ctx release. // The fiber.Ctx is intentionally NOT released to avoid races with requestHandler // which may still access ctx (e.g., ErrorHandler) after this function returns. // ForceRelease cannot be called safely here for the same reason. go func() { select { case <-done: case <-panicChan: } // Handler finished - cancel timeout context and restore parent cancel() ctx.SetContext(parent) // TODO: Currently the ctx is not returned to the pool (memory leak for timed-out requests). // Future improvement: Implement a concurrent "garbage collector" list where abandoned // contexts are queued after both the handler AND requestHandler are done. A background // goroutine would periodically process this list and call ForceRelease() to recycle // the contexts safely. This would require tracking when requestHandler finishes // (e.g., via a channel signaled in ReleaseCtx) without adding per-request overhead // for non-timeout cases. }() return timeoutErr } // invokeOnTimeout calls the OnTimeout handler if configured func invokeOnTimeout(ctx fiber.Ctx, cfg Config) error { if cfg.OnTimeout != nil { return cfg.OnTimeout(ctx) } return fiber.ErrRequestTimeout } // isTimeoutError checks if err is a timeout-like error (context.DeadlineExceeded // or any of the custom errors). func isTimeoutError(err error, customErrors []error) bool { if errors.Is(err, context.DeadlineExceeded) { return true } if len(customErrors) > 0 { for _, e := range customErrors { if errors.Is(err, e) { return true } } } return false } ================================================ FILE: middleware/timeout/timeout_test.go ================================================ package timeout import ( "context" "errors" "fmt" "io" "net/http" "net/http/httptest" "sync/atomic" "testing" "time" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" "github.com/gofiber/fiber/v3" ) var ( // Custom error that we treat like a timeout when returned by the handler. errCustomTimeout = errors.New("custom timeout error") // Some unrelated error that should NOT trigger a request timeout. errUnrelated = errors.New("unmatched error") ) // sleepWithContext simulates a task that takes `d` time, but returns `te` if the context is canceled. func sleepWithContext(ctx context.Context, d time.Duration, te error) error { timer := time.NewTimer(d) defer timer.Stop() // Clean up the timer select { case <-ctx.Done(): return te case <-timer.C: return nil } } // TestTimeout_Success tests a handler that completes within the allotted timeout. func TestTimeout_Success(t *testing.T) { t.Parallel() app := fiber.New() // Our middleware wraps a handler that sleeps for 10ms, well under the 50ms limit. app.Get("/fast", New(func(c fiber.Ctx) error { // Simulate some work if err := sleepWithContext(c.Context(), 10*time.Millisecond, context.DeadlineExceeded); err != nil { return err } return c.SendString("OK") }, Config{Timeout: 50 * time.Millisecond})) req := httptest.NewRequest(fiber.MethodGet, "/fast", http.NoBody) resp, err := app.Test(req) require.NoError(t, err, "app.Test(req) should not fail") require.Equal(t, fiber.StatusOK, resp.StatusCode, "Expected 200 OK for fast requests") } // TestTimeout_Exceeded tests a handler that exceeds the provided timeout. func TestTimeout_Exceeded(t *testing.T) { t.Parallel() app := fiber.New() // This handler listens for context cancelation and returns early when timeout occurs. app.Get("/slow", New(func(c fiber.Ctx) error { if err := sleepWithContext(c.Context(), 200*time.Millisecond, context.DeadlineExceeded); err != nil { return err } return c.SendString("Should never get here") }, Config{Timeout: 50 * time.Millisecond})) req := httptest.NewRequest(fiber.MethodGet, "/slow", http.NoBody) start := time.Now() resp, err := app.Test(req) elapsed := time.Since(start) require.NoError(t, err, "app.Test(req) should not fail") require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode, "Expected 408 Request Timeout") // Handler should return shortly after timeout (not wait full 200ms) require.Less(t, elapsed, 150*time.Millisecond, "handler should return early on context cancelation") } // TestTimeout_ContextPropagation verifies that the timeout context is properly // passed to the handler so it can detect cancelation (Issue #3671). func TestTimeout_ContextPropagation(t *testing.T) { t.Parallel() app := fiber.New() errCh := make(chan error, 1) app.Get("/context-aware", New(func(c fiber.Ctx) error { timer := time.NewTimer(500 * time.Millisecond) defer timer.Stop() select { case <-timer.C: errCh <- nil return c.SendString("completed") case <-c.Context().Done(): errCh <- c.Context().Err() return c.Context().Err() } }, Config{Timeout: 50 * time.Millisecond})) req := httptest.NewRequest(fiber.MethodGet, "/context-aware", http.NoBody) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode) safety := time.NewTimer(1 * time.Second) defer safety.Stop() select { case handlerErr := <-errCh: require.ErrorIs(t, handlerErr, context.DeadlineExceeded, "handler should report DeadlineExceeded") case <-safety.C: t.Fatal("timed out waiting for handler to report context state") } } // TestTimeout_HandlerReturnsEarlyOnCancel verifies that handlers checking context // can return early, making the overall request faster than the handler's work time. func TestTimeout_HandlerReturnsEarlyOnCancel(t *testing.T) { t.Parallel() app := fiber.New() app.Get("/early-return", New(func(c fiber.Ctx) error { // Handler that would take 500ms but checks context for i := 0; i < 50; i++ { select { case <-c.Context().Done(): return c.Context().Err() case <-time.After(10 * time.Millisecond): // Continue work } } return c.SendString("completed") }, Config{Timeout: 30 * time.Millisecond})) req := httptest.NewRequest(fiber.MethodGet, "/early-return", http.NoBody) start := time.Now() resp, err := app.Test(req) elapsed := time.Since(start) require.NoError(t, err) require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode) // Should complete much faster than 500ms because handler checks context require.Less(t, elapsed, 100*time.Millisecond) } // TestTimeout_CustomError tests that returning a user-defined error is also treated as a timeout. func TestTimeout_CustomError(t *testing.T) { t.Parallel() app := fiber.New() // This handler sleeps 50ms and returns errCustomTimeout if canceled. app.Get("/custom", New(func(c fiber.Ctx) error { // Sleep might time out, or might return early. If the context is canceled, // we treat errCustomTimeout as a 'timeout-like' condition. if err := sleepWithContext(c.Context(), 200*time.Millisecond, errCustomTimeout); err != nil { return fmt.Errorf("wrapped: %w", err) } return c.SendString("Should never get here") }, Config{Timeout: 50 * time.Millisecond, Errors: []error{errCustomTimeout}})) req := httptest.NewRequest(fiber.MethodGet, "/custom", http.NoBody) resp, err := app.Test(req) require.NoError(t, err, "app.Test(req) should not fail") require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode, "Expected 408 for custom timeout error") } // TestTimeout_UnmatchedError checks that if the handler returns an error // that is neither a deadline exceeded nor a custom 'timeout' error, it is // propagated as a regular 500 (internal server error). func TestTimeout_UnmatchedError(t *testing.T) { t.Parallel() app := fiber.New() app.Get("/unmatched", New(func(_ fiber.Ctx) error { return errUnrelated // Not in the custom error list }, Config{Timeout: 100 * time.Millisecond, Errors: []error{errCustomTimeout}})) req := httptest.NewRequest(fiber.MethodGet, "/unmatched", http.NoBody) resp, err := app.Test(req) require.NoError(t, err, "app.Test(req) should not fail") require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode, "Expected 500 because the error is not recognized as a timeout error") } // TestTimeout_ZeroDuration tests the edge case where the timeout is set to zero. // Usually this means the request can never exceed a 'deadline' – effectively no timeout. func TestTimeout_ZeroDuration(t *testing.T) { t.Parallel() app := fiber.New() app.Get("/zero", New(func(c fiber.Ctx) error { // Sleep 50ms, but there's no real 'deadline' since zero-timeout. time.Sleep(50 * time.Millisecond) return c.SendString("No timeout used") })) req := httptest.NewRequest(fiber.MethodGet, "/zero", http.NoBody) resp, err := app.Test(req) require.NoError(t, err, "app.Test(req) should not fail") require.Equal(t, fiber.StatusOK, resp.StatusCode, "Expected 200 OK with zero timeout") } // TestTimeout_NegativeDuration ensures negative timeout values fall back to zero. func TestTimeout_NegativeDuration(t *testing.T) { t.Parallel() app := fiber.New() app.Get("/negative", New(func(c fiber.Ctx) error { time.Sleep(50 * time.Millisecond) return c.SendString("No timeout used") }, Config{Timeout: -100 * time.Millisecond})) req := httptest.NewRequest(fiber.MethodGet, "/negative", http.NoBody) resp, err := app.Test(req) require.NoError(t, err, "app.Test(req) should not fail") require.Equal(t, fiber.StatusOK, resp.StatusCode, "Expected 200 OK with zero timeout") } // TestTimeout_CustomHandler ensures that a custom handler runs on timeout. func TestTimeout_CustomHandler(t *testing.T) { t.Parallel() app := fiber.New() var called atomic.Int32 app.Get("/custom-handler", New(func(c fiber.Ctx) error { if err := sleepWithContext(c.Context(), 100*time.Millisecond, context.DeadlineExceeded); err != nil { return err } return c.SendString("should not reach") }, Config{ Timeout: 20 * time.Millisecond, OnTimeout: func(c fiber.Ctx) error { called.Add(1) return c.Status(408).JSON(fiber.Map{"error": "timeout"}) }, })) req := httptest.NewRequest(fiber.MethodGet, "/custom-handler", http.NoBody) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode) require.Equal(t, int32(1), called.Load()) body, readErr := io.ReadAll(resp.Body) require.NoError(t, readErr) require.JSONEq(t, `{"error":"timeout"}`, string(body)) } // TestTimeout_PanicInHandler verifies that panics in the handler return 500. func TestTimeout_PanicInHandler(t *testing.T) { t.Parallel() app := fiber.New() app.Get("/panic", New(func(_ fiber.Ctx) error { panic("test panic") }, Config{Timeout: 100 * time.Millisecond})) req := httptest.NewRequest(fiber.MethodGet, "/panic", http.NoBody) resp, err := app.Test(req) require.NoError(t, err) // Panic in handler results in 500 Internal Server Error require.Equal(t, fiber.StatusInternalServerError, resp.StatusCode) } // TestIsTimeoutError_DeadlineExceeded ensures context.DeadlineExceeded triggers timeout. func TestIsTimeoutError_DeadlineExceeded(t *testing.T) { t.Parallel() require.True(t, isTimeoutError(context.DeadlineExceeded, nil)) require.True(t, isTimeoutError(fmt.Errorf("wrap: %w", context.DeadlineExceeded), nil)) } // TestIsTimeoutError_CustomErrors verifies custom errors are detected. func TestIsTimeoutError_CustomErrors(t *testing.T) { t.Parallel() customErr := errors.New("custom timeout") require.True(t, isTimeoutError(customErr, []error{customErr})) require.True(t, isTimeoutError(fmt.Errorf("wrap: %w", customErr), []error{customErr})) require.False(t, isTimeoutError(errUnrelated, []error{customErr})) } // TestIsTimeoutError_WithOnTimeout verifies that custom OnTimeout is called for custom errors. func TestIsTimeoutError_WithOnTimeout(t *testing.T) { t.Parallel() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) called := false cfg := Config{ Timeout: 100 * time.Millisecond, Errors: []error{errCustomTimeout}, OnTimeout: func(_ fiber.Ctx) error { called = true return errors.New("handled") }, } // Test via full middleware to ensure OnTimeout is called handler := New(func(_ fiber.Ctx) error { return fmt.Errorf("wrap: %w", errCustomTimeout) }, cfg) err := handler(ctx) require.True(t, called) require.EqualError(t, err, "handled") } // TestTimeout_ImmediateReturn verifies that the middleware returns immediately on timeout // without waiting for the handler to finish (using Abandon mechanism). func TestTimeout_ImmediateReturn(t *testing.T) { t.Parallel() app := fiber.New() handlerStarted := make(chan struct{}) handlerDone := make(chan struct{}) app.Get("/immediate", New(func(_ fiber.Ctx) error { close(handlerStarted) // Handler takes 500ms but middleware should return after 20ms time.Sleep(500 * time.Millisecond) close(handlerDone) return nil }, Config{Timeout: 20 * time.Millisecond})) req := httptest.NewRequest(fiber.MethodGet, "/immediate", http.NoBody) start := time.Now() resp, err := app.Test(req) elapsed := time.Since(start) require.NoError(t, err) require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode) // Middleware should return immediately after timeout, not wait 500ms require.Less(t, elapsed, 100*time.Millisecond, "middleware should return immediately on timeout") // Wait for handler to verify it was abandoned properly <-handlerStarted select { case <-handlerDone: // Handler finished - cleanup goroutine should have released context case <-time.After(1 * time.Second): t.Log("Handler still running (expected for abandoned context)") } } // TestTimeout_PanicAfterTimeout ensures panics after a timeout are handled. func TestTimeout_PanicAfterTimeout(t *testing.T) { t.Parallel() app := fiber.New() panicDone := make(chan struct{}) app.Get("/panic-after-timeout", New(func(c fiber.Ctx) error { <-c.Context().Done() defer close(panicDone) panic("panic after timeout") }, Config{Timeout: 20 * time.Millisecond})) req := httptest.NewRequest(fiber.MethodGet, "/panic-after-timeout", http.NoBody) resp, err := app.Test(req) require.NoError(t, err) // With immediate return, we get 408 (not 500) because panic happens after middleware returned require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode) // Wait for panic to occur and be handled by cleanup goroutine select { case <-panicDone: case <-time.After(200 * time.Millisecond): t.Fatal("panic did not occur") } } // TestTimeout_Next verifies the Next function skips the middleware. func TestTimeout_Next(t *testing.T) { t.Parallel() app := fiber.New() app.Get("/skip", New(func(c fiber.Ctx) error { time.Sleep(100 * time.Millisecond) return c.SendString("OK") }, Config{ Timeout: 10 * time.Millisecond, Next: func(_ fiber.Ctx) bool { return true // Always skip }, })) req := httptest.NewRequest(fiber.MethodGet, "/skip", http.NoBody) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode, "Middleware should be skipped") } // TestTimeout_ContextCleanup verifies that the context is properly released // after the handler finishes (even after timeout). func TestTimeout_ContextCleanup(t *testing.T) { t.Parallel() app := fiber.New() handlerDone := make(chan struct{}) app.Get("/cleanup", New(func(c fiber.Ctx) error { defer close(handlerDone) <-c.Context().Done() // Small delay to simulate cleanup time.Sleep(50 * time.Millisecond) return nil }, Config{Timeout: 20 * time.Millisecond})) req := httptest.NewRequest(fiber.MethodGet, "/cleanup", http.NoBody) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusRequestTimeout, resp.StatusCode) // Wait for handler to finish - cleanup goroutine should release context select { case <-handlerDone: // Give cleanup goroutine time to run time.Sleep(20 * time.Millisecond) case <-time.After(200 * time.Millisecond): t.Fatal("handler did not finish") } } // TestTimeout_AbandonMechanism verifies the Abandon mechanism works correctly. func TestTimeout_AbandonMechanism(t *testing.T) { t.Parallel() app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) t.Cleanup(ctx.ForceRelease) // Initially not abandoned require.False(t, ctx.IsAbandoned()) // Abandon it ctx.Abandon() require.True(t, ctx.IsAbandoned()) // ReleaseCtx should be a no-op when abandoned app.ReleaseCtx(ctx) require.True(t, ctx.IsAbandoned(), "ReleaseCtx should not release abandoned context") // Note: We intentionally do NOT test ForceRelease here. // In the timeout middleware, abandoned contexts are NOT released back to the pool // to avoid race conditions with requestHandler. This is the same approach // fasthttp uses for timed-out RequestCtx objects. } ================================================ FILE: mount.go ================================================ // ⚡️ Fiber is an Express inspired web framework written in Go with ☕️ // 🤖 GitHub Repository: https://github.com/gofiber/fiber // 📌 API Documentation: https://docs.gofiber.io package fiber import ( "sort" "sync" "sync/atomic" "github.com/gofiber/utils/v2" ) // Put fields related to mounting. type mountFields struct { // Mounted and main apps appList map[string]*App // Prefix of app if it was mounted mountPath string // Ordered keys of apps (sorted by key length for Render) appListKeys []string // check added routes of sub-apps subAppsRoutesAdded sync.Once // check mounted sub-apps subAppsProcessed sync.Once } // Create empty mountFields instance func newMountFields(app *App) *mountFields { return &mountFields{ appList: map[string]*App{"": app}, appListKeys: make([]string, 0), } } // Mount attaches another app instance as a sub-router along a routing path. // It's very useful to split up a large API as many independent routers and // compose them as a single service using Mount. The fiber's error handler and // any of the fiber's sub apps are added to the application's error handlers // to be invoked on errors that happen within the prefix route. func (app *App) mount(prefix string, subApp *App) Router { prefix = utils.TrimRight(prefix, '/') if prefix == "" { prefix = "/" } app.mutex.Lock() // Support for configs of mounted-apps and sub-mounted-apps for mountedPrefixes, subApp := range subApp.mountFields.appList { path := getGroupPath(prefix, mountedPrefixes) subApp.mountFields.mountPath = path app.mountFields.appList[path] = subApp } app.mutex.Unlock() // register mounted group mountGroup := &Group{Prefix: prefix, app: subApp} app.register([]string{methodUse}, prefix, mountGroup) // Execute onMount hooks if err := subApp.hooks.executeOnMountHooks(app); err != nil { panic(err) } return app } // Mount attaches another app instance as a sub-router along a routing path. // It's very useful to split up a large API as many independent routers and // compose them as a single service using Mount. func (grp *Group) mount(prefix string, subApp *App) Router { groupPath := getGroupPath(grp.Prefix, prefix) groupPath = utils.TrimRight(groupPath, '/') if groupPath == "" { groupPath = "/" } grp.app.mutex.Lock() // Support for configs of mounted-apps and sub-mounted-apps for mountedPrefixes, subApp := range subApp.mountFields.appList { path := getGroupPath(groupPath, mountedPrefixes) subApp.mountFields.mountPath = path grp.app.mountFields.appList[path] = subApp } grp.app.mutex.Unlock() // register mounted group mountGroup := &Group{Prefix: groupPath, app: subApp} grp.app.register([]string{methodUse}, groupPath, mountGroup) // Execute onMount hooks if err := subApp.hooks.executeOnMountHooks(grp.app); err != nil { panic(err) } return grp } // MountPath returns the route pattern where the current app instance was mounted as a sub-application. func (app *App) MountPath() string { return app.mountFields.mountPath } // hasMountedApps Checks if there are any mounted apps in the current application. func (app *App) hasMountedApps() bool { return len(app.mountFields.appList) > 1 } // mountStartupProcess Handles the startup process of mounted apps by appending sub-app routes, generating app list keys, and processing sub-app routes. func (app *App) mountStartupProcess() { if app.hasMountedApps() { // add routes of sub-apps app.mountFields.subAppsProcessed.Do(func() { app.appendSubAppLists(app.mountFields.appList) app.generateAppListKeys() }) // adds the routes of the sub-apps to the current application. app.mountFields.subAppsRoutesAdded.Do(func() { app.processSubAppsRoutes() }) } } // generateAppListKeys generates app list keys for Render, should work after appendSubAppLists func (app *App) generateAppListKeys() { for key := range app.mountFields.appList { app.mountFields.appListKeys = append(app.mountFields.appListKeys, key) } sort.Slice(app.mountFields.appListKeys, func(i, j int) bool { return len(app.mountFields.appListKeys[i]) < len(app.mountFields.appListKeys[j]) }) } // appendSubAppLists supports nested for sub apps func (app *App) appendSubAppLists(appList map[string]*App, parent ...string) { // Optimize: Cache parent prefix parentPrefix := "" if len(parent) > 0 { parentPrefix = parent[0] } for prefix, subApp := range appList { // skip real app if prefix == "" { continue } if parentPrefix != "" { prefix = getGroupPath(parentPrefix, prefix) } if _, ok := app.mountFields.appList[prefix]; !ok { app.mountFields.appList[prefix] = subApp } // The first element of appList is always the app itself. If there are no other sub apps, we should skip appending nested apps. if len(subApp.mountFields.appList) > 1 { app.appendSubAppLists(subApp.mountFields.appList, prefix) } } } // processSubAppsRoutes adds routes of sub-apps recursively when the server is started func (app *App) processSubAppsRoutes() { for prefix, subApp := range app.mountFields.appList { // skip real app if prefix == "" { continue } // process the inner routes if subApp.hasMountedApps() { subApp.mountFields.subAppsRoutesAdded.Do(func() { subApp.processSubAppsRoutes() }) } } var handlersCount uint32 // Iterate over the stack of the parent app for m := range app.stack { // Iterate over each route in the stack stackLen := len(app.stack[m]) for i := 0; i < stackLen; i++ { route := app.stack[m][i] // Check if the route has a mounted app if !route.mount { if !route.use || (route.use && m == 0) { handlersCount += uint32(len(route.Handlers)) //nolint:gosec // G115 - handler count is always small } continue } // Create a slice to hold the sub-app's routes subRoutes := make([]*Route, len(route.group.app.stack[m])) // Iterate over the sub-app's routes for j, subAppRoute := range route.group.app.stack[m] { // Clone the sub-app's route subAppRouteClone := app.copyRoute(subAppRoute) // Add the parent route's path as a prefix to the sub-app's route app.addPrefixToRoute(route.path, subAppRouteClone) // Add the cloned sub-app's route to the slice of sub-app routes subRoutes[j] = subAppRouteClone } // Insert the sub-app's routes into the parent app's stack newStack := make([]*Route, len(app.stack[m])+len(subRoutes)-1) copy(newStack[:i], app.stack[m][:i]) copy(newStack[i:i+len(subRoutes)], subRoutes) copy(newStack[i+len(subRoutes):], app.stack[m][i+1:]) app.stack[m] = newStack i-- // Mark the parent app's routes as refreshed app.routesRefreshed = true // update stackLen after appending subRoutes to app.stack[m] stackLen = len(app.stack[m]) } } atomic.StoreUint32(&app.handlersCount, handlersCount) } ================================================ FILE: mount_test.go ================================================ // ⚡️ Fiber is an Express inspired web framework written in Go with ☕️ // 🤖 GitHub Repository: https://github.com/gofiber/fiber // 📌 API Documentation: https://docs.gofiber.io package fiber import ( "errors" "io" "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/require" ) // go test -run Test_App_Mount func Test_App_Mount(t *testing.T) { t.Parallel() micro := New() micro.Get("/doe", func(c Ctx) error { return c.SendStatus(StatusOK) }) app := New() app.Use("/john", micro) resp, err := app.Test(httptest.NewRequest(MethodGet, "/john/doe", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.Equal(t, uint32(2), app.handlersCount) } func Test_App_Mount_RootPath_Nested(t *testing.T) { t.Parallel() app := New() dynamic := New() apiserver := New() apiroutes := apiserver.Group("/v1") apiroutes.Get("/home", func(c Ctx) error { return c.SendString("home") }) dynamic.Use("/api", apiserver) app.Use("/", dynamic) resp, err := app.Test(httptest.NewRequest(MethodGet, "/api/v1/home", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.Equal(t, uint32(2), app.handlersCount) } // go test -run Test_App_Mount_Nested func Test_App_Mount_Nested(t *testing.T) { t.Parallel() app := New() one := New() two := New() three := New() two.Use("/three", three) app.Use("/one", one) one.Use("/two", two) one.Get("/doe", func(c Ctx) error { return c.SendStatus(StatusOK) }) two.Get("/nested", func(c Ctx) error { return c.SendStatus(StatusOK) }) three.Get("/test", func(c Ctx) error { return c.SendStatus(StatusOK) }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/one/doe", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") resp, err = app.Test(httptest.NewRequest(MethodGet, "/one/two/nested", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") resp, err = app.Test(httptest.NewRequest(MethodGet, "/one/two/three/test", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.Equal(t, uint32(6), app.handlersCount) } // go test -run Test_App_Mount_Express_Behavior func Test_App_Mount_Express_Behavior(t *testing.T) { t.Parallel() createTestHandler := func(body string) func(c Ctx) error { return func(c Ctx) error { return c.SendString(body) } } testEndpoint := func(app *App, route, expectedBody string, expectedStatusCode int) { resp, err := app.Test(httptest.NewRequest(MethodGet, route, http.NoBody)) require.NoError(t, err, "app.Test(req)") body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, expectedStatusCode, resp.StatusCode, "Status code") require.Equal(t, expectedBody, string(body), "Unexpected response body") } app := New() subApp := New() // app setup subApp.Get("/hello", createTestHandler("subapp hello!")) subApp.Get("/world", createTestHandler("subapp world!")) // <- wins app.Get("/hello", createTestHandler("app hello!")) // <- wins app.Use("/", subApp) // <- subApp registration app.Get("/world", createTestHandler("app world!")) app.Get("/bar", createTestHandler("app bar!")) subApp.Get("/bar", createTestHandler("subapp bar!")) // <- wins subApp.Get("/foo", createTestHandler("subapp foo!")) // <- wins app.Get("/foo", createTestHandler("app foo!")) // 404 Handler app.Use(func(c Ctx) error { return c.SendStatus(StatusNotFound) }) // expectation check testEndpoint(app, "/world", "subapp world!", StatusOK) testEndpoint(app, "/hello", "app hello!", StatusOK) testEndpoint(app, "/bar", "subapp bar!", StatusOK) testEndpoint(app, "/foo", "subapp foo!", StatusOK) testEndpoint(app, "/unknown", ErrNotFound.Message, StatusNotFound) require.Equal(t, uint32(17), app.handlersCount) } // go test -run Test_App_Mount_RoutePositions func Test_App_Mount_RoutePositions(t *testing.T) { t.Parallel() testEndpoint := func(app *App, route, expectedBody string) { resp, err := app.Test(httptest.NewRequest(MethodGet, route, http.NoBody)) require.NoError(t, err, "app.Test(req)") body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, StatusOK, resp.StatusCode, "Status code") require.Equal(t, expectedBody, string(body), "Unexpected response body") } app := New() subApp1 := New() subApp2 := New() // app setup { app.Use(func(c Ctx) error { // set initial value c.Locals("world", "world") return c.Next() }) app.Use("/subApp1", subApp1) app.Use(func(c Ctx) error { return c.Next() }) app.Get("/bar", func(c Ctx) error { return c.SendString("ok") }) app.Use(func(c Ctx) error { // is overwritten when the positioning is not correct c.Locals("world", "hello") return c.Next() }) methods := subApp2.Group("/subApp2") methods.Get("/world", func(c Ctx) error { v, ok := c.Locals("world").(string) if !ok { panic("unexpected data type") } return c.SendString(v) }) app.Use("", subApp2) } testEndpoint(app, "/subApp2/world", "hello") routeStackGET := app.Stack()[0] require.True(t, routeStackGET[0].use) require.Equal(t, "/", routeStackGET[0].path) require.True(t, routeStackGET[1].use) require.Equal(t, "/", routeStackGET[1].path) require.False(t, routeStackGET[2].use) require.Equal(t, "/bar", routeStackGET[2].path) require.True(t, routeStackGET[3].use) require.Equal(t, "/", routeStackGET[3].path) require.False(t, routeStackGET[4].use) require.Equal(t, "/subapp2/world", routeStackGET[4].path) require.Len(t, routeStackGET, 5) } // go test -run Test_App_MountPath func Test_App_MountPath(t *testing.T) { t.Parallel() app := New() one := New() two := New() three := New() two.Use("/three", three) one.Use("/two", two) app.Use("/one", one) require.Equal(t, "/one", one.MountPath()) require.Equal(t, "/one/two", two.MountPath()) require.Equal(t, "/one/two/three", three.MountPath()) require.Empty(t, app.MountPath()) } func Test_App_ErrorHandler_GroupMount(t *testing.T) { t.Parallel() micro := New(Config{ ErrorHandler: func(c Ctx, err error) error { require.Equal(t, "0: GET error", err.Error()) return c.Status(500).SendString("1: custom error") }, }) micro.Get("/doe", func(_ Ctx) error { return errors.New("0: GET error") }) app := New() v1 := app.Group("/v1") v1.Use("/john", micro) resp, err := app.Test(httptest.NewRequest(MethodGet, "/v1/john/doe", http.NoBody)) testErrorResponse(t, err, resp, "1: custom error") } func Test_App_ErrorHandler_GroupMountRootLevel(t *testing.T) { t.Parallel() micro := New(Config{ ErrorHandler: func(c Ctx, err error) error { require.Equal(t, "0: GET error", err.Error()) return c.Status(500).SendString("1: custom error") }, }) micro.Get("/john/doe", func(_ Ctx) error { return errors.New("0: GET error") }) app := New() v1 := app.Group("/v1") v1.Use("/", micro) resp, err := app.Test(httptest.NewRequest(MethodGet, "/v1/john/doe", http.NoBody)) testErrorResponse(t, err, resp, "1: custom error") } // go test -run Test_App_Group_Mount func Test_App_Group_Mount(t *testing.T) { t.Parallel() micro := New() micro.Get("/doe", func(c Ctx) error { return c.SendStatus(StatusOK) }) app := New() v1 := app.Group("/v1") v1.Use("/john", micro) resp, err := app.Test(httptest.NewRequest(MethodGet, "/v1/john/doe", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") require.Equal(t, uint32(2), app.handlersCount) } func Test_App_UseParentErrorHandler(t *testing.T) { t.Parallel() app := New(Config{ ErrorHandler: func(ctx Ctx, _ error) error { return ctx.Status(500).SendString("hi, i'm a custom error") }, }) fiber := New() fiber.Get("/", func(_ Ctx) error { return errors.New("something happened") }) app.Use("/api", fiber) resp, err := app.Test(httptest.NewRequest(MethodGet, "/api", http.NoBody)) testErrorResponse(t, err, resp, "hi, i'm a custom error") } func Test_App_UseMountedErrorHandler(t *testing.T) { t.Parallel() app := New() fiber := New(Config{ ErrorHandler: func(c Ctx, _ error) error { return c.Status(500).SendString("hi, i'm a custom error") }, }) fiber.Get("/", func(_ Ctx) error { return errors.New("something happened") }) app.Use("/api", fiber) resp, err := app.Test(httptest.NewRequest(MethodGet, "/api", http.NoBody)) testErrorResponse(t, err, resp, "hi, i'm a custom error") } func Test_App_UseMountedErrorHandlerRootLevel(t *testing.T) { t.Parallel() app := New() fiber := New(Config{ ErrorHandler: func(c Ctx, _ error) error { return c.Status(500).SendString("hi, i'm a custom error") }, }) fiber.Get("/api", func(_ Ctx) error { return errors.New("something happened") }) app.Use("/", fiber) resp, err := app.Test(httptest.NewRequest(MethodGet, "/api", http.NoBody)) testErrorResponse(t, err, resp, "hi, i'm a custom error") } func Test_App_UseMountedErrorHandlerForBestPrefixMatch(t *testing.T) { t.Parallel() app := New() tsf := func(c Ctx, _ error) error { return c.Status(200).SendString("hi, i'm a custom sub fiber error 2") } tripleSubFiber := New(Config{ ErrorHandler: tsf, }) tripleSubFiber.Get("/", func(_ Ctx) error { return errors.New("something happened") }) sf := func(c Ctx, _ error) error { return c.Status(200).SendString("hi, i'm a custom sub fiber error") } subfiber := New(Config{ ErrorHandler: sf, }) subfiber.Get("/", func(_ Ctx) error { return errors.New("something happened") }) subfiber.Use("/third", tripleSubFiber) f := func(c Ctx, _ error) error { return c.Status(200).SendString("hi, i'm a custom error") } fiber := New(Config{ ErrorHandler: f, }) fiber.Get("/", func(_ Ctx) error { return errors.New("something happened") }) fiber.Use("/sub", subfiber) app.Use("/api", fiber) resp, err := app.Test(httptest.NewRequest(MethodGet, "/api/sub", http.NoBody)) require.NoError(t, err, "/api/sub req") require.Equal(t, 200, resp.StatusCode, "Status code") b, err := io.ReadAll(resp.Body) require.NoError(t, err, "iotuil.ReadAll()") require.Equal(t, "hi, i'm a custom sub fiber error", string(b), "Response body") resp2, err := app.Test(httptest.NewRequest(MethodGet, "/api/sub/third", http.NoBody)) require.NoError(t, err, "/api/sub/third req") require.Equal(t, 200, resp2.StatusCode, "Status code") b, err = io.ReadAll(resp2.Body) require.NoError(t, err, "iotuil.ReadAll()") require.Equal(t, "hi, i'm a custom sub fiber error 2", string(b), "Third fiber Response body") } // go test -run Test_Mount_Route_Names func Test_Mount_Route_Names(t *testing.T) { t.Parallel() // create sub-app with 2 handlers: subApp1 := New() subApp1.Get("/users", func(c Ctx) error { url, err := c.GetRouteURL("add-user", Map{}) require.NoError(t, err) require.Equal(t, "/app1/users", url, "handler: app1.add-user") // the prefix is /app1 because of the mount // if subApp1 is not mounted, expected url just /users return nil }).Name("get-users") subApp1.Post("/users", func(c Ctx) error { route := c.App().GetRoute("get-users") require.Equal(t, MethodGet, route.Method, "handler: app1.get-users method") require.Equal(t, "/app1/users", route.Path, "handler: app1.get-users path") return nil }).Name("add-user") // create sub-app with 2 handlers inside a group: subApp2 := New() app2Grp := subApp2.Group("/users").Name("users.") app2Grp.Get("", emptyHandler).Name("get") app2Grp.Post("", emptyHandler).Name("add") // put both sub-apps into root app rootApp := New() _ = rootApp.Use("/app1", subApp1) _ = rootApp.Use("/app2", subApp2) rootApp.startupProcess() // take route directly from sub-app route := subApp1.GetRoute("get-users") require.Equal(t, MethodGet, route.Method) require.Equal(t, "/users", route.Path) route = subApp1.GetRoute("add-user") require.Equal(t, MethodPost, route.Method) require.Equal(t, "/users", route.Path) // take route directly from sub-app with group route = subApp2.GetRoute("users.get") require.Equal(t, MethodGet, route.Method) require.Equal(t, "/users", route.Path) route = subApp2.GetRoute("users.add") require.Equal(t, MethodPost, route.Method) require.Equal(t, "/users", route.Path) // take route from root app (using names of sub-apps) route = rootApp.GetRoute("add-user") require.Equal(t, MethodPost, route.Method) require.Equal(t, "/app1/users", route.Path) route = rootApp.GetRoute("users.add") require.Equal(t, MethodPost, route.Method) require.Equal(t, "/app2/users", route.Path) // GetRouteURL inside handler req := httptest.NewRequest(MethodGet, "/app1/users", http.NoBody) resp, err := rootApp.Test(req) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") // ctx.App().GetRoute() inside handler req = httptest.NewRequest(MethodPost, "/app1/users", http.NoBody) resp, err = rootApp.Test(req) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") } // go test -run Test_Ctx_Render_Mount func Test_Ctx_Render_Mount(t *testing.T) { t.Parallel() engine := &testTemplateEngine{} err := engine.Load() require.NoError(t, err) sub := New(Config{ Views: engine, }) sub.Get("/:name", func(c Ctx) error { return c.Render("hello_world.tmpl", Map{ "Name": c.Params("name"), }) }) app := New() app.Use("/hello", sub) resp, err := app.Test(httptest.NewRequest(MethodGet, "/hello/a", http.NoBody)) require.Equal(t, StatusOK, resp.StatusCode, "Status code") require.NoError(t, err, "app.Test(req)") body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "

Hello a!

", string(body)) } // go test -run Test_Ctx_Render_Mount_ParentOrSubHasViews func Test_Ctx_Render_Mount_ParentOrSubHasViews(t *testing.T) { t.Parallel() engine := &testTemplateEngine{} err := engine.Load() require.NoError(t, err) engine2 := &testTemplateEngine{path: "testdata2"} err = engine2.Load() require.NoError(t, err) engine3 := &testTemplateEngine{path: "testdata3"} err = engine3.Load() require.NoError(t, err) sub := New(Config{ Views: engine3, }) sub2 := New(Config{ Views: engine2, }) app := New(Config{ Views: engine, }) app.Get("/test", func(c Ctx) error { return c.Render("index.tmpl", Map{ "Title": "Hello, World!", }) }) sub.Get("/world/:name", func(c Ctx) error { return c.Render("hello_world.tmpl", Map{ "Name": c.Params("name"), }) }) sub2.Get("/moment", func(c Ctx) error { return c.Render("bruh.tmpl", Map{}) }) sub.Use("/bruh", sub2) app.Use("/hello", sub) resp, err := app.Test(httptest.NewRequest(MethodGet, "/hello/world/a", http.NoBody)) require.Equal(t, StatusOK, resp.StatusCode, "Status code") require.NoError(t, err, "app.Test(req)") body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "

Hello a!

", string(body)) resp, err = app.Test(httptest.NewRequest(MethodGet, "/test", http.NoBody)) require.Equal(t, StatusOK, resp.StatusCode, "Status code") require.NoError(t, err, "app.Test(req)") body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "

Hello, World!

", string(body)) resp, err = app.Test(httptest.NewRequest(MethodGet, "/hello/bruh/moment", http.NoBody)) require.Equal(t, StatusOK, resp.StatusCode, "Status code") require.NoError(t, err, "app.Test(req)") body, err = io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "

I'm Bruh

", string(body)) } func Test_Ctx_Render_MountGroup(t *testing.T) { t.Parallel() engine := &testTemplateEngine{} err := engine.Load() require.NoError(t, err) micro := New(Config{ Views: engine, }) micro.Get("/doe", func(c Ctx) error { return c.Render("hello_world.tmpl", Map{ "Name": "doe", }) }) app := New() v1 := app.Group("/v1") v1.Use("/john", micro) resp, err := app.Test(httptest.NewRequest(MethodGet, "/v1/john/doe", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "

Hello doe!

", string(body)) } ================================================ FILE: path.go ================================================ // ⚡️ Fiber is an Express inspired web framework written in Go with ☕️ // 📄 GitHub Repository: https://github.com/gofiber/fiber // 📌 API Documentation: https://docs.gofiber.io // ⚠️ This path parser was inspired by https://github.com/ucarion/urlpath // 💖 Maintained and modified for Fiber by @renewerner87 package fiber import ( "bytes" "fmt" "regexp" "strconv" "strings" "sync" "time" "unicode" "github.com/gofiber/utils/v2" utilsbytes "github.com/gofiber/utils/v2/bytes" utilsstrings "github.com/gofiber/utils/v2/strings" "github.com/google/uuid" ) // routeParser holds the path segments and param names type routeParser struct { segs []*routeSegment // the parsed segments of the route params []string // that parameter names the parsed route wildCardCount int // number of wildcard parameters, used internally to give the wildcard parameter its number plusCount int // number of plus parameters, used internally to give the plus parameter its number } var routerParserPool = &sync.Pool{ New: func() any { return &routeParser{} }, } // routeSegment holds the segment metadata type routeSegment struct { // const information Const string // constant part of the route ParamName string // name of the parameter for access to it, for wildcards and plus parameters access iterators starting with 1 are added ComparePart string // search part to find the end of the parameter Constraints []*Constraint // Constraint type if segment is a parameter, if not it will be set to noConstraint by default PartCount int // how often is the search part contained in the non-param segments? -> necessary for greedy search Length int // length of the parameter for segment, when its 0 then the length is undetermined // future TODO: add support for optional groups "/abc(/def)?" // parameter information IsParam bool // Truth value that indicates whether it is a parameter or a constant part IsGreedy bool // indicates whether the parameter is greedy or not, is used with wildcard and plus IsOptional bool // indicates whether the parameter is optional or not // common information IsLast bool // shows if the segment is the last one for the route HasOptionalSlash bool // segment has the possibility of an optional slash } // different special routing signs const ( wildcardParam byte = '*' // indicates an optional greedy parameter plusParam byte = '+' // indicates a required greedy parameter optionalParam byte = '?' // concludes a parameter by name and makes it optional paramStarterChar byte = ':' // start character for a parameter with name slashDelimiter byte = '/' // separator for the route, unlike the other delimiters this character at the end can be optional escapeChar byte = '\\' // escape character paramConstraintStart byte = '<' // start of type constraint for a parameter paramConstraintEnd byte = '>' // end of type constraint for a parameter paramConstraintSeparator byte = ';' // separator of type constraints for a parameter paramConstraintDataStart byte = '(' // start of data of type constraint for a parameter paramConstraintDataEnd byte = ')' // end of data of type constraint for a parameter paramConstraintDataSeparator byte = ',' // separator of data of type constraint for a parameter ) // TypeConstraint parameter constraint types type TypeConstraint uint16 // Constraint describes the validation rules that apply to a dynamic route // segment when matching incoming requests. type Constraint struct { RegexCompiler *regexp.Regexp Name string Data []string customConstraints []CustomConstraint ID TypeConstraint } // CustomConstraint is an interface for custom constraints type CustomConstraint interface { // Name returns the name of the constraint. // This name is used in the constraint matching. Name() string // Execute executes the constraint. // It returns true if the constraint is matched and right. // param is the parameter value to check. // args are the constraint arguments. Execute(param string, args ...string) bool } const ( noConstraint TypeConstraint = 1 << iota intConstraint boolConstraint floatConstraint alphaConstraint datetimeConstraint guidConstraint minLenConstraint maxLenConstraint lenConstraint betweenLenConstraint minConstraint maxConstraint rangeConstraint regexConstraint ) const ( needOneData = minLenConstraint | maxLenConstraint | lenConstraint | minConstraint | maxConstraint | datetimeConstraint | regexConstraint needTwoData = betweenLenConstraint | rangeConstraint ) // list of possible parameter and segment delimiter var ( // slash has a special role, unlike the other parameters it must not be interpreted as a parameter routeDelimiter = []byte{slashDelimiter, '-', '.'} // list of greedy parameters greedyParameters = []byte{wildcardParam, plusParam} // list of chars for the parameter recognizing parameterStartChars = [256]bool{ wildcardParam: true, plusParam: true, paramStarterChar: true, } // list of chars of delimiters and the starting parameter name char parameterDelimiterChars = append([]byte{paramStarterChar, escapeChar}, routeDelimiter...) // list of chars to find the end of a parameter parameterEndChars = [256]bool{ optionalParam: true, paramStarterChar: true, escapeChar: true, slashDelimiter: true, '-': true, '.': true, } ) // RoutePatternMatch reports whether path matches the provided Fiber route pattern. // // Patterns use the same syntax as routes registered on an App, including // parameters (for example `:id`), wildcards (`*`, `+`), and optional segments. // The optional Config argument can be used to control case sensitivity and // strict routing behavior. This helper allows checking potential matches // without registering a route. func RoutePatternMatch(path, pattern string, cfg ...Config) bool { // See logic in (*Route).match and (*App).register var ctxParams [maxParams]string config := Config{} if len(cfg) > 0 { config = cfg[0] } if path == "" { path = "/" } // Cannot have an empty pattern if pattern == "" { pattern = "/" } // Pattern always start with a '/' if pattern[0] != '/' { pattern = "/" + pattern } patternPretty := []byte(pattern) // Case-sensitive routing, all to lowercase if !config.CaseSensitive { patternPretty = utilsbytes.UnsafeToLower(patternPretty) path = utilsstrings.ToLower(path) } // Strict routing, remove trailing slashes if !config.StrictRouting && len(patternPretty) > 1 { patternPretty = utils.TrimRight(patternPretty, '/') } parser, _ := routerParserPool.Get().(*routeParser) //nolint:errcheck // only contains routeParser parser.reset() patternStr := string(patternPretty) parser.parseRoute(patternStr) defer routerParserPool.Put(parser) // '*' wildcard matches any path if (patternStr == "/" && path == "/") || patternStr == "/*" { return true } // Does this route have parameters if len(parser.params) > 0 { if match := parser.getMatch(path, path, &ctxParams, false); match { return true } } // Check for a simple match patternPretty = RemoveEscapeCharBytes(patternPretty) return string(patternPretty) == path } func (parser *routeParser) reset() { parser.segs = parser.segs[:0] parser.params = parser.params[:0] parser.wildCardCount = 0 parser.plusCount = 0 } // parseRoute analyzes the route and divides it into segments for constant areas and parameters, // this information is needed later when assigning the requests to the declared routes func (parser *routeParser) parseRoute(pattern string, customConstraints ...CustomConstraint) { var n int var seg *routeSegment for pattern != "" { nextParamPosition := findNextParamPosition(pattern) // handle the parameter part if nextParamPosition == 0 { n, seg = parser.analyseParameterPart(pattern, customConstraints...) parser.params, parser.segs = append(parser.params, seg.ParamName), append(parser.segs, seg) } else { n, seg = parser.analyseConstantPart(pattern, nextParamPosition) parser.segs = append(parser.segs, seg) } pattern = pattern[n:] } // mark last segment if len(parser.segs) > 0 { parser.segs[len(parser.segs)-1].IsLast = true } parser.segs = addParameterMetaInfo(parser.segs) } // parseRoute analyzes the route and divides it into segments for constant areas and parameters, // this information is needed later when assigning the requests to the declared routes func parseRoute(pattern string, customConstraints ...CustomConstraint) routeParser { parser := routeParser{} parser.parseRoute(pattern, customConstraints...) // Check if the route has too many parameters if len(parser.params) > maxParams { panic(fmt.Sprintf("Route '%s' has %d parameters, which exceeds the maximum of %d", pattern, len(parser.params), maxParams)) } return parser } // addParameterMetaInfo add important meta information to the parameter segments // to simplify the search for the end of the parameter func addParameterMetaInfo(segs []*routeSegment) []*routeSegment { var comparePart string segLen := len(segs) // loop from end to begin for i := segLen - 1; i >= 0; i-- { // set the compare part for the parameter if segs[i].IsParam { // important for finding the end of the parameter segs[i].ComparePart = RemoveEscapeChar(comparePart) } else { comparePart = segs[i].Const if len(comparePart) > 1 { comparePart = utils.TrimRight(comparePart, slashDelimiter) } } } // loop from beginning to end for i := range segLen { // check how often the compare part is in the following const parts if segs[i].IsParam { // check if parameter segments are directly after each other; // when neither this parameter nor the next parameter are greedy, we only want one character if segLen > i+1 && !segs[i].IsGreedy && segs[i+1].IsParam && !segs[i+1].IsGreedy { segs[i].Length = 1 } if segs[i].ComparePart == "" { continue } for j := i + 1; j <= len(segs)-1; j++ { if !segs[j].IsParam { // count is important for the greedy match segs[i].PartCount += strings.Count(segs[j].Const, segs[i].ComparePart) } } // check if the end of the segment is an optional slash and then if the segment is optional or the last one } else if segs[i].Const[len(segs[i].Const)-1] == slashDelimiter && (segs[i].IsLast || (segLen > i+1 && segs[i+1].IsOptional)) { segs[i].HasOptionalSlash = true } } return segs } // findNextParamPosition search for the next possible parameter start position func findNextParamPosition(pattern string) int { // Find the first parameter position next := -1 for i := range pattern { if parameterStartChars[pattern[i]] && (i == 0 || pattern[i-1] != escapeChar) { next = i break } } if next > 0 && pattern[next] != wildcardParam { // checking the found parameterStartChar is a cluster for i := next + 1; i < len(pattern); i++ { if !parameterStartChars[pattern[i]] { return i - 1 } } return len(pattern) - 1 } return next } // analyseConstantPart find the end of the constant part and create the route segment func (*routeParser) analyseConstantPart(pattern string, nextParamPosition int) (int, *routeSegment) { // handle the constant part processedPart := pattern if nextParamPosition != -1 { // remove the constant part until the parameter processedPart = pattern[:nextParamPosition] } constPart := RemoveEscapeChar(processedPart) return len(processedPart), &routeSegment{ Const: constPart, Length: len(constPart), } } // analyseParameterPart find the parameter end and create the route segment func (parser *routeParser) analyseParameterPart(pattern string, customConstraints ...CustomConstraint) (int, *routeSegment) { isWildCard := pattern[0] == wildcardParam isPlusParam := pattern[0] == plusParam paramEndPosition := 0 paramConstraintStartPosition := -1 paramConstraintEndPosition := -1 // handle wildcard end if !isWildCard && !isPlusParam { paramEndPosition = -1 search := pattern[1:] for i := range search { if paramConstraintStartPosition == -1 && search[i] == paramConstraintStart && (i == 0 || search[i-1] != escapeChar) { paramConstraintStartPosition = i + 1 continue } if paramConstraintEndPosition == -1 && search[i] == paramConstraintEnd && (i == 0 || search[i-1] != escapeChar) { paramConstraintEndPosition = i + 1 continue } if parameterEndChars[search[i]] { if (paramConstraintStartPosition == -1 && paramConstraintEndPosition == -1) || (paramConstraintStartPosition != -1 && paramConstraintEndPosition != -1) { paramEndPosition = i break } } } switch { case paramEndPosition == -1: paramEndPosition = len(pattern) - 1 case bytes.IndexByte(parameterDelimiterChars, pattern[paramEndPosition+1]) == -1: paramEndPosition++ default: // do nothing } } // cut params part processedPart := pattern[0 : paramEndPosition+1] n := paramEndPosition + 1 paramName := RemoveEscapeChar(GetTrimmedParam(processedPart)) // Check has constraint var constraints []*Constraint if hasConstraint := paramConstraintStartPosition != -1 && paramConstraintEndPosition != -1; hasConstraint { constraintString := pattern[paramConstraintStartPosition+1 : paramConstraintEndPosition] userConstraints := splitNonEscaped(constraintString, paramConstraintSeparator) constraints = make([]*Constraint, 0, len(userConstraints)) for _, c := range userConstraints { start := findNextNonEscapedCharPosition(c, paramConstraintDataStart) end := strings.LastIndexByte(c, paramConstraintDataEnd) // Assign constraint if start != -1 && end != -1 { constraint := &Constraint{ ID: getParamConstraintType(c[:start]), Name: c[:start], customConstraints: customConstraints, } // remove escapes from data if constraint.ID != regexConstraint { constraint.Data = splitNonEscaped(c[start+1:end], paramConstraintDataSeparator) if len(constraint.Data) == 1 { constraint.Data[0] = RemoveEscapeChar(constraint.Data[0]) } else if len(constraint.Data) == 2 { // This is fine, we simply expect two parts constraint.Data[0] = RemoveEscapeChar(constraint.Data[0]) constraint.Data[1] = RemoveEscapeChar(constraint.Data[1]) } } // Precompile regex if has regex constraint if constraint.ID == regexConstraint { constraint.Data = []string{c[start+1 : end]} constraint.RegexCompiler = regexp.MustCompile(constraint.Data[0]) } constraints = append(constraints, constraint) } else { constraints = append(constraints, &Constraint{ ID: getParamConstraintType(c), Data: []string{}, Name: c, customConstraints: customConstraints, }) } } paramName = RemoveEscapeChar(GetTrimmedParam(pattern[0:paramConstraintStartPosition])) } // add access iterator to wildcard and plus if isWildCard { parser.wildCardCount++ paramName += strconv.Itoa(parser.wildCardCount) } else if isPlusParam { parser.plusCount++ paramName += strconv.Itoa(parser.plusCount) } segment := &routeSegment{ ParamName: paramName, IsParam: true, IsOptional: isWildCard || pattern[paramEndPosition] == optionalParam, IsGreedy: isWildCard || isPlusParam, } if len(constraints) > 0 { segment.Constraints = constraints } return n, segment } // findNextNonEscapedCharPosition searches the next char position and skips the escaped characters func findNextNonEscapedCharPosition(search string, char byte) int { for i := 0; i < len(search); i++ { if search[i] == char && (i == 0 || search[i-1] != escapeChar) { return i } } return -1 } // splitNonEscaped slices s into all substrings separated by sep and returns a slice of the substrings between those separators // This function also takes a care of escape char when splitting. func splitNonEscaped(s string, sep byte) []string { var result []string i := findNextNonEscapedCharPosition(s, sep) for i > -1 { result = append(result, s[:i]) s = s[i+1:] i = findNextNonEscapedCharPosition(s, sep) } return append(result, s) } func hasPartialMatchBoundary(path string, matchedLength int) bool { if matchedLength < 0 || matchedLength > len(path) { return false } if matchedLength == len(path) { return true } if matchedLength == 0 { return false } if path[matchedLength-1] == slashDelimiter { return true } if matchedLength < len(path) && path[matchedLength] == slashDelimiter { return true } return false } // getMatch parses the passed url and tries to match it against the route segments and determine the parameter positions func (parser *routeParser) getMatch(detectionPath, path string, params *[maxParams]string, partialCheck bool) bool { //nolint:revive // Accepting a bool param is fine here originalDetectionPath := detectionPath var i, paramsIterator, partLen int for _, segment := range parser.segs { partLen = len(detectionPath) // check const segment if !segment.IsParam { i = segment.Length // is optional part or the const part must match with the given string // check if the end of the segment is an optional slash if segment.HasOptionalSlash && partLen == i-1 && detectionPath == segment.Const[:i-1] { i-- } else if i > partLen || detectionPath[:i] != segment.Const { return false } } else { // determine parameter length i = findParamLen(detectionPath, segment) if !segment.IsOptional && i == 0 { return false } // take over the params positions params[paramsIterator] = path[:i] if !segment.IsOptional || i != 0 { // check constraint for _, c := range segment.Constraints { if matched := c.CheckConstraint(params[paramsIterator]); !matched { return false } } } paramsIterator++ } // reduce founded part from the string if partLen > 0 { detectionPath, path = detectionPath[i:], path[i:] } } if detectionPath != "" { if !partialCheck { return false } consumedLength := len(originalDetectionPath) - len(detectionPath) if !hasPartialMatchBoundary(originalDetectionPath, consumedLength) { return false } } return true } // findParamLen for the expressjs wildcard behavior (right to left greedy) // look at the other segments and take what is left for the wildcard from right to left func findParamLen(s string, segment *routeSegment) int { if segment.IsLast { return findParamLenForLastSegment(s, segment) } if segment.Length != 0 && len(s) >= segment.Length { return segment.Length } else if segment.IsGreedy { // Search the parameters until the next constant part // special logic for greedy params searchCount := strings.Count(s, segment.ComparePart) if searchCount > 1 { return findGreedyParamLen(s, searchCount, segment) } } if len(segment.ComparePart) == 1 { if constPosition := strings.IndexByte(s, segment.ComparePart[0]); constPosition != -1 { return constPosition } } else if constPosition := strings.Index(s, segment.ComparePart); constPosition != -1 { // if the compare part was found, but contains a slash although this part is not greedy, then it must not match // example: /api/:param/fixedEnd -> path: /api/123/456/fixedEnd = no match , /api/123/fixedEnd = match if !segment.IsGreedy && strings.IndexByte(s[:constPosition], slashDelimiter) != -1 { return 0 } return constPosition } return len(s) } // findParamLenForLastSegment get the length of the parameter if it is the last segment func findParamLenForLastSegment(s string, seg *routeSegment) int { if !seg.IsGreedy { if i := strings.IndexByte(s, slashDelimiter); i != -1 { return i } } return len(s) } // findGreedyParamLen get the length of the parameter for greedy segments from right to left func findGreedyParamLen(s string, searchCount int, segment *routeSegment) int { // check all from right to left segments for i := segment.PartCount; i > 0 && searchCount > 0; i-- { searchCount-- constPosition := strings.LastIndex(s, segment.ComparePart) if constPosition == -1 { break } s = s[:constPosition] } return len(s) } // GetTrimmedParam trims the ':' & '?' from a string func GetTrimmedParam(param string) string { start := 0 end := len(param) if end == 0 || param[start] != paramStarterChar { // is not a param return param } start++ if param[end-1] == optionalParam { // is ? end-- } return param[start:end] } // RemoveEscapeChar removes escape characters func RemoveEscapeChar(word string) string { // Fast path: check if there are any escape characters first escapeIdx := strings.IndexByte(word, '\\') if escapeIdx == -1 { return word // No escape chars, return original string without allocation } // Slow path: copy and remove escape characters b := []byte(word) dst := escapeIdx for src := escapeIdx + 1; src < len(b); src++ { if b[src] != '\\' { b[dst] = b[src] dst++ } } return string(b[:dst]) } // RemoveEscapeCharBytes removes escape characters func RemoveEscapeCharBytes(word []byte) []byte { dst := 0 for src := range word { if word[src] != '\\' { word[dst] = word[src] dst++ } } return word[:dst] } func getParamConstraintType(constraintPart string) TypeConstraint { switch constraintPart { case ConstraintInt: return intConstraint case ConstraintBool: return boolConstraint case ConstraintFloat: return floatConstraint case ConstraintAlpha: return alphaConstraint case ConstraintGUID: return guidConstraint case ConstraintMinLen, ConstraintMinLenLower: return minLenConstraint case ConstraintMaxLen, ConstraintMaxLenLower: return maxLenConstraint case ConstraintLen: return lenConstraint case ConstraintBetweenLen, ConstraintBetweenLenLower: return betweenLenConstraint case ConstraintMin: return minConstraint case ConstraintMax: return maxConstraint case ConstraintRange: return rangeConstraint case ConstraintDatetime: return datetimeConstraint case ConstraintRegex: return regexConstraint default: return noConstraint } } // CheckConstraint validates if a param matches the given constraint // Returns true if the param passes the constraint check, false otherwise func (c *Constraint) CheckConstraint(param string) bool { // First check if there's a custom constraint with the same name // This allows custom constraints to override built-in constraints for _, cc := range c.customConstraints { if cc.Name() == c.Name { return cc.Execute(param, c.Data...) } } var ( err error num int ) // Validate constraint has required data if c.ID&needOneData != 0 && len(c.Data) == 0 { return false } if c.ID&needTwoData != 0 && len(c.Data) < 2 { return false } switch c.ID { case noConstraint: return true case intConstraint: _, err = strconv.Atoi(param) case boolConstraint: _, err = strconv.ParseBool(param) case floatConstraint: _, err = strconv.ParseFloat(param, 32) case alphaConstraint: for _, r := range param { if !unicode.IsLetter(r) { return false } } case guidConstraint: _, err = uuid.Parse(param) case minLenConstraint: data, parseErr := strconv.Atoi(c.Data[0]) if parseErr != nil { return false } if len(param) < data { return false } case maxLenConstraint: data, parseErr := strconv.Atoi(c.Data[0]) if parseErr != nil { return false } if len(param) > data { return false } case lenConstraint: data, parseErr := strconv.Atoi(c.Data[0]) if parseErr != nil { return false } if len(param) != data { return false } case betweenLenConstraint: data, parseErr := strconv.Atoi(c.Data[0]) if parseErr != nil { return false } data2, parseErr := strconv.Atoi(c.Data[1]) if parseErr != nil { return false } length := len(param) if length < data || length > data2 { return false } case minConstraint: data, parseErr := strconv.Atoi(c.Data[0]) if parseErr != nil { return false } num, err = strconv.Atoi(param) if err != nil || num < data { return false } case maxConstraint: data, parseErr := strconv.Atoi(c.Data[0]) if parseErr != nil { return false } num, err = strconv.Atoi(param) if err != nil || num > data { return false } case rangeConstraint: data, parseErr := strconv.Atoi(c.Data[0]) if parseErr != nil { return false } data2, parseErr := strconv.Atoi(c.Data[1]) if parseErr != nil { return false } num, err = strconv.Atoi(param) if err != nil || num < data || num > data2 { return false } case datetimeConstraint: _, err = time.Parse(c.Data[0], param) if err != nil { return false } case regexConstraint: if c.RegexCompiler == nil { return false } if match := c.RegexCompiler.MatchString(param); !match { return false } default: return false } return err == nil } ================================================ FILE: path_test.go ================================================ // ⚡️ Fiber is an Express inspired web framework written in Go with ☕️ // 📝 GitHub Repository: https://github.com/gofiber/fiber // 📌 API Documentation: https://docs.gofiber.io package fiber import ( "fmt" "strings" "testing" "github.com/stretchr/testify/require" ) // go test -race -run Test_Path_parseRoute func Test_Path_parseRoute(t *testing.T) { t.Parallel() var rp routeParser rp = parseRoute("/shop/product/::filter/color::color/size::size") require.Equal(t, routeParser{ segs: []*routeSegment{ {Const: "/shop/product/:", Length: 15}, {IsParam: true, ParamName: "filter", ComparePart: "/color:", PartCount: 1}, {Const: "/color:", Length: 7}, {IsParam: true, ParamName: "color", ComparePart: "/size:", PartCount: 1}, {Const: "/size:", Length: 6}, {IsParam: true, ParamName: "size", IsLast: true}, }, params: []string{"filter", "color", "size"}, }, rp) rp = parseRoute("/api/v1/:param/abc/*") require.Equal(t, routeParser{ segs: []*routeSegment{ {Const: "/api/v1/", Length: 8}, {IsParam: true, ParamName: "param", ComparePart: "/abc", PartCount: 1}, {Const: "/abc/", Length: 5, HasOptionalSlash: true}, {IsParam: true, ParamName: "*1", IsGreedy: true, IsOptional: true, IsLast: true}, }, params: []string{"param", "*1"}, wildCardCount: 1, }, rp) rp = parseRoute("/v1/some/resource/name\\:customVerb") require.Equal(t, routeParser{ segs: []*routeSegment{ {Const: "/v1/some/resource/name:customVerb", Length: 33, IsLast: true}, }, params: nil, }, rp) rp = parseRoute("/v1/some/resource/:name\\:customVerb") require.Equal(t, routeParser{ segs: []*routeSegment{ {Const: "/v1/some/resource/", Length: 18}, {IsParam: true, ParamName: "name", ComparePart: ":customVerb", PartCount: 1}, {Const: ":customVerb", Length: 11, IsLast: true}, }, params: []string{"name"}, }, rp) // heavy test with escaped characters rp = parseRoute("/v1/some/resource/name\\\\:customVerb?\\?/:param/*") require.Equal(t, routeParser{ segs: []*routeSegment{ {Const: "/v1/some/resource/name:customVerb??/", Length: 36}, {IsParam: true, ParamName: "param", ComparePart: "/", PartCount: 1}, {Const: "/", Length: 1, HasOptionalSlash: true}, {IsParam: true, ParamName: "*1", IsGreedy: true, IsOptional: true, IsLast: true}, }, params: []string{"param", "*1"}, wildCardCount: 1, }, rp) rp = parseRoute("/api/*/:param/:param2") require.Equal(t, routeParser{ segs: []*routeSegment{ {Const: "/api/", Length: 5, HasOptionalSlash: true}, {IsParam: true, ParamName: "*1", IsGreedy: true, IsOptional: true, ComparePart: "/", PartCount: 2}, {Const: "/", Length: 1}, {IsParam: true, ParamName: "param", ComparePart: "/", PartCount: 1}, {Const: "/", Length: 1}, {IsParam: true, ParamName: "param2", IsLast: true}, }, params: []string{"*1", "param", "param2"}, wildCardCount: 1, }, rp) rp = parseRoute("/test:optional?:optional2?") require.Equal(t, routeParser{ segs: []*routeSegment{ {Const: "/test", Length: 5}, {IsParam: true, ParamName: "optional", IsOptional: true, Length: 1}, {IsParam: true, ParamName: "optional2", IsOptional: true, IsLast: true}, }, params: []string{"optional", "optional2"}, }, rp) rp = parseRoute("/config/+.json") require.Equal(t, routeParser{ segs: []*routeSegment{ {Const: "/config/", Length: 8}, {IsParam: true, ParamName: "+1", IsGreedy: true, IsOptional: false, ComparePart: ".json", PartCount: 1}, {Const: ".json", Length: 5, IsLast: true}, }, params: []string{"+1"}, plusCount: 1, }, rp) rp = parseRoute("/api/:day.:month?.:year?") require.Equal(t, routeParser{ segs: []*routeSegment{ {Const: "/api/", Length: 5}, {IsParam: true, ParamName: "day", IsOptional: false, ComparePart: ".", PartCount: 2}, {Const: ".", Length: 1}, {IsParam: true, ParamName: "month", IsOptional: true, ComparePart: ".", PartCount: 1}, {Const: ".", Length: 1}, {IsParam: true, ParamName: "year", IsOptional: true, IsLast: true}, }, params: []string{"day", "month", "year"}, }, rp) rp = parseRoute("/*v1*/proxy") require.Equal(t, routeParser{ segs: []*routeSegment{ {Const: "/", Length: 1, HasOptionalSlash: true}, {IsParam: true, ParamName: "*1", IsGreedy: true, IsOptional: true, ComparePart: "v1", PartCount: 1}, {Const: "v1", Length: 2}, {IsParam: true, ParamName: "*2", IsGreedy: true, IsOptional: true, ComparePart: "/proxy", PartCount: 1}, {Const: "/proxy", Length: 6, IsLast: true}, }, params: []string{"*1", "*2"}, wildCardCount: 2, }, rp) } // go test -race -run Test_Path_matchParams func Test_Path_matchParams(t *testing.T) { t.Parallel() var ctxParams [maxParams]string testCaseFn := func(testCollection routeCaseCollection) { parser := parseRoute(testCollection.pattern) for _, c := range testCollection.testCases { match := parser.getMatch(c.url, c.url, &ctxParams, c.partialCheck) require.Equal(t, c.match, match, "route: '%s', url: '%s'", testCollection.pattern, c.url) if match && len(c.params) > 0 { require.Equal(t, c.params[0:len(c.params)], ctxParams[0:len(c.params)], "route: '%s', url: '%s'", testCollection.pattern, c.url) } } } for _, testCaseCollection := range routeTestCases { testCaseFn(testCaseCollection) } } // go test -race -run Test_RoutePatternMatch func Test_RoutePatternMatch(t *testing.T) { t.Parallel() testCaseFn := func(pattern string, cases []routeTestCase) { for _, c := range cases { // skip all cases for partial checks if c.partialCheck { continue } match := RoutePatternMatch(c.url, pattern) require.Equal(t, c.match, match, "route: '%s', url: '%s'", pattern, c.url) } } for _, testCase := range routeTestCases { testCaseFn(testCase.pattern, testCase.testCases) } } func TestHasPartialMatchBoundary(t *testing.T) { t.Parallel() testCases := []struct { name string path string matchedLength int expected bool }{ { name: "negative length", path: "/demo", matchedLength: -1, expected: false, }, { name: "greater than length", path: "/demo", matchedLength: 6, expected: false, }, { name: "exact match", path: "/demo", matchedLength: len("/demo"), expected: true, }, { name: "zero length", path: "/demo", matchedLength: 0, expected: false, }, { name: "previous rune slash", path: "/demo/child", matchedLength: len("/demo/"), expected: true, }, { name: "next rune slash", path: "/demo/child", matchedLength: len("/demo"), expected: true, }, { name: "no boundary", path: "/demo/child", matchedLength: len("/dem"), expected: false, }, } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { t.Parallel() require.Equal(t, testCase.expected, hasPartialMatchBoundary(testCase.path, testCase.matchedLength)) }) } } func Test_Utils_GetTrimmedParam(t *testing.T) { t.Parallel() res := GetTrimmedParam("") require.Empty(t, res) res = GetTrimmedParam("*") require.Equal(t, "*", res) res = GetTrimmedParam(":param") require.Equal(t, "param", res) res = GetTrimmedParam(":param1?") require.Equal(t, "param1", res) res = GetTrimmedParam("noParam") require.Equal(t, "noParam", res) } func Test_Utils_RemoveEscapeChar(t *testing.T) { t.Parallel() res := RemoveEscapeChar(":test\\:bla") require.Equal(t, ":test:bla", res) res = RemoveEscapeChar("\\abc") require.Equal(t, "abc", res) res = RemoveEscapeChar("noEscapeChar") require.Equal(t, "noEscapeChar", res) } func Test_ConstraintCheckConstraint_InvalidMetadata(t *testing.T) { t.Parallel() testCases := []struct { name string param string constraint Constraint }{ { name: "minLen invalid metadata", constraint: Constraint{ID: minLenConstraint, Data: []string{"abc"}}, param: "abcd", }, { name: "maxLen invalid metadata", constraint: Constraint{ID: maxLenConstraint, Data: []string{"abc"}}, param: "abcd", }, { name: "len invalid metadata", constraint: Constraint{ID: lenConstraint, Data: []string{"abc"}}, param: "abcd", }, { name: "betweenLen invalid first metadata", constraint: Constraint{ID: betweenLenConstraint, Data: []string{"abc", "5"}}, param: "abcd", }, { name: "betweenLen invalid second metadata", constraint: Constraint{ID: betweenLenConstraint, Data: []string{"1", "abc"}}, param: "abcd", }, { name: "min invalid metadata", constraint: Constraint{ID: minConstraint, Data: []string{"abc"}}, param: "10", }, { name: "max invalid metadata", constraint: Constraint{ID: maxConstraint, Data: []string{"abc"}}, param: "10", }, { name: "range invalid first metadata", constraint: Constraint{ID: rangeConstraint, Data: []string{"abc", "10"}}, param: "7", }, { name: "range invalid second metadata", constraint: Constraint{ID: rangeConstraint, Data: []string{"1", "abc"}}, param: "7", }, } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { t.Parallel() require.False(t, testCase.constraint.CheckConstraint(testCase.param)) }) } } func Benchmark_Utils_RemoveEscapeChar(b *testing.B) { b.ReportAllocs() var res string for b.Loop() { res = RemoveEscapeChar(":test\\:bla") } require.Equal(b, ":test:bla", res) } // go test -race -run Test_Path_matchParams func Benchmark_Path_matchParams(t *testing.B) { var ctxParams [maxParams]string benchCaseFn := func(testCollection routeCaseCollection) { parser := parseRoute(testCollection.pattern) for _, c := range testCollection.testCases { var matchRes bool state := "match" if !c.match { state = "not match" } t.Run(testCollection.pattern+" | "+state+" | "+c.url, func(b *testing.B) { for b.Loop() { if match := parser.getMatch(c.url, c.url, &ctxParams, c.partialCheck); match { // Get testCases from the original path matchRes = true } } require.Equal(t, c.match, matchRes, "route: '%s', url: '%s'", testCollection.pattern, c.url) if matchRes && len(c.params) > 0 { require.Equal(t, c.params[0:len(c.params)-1], ctxParams[0:len(c.params)-1], "route: '%s', url: '%s'", testCollection.pattern, c.url) } }) } } for _, testCollection := range benchmarkCases { benchCaseFn(testCollection) } } // go test -race -run Test_RoutePatternMatch func Benchmark_RoutePatternMatch(t *testing.B) { benchCaseFn := func(testCollection routeCaseCollection) { for _, c := range testCollection.testCases { // skip all cases for partial checks if c.partialCheck { continue } var matchRes bool state := "match" if !c.match { state = "not match" } t.Run(testCollection.pattern+" | "+state+" | "+c.url, func(b *testing.B) { for b.Loop() { if match := RoutePatternMatch(c.url, testCollection.pattern); match { // Get testCases from the original path matchRes = true } } require.Equal(t, c.match, matchRes, "route: '%s', url: '%s'", testCollection.pattern, c.url) }) } } for _, testCollection := range benchmarkCases { benchCaseFn(testCollection) } } func Test_Route_TooManyParams_Panic(t *testing.T) { t.Parallel() // Test with exactly maxParams (30) - should work t.Run("exactly_maxParams", func(t *testing.T) { t.Parallel() route := paramsRoute(t, maxParams) require.NotPanics(t, func() { parseRoute(route) }) }) // Test with maxParams + 1 (31) - should panic t.Run("maxParams_plus_one", func(t *testing.T) { t.Parallel() route := paramsRoute(t, maxParams+1) require.PanicsWithValue(t, "Route '"+route+"' has 31 parameters, which exceeds the maximum of 30", func() { parseRoute(route) }) }) // Test with 35 params - should panic t.Run("35_params", func(t *testing.T) { t.Parallel() route := paramsRoute(t, maxParams+5) require.PanicsWithValue(t, "Route '"+route+"' has 35 parameters, which exceeds the maximum of 30", func() { parseRoute(route) }) }) } func Test_App_Register_TooManyParams_Panic(t *testing.T) { t.Parallel() // Test registering a route with too many params via app t.Run("register_via_Get", func(t *testing.T) { t.Parallel() app := New() route := paramsRoute(t, maxParams+1) require.PanicsWithValue(t, "Route '"+route+"' has 31 parameters, which exceeds the maximum of 30", func() { app.Get(route, func(c Ctx) error { return c.SendString("test") }) }) }) // Test registering a route with maxParams works t.Run("register_maxParams_works", func(t *testing.T) { t.Parallel() app := New() route := paramsRoute(t, maxParams) require.NotPanics(t, func() { app.Get(route, func(c Ctx) error { return c.SendString("test") }) }) }) } // paramsRoute generates a route with n parameters for testing parseRoute maxParams condition. // Returns a route in the format "/:p1/:p2/:p3/.../:pN" func paramsRoute(t *testing.T, n int) string { t.Helper() params := make([]string, n) for i := range params { params[i] = fmt.Sprintf(":p%d", i+1) } return "/" + strings.Join(params, "/") } ================================================ FILE: path_testcases_test.go ================================================ // ⚡️ Fiber is an Express inspired web framework written in Go with ☕️ // 📝 GitHub Repository: https://github.com/gofiber/fiber // 📌 API Documentation: https://docs.gofiber.io package fiber import ( "strings" ) type routeTestCase struct { url string params []string match bool partialCheck bool } type routeCaseCollection struct { pattern string testCases []routeTestCase } var ( benchmarkCases []routeCaseCollection routeTestCases []routeCaseCollection ) func init() { // smaller list for benchmark cases benchmarkCases = []routeCaseCollection{ { pattern: "/api/v1/const", testCases: []routeTestCase{ {url: "/api/v1/const", params: []string{}, match: true}, {url: "/api/v1", params: nil, match: false}, {url: "/api/v1/", params: nil, match: false}, {url: "/api/v1/something", params: nil, match: false}, }, }, { pattern: "/api/:param/fixedEnd", testCases: []routeTestCase{ {url: "/api/abc/fixedEnd", params: []string{"abc"}, match: true}, {url: "/api/abc/def/fixedEnd", params: nil, match: false}, }, }, { pattern: "/api/v1/:param/*", testCases: []routeTestCase{ {url: "/api/v1/entity", params: []string{"entity", ""}, match: true}, {url: "/api/v1/entity/", params: []string{"entity", ""}, match: true}, {url: "/api/v1/entity/1", params: []string{"entity", "1"}, match: true}, {url: "/api/v", params: nil, match: false}, {url: "/api/v2", params: nil, match: false}, {url: "/api/v1/", params: nil, match: false}, }, }, } // combine benchmark cases and other cases routeTestCases = benchmarkCases routeTestCases = append( routeTestCases, []routeCaseCollection{ { pattern: "/api/v1/:param/+", testCases: []routeTestCase{ {url: "/api/v1/entity", params: nil, match: false}, {url: "/api/v1/entity/", params: nil, match: false}, {url: "/api/v1/entity/1", params: []string{"entity", "1"}, match: true}, {url: "/api/v", params: nil, match: false}, {url: "/api/v2", params: nil, match: false}, {url: "/api/v1/", params: nil, match: false}, }, }, { pattern: "/api/v1/:param?", testCases: []routeTestCase{ {url: "/api/v1", params: []string{""}, match: true}, {url: "/api/v1/", params: []string{""}, match: true}, {url: "/api/v1/optional", params: []string{"optional"}, match: true}, {url: "/api/v", params: nil, match: false}, {url: "/api/v2", params: nil, match: false}, {url: "/api/xyz", params: nil, match: false}, }, }, { pattern: `/v1/some/resource/name\:customVerb`, testCases: []routeTestCase{ {url: "/v1/some/resource/name:customVerb", params: nil, match: true}, {url: "/v1/some/resource/name:test", params: nil, match: false}, }, }, { pattern: `/v1/some/resource/:name\:customVerb`, testCases: []routeTestCase{ {url: "/v1/some/resource/test:customVerb", params: []string{"test"}, match: true}, {url: "/v1/some/resource/test:test", params: nil, match: false}, }, }, { pattern: `/v1/some/resource/name\\:customVerb?\?/:param/*`, testCases: []routeTestCase{ {url: "/v1/some/resource/name:customVerb??/test/optionalWildCard/character", params: []string{"test", "optionalWildCard/character"}, match: true}, {url: "/v1/some/resource/name:customVerb??/test", params: []string{"test", ""}, match: true}, }, }, { pattern: "/api/v1/*", testCases: []routeTestCase{ {url: "/api/v1", params: []string{""}, match: true}, {url: "/api/v1/", params: []string{""}, match: true}, {url: "/api/v1/entity", params: []string{"entity"}, match: true}, {url: "/api/v1/entity/1/2", params: []string{"entity/1/2"}, match: true}, {url: "/api/v1/Entity/1/2", params: []string{"Entity/1/2"}, match: true}, {url: "/api/v", params: nil, match: false}, {url: "/api/v2", params: nil, match: false}, {url: "/api/abc", params: nil, match: false}, }, }, { pattern: "/api/v1/:param", testCases: []routeTestCase{ {url: "/api/v1/entity", params: []string{"entity"}, match: true}, {url: "/api/v1/entity/8728382", params: nil, match: false}, {url: "/api/v1", params: nil, match: false}, {url: "/api/v1/", params: nil, match: false}, }, }, { pattern: "/api/v1/:param-:param2", testCases: []routeTestCase{ {url: "/api/v1/entity-entity2", params: []string{"entity", "entity2"}, match: true}, {url: "/api/v1/entity/8728382", params: nil, match: false}, {url: "/api/v1/entity-8728382", params: []string{"entity", "8728382"}, match: true}, {url: "/api/v1", params: nil, match: false}, {url: "/api/v1/", params: nil, match: false}, }, }, { pattern: "/api/v1/:filename.:extension", testCases: []routeTestCase{ {url: "/api/v1/test.pdf", params: []string{"test", "pdf"}, match: true}, {url: "/api/v1/test/pdf", params: nil, match: false}, {url: "/api/v1/test-pdf", params: nil, match: false}, {url: "/api/v1/test_pdf", params: nil, match: false}, {url: "/api/v1", params: nil, match: false}, {url: "/api/v1/", params: nil, match: false}, }, }, { pattern: "/shop/product/::filter/color::color/size::size", testCases: []routeTestCase{ {url: "/shop/product/:test/color:blue/size:xs", params: []string{"test", "blue", "xs"}, match: true}, {url: "/shop/product/test/color:blue/size:xs", params: nil, match: false}, }, }, { pattern: "/::param?", testCases: []routeTestCase{ {url: "/:hello", params: []string{"hello"}, match: true}, {url: "/:", params: []string{""}, match: true}, {url: "/", params: nil, match: false}, }, }, // successive parameters, each take one character and the last parameter gets everything { pattern: "/test:sign:param", testCases: []routeTestCase{ {url: "/test-abc", params: []string{"-", "abc"}, match: true}, {url: "/test", params: nil, match: false}, }, }, // optional parameters are not greedy { pattern: "/:param1:param2?:param3", testCases: []routeTestCase{ {url: "/abbbc", params: []string{"a", "b", "bbc"}, match: true}, // {url: "/ac", testCases: []string{"a", "", "c"}, match: true}, // TODO: fix it {url: "/test", params: []string{"t", "e", "st"}, match: true}, }, }, { pattern: "/test:optional?:mandatory", testCases: []routeTestCase{ // {url: "/testo", testCases: []string{"", "o"}, match: true}, // TODO: fix it {url: "/testoaaa", params: []string{"o", "aaa"}, match: true}, {url: "/test", params: nil, match: false}, }, }, { pattern: "/test:optional?:optional2?", testCases: []routeTestCase{ {url: "/testo", params: []string{"o", ""}, match: true}, {url: "/testoaaa", params: []string{"o", "aaa"}, match: true}, {url: "/test", params: []string{"", ""}, match: true}, {url: "/tes", params: nil, match: false}, }, }, { pattern: "/foo:param?bar", testCases: []routeTestCase{ {url: "/foofaselbar", params: []string{"fasel"}, match: true}, {url: "/foobar", params: []string{""}, match: true}, {url: "/fooba", params: nil, match: false}, {url: "/fobar", params: nil, match: false}, }, }, { pattern: "/foo*bar", testCases: []routeTestCase{ {url: "/foofaselbar", params: []string{"fasel"}, match: true}, {url: "/foobar", params: []string{""}, match: true}, {url: "/", params: nil, match: false}, }, }, { pattern: "/foo+bar", testCases: []routeTestCase{ {url: "/foofaselbar", params: []string{"fasel"}, match: true}, {url: "/foobar", params: nil, match: false}, {url: "/", params: nil, match: false}, }, }, { pattern: "/a*cde*g/", testCases: []routeTestCase{ {url: "/abbbcdefffg", params: []string{"bbb", "fff"}, match: true}, {url: "/acdeg", params: []string{"", ""}, match: true}, {url: "/", params: nil, match: false}, }, }, { pattern: "/*v1*/proxy", testCases: []routeTestCase{ {url: "/customer/v1/cart/proxy", params: []string{"customer/", "/cart"}, match: true}, {url: "/v1/proxy", params: []string{"", ""}, match: true}, {url: "/v1/", params: nil, match: false}, }, }, // successive wildcard -> first wildcard is greedy { pattern: "/foo***bar", testCases: []routeTestCase{ {url: "/foo*abar", params: []string{"*a", "", ""}, match: true}, {url: "/foo*bar", params: []string{"*", "", ""}, match: true}, {url: "/foobar", params: []string{"", "", ""}, match: true}, {url: "/fooba", params: nil, match: false}, }, }, // chars in front of a parameter { pattern: "/name::name", testCases: []routeTestCase{ {url: "/name:john", params: []string{"john"}, match: true}, }, }, { pattern: "/@:name", testCases: []routeTestCase{ {url: "/@john", params: []string{"john"}, match: true}, }, }, { pattern: "/-:name", testCases: []routeTestCase{ {url: "/-john", params: []string{"john"}, match: true}, }, }, { pattern: "/.:name", testCases: []routeTestCase{ {url: "/.john", params: []string{"john"}, match: true}, }, }, { pattern: "/api/v1/:param/abc/*", testCases: []routeTestCase{ {url: "/api/v1/well/abc/wildcard", params: []string{"well", "wildcard"}, match: true}, {url: "/api/v1/well/abc/", params: []string{"well", ""}, match: true}, {url: "/api/v1/well/abc", params: []string{"well", ""}, match: true}, {url: "/api/v1/well/ttt", params: nil, match: false}, }, }, { pattern: "/api/:day/:month?/:year?", testCases: []routeTestCase{ {url: "/api/1", params: []string{"1", "", ""}, match: true}, {url: "/api/1/", params: []string{"1", "", ""}, match: true}, {url: "/api/1//", params: []string{"1", "", ""}, match: true}, {url: "/api/1/-/", params: []string{"1", "-", ""}, match: true}, {url: "/api/1-", params: []string{"1-", "", ""}, match: true}, {url: "/api/1.", params: []string{"1.", "", ""}, match: true}, {url: "/api/1/2", params: []string{"1", "2", ""}, match: true}, {url: "/api/1/2/3", params: []string{"1", "2", "3"}, match: true}, {url: "/api/", params: nil, match: false}, }, }, { pattern: "/api/:day.:month?.:year?", testCases: []routeTestCase{ {url: "/api/1", params: nil, match: false}, {url: "/api/1/", params: nil, match: false}, {url: "/api/1.", params: nil, match: false}, {url: "/api/1..", params: []string{"1", "", ""}, match: true}, {url: "/api/1.2", params: nil, match: false}, {url: "/api/1.2.", params: []string{"1", "2", ""}, match: true}, {url: "/api/1.2.3", params: []string{"1", "2", "3"}, match: true}, {url: "/api/", params: nil, match: false}, }, }, { pattern: "/api/:day-:month?-:year?", testCases: []routeTestCase{ {url: "/api/1", params: nil, match: false}, {url: "/api/1/", params: nil, match: false}, {url: "/api/1-", params: nil, match: false}, {url: "/api/1--", params: []string{"1", "", ""}, match: true}, {url: "/api/1-/", params: nil, match: false}, // {url: "/api/1-/-", testCases: nil, match: false}, // TODO: fix this part {url: "/api/1-2", params: nil, match: false}, {url: "/api/1-2-", params: []string{"1", "2", ""}, match: true}, {url: "/api/1-2-3", params: []string{"1", "2", "3"}, match: true}, {url: "/api/", params: nil, match: false}, }, }, { pattern: "/api/*", testCases: []routeTestCase{ {url: "/api/", params: []string{""}, match: true}, {url: "/api/joker", params: []string{"joker"}, match: true}, {url: "/api", params: []string{""}, match: true}, {url: "/api/v1/entity", params: []string{"v1/entity"}, match: true}, {url: "/api2/v1/entity", params: nil, match: false}, {url: "/api_ignore/v1/entity", params: nil, match: false}, }, }, { pattern: "/partialCheck/foo/", testCases: []routeTestCase{ {url: "/partialCheck/foo/", params: nil, match: true, partialCheck: true}, {url: "/partialCheck/foo/bar", params: nil, match: true, partialCheck: true}, {url: "/partialCheck/foo/bar/baz", params: nil, match: true, partialCheck: true}, {url: "/partialCheck/foobar", params: nil, match: false, partialCheck: true}, }, }, { pattern: "/partialCheck/foo/bar/:param", testCases: []routeTestCase{ {url: "/partialCheck/foo/bar/test", params: []string{"test"}, match: true, partialCheck: true}, {url: "/partialCheck/foo/bar/test/test2", params: []string{"test"}, match: true, partialCheck: true}, {url: "/partialCheck/foo/bar", params: nil, match: false, partialCheck: true}, {url: "/partialFoo", params: nil, match: false, partialCheck: true}, }, }, { pattern: "/partialCheck/foo", testCases: []routeTestCase{ {url: "/partialCheck/foo", params: nil, match: true, partialCheck: true}, {url: "/partialCheck/foo/", params: nil, match: true, partialCheck: true}, {url: "/partialCheck/foo/bar", params: nil, match: true, partialCheck: true}, {url: "/partialCheck/foobar", params: nil, match: false, partialCheck: true}, }, }, { pattern: "/partialCheck/:param", testCases: []routeTestCase{ {url: "/partialCheck/value", params: []string{"value"}, match: true, partialCheck: true}, {url: "/partialCheck/value/", params: []string{"value"}, match: true, partialCheck: true}, {url: "/partialCheck/value/next", params: []string{"value"}, match: true, partialCheck: true}, }, }, { pattern: "/", testCases: []routeTestCase{ {url: "/api", params: nil, match: false}, {url: "", params: []string{}, match: true}, {url: "/", params: []string{}, match: true}, }, }, { pattern: "/config/abc.json", testCases: []routeTestCase{ {url: "/config/abc.json", params: []string{}, match: true}, {url: "config/abc.json", params: nil, match: false}, {url: "/config/efg.json", params: nil, match: false}, {url: "/config", params: nil, match: false}, }, }, { pattern: "/config/*.json", testCases: []routeTestCase{ {url: "/config/abc.json", params: []string{"abc"}, match: true}, {url: "/config/efg.json", params: []string{"efg"}, match: true}, {url: "/config/.json", params: []string{""}, match: true}, {url: "/config/efg.csv", params: nil, match: false}, {url: "config/abc.json", params: nil, match: false}, {url: "/config", params: nil, match: false}, }, }, { pattern: "/config/+.json", testCases: []routeTestCase{ {url: "/config/abc.json", params: []string{"abc"}, match: true}, {url: "/config/.json", params: nil, match: false}, {url: "/config/efg.json", params: []string{"efg"}, match: true}, {url: "/config/efg.csv", params: nil, match: false}, {url: "config/abc.json", params: nil, match: false}, {url: "/config", params: nil, match: false}, }, }, { pattern: "/xyz", testCases: []routeTestCase{ {url: "xyz", params: nil, match: false}, {url: "xyz/", params: nil, match: false}, }, }, { pattern: "/api/*/:param?", testCases: []routeTestCase{ {url: "/api/", params: []string{"", ""}, match: true}, {url: "/api/joker", params: []string{"joker", ""}, match: true}, {url: "/api/joker/batman", params: []string{"joker", "batman"}, match: true}, {url: "/api/joker//batman", params: []string{"joker/", "batman"}, match: true}, {url: "/api/joker/batman/robin", params: []string{"joker/batman", "robin"}, match: true}, {url: "/api/joker/batman/robin/1", params: []string{"joker/batman/robin", "1"}, match: true}, {url: "/api/joker/batman/robin/1/", params: []string{"joker/batman/robin/1", ""}, match: true}, {url: "/api/joker-batman/robin/1", params: []string{"joker-batman/robin", "1"}, match: true}, {url: "/api/joker-batman-robin/1", params: []string{"joker-batman-robin", "1"}, match: true}, {url: "/api/joker-batman-robin-1", params: []string{"joker-batman-robin-1", ""}, match: true}, {url: "/api", params: []string{"", ""}, match: true}, }, }, { pattern: "/api/*/:param", testCases: []routeTestCase{ {url: "/api/test/abc", params: []string{"test", "abc"}, match: true}, {url: "/api/joker/batman", params: []string{"joker", "batman"}, match: true}, {url: "/api/joker/batman/robin", params: []string{"joker/batman", "robin"}, match: true}, {url: "/api/joker/batman/robin/1", params: []string{"joker/batman/robin", "1"}, match: true}, {url: "/api/joker/batman-robin/1", params: []string{"joker/batman-robin", "1"}, match: true}, {url: "/api/joker-batman-robin-1", params: nil, match: false}, {url: "/api", params: nil, match: false}, }, }, { pattern: "/api/+/:param", testCases: []routeTestCase{ {url: "/api/test/abc", params: []string{"test", "abc"}, match: true}, {url: "/api/joker/batman/robin/1", params: []string{"joker/batman/robin", "1"}, match: true}, {url: "/api/joker", params: nil, match: false}, {url: "/api", params: nil, match: false}, }, }, { pattern: "/api/*/:param/:param2", testCases: []routeTestCase{ {url: "/api/test/abc/1", params: []string{"test", "abc", "1"}, match: true}, {url: "/api/joker/batman", params: nil, match: false}, {url: "/api/joker/batman-robin/1", params: []string{"joker", "batman-robin", "1"}, match: true}, {url: "/api/joker-batman-robin-1", params: nil, match: false}, {url: "/api/test/abc", params: nil, match: false}, {url: "/api/joker/batman/robin", params: []string{"joker", "batman", "robin"}, match: true}, {url: "/api/joker/batman/robin/1", params: []string{"joker/batman", "robin", "1"}, match: true}, {url: "/api/joker/batman/robin/1/2", params: []string{"joker/batman/robin", "1", "2"}, match: true}, {url: "/api", params: nil, match: false}, {url: "/api/:test", params: nil, match: false}, }, }, { pattern: "/api/v1/:param", testCases: []routeTestCase{ {url: "/api/v1/entity", params: nil, match: false}, {url: "/api/v1/8728382", params: []string{"8728382"}, match: true}, {url: "/api/v1/true", params: nil, match: false}, }, }, { pattern: "/api/v1/:param", testCases: []routeTestCase{ {url: "/api/v1/entity", params: nil, match: false}, {url: "/api/v1/8728382", params: nil, match: false}, {url: "/api/v1/true", params: []string{"true"}, match: true}, }, }, { pattern: "/api/v1/:param", testCases: []routeTestCase{ {url: "/api/v1/entity", params: nil, match: false}, {url: "/api/v1/8728382", params: []string{"8728382"}, match: true}, {url: "/api/v1/8728382.5", params: []string{"8728382.5"}, match: true}, {url: "/api/v1/true", params: nil, match: false}, }, }, { pattern: "/api/v1/:param", testCases: []routeTestCase{ {url: "/api/v1/entity", params: []string{"entity"}, match: true}, {url: "/api/v1/#!?", params: nil, match: false}, {url: "/api/v1/8728382", params: nil, match: false}, }, }, { pattern: "/api/v1/:param", testCases: []routeTestCase{ {url: "/api/v1/entity", params: nil, match: false}, {url: "/api/v1/8728382", params: nil, match: false}, {url: "/api/v1/f0fa66cc-d22e-445b-866d-1d76e776371d", params: []string{"f0fa66cc-d22e-445b-866d-1d76e776371d"}, match: true}, }, }, { pattern: "/api/v1/:param", testCases: []routeTestCase{ {url: "/api/v1/entity", params: nil, match: false}, {url: "/api/v1/8728382", params: nil, match: false}, }, }, { pattern: "/api/v1/:param", testCases: []routeTestCase{ {url: "/api/v1/entity", params: []string{"entity"}, match: true}, {url: "/api/v1/ent", params: nil, match: false}, {url: "/api/v1/8728382", params: []string{"8728382"}, match: true}, {url: "/api/v1/123", params: nil, match: false}, {url: "/api/v1/12345", params: []string{"12345"}, match: true}, }, }, { pattern: "/api/v1/:param", testCases: []routeTestCase{ {url: "/api/v1/entity", params: nil, match: false}, {url: "/api/v1/ent", params: []string{"ent"}, match: true}, {url: "/api/v1/8728382", params: nil, match: false}, {url: "/api/v1/123", params: []string{"123"}, match: true}, {url: "/api/v1/12345", params: []string{"12345"}, match: true}, }, }, { pattern: "/api/v1/:param", testCases: []routeTestCase{ {url: "/api/v1/ent", params: nil, match: false}, {url: "/api/v1/123", params: nil, match: false}, {url: "/api/v1/12345", params: []string{"12345"}, match: true}, }, }, { pattern: "/api/v1/:param", testCases: []routeTestCase{ {url: "/api/v1/entity", params: nil, match: false}, {url: "/api/v1/ent", params: nil, match: false}, }, }, { pattern: "/api/v1/:param", testCases: []routeTestCase{ {url: "/api/v1/e", params: nil, match: false}, {url: "/api/v1/en", params: []string{"en"}, match: true}, {url: "/api/v1/8728382", params: nil, match: false}, {url: "/api/v1/123", params: []string{"123"}, match: true}, {url: "/api/v1/12345", params: []string{"12345"}, match: true}, }, }, { pattern: "/api/v1/:param", testCases: []routeTestCase{ {url: "/api/v1/e", params: nil, match: false}, {url: "/api/v1/en", params: []string{"en"}, match: true}, {url: "/api/v1/8728382", params: nil, match: false}, {url: "/api/v1/123", params: []string{"123"}, match: true}, {url: "/api/v1/12345", params: []string{"12345"}, match: true}, }, }, { pattern: "/api/v1/:param", testCases: []routeTestCase{ {url: "/api/v1/ent", params: nil, match: false}, {url: "/api/v1/1", params: nil, match: false}, {url: "/api/v1/5", params: []string{"5"}, match: true}, }, }, { pattern: "/api/v1/:param", testCases: []routeTestCase{ {url: "/api/v1/ent", params: nil, match: false}, {url: "/api/v1/1", params: []string{"1"}, match: true}, {url: "/api/v1/5", params: []string{"5"}, match: true}, {url: "/api/v1/15", params: nil, match: false}, }, }, { pattern: "/api/v1/:param", testCases: []routeTestCase{ {url: "/api/v1/ent", params: nil, match: false}, {url: "/api/v1/9", params: []string{"9"}, match: true}, {url: "/api/v1/5", params: []string{"5"}, match: true}, {url: "/api/v1/15", params: nil, match: false}, }, }, { pattern: `/api/v1/:param`, testCases: []routeTestCase{ {url: "/api/v1/entity", params: nil, match: false}, {url: "/api/v1/8728382", params: nil, match: false}, {url: "/api/v1/2005-11-01", params: []string{"2005-11-01"}, match: true}, }, }, { pattern: "/api/v1/:param", testCases: []routeTestCase{ {url: "/api/v1/ent", params: nil, match: false}, {url: "/api/v1/15", params: nil, match: false}, {url: "/api/v1/peach", params: []string{"peach"}, match: true}, {url: "/api/v1/p34ch", params: nil, match: false}, }, }, { pattern: "/api/v1/:param", testCases: []routeTestCase{ {url: "/api/v1/12", params: nil, match: false}, {url: "/api/v1/xy", params: nil, match: false}, {url: "/api/v1/test", params: []string{"test"}, match: true}, {url: "/api/v1/" + strings.Repeat("a", 64), params: nil, match: false}, }, }, { pattern: `/api/v1/:param`, testCases: []routeTestCase{ {url: "/api/v1/ent", params: nil, match: false}, {url: "/api/v1/15", params: nil, match: false}, {url: "/api/v1/2022-08-27", params: []string{"2022-08-27"}, match: true}, {url: "/api/v1/2022/08-27", params: nil, match: false}, }, }, { pattern: "/api/v1/:param", testCases: []routeTestCase{ {url: "/api/v1/entity", params: nil, match: false}, {url: "/api/v1/8728382", params: []string{"8728382"}, match: true}, {url: "/api/v1/true", params: nil, match: false}, }, }, { pattern: "/api/v1/:param", testCases: []routeTestCase{ {url: "/api/v1/entity", params: nil, match: false}, {url: "/api/v1/8728382", params: nil, match: false}, {url: "/api/v1/123", params: []string{"123"}, match: true}, {url: "/api/v1/true", params: nil, match: false}, }, }, { pattern: "/api/v1/:param", testCases: []routeTestCase{ {url: "/api/v1/entity", params: nil, match: false}, {url: "/api/v1/87283827683", params: nil, match: false}, {url: "/api/v1/123", params: []string{"123"}, match: true}, {url: "/api/v1/true", params: nil, match: false}, }, }, { pattern: "/api/v1/:param", testCases: []routeTestCase{ {url: "/api/v1/entity", params: nil, match: false}, {url: "/api/v1/87283827683", params: nil, match: false}, {url: "/api/v1/25", params: []string{"25"}, match: true}, {url: "/api/v1/true", params: nil, match: false}, }, }, { pattern: `/api/v1/:param`, testCases: []routeTestCase{ {url: "/api/v1/entity", params: []string{"entity"}, match: true}, {url: "/api/v1/87283827683", params: []string{"87283827683"}, match: true}, {url: "/api/v1/25", params: []string{"25"}, match: true}, {url: "/api/v1/true", params: []string{"true"}, match: true}, }, }, { pattern: `/api/v1/:param`, testCases: []routeTestCase{ {url: "/api/v1/entity", params: nil, match: false}, {url: "/api/v1/87283827683", params: nil, match: false}, {url: "/api/v1/25", params: nil, match: false}, {url: "/api/v1/1200", params: nil, match: false}, {url: "/api/v1/true", params: nil, match: false}, }, }, { pattern: "/api/v1/:param", testCases: []routeTestCase{ {url: "/api/v1/entity", params: nil, match: false}, {url: "/api/v1/87283827683", params: nil, match: false}, {url: "/api/v1/25", params: []string{"25"}, match: true}, {url: "/api/v1/1200", params: []string{"1200"}, match: true}, {url: "/api/v1/true", params: nil, match: false}, }, }, { pattern: "/api/v1/:lang/videos/:page", testCases: []routeTestCase{ {url: "/api/v1/try/videos/200", params: nil, match: false}, {url: "/api/v1/tr/videos/1800", params: nil, match: false}, {url: "/api/v1/tr/videos/100", params: []string{"tr", "100"}, match: true}, {url: "/api/v1/e/videos/10", params: nil, match: false}, }, }, { pattern: "/api/v1/:lang/:page", testCases: []routeTestCase{ {url: "/api/v1/try/200", params: nil, match: false}, {url: "/api/v1/tr/1800", params: nil, match: false}, {url: "/api/v1/tr/100", params: []string{"tr", "100"}, match: true}, {url: "/api/v1/e/10", params: nil, match: false}, }, }, { pattern: "/api/v1/:lang/:page", testCases: []routeTestCase{ {url: "/api/v1/try/200", params: []string{"try", "200"}, match: true}, {url: "/api/v1/tr/1800", params: nil, match: false}, {url: "/api/v1/tr/100", params: []string{"tr", "100"}, match: true}, {url: "/api/v1/e/10", params: nil, match: false}, }, }, { pattern: "/api/v1/:lang/:page", testCases: []routeTestCase{ {url: "/api/v1/try/200", params: nil, match: false}, {url: "/api/v1/tr/1800", params: []string{"tr", "1800"}, match: true}, {url: "/api/v1/tr/100", params: []string{"tr", "100"}, match: true}, {url: "/api/v1/e/10", params: nil, match: false}, }, }, { pattern: `/api/v1/:date/:regex`, testCases: []routeTestCase{ {url: "/api/v1/2005-11-01/a", params: nil, match: false}, {url: "/api/v1/2005-1101/paach", params: nil, match: false}, {url: "/api/v1/2005-11-01/peach", params: []string{"2005-11-01", "peach"}, match: true}, }, }, { pattern: "/api/v1/:param?", testCases: []routeTestCase{ {url: "/api/v1/entity", params: nil, match: false}, {url: "/api/v1/8728382", params: []string{"8728382"}, match: true}, {url: "/api/v1/true", params: nil, match: false}, {url: "/api/v1/", params: []string{""}, match: true}, }, }, // Add test case for RegexCompiler == nil { pattern: "/api/v1/:param", testCases: []routeTestCase{ {url: "/api/v1/123", params: []string{"123"}, match: true}, {url: "/api/v1/abc", params: nil, match: false}, }, }, }..., ) } ================================================ FILE: prefork.go ================================================ package fiber import ( "crypto/tls" "errors" "fmt" "net" "os" "os/exec" "runtime" "sync/atomic" "time" "github.com/valyala/fasthttp/reuseport" "github.com/gofiber/fiber/v3/log" ) const ( envPreforkChildKey = "FIBER_PREFORK_CHILD" envPreforkChildVal = "1" sleepDuration = 100 * time.Millisecond windowsOS = "windows" ) var ( testPreforkMaster = false testOnPrefork = false ) // IsChild determines if the current process is a child of Prefork func IsChild() bool { return os.Getenv(envPreforkChildKey) == envPreforkChildVal } // prefork manages child processes to make use of the OS REUSEPORT or REUSEADDR feature func (app *App) prefork(addr string, tlsConfig *tls.Config, cfg *ListenConfig) error { if cfg == nil { cfg = &ListenConfig{} } var ln net.Listener var err error // 👶 child process 👶 if IsChild() { // use 1 cpu core per child process runtime.GOMAXPROCS(1) // Linux will use SO_REUSEPORT and Windows falls back to SO_REUSEADDR // Only tcp4 or tcp6 is supported when preforking, both are not supported if ln, err = reuseport.Listen(cfg.ListenerNetwork, addr); err != nil { if !cfg.DisableStartupMessage { time.Sleep(sleepDuration) // avoid colliding with startup message } return fmt.Errorf("prefork: %w", err) } // wrap a tls config around the listener if provided if tlsConfig != nil { ln = tls.NewListener(ln, tlsConfig) } // kill current child proc when master exits masterPID := os.Getppid() go watchMaster(masterPID) // prepare the server for the start app.startupProcess() if cfg.ListenerAddrFunc != nil { cfg.ListenerAddrFunc(ln.Addr()) } // listen for incoming connections return app.server.Serve(ln) } // 👮 master process 👮 type child struct { err error pid int } // create variables maxProcs := runtime.GOMAXPROCS(0) children := make(map[int]*exec.Cmd) channel := make(chan child, maxProcs) // kill child procs when master exits defer func() { for _, proc := range children { if err = proc.Process.Kill(); err != nil { if !errors.Is(err, os.ErrProcessDone) { log.Errorf("prefork: failed to kill child: %v", err) } } } }() // collect child pids var childPIDs []int // launch child procs for range maxProcs { cmd := exec.Command(os.Args[0], os.Args[1:]...) //nolint:gosec // It's fine to launch the same process again if testPreforkMaster { // When test prefork master, // just start the child process with a dummy cmd, // which will exit soon cmd = dummyCmd() } cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr // add fiber prefork child flag into child proc env cmd.Env = append(os.Environ(), fmt.Sprintf("%s=%s", envPreforkChildKey, envPreforkChildVal), ) if err = cmd.Start(); err != nil { return fmt.Errorf("failed to start a child prefork process, error: %w", err) } // store child process pid := cmd.Process.Pid children[pid] = cmd childPIDs = append(childPIDs, pid) // execute fork hook if app.hooks != nil { if testOnPrefork { app.hooks.executeOnForkHooks(dummyPid) } else { app.hooks.executeOnForkHooks(pid) } } // notify master if child crashes go func() { channel <- child{pid: pid, err: cmd.Wait()} }() } // Run onListen hooks // Hooks have to be run here as different as non-prefork mode due to they should run as child or master listenData := app.prepareListenData(addr, tlsConfig != nil, cfg, childPIDs) app.runOnListenHooks(listenData) app.startupMessage(listenData, cfg) if cfg.EnablePrintRoutes { app.printRoutesMessage() } // return error if child crashes return (<-channel).err } // watchMaster watches the master process and exits if it dies. // It detects master death by checking if the parent PID has changed, // which happens when the master exits and the child is reparented to // another process (often init/PID 1, but could be a subreaper). func watchMaster(masterPID int) { if runtime.GOOS == windowsOS { // finds parent process, // and waits for it to exit p, err := os.FindProcess(masterPID) if err == nil { _, _ = p.Wait() //nolint:errcheck // It is fine to ignore the error here } os.Exit(1) //nolint:revive // Calling os.Exit is fine here in the prefork } // Watch for parent PID changes. When the master exits, the OS // reparents the child to another process, causing Getppid() to change. // Comparing against the original PID instead of hardcoding 1 ensures // this works correctly when the master itself is PID 1 (e.g. in // Docker containers). const watchInterval = 500 * time.Millisecond for range time.NewTicker(watchInterval).C { if os.Getppid() != masterPID { os.Exit(1) //nolint:revive // Calling os.Exit is fine here in the prefork } } } var ( dummyPid = 1 dummyChildCmd atomic.Value ) // dummyCmd is for internal prefork testing func dummyCmd() *exec.Cmd { command := "go" if storeCommand := dummyChildCmd.Load(); storeCommand != nil && storeCommand != "" { command = storeCommand.(string) //nolint:forcetypeassert,errcheck // We always store a string in here } if runtime.GOOS == windowsOS { return exec.Command("cmd", "/C", command, "version") } return exec.Command(command, "version") } ================================================ FILE: prefork_test.go ================================================ // ⚡️ Fiber is an Express inspired web framework written in Go with ☕️ // 📄 GitHub Repository: https://github.com/gofiber/fiber // 📌 API Documentation: https://docs.gofiber.io // 💖 Maintained and modified for Fiber by @renewerner87 package fiber import ( "crypto/tls" "io" "os" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func Test_App_Prefork_Child_Process(t *testing.T) { // Reset test var testPreforkMaster = true setupIsChild(t) app := New() cfg := listenConfigDefault() err := app.prefork("invalid", nil, &cfg) require.Error(t, err) go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() ipv6Cfg := ListenConfig{ListenerNetwork: NetworkTCP6} require.NoError(t, app.prefork("[::1]:", nil, &ipv6Cfg)) // Create tls certificate cer, err := tls.LoadX509KeyPair("./.github/testdata/ssl.pem", "./.github/testdata/ssl.key") if err != nil { require.NoError(t, err) } //nolint:gosec // We're in a test so using old ciphers is fine config := &tls.Config{Certificates: []tls.Certificate{cer}} go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() cfg = listenConfigDefault() require.NoError(t, app.prefork("127.0.0.1:", config, &cfg)) } func Test_App_Prefork_Master_Process(t *testing.T) { // Reset test var testPreforkMaster = true app := New() go func() { time.Sleep(1000 * time.Millisecond) assert.NoError(t, app.Shutdown()) }() cfg := listenConfigDefault() require.NoError(t, app.prefork(":0", nil, &cfg)) dummyChildCmd.Store("invalid") cfg = listenConfigDefault() err := app.prefork("127.0.0.1:", nil, &cfg) require.Error(t, err) dummyChildCmd.Store("go") } func Test_App_Prefork_Child_Process_Never_Show_Startup_Message(t *testing.T) { setupIsChild(t) rescueStdout := os.Stdout defer func() { os.Stdout = rescueStdout }() r, w, err := os.Pipe() require.NoError(t, err) os.Stdout = w cfg := listenConfigDefault() app := New() app.startupProcess() listenData := app.prepareListenData(":0", false, &cfg, nil) app.startupMessage(listenData, &cfg) require.NoError(t, w.Close()) out, err := io.ReadAll(r) require.NoError(t, err) require.Empty(t, out) } func setupIsChild(t *testing.T) { t.Helper() t.Setenv(envPreforkChildKey, envPreforkChildVal) } ================================================ FILE: readonly.go ================================================ //go:build !s390x && !ppc64 && !ppc64le package fiber import ( "unsafe" ) //go:linkname runtimeRodata runtime.rodata var runtimeRodata byte //go:linkname runtimeErodata runtime.erodata var runtimeErodata byte func isReadOnly(p unsafe.Pointer) bool { start := uintptr(unsafe.Pointer(&runtimeRodata)) //nolint:gosec // converting runtime symbols end := uintptr(unsafe.Pointer(&runtimeErodata)) //nolint:gosec // converting runtime symbols addr := uintptr(p) return addr >= start && addr < end } ================================================ FILE: readonly_strict.go ================================================ //go:build s390x || ppc64 || ppc64le package fiber import "unsafe" func isReadOnly(_ unsafe.Pointer) bool { return false } ================================================ FILE: redirect.go ================================================ // ⚡️ Fiber is an Express inspired web framework written in Go with ☕️ // 📝 GitHub Repository: https://github.com/gofiber/fiber // 📌 API Documentation: https://docs.gofiber.io package fiber import ( "bytes" "encoding/hex" "sync" "github.com/gofiber/utils/v2" utilsbytes "github.com/gofiber/utils/v2/bytes" "github.com/valyala/bytebufferpool" "github.com/valyala/fasthttp" "github.com/gofiber/fiber/v3/binder" ) // Pool for redirection var ( redirectPool = sync.Pool{ New: func() any { return &Redirect{ status: StatusSeeOther, messages: make(redirectionMsgs, 0), } }, } oldInputPool = sync.Pool{ New: func() any { return make(map[string]string) }, } ) const maxPoolableMapSize = 64 // FlashCookieName Cookie name to send flash messages when to use redirection. const ( FlashCookieName = "fiber_flash" ) var flashCookieNeedle = []byte(FlashCookieName + "=") // hasFlashCookie is on the request hot path and runs on every request/response cycle. // Keep this cheap for users who don't use flash messages: // 1) a fast raw-header prefilter to avoid unnecessary cookie parsing, // 2) an exact cookie lookup to avoid prefix false positives (e.g. fiber_flashX). func hasFlashCookie(header *fasthttp.RequestHeader) bool { rawHeaders := header.RawHeaders() if len(rawHeaders) == 0 { return false } if !bytes.Contains(rawHeaders, flashCookieNeedle) { return false } return header.Cookie(FlashCookieName) != nil } // redirectionMsgs is a struct that used to store flash messages and old input data in cookie using MSGP. // msgp -file="redirect.go" -o="redirect_msgp.go" -unexported // Cookie payloads are limited to ~4KB, so keep flash message counts bounded but usable. // //msgp:limit arrays:256 maps:32 marshal:true //msgp:ignore Redirect RedirectConfig OldInputData FlashMessage type redirectionMsg struct { key string value string level uint8 isOldInput bool } type redirectionMsgs []redirectionMsg // OldInputData is a struct that holds the old input data. type OldInputData struct { Key string Value string } // FlashMessage is a struct that holds the flash message data. type FlashMessage struct { Key string Value string Level uint8 } // Redirect is a struct that holds the redirect data. type Redirect struct { c *DefaultCtx // Embed ctx messages redirectionMsgs // Flash messages and old input data status int // Status code of redirection. Default: 303 StatusSeeOther } // RedirectConfig A config to use with Redirect().Route() // You can specify queries or route parameters. // NOTE: We don't use net/url to parse parameters because of it has poor performance. You have to pass map. type RedirectConfig struct { Params Map // Route parameters Queries map[string]string // Query map } // AcquireRedirect return default Redirect reference from the redirect pool func AcquireRedirect() *Redirect { redirect, ok := redirectPool.Get().(*Redirect) if !ok { panic(errRedirectTypeAssertion) } return redirect } // ReleaseRedirect returns c acquired via Redirect to redirect pool. // // It is forbidden accessing req and/or its members after returning // it to redirect pool. func ReleaseRedirect(r *Redirect) { r.release() redirectPool.Put(r) } func (r *Redirect) release() { r.status = StatusSeeOther r.messages = r.messages[:0] r.c = nil } func acquireOldInput() map[string]string { oldInput, ok := oldInputPool.Get().(map[string]string) if !ok { return make(map[string]string) } return oldInput } func releaseOldInput(oldInput map[string]string) { if len(oldInput) > maxPoolableMapSize { return } clear(oldInput) oldInputPool.Put(oldInput) } // Status sets the status code of redirection. // If status is not specified, status defaults to 303 See Other. func (r *Redirect) Status(code int) *Redirect { r.status = code return r } // With You can send flash messages by using With(). // They will be sent as a cookie. // You can get them by using: Redirect().Messages(), Redirect().Message() // Note: You must use escape char before using ',' and ':' chars to avoid wrong parsing. func (r *Redirect) With(key, value string, level ...uint8) *Redirect { // Get level var msgLevel uint8 if len(level) > 0 { msgLevel = level[0] } // Override old message if exists for i, msg := range r.messages { if msg.key == key && !msg.isOldInput { r.messages[i].value = value r.messages[i].level = msgLevel return r } } r.messages = append(r.messages, redirectionMsg{ key: key, value: value, level: msgLevel, }) return r } // WithInput You can send input data by using WithInput(). // They will be sent as a cookie. // This method can send form, multipart form, query data to redirected route. // You can get them by using: Redirect().OldInputs(), Redirect().OldInput() func (r *Redirect) WithInput() *Redirect { // Get content-type ctype := utils.UnsafeString(utilsbytes.UnsafeToLower(r.c.RequestCtx().Request.Header.ContentType())) ctype = binder.FilterFlags(utils.ParseVendorSpecificContentType(ctype)) oldInput := acquireOldInput() defer releaseOldInput(oldInput) switch ctype { case MIMEApplicationForm, MIMEMultipartForm: _ = r.c.Bind().Form(oldInput) //nolint:errcheck // not needed default: _ = r.c.Bind().Query(oldInput) //nolint:errcheck // not needed } // Add old input data for k, v := range oldInput { r.messages = append(r.messages, redirectionMsg{ key: k, value: v, isOldInput: true, }) } return r } // Messages Get flash messages. func (r *Redirect) Messages() []FlashMessage { if len(r.c.flashMessages) == 0 { return nil } flashMessages := make([]FlashMessage, 0, len(r.c.flashMessages)) writeIdx := 0 for _, msg := range r.c.flashMessages { if msg.isOldInput { r.c.flashMessages[writeIdx] = msg writeIdx++ continue } flashMessages = append(flashMessages, FlashMessage{ Key: msg.key, Value: msg.value, Level: msg.level, }) } for i := writeIdx; i < len(r.c.flashMessages); i++ { r.c.flashMessages[i] = redirectionMsg{} } r.c.flashMessages = r.c.flashMessages[:writeIdx] return flashMessages } // Message Get flash message by key. func (r *Redirect) Message(key string) FlashMessage { if len(r.c.flashMessages) == 0 { return FlashMessage{} } var flashMessage FlashMessage found := false writeIdx := 0 for _, msg := range r.c.flashMessages { if msg.isOldInput || found || msg.key != key { r.c.flashMessages[writeIdx] = msg writeIdx++ continue } flashMessage = FlashMessage{ Key: msg.key, Value: msg.value, Level: msg.level, } found = true } for i := writeIdx; i < len(r.c.flashMessages); i++ { r.c.flashMessages[i] = redirectionMsg{} } r.c.flashMessages = r.c.flashMessages[:writeIdx] return flashMessage } // OldInputs Get old input data. func (r *Redirect) OldInputs() []OldInputData { // Count old inputs first to avoid allocation if none exist count := 0 for _, msg := range r.c.flashMessages { if msg.isOldInput { count++ } } if count == 0 { return nil } inputs := make([]OldInputData, 0, count) for _, msg := range r.c.flashMessages { if msg.isOldInput { inputs = append(inputs, OldInputData{ Key: msg.key, Value: msg.value, }) } } return inputs } // OldInput Get old input data by key. func (r *Redirect) OldInput(key string) OldInputData { msgs := r.c.flashMessages for _, msg := range msgs { if msg.key == key && msg.isOldInput { return OldInputData{ Key: msg.key, Value: msg.value, } } } return OldInputData{} } // To redirect to the URL derived from the specified path, with specified status. func (r *Redirect) To(location string) error { r.c.setCanonical(HeaderLocation, location) r.c.Status(r.status) r.processFlashMessages() return nil } // Route redirects to the Route registered in the app with appropriate parameters. // If you want to send queries or params to route, you should use config parameter. func (r *Redirect) Route(name string, config ...RedirectConfig) error { // Check config cfg := RedirectConfig{} if len(config) > 0 { cfg = config[0] } // Get location from route name route := r.c.App().GetRoute(name) location, err := r.c.getLocationFromRoute(&route, cfg.Params) if err != nil { return err } // Check queries if len(cfg.Queries) > 0 { queryText := bytebufferpool.Get() defer bytebufferpool.Put(queryText) first := true for k, v := range cfg.Queries { if !first { queryText.WriteByte('&') } first = false queryText.WriteString(k) queryText.WriteByte('=') queryText.WriteString(v) } return r.To(location + "?" + r.c.app.toString(queryText.Bytes())) } return r.To(location) } // Back redirect to the URL to referer. func (r *Redirect) Back(fallback ...string) error { location := r.c.Get(HeaderReferer) if location == "" { // Check fallback URL if len(fallback) == 0 { err := ErrRedirectBackNoFallback r.c.Status(err.Code) return err } location = fallback[0] } return r.To(location) } // parseAndClearFlashMessages is a method to get flash messages before they are getting removed func (r *Redirect) parseAndClearFlashMessages() { // parse flash messages cookieValue, err := hex.DecodeString(r.c.Cookies(FlashCookieName)) if err != nil { return } _, err = r.c.flashMessages.UnmarshalMsg(cookieValue) if err != nil { return } r.c.Cookie(&Cookie{ Name: FlashCookieName, Value: "", Path: "/", MaxAge: -1, }) } // processFlashMessages is a helper function to process flash messages and old input data // and set them as cookies func (r *Redirect) processFlashMessages() { if len(r.messages) == 0 { return } val, err := r.messages.MarshalMsg(nil) if err != nil { return } dst := make([]byte, hex.EncodedLen(len(val))) hex.Encode(dst, val) r.c.Cookie(&Cookie{ Name: FlashCookieName, Value: r.c.app.toString(dst), SessionOnly: true, }) } ================================================ FILE: redirect_msgp.go ================================================ // Code generated by github.com/tinylib/msgp DO NOT EDIT. package fiber import ( "github.com/tinylib/msgp/msgp" ) // Size limits for msgp deserialization const ( zc920acdalimitArrays = 256 zc920acdalimitMaps = 32 ) // DecodeMsg implements msgp.Decodable func (z *redirectionMsg) DecodeMsg(dc *msgp.Reader) (err error) { var field []byte _ = field var zb0001 uint32 zb0001, err = dc.ReadMapHeader() if err != nil { err = msgp.WrapError(err) return } if zb0001 > zc920acdalimitMaps { err = msgp.ErrLimitExceeded return } for zb0001 > 0 { zb0001-- field, err = dc.ReadMapKeyPtr() if err != nil { err = msgp.WrapError(err) return } switch msgp.UnsafeString(field) { case "key": z.key, err = dc.ReadString() if err != nil { err = msgp.WrapError(err, "key") return } case "value": z.value, err = dc.ReadString() if err != nil { err = msgp.WrapError(err, "value") return } case "level": z.level, err = dc.ReadUint8() if err != nil { err = msgp.WrapError(err, "level") return } case "isOldInput": z.isOldInput, err = dc.ReadBool() if err != nil { err = msgp.WrapError(err, "isOldInput") return } default: err = dc.Skip() if err != nil { err = msgp.WrapError(err) return } } } return } // EncodeMsg implements msgp.Encodable func (z *redirectionMsg) EncodeMsg(en *msgp.Writer) (err error) { // map header, size 4 // write "key" err = en.Append(0x84, 0xa3, 0x6b, 0x65, 0x79) if err != nil { return } err = en.WriteString(z.key) if err != nil { err = msgp.WrapError(err, "key") return } // write "value" err = en.Append(0xa5, 0x76, 0x61, 0x6c, 0x75, 0x65) if err != nil { return } err = en.WriteString(z.value) if err != nil { err = msgp.WrapError(err, "value") return } // write "level" err = en.Append(0xa5, 0x6c, 0x65, 0x76, 0x65, 0x6c) if err != nil { return } err = en.WriteUint8(z.level) if err != nil { err = msgp.WrapError(err, "level") return } // write "isOldInput" err = en.Append(0xaa, 0x69, 0x73, 0x4f, 0x6c, 0x64, 0x49, 0x6e, 0x70, 0x75, 0x74) if err != nil { return } err = en.WriteBool(z.isOldInput) if err != nil { err = msgp.WrapError(err, "isOldInput") return } return } // MarshalMsg implements msgp.Marshaler func (z *redirectionMsg) MarshalMsg(b []byte) (o []byte, err error) { o = msgp.Require(b, z.Msgsize()) // map header, size 4 // string "key" o = append(o, 0x84, 0xa3, 0x6b, 0x65, 0x79) o = msgp.AppendString(o, z.key) // string "value" o = append(o, 0xa5, 0x76, 0x61, 0x6c, 0x75, 0x65) o = msgp.AppendString(o, z.value) // string "level" o = append(o, 0xa5, 0x6c, 0x65, 0x76, 0x65, 0x6c) o = msgp.AppendUint8(o, z.level) // string "isOldInput" o = append(o, 0xaa, 0x69, 0x73, 0x4f, 0x6c, 0x64, 0x49, 0x6e, 0x70, 0x75, 0x74) o = msgp.AppendBool(o, z.isOldInput) return } // UnmarshalMsg implements msgp.Unmarshaler func (z *redirectionMsg) UnmarshalMsg(bts []byte) (o []byte, err error) { var field []byte _ = field var zb0001 uint32 zb0001, bts, err = msgp.ReadMapHeaderBytes(bts) if err != nil { err = msgp.WrapError(err) return } if zb0001 > zc920acdalimitMaps { err = msgp.ErrLimitExceeded return } for zb0001 > 0 { zb0001-- field, bts, err = msgp.ReadMapKeyZC(bts) if err != nil { err = msgp.WrapError(err) return } switch msgp.UnsafeString(field) { case "key": z.key, bts, err = msgp.ReadStringBytes(bts) if err != nil { err = msgp.WrapError(err, "key") return } case "value": z.value, bts, err = msgp.ReadStringBytes(bts) if err != nil { err = msgp.WrapError(err, "value") return } case "level": z.level, bts, err = msgp.ReadUint8Bytes(bts) if err != nil { err = msgp.WrapError(err, "level") return } case "isOldInput": z.isOldInput, bts, err = msgp.ReadBoolBytes(bts) if err != nil { err = msgp.WrapError(err, "isOldInput") return } default: bts, err = msgp.Skip(bts) if err != nil { err = msgp.WrapError(err) return } } } o = bts return } // Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message func (z *redirectionMsg) Msgsize() (s int) { s = 1 + 4 + msgp.StringPrefixSize + len(z.key) + 6 + msgp.StringPrefixSize + len(z.value) + 6 + msgp.Uint8Size + 11 + msgp.BoolSize return } // DecodeMsg implements msgp.Decodable func (z *redirectionMsgs) DecodeMsg(dc *msgp.Reader) (err error) { var zb0002 uint32 zb0002, err = dc.ReadArrayHeader() if err != nil { err = msgp.WrapError(err) return } if zb0002 > zc920acdalimitArrays { err = msgp.ErrLimitExceeded return } if cap((*z)) >= int(zb0002) { (*z) = (*z)[:zb0002] } else { (*z) = make(redirectionMsgs, zb0002) } for zb0001 := range *z { err = (*z)[zb0001].DecodeMsg(dc) if err != nil { err = msgp.WrapError(err, zb0001) return } } return } // EncodeMsg implements msgp.Encodable func (z redirectionMsgs) EncodeMsg(en *msgp.Writer) (err error) { err = en.WriteArrayHeader(uint32(len(z))) if err != nil { err = msgp.WrapError(err) return } if uint32(len(z)) > zc920acdalimitArrays { err = msgp.ErrLimitExceeded return } for zb0003 := range z { err = z[zb0003].EncodeMsg(en) if err != nil { err = msgp.WrapError(err, zb0003) return } } return } // MarshalMsg implements msgp.Marshaler func (z redirectionMsgs) MarshalMsg(b []byte) (o []byte, err error) { o = msgp.Require(b, z.Msgsize()) o = msgp.AppendArrayHeader(o, uint32(len(z))) if uint32(len(z)) > zc920acdalimitArrays { return nil, msgp.ErrLimitExceeded } for zb0003 := range z { o, err = z[zb0003].MarshalMsg(o) if err != nil { err = msgp.WrapError(err, zb0003) return } } return } // UnmarshalMsg implements msgp.Unmarshaler func (z *redirectionMsgs) UnmarshalMsg(bts []byte) (o []byte, err error) { var zb0002 uint32 zb0002, bts, err = msgp.ReadArrayHeaderBytes(bts) if err != nil { err = msgp.WrapError(err) return } if zb0002 > zc920acdalimitArrays { err = msgp.ErrLimitExceeded return } if cap((*z)) >= int(zb0002) { (*z) = (*z)[:zb0002] } else { (*z) = make(redirectionMsgs, zb0002) } for zb0001 := range *z { bts, err = (*z)[zb0001].UnmarshalMsg(bts) if err != nil { err = msgp.WrapError(err, zb0001) return } } o = bts return } // Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message func (z redirectionMsgs) Msgsize() (s int) { s = msgp.ArrayHeaderSize for zb0003 := range z { s += z[zb0003].Msgsize() } return } ================================================ FILE: redirect_msgp_test.go ================================================ // Code generated by github.com/tinylib/msgp DO NOT EDIT. package fiber import ( "bytes" "testing" "github.com/tinylib/msgp/msgp" ) func TestMarshalUnmarshalredirectionMsg(t *testing.T) { v := redirectionMsg{} bts, err := v.MarshalMsg(nil) if err != nil { t.Fatal(err) } left, err := v.UnmarshalMsg(bts) if err != nil { t.Fatal(err) } if len(left) > 0 { t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) } left, err = msgp.Skip(bts) if err != nil { t.Fatal(err) } if len(left) > 0 { t.Errorf("%d bytes left over after Skip(): %q", len(left), left) } } func BenchmarkMarshalMsgredirectionMsg(b *testing.B) { v := redirectionMsg{} b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { v.MarshalMsg(nil) } } func BenchmarkAppendMsgredirectionMsg(b *testing.B) { v := redirectionMsg{} bts := make([]byte, 0, v.Msgsize()) bts, _ = v.MarshalMsg(bts[0:0]) b.SetBytes(int64(len(bts))) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { bts, _ = v.MarshalMsg(bts[0:0]) } } func BenchmarkUnmarshalredirectionMsg(b *testing.B) { v := redirectionMsg{} bts, _ := v.MarshalMsg(nil) b.ReportAllocs() b.SetBytes(int64(len(bts))) b.ResetTimer() for i := 0; i < b.N; i++ { _, err := v.UnmarshalMsg(bts) if err != nil { b.Fatal(err) } } } func TestEncodeDecoderedirectionMsg(t *testing.T) { v := redirectionMsg{} var buf bytes.Buffer msgp.Encode(&buf, &v) m := v.Msgsize() if buf.Len() > m { t.Log("WARNING: TestEncodeDecoderedirectionMsg Msgsize() is inaccurate") } vn := redirectionMsg{} err := msgp.Decode(&buf, &vn) if err != nil { t.Error(err) } buf.Reset() msgp.Encode(&buf, &v) err = msgp.NewReader(&buf).Skip() if err != nil { t.Error(err) } } func BenchmarkEncoderedirectionMsg(b *testing.B) { v := redirectionMsg{} var buf bytes.Buffer msgp.Encode(&buf, &v) b.SetBytes(int64(buf.Len())) en := msgp.NewWriter(msgp.Nowhere) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { v.EncodeMsg(en) } en.Flush() } func BenchmarkDecoderedirectionMsg(b *testing.B) { v := redirectionMsg{} var buf bytes.Buffer msgp.Encode(&buf, &v) b.SetBytes(int64(buf.Len())) rd := msgp.NewEndlessReader(buf.Bytes(), b) dc := msgp.NewReader(rd) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { err := v.DecodeMsg(dc) if err != nil { b.Fatal(err) } } } func TestMarshalUnmarshalredirectionMsgs(t *testing.T) { v := redirectionMsgs{} bts, err := v.MarshalMsg(nil) if err != nil { t.Fatal(err) } left, err := v.UnmarshalMsg(bts) if err != nil { t.Fatal(err) } if len(left) > 0 { t.Errorf("%d bytes left over after UnmarshalMsg(): %q", len(left), left) } left, err = msgp.Skip(bts) if err != nil { t.Fatal(err) } if len(left) > 0 { t.Errorf("%d bytes left over after Skip(): %q", len(left), left) } } func BenchmarkMarshalMsgredirectionMsgs(b *testing.B) { v := redirectionMsgs{} b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { v.MarshalMsg(nil) } } func BenchmarkAppendMsgredirectionMsgs(b *testing.B) { v := redirectionMsgs{} bts := make([]byte, 0, v.Msgsize()) bts, _ = v.MarshalMsg(bts[0:0]) b.SetBytes(int64(len(bts))) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { bts, _ = v.MarshalMsg(bts[0:0]) } } func BenchmarkUnmarshalredirectionMsgs(b *testing.B) { v := redirectionMsgs{} bts, _ := v.MarshalMsg(nil) b.ReportAllocs() b.SetBytes(int64(len(bts))) b.ResetTimer() for i := 0; i < b.N; i++ { _, err := v.UnmarshalMsg(bts) if err != nil { b.Fatal(err) } } } func TestEncodeDecoderedirectionMsgs(t *testing.T) { v := redirectionMsgs{} var buf bytes.Buffer msgp.Encode(&buf, &v) m := v.Msgsize() if buf.Len() > m { t.Log("WARNING: TestEncodeDecoderedirectionMsgs Msgsize() is inaccurate") } vn := redirectionMsgs{} err := msgp.Decode(&buf, &vn) if err != nil { t.Error(err) } buf.Reset() msgp.Encode(&buf, &v) err = msgp.NewReader(&buf).Skip() if err != nil { t.Error(err) } } func BenchmarkEncoderedirectionMsgs(b *testing.B) { v := redirectionMsgs{} var buf bytes.Buffer msgp.Encode(&buf, &v) b.SetBytes(int64(buf.Len())) en := msgp.NewWriter(msgp.Nowhere) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { v.EncodeMsg(en) } en.Flush() } func BenchmarkDecoderedirectionMsgs(b *testing.B) { v := redirectionMsgs{} var buf bytes.Buffer msgp.Encode(&buf, &v) b.SetBytes(int64(buf.Len())) rd := msgp.NewEndlessReader(buf.Bytes(), b) dc := msgp.NewReader(rd) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { err := v.DecodeMsg(dc) if err != nil { b.Fatal(err) } } } ================================================ FILE: redirect_test.go ================================================ // ⚡️ Fiber is an Express inspired web framework written in Go with ☕️ // 📝 GitHub Repository: https://github.com/gofiber/fiber // 📌 API Documentation: https://docs.gofiber.io package fiber import ( "bytes" "encoding/hex" "encoding/json" "io" "mime/multipart" "net/http" "net/http/httptest" "net/url" "strings" "testing" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) func assertFlashCookieCleared(t *testing.T, setCookie string) { t.Helper() setCookie = strings.ToLower(setCookie) require.Contains(t, setCookie, FlashCookieName+"=") require.True(t, strings.Contains(setCookie, "max-age=0") || strings.Contains(setCookie, "max-age=-1")) require.Contains(t, setCookie, "path=/") } // go test -run Test_Redirect_To func Test_Redirect_To(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(c) err := c.Redirect().To("http://default.com") require.NoError(t, err) require.Equal(t, StatusSeeOther, c.Response().StatusCode()) require.Equal(t, "http://default.com", string(c.Response().Header.Peek(HeaderLocation))) err = c.Redirect().Status(301).To("http://example.com") require.NoError(t, err) require.Equal(t, 301, c.Response().StatusCode()) require.Equal(t, "http://example.com", string(c.Response().Header.Peek(HeaderLocation))) } // go test -run Test_Redirect_To_WithFlashMessages func Test_Redirect_To_WithFlashMessages(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) err := c.Redirect().With("success", "2").With("success", "1").With("message", "test", 2).To("http://example.com") require.NoError(t, err) require.Equal(t, StatusSeeOther, c.Response().StatusCode()) require.Equal(t, "http://example.com", string(c.Response().Header.Peek(HeaderLocation))) c.RequestCtx().Request.Header.Set(HeaderCookie, c.GetRespHeader(HeaderSetCookie)) // necessary for testing var msgs redirectionMsgs decoded, err := hex.DecodeString(c.Cookies(FlashCookieName)) require.NoError(t, err) _, err = msgs.UnmarshalMsg(decoded) require.NoError(t, err) require.Len(t, msgs, 2) require.Contains(t, msgs, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false}) require.Contains(t, msgs, redirectionMsg{key: "message", value: "test", level: 2, isOldInput: false}) } // go test -run Test_Redirect_Route_WithParams func Test_Redirect_Route_WithParams(t *testing.T) { t.Parallel() app := New() app.Get("/user/:name", func(c Ctx) error { return c.JSON(c.Params("name")) }).Name("user") c := app.AcquireCtx(&fasthttp.RequestCtx{}) err := c.Redirect().Route("user", RedirectConfig{ Params: Map{ "name": "fiber", }, }) require.NoError(t, err) require.Equal(t, StatusSeeOther, c.Response().StatusCode()) require.Equal(t, "/user/fiber", string(c.Response().Header.Peek(HeaderLocation))) } // go test -run Test_Redirect_Route_WithParams_WithQueries func Test_Redirect_Route_WithParams_WithQueries(t *testing.T) { t.Parallel() app := New() app.Get("/user/:name", func(c Ctx) error { return c.JSON(c.Params("name")) }).Name("user") c := app.AcquireCtx(&fasthttp.RequestCtx{}) err := c.Redirect().Route("user", RedirectConfig{ Params: Map{ "name": "fiber", }, Queries: map[string]string{"data[0][name]": "john", "data[0][age]": "10", "test": "doe"}, }) require.NoError(t, err) require.Equal(t, StatusSeeOther, c.Response().StatusCode()) // analysis of query parameters with url parsing, since a map pass is always randomly ordered location, err := url.Parse(string(c.Response().Header.Peek(HeaderLocation))) require.NoError(t, err, "url.Parse(location)") require.Equal(t, "/user/fiber", location.Path) require.Equal(t, url.Values{"data[0][name]": []string{"john"}, "data[0][age]": []string{"10"}, "test": []string{"doe"}}, location.Query()) } // go test -run Test_Redirect_Route_WithOptionalParams func Test_Redirect_Route_WithOptionalParams(t *testing.T) { t.Parallel() app := New() app.Get("/user/:name?", func(c Ctx) error { return c.JSON(c.Params("name")) }).Name("user") c := app.AcquireCtx(&fasthttp.RequestCtx{}) err := c.Redirect().Route("user", RedirectConfig{ Params: Map{ "name": "fiber", }, }) require.NoError(t, err) require.Equal(t, StatusSeeOther, c.Response().StatusCode()) require.Equal(t, "/user/fiber", string(c.Response().Header.Peek(HeaderLocation))) } // go test -run Test_Redirect_Route_WithOptionalParamsWithoutValue func Test_Redirect_Route_WithOptionalParamsWithoutValue(t *testing.T) { t.Parallel() app := New() app.Get("/user/:name?", func(c Ctx) error { return c.JSON(c.Params("name")) }).Name("user") c := app.AcquireCtx(&fasthttp.RequestCtx{}) err := c.Redirect().Route("user") require.NoError(t, err) require.Equal(t, StatusSeeOther, c.Response().StatusCode()) require.Equal(t, "/user/", string(c.Response().Header.Peek(HeaderLocation))) } // go test -run Test_Redirect_Route_WithGreedyParameters func Test_Redirect_Route_WithGreedyParameters(t *testing.T) { t.Parallel() app := New() app.Get("/user/+", func(c Ctx) error { return c.JSON(c.Params("+")) }).Name("user") c := app.AcquireCtx(&fasthttp.RequestCtx{}) err := c.Redirect().Route("user", RedirectConfig{ Params: Map{ "+": "test/routes", }, }) require.NoError(t, err) require.Equal(t, StatusSeeOther, c.Response().StatusCode()) require.Equal(t, "/user/test/routes", string(c.Response().Header.Peek(HeaderLocation))) } // go test -run Test_Redirect_Back func Test_Redirect_Back(t *testing.T) { t.Parallel() app := New() app.Get("/", func(c Ctx) error { return c.JSON("Home") }).Name("home") c := app.AcquireCtx(&fasthttp.RequestCtx{}) err := c.Redirect().Back("/") require.NoError(t, err) require.Equal(t, StatusSeeOther, c.Response().StatusCode()) require.Equal(t, "/", string(c.Response().Header.Peek(HeaderLocation))) err = c.Redirect().Back() require.Equal(t, 500, c.Response().StatusCode()) require.ErrorAs(t, err, &ErrRedirectBackNoFallback) } // go test -run Test_Redirect_Back_WithFlashMessages func Test_Redirect_Back_WithFlashMessages(t *testing.T) { t.Parallel() app := New() app.Get("/user", func(c Ctx) error { return c.SendString("user") }).Name("user") c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed err := c.Redirect().With("success", "1").With("message", "test").Back("/") require.NoError(t, err) require.Equal(t, StatusSeeOther, c.Response().StatusCode()) require.Equal(t, "/", string(c.Response().Header.Peek(HeaderLocation))) c.RequestCtx().Request.Header.Set(HeaderCookie, c.GetRespHeader(HeaderSetCookie)) // necessary for testing var msgs redirectionMsgs decoded, err := hex.DecodeString(c.Cookies(FlashCookieName)) require.NoError(t, err) _, err = msgs.UnmarshalMsg(decoded) require.NoError(t, err) require.Len(t, msgs, 2) require.Contains(t, msgs, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false}) require.Contains(t, msgs, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false}) } // go test -run Test_Redirect_Back_WithReferer func Test_Redirect_Back_WithReferer(t *testing.T) { t.Parallel() app := New() app.Get("/", func(c Ctx) error { return c.JSON("Home") }).Name("home") app.Get("/back", func(c Ctx) error { return c.JSON("Back") }).Name("back") c := app.AcquireCtx(&fasthttp.RequestCtx{}) c.Request().Header.Set(HeaderReferer, "/back") err := c.Redirect().Back("/") require.NoError(t, err) require.Equal(t, StatusSeeOther, c.Response().StatusCode()) require.Equal(t, "/back", c.Get(HeaderReferer)) require.Equal(t, "/back", string(c.Response().Header.Peek(HeaderLocation))) } // go test -run Test_Redirect_Route_WithFlashMessages func Test_Redirect_Route_WithFlashMessages(t *testing.T) { t.Parallel() app := New() app.Get("/user", func(c Ctx) error { return c.SendString("user") }).Name("user") c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed err := c.Redirect().With("success", "1").With("message", "test").Route("user") require.Contains(t, c.redirect.messages, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false}) require.Contains(t, c.redirect.messages, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false}) require.NoError(t, err) require.Equal(t, StatusSeeOther, c.Response().StatusCode()) require.Equal(t, "/user", string(c.Response().Header.Peek(HeaderLocation))) c.RequestCtx().Request.Header.Set(HeaderCookie, c.GetRespHeader(HeaderSetCookie)) // necessary for testing var msgs redirectionMsgs decoded, err := hex.DecodeString(c.Cookies(FlashCookieName)) require.NoError(t, err) _, err = msgs.UnmarshalMsg(decoded) require.NoError(t, err) require.Len(t, msgs, 2) require.Contains(t, msgs, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false}) require.Contains(t, msgs, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false}) } // go test -run Test_Redirect_Route_WithOldInput func Test_Redirect_Route_WithOldInput(t *testing.T) { t.Parallel() t.Run("Query", func(t *testing.T) { t.Parallel() app := New() app.Get("/user", func(c Ctx) error { return c.SendString("user") }).Name("user") c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed c.Request().URI().SetQueryString("id=1&name=tom") err := c.Redirect().With("success", "1").With("message", "test").WithInput().Route("user") require.Contains(t, c.redirect.messages, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false}) require.Contains(t, c.redirect.messages, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false}) require.Contains(t, c.redirect.messages, redirectionMsg{key: "id", value: "1", isOldInput: true}) require.Contains(t, c.redirect.messages, redirectionMsg{key: "name", value: "tom", isOldInput: true}) require.NoError(t, err) require.Equal(t, StatusSeeOther, c.Response().StatusCode()) require.Equal(t, "/user", string(c.Response().Header.Peek(HeaderLocation))) c.RequestCtx().Request.Header.Set(HeaderCookie, c.GetRespHeader(HeaderSetCookie)) // necessary for testing var msgs redirectionMsgs decoded, err := hex.DecodeString(c.Cookies(FlashCookieName)) require.NoError(t, err) _, err = msgs.UnmarshalMsg(decoded) require.NoError(t, err) require.Len(t, msgs, 4) require.Contains(t, msgs, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false}) require.Contains(t, msgs, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false}) require.Contains(t, msgs, redirectionMsg{key: "id", value: "1", level: 0, isOldInput: true}) require.Contains(t, msgs, redirectionMsg{key: "name", value: "tom", level: 0, isOldInput: true}) }) t.Run("Form", func(t *testing.T) { t.Parallel() app := New() app.Post("/user", func(c Ctx) error { return c.SendString("user") }).Name("user") c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed c.Request().Header.Set(HeaderContentType, MIMEApplicationForm) c.Request().SetBodyString("id=1&name=tom") err := c.Redirect().With("success", "1").With("message", "test").WithInput().Route("user") require.Contains(t, c.redirect.messages, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false}) require.Contains(t, c.redirect.messages, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false}) require.Contains(t, c.redirect.messages, redirectionMsg{key: "id", value: "1", isOldInput: true}) require.Contains(t, c.redirect.messages, redirectionMsg{key: "name", value: "tom", isOldInput: true}) require.NoError(t, err) require.Equal(t, StatusSeeOther, c.Response().StatusCode()) require.Equal(t, "/user", string(c.Response().Header.Peek(HeaderLocation))) c.RequestCtx().Request.Header.Set(HeaderCookie, c.GetRespHeader(HeaderSetCookie)) // necessary for testing var msgs redirectionMsgs decoded, err := hex.DecodeString(c.Cookies(FlashCookieName)) require.NoError(t, err) _, err = msgs.UnmarshalMsg(decoded) require.NoError(t, err) require.Len(t, msgs, 4) require.Contains(t, msgs, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false}) require.Contains(t, msgs, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false}) require.Contains(t, msgs, redirectionMsg{key: "id", value: "1", level: 0, isOldInput: true}) require.Contains(t, msgs, redirectionMsg{key: "name", value: "tom", level: 0, isOldInput: true}) }) t.Run("MultipartForm", func(t *testing.T) { t.Parallel() app := New() app.Get("/user", func(c Ctx) error { return c.SendString("user") }).Name("user") c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed body := &bytes.Buffer{} writer := multipart.NewWriter(body) require.NoError(t, writer.WriteField("id", "1")) require.NoError(t, writer.WriteField("name", "tom")) require.NoError(t, writer.Close()) c.Request().SetBody(body.Bytes()) c.Request().Header.Set(HeaderContentType, writer.FormDataContentType()) err := c.Redirect().With("success", "1").With("message", "test").WithInput().Route("user") require.Contains(t, c.redirect.messages, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false}) require.Contains(t, c.redirect.messages, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false}) require.Contains(t, c.redirect.messages, redirectionMsg{key: "id", value: "1", isOldInput: true}) require.Contains(t, c.redirect.messages, redirectionMsg{key: "name", value: "tom", isOldInput: true}) require.NoError(t, err) require.Equal(t, StatusSeeOther, c.Response().StatusCode()) require.Equal(t, "/user", string(c.Response().Header.Peek(HeaderLocation))) c.RequestCtx().Request.Header.Set(HeaderCookie, c.GetRespHeader(HeaderSetCookie)) // necessary for testing var msgs redirectionMsgs decoded, err := hex.DecodeString(c.Cookies(FlashCookieName)) require.NoError(t, err) _, err = msgs.UnmarshalMsg(decoded) require.NoError(t, err) require.Len(t, msgs, 4) require.Contains(t, msgs, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false}) require.Contains(t, msgs, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false}) require.Contains(t, msgs, redirectionMsg{key: "id", value: "1", level: 0, isOldInput: true}) require.Contains(t, msgs, redirectionMsg{key: "name", value: "tom", level: 0, isOldInput: true}) }) } func Test_Redirect_WithInput_ReusesClearedMap(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed defer app.ReleaseCtx(c) c.Request().URI().SetQueryString("first=1") c.Redirect().WithInput() require.Contains(t, c.redirect.messages, redirectionMsg{key: "first", value: "1", isOldInput: true}) c.redirect.messages = c.redirect.messages[:0] c.Request().URI().SetQueryString("second=2") c.Redirect().WithInput() require.Len(t, c.redirect.messages, 1) require.Contains(t, c.redirect.messages, redirectionMsg{key: "second", value: "2", isOldInput: true}) require.NotContains(t, c.redirect.messages, redirectionMsg{key: "first", value: "1", isOldInput: true}) } // go test -run Test_Redirect_parseAndClearFlashMessages func Test_Redirect_parseAndClearFlashMessages(t *testing.T) { t.Parallel() app := New() app.Get("/user", func(c Ctx) error { return c.SendString("user") }).Name("user") c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed msgs := redirectionMsgs{ { key: "success", value: "1", }, { key: "message", value: "test", }, { key: "name", value: "tom", isOldInput: true, }, { key: "id", value: "1", isOldInput: true, }, } val, err := msgs.MarshalMsg(nil) require.NoError(t, err) c.Request().Header.Set(HeaderCookie, "fiber_flash="+hex.EncodeToString(val)) c.Redirect().parseAndClearFlashMessages() require.Equal(t, FlashMessage{ Key: "success", Value: "1", Level: 0, }, c.Redirect().Message("success")) require.Equal(t, FlashMessage{ Key: "message", Value: "test", Level: 0, }, c.Redirect().Message("message")) require.Equal(t, FlashMessage{}, c.Redirect().Message("success")) require.Equal(t, FlashMessage{}, c.Redirect().Message("not_message")) require.Empty(t, c.Redirect().Messages()) require.Equal(t, OldInputData{ Key: "id", Value: "1", }, c.Redirect().OldInput("id")) require.Equal(t, OldInputData{ Key: "name", Value: "tom", }, c.Redirect().OldInput("name")) require.Equal(t, OldInputData{}, c.Redirect().OldInput("not_name")) require.Equal(t, []OldInputData{ { Key: "name", Value: "tom", }, { Key: "id", Value: "1", }, }, c.Redirect().OldInputs()) assertFlashCookieCleared(t, string(c.Response().Header.Peek(HeaderSetCookie))) c.Request().Header.Set(HeaderCookie, "fiber_flash=test") c.Redirect().parseAndClearFlashMessages() require.Empty(t, c.Redirect().messages) } // Test_Redirect_parseAndClearFlashMessages_InvalidHex tests the case where hex decoding fails func Test_Redirect_parseAndClearFlashMessages_InvalidHex(t *testing.T) { t.Parallel() app := New() // Setup request and response c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed defer app.ReleaseCtx(c) // Create redirect instance r := AcquireRedirect() r.c = c // Set invalid hex value in flash cookie c.Request().Header.SetCookie(FlashCookieName, "not-a-valid-hex-string") // Call parseAndClearFlashMessages r.parseAndClearFlashMessages() // Verify that no flash messages are processed (should be empty) require.Empty(t, r.messages) // Release redirect ReleaseRedirect(r) } func Test_Redirect_Messages_ClearsFlashMessages(t *testing.T) { t.Parallel() app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed defer app.ReleaseCtx(c) val, err := testredirectionMsgs.MarshalMsg(nil) require.NoError(t, err) c.Request().Header.Set(HeaderCookie, "fiber_flash="+hex.EncodeToString(val)) c.Redirect().parseAndClearFlashMessages() require.Equal(t, []FlashMessage{ { Key: "success", Value: "1", Level: 0, }, { Key: "message", Value: "test", Level: 0, }, }, c.Redirect().Messages()) require.Empty(t, c.Redirect().Messages()) require.Equal(t, FlashMessage{}, c.Redirect().Message("success")) require.Equal(t, []OldInputData{ { Key: "name", Value: "tom", }, { Key: "id", Value: "1", }, }, c.Redirect().OldInputs()) } func Test_Redirect_CompleteFlowWithFlashMessages(t *testing.T) { t.Parallel() app := New() // First handler that sets flash messages and redirects app.Get("/source", func(c Ctx) error { // Redirect to the target handler return c.Redirect().With("string_message", "Hello, World!"). With("number_message", "12345"). With("bool_message", "true"). To("/target") }) // Second handler that receives and processes flash messages app.Get("/target", func(c Ctx) error { // Get all flash messages and return them as a JSON response return c.JSON(Map{ "string_message": c.Redirect().Message("string_message").Value, "number_message": c.Redirect().Message("number_message").Value, "bool_message": c.Redirect().Message("bool_message").Value, }) }) // Step 1: Make the initial request to the source route req := httptest.NewRequest(MethodGet, "/source", http.NoBody) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, StatusSeeOther, resp.StatusCode) require.Equal(t, "/target", resp.Header.Get(HeaderLocation)) // Verify and get the cookie from the response cookies := resp.Cookies() var flashCookie *http.Cookie for _, cookie := range cookies { if cookie.Name == "fiber_flash" { flashCookie = cookie break } } require.NotNil(t, flashCookie, "Flash cookie should be set") // Step 2: Make the second request to the target route with the cookie req = httptest.NewRequest(MethodGet, "/target", http.NoBody) req.Header.Set("Cookie", flashCookie.Name+"="+flashCookie.Value) resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, StatusOK, resp.StatusCode) assertFlashCookieCleared(t, resp.Header.Get(HeaderSetCookie)) // Parse the JSON response and verify flash messages body, err := io.ReadAll(resp.Body) require.NoError(t, err) var result map[string]any err = json.Unmarshal(body, &result) require.NoError(t, err) // Verify all flash messages were received correctly require.Equal(t, "Hello, World!", result["string_message"]) require.Equal(t, "12345", result["number_message"]) // JSON numbers are float64 require.Equal(t, "true", result["bool_message"]) } func Test_Redirect_FlashMessagesWithSpecialChars(t *testing.T) { t.Parallel() app := New() // Handler that sets flash messages with special characters and redirects app.Get("/special-source", func(c Ctx) error { // Create a large message to test encoding of larger data return c.Redirect().With("null_bytes", "Contains\x00null\x00bytes"). With("control_chars", "Contains\r\ncontrol\tcharacters"). With("unicode", "Unicode: 你好世界"). With("emoji", "Emoji: 🔥🚀😊"). To("/special-target") }) // Target handler that receives the flash messages app.Get("/special-target", func(c Ctx) error { return c.JSON(Map{ "null_bytes": c.Redirect().Message("null_bytes").Value, "control_chars": c.Redirect().Message("control_chars").Value, "unicode": c.Redirect().Message("unicode").Value, "emoji": c.Redirect().Message("emoji").Value, }) }) // Step 1: Make the initial request req := httptest.NewRequest(MethodGet, "/special-source", http.NoBody) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, StatusSeeOther, resp.StatusCode) require.Equal(t, "/special-target", resp.Header.Get(HeaderLocation)) // Get the flash cookie var flashCookie *http.Cookie for _, cookie := range resp.Cookies() { if cookie.Name == "fiber_flash" { flashCookie = cookie break } } require.NotNil(t, flashCookie, "Flash cookie should be set") // Step 2: Make the second request with the cookie req = httptest.NewRequest(MethodGet, "/special-target", http.NoBody) req.Header.Set("Cookie", flashCookie.Name+"="+flashCookie.Value) resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, StatusOK, resp.StatusCode) assertFlashCookieCleared(t, resp.Header.Get(HeaderSetCookie)) // Parse and verify the response body, err := io.ReadAll(resp.Body) require.NoError(t, err) var result map[string]any err = json.Unmarshal(body, &result) require.NoError(t, err) // Verify special character handling require.Equal(t, "Contains\x00null\x00bytes", result["null_bytes"]) require.Equal(t, "Contains\r\ncontrol\tcharacters", result["control_chars"]) require.Equal(t, "Unicode: 你好世界", result["unicode"]) require.Equal(t, "Emoji: 🔥🚀😊", result["emoji"]) } // go test -v -run=^$ -bench=Benchmark_Redirect_Route -benchmem -count=4 func Benchmark_Redirect_Route(b *testing.B) { app := New() app.Get("/user/:name", func(c Ctx) error { return c.JSON(c.Params("name")) }).Name("user") c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed b.ReportAllocs() var err error for b.Loop() { err = c.Redirect().Route("user", RedirectConfig{ Params: Map{ "name": "fiber", }, }) } require.NoError(b, err) require.Equal(b, StatusSeeOther, c.Response().StatusCode()) require.Equal(b, "/user/fiber", string(c.Response().Header.Peek(HeaderLocation))) } // go test -v -run=^$ -bench=Benchmark_Redirect_Route_WithQueries -benchmem -count=4 func Benchmark_Redirect_Route_WithQueries(b *testing.B) { app := New() app.Get("/user/:name", func(c Ctx) error { return c.JSON(c.Params("name")) }).Name("user") c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed b.ReportAllocs() var err error for b.Loop() { err = c.Redirect().Route("user", RedirectConfig{ Params: Map{ "name": "fiber", }, Queries: map[string]string{"a": "a", "b": "b"}, }) } require.NoError(b, err) require.Equal(b, StatusSeeOther, c.Response().StatusCode()) // analysis of query parameters with url parsing, since a map pass is always randomly ordered location, err := url.Parse(string(c.Response().Header.Peek(HeaderLocation))) require.NoError(b, err, "url.Parse(location)") require.Equal(b, "/user/fiber", location.Path) require.Equal(b, url.Values{"a": []string{"a"}, "b": []string{"b"}}, location.Query()) } // go test -v -run=^$ -bench=Benchmark_Redirect_Route_WithFlashMessages -benchmem -count=4 func Benchmark_Redirect_Route_WithFlashMessages(b *testing.B) { app := New() app.Get("/user", func(c Ctx) error { return c.SendString("user") }).Name("user") c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed b.ReportAllocs() var err error for b.Loop() { err = c.Redirect().With("success", "1").With("message", "test").Route("user") } require.NoError(b, err) require.Equal(b, StatusSeeOther, c.Response().StatusCode()) require.Equal(b, "/user", string(c.Response().Header.Peek(HeaderLocation))) c.RequestCtx().Request.Header.Set(HeaderCookie, c.GetRespHeader(HeaderSetCookie)) // necessary for testing var msgs redirectionMsgs decoded, err := hex.DecodeString(c.Cookies(FlashCookieName)) require.NoError(b, err) _, err = msgs.UnmarshalMsg(decoded) require.NoError(b, err) require.Contains(b, msgs, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false}) require.Contains(b, msgs, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false}) } var testredirectionMsgs = redirectionMsgs{ { key: "success", value: "1", }, { key: "message", value: "test", }, { key: "name", value: "tom", isOldInput: true, }, { key: "id", value: "1", isOldInput: true, }, } // go test -v -run=^$ -bench=Benchmark_Redirect_parseAndClearFlashMessages -benchmem -count=4 func Benchmark_Redirect_parseAndClearFlashMessages(b *testing.B) { app := New() app.Get("/user", func(c Ctx) error { return c.SendString("user") }).Name("user") c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed val, err := testredirectionMsgs.MarshalMsg(nil) require.NoError(b, err) c.Request().Header.Set(HeaderCookie, "fiber_flash="+hex.EncodeToString(val)) b.ReportAllocs() for b.Loop() { c.Redirect().parseAndClearFlashMessages() } require.Equal(b, FlashMessage{ Key: "success", Value: "1", }, c.Redirect().Message("success")) require.Equal(b, FlashMessage{ Key: "message", Value: "test", }, c.Redirect().Message("message")) require.Equal(b, OldInputData{ Key: "id", Value: "1", }, c.Redirect().OldInput("id")) require.Equal(b, OldInputData{ Key: "name", Value: "tom", }, c.Redirect().OldInput("name")) } // go test -v -run=^$ -bench=Benchmark_Redirect_processFlashMessages -benchmem -count=4 func Benchmark_Redirect_processFlashMessages(b *testing.B) { app := New() app.Get("/user", func(c Ctx) error { return c.SendString("user") }).Name("user") c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed c.Redirect().With("success", "1").With("message", "test") b.ReportAllocs() for b.Loop() { c.Redirect().processFlashMessages() } c.RequestCtx().Request.Header.Set(HeaderCookie, c.GetRespHeader(HeaderSetCookie)) // necessary for testing var msgs redirectionMsgs decoded, err := hex.DecodeString(c.Cookies(FlashCookieName)) require.NoError(b, err) _, err = msgs.UnmarshalMsg(decoded) require.NoError(b, err) require.Len(b, msgs, 2) require.Contains(b, msgs, redirectionMsg{key: "success", value: "1", level: 0, isOldInput: false}) require.Contains(b, msgs, redirectionMsg{key: "message", value: "test", level: 0, isOldInput: false}) } // go test -v -run=^$ -bench=Benchmark_Redirect_Messages -benchmem -count=4 func Benchmark_Redirect_Messages(b *testing.B) { app := New() app.Get("/user", func(c Ctx) error { return c.SendString("user") }).Name("user") c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed val, err := testredirectionMsgs.MarshalMsg(nil) require.NoError(b, err) c.Request().Header.Set(HeaderCookie, "fiber_flash="+hex.EncodeToString(val)) c.Redirect().parseAndClearFlashMessages() var msgs []FlashMessage msgTemplate := testredirectionMsgs b.ReportAllocs() for b.Loop() { c.flashMessages = c.flashMessages[:0] c.flashMessages = append(c.flashMessages, msgTemplate...) msgs = c.Redirect().Messages() } require.Contains(b, msgs, FlashMessage{ Key: "success", Value: "1", Level: 0, }) require.Contains(b, msgs, FlashMessage{ Key: "message", Value: "test", Level: 0, }) } // go test -v -run=^$ -bench=Benchmark_Redirect_OldInputs -benchmem -count=4 func Benchmark_Redirect_OldInputs(b *testing.B) { app := New() app.Get("/user", func(c Ctx) error { return c.SendString("user") }).Name("user") c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed val, err := testredirectionMsgs.MarshalMsg(nil) require.NoError(b, err) c.Request().Header.Set(HeaderCookie, "fiber_flash="+hex.EncodeToString(val)) c.Redirect().parseAndClearFlashMessages() var oldInputs []OldInputData b.ReportAllocs() for b.Loop() { oldInputs = c.Redirect().OldInputs() } require.Contains(b, oldInputs, OldInputData{ Key: "name", Value: "tom", }) require.Contains(b, oldInputs, OldInputData{ Key: "id", Value: "1", }) } // go test -v -run=^$ -bench=Benchmark_Redirect_Message -benchmem -count=4 func Benchmark_Redirect_Message(b *testing.B) { app := New() app.Get("/user", func(c Ctx) error { return c.SendString("user") }).Name("user") c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed val, err := testredirectionMsgs.MarshalMsg(nil) require.NoError(b, err) c.Request().Header.Set(HeaderCookie, "fiber_flash="+hex.EncodeToString(val)) c.Redirect().parseAndClearFlashMessages() var msg FlashMessage msgTemplate := testredirectionMsgs b.ReportAllocs() for b.Loop() { c.flashMessages = c.flashMessages[:0] c.flashMessages = append(c.flashMessages, msgTemplate...) msg = c.Redirect().Message("message") } require.Equal(b, FlashMessage{ Key: "message", Value: "test", Level: 0, }, msg) } // go test -v -run=^$ -bench=Benchmark_Redirect_OldInput -benchmem -count=4 func Benchmark_Redirect_OldInput(b *testing.B) { app := New() app.Get("/user", func(c Ctx) error { return c.SendString("user") }).Name("user") c := app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed val, err := testredirectionMsgs.MarshalMsg(nil) require.NoError(b, err) c.Request().Header.Set(HeaderCookie, "fiber_flash="+hex.EncodeToString(val)) c.Redirect().parseAndClearFlashMessages() var input OldInputData b.ReportAllocs() for b.Loop() { input = c.Redirect().OldInput("name") } require.Equal(b, OldInputData{ Key: "name", Value: "tom", }, input) } ================================================ FILE: register.go ================================================ // ⚡️ Fiber is an Express inspired web framework written in Go with ☕️ // 🤖 GitHub Repository: https://github.com/gofiber/fiber // 📌 API Documentation: https://docs.gofiber.io package fiber // Register defines all router handle interface generate by RouteChain(). type Register interface { All(handler any, handlers ...any) Register Get(handler any, handlers ...any) Register Head(handler any, handlers ...any) Register Post(handler any, handlers ...any) Register Put(handler any, handlers ...any) Register Delete(handler any, handlers ...any) Register Connect(handler any, handlers ...any) Register Options(handler any, handlers ...any) Register Trace(handler any, handlers ...any) Register Patch(handler any, handlers ...any) Register Add(methods []string, handler any, handlers ...any) Register RouteChain(path string) Register } var _ Register = (*Registering)(nil) // Registering provides route registration helpers for a specific path on the // application instance. type Registering struct { app *App group *Group path string } // All registers a middleware route that will match requests // with the provided path which is stored in register struct. // // app.RouteChain("/").All(func(c fiber.Ctx) error { // return c.Next() // }) // app.RouteChain("/api").All(func(c fiber.Ctx) error { // return c.Next() // }) // app.RouteChain("/api").All(handler, func(c fiber.Ctx) error { // return c.Next() // }) // // This method will match all HTTP verbs: GET, POST, PUT, HEAD etc... func (r *Registering) All(handler any, handlers ...any) Register { converted := collectHandlers("register", append([]any{handler}, handlers...)...) r.app.register([]string{methodUse}, r.path, r.group, converted...) return r } // Get registers a route for GET methods that requests a representation // of the specified resource. Requests using GET should only retrieve data. func (r *Registering) Get(handler any, handlers ...any) Register { return r.Add([]string{MethodGet}, handler, handlers...) } // Head registers a route for HEAD methods that asks for a response identical // to that of a GET request, but without the response body. func (r *Registering) Head(handler any, handlers ...any) Register { return r.Add([]string{MethodHead}, handler, handlers...) } // Post registers a route for POST methods that is used to submit an entity to the // specified resource, often causing a change in state or side effects on the server. func (r *Registering) Post(handler any, handlers ...any) Register { return r.Add([]string{MethodPost}, handler, handlers...) } // Put registers a route for PUT methods that replaces all current representations // of the target resource with the request payload. func (r *Registering) Put(handler any, handlers ...any) Register { return r.Add([]string{MethodPut}, handler, handlers...) } // Delete registers a route for DELETE methods that deletes the specified resource. func (r *Registering) Delete(handler any, handlers ...any) Register { return r.Add([]string{MethodDelete}, handler, handlers...) } // Connect registers a route for CONNECT methods that establishes a tunnel to the // server identified by the target resource. func (r *Registering) Connect(handler any, handlers ...any) Register { return r.Add([]string{MethodConnect}, handler, handlers...) } // Options registers a route for OPTIONS methods that is used to describe the // communication options for the target resource. func (r *Registering) Options(handler any, handlers ...any) Register { return r.Add([]string{MethodOptions}, handler, handlers...) } // Trace registers a route for TRACE methods that performs a message loop-back // test along the r.Path to the target resource. func (r *Registering) Trace(handler any, handlers ...any) Register { return r.Add([]string{MethodTrace}, handler, handlers...) } // Patch registers a route for PATCH methods that is used to apply partial // modifications to a resource. func (r *Registering) Patch(handler any, handlers ...any) Register { return r.Add([]string{MethodPatch}, handler, handlers...) } // Add allows you to specify multiple HTTP methods to register a route. // The provided handlers are executed in order, starting with `handler` and then the variadic `handlers`. func (r *Registering) Add(methods []string, handler any, handlers ...any) Register { converted := collectHandlers("register", append([]any{handler}, handlers...)...) r.app.register(methods, r.path, r.group, converted...) return r } // RouteChain returns a new Register instance whose route path takes // the path in the current instance as its prefix. func (r *Registering) RouteChain(path string) Register { // Create new group route := &Registering{app: r.app, group: r.group, path: getGroupPath(r.path, path)} return route } ================================================ FILE: req.go ================================================ package fiber import ( "bytes" "errors" "fmt" "math" "mime/multipart" "net" "strconv" "strings" "github.com/gofiber/utils/v2" utilsbytes "github.com/gofiber/utils/v2/bytes" utilsstrings "github.com/gofiber/utils/v2/strings" "github.com/valyala/bytebufferpool" "github.com/valyala/fasthttp" "golang.org/x/net/idna" ) // Pre-allocated byte slices for common header comparisons to avoid allocations var ( xForwardedPrefix = []byte("X-Forwarded-") xForwardedProtoBytes = []byte(HeaderXForwardedProto) xForwardedProtocolBytes = []byte(HeaderXForwardedProtocol) xForwardedSslBytes = []byte(HeaderXForwardedSsl) xURLSchemeBytes = []byte(HeaderXUrlScheme) onBytes = []byte("on") ) // Range represents the parsed HTTP Range header extracted by DefaultReq.Range. type Range struct { Type string Ranges []RangeSet } // RangeSet represents a single content range from a request. type RangeSet struct { Start int64 End int64 } // DefaultReq is the default implementation of Req used by DefaultCtx. // //go:generate ifacemaker --file req.go --struct DefaultReq --iface Req --pkg fiber --output req_interface_gen.go --not-exported true --iface-comment "Req is an interface for request-related Ctx methods." type DefaultReq struct { c *DefaultCtx } // Accepts checks if the specified extensions or content types are acceptable. func (r *DefaultReq) Accepts(offers ...string) string { header := joinHeaderValues(r.c.fasthttp.Request.Header.PeekAll(HeaderAccept)) return getOffer(header, acceptsOfferType, offers...) } // AcceptsCharsets checks if the specified charset is acceptable. func (r *DefaultReq) AcceptsCharsets(offers ...string) string { header := joinHeaderValues(r.c.fasthttp.Request.Header.PeekAll(HeaderAcceptCharset)) return getOffer(header, acceptsOffer, offers...) } // AcceptsEncodings checks if the specified encoding is acceptable. func (r *DefaultReq) AcceptsEncodings(offers ...string) string { header := joinHeaderValues(r.c.fasthttp.Request.Header.PeekAll(HeaderAcceptEncoding)) return getOffer(header, acceptsOffer, offers...) } // AcceptsLanguages checks if the specified language is acceptable using // RFC 4647 Basic Filtering. func (r *DefaultReq) AcceptsLanguages(offers ...string) string { header := joinHeaderValues(r.c.fasthttp.Request.Header.PeekAll(HeaderAcceptLanguage)) return getOffer(header, acceptsLanguageOfferBasic, offers...) } // AcceptsLanguagesExtended checks if the specified language is acceptable using // RFC 4647 Extended Filtering. func (r *DefaultReq) AcceptsLanguagesExtended(offers ...string) string { header := joinHeaderValues(r.c.fasthttp.Request.Header.PeekAll(HeaderAcceptLanguage)) return getOffer(header, acceptsLanguageOfferExtended, offers...) } // App returns the *App reference to the instance of the Fiber application func (r *DefaultReq) App() *App { return r.c.app } // BaseURL returns (protocol + host + base path). func (r *DefaultReq) BaseURL() string { return r.c.BaseURL() } // BodyRaw contains the raw body submitted in a POST request. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. func (r *DefaultReq) BodyRaw() []byte { return r.getBody() } //nolint:nonamedreturns // gocritic unnamedResult prefers naming decoded body, decode count, and error func (r *DefaultReq) tryDecodeBodyInOrder( originalBody *[]byte, encodings []string, ) (body []byte, decodesRealized uint8, err error) { request := &r.c.fasthttp.Request for idx := range encodings { i := len(encodings) - 1 - idx encoding := encodings[i] decodesRealized++ var decodeErr error switch encoding { case StrGzip, "x-gzip": body, decodeErr = request.BodyGunzip() case StrBr, StrBrotli: body, decodeErr = request.BodyUnbrotli() case StrDeflate: body, decodeErr = request.BodyInflate() case StrZstd: body, decodeErr = request.BodyUnzstd() case StrIdentity: body = request.Body() case StrCompress, "x-compress": return nil, decodesRealized - 1, ErrNotImplemented default: return nil, decodesRealized - 1, ErrUnsupportedMediaType } if decodeErr != nil { return nil, decodesRealized, decodeErr } if i > 0 && decodesRealized > 0 { if i == len(encodings)-1 { tempBody := request.Body() *originalBody = make([]byte, len(tempBody)) copy(*originalBody, tempBody) } request.SetBodyRaw(body) } } return body, decodesRealized, nil } // Body contains the raw body submitted in a POST request. // This method will decompress the body if the 'Content-Encoding' header is provided. // It returns the original (or decompressed) body data which is valid only within the handler. // Don't store direct references to the returned data. // If you need to keep the body's data later, make a copy or use the Immutable option. func (r *DefaultReq) Body() []byte { var ( err error body, originalBody []byte headerEncoding string encodingOrder = []string{"", "", ""} ) request := &r.c.fasthttp.Request // Get Content-Encoding header headerEncoding = utils.UnsafeString(utilsbytes.UnsafeToLower(request.Header.ContentEncoding())) // If no encoding is provided, return the original body if headerEncoding == "" { return r.getBody() } // Split and get the encodings list, in order to attend the // rule defined at: https://www.rfc-editor.org/rfc/rfc9110#section-8.4-5 encodingOrder = getSplicedStrList(headerEncoding, encodingOrder) for i := range encodingOrder { encodingOrder[i] = utilsstrings.UnsafeToLower(encodingOrder[i]) } if len(encodingOrder) == 0 { return r.getBody() } var decodesRealized uint8 body, decodesRealized, err = r.tryDecodeBodyInOrder(&originalBody, encodingOrder) // Ensure that the body will be the original if originalBody != nil && decodesRealized > 0 { request.SetBodyRaw(originalBody) } if err != nil { switch { case errors.Is(err, ErrUnsupportedMediaType): _ = r.c.DefaultRes.SendStatus(StatusUnsupportedMediaType) //nolint:errcheck,staticcheck // It is fine to ignore the error and the static check case errors.Is(err, ErrNotImplemented): _ = r.c.DefaultRes.SendStatus(StatusNotImplemented) //nolint:errcheck,staticcheck // It is fine to ignore the error and the static check default: // do nothing } return []byte(err.Error()) } return r.c.app.GetBytes(body) } // RequestCtx returns *fasthttp.RequestCtx that carries a deadline // a cancellation signal, and other values across API boundaries. func (r *DefaultReq) RequestCtx() *fasthttp.RequestCtx { return r.c.fasthttp } // FullURL returns the full request URL (protocol + host + original URL). func (c *DefaultCtx) FullURL() string { buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) buf.WriteString(c.Scheme()) buf.WriteString("://") buf.WriteString(c.Host()) buf.WriteString(c.OriginalURL()) return c.app.toString(buf.Bytes()) } // UserAgent returns the User-Agent request header. func (c *DefaultCtx) UserAgent() string { return c.app.toString(c.fasthttp.Request.Header.UserAgent()) } // Referer returns the Referer request header. func (c *DefaultCtx) Referer() string { return c.app.toString(c.fasthttp.Request.Header.Referer()) } // AcceptLanguage returns the Accept-Language request header. func (c *DefaultCtx) AcceptLanguage() string { return c.app.toString(c.fasthttp.Request.Header.Peek(HeaderAcceptLanguage)) } // AcceptEncoding returns the Accept-Encoding request header. func (c *DefaultCtx) AcceptEncoding() string { return c.app.toString(c.fasthttp.Request.Header.Peek(HeaderAcceptEncoding)) } // HasHeader reports whether the request includes a header with the given key. func (c *DefaultCtx) HasHeader(key string) bool { return len(c.fasthttp.Request.Header.Peek(key)) > 0 } // MediaType returns the MIME type from the Content-Type header without parameters. func (c *DefaultCtx) MediaType() string { contentType := utils.TrimSpace(c.fasthttp.Request.Header.ContentType()) if len(contentType) == 0 { return "" } if idx := bytes.IndexByte(contentType, ';'); idx != -1 { contentType = contentType[:idx] } contentType = utils.TrimSpace(contentType) return c.app.toString(contentType) } // Charset returns the charset parameter from the Content-Type header. func (c *DefaultCtx) Charset() string { contentType := c.fasthttp.Request.Header.ContentType() if len(contentType) == 0 { return "" } _, after, ok := bytes.Cut(contentType, []byte{';'}) if !ok { return "" } params := after for len(params) > 0 { params = utils.TrimSpace(params) if len(params) == 0 { return "" } param := params if idx := bytes.IndexByte(params, ';'); idx != -1 { param = params[:idx] params = params[idx+1:] } else { params = nil } param = utils.TrimSpace(param) if len(param) == 0 { continue } before, after, ok := bytes.Cut(param, []byte{'='}) if !ok { continue } name := utils.TrimSpace(before) if !bytes.EqualFold(name, []byte("charset")) { continue } value := utils.TrimSpace(after) if len(value) >= 2 && value[0] == '"' && value[len(value)-1] == '"' { value = value[1 : len(value)-1] } return c.app.toString(value) } return "" } // IsJSON reports whether the Content-Type header is JSON. func (c *DefaultCtx) IsJSON() bool { return utils.EqualFold(c.MediaType(), MIMEApplicationJSON) } // IsForm reports whether the Content-Type header is form-encoded. func (c *DefaultCtx) IsForm() bool { return utils.EqualFold(c.MediaType(), MIMEApplicationForm) } // IsMultipart reports whether the Content-Type header is multipart form data. func (c *DefaultCtx) IsMultipart() bool { return utils.EqualFold(c.MediaType(), MIMEMultipartForm) } // AcceptsJSON reports whether the Accept header allows JSON. func (c *DefaultCtx) AcceptsJSON() bool { return c.Accepts(MIMEApplicationJSON) != "" } // AcceptsHTML reports whether the Accept header allows HTML. func (c *DefaultCtx) AcceptsHTML() bool { return c.Accepts(MIMETextHTML) != "" } // AcceptsXML reports whether the Accept header allows XML. func (c *DefaultCtx) AcceptsXML() bool { return c.Accepts(MIMEApplicationXML, MIMETextXML) != "" } // AcceptsEventStream reports whether the Accept header allows text/event-stream. func (c *DefaultCtx) AcceptsEventStream() bool { return c.Accepts("text/event-stream") != "" } // Cookies are used for getting a cookie value by key. // Defaults to the empty string "" if the cookie doesn't exist. // If a default value is given, it will return that value if the cookie doesn't exist. // The returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting to use the value outside the Handler. func (r *DefaultReq) Cookies(key string, defaultValue ...string) string { return defaultString(r.c.app.toString(r.c.fasthttp.Request.Header.Cookie(key)), defaultValue) } // Request return the *fasthttp.Request object // This allows you to use all fasthttp request methods // https://godoc.org/github.com/valyala/fasthttp#Request func (r *DefaultReq) Request() *fasthttp.Request { return &r.c.fasthttp.Request } // FormFile returns the first file by key from a MultipartForm. func (r *DefaultReq) FormFile(key string) (*multipart.FileHeader, error) { return r.c.fasthttp.FormFile(key) } // FormValue returns the first value by key from a MultipartForm. // Search is performed in QueryArgs, PostArgs, MultipartForm and FormFile in this particular order. // Defaults to the empty string "" if the form value doesn't exist. // If a default value is given, it will return that value if the form value does not exist. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. func (r *DefaultReq) FormValue(key string, defaultValue ...string) string { return defaultString(r.c.app.toString(r.c.fasthttp.FormValue(key)), defaultValue) } // Fresh returns true when the response is still “fresh” in the client's cache, // otherwise false is returned to indicate that the client cache is now stale // and the full response should be sent. // When a client sends the Cache-Control: no-cache request header to indicate an end-to-end // reload request, this module will return false to make handling these requests transparent. // https://github.com/jshttp/fresh/blob/master/index.js#L33 func (r *DefaultReq) Fresh() bool { header := &r.c.fasthttp.Request.Header // fields modifiedSince := header.Peek(HeaderIfModifiedSince) noneMatch := header.Peek(HeaderIfNoneMatch) // unconditional request if len(modifiedSince) == 0 && len(noneMatch) == 0 { return false } // Always return stale when Cache-Control: no-cache // to support end-to-end reload requests // https://www.rfc-editor.org/rfc/rfc9111#section-5.2.1.4 cacheControl := header.Peek(HeaderCacheControl) if len(cacheControl) > 0 && isNoCache(utils.UnsafeString(cacheControl)) { return false } // if-none-match if len(noneMatch) > 0 && (len(noneMatch) != 1 || noneMatch[0] != '*') { app := r.c.app response := &r.c.fasthttp.Response etag := app.toString(response.Header.Peek(HeaderETag)) if etag == "" { return false } if app.isEtagStale(etag, noneMatch) { return false } if len(modifiedSince) > 0 { lastModified := response.Header.Peek(HeaderLastModified) if len(lastModified) > 0 { lastModifiedTime, err := fasthttp.ParseHTTPDate(lastModified) if err != nil { return false } modifiedSinceTime, err := fasthttp.ParseHTTPDate(modifiedSince) if err != nil { return false } return lastModifiedTime.Compare(modifiedSinceTime) != 1 } } } return true } // Get returns the HTTP request header specified by field. // Field names are case-insensitive // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. func (r *DefaultReq) Get(key string, defaultValue ...string) string { return GetReqHeader(r.c, key, defaultValue...) } // GetReqHeader returns the HTTP request header specified by filed. // This function is generic and can handle different headers type values. // If the generic type cannot be matched to a supported type, the function // returns the default value (if provided) or the zero value of type V. func GetReqHeader[V GenericType](c Ctx, key string, defaultValue ...V) V { v, err := genericParseType[V](c.App().toString(c.Request().Header.Peek(key))) if err != nil && len(defaultValue) > 0 { return defaultValue[0] } return v } // GetHeaders (a.k.a GetReqHeaders) returns the HTTP request headers. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. func (r *DefaultReq) GetHeaders() map[string][]string { app := r.c.app reqHeader := &r.c.fasthttp.Request.Header // Pre-allocate map with known header count to avoid reallocations headers := make(map[string][]string, reqHeader.Len()) for k, v := range reqHeader.All() { key := app.toString(k) headers[key] = append(headers[key], app.toString(v)) } return headers } // Host contains the host derived from the X-Forwarded-Host or Host HTTP header. // Returned value is only valid within the handler. Do not store any references. // In a network context, `Host` refers to the combination of a hostname and potentially a port number used for connecting, // while `Hostname` refers specifically to the name assigned to a device on a network, excluding any port information. // Example: URL: https://example.com:8080 -> Host: example.com:8080 // Make copies or use the Immutable setting instead. // Please use Config.TrustProxy to prevent header spoofing if your app is behind a proxy. func (r *DefaultReq) Host() string { if r.IsProxyTrusted() { if host := r.Get(HeaderXForwardedHost); host != "" { if before, _, found := strings.Cut(host, ","); found { return utils.TrimSpace(before) } return utils.TrimSpace(host) } } return r.c.app.toString(r.c.fasthttp.Request.URI().Host()) } // Hostname contains the hostname derived from the X-Forwarded-Host or Host HTTP header using the c.Host() method. // Returned value is only valid within the handler. Do not store any references. // Example: URL: https://example.com:8080 -> Hostname: example.com // Make copies or use the Immutable setting instead. // Please use Config.TrustProxy to prevent header spoofing if your app is behind a proxy. func (r *DefaultReq) Hostname() string { addr, _ := parseAddr(r.Host()) return addr } // Port returns the remote port of the request. func (r *DefaultReq) Port() string { addr := r.c.fasthttp.RemoteAddr() if addr == nil { return "0" } switch typedAddr := addr.(type) { case *net.TCPAddr: return strconv.Itoa(typedAddr.Port) case *net.UnixAddr: return "" } _, port, err := net.SplitHostPort(addr.String()) if err != nil { return "" } return port } // IP returns the remote IP address of the request. // If ProxyHeader and IP Validation is configured, it will parse that header and return the first valid IP address. // Please use Config.TrustProxy to prevent header spoofing if your app is behind a proxy. func (r *DefaultReq) IP() string { app := r.c.app if r.IsProxyTrusted() && app.config.ProxyHeader != "" { return r.extractIPFromHeader(app.config.ProxyHeader) } if ip := r.c.fasthttp.RemoteIP(); ip != nil { return ip.String() } return "" } // extractIPsFromHeader will return a slice of IPs it found given a header name in the order they appear. // When IP validation is enabled, any invalid IPs will be omitted. func (r *DefaultReq) extractIPsFromHeader(header string) []string { // TODO: Reuse the c.extractIPFromHeader func somehow in here headerValue := r.Get(header) // We can't know how many IPs we will return, but we will try to guess with this constant division. // Counting ',' makes function slower for about 50ns in general case. const maxEstimatedCount = 8 estimatedCount := min(len(headerValue)/maxEstimatedCount, // Avoid big allocation on big header maxEstimatedCount) ipsFound := make([]string, 0, estimatedCount) i := 0 j := -1 for { var v4, v6 bool // Manually splitting string without allocating slice, working with parts directly i, j = j+1, j+2 if j > len(headerValue) { break } for j < len(headerValue) && headerValue[j] != ',' { switch headerValue[j] { case ':': v6 = true case '.': v4 = true default: // do nothing } j++ } for i < j && (headerValue[i] == ' ' || headerValue[i] == ',') { i++ } s := utils.TrimRight(headerValue[i:j], ' ') if r.c.app.config.EnableIPValidation { // Skip validation if IP is clearly not IPv4/IPv6; otherwise, validate without allocations if (!v6 && !v4) || (v6 && !utils.IsIPv6(s)) || (v4 && !utils.IsIPv4(s)) { continue } } ipsFound = append(ipsFound, s) } return ipsFound } // extractIPFromHeader will attempt to pull the real client IP from the given header when IP validation is enabled. // currently, it will return the first valid IP address in header. // when IP validation is disabled, it will simply return the value of the header without any inspection. // Implementation is almost the same as in extractIPsFromHeader, but without allocation of []string. func (r *DefaultReq) extractIPFromHeader(header string) string { app := r.c.app if app.config.EnableIPValidation { headerValue := r.Get(header) i := 0 j := -1 for { var v4, v6 bool // Manually splitting string without allocating slice, working with parts directly i, j = j+1, j+2 if j > len(headerValue) { break } for j < len(headerValue) && headerValue[j] != ',' { switch headerValue[j] { case ':': v6 = true case '.': v4 = true default: // do nothing } j++ } for i < j && headerValue[i] == ' ' { i++ } s := utils.TrimRight(headerValue[i:j], ' ') if app.config.EnableIPValidation { if (!v6 && !v4) || (v6 && !utils.IsIPv6(s)) || (v4 && !utils.IsIPv4(s)) { continue } } return s } if ip := r.c.fasthttp.RemoteIP(); ip != nil { return ip.String() } return "" } // default behavior if IP validation is not enabled is just to return whatever value is // in the proxy header. Even if it is empty or invalid return r.Get(app.config.ProxyHeader) } // IPs returns a string slice of IP addresses specified in the X-Forwarded-For request header. // When IP validation is enabled, only valid IPs are returned. func (r *DefaultReq) IPs() []string { return r.extractIPsFromHeader(HeaderXForwardedFor) } // Is returns the matching content type, // if the incoming request's Content-Type HTTP header field matches the MIME type specified by the type parameter func (r *DefaultReq) Is(extension string) bool { extensionHeader := utils.GetMIME(extension) if extensionHeader == "" { return false } ct := r.c.app.toString(r.c.fasthttp.Request.Header.ContentType()) if i := strings.IndexByte(ct, ';'); i != -1 { ct = ct[:i] } ct = utils.TrimSpace(ct) return utils.EqualFold(ct, extensionHeader) } // Locals makes it possible to pass any values under keys scoped to the request // and therefore available to all following routes that match the request. // // All the values are removed from ctx after returning from the top // RequestHandler. Additionally, Close method is called on each value // implementing io.Closer before removing the value from ctx. func (r *DefaultReq) Locals(key any, value ...any) any { if len(value) == 0 { return r.c.fasthttp.UserValue(key) } r.c.fasthttp.SetUserValue(key, value[0]) return value[0] } // Locals function utilizing Go's generics feature. // This function allows for manipulating and retrieving local values within a // request context with a more specific data type. // // All the values are removed from ctx after returning from the top // RequestHandler. Additionally, Close method is called on each value // implementing io.Closer before removing the value from ctx. func Locals[V any](c Ctx, key any, value ...V) V { var v V var ok bool if len(value) == 0 { v, ok = c.Locals(key).(V) } else { v, ok = c.Locals(key, value[0]).(V) } if !ok { return v // return zero of type V } return v } // Method returns the HTTP request method for the context, optionally overridden by the provided argument. // If no override is given or if the provided override is not a valid HTTP method, it returns the current method from the context. // Otherwise, it updates the context's method and returns the overridden method as a string. func (r *DefaultReq) Method(override ...string) string { app := r.c.app if len(override) == 0 { // Nothing to override, just return current method from context return app.method(r.c.methodInt) } method := utilsstrings.ToUpper(override[0]) methodInt := app.methodInt(method) if methodInt == -1 { // Provided override does not valid HTTP method, no override, return current method return app.method(r.c.methodInt) } r.c.methodInt = methodInt return method } // MultipartForm parse form entries from binary. // This returns a map[string][]string, so given a key, the value will be a string slice. func (r *DefaultReq) MultipartForm() (*multipart.Form, error) { return r.c.fasthttp.MultipartForm() } // OriginalURL contains the original request URL. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting to use the value outside the Handler. func (r *DefaultReq) OriginalURL() string { return r.c.app.toString(r.c.fasthttp.Request.Header.RequestURI()) } // Params is used to get the route parameters. // Defaults to empty string "" if the param doesn't exist. // If a default value is given, it will return that value if the param doesn't exist. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting to use the value outside the Handler. func (r *DefaultReq) Params(key string, defaultValue ...string) string { if key == "*" || key == "+" { key += "1" } app := r.c.app route := r.c.Route() values := &r.c.values for i := range route.Params { if len(key) != len(route.Params[i]) { continue } if route.Params[i] == key || (!app.config.CaseSensitive && utils.EqualFold(route.Params[i], key)) { // if there is no value for the key if len(values) <= i || values[i] == "" { break } val := values[i] return r.c.app.GetString(val) } } return defaultString("", defaultValue) } // Params is used to get the route parameters. // This function is generic and can handle different route parameters type values. // If the generic type cannot be matched to a supported type, the function // returns the default value (if provided) or the zero value of type V. // // Example: // // http://example.com/user/:user -> http://example.com/user/john // Params[string](c, "user") -> returns john // // http://example.com/id/:id -> http://example.com/user/114 // Params[int](c, "id") -> returns 114 as integer. // // http://example.com/id/:number -> http://example.com/id/john // Params[int](c, "number", 0) -> returns 0 because can't parse 'john' as integer. func Params[V GenericType](c Ctx, key string, defaultValue ...V) V { v, err := genericParseType[V](c.Params(key)) if err != nil && len(defaultValue) > 0 { return defaultValue[0] } return v } // Scheme contains the request protocol string: http or https for TLS requests. // Please use Config.TrustProxy to prevent header spoofing if your app is behind a proxy. func (r *DefaultReq) Scheme() string { ctx := r.c.fasthttp if ctx.IsTLS() { return schemeHTTPS } if !r.IsProxyTrusted() { return schemeHTTP } app := r.c.app scheme := schemeHTTP const lenXHeaderName = 12 for key, val := range ctx.Request.Header.All() { if len(key) < lenXHeaderName { continue // Neither "X-Forwarded-" nor "X-Url-Scheme" } switch { case utils.EqualFold(key[:len(xForwardedPrefix)], xForwardedPrefix): if utils.EqualFold(key, xForwardedProtoBytes) || utils.EqualFold(key, xForwardedProtocolBytes) { v := app.toString(val) if before, _, found := strings.Cut(v, ","); found { scheme = utils.TrimSpace(before) } else { scheme = utils.TrimSpace(v) } } else if utils.EqualFold(key, xForwardedSslBytes) && utils.EqualFold(val, onBytes) { scheme = schemeHTTPS } case utils.EqualFold(key, xURLSchemeBytes): scheme = utils.TrimSpace(app.toString(val)) default: continue } } return utilsstrings.ToLower(utils.TrimSpace(scheme)) } // Protocol returns the HTTP protocol of request: HTTP/1.1 and HTTP/2. func (r *DefaultReq) Protocol() string { return r.c.app.toString(r.c.fasthttp.Request.Header.Protocol()) } // Query returns the query string parameter in the url. // Defaults to empty string "" if the query doesn't exist. // If a default value is given, it will return that value if the query doesn't exist. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting to use the value outside the Handler. func (r *DefaultReq) Query(key string, defaultValue ...string) string { return Query(r.c, key, defaultValue...) } // Queries returns a map of query parameters and their values. // // GET /?name=alex&wanna_cake=2&id= // Queries()["name"] == "alex" // Queries()["wanna_cake"] == "2" // Queries()["id"] == "" // // GET /?field1=value1&field1=value2&field2=value3 // Queries()["field1"] == "value2" // Queries()["field2"] == "value3" // // GET /?list_a=1&list_a=2&list_a=3&list_b[]=1&list_b[]=2&list_b[]=3&list_c=1,2,3 // Queries()["list_a"] == "3" // Queries()["list_b[]"] == "3" // Queries()["list_c"] == "1,2,3" // // GET /api/search?filters.author.name=John&filters.category.name=Technology&filters[customer][name]=Alice&filters[status]=pending // Queries()["filters.author.name"] == "John" // Queries()["filters.category.name"] == "Technology" // Queries()["filters[customer][name]"] == "Alice" // Queries()["filters[status]"] == "pending" func (r *DefaultReq) Queries() map[string]string { app := r.c.app queryArgs := r.c.fasthttp.QueryArgs() m := make(map[string]string, queryArgs.Len()) for key, value := range queryArgs.All() { m[app.toString(key)] = app.toString(value) } return m } // Query Retrieves the value of a query parameter from the request's URI. // The function is generic and can handle query parameter values of different types. // It takes the following parameters: // - c: The context object representing the current request. // - key: The name of the query parameter. // - defaultValue: (Optional) The default value to return if the query parameter is not found or cannot be parsed. // The function performs the following steps: // 1. Type-asserts the context object to *DefaultCtx. // 2. Retrieves the raw query parameter value from the request's URI. // 3. Parses the raw value into the appropriate type based on the generic type parameter V. // If parsing fails, the function checks if a default value is provided. If so, it returns the default value. // 4. Returns the parsed value. // // If the generic type cannot be matched to a supported type, the function returns the default value (if provided) or the zero value of type V. // // Example usage: // // GET /?search=john&age=8 // name := Query[string](c, "search") // Returns "john" // age := Query[int](c, "age") // Returns 8 // unknown := Query[string](c, "unknown", "default") // Returns "default" since the query parameter "unknown" is not found func Query[V GenericType](c Ctx, key string, defaultValue ...V) V { q := c.App().toString(c.RequestCtx().QueryArgs().Peek(key)) v, err := genericParseType[V](q) if err != nil && len(defaultValue) > 0 { return defaultValue[0] } return v } // Range returns a struct containing the type and a slice of ranges. func (r *DefaultReq) Range(size int64) (Range, error) { var ( rangeData Range ranges string ) rangeStr := utils.TrimSpace(r.Get(HeaderRange)) maxRanges := r.c.app.config.MaxRanges const maxRangePrealloc = 8 prealloc := min(maxRanges, maxRangePrealloc) if prealloc > 0 { rangeData.Ranges = make([]RangeSet, 0, prealloc) } parseBound := func(value string) (int64, error) { parsed, err := utils.ParseUint(value) if err != nil { return 0, fmt.Errorf("parse range bound %q: %w", value, err) } if parsed > (math.MaxUint64 >> 1) { return 0, ErrRangeMalformed } return int64(parsed), nil } before, after, found := strings.Cut(rangeStr, "=") if !found || strings.IndexByte(after, '=') >= 0 { return rangeData, ErrRangeMalformed } rangeData.Type = utilsstrings.ToLower(utils.TrimSpace(before)) if rangeData.Type != "bytes" { return rangeData, ErrRangeMalformed } ranges = utils.TrimSpace(after) var ( singleRange string moreRanges = ranges rangeCount int ) for moreRanges != "" { rangeCount++ if rangeCount > maxRanges { r.c.DefaultRes.Status(StatusRequestedRangeNotSatisfiable) r.c.DefaultRes.Set(HeaderContentRange, "bytes */"+utils.FormatInt(size)) //nolint:staticcheck // It is fine to ignore the static check return rangeData, ErrRangeTooLarge } singleRange = moreRanges if i := strings.IndexByte(moreRanges, ','); i >= 0 { singleRange = moreRanges[:i] moreRanges = utils.TrimSpace(moreRanges[i+1:]) } else { moreRanges = "" } singleRange = utils.TrimSpace(singleRange) var ( startStr, endStr string i int ) if i = strings.IndexByte(singleRange, '-'); i == -1 { return rangeData, ErrRangeMalformed } startStr = utils.TrimSpace(singleRange[:i]) endStr = utils.TrimSpace(singleRange[i+1:]) start, startErr := parseBound(startStr) end, endErr := parseBound(endStr) if errors.Is(startErr, ErrRangeMalformed) || errors.Is(endErr, ErrRangeMalformed) { return rangeData, ErrRangeMalformed } if startErr != nil { // -nnn start = max(size-end, 0) end = size - 1 } else if endErr != nil { // nnn- end = size - 1 } if end > size-1 { // limit last-byte-pos to current length end = size - 1 } if start > end || start < 0 { continue } rangeData.Ranges = append(rangeData.Ranges, RangeSet{ Start: start, End: end, }) } if len(rangeData.Ranges) < 1 { r.c.DefaultRes.Status(StatusRequestedRangeNotSatisfiable) r.c.DefaultRes.Set(HeaderContentRange, "bytes */"+utils.FormatInt(size)) //nolint:staticcheck // It is fine to ignore the static check return rangeData, ErrRequestedRangeNotSatisfiable } return rangeData, nil } // Route returns the matched Route struct. func (r *DefaultReq) Route() *Route { return r.c.Route() } // Subdomains returns a slice of subdomains from the host, excluding the last `offset` components. // If the offset is negative or exceeds the number of subdomains, an empty slice is returned. // If the offset is zero every label (no trimming) is returned. func (r *DefaultReq) Subdomains(offset ...int) []string { o := 2 if len(offset) > 0 { o = offset[0] } // Negative offset, return nothing. if o < 0 { return []string{} } // Normalize host according to RFC 3986 host := r.Hostname() // Trim the trailing dot of a fully-qualified domain if strings.HasSuffix(host, ".") { host = utils.TrimRight(host, '.') } host = utilsstrings.ToLower(host) // Decode punycode labels only when necessary if strings.Contains(host, "xn--") { if u, err := idna.Lookup.ToUnicode(host); err == nil { host = utilsstrings.ToLower(u) } } // Return nothing for IP addresses ip := host if strings.HasPrefix(ip, "[") && strings.HasSuffix(ip, "]") { ip = ip[1 : len(ip)-1] } if utils.IsIPv4(ip) || utils.IsIPv6(ip) { return []string{} } // Use stack-allocated array for typical domain names (up to 8 labels) // This avoids heap allocation for most common cases var partsBuf [8]string parts := partsBuf[:0] for part := range strings.SplitSeq(host, ".") { parts = append(parts, part) } // offset == 0, caller wants everything. if o == 0 { // Need to return a copy since partsBuf is on the stack result := make([]string, len(parts)) copy(result, parts) return result } // If we trim away the whole slice (or more), nothing remains. if o >= len(parts) { return []string{} } // Return a heap-allocated copy of the relevant portion result := make([]string, len(parts)-o) copy(result, parts[:len(parts)-o]) return result } // Stale returns the inverse of Fresh, indicating if the client's cached response is considered stale. func (r *DefaultReq) Stale() bool { return !r.Fresh() } // IsProxyTrusted checks trustworthiness of remote ip. // If Config.TrustProxy false, it returns false. // IsProxyTrusted can check remote ip by proxy ranges and ip map. func (r *DefaultReq) IsProxyTrusted() bool { config := r.c.app.config if !config.TrustProxy { return false } remoteAddr := r.c.fasthttp.RemoteAddr() switch remoteAddr.(type) { case *net.UnixAddr: return config.TrustProxyConfig.UnixSocket case *net.TCPAddr, *net.UDPAddr: // Keep existing RemoteIP/IP-map/CIDR checks for TCP/UDP paths as-is. default: // Unknown address type: do not trust by default. return false } ip := r.c.fasthttp.RemoteIP() if ip == nil { return false } if (config.TrustProxyConfig.Loopback && ip.IsLoopback()) || (config.TrustProxyConfig.Private && ip.IsPrivate()) || (config.TrustProxyConfig.LinkLocal && ip.IsLinkLocalUnicast()) { return true } if _, trusted := config.TrustProxyConfig.ips[ip.String()]; trusted { return true } for _, ipNet := range config.TrustProxyConfig.ranges { if ipNet.Contains(ip) { return true } } return false } // IsFromLocal will return true if request came from local. func (r *DefaultReq) IsFromLocal() bool { // Unix sockets are inherently local - only processes on the same host can connect. remoteAddr := r.c.fasthttp.RemoteAddr() if _, ok := remoteAddr.(*net.UnixAddr); ok { return true } if ip := r.c.fasthttp.RemoteIP(); ip != nil { return ip.IsLoopback() } return false } // Release is a method to reset Req fields when to use ReleaseCtx() func (r *DefaultReq) release() { r.c = nil } func (r *DefaultReq) getBody() []byte { return r.c.app.GetBytes(r.c.fasthttp.Request.Body()) } ================================================ FILE: req_interface_gen.go ================================================ // Code generated by ifacemaker; DO NOT EDIT. package fiber import ( "mime/multipart" "github.com/valyala/fasthttp" ) // Req is an interface for request-related Ctx methods. type Req interface { // Accepts checks if the specified extensions or content types are acceptable. Accepts(offers ...string) string // AcceptsCharsets checks if the specified charset is acceptable. AcceptsCharsets(offers ...string) string // AcceptsEncodings checks if the specified encoding is acceptable. AcceptsEncodings(offers ...string) string // AcceptsLanguages checks if the specified language is acceptable using // RFC 4647 Basic Filtering. AcceptsLanguages(offers ...string) string // AcceptsLanguagesExtended checks if the specified language is acceptable using // RFC 4647 Extended Filtering. AcceptsLanguagesExtended(offers ...string) string // App returns the *App reference to the instance of the Fiber application App() *App // BaseURL returns (protocol + host + base path). BaseURL() string // BodyRaw contains the raw body submitted in a POST request. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. BodyRaw() []byte //nolint:nonamedreturns // gocritic unnamedResult prefers naming decoded body, decode count, and error tryDecodeBodyInOrder(originalBody *[]byte, encodings []string) (body []byte, decodesRealized uint8, err error) // Body contains the raw body submitted in a POST request. // This method will decompress the body if the 'Content-Encoding' header is provided. // It returns the original (or decompressed) body data which is valid only within the handler. // Don't store direct references to the returned data. // If you need to keep the body's data later, make a copy or use the Immutable option. Body() []byte // RequestCtx returns *fasthttp.RequestCtx that carries a deadline // a cancellation signal, and other values across API boundaries. RequestCtx() *fasthttp.RequestCtx // Cookies are used for getting a cookie value by key. // Defaults to the empty string "" if the cookie doesn't exist. // If a default value is given, it will return that value if the cookie doesn't exist. // The returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting to use the value outside the Handler. Cookies(key string, defaultValue ...string) string // Request return the *fasthttp.Request object // This allows you to use all fasthttp request methods // https://godoc.org/github.com/valyala/fasthttp#Request Request() *fasthttp.Request // FormFile returns the first file by key from a MultipartForm. FormFile(key string) (*multipart.FileHeader, error) // FormValue returns the first value by key from a MultipartForm. // Search is performed in QueryArgs, PostArgs, MultipartForm and FormFile in this particular order. // Defaults to the empty string "" if the form value doesn't exist. // If a default value is given, it will return that value if the form value does not exist. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. FormValue(key string, defaultValue ...string) string // Fresh returns true when the response is still “fresh” in the client's cache, // otherwise false is returned to indicate that the client cache is now stale // and the full response should be sent. // When a client sends the Cache-Control: no-cache request header to indicate an end-to-end // reload request, this module will return false to make handling these requests transparent. // https://github.com/jshttp/fresh/blob/master/index.js#L33 Fresh() bool // Get returns the HTTP request header specified by field. // Field names are case-insensitive // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. Get(key string, defaultValue ...string) string // GetHeaders (a.k.a GetReqHeaders) returns the HTTP request headers. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. GetHeaders() map[string][]string // Host contains the host derived from the X-Forwarded-Host or Host HTTP header. // Returned value is only valid within the handler. Do not store any references. // In a network context, `Host` refers to the combination of a hostname and potentially a port number used for connecting, // while `Hostname` refers specifically to the name assigned to a device on a network, excluding any port information. // Example: URL: https://example.com:8080 -> Host: example.com:8080 // Make copies or use the Immutable setting instead. // Please use Config.TrustProxy to prevent header spoofing if your app is behind a proxy. Host() string // Hostname contains the hostname derived from the X-Forwarded-Host or Host HTTP header using the c.Host() method. // Returned value is only valid within the handler. Do not store any references. // Example: URL: https://example.com:8080 -> Hostname: example.com // Make copies or use the Immutable setting instead. // Please use Config.TrustProxy to prevent header spoofing if your app is behind a proxy. Hostname() string // Port returns the remote port of the request. Port() string // IP returns the remote IP address of the request. // If ProxyHeader and IP Validation is configured, it will parse that header and return the first valid IP address. // Please use Config.TrustProxy to prevent header spoofing if your app is behind a proxy. IP() string // extractIPsFromHeader will return a slice of IPs it found given a header name in the order they appear. // When IP validation is enabled, any invalid IPs will be omitted. extractIPsFromHeader(header string) []string // extractIPFromHeader will attempt to pull the real client IP from the given header when IP validation is enabled. // currently, it will return the first valid IP address in header. // when IP validation is disabled, it will simply return the value of the header without any inspection. // Implementation is almost the same as in extractIPsFromHeader, but without allocation of []string. extractIPFromHeader(header string) string // IPs returns a string slice of IP addresses specified in the X-Forwarded-For request header. // When IP validation is enabled, only valid IPs are returned. IPs() []string // Is returns the matching content type, // if the incoming request's Content-Type HTTP header field matches the MIME type specified by the type parameter Is(extension string) bool // Locals makes it possible to pass any values under keys scoped to the request // and therefore available to all following routes that match the request. // // All the values are removed from ctx after returning from the top // RequestHandler. Additionally, Close method is called on each value // implementing io.Closer before removing the value from ctx. Locals(key any, value ...any) any // Method returns the HTTP request method for the context, optionally overridden by the provided argument. // If no override is given or if the provided override is not a valid HTTP method, it returns the current method from the context. // Otherwise, it updates the context's method and returns the overridden method as a string. Method(override ...string) string // MultipartForm parse form entries from binary. // This returns a map[string][]string, so given a key, the value will be a string slice. MultipartForm() (*multipart.Form, error) // OriginalURL contains the original request URL. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting to use the value outside the Handler. OriginalURL() string // Params is used to get the route parameters. // Defaults to empty string "" if the param doesn't exist. // If a default value is given, it will return that value if the param doesn't exist. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting to use the value outside the Handler. Params(key string, defaultValue ...string) string // Scheme contains the request protocol string: http or https for TLS requests. // Please use Config.TrustProxy to prevent header spoofing if your app is behind a proxy. Scheme() string // Protocol returns the HTTP protocol of request: HTTP/1.1 and HTTP/2. Protocol() string // Query returns the query string parameter in the url. // Defaults to empty string "" if the query doesn't exist. // If a default value is given, it will return that value if the query doesn't exist. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting to use the value outside the Handler. Query(key string, defaultValue ...string) string // Queries returns a map of query parameters and their values. // // GET /?name=alex&wanna_cake=2&id= // Queries()["name"] == "alex" // Queries()["wanna_cake"] == "2" // Queries()["id"] == "" // // GET /?field1=value1&field1=value2&field2=value3 // Queries()["field1"] == "value2" // Queries()["field2"] == "value3" // // GET /?list_a=1&list_a=2&list_a=3&list_b[]=1&list_b[]=2&list_b[]=3&list_c=1,2,3 // Queries()["list_a"] == "3" // Queries()["list_b[]"] == "3" // Queries()["list_c"] == "1,2,3" // // GET /api/search?filters.author.name=John&filters.category.name=Technology&filters[customer][name]=Alice&filters[status]=pending // Queries()["filters.author.name"] == "John" // Queries()["filters.category.name"] == "Technology" // Queries()["filters[customer][name]"] == "Alice" // Queries()["filters[status]"] == "pending" Queries() map[string]string // Range returns a struct containing the type and a slice of ranges. Range(size int64) (Range, error) // Route returns the matched Route struct. Route() *Route // Subdomains returns a slice of subdomains from the host, excluding the last `offset` components. // If the offset is negative or exceeds the number of subdomains, an empty slice is returned. // If the offset is zero every label (no trimming) is returned. Subdomains(offset ...int) []string // Stale returns the inverse of Fresh, indicating if the client's cached response is considered stale. Stale() bool // IsProxyTrusted checks trustworthiness of remote ip. // If Config.TrustProxy false, it returns false. // IsProxyTrusted can check remote ip by proxy ranges and ip map. IsProxyTrusted() bool // IsFromLocal will return true if request came from local. IsFromLocal() bool // Release is a method to reset Req fields when to use ReleaseCtx() release() getBody() []byte } ================================================ FILE: res.go ================================================ package fiber import ( "bufio" "bytes" "fmt" "html/template" "io" "io/fs" "net/http" "net/url" "os" pathpkg "path" "path/filepath" "strconv" "strings" "time" "unicode" "unicode/utf8" "github.com/gofiber/utils/v2" "github.com/valyala/bytebufferpool" "github.com/valyala/fasthttp" ) // SendFile defines configuration options when to transfer file with SendFile. type SendFile struct { // FS is the file system to serve the static files from. // You can use interfaces compatible with fs.FS like embed.FS, os.DirFS etc. // // Optional. Default: nil FS fs.FS // When set to true, the server tries minimizing CPU usage by caching compressed files. // This works differently than the github.com/gofiber/compression middleware. // You have to set Content-Encoding header to compress the file. // Available compression methods are gzip, br, and zstd. // // Optional. Default: false Compress bool `json:"compress"` // When set to true, enables byte range requests. // // Optional. Default: false ByteRange bool `json:"byte_range"` // When set to true, enables direct download. // // Optional. Default: false Download bool `json:"download"` // Expiration duration for inactive file handlers. // Use a negative time.Duration to disable it. // // Optional. Default: 10 * time.Second CacheDuration time.Duration `json:"cache_duration"` // The value for the Cache-Control HTTP-header // that is set on the file response. MaxAge is defined in seconds. // // Optional. Default: 0 MaxAge int `json:"max_age"` } // sendFileStore is used to keep the SendFile configuration and the handler. type sendFileStore struct { handler fasthttp.RequestHandler cacheControlValue string config SendFile } // configEqual compares the current SendFile config with the new one // and returns true if they are equal. // // Here we don't use reflect.DeepEqual because it is quite slow compared to manual comparison. func (sf *sendFileStore) configEqual(cfg SendFile) bool { if sf.config.FS != cfg.FS { return false } if sf.config.Compress != cfg.Compress { return false } if sf.config.ByteRange != cfg.ByteRange { return false } if sf.config.Download != cfg.Download { return false } if sf.config.CacheDuration != cfg.CacheDuration { return false } if sf.config.MaxAge != cfg.MaxAge { return false } return true } // Cookie defines the values used when configuring cookies emitted by // DefaultRes.Cookie. type Cookie struct { Expires time.Time `json:"expires"` // The expiration date of the cookie Name string `json:"name"` // The name of the cookie Value string `json:"value"` // The value of the cookie Path string `json:"path"` // Specifies a URL path which is allowed to receive the cookie Domain string `json:"domain"` // Specifies the domain which is allowed to receive the cookie SameSite string `json:"same_site"` // Controls whether or not a cookie is sent with cross-site requests MaxAge int `json:"max_age"` // The maximum age (in seconds) of the cookie Secure bool `json:"secure"` // Indicates that the cookie should only be transmitted over a secure HTTPS connection HTTPOnly bool `json:"http_only"` // Indicates that the cookie is accessible only through the HTTP protocol Partitioned bool `json:"partitioned"` // Indicates if the cookie is stored in a partitioned cookie jar SessionOnly bool `json:"session_only"` // Indicates if the cookie is a session-only cookie } // ResFmt associates a Content Type to a fiber.Handler for c.Format type ResFmt struct { Handler func(Ctx) error MediaType string } // DefaultRes is the default implementation of Res used by DefaultCtx. // //go:generate ifacemaker --file res.go --struct DefaultRes --iface Res --pkg fiber --output res_interface_gen.go --not-exported true --iface-comment "Res is an interface for response-related Ctx methods." type DefaultRes struct { c *DefaultCtx } // App returns the *App reference to the instance of the Fiber application func (r *DefaultRes) App() *App { return r.c.app } // Append the specified value to the HTTP response header field. // If the header is not already set, it creates the header with the specified value. func (r *DefaultRes) Append(field string, values ...string) { if len(values) == 0 { return } h := r.c.app.toString(r.c.fasthttp.Response.Header.Peek(field)) originalH := h for _, value := range values { if h == "" { h = value } else if !headerContainsValue(h, value) { h += ", " + value } } if originalH != h { r.Set(field, h) } } // headerContainsValue checks if a header value already contains the given value // as a comma-separated element. Per RFC 9110, list elements are separated by commas // with optional whitespace (OWS) around them. func headerContainsValue(header, value string) bool { // Empty value should never match if value == "" { return false } // Exact match (single value header) if header == value { return true } // Check each comma-separated element, handling optional whitespace (OWS) for part := range strings.SplitSeq(header, ",") { if utils.TrimSpace(part) == value { return true } } return false } func sanitizeFilename(filename string) string { for _, r := range filename { if unicode.IsControl(r) { b := make([]byte, 0, len(filename)) for _, rr := range filename { if !unicode.IsControl(rr) { b = utf8.AppendRune(b, rr) } } return utils.TrimSpace(string(b)) } } return utils.TrimSpace(filename) } func fallbackFilenameIfInvalid(filename string) string { if filename == "" || filename == "." { return "download" } return filename } // Attachment sets the HTTP response Content-Disposition header field to attachment. func (r *DefaultRes) Attachment(filename ...string) { if len(filename) > 0 { fname := filepath.Base(filename[0]) fname = sanitizeFilename(fname) fname = fallbackFilenameIfInvalid(fname) r.Type(filepath.Ext(fname)) app := r.c.app var quoted string if app.isASCII(fname) { quoted = app.quoteString(fname) } else { quoted = app.quoteRawString(fname) } disp := `attachment; filename="` + quoted + `"` if !app.isASCII(fname) { disp += `; filename*=UTF-8''` + url.PathEscape(fname) } r.setCanonical(HeaderContentDisposition, disp) return } r.setCanonical(HeaderContentDisposition, "attachment") } // ClearCookie expires a specific cookie by key on the client side. // If no key is provided it expires all cookies that came with the request. func (r *DefaultRes) ClearCookie(key ...string) { request := &r.c.fasthttp.Request response := &r.c.fasthttp.Response if len(key) > 0 { for i := range key { response.Header.DelClientCookie(key[i]) } return } for k := range request.Header.Cookies() { response.Header.DelClientCookieBytes(k) } } // RequestCtx returns *fasthttp.RequestCtx that carries a deadline // a cancellation signal, and other values across API boundaries. func (r *DefaultRes) RequestCtx() *fasthttp.RequestCtx { return r.c.fasthttp } // Cookie sets a cookie by passing a cookie struct. func (r *DefaultRes) Cookie(cookie *Cookie) { if cookie.Path == "" { cookie.Path = "/" } if cookie.SessionOnly { cookie.MaxAge = 0 cookie.Expires = time.Time{} } var sameSite http.SameSite switch { case utils.EqualFold(cookie.SameSite, CookieSameSiteStrictMode): sameSite = http.SameSiteStrictMode case utils.EqualFold(cookie.SameSite, CookieSameSiteNoneMode): sameSite = http.SameSiteNoneMode // SameSite=None requires Secure=true per RFC and browser requirements cookie.Secure = true case utils.EqualFold(cookie.SameSite, CookieSameSiteDisabled): sameSite = 0 case utils.EqualFold(cookie.SameSite, CookieSameSiteLaxMode): sameSite = http.SameSiteLaxMode default: sameSite = http.SameSiteLaxMode } // Partitioned requires Secure=true per CHIPS spec if cookie.Partitioned { cookie.Secure = true } // create/validate cookie using net/http hc := &http.Cookie{ Name: cookie.Name, Value: cookie.Value, Path: cookie.Path, Domain: cookie.Domain, Expires: cookie.Expires, MaxAge: cookie.MaxAge, Secure: cookie.Secure, HttpOnly: cookie.HTTPOnly, SameSite: sameSite, Partitioned: cookie.Partitioned, } if err := hc.Valid(); err != nil { // invalid cookies are ignored, same approach as net/http return } // create fasthttp cookie fcookie := fasthttp.AcquireCookie() fcookie.SetKey(hc.Name) fcookie.SetValue(hc.Value) fcookie.SetPath(hc.Path) fcookie.SetDomain(hc.Domain) if !cookie.SessionOnly { fcookie.SetMaxAge(hc.MaxAge) fcookie.SetExpire(hc.Expires) } fcookie.SetSecure(hc.Secure) fcookie.SetHTTPOnly(hc.HttpOnly) switch sameSite { case http.SameSiteLaxMode: fcookie.SetSameSite(fasthttp.CookieSameSiteLaxMode) case http.SameSiteStrictMode: fcookie.SetSameSite(fasthttp.CookieSameSiteStrictMode) case http.SameSiteNoneMode: fcookie.SetSameSite(fasthttp.CookieSameSiteNoneMode) case http.SameSiteDefaultMode: fcookie.SetSameSite(fasthttp.CookieSameSiteDefaultMode) default: fcookie.SetSameSite(fasthttp.CookieSameSiteDisabled) } fcookie.SetPartitioned(hc.Partitioned) // Set resp header r.c.fasthttp.Response.Header.SetCookie(fcookie) fasthttp.ReleaseCookie(fcookie) } // Download transfers the file from path as an attachment. // Typically, browsers will prompt the user for download. // By default, the Content-Disposition header filename= parameter is the filepath (this typically appears in the browser dialog). // Override this default with the filename parameter. func (r *DefaultRes) Download(file string, filename ...string) error { var fname string if len(filename) > 0 { fname = filepath.Base(filename[0]) } else { fname = filepath.Base(file) } fname = sanitizeFilename(fname) fname = fallbackFilenameIfInvalid(fname) app := r.c.app var quoted string if app.isASCII(fname) { quoted = app.quoteString(fname) } else { quoted = app.quoteRawString(fname) } disp := `attachment; filename="` + quoted + `"` if !app.isASCII(fname) { disp += `; filename*=UTF-8''` + url.PathEscape(fname) } r.setCanonical(HeaderContentDisposition, disp) return r.SendFile(file) } // Response return the *fasthttp.Response object // This allows you to use all fasthttp response methods // https://godoc.org/github.com/valyala/fasthttp#Response func (r *DefaultRes) Response() *fasthttp.Response { return &r.c.fasthttp.Response } // Format performs content-negotiation on the Accept HTTP header. // It uses Accepts to select a proper format and calls the matching // user-provided handler function. // If no accepted format is found, and a format with MediaType "default" is given, // that default handler is called. If no format is found and no default is given, // StatusNotAcceptable is sent. func (r *DefaultRes) Format(handlers ...ResFmt) error { if len(handlers) == 0 { return ErrNoHandlers } for i, h := range handlers { if h.Handler == nil { return fmt.Errorf("format handler is nil for media type %q at index %d", h.MediaType, i) } } r.Vary(HeaderAccept) if r.c.DefaultReq.Get(HeaderAccept) == "" { r.c.fasthttp.Response.Header.SetContentType(handlers[0].MediaType) return handlers[0].Handler(r.c) } // Using an int literal as the slice capacity allows for the slice to be // allocated on the stack. The number was chosen arbitrarily as an // approximation of the maximum number of content types a user might handle. // If the user goes over, it just causes allocations, so it's not a problem. types := make([]string, 0, 8) var defaultHandler Handler for _, h := range handlers { if h.MediaType == "default" { defaultHandler = h.Handler continue } types = append(types, h.MediaType) } accept := r.c.DefaultReq.Accepts(types...) //nolint:staticcheck // It is fine to ignore the static check if accept == "" { if defaultHandler == nil { return r.SendStatus(StatusNotAcceptable) } return defaultHandler(r.c) } for _, h := range handlers { if h.MediaType == accept { r.c.fasthttp.Response.Header.SetContentType(h.MediaType) return h.Handler(r.c) } } return fmt.Errorf("%w: format: an Accept was found but no handler was called", errUnreachable) } // AutoFormat performs content-negotiation on the Accept HTTP header. // It uses Accepts to select a proper format. // The supported content types are text/html, text/plain, application/json, application/xml, application/vnd.msgpack, and application/cbor. // For more flexible content negotiation, use Format. // If the header is not specified or there is no proper format, text/plain is used. func (r *DefaultRes) AutoFormat(body any) error { // Get accepted content type accept := r.c.DefaultReq.Accepts("html", "json", "txt", "xml", "msgpack", "cbor") //nolint:staticcheck // It is fine to ignore the static check // Set accepted content type r.Type(accept) // Type convert provided body var b string switch val := body.(type) { case string: b = val case []byte: b = r.c.app.toString(val) default: b = fmt.Sprintf("%v", val) } // Format based on the accept content type switch accept { case "txt": return r.SendString(b) case "json": return r.JSON(body) case "xml": return r.XML(body) case "html": return r.SendString("

" + b + "

") case "msgpack": return r.MsgPack(body) case "cbor": return r.CBOR(body) } // Default case return r.SendString(b) } // Get (a.k.a. GetRespHeader) returns the HTTP response header specified by field. // Field names are case-insensitive // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. func (r *DefaultRes) Get(key string, defaultValue ...string) string { return defaultString(r.c.app.toString(r.c.fasthttp.Response.Header.Peek(key)), defaultValue) } // GetHeaders (a.k.a GetRespHeaders) returns the HTTP response headers. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. func (r *DefaultRes) GetHeaders() map[string][]string { app := r.c.app respHeader := &r.c.fasthttp.Response.Header // Pre-allocate map with known header count to avoid reallocations headers := make(map[string][]string, respHeader.Len()) for k, v := range respHeader.All() { key := app.toString(k) headers[key] = append(headers[key], app.toString(v)) } return headers } // JSON converts any interface or string to JSON. // Array and slice values encode as JSON arrays, // except that []byte encodes as a base64-encoded string, // and a nil slice encodes as the null JSON value. // If the ctype parameter is given, this method will set the // Content-Type header equal to ctype. If ctype is not given, // The Content-Type header will be set to application/json; charset=utf-8. func (r *DefaultRes) JSON(data any, ctype ...string) error { raw, err := r.c.app.config.JSONEncoder(data) if err != nil { return err } response := &r.c.fasthttp.Response response.SetBodyRaw(raw) if len(ctype) > 0 { response.Header.SetContentType(ctype[0]) } else { response.Header.SetContentType(MIMEApplicationJSONCharsetUTF8) } return nil } // MsgPack converts any interface or string to MessagePack encoded bytes. // If the ctype parameter is given, this method will set the // Content-Type header equal to ctype. If ctype is not given, // The Content-Type header will be set to application/vnd.msgpack. func (r *DefaultRes) MsgPack(data any, ctype ...string) error { raw, err := r.c.app.config.MsgPackEncoder(data) if err != nil { return err } response := &r.c.fasthttp.Response response.SetBodyRaw(raw) if len(ctype) > 0 { response.Header.SetContentType(ctype[0]) } else { response.Header.SetContentType(MIMEApplicationMsgPack) } return nil } // CBOR converts any interface or string to CBOR encoded bytes. // If the ctype parameter is given, this method will set the // Content-Type header equal to ctype. If ctype is not given, // The Content-Type header will be set to application/cbor. func (r *DefaultRes) CBOR(data any, ctype ...string) error { raw, err := r.c.app.config.CBOREncoder(data) if err != nil { return err } response := &r.c.fasthttp.Response response.SetBodyRaw(raw) if len(ctype) > 0 { response.Header.SetContentType(ctype[0]) } else { response.Header.SetContentType(MIMEApplicationCBOR) } return nil } // JSONP sends a JSON response with JSONP support. // This method is identical to JSON, except that it opts-in to JSONP callback support. // By default, the callback name is simply callback. func (r *DefaultRes) JSONP(data any, callback ...string) error { raw, err := r.c.app.config.JSONEncoder(data) if err != nil { return err } cb := "callback" if len(callback) > 0 { cb = callback[0] } // Build JSONP response: callback(data); // Use bytebufferpool to avoid string concatenation allocations buf := bytebufferpool.Get() buf.WriteString(cb) buf.WriteByte('(') buf.Write(raw) buf.WriteString(");") r.setCanonical(HeaderXContentTypeOptions, "nosniff") r.c.fasthttp.Response.Header.SetContentType(MIMETextJavaScriptCharsetUTF8) // Use SetBody (not SetBodyRaw) to copy the bytes before returning buffer to pool r.c.fasthttp.Response.SetBody(buf.Bytes()) bytebufferpool.Put(buf) return nil } // XML converts any interface or string to XML. // This method also sets the content header to application/xml; charset=utf-8. func (r *DefaultRes) XML(data any) error { raw, err := r.c.app.config.XMLEncoder(data) if err != nil { return err } response := &r.c.fasthttp.Response response.SetBodyRaw(raw) response.Header.SetContentType(MIMEApplicationXMLCharsetUTF8) return nil } // Links joins the links followed by the property to populate the response's Link HTTP header field. func (r *DefaultRes) Links(link ...string) { if len(link) == 0 { return } bb := bytebufferpool.Get() for i := range link { if i%2 == 0 { bb.WriteByte('<') bb.WriteString(link[i]) bb.WriteByte('>') } else { bb.WriteString(`; rel="`) bb.WriteString(link[i]) bb.WriteString(`",`) } } r.setCanonical(HeaderLink, utils.TrimRight(r.c.app.toString(bb.Bytes()), ',')) bytebufferpool.Put(bb) } // Location sets the response Location HTTP header to the specified path parameter. func (r *DefaultRes) Location(path string) { r.setCanonical(HeaderLocation, path) } // OriginalURL contains the original request URL. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting to use the value outside the Handler. func (r *DefaultRes) OriginalURL() string { return r.c.OriginalURL() } // Redirect returns the Redirect reference. // Use Redirect().Status() to set custom redirection status code. // If status is not specified, status defaults to 303 See Other. // You can use Redirect().To(), Redirect().Route() and Redirect().Back() for redirection. func (r *DefaultRes) Redirect() *Redirect { return r.c.Redirect() } // ViewBind Add vars to default view var map binding to template engine. // Variables are read by the Render method and may be overwritten. func (r *DefaultRes) ViewBind(vars Map) error { return r.c.ViewBind(vars) } // getLocationFromRoute get URL location from route using parameters func (r *DefaultRes) getLocationFromRoute(route *Route, params Map) (string, error) { if route == nil || route.Path == "" { return "", ErrNotFound } app := r.c.app buf := bytebufferpool.Get() for _, segment := range route.routeParser.segs { if !segment.IsParam { _, err := buf.WriteString(segment.Const) if err != nil { return "", fmt.Errorf("failed to write string: %w", err) } continue } for key, val := range params { isSame := key == segment.ParamName || (!app.config.CaseSensitive && utils.EqualFold(key, segment.ParamName)) isGreedy := segment.IsGreedy && len(key) == 1 && bytes.IndexByte(greedyParameters, key[0]) >= 0 if isSame || isGreedy { _, err := buf.WriteString(utils.ToString(val)) if err != nil { return "", fmt.Errorf("failed to write string: %w", err) } } } } location := buf.String() // release buffer bytebufferpool.Put(buf) return location, nil } // GetRouteURL generates URLs to named routes, with parameters. URLs are relative, for example: "/user/1831" func (r *DefaultRes) GetRouteURL(routeName string, params Map) (string, error) { route := r.c.app.GetRoute(routeName) return r.getLocationFromRoute(&route, params) } // Render a template with data and sends a text/html response. // We support the following engines: https://github.com/gofiber/template func (r *DefaultRes) Render(name string, bind any, layouts ...string) error { // Get new buffer from pool buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) // Initialize empty bind map if bind is nil if bind == nil { bind = make(Map) } // Pass-locals-to-views, bind, appListKeys r.c.renderExtensions(bind) rootApp := r.c.app var rendered bool for i := len(rootApp.mountFields.appListKeys) - 1; i >= 0; i-- { prefix := rootApp.mountFields.appListKeys[i] app := rootApp.mountFields.appList[prefix] if prefix == "" || strings.Contains(r.c.OriginalURL(), prefix) { if len(layouts) == 0 && app.config.ViewsLayout != "" { layouts = []string{ app.config.ViewsLayout, } } // Render template from Views if app.config.Views != nil { if err := app.config.Views.Render(buf, name, bind, layouts...); err != nil { return fmt.Errorf("failed to render: %w", err) } rendered = true break } } } if !rendered { // Render raw template using 'name' as filepath if no engine is set var tmpl *template.Template if _, err := readContent(buf, name); err != nil { return err } // Parse template tmpl, err := template.New("").Parse(rootApp.toString(buf.Bytes())) if err != nil { return fmt.Errorf("failed to parse: %w", err) } buf.Reset() // Render template if err := tmpl.Execute(buf, bind); err != nil { return fmt.Errorf("failed to execute: %w", err) } } response := &r.c.fasthttp.Response // Set Content-Type to text/html response.Header.SetContentType(MIMETextHTMLCharsetUTF8) // Set rendered template to body response.SetBody(buf.Bytes()) return nil } func (r *DefaultRes) renderExtensions(bind any) { r.c.renderExtensions(bind) } // Send sets the HTTP response body without copying it. // From this point onward the body argument must not be changed. func (r *DefaultRes) Send(body []byte) error { // Write response body r.c.fasthttp.Response.SetBodyRaw(body) return nil } // SendEarlyHints allows the server to hint to the browser what resources a page would need // so the browser can preload them while waiting for the server's full response. Only Link // headers already written to the response will be transmitted as Early Hints. // // This is a HTTP/2+ feature but all browsers will either understand it or safely ignore it. // // NOTE: Older HTTP/1.1 non-browser clients may face compatibility issues. // // See: https://developer.chrome.com/docs/web-platform/early-hints and // https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Link#syntax func (r *DefaultRes) SendEarlyHints(hints []string) error { if len(hints) == 0 { return nil } for _, h := range hints { r.c.fasthttp.Response.Header.Add("Link", h) } return r.c.fasthttp.EarlyHints() } // SendFile transfers the file from the specified path. // By default, the file is not compressed. To enable compression, set SendFile.Compress to true. // The Content-Type response HTTP header field is set based on the file's extension. // If the file extension is missing or invalid, the Content-Type is detected from the file's format. func (r *DefaultRes) SendFile(file string, config ...SendFile) error { // Save the filename, we will need it in the error message if the file isn't found filename := file var cfg SendFile if len(config) > 0 { cfg = config[0] } if cfg.CacheDuration == 0 { cfg.CacheDuration = 10 * time.Second } var fsHandler fasthttp.RequestHandler var cacheControlValue string app := r.c.app app.sendfilesMutex.RLock() for _, sf := range app.sendfiles { if sf.configEqual(cfg) { fsHandler = sf.handler cacheControlValue = sf.cacheControlValue break } } app.sendfilesMutex.RUnlock() if fsHandler == nil { fasthttpFS := &fasthttp.FS{ Root: "", FS: cfg.FS, AllowEmptyRoot: true, GenerateIndexPages: false, AcceptByteRange: cfg.ByteRange, Compress: cfg.Compress, CompressBrotli: cfg.Compress, CompressZstd: cfg.Compress, CompressedFileSuffixes: app.config.CompressedFileSuffixes, CacheDuration: cfg.CacheDuration, SkipCache: cfg.CacheDuration < 0, IndexNames: []string{"index.html"}, PathNotFound: func(ctx *fasthttp.RequestCtx) { ctx.Response.SetStatusCode(StatusNotFound) }, } if cfg.FS != nil { fasthttpFS.Root = "." } sf := &sendFileStore{ config: cfg, handler: fasthttpFS.NewRequestHandler(), } maxAge := cfg.MaxAge if maxAge > 0 { sf.cacheControlValue = "public, max-age=" + strconv.Itoa(maxAge) } // set vars fsHandler = sf.handler cacheControlValue = sf.cacheControlValue app.sendfilesMutex.Lock() app.sendfiles = append(app.sendfiles, sf) app.sendfilesMutex.Unlock() } // Keep original path for mutable params r.c.pathOriginal = utils.CopyString(r.c.pathOriginal) request := &r.c.fasthttp.Request // Delete the Accept-Encoding header if compression is disabled if !cfg.Compress { // https://github.com/valyala/fasthttp/blob/7cc6f4c513f9e0d3686142e0a1a5aa2f76b3194a/fs.go#L55 request.Header.Del(HeaderAcceptEncoding) } // copy of https://github.com/valyala/fasthttp/blob/7cc6f4c513f9e0d3686142e0a1a5aa2f76b3194a/fs.go#L103-L121 with small adjustments if file == "" || (!filepath.IsAbs(file) && cfg.FS == nil) { // extend relative path to absolute path hasTrailingSlash := file != "" && (file[len(file)-1] == '/' || file[len(file)-1] == '\\') var err error file = filepath.FromSlash(file) if file, err = filepath.Abs(file); err != nil { return fmt.Errorf("failed to determine abs file path: %w", err) } if hasTrailingSlash { file += "/" } } // convert the path to forward slashes regardless the OS in order to set the URI properly // the handler will convert back to OS path separator before opening the file file = filepath.ToSlash(file) // Restore the original requested URL originalURL := utils.CopyString(r.c.OriginalURL()) defer request.SetRequestURI(originalURL) // Set new URI for fileHandler request.SetRequestURI(file) var ( sendFileSize int64 hasSendFileSize bool ) if cfg.ByteRange && len(request.Header.Peek(HeaderRange)) > 0 { sizePath := file if cfg.FS != nil { sizePath = filepath.ToSlash(filename) } if size, err := sendFileContentLength(sizePath, cfg); err == nil { sendFileSize = size hasSendFileSize = true } } // Save status code response := &r.c.fasthttp.Response status := response.StatusCode() // Serve file fsHandler(r.c.fasthttp) // Sets the response Content-Disposition header to attachment if the Download option is true if cfg.Download { r.Attachment() } // Get the status code which is set by fasthttp fsStatus := response.StatusCode() // Check for error if status != StatusNotFound && fsStatus == StatusNotFound { return NewError(StatusNotFound, fmt.Sprintf("sendfile: file %s not found", filename)) } // Set the status code set by the user if it is different from the fasthttp status code and 200 if status != fsStatus && status != StatusOK { r.Status(status) } // Apply cache control header if status != StatusNotFound && status != StatusForbidden { if cfg.ByteRange && hasSendFileSize && response.StatusCode() == StatusRequestedRangeNotSatisfiable && len(response.Header.Peek(HeaderContentRange)) == 0 { response.Header.Set(HeaderContentRange, "bytes */"+strconv.FormatInt(sendFileSize, 10)) } if cacheControlValue != "" { response.Header.Set(HeaderCacheControl, cacheControlValue) } return nil } return nil } func sendFileContentLength(path string, cfg SendFile) (int64, error) { if cfg.FS != nil { cleanPath := pathpkg.Clean(utils.TrimLeft(path, '/')) if cleanPath == "." { cleanPath = "" } info, err := fs.Stat(cfg.FS, cleanPath) if err != nil { return 0, fmt.Errorf("stat %q: %w", cleanPath, err) } return info.Size(), nil } info, err := os.Stat(filepath.FromSlash(path)) if err != nil { return 0, fmt.Errorf("stat %q: %w", path, err) } return info.Size(), nil } // SendStatus sets the HTTP status code and if the response body is empty, // it sets the correct status message in the body. func (r *DefaultRes) SendStatus(status int) error { r.Status(status) if statusDisallowsBody(status) { r.c.fasthttp.Response.ResetBody() return nil } // Only set status body when there is no response body if len(r.c.fasthttp.Response.Body()) == 0 { return r.SendString(utils.StatusMessage(status)) } return nil } // SendString sets the HTTP response body for string types. // This means no type assertion, recommended for faster performance func (r *DefaultRes) SendString(body string) error { r.c.fasthttp.Response.SetBodyString(body) return nil } // SendStream sets response body stream and optional body size. func (r *DefaultRes) SendStream(stream io.Reader, size ...int) error { if len(size) > 0 && size[0] >= 0 { r.c.fasthttp.Response.SetBodyStream(stream, size[0]) } else { r.c.fasthttp.Response.SetBodyStream(stream, -1) } return nil } // SendStreamWriter sets response body stream writer func (r *DefaultRes) SendStreamWriter(streamWriter func(*bufio.Writer)) error { r.c.fasthttp.Response.SetBodyStreamWriter(fasthttp.StreamWriter(streamWriter)) return nil } // Set sets the response's HTTP header field to the specified key, value. func (r *DefaultRes) Set(key, val string) { r.c.fasthttp.Response.Header.Set(key, val) } func (r *DefaultRes) setCanonical(key, val string) { r.c.fasthttp.Response.Header.SetCanonical(utils.UnsafeBytes(key), utils.UnsafeBytes(val)) } // Status sets the HTTP status for the response. // This method is chainable. func (r *DefaultRes) Status(status int) Ctx { r.c.fasthttp.Response.SetStatusCode(status) return r.c } func statusDisallowsBody(status int) bool { // As per RFC 9110, 1xx (Informational) responses cannot have a body. if status >= 100 && status < 200 { return true } switch status { case StatusNoContent, StatusResetContent, StatusNotModified: return true default: return false } } // Type sets the Content-Type HTTP header to the MIME type specified by the file extension. func (r *DefaultRes) Type(extension string, charset ...string) Ctx { mimeType := utils.GetMIME(extension) if len(charset) > 0 { r.c.fasthttp.Response.Header.SetContentType(mimeType + "; charset=" + charset[0]) } else { // Automatically add UTF-8 charset for text-based MIME types if shouldIncludeCharset(mimeType) { r.c.fasthttp.Response.Header.SetContentType(mimeType + "; charset=utf-8") } else { r.c.fasthttp.Response.Header.SetContentType(mimeType) } } return r.c } // shouldIncludeCharset determines if a MIME type should include UTF-8 charset by default func shouldIncludeCharset(mimeType string) bool { // Everything under text/ gets UTF-8 by default. if strings.HasPrefix(mimeType, "text/") { return true } // Explicit application types that should default to UTF-8. switch mimeType { case MIMEApplicationJSON, MIMEApplicationJavaScript, MIMEApplicationXML: return true } // Any application/*+json or application/*+xml. if strings.HasSuffix(mimeType, "+json") || strings.HasSuffix(mimeType, "+xml") { return true } return false } // Vary adds the given header field to the Vary response header. // This will append the header, if not already listed; otherwise, leaves it listed in the current location. func (r *DefaultRes) Vary(fields ...string) { r.Append(HeaderVary, fields...) } // Write appends p into response body. func (r *DefaultRes) Write(p []byte) (int, error) { r.c.fasthttp.Response.AppendBody(p) return len(p), nil } // Writef appends f & a into response body writer. func (r *DefaultRes) Writef(f string, a ...any) (int, error) { //nolint:wrapcheck // This must not be wrapped return fmt.Fprintf(r.c.fasthttp.Response.BodyWriter(), f, a...) } // WriteString appends s to response body. func (r *DefaultRes) WriteString(s string) (int, error) { r.c.fasthttp.Response.AppendBodyString(s) return len(s), nil } // Release is a method to reset Res fields when to use ReleaseCtx() func (r *DefaultRes) release() { r.c = nil } // Drop closes the underlying connection without sending any response headers or body. // This can be useful for silently terminating client connections, such as in DDoS mitigation // or when blocking access to sensitive endpoints. func (r *DefaultRes) Drop() error { //nolint:wrapcheck // error wrapping is avoided to keep the operation lightweight and focused on connection closure. return r.c.fasthttp.Conn().Close() } // End immediately flushes the current response and closes the underlying connection. // // Note: End does not work when using streaming (e.g. fasthttp's HijackConn or SendStream), // because in streaming mode the connection is managed asynchronously and ctx.Conn() may return nil. func (r *DefaultRes) End() error { ctx := r.c.fasthttp if ctx == nil { return nil } conn := ctx.Conn() if conn == nil { return nil } bw := bufio.NewWriter(conn) if err := ctx.Response.Write(bw); err != nil { return err } if err := bw.Flush(); err != nil { return err //nolint:wrapcheck // unnecessary to wrap it } return conn.Close() //nolint:wrapcheck // unnecessary to wrap it } ================================================ FILE: res_interface_gen.go ================================================ // Code generated by ifacemaker; DO NOT EDIT. package fiber import ( "bufio" "io" "github.com/valyala/fasthttp" ) // Res is an interface for response-related Ctx methods. type Res interface { // App returns the *App reference to the instance of the Fiber application App() *App // Append the specified value to the HTTP response header field. // If the header is not already set, it creates the header with the specified value. Append(field string, values ...string) // Attachment sets the HTTP response Content-Disposition header field to attachment. Attachment(filename ...string) // ClearCookie expires a specific cookie by key on the client side. // If no key is provided it expires all cookies that came with the request. ClearCookie(key ...string) // RequestCtx returns *fasthttp.RequestCtx that carries a deadline // a cancellation signal, and other values across API boundaries. RequestCtx() *fasthttp.RequestCtx // Cookie sets a cookie by passing a cookie struct. Cookie(cookie *Cookie) // Download transfers the file from path as an attachment. // Typically, browsers will prompt the user for download. // By default, the Content-Disposition header filename= parameter is the filepath (this typically appears in the browser dialog). // Override this default with the filename parameter. Download(file string, filename ...string) error // Response return the *fasthttp.Response object // This allows you to use all fasthttp response methods // https://godoc.org/github.com/valyala/fasthttp#Response Response() *fasthttp.Response // Format performs content-negotiation on the Accept HTTP header. // It uses Accepts to select a proper format and calls the matching // user-provided handler function. // If no accepted format is found, and a format with MediaType "default" is given, // that default handler is called. If no format is found and no default is given, // StatusNotAcceptable is sent. Format(handlers ...ResFmt) error // AutoFormat performs content-negotiation on the Accept HTTP header. // It uses Accepts to select a proper format. // The supported content types are text/html, text/plain, application/json, application/xml, application/vnd.msgpack, and application/cbor. // For more flexible content negotiation, use Format. // If the header is not specified or there is no proper format, text/plain is used. AutoFormat(body any) error // Get (a.k.a. GetRespHeader) returns the HTTP response header specified by field. // Field names are case-insensitive // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. Get(key string, defaultValue ...string) string // GetHeaders (a.k.a GetRespHeaders) returns the HTTP response headers. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. GetHeaders() map[string][]string // JSON converts any interface or string to JSON. // Array and slice values encode as JSON arrays, // except that []byte encodes as a base64-encoded string, // and a nil slice encodes as the null JSON value. // If the ctype parameter is given, this method will set the // Content-Type header equal to ctype. If ctype is not given, // The Content-Type header will be set to application/json; charset=utf-8. JSON(data any, ctype ...string) error // MsgPack converts any interface or string to MessagePack encoded bytes. // If the ctype parameter is given, this method will set the // Content-Type header equal to ctype. If ctype is not given, // The Content-Type header will be set to application/vnd.msgpack. MsgPack(data any, ctype ...string) error // CBOR converts any interface or string to CBOR encoded bytes. // If the ctype parameter is given, this method will set the // Content-Type header equal to ctype. If ctype is not given, // The Content-Type header will be set to application/cbor. CBOR(data any, ctype ...string) error // JSONP sends a JSON response with JSONP support. // This method is identical to JSON, except that it opts-in to JSONP callback support. // By default, the callback name is simply callback. JSONP(data any, callback ...string) error // XML converts any interface or string to XML. // This method also sets the content header to application/xml; charset=utf-8. XML(data any) error // Links joins the links followed by the property to populate the response's Link HTTP header field. Links(link ...string) // Location sets the response Location HTTP header to the specified path parameter. Location(path string) // OriginalURL contains the original request URL. // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting to use the value outside the Handler. OriginalURL() string // Redirect returns the Redirect reference. // Use Redirect().Status() to set custom redirection status code. // If status is not specified, status defaults to 303 See Other. // You can use Redirect().To(), Redirect().Route() and Redirect().Back() for redirection. Redirect() *Redirect // ViewBind Add vars to default view var map binding to template engine. // Variables are read by the Render method and may be overwritten. ViewBind(vars Map) error // getLocationFromRoute get URL location from route using parameters getLocationFromRoute(route *Route, params Map) (string, error) // GetRouteURL generates URLs to named routes, with parameters. URLs are relative, for example: "/user/1831" GetRouteURL(routeName string, params Map) (string, error) // Render a template with data and sends a text/html response. // We support the following engines: https://github.com/gofiber/template Render(name string, bind any, layouts ...string) error renderExtensions(bind any) // Send sets the HTTP response body without copying it. // From this point onward the body argument must not be changed. Send(body []byte) error // SendEarlyHints allows the server to hint to the browser what resources a page would need // so the browser can preload them while waiting for the server's full response. Only Link // headers already written to the response will be transmitted as Early Hints. // // This is a HTTP/2+ feature but all browsers will either understand it or safely ignore it. // // NOTE: Older HTTP/1.1 non-browser clients may face compatibility issues. // // See: https://developer.chrome.com/docs/web-platform/early-hints and // https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Link#syntax SendEarlyHints(hints []string) error // SendFile transfers the file from the specified path. // By default, the file is not compressed. To enable compression, set SendFile.Compress to true. // The Content-Type response HTTP header field is set based on the file's extension. // If the file extension is missing or invalid, the Content-Type is detected from the file's format. SendFile(file string, config ...SendFile) error // SendStatus sets the HTTP status code and if the response body is empty, // it sets the correct status message in the body. SendStatus(status int) error // SendString sets the HTTP response body for string types. // This means no type assertion, recommended for faster performance SendString(body string) error // SendStream sets response body stream and optional body size. SendStream(stream io.Reader, size ...int) error // SendStreamWriter sets response body stream writer SendStreamWriter(streamWriter func(*bufio.Writer)) error // Set sets the response's HTTP header field to the specified key, value. Set(key, val string) setCanonical(key, val string) // Status sets the HTTP status for the response. // This method is chainable. Status(status int) Ctx // Type sets the Content-Type HTTP header to the MIME type specified by the file extension. Type(extension string, charset ...string) Ctx // Vary adds the given header field to the Vary response header. // This will append the header, if not already listed; otherwise, leaves it listed in the current location. Vary(fields ...string) // Write appends p into response body. Write(p []byte) (int, error) // Writef appends f & a into response body writer. Writef(f string, a ...any) (int, error) // WriteString appends s to response body. WriteString(s string) (int, error) // Release is a method to reset Res fields when to use ReleaseCtx() release() // Drop closes the underlying connection without sending any response headers or body. // This can be useful for silently terminating client connections, such as in DDoS mitigation // or when blocking access to sensitive endpoints. Drop() error // End immediately flushes the current response and closes the underlying connection. End() error } ================================================ FILE: router.go ================================================ // ⚡️ Fiber is an Express inspired web framework written in Go with ☕️ // 🤖 GitHub Repository: https://github.com/gofiber/fiber // 📌 API Documentation: https://docs.gofiber.io package fiber import ( "fmt" "slices" "sync/atomic" "github.com/gofiber/utils/v2" utilsstrings "github.com/gofiber/utils/v2/strings" "github.com/valyala/fasthttp" ) // Router defines all router handle interface, including app and group router. type Router interface { Use(args ...any) Router Get(path string, handler any, handlers ...any) Router Head(path string, handler any, handlers ...any) Router Post(path string, handler any, handlers ...any) Router Put(path string, handler any, handlers ...any) Router Delete(path string, handler any, handlers ...any) Router Connect(path string, handler any, handlers ...any) Router Options(path string, handler any, handlers ...any) Router Trace(path string, handler any, handlers ...any) Router Patch(path string, handler any, handlers ...any) Router Add(methods []string, path string, handler any, handlers ...any) Router All(path string, handler any, handlers ...any) Router Group(prefix string, handlers ...any) Router RouteChain(path string) Register Route(prefix string, fn func(router Router), name ...string) Router Name(name string) Router } // Route is a struct that holds all metadata for each registered handler. type Route struct { // ### important: always keep in sync with the copy method "app.copyRoute" and all creations of Route struct ### group *Group // Group instance. used for routes in groups path string // Prettified path // Public fields Method string `json:"method"` // HTTP method Name string `json:"name"` // Route's name //nolint:revive // Having both a Path (uppercase) and a path (lowercase) is fine Path string `json:"path"` // Original registered route path Params []string `json:"params"` // Case-sensitive param keys Handlers []Handler `json:"-"` // Ctx handlers routeParser routeParser // Parameter parser // Data for routing use bool // USE matches path prefixes mount bool // Indicated a mounted app on a specific route star bool // Path equals '*' root bool // Path equals '/' autoHead bool // Automatically generated HEAD route } func (r *Route) match(detectionPath, path string, params *[maxParams]string) bool { // root detectionPath check if r.root && len(detectionPath) == 1 && detectionPath[0] == '/' { return true } // '*' wildcard matches any detectionPath if r.star { if len(path) > 1 { params[0] = path[1:] } else { params[0] = "" } return true } // Does this route have parameters? if len(r.Params) > 0 { // Match params using precomputed routeParser if r.routeParser.getMatch(detectionPath, path, params, r.use) { return true } } // Middleware route? if r.use { // Single slash or prefix match plen := len(r.path) if r.root { // If r.root is '/', it matches everything starting at '/' if detectionPath != "" && detectionPath[0] == '/' { return true } } else if len(detectionPath) >= plen && detectionPath[:plen] == r.path { if hasPartialMatchBoundary(detectionPath, plen) { return true } } } else if len(r.path) == len(detectionPath) && detectionPath == r.path { // Check exact match return true } // No match return false } func (app *App) next(c *DefaultCtx) (bool, error) { methodInt := c.methodInt treeHash := c.treePathHash // Get stack length tree, ok := app.treeStack[methodInt][treeHash] if !ok { tree = app.treeStack[methodInt][0] } lenr := len(tree) - 1 indexRoute := c.indexRoute // Loop over the route stack starting from previous index for indexRoute < lenr { // Increment route index indexRoute++ // Get *Route route := tree[indexRoute] if route.mount { continue } // Check if it matches the request path if !route.match(utils.UnsafeString(c.detectionPath), utils.UnsafeString(c.path), &c.values) { continue } if c.skipNonUseRoutes && !route.use { continue } // Pass route reference and param values c.route = route // Non use handler matched if !route.use { c.matched = true } // Execute first handler of route if len(route.Handlers) > 0 { c.indexHandler = 0 c.indexRoute = indexRoute return true, route.Handlers[0](c) } return true, nil // Stop scanning the stack } // If c.Next() does not match, return 404 // If no match, scan stack again if other methods match the request // Moved from app.handler because middleware may break the route chain if c.matched { return false, ErrNotFound } exists := false methods := app.config.RequestMethods for i := range methods { // Skip original method if methodInt == i { continue } // Reset stack index indexRoute := -1 tree, ok := app.treeStack[i][treeHash] if !ok { tree = app.treeStack[i][0] } // Get stack length lenr := len(tree) - 1 // Loop over the route stack starting from previous index for indexRoute < lenr { // Increment route index indexRoute++ // Get *Route route := tree[indexRoute] // Skip use routes if route.use { continue } // Check if it matches the request path // No match, next route if route.match(utils.UnsafeString(c.detectionPath), utils.UnsafeString(c.path), &c.values) { // We matched exists = true // Add method to Allow header c.Append(HeaderAllow, methods[i]) // Break stack loop break } } c.indexRoute = indexRoute } if exists { return false, ErrMethodNotAllowed } return false, ErrNotFound } func (app *App) nextCustom(c CustomCtx) (bool, error) { methodInt := c.getMethodInt() treeHash := c.getTreePathHash() // Get stack length tree, ok := app.treeStack[methodInt][treeHash] if !ok { tree = app.treeStack[methodInt][0] } lenr := len(tree) - 1 indexRoute := c.getIndexRoute() // Loop over the route stack starting from previous index for indexRoute < lenr { // Increment route index indexRoute++ // Get *Route route := tree[indexRoute] if route.mount { continue } // Check if it matches the request path if !route.match(c.getDetectionPath(), c.Path(), c.getValues()) { continue } if c.getSkipNonUseRoutes() && !route.use { continue } // Pass route reference and param values c.setRoute(route) // Non use handler matched if !route.use { c.setMatched(true) } // Execute first handler of route if len(route.Handlers) > 0 { c.setIndexHandler(0) c.setIndexRoute(indexRoute) return true, route.Handlers[0](c) } return true, nil // Stop scanning the stack } // If c.Next() does not match, return 404 // If no match, scan stack again if other methods match the request // Moved from app.handler because middleware may break the route chain if c.getMatched() { return false, ErrNotFound } exists := false methods := app.config.RequestMethods for i := range methods { // Skip original method if methodInt == i { continue } // Reset stack index indexRoute := -1 tree, ok := app.treeStack[i][treeHash] if !ok { tree = app.treeStack[i][0] } // Get stack length lenr := len(tree) - 1 // Loop over the route stack starting from previous index for indexRoute < lenr { // Increment route index indexRoute++ // Get *Route route := tree[indexRoute] // Skip use routes if route.use { continue } // Check if it matches the request path // No match, next route if route.match(c.getDetectionPath(), c.Path(), c.getValues()) { // We matched exists = true // Add method to Allow header c.Append(HeaderAllow, methods[i]) // Break stack loop break } } c.setIndexRoute(indexRoute) } if exists { return false, ErrMethodNotAllowed } return false, ErrNotFound } func (app *App) requestHandler(rctx *fasthttp.RequestCtx) { // Acquire context from the pool ctx := app.AcquireCtx(rctx) defer app.ReleaseCtx(ctx) var err error // Attempt to match a route and execute the chain if d, isDefault := ctx.(*DefaultCtx); isDefault { // Check if the HTTP method is valid if d.methodInt == -1 { _ = d.SendStatus(StatusNotImplemented) //nolint:errcheck // Always return nil return } // Optional: check flash messages (hot path, see hasFlashCookie). if hasFlashCookie(&d.Request().Header) { d.Redirect().parseAndClearFlashMessages() } _, err = app.next(d) } else { // Check if the HTTP method is valid if ctx.getMethodInt() == -1 { _ = ctx.SendStatus(StatusNotImplemented) //nolint:errcheck // Always return nil return } // Optional: check flash messages (hot path, see hasFlashCookie). if hasFlashCookie(&ctx.Request().Header) { ctx.Redirect().parseAndClearFlashMessages() } _, err = app.nextCustom(ctx) } if err != nil { if catch := ctx.App().ErrorHandler(ctx, err); catch != nil { _ = ctx.SendStatus(StatusInternalServerError) //nolint:errcheck // Always return nil } return } } func (app *App) addPrefixToRoute(prefix string, route *Route) *Route { prefixedPath := getGroupPath(prefix, route.Path) prettyPath := prefixedPath // Case-sensitive routing, all to lowercase if !app.config.CaseSensitive { prettyPath = utilsstrings.ToLower(prettyPath) } // Strict routing, remove trailing slashes if !app.config.StrictRouting && len(prettyPath) > 1 { prettyPath = utils.TrimRight(prettyPath, '/') } route.Path = prefixedPath route.path = RemoveEscapeChar(prettyPath) route.routeParser = parseRoute(prettyPath, app.customConstraints...) route.root = false route.star = false return route } func (*App) copyRoute(route *Route) *Route { return &Route{ // Router booleans use: route.use, mount: route.mount, star: route.star, root: route.root, autoHead: route.autoHead, // Path data path: route.path, routeParser: route.routeParser, // Public data Path: route.Path, Params: route.Params, Name: route.Name, Method: route.Method, Handlers: route.Handlers, } } func (app *App) normalizePath(path string) string { if path == "" { path = "/" } if path[0] != '/' { path = "/" + path } if !app.config.CaseSensitive { path = utilsstrings.ToLower(path) } if !app.config.StrictRouting && len(path) > 1 { path = utils.TrimRight(path, '/') } return RemoveEscapeChar(path) } // RemoveRoute is used to remove a route from the stack by path. // If no methods are specified, it will remove the route for all methods defined in the app. // You should call RebuildTree after using this to ensure consistency of the tree. func (app *App) RemoveRoute(path string, methods ...string) { // Normalize same as register uses norm := app.normalizePath(path) pathMatchFunc := func(r *Route) bool { return r.path == norm // compare private normalized path } app.deleteRoute(methods, pathMatchFunc) } // RemoveRouteByName is used to remove a route from the stack by name. // If no methods are specified, it will remove the route for all methods defined in the app. // You should call RebuildTree after using this to ensure consistency of the tree. func (app *App) RemoveRouteByName(name string, methods ...string) { matchFunc := func(r *Route) bool { return r.Name == name } app.deleteRoute(methods, matchFunc) } // RemoveRouteFunc is used to remove a route from the stack by a custom match function. // If no methods are specified, it will remove the route for all methods defined in the app. // You should call RebuildTree after using this to ensure consistency of the tree. // Note: The route.Path is original path, not the normalized path. func (app *App) RemoveRouteFunc(matchFunc func(r *Route) bool, methods ...string) { app.deleteRoute(methods, matchFunc) } func (app *App) deleteRoute(methods []string, matchFunc func(r *Route) bool) { if len(methods) == 0 { methods = app.config.RequestMethods } app.mutex.Lock() defer app.mutex.Unlock() removedUseRoutes := make(map[string]struct{}) for _, method := range methods { // Uppercase HTTP methods method = utilsstrings.ToUpper(method) // Get unique HTTP method identifier m := app.methodInt(method) if m == -1 { continue // Skip invalid HTTP methods } for i := len(app.stack[m]) - 1; i >= 0; i-- { route := app.stack[m][i] if !matchFunc(route) { continue // Skip if route does not match } app.stack[m] = append(app.stack[m][:i], app.stack[m][i+1:]...) app.routesRefreshed = true // Decrement global handler count. In middleware routes, only decrement once if _, ok := removedUseRoutes[route.path]; (route.use && slices.Equal(methods, app.config.RequestMethods) && !ok) || !route.use { if route.use { removedUseRoutes[route.path] = struct{}{} } atomic.AddUint32(&app.handlersCount, ^uint32(len(route.Handlers)-1)) //nolint:gosec // G115 - handler count is always small } if method == MethodGet && !route.use && !route.mount { app.pruneAutoHeadRouteLocked(route.path) } } } } // pruneAutoHeadRouteLocked removes an automatically generated HEAD route so a // later explicit registration can take its place without duplicating handler // chains. The caller must already hold app.mutex. func (app *App) pruneAutoHeadRouteLocked(path string) { headIndex := app.methodInt(MethodHead) if headIndex == -1 { return } norm := app.normalizePath(path) headStack := app.stack[headIndex] for i := len(headStack) - 1; i >= 0; i-- { headRoute := headStack[i] if headRoute.path != norm || headRoute.mount || headRoute.use || !headRoute.autoHead { continue } app.stack[headIndex] = append(headStack[:i], headStack[i+1:]...) app.routesRefreshed = true atomic.AddUint32(&app.handlersCount, ^uint32(len(headRoute.Handlers)-1)) //nolint:gosec // G115 - handler count is always small return } } func (app *App) register(methods []string, pathRaw string, group *Group, handlers ...Handler) { // A regular route requires at least one ctx handler if len(handlers) == 0 && group == nil { panic(fmt.Sprintf("missing handler/middleware in route: %s\n", pathRaw)) } // No nil handlers allowed for _, h := range handlers { if h == nil { panic(fmt.Sprintf("nil handler in route: %s\n", pathRaw)) } } // Precompute path normalization ONCE if pathRaw == "" { pathRaw = "/" } if pathRaw[0] != '/' { pathRaw = "/" + pathRaw } pathPretty := pathRaw if !app.config.CaseSensitive { pathPretty = utilsstrings.ToLower(pathPretty) } if !app.config.StrictRouting && len(pathPretty) > 1 { pathPretty = utils.TrimRight(pathPretty, '/') } pathClean := RemoveEscapeChar(pathPretty) parsedRaw := parseRoute(pathRaw, app.customConstraints...) parsedPretty := parseRoute(pathPretty, app.customConstraints...) isMount := group != nil && group.app != app for _, method := range methods { method = utilsstrings.ToUpper(method) if method != methodUse && app.methodInt(method) == -1 { panic(fmt.Sprintf("add: invalid http method %s\n", method)) } isUse := method == methodUse isStar := pathClean == "/*" isRoot := pathClean == "/" route := Route{ use: isUse, mount: isMount, star: isStar, root: isRoot, path: pathClean, routeParser: parsedPretty, Params: parsedRaw.params, group: group, Path: pathRaw, Method: method, Handlers: handlers, } // Increment global handler count atomic.AddUint32(&app.handlersCount, uint32(len(handlers))) //nolint:gosec // G115 - handler count is always small // Middleware route matches all HTTP methods if isUse { // Add route to all HTTP methods stack for _, m := range app.config.RequestMethods { // Create a route copy to avoid duplicates during compression r := route app.addRoute(m, &r) } } else { // Add route to stack app.addRoute(method, &route) } } } func (app *App) addRoute(method string, route *Route) { app.mutex.Lock() defer app.mutex.Unlock() // Get unique HTTP method identifier m := app.methodInt(method) if method == MethodHead && !route.mount && !route.use { app.pruneAutoHeadRouteLocked(route.path) } // prevent identically route registration l := len(app.stack[m]) if l > 0 && app.stack[m][l-1].Path == route.Path && route.use == app.stack[m][l-1].use && !route.mount && !app.stack[m][l-1].mount { preRoute := app.stack[m][l-1] preRoute.Handlers = append(preRoute.Handlers, route.Handlers...) } else { route.Method = method // Add route to the stack app.stack[m] = append(app.stack[m], route) app.routesRefreshed = true } // Execute onRoute hooks & change latestRoute if not adding mounted route if !route.mount { app.latestRoute = route if err := app.hooks.executeOnRouteHooks(route); err != nil { panic(err) } } } func (app *App) ensureAutoHeadRoutes() { app.mutex.Lock() defer app.mutex.Unlock() app.ensureAutoHeadRoutesLocked() } func (app *App) ensureAutoHeadRoutesLocked() { if app.config.DisableHeadAutoRegister { return } headIndex := app.methodInt(MethodHead) getIndex := app.methodInt(MethodGet) if headIndex == -1 || getIndex == -1 { return } headStack := app.stack[headIndex] existing := make(map[string]struct{}, len(headStack)) for _, route := range headStack { if route.mount || route.use { continue } existing[route.path] = struct{}{} } if len(app.stack[getIndex]) == 0 { return } var added bool for _, route := range app.stack[getIndex] { if route.mount || route.use { continue } if _, ok := existing[route.path]; ok { continue } headRoute := app.copyRoute(route) headRoute.group = route.group headRoute.Method = MethodHead headRoute.autoHead = true // Fasthttp automatically omits response bodies when transmitting // HEAD responses, so the copied GET handler stack can execute // unchanged while still producing an empty body on the wire. headStack = append(headStack, headRoute) existing[route.path] = struct{}{} app.routesRefreshed = true added = true atomic.AddUint32(&app.handlersCount, uint32(len(headRoute.Handlers))) //nolint:gosec // G115 - handler count is always small app.latestRoute = headRoute if err := app.hooks.executeOnRouteHooks(headRoute); err != nil { panic(err) } } if added { app.stack[headIndex] = headStack } } // RebuildTree rebuilds the prefix tree from the previously registered routes. // This method is useful when you want to register routes dynamically after the app has started. // It is not recommended to use this method on production environments because rebuilding // the tree is performance-intensive and not thread-safe in runtime. Since building the tree // is only done in the startupProcess of the app, this method does not make sure that the // routeTree is being safely changed, as it would add a great deal of overhead in the request. // Latest benchmark results showed a degradation from 82.79 ns/op to 94.48 ns/op and can be found in: // https://github.com/gofiber/fiber/issues/2769#issuecomment-2227385283 func (app *App) RebuildTree() *App { app.mutex.Lock() defer app.mutex.Unlock() return app.buildTree() } // buildTree build the prefix tree from the previously registered routes func (app *App) buildTree() *App { // If routes haven't been refreshed, nothing to do if !app.routesRefreshed { return app } // 1) First loop: determine all possible 3-char prefixes ("treePaths") for each method for method := range app.config.RequestMethods { routes := app.stack[method] treePaths := make([]int, len(routes)) globalCount := 0 prefixCounts := make(map[int]int, len(routes)) for i, route := range routes { if len(route.routeParser.segs) > 0 && len(route.routeParser.segs[0].Const) >= maxDetectionPaths { treePaths[i] = int(route.routeParser.segs[0].Const[0])<<16 | int(route.routeParser.segs[0].Const[1])<<8 | int(route.routeParser.segs[0].Const[2]) } if treePaths[i] == 0 { globalCount++ continue } prefixCounts[treePaths[i]]++ } prevBuckets := app.treeStack[method] tsMap := make(map[int][]*Route, len(prefixCounts)+1) tsMap[0] = reuseRouteBucket(prevBuckets, 0, globalCount) for treePath, count := range prefixCounts { tsMap[treePath] = reuseRouteBucket(prevBuckets, treePath, count+globalCount) } for i, route := range routes { treePath := treePaths[i] if treePath == 0 { for bucket := range tsMap { tsMap[bucket] = append(tsMap[bucket], route) } continue } tsMap[treePath] = append(tsMap[treePath], route) } app.treeStack[method] = tsMap } // reset the flag and return app.routesRefreshed = false return app } func reuseRouteBucket(prev map[int][]*Route, key, capHint int) []*Route { if bucket, ok := prev[key]; ok && cap(bucket) >= capHint { return bucket[:0] } return make([]*Route, 0, capHint) } ================================================ FILE: router_test.go ================================================ // ⚡️ Fiber is an Express inspired web framework written in Go with ☕️ // 📃 GitHub Repository: https://github.com/gofiber/fiber // 📌 API Documentation: https://docs.gofiber.io package fiber import ( "bufio" "context" "encoding/json" "errors" "io" "net/http" "net/http/httptest" "os" "reflect" "runtime" "strings" "sync" "testing" "github.com/gofiber/utils/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/valyala/fasthttp" ) var routesFixture routeJSON func init() { dat, err := os.ReadFile("./.github/testdata/testRoutes.json") if err != nil { panic(err) } if err := json.Unmarshal(dat, &routesFixture); err != nil { panic(err) } } func Test_Route_Handler_Order(t *testing.T) { t.Parallel() app := New() var order []int handler1 := func(c Ctx) error { order = append(order, 1) return c.Next() } handler2 := func(c Ctx) error { order = append(order, 2) return c.Next() } handler3 := func(c Ctx) error { order = append(order, 3) return c.Next() } app.Get("/test", handler1, handler2, handler3, func(c Ctx) error { order = append(order, 4) return c.SendStatus(200) }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/test", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") expectedOrder := []int{1, 2, 3, 4} require.Equal(t, expectedOrder, order, "Handler order") } func Test_hasFlashCookieExactMatch(t *testing.T) { t.Parallel() buildRequestWithCookie := func(t *testing.T, cookie string) *fasthttp.Request { t.Helper() rawRequest := strings.NewReader( "GET / HTTP/1.1\r\nHost: localhost\r\nCookie: " + cookie + "\r\n\r\n", ) req := new(fasthttp.Request) require.NoError(t, req.Read(bufio.NewReader(rawRequest))) return req } req := buildRequestWithCookie(t, FlashCookieName+"X=not-the-flash-cookie") require.False(t, hasFlashCookie(&req.Header)) req = buildRequestWithCookie(t, FlashCookieName+"=valid") require.True(t, hasFlashCookie(&req.Header)) var syntheticReq fasthttp.Request syntheticReq.Header.Set(HeaderCookie, FlashCookieName+"=valid") require.False(t, hasFlashCookie(&syntheticReq.Header)) } func Test_Route_MixedFiberAndHTTPHandlers(t *testing.T) { t.Parallel() app := New() var order []string fiberBefore := func(c Ctx) error { order = append(order, "fiber-before") c.Set("X-Fiber", "1") return c.Next() } httpHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { order = append(order, "http-final") w.Header().Set("X-HTTP", "true") _, err := w.Write([]byte("http")) assert.NoError(t, err) }) fiberAfter := func(c Ctx) error { order = append(order, "fiber-after") return c.SendString("fiber") } app.Get("/mixed", fiberBefore, httpHandler, fiberAfter) resp, err := app.Test(httptest.NewRequest(MethodGet, "/mixed", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) t.Cleanup(func() { require.NoError(t, resp.Body.Close()) }) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "http", string(body)) require.Equal(t, "true", resp.Header.Get("X-HTTP")) require.Equal(t, "1", resp.Header.Get("X-Fiber")) require.Equal(t, []string{"fiber-before", "http-final"}, order) } func Test_Route_Group_WithHTTPHandlers(t *testing.T) { t.Parallel() app := New() var order []string app.Use("/api", func(c Ctx) error { order = append(order, "app-use") return c.Next() }) grp := app.Group("/api", func(c Ctx) error { order = append(order, "group-middleware") return c.Next() }) grp.Get("/users", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { order = append(order, "http-handler") _, err := w.Write([]byte("users")) assert.NoError(t, err) })) resp, err := app.Test(httptest.NewRequest(MethodGet, "/api/users", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) t.Cleanup(func() { require.NoError(t, resp.Body.Close()) }) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "users", string(body)) require.Equal(t, []string{"app-use", "group-middleware", "http-handler"}, order) } func Test_RouteChain_WithHTTPHandlers(t *testing.T) { t.Parallel() app := New() chain := app.RouteChain("/combo") chain.Get(func(c Ctx) error { c.Set("X-Chain", "fiber") return c.Next() }, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _, err := w.Write([]byte("combo")) assert.NoError(t, err) })) chain.RouteChain("/nested").Get(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("X-Nested", "true") _, err := w.Write([]byte("nested")) assert.NoError(t, err) })) resp, err := app.Test(httptest.NewRequest(MethodGet, "/combo", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) require.Equal(t, "fiber", resp.Header.Get("X-Chain")) t.Cleanup(func() { require.NoError(t, resp.Body.Close()) }) body, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, "combo", string(body)) nestedResp, err := app.Test(httptest.NewRequest(MethodGet, "/combo/nested", http.NoBody)) require.NoError(t, err) require.Equal(t, 200, nestedResp.StatusCode) require.Equal(t, "true", nestedResp.Header.Get("X-Nested")) t.Cleanup(func() { require.NoError(t, nestedResp.Body.Close()) }) nestedBody, err := io.ReadAll(nestedResp.Body) require.NoError(t, err) require.Equal(t, "nested", string(nestedBody)) } func Test_Route_Match_SameLength(t *testing.T) { t.Parallel() app := New() app.Get("/:param", func(c Ctx) error { return c.SendString(c.Params("param")) }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/:param", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") body, err := io.ReadAll(resp.Body) require.NoError(t, err, "app.Test(req)") require.Equal(t, ":param", app.toString(body)) // with param resp, err = app.Test(httptest.NewRequest(MethodGet, "/test", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") body, err = io.ReadAll(resp.Body) require.NoError(t, err, "app.Test(req)") require.Equal(t, "test", app.toString(body)) } func Test_Route_Match_Star(t *testing.T) { t.Parallel() app := New() app.Get("/*", func(c Ctx) error { return c.SendString(c.Params("*")) }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/*", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") body, err := io.ReadAll(resp.Body) require.NoError(t, err, "app.Test(req)") require.Equal(t, "*", app.toString(body)) // with param resp, err = app.Test(httptest.NewRequest(MethodGet, "/test", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") body, err = io.ReadAll(resp.Body) require.NoError(t, err, "app.Test(req)") require.Equal(t, "test", app.toString(body)) // without parameter route := Route{ star: true, path: "/*", routeParser: routeParser{}, } params := [maxParams]string{} match := route.match("", "", ¶ms) require.True(t, match) require.Equal(t, [maxParams]string{}, params) // with parameter match = route.match("/favicon.ico", "/favicon.ico", ¶ms) require.True(t, match) require.Equal(t, [maxParams]string{"favicon.ico"}, params) // without parameter again match = route.match("", "", ¶ms) require.True(t, match) require.Equal(t, [maxParams]string{}, params) } func Test_Route_Match_Root(t *testing.T) { t.Parallel() app := New() app.Get("/", func(c Ctx) error { return c.SendString("root") }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") body, err := io.ReadAll(resp.Body) require.NoError(t, err, "app.Test(req)") require.Equal(t, "root", app.toString(body)) } func Test_Route_Match_Parser(t *testing.T) { t.Parallel() app := New() app.Get("/foo/:ParamName", func(c Ctx) error { return c.SendString(c.Params("ParamName")) }) app.Get("/Foobar/*", func(c Ctx) error { return c.SendString(c.Params("*")) }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/foo/bar", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") body, err := io.ReadAll(resp.Body) require.NoError(t, err, "app.Test(req)") require.Equal(t, "bar", app.toString(body)) // with star resp, err = app.Test(httptest.NewRequest(MethodGet, "/Foobar/test", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") body, err = io.ReadAll(resp.Body) require.NoError(t, err, "app.Test(req)") require.Equal(t, "test", app.toString(body)) } func TestAutoRegisterHeadRoutes(t *testing.T) { t.Parallel() cases := []struct { name string }{ {name: "auto registers head for get"}, {name: "disable auto register config"}, {name: "explicit head overrides auto head route"}, {name: "auto head for grouped routes"}, {name: "static handler auto head"}, {name: "head without matching route returns 404"}, {name: "late explicit get keeps explicit head"}, {name: "route listing includes auto head"}, {name: "head mirrors status without body"}, } requireClose := func(tb testing.TB, closer io.Closer) { tb.Helper() require.NoError(tb, closer.Close()) } registerCleanup := func(tb testing.TB, body io.ReadCloser) { tb.Helper() tb.Cleanup(func() { requireClose(tb, body) }) } runners := []func(t *testing.T){ func(t *testing.T) { t.Helper() app := New() app.Get("/", func(c Ctx) error { c.Set("X-Test", "auto") return c.SendString("Hello") }) respHead, err := app.Test(httptest.NewRequest(MethodHead, "/", http.NoBody)) require.NoError(t, err) registerCleanup(t, respHead.Body) require.Equal(t, StatusOK, respHead.StatusCode) require.Equal(t, int64(len("Hello")), respHead.ContentLength) require.Equal(t, "auto", respHead.Header.Get("X-Test")) body, err := io.ReadAll(respHead.Body) require.NoError(t, err) require.Empty(t, body) respGet, err := app.Test(httptest.NewRequest(MethodGet, "/", http.NoBody)) require.NoError(t, err) registerCleanup(t, respGet.Body) require.Equal(t, StatusOK, respGet.StatusCode) require.Equal(t, int64(len("Hello")), respGet.ContentLength) data, err := io.ReadAll(respGet.Body) require.NoError(t, err) require.Equal(t, "Hello", string(data)) }, func(t *testing.T) { t.Helper() app := New(Config{DisableHeadAutoRegister: true}) app.Get("/", func(c Ctx) error { return c.SendString("Hello") }) resp, err := app.Test(httptest.NewRequest(MethodHead, "/", http.NoBody)) require.NoError(t, err) registerCleanup(t, resp.Body) require.Equal(t, StatusMethodNotAllowed, resp.StatusCode) }, func(t *testing.T) { t.Helper() app := New() var getCalls int app.Get("/override", func(c Ctx) error { getCalls++ return c.SendString("GET") }) respHead, err := app.Test(httptest.NewRequest(MethodHead, "/override", http.NoBody)) require.NoError(t, err) require.Equal(t, StatusOK, respHead.StatusCode) require.Equal(t, 1, getCalls) requireClose(t, respHead.Body) var headCalls int app.Head("/override", func(c Ctx) error { headCalls++ c.Set("X-Explicit", "true") return c.SendStatus(StatusNoContent) }) respHead, err = app.Test(httptest.NewRequest(MethodHead, "/override", http.NoBody)) require.NoError(t, err) registerCleanup(t, respHead.Body) require.Equal(t, StatusNoContent, respHead.StatusCode) require.Equal(t, "true", respHead.Header.Get("X-Explicit")) body, err := io.ReadAll(respHead.Body) require.NoError(t, err) require.Empty(t, body) require.Equal(t, 1, getCalls) require.Equal(t, 1, headCalls) }, func(t *testing.T) { t.Helper() app := New() group := app.Group("/api") group.Get("/users/:id", func(c Ctx) error { c.Set("X-User", c.Params("id")) return c.SendString("grouped") }) respHead, err := app.Test(httptest.NewRequest(MethodHead, "/api/users/42", http.NoBody)) require.NoError(t, err) registerCleanup(t, respHead.Body) require.Equal(t, StatusOK, respHead.StatusCode) require.Equal(t, "42", respHead.Header.Get("X-User")) body, err := io.ReadAll(respHead.Body) require.NoError(t, err) require.Empty(t, body) }, func(t *testing.T) { t.Helper() const file = "./.github/testdata/testRoutes.json" content, err := os.ReadFile(file) require.NoError(t, err) app := New() app.Get("/file", func(c Ctx) error { return c.SendFile(file) }) respHead, err := app.Test(httptest.NewRequest(MethodHead, "/file", http.NoBody)) require.NoError(t, err) registerCleanup(t, respHead.Body) require.Equal(t, StatusOK, respHead.StatusCode) require.Equal(t, int64(len(content)), respHead.ContentLength) body, err := io.ReadAll(respHead.Body) require.NoError(t, err) require.Empty(t, body) respGet, err := app.Test(httptest.NewRequest(MethodGet, "/file", http.NoBody)) require.NoError(t, err) registerCleanup(t, respGet.Body) data, err := io.ReadAll(respGet.Body) require.NoError(t, err) require.Equal(t, content, data) }, func(t *testing.T) { t.Helper() app := New() resp, err := app.Test(httptest.NewRequest(MethodHead, "/missing", http.NoBody)) require.NoError(t, err) registerCleanup(t, resp.Body) require.Equal(t, StatusNotFound, resp.StatusCode) }, func(t *testing.T) { t.Helper() app := New() var headCalls int app.Head("/late", func(c Ctx) error { headCalls++ c.Set("X-Late", "head") return c.SendStatus(StatusAccepted) }) var getCalls int app.Get("/late", func(c Ctx) error { getCalls++ return c.SendString("ok") }) respHead, err := app.Test(httptest.NewRequest(MethodHead, "/late", http.NoBody)) require.NoError(t, err) registerCleanup(t, respHead.Body) require.Equal(t, StatusAccepted, respHead.StatusCode) require.Equal(t, "head", respHead.Header.Get("X-Late")) require.Equal(t, 1, headCalls) require.Equal(t, 0, getCalls) respGet, err := app.Test(httptest.NewRequest(MethodGet, "/late", http.NoBody)) require.NoError(t, err) registerCleanup(t, respGet.Body) require.Equal(t, StatusOK, respGet.StatusCode) require.Equal(t, 1, getCalls) }, func(t *testing.T) { t.Helper() app := New() app.Get("/list", func(c Ctx) error { return c.SendString("list") }) app.startupProcess() routes := app.GetRoutes() var hasGet, hasHead bool for _, route := range routes { if route.Path == "/list" { if route.Method == MethodGet { hasGet = true } if route.Method == MethodHead { hasHead = true } } } require.True(t, hasGet) require.True(t, hasHead) }, func(t *testing.T) { t.Helper() app := New() app.Get("/nocontent", func(c Ctx) error { return c.SendStatus(StatusNoContent) }) respHead, err := app.Test(httptest.NewRequest(MethodHead, "/nocontent", http.NoBody)) require.NoError(t, err) registerCleanup(t, respHead.Body) require.Equal(t, StatusNoContent, respHead.StatusCode) body, err := io.ReadAll(respHead.Body) require.NoError(t, err) require.Empty(t, body) respGet, err := app.Test(httptest.NewRequest(MethodGet, "/nocontent", http.NoBody)) require.NoError(t, err) registerCleanup(t, respGet.Body) require.Equal(t, StatusNoContent, respGet.StatusCode) }, } require.Len(t, runners, len(cases)) for i, tc := range cases { runner := runners[i] t.Run(tc.name, func(t *testing.T) { t.Parallel() runner(t) }) } } func Test_Route_Match_Middleware(t *testing.T) { t.Parallel() app := New() app.Use("/foo/*", func(c Ctx) error { return c.SendString(c.Params("*")) }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/foo/*", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") body, err := io.ReadAll(resp.Body) require.NoError(t, err, "app.Test(req)") require.Equal(t, "*", app.toString(body)) // with param resp, err = app.Test(httptest.NewRequest(MethodGet, "/foo/bar/fasel", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") body, err = io.ReadAll(resp.Body) require.NoError(t, err, "app.Test(req)") require.Equal(t, "bar/fasel", app.toString(body)) } func Test_Route_Match_UnescapedPath(t *testing.T) { t.Parallel() app := New(Config{UnescapePath: true}) app.Use("/créer", func(c Ctx) error { return c.SendString("test") }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/cr%C3%A9er", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") body, err := io.ReadAll(resp.Body) require.NoError(t, err, "app.Test(req)") require.Equal(t, "test", app.toString(body)) // without special chars resp, err = app.Test(httptest.NewRequest(MethodGet, "/créer", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") // check deactivated behavior app.config.UnescapePath = false resp, err = app.Test(httptest.NewRequest(MethodGet, "/cr%C3%A9er", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusNotFound, resp.StatusCode, "Status code") } func Test_Route_Match_WithEscapeChar(t *testing.T) { t.Parallel() app := New() // static route and escaped part app.Get("/v1/some/resource/name\\:customVerb", func(c Ctx) error { return c.SendString("static") }) // group route group := app.Group("/v2/\\:firstVerb") group.Get("/\\:customVerb", func(c Ctx) error { return c.SendString("group") }) // route with resource param and escaped part app.Get("/v3/:resource/name\\:customVerb", func(c Ctx) error { return c.SendString(c.Params("resource")) }) // check static route resp, err := app.Test(httptest.NewRequest(MethodGet, "/v1/some/resource/name:customVerb", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") body, err := io.ReadAll(resp.Body) require.NoError(t, err, "app.Test(req)") require.Equal(t, "static", app.toString(body)) // check group route resp, err = app.Test(httptest.NewRequest(MethodGet, "/v2/:firstVerb/:customVerb", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") body, err = io.ReadAll(resp.Body) require.NoError(t, err, "app.Test(req)") require.Equal(t, "group", app.toString(body)) // check param route resp, err = app.Test(httptest.NewRequest(MethodGet, "/v3/awesome/name:customVerb", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusOK, resp.StatusCode, "Status code") body, err = io.ReadAll(resp.Body) require.NoError(t, err, "app.Test(req)") require.Equal(t, "awesome", app.toString(body)) } func Test_Route_Match_Middleware_HasPrefix(t *testing.T) { t.Parallel() app := New() app.Use("/foo", func(c Ctx) error { return c.SendString("middleware") }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/foo/bar", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") body, err := io.ReadAll(resp.Body) require.NoError(t, err, "app.Test(req)") require.Equal(t, "middleware", app.toString(body)) } func Test_Route_Match_Middleware_NoBoundary(t *testing.T) { t.Parallel() app := New() app.Use("/foo", func(c Ctx) error { return c.SendStatus(StatusOK) }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/foobar", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, StatusNotFound, resp.StatusCode, "Status code") } func Test_Route_Match_Middleware_Root(t *testing.T) { t.Parallel() app := New() app.Use("/", func(c Ctx) error { return c.SendString("middleware") }) resp, err := app.Test(httptest.NewRequest(MethodGet, "/everything", http.NoBody)) require.NoError(t, err, "app.Test(req)") require.Equal(t, 200, resp.StatusCode, "Status code") body, err := io.ReadAll(resp.Body) require.NoError(t, err, "app.Test(req)") require.Equal(t, "middleware", app.toString(body)) } func Test_Router_Register_Missing_Handler(t *testing.T) { t.Parallel() app := New() t.Run("No Handler", func(t *testing.T) { t.Parallel() require.PanicsWithValue(t, "missing handler/middleware in route: /doe\n", func() { app.register([]string{"USE"}, "/doe", nil) }) }) t.Run("Nil Handler", func(t *testing.T) { t.Parallel() require.PanicsWithValue(t, "nil handler in route: /doe\n", func() { app.register([]string{"USE"}, "/doe", nil, nil) }) }) } func Test_Ensure_Router_Interface_Implementation(t *testing.T) { t.Parallel() var app any = (*App)(nil) _, ok := app.(Router) require.True(t, ok) var group any = (*Group)(nil) _, ok = group.(Router) require.True(t, ok) } func Test_Router_Handler_Catch_Error(t *testing.T) { t.Parallel() app := New() app.config.ErrorHandler = func(_ Ctx, _ error) error { return errors.New("fake error") } app.Get("/", func(_ Ctx) error { return ErrForbidden }) c := &fasthttp.RequestCtx{} app.Handler()(c) require.Equal(t, StatusInternalServerError, c.Response.Header.StatusCode()) } func Test_Router_NotFound(t *testing.T) { t.Parallel() app := New() app.Use(func(c Ctx) error { return c.Next() }) appHandler := app.Handler() c := &fasthttp.RequestCtx{} c.Request.Header.SetMethod("DELETE") c.URI().SetPath("/this/route/does/not/exist") appHandler(c) require.Equal(t, 404, c.Response.StatusCode()) require.Equal(t, "Not Found", string(c.Response.Body())) } func Test_Router_NotFound_HTML_Inject(t *testing.T) { t.Parallel() app := New() app.Use(func(c Ctx) error { return c.Next() }) appHandler := app.Handler() c := &fasthttp.RequestCtx{} c.Request.Header.SetMethod("DELETE") c.URI().SetPath("/does/not/exist") appHandler(c) require.Equal(t, 404, c.Response.StatusCode()) require.Equal(t, "Not Found", string(c.Response.Body())) } func registerTreeManipulationRoutes(app *App, middleware ...func(Ctx) error) { converted := make([]any, len(middleware)) for i, h := range middleware { converted[i] = h } app.Get("/test", func(c Ctx) error { app.Get("/dynamically-defined", func(c Ctx) error { return c.SendStatus(StatusOK) }) app.RebuildTree() return c.SendStatus(StatusOK) }, converted...) } func verifyRequest(tb testing.TB, app *App, path string, expectedStatus int) *http.Response { tb.Helper() resp, err := app.Test(httptest.NewRequest(MethodGet, path, http.NoBody)) require.NoError(tb, err, "app.Test(req)") require.Equal(tb, expectedStatus, resp.StatusCode, "Status code") return resp } func verifyRouteHandlerCounts(tb testing.TB, app *App, expectedRoutesCount int) { tb.Helper() // this is taken from listen.go's printRoutesMessage app method var routes []RouteMessage for _, routeStack := range app.stack { for _, route := range routeStack { routeMsg := RouteMessage{ name: route.Name, method: route.Method, path: route.Path, } for _, handler := range route.Handlers { routeMsg.handlers += runtime.FuncForPC(reflect.ValueOf(handler).Pointer()).Name() + " " } routes = append(routes, routeMsg) } } for _, route := range routes { require.Equal(tb, expectedRoutesCount, strings.Count(route.handlers, " ")) } } func verifyThereAreNoRoutes(tb testing.TB, app *App) { tb.Helper() require.Equal(tb, uint32(0), app.handlersCount) verifyRouteHandlerCounts(tb, app, 0) } func Test_App_Rebuild_Tree(t *testing.T) { t.Parallel() app := New() registerTreeManipulationRoutes(app) verifyRequest(t, app, "/dynamically-defined", StatusNotFound) verifyRequest(t, app, "/test", StatusOK) verifyRequest(t, app, "/dynamically-defined", StatusOK) } func Test_App_Remove_Route_A_B_Feature_Testing(t *testing.T) { t.Parallel() app := New() app.Get("/api/feature-a", func(c Ctx) error { app.RemoveRoute("/api/feature", MethodGet) app.RebuildTree() // Redefine route app.Get("/api/feature", func(c Ctx) error { return c.SendString("Testing feature-a") }) app.RebuildTree() return c.SendStatus(StatusOK) }) app.Get("/api/feature-b", func(c Ctx) error { app.RemoveRoute("/api/feature", MethodGet) app.RebuildTree() // Redefine route app.Get("/api/feature", func(c Ctx) error { return c.SendString("Testing feature-b") }) app.RebuildTree() return c.SendStatus(StatusOK) }) verifyRequest(t, app, "/api/feature-a", StatusOK) resp := verifyRequest(t, app, "/api/feature", StatusOK) body, err := io.ReadAll(resp.Body) require.NoError(t, err, "app.Test(req)") require.Equal(t, "Testing feature-a", string(body), "Response Message") verifyRequest(t, app, "/api/feature-b", StatusOK) resp = verifyRequest(t, app, "/api/feature", StatusOK) body, err = io.ReadAll(resp.Body) require.NoError(t, err, "app.Test(req)") require.Equal(t, "Testing feature-b", string(body), "Response Message") } func Test_App_Remove_Route_By_Name(t *testing.T) { t.Parallel() app := New() app.Get("/api/test", func(c Ctx) error { return c.SendStatus(StatusOK) }).Name("test") app.RemoveRouteByName("test", MethodGet) app.RebuildTree() verifyRequest(t, app, "/test", StatusNotFound) verifyThereAreNoRoutes(t, app) } func Test_App_Remove_Route_By_Name_Non_Existing_Route(t *testing.T) { t.Parallel() app := New() app.RemoveRouteByName("test", MethodGet) app.RebuildTree() verifyThereAreNoRoutes(t, app) } func Test_App_Remove_Route_Nested(t *testing.T) { t.Parallel() app := New() api := app.Group("/api") v1 := api.Group("/v1") v1.Get("/test", func(c Ctx) error { return c.SendStatus(StatusOK) }) verifyRequest(t, app, "/api/v1/test", StatusOK) app.RemoveRoute("/api/v1/test", MethodGet) verifyThereAreNoRoutes(t, app) } func Test_App_Remove_Route_Parameterized(t *testing.T) { t.Parallel() app := New() app.Get("/test/:id", func(c Ctx) error { return c.SendStatus(StatusOK) }) verifyRequest(t, app, "/test/:id", StatusOK) app.RemoveRoute("/test/:id", MethodGet) verifyThereAreNoRoutes(t, app) } func Test_App_Remove_Route(t *testing.T) { t.Parallel() app := New() app.Get("/test", func(c Ctx) error { return c.SendStatus(StatusOK) }) app.RemoveRoute("/test", MethodGet) app.RebuildTree() verifyRequest(t, app, "/test", StatusNotFound) } func Test_App_Remove_Route_Non_Existing_Route(t *testing.T) { t.Parallel() app := New() app.RemoveRoute("/test", MethodGet, MethodHead) app.RebuildTree() verifyThereAreNoRoutes(t, app) } func Test_App_Use_StrictRoutingBoundary(t *testing.T) { type testCase struct { name string path string expectedStatus int strictRouting bool expectMatched bool } testCases := []testCase{ { name: "Strict exact match", strictRouting: true, path: "/api", expectMatched: true, expectedStatus: StatusOK, }, { name: "Strict trailing slash partial", strictRouting: true, path: "/api/", expectMatched: true, expectedStatus: StatusOK, }, { name: "Strict nested partial", strictRouting: true, path: "/api/users", expectMatched: true, expectedStatus: StatusOK, }, { name: "Strict disallows sibling prefix", strictRouting: true, path: "/apiv1", expectMatched: false, expectedStatus: StatusNotFound, }, { name: "Non-strict exact match", strictRouting: false, path: "/api", expectMatched: true, expectedStatus: StatusOK, }, { name: "Non-strict trailing slash partial", strictRouting: false, path: "/api/", expectMatched: true, expectedStatus: StatusOK, }, { name: "Non-strict nested partial", strictRouting: false, path: "/api/users", expectMatched: true, expectedStatus: StatusOK, }, { name: "Non-strict disallows sibling prefix", strictRouting: false, path: "/apiv1", expectMatched: false, expectedStatus: StatusNotFound, }, } for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { app := New(Config{StrictRouting: tt.strictRouting}) matched := false app.Use("/api", func(c Ctx) error { matched = true return c.SendStatus(StatusOK) }) resp, err := app.Test(httptest.NewRequest(MethodGet, tt.path, http.NoBody)) require.NoError(t, err) require.Equal(t, tt.expectedStatus, resp.StatusCode) require.Equal(t, tt.expectMatched, matched) }) } } func Test_Group_Use_StrictRoutingBoundary(t *testing.T) { type testCase struct { name string path string expectedStatus int strictRouting bool expectMatched bool } testCases := []testCase{ { name: "Strict group exact match", strictRouting: true, path: "/api/v1", expectMatched: true, expectedStatus: StatusOK, }, { name: "Strict group trailing slash partial", strictRouting: true, path: "/api/v1/", expectMatched: true, expectedStatus: StatusOK, }, { name: "Strict group nested partial", strictRouting: true, path: "/api/v1/users", expectMatched: true, expectedStatus: StatusOK, }, { name: "Strict group disallows sibling prefix", strictRouting: true, path: "/api/v1beta", expectMatched: false, expectedStatus: StatusNotFound, }, { name: "Non-strict group exact match", strictRouting: false, path: "/api/v1", expectMatched: true, expectedStatus: StatusOK, }, { name: "Non-strict group trailing slash partial", strictRouting: false, path: "/api/v1/", expectMatched: true, expectedStatus: StatusOK, }, { name: "Non-strict group nested partial", strictRouting: false, path: "/api/v1/users", expectMatched: true, expectedStatus: StatusOK, }, { name: "Non-strict group disallows sibling prefix", strictRouting: false, path: "/api/v1beta", expectMatched: false, expectedStatus: StatusNotFound, }, } for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { app := New(Config{StrictRouting: tt.strictRouting}) grp := app.Group("/api") matched := false grp.Use("/v1", func(c Ctx) error { matched = true return c.Next() }) grp.Get("/v1", func(c Ctx) error { return c.SendStatus(StatusOK) }) grp.Get("/v1/*", func(c Ctx) error { return c.SendStatus(StatusOK) }) resp, err := app.Test(httptest.NewRequest(MethodGet, tt.path, http.NoBody)) require.NoError(t, err) require.Equal(t, tt.expectedStatus, resp.StatusCode) require.Equal(t, tt.expectMatched, matched) }) } } func Test_App_Remove_Route_Concurrent(t *testing.T) { t.Parallel() app := New() // Add test route app.Get("/test", func(c Ctx) error { return c.SendStatus(StatusOK) }) // Concurrently remove and add routes var wg sync.WaitGroup for range 10 { wg.Go(func() { app.RemoveRoute("/test", MethodGet) app.Get("/test", func(c Ctx) error { return c.SendStatus(StatusOK) }) }) } wg.Wait() // Verify final state app.RebuildTree() verifyRequest(t, app, "/test", StatusOK) } func Test_Route_Registration_Prevent_Duplicate_With_Middleware(t *testing.T) { t.Parallel() app := New() middleware := func(c Ctx) error { return c.Next() } registerTreeManipulationRoutes(app, middleware) registerTreeManipulationRoutes(app) verifyRequest(t, app, "/dynamically-defined", StatusNotFound) require.Equal(t, uint32(6), app.handlersCount) verifyRequest(t, app, "/test", StatusOK) require.Equal(t, uint32(7), app.handlersCount) verifyRequest(t, app, "/dynamically-defined", StatusOK) require.Equal(t, uint32(8), app.handlersCount) verifyRequest(t, app, "/test", StatusOK) require.Equal(t, uint32(9), app.handlersCount) verifyRequest(t, app, "/dynamically-defined", StatusOK) require.Equal(t, uint32(9), app.handlersCount) } func TestNormalizePath(t *testing.T) { tests := []struct { name string path string expected string caseSensitive bool strictRouting bool }{ { name: "Empty path", path: "", caseSensitive: true, strictRouting: true, expected: "/", }, { name: "No leading slash", path: "users", caseSensitive: true, strictRouting: true, expected: "/users", }, { name: "With trailing slash and strict routing", path: "/users/", caseSensitive: true, strictRouting: true, expected: "/users/", }, { name: "With trailing slash and non-strict routing", path: "/users/", caseSensitive: true, strictRouting: false, expected: "/users", }, { name: "Case sensitive", path: "/Users", caseSensitive: true, strictRouting: true, expected: "/Users", }, { name: "Case insensitive", path: "/Users", caseSensitive: false, strictRouting: true, expected: "/users", }, { name: "With escape characters", path: "/users\\/profile", caseSensitive: true, strictRouting: true, expected: "/users/profile", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { app := &App{ config: Config{ CaseSensitive: tt.caseSensitive, StrictRouting: tt.strictRouting, }, } result := app.normalizePath(tt.path) require.Equal(t, tt.expected, result) }) } } func TestRemoveRoute(t *testing.T) { app := New() var buf strings.Builder app.Use(func(c Ctx) error { buf.WriteString("1") //nolint:errcheck // not needed return c.Next() }) app.Post("/", func(c Ctx) error { buf.WriteString("2") //nolint:errcheck // not needed return c.SendStatus(StatusOK) }) app.Use("/test", func(c Ctx) error { buf.WriteString("3") //nolint:errcheck // not needed return c.Next() }) app.Get("/test", func(c Ctx) error { buf.WriteString("4") //nolint:errcheck // not needed return c.SendStatus(StatusOK) }) app.Post("/test", func(c Ctx) error { buf.WriteString("5") //nolint:errcheck // not needed return c.SendStatus(StatusOK) }) app.startupProcess() require.Equal(t, uint32(6), app.handlersCount) req, err := http.NewRequestWithContext(context.Background(), MethodPost, "/", http.NoBody) require.NoError(t, err) resp, err := app.Test(req) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) require.Equal(t, "12", buf.String()) buf.Reset() req, err = http.NewRequestWithContext(context.Background(), MethodGet, "/test", http.NoBody) require.NoError(t, err) resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) require.Equal(t, "134", buf.String()) buf.Reset() require.Equal(t, uint32(6), app.handlersCount) app.RemoveRoute("/test", MethodGet) app.RebuildTree() app.RemoveRoute("/test", "TEST") app.RebuildTree() app.RemoveRouteFunc(func(_ *Route) bool { return false }, MethodGet) req, err = http.NewRequestWithContext(context.Background(), MethodGet, "/test", http.NoBody) require.NoError(t, err) resp, err = app.Test(req) require.NoError(t, err) buf.Reset() require.Equal(t, StatusMethodNotAllowed, resp.StatusCode) require.Equal(t, uint32(4), app.handlersCount) app.RemoveRoute("/test", MethodPost) app.RebuildTree() require.Equal(t, uint32(3), app.handlersCount) req, err = http.NewRequestWithContext(context.Background(), MethodPost, "/test", http.NoBody) require.NoError(t, err) resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, 404, resp.StatusCode) require.Equal(t, "1", buf.String()) buf.Reset() req, err = http.NewRequestWithContext(context.Background(), MethodGet, "/test", http.NoBody) require.NoError(t, err) resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, 404, resp.StatusCode) require.Equal(t, "1", buf.String()) buf.Reset() app.RemoveRoute("/", MethodGet, MethodPost) require.Equal(t, uint32(2), app.handlersCount) req, err = http.NewRequestWithContext(context.Background(), MethodGet, "/", http.NoBody) require.NoError(t, err) resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, 404, resp.StatusCode) require.Empty(t, buf.String()) buf.Reset() app.RemoveRoute("/test", MethodGet, MethodPost) require.Equal(t, uint32(2), app.handlersCount) app.RemoveRoute("/test", app.config.RequestMethods...) require.Equal(t, uint32(1), app.handlersCount) app.Patch("/test", func(c Ctx) error { buf.WriteString("6") //nolint:errcheck // not needed return c.SendStatus(StatusOK) }) require.Equal(t, uint32(2), app.handlersCount) app.RemoveRoute("/test") app.RemoveRoute("/") app.RebuildTree() require.Equal(t, uint32(0), app.handlersCount) } ////////////////////////////////////////////// ///////////////// BENCHMARKS ///////////////// ////////////////////////////////////////////// func registerDummyRoutes(app *App) { h := func(_ Ctx) error { return nil } for _, r := range routesFixture.GitHubAPI { app.Add([]string{r.Method}, r.Path, h) } } func acquireDefaultCtxForRouterBenchmark(b *testing.B, app *App, fctx *fasthttp.RequestCtx) *DefaultCtx { b.Helper() ctx := app.AcquireCtx(fctx) defaultCtx, ok := ctx.(*DefaultCtx) if !ok { b.Fatal("AcquireCtx did not return *DefaultCtx") } return defaultCtx } // go test -v -run=^$ -bench=Benchmark_App_RebuildTree -benchmem -count=4 func Benchmark_App_RebuildTree(b *testing.B) { app := New() registerDummyRoutes(app) b.ReportAllocs() b.ResetTimer() for b.Loop() { app.routesRefreshed = true app.RebuildTree() } } // go test -v -run=^$ -bench=Benchmark_App_MethodNotAllowed -benchmem -count=4 func Benchmark_App_MethodNotAllowed(b *testing.B) { app := New() h := func(c Ctx) error { return c.SendString("Hello World!") } app.All("/this/is/a/", h) app.Get("/this/is/a/dummy/route/oke", h) appHandler := app.Handler() c := &fasthttp.RequestCtx{} c.Request.Header.SetMethod("DELETE") c.URI().SetPath("/this/is/a/dummy/route/oke") for b.Loop() { appHandler(c) } require.Equal(b, 405, c.Response.StatusCode()) require.Equal(b, MethodGet+", "+MethodHead, string(c.Response.Header.Peek("Allow"))) require.Equal(b, utils.StatusMessage(StatusMethodNotAllowed), string(c.Response.Body())) } // go test -v ./... -run=^$ -bench=Benchmark_Router_NotFound -benchmem -count=4 func Benchmark_Router_NotFound(b *testing.B) { b.ReportAllocs() app := New() app.Use(func(c Ctx) error { return c.Next() }) registerDummyRoutes(app) appHandler := app.Handler() c := &fasthttp.RequestCtx{} c.Request.Header.SetMethod("DELETE") c.URI().SetPath("/this/route/does/not/exist") for b.Loop() { appHandler(c) } require.Equal(b, 404, c.Response.StatusCode()) require.Equal(b, "Not Found", string(c.Response.Body())) } // go test -v ./... -run=^$ -bench=Benchmark_Router_Handler -benchmem -count=4 func Benchmark_Router_Handler(b *testing.B) { app := New() registerDummyRoutes(app) appHandler := app.Handler() c := &fasthttp.RequestCtx{} c.Request.Header.SetMethod("DELETE") c.URI().SetPath("/user/keys/1337") for b.Loop() { appHandler(c) } } func Benchmark_Router_Handler_Strict_Case(b *testing.B) { app := New(Config{ StrictRouting: true, CaseSensitive: true, }) registerDummyRoutes(app) appHandler := app.Handler() c := &fasthttp.RequestCtx{} c.Request.Header.SetMethod("DELETE") c.URI().SetPath("/user/keys/1337") for b.Loop() { appHandler(c) } } // go test -v ./... -run=^$ -bench=Benchmark_Router_Chain -benchmem -count=4 func Benchmark_Router_Chain(b *testing.B) { app := New() handler := func(c Ctx) error { return c.Next() } app.Get("/", handler, handler, handler, handler, handler, handler) appHandler := app.Handler() c := &fasthttp.RequestCtx{} c.Request.Header.SetMethod(MethodGet) c.URI().SetPath("/") for b.Loop() { appHandler(c) } } // go test -v ./... -run=^$ -bench=Benchmark_Router_WithCompression -benchmem -count=4 func Benchmark_Router_WithCompression(b *testing.B) { app := New() handler := func(c Ctx) error { return c.Next() } app.Get("/", handler) app.Get("/", handler) app.Get("/", handler) app.Get("/", handler) app.Get("/", handler) app.Get("/", handler) appHandler := app.Handler() c := &fasthttp.RequestCtx{} c.Request.Header.SetMethod(MethodGet) c.URI().SetPath("/") for b.Loop() { appHandler(c) } } // go test -run=^$ -bench=Benchmark_Startup_Process -benchmem -count=9 func Benchmark_Startup_Process(b *testing.B) { for b.Loop() { app := New() registerDummyRoutes(app) app.startupProcess() } } // go test -v ./... -run=^$ -bench=Benchmark_Router_Next -benchmem -count=4 func Benchmark_Router_Next(b *testing.B) { app := New() registerDummyRoutes(app) app.startupProcess() request := &fasthttp.RequestCtx{} request.Request.Header.SetMethod("DELETE") request.URI().SetPath("/user/keys/1337") var res bool var err error c := app.AcquireCtx(request).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed for b.Loop() { c.indexRoute = -1 res, err = app.next(c) } require.NoError(b, err) require.True(b, res) require.Equal(b, 4, c.indexRoute) } // go test -v ./... -run=^$ -bench=Benchmark_Router_Next_Default -benchmem -count=4 func Benchmark_Router_Next_Default(b *testing.B) { app := New() app.Get("/", func(_ Ctx) error { return nil }) h := app.Handler() fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(MethodGet) fctx.Request.SetRequestURI("/") b.ReportAllocs() for b.Loop() { h(fctx) } } // go test -benchmem -run=^$ -bench ^Benchmark_Router_Next_Default_Parallel$ github.com/gofiber/fiber/v3 -count=1 func Benchmark_Router_Next_Default_Parallel(b *testing.B) { app := New() app.Get("/", func(_ Ctx) error { return nil }) h := app.Handler() b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(MethodGet) fctx.Request.SetRequestURI("/") for pb.Next() { h(fctx) } }) } // go test -v ./... -run=^$ -bench=Benchmark_Router_Next_Default_Immutable -benchmem -count=4 func Benchmark_Router_Next_Default_Immutable(b *testing.B) { app := New(Config{Immutable: true}) app.Get("/", func(_ Ctx) error { return nil }) h := app.Handler() fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(MethodGet) fctx.Request.SetRequestURI("/") b.ReportAllocs() for b.Loop() { h(fctx) } } // go test -benchmem -run=^$ -bench ^Benchmark_Router_Next_Default_Parallel_Immutable$ github.com/gofiber/fiber/v3 -count=1 func Benchmark_Router_Next_Default_Parallel_Immutable(b *testing.B) { app := New(Config{Immutable: true}) app.Get("/", func(_ Ctx) error { return nil }) h := app.Handler() b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(MethodGet) fctx.Request.SetRequestURI("/") for pb.Next() { h(fctx) } }) } // go test -v ./... -run=^$ -bench=Benchmark_Route_Match -benchmem -count=4 func Benchmark_Route_Match(b *testing.B) { var match bool var params [maxParams]string parsed := parseRoute("/user/keys/:id") route := &Route{ use: false, root: false, star: false, routeParser: parsed, Params: parsed.params, path: "/user/keys/:id", Path: "/user/keys/:id", Method: "DELETE", } route.Handlers = append(route.Handlers, func(_ Ctx) error { return nil }) for b.Loop() { match = route.match("/user/keys/1337", "/user/keys/1337", ¶ms) } require.True(b, match) require.Equal(b, []string{"1337"}, params[0:len(parsed.params)]) } // go test -v ./... -run=^$ -bench=Benchmark_Route_Match_Star -benchmem -count=4 func Benchmark_Route_Match_Star(b *testing.B) { var match bool var params [maxParams]string parsed := parseRoute("/*") route := &Route{ use: false, root: false, star: true, routeParser: parsed, Params: parsed.params, path: "/user/keys/bla", Path: "/user/keys/bla", Method: "DELETE", } route.Handlers = append(route.Handlers, func(_ Ctx) error { return nil }) for b.Loop() { match = route.match("/user/keys/bla", "/user/keys/bla", ¶ms) } require.True(b, match) require.Equal(b, []string{"user/keys/bla"}, params[0:len(parsed.params)]) } // go test -v ./... -run=^$ -bench=Benchmark_Route_Match_Root -benchmem -count=4 func Benchmark_Route_Match_Root(b *testing.B) { var match bool var params [maxParams]string parsed := parseRoute("/") route := &Route{ use: false, root: true, star: false, path: "/", routeParser: parsed, Params: parsed.params, Path: "/", Method: "DELETE", } route.Handlers = append(route.Handlers, func(_ Ctx) error { return nil }) for b.Loop() { match = route.match("/", "/", ¶ms) } require.True(b, match) require.Equal(b, []string{}, params[0:len(parsed.params)]) } // go test -v ./... -run=^$ -bench=Benchmark_Router_Handler_CaseSensitive -benchmem -count=4 func Benchmark_Router_Handler_CaseSensitive(b *testing.B) { app := New() app.config.CaseSensitive = true registerDummyRoutes(app) appHandler := app.Handler() c := &fasthttp.RequestCtx{} c.Request.Header.SetMethod("DELETE") c.URI().SetPath("/user/keys/1337") for b.Loop() { appHandler(c) } } // go test -v ./... -run=^$ -bench=Benchmark_Router_Handler_Unescape -benchmem -count=4 func Benchmark_Router_Handler_Unescape(b *testing.B) { app := New() app.config.UnescapePath = true registerDummyRoutes(app) app.Delete("/créer", func(_ Ctx) error { return nil }) appHandler := app.Handler() c := &fasthttp.RequestCtx{} c.Request.Header.SetMethod(MethodDelete) c.URI().SetPath("/cr%C3%A9er") for b.Loop() { c.URI().SetPath("/cr%C3%A9er") appHandler(c) } } // go test -run=^$ -bench=Benchmark_Router_Handler_StrictRouting -benchmem -count=4 func Benchmark_Router_Handler_StrictRouting(b *testing.B) { app := New() app.config.CaseSensitive = true registerDummyRoutes(app) appHandler := app.Handler() c := &fasthttp.RequestCtx{} c.Request.Header.SetMethod("DELETE") c.URI().SetPath("/user/keys/1337") for b.Loop() { appHandler(c) } } // go test -run=^$ -bench=Benchmark_Router_GitHub_API -benchmem -count=16 func Benchmark_Router_GitHub_API(b *testing.B) { app := New() registerDummyRoutes(app) app.startupProcess() c := &fasthttp.RequestCtx{} var match bool var err error b.ResetTimer() for i := range routesFixture.TestRoutes { b.RunParallel(func(pb *testing.PB) { c.Request.Header.SetMethod(routesFixture.TestRoutes[i].Method) for pb.Next() { c.URI().SetPath(routesFixture.TestRoutes[i].Path) ctx := app.AcquireCtx(c).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed match, err = app.next(ctx) app.ReleaseCtx(ctx) } }) require.NoError(b, err) require.True(b, match) } } type testRoute struct { Method string `json:"method"` Path string `json:"path"` } type routeJSON struct { TestRoutes []testRoute `json:"test_routes"` GitHubAPI []testRoute `json:"github_api"` } func newCustomApp() *App { return NewWithCustomCtx(func(app *App) CustomCtx { return &customCtx{DefaultCtx: *NewDefaultCtx(app)} }) } func Test_NextCustom_MethodNotAllowed(t *testing.T) { t.Parallel() app := newCustomApp() app.Get("/foo", func(c Ctx) error { return c.SendStatus(StatusOK) }) useRoute := &Route{use: true, path: "/foo", Path: "/foo", routeParser: parseRoute("/foo")} m := app.methodInt(MethodGet) app.stack[m] = append([]*Route{useRoute}, app.stack[m]...) app.routesRefreshed = true app.ensureAutoHeadRoutes() app.RebuildTree() fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(MethodPost) fctx.Request.SetRequestURI("/foo") ctx := app.AcquireCtx(fctx) defer app.ReleaseCtx(ctx) matched, err := app.nextCustom(ctx) require.False(t, matched) require.ErrorIs(t, err, ErrMethodNotAllowed) allow := string(ctx.Response().Header.Peek(HeaderAllow)) require.Equal(t, "GET, HEAD", allow) } func Test_NextCustom_NotFound(t *testing.T) { t.Parallel() app := newCustomApp() app.RebuildTree() fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(MethodGet) fctx.Request.SetRequestURI("/not-exist") ctx := app.AcquireCtx(fctx) defer app.ReleaseCtx(ctx) matched, err := app.nextCustom(ctx) require.False(t, matched) var e *Error require.ErrorAs(t, err, &e) require.Equal(t, StatusNotFound, e.Code) } func Test_RequestHandler_CustomCtx_NotImplemented(t *testing.T) { t.Parallel() app := newCustomApp() h := app.Handler() fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod("UNKNOWN") fctx.Request.SetRequestURI("/") h(fctx) require.Equal(t, StatusNotImplemented, fctx.Response.StatusCode()) } func Test_NextCustom_Matched404(t *testing.T) { t.Parallel() app := newCustomApp() app.RebuildTree() fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(MethodGet) fctx.Request.SetRequestURI("/none") ctx := app.AcquireCtx(fctx) ctx.setMatched(true) defer app.ReleaseCtx(ctx) matched, err := app.nextCustom(ctx) require.False(t, matched) var e *Error require.ErrorAs(t, err, &e) require.Equal(t, StatusNotFound, e.Code) } func Test_NextCustom_SkipMountAndNoHandlers(t *testing.T) { t.Parallel() app := newCustomApp() m := app.methodInt(MethodGet) mountR := &Route{path: "/skip", Path: "/skip", routeParser: parseRoute("/skip"), mount: true} empty := &Route{path: "/foo", Path: "/foo", routeParser: parseRoute("/foo")} app.stack[m] = []*Route{mountR, empty} app.routesRefreshed = true app.RebuildTree() fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(MethodGet) fctx.Request.SetRequestURI("/foo") ctx := app.AcquireCtx(fctx) defer app.ReleaseCtx(ctx) matched, err := app.nextCustom(ctx) require.True(t, matched) require.NoError(t, err) require.Equal(t, "/foo", ctx.Route().Path) } func Test_AddRoute_MergeHandlers(t *testing.T) { t.Parallel() app := New() count := func(_ Ctx) error { return nil } app.Get("/merge", count) app.Get("/merge", count) require.Len(t, app.stack[app.methodInt(MethodGet)], 1) require.Len(t, app.stack[app.methodInt(MethodGet)][0].Handlers, 2) } func Benchmark_App_RebuildTree_Parallel(b *testing.B) { b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { // Each worker gets its own App instance to avoid data races on shared state localApp := New() registerDummyRoutes(localApp) for pb.Next() { localApp.routesRefreshed = true localApp.RebuildTree() } }) } func Benchmark_App_MethodNotAllowed_Parallel(b *testing.B) { app := New() h := func(c Ctx) error { return c.SendString("Hello World!") } app.All("/this/is/a/", h) app.Get("/this/is/a/dummy/route/oke", h) appHandler := app.Handler() b.RunParallel(func(pb *testing.PB) { // Each worker gets its own RequestCtx to avoid data races c := &fasthttp.RequestCtx{} c.Request.Header.SetMethod("DELETE") c.URI().SetPath("/this/is/a/dummy/route/oke") for pb.Next() { appHandler(c) } }) // Single-threaded verification on a fresh context to preserve correctness checks verifyCtx := &fasthttp.RequestCtx{} verifyCtx.Request.Header.SetMethod("DELETE") verifyCtx.URI().SetPath("/this/is/a/dummy/route/oke") appHandler(verifyCtx) require.Equal(b, 405, verifyCtx.Response.StatusCode()) require.Equal(b, MethodGet+", "+MethodHead, string(verifyCtx.Response.Header.Peek("Allow"))) require.Equal(b, utils.StatusMessage(StatusMethodNotAllowed), string(verifyCtx.Response.Body())) } func Benchmark_Router_NotFound_Parallel(b *testing.B) { b.ReportAllocs() app := New() app.Use(func(c Ctx) error { return c.Next() }) registerDummyRoutes(app) appHandler := app.Handler() c := &fasthttp.RequestCtx{} c.Request.Header.SetMethod("DELETE") c.URI().SetPath("/this/route/does/not/exist") b.RunParallel(func(pb *testing.PB) { for pb.Next() { appHandler(c) } }) require.Equal(b, 404, c.Response.StatusCode()) require.Equal(b, "Not Found", string(c.Response.Body())) } func Benchmark_Router_Handler_Parallel(b *testing.B) { app := New() registerDummyRoutes(app) appHandler := app.Handler() c := &fasthttp.RequestCtx{} c.Request.Header.SetMethod("DELETE") c.URI().SetPath("/user/keys/1337") b.RunParallel(func(pb *testing.PB) { for pb.Next() { appHandler(c) } }) } func Benchmark_Router_Handler_Strict_Case_Parallel(b *testing.B) { app := New(Config{StrictRouting: true, CaseSensitive: true}) registerDummyRoutes(app) appHandler := app.Handler() c := &fasthttp.RequestCtx{} c.Request.Header.SetMethod("DELETE") c.URI().SetPath("/user/keys/1337") b.RunParallel(func(pb *testing.PB) { for pb.Next() { appHandler(c) } }) } func Benchmark_Router_Chain_Parallel(b *testing.B) { app := New() handler := func(c Ctx) error { return c.Next() } app.Get("/", handler, handler, handler, handler, handler, handler) appHandler := app.Handler() c := &fasthttp.RequestCtx{} c.Request.Header.SetMethod(MethodGet) c.URI().SetPath("/") b.RunParallel(func(pb *testing.PB) { for pb.Next() { appHandler(c) } }) } func Benchmark_Router_WithCompression_Parallel(b *testing.B) { app := New() handler := func(c Ctx) error { return c.Next() } app.Get("/", handler) app.Get("/", handler) app.Get("/", handler) app.Get("/", handler) app.Get("/", handler) app.Get("/", handler) appHandler := app.Handler() c := &fasthttp.RequestCtx{} c.Request.Header.SetMethod(MethodGet) c.URI().SetPath("/") b.RunParallel(func(pb *testing.PB) { for pb.Next() { appHandler(c) } }) } func Benchmark_Startup_Process_Parallel(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { app := New() registerDummyRoutes(app) app.startupProcess() } }) } func Benchmark_Router_Next_Parallel(b *testing.B) { app := New() registerDummyRoutes(app) app.startupProcess() b.RunParallel(func(pb *testing.PB) { // Each worker gets its own request and context to avoid data races. request := &fasthttp.RequestCtx{} request.Request.Header.SetMethod("DELETE") request.URI().SetPath("/user/keys/1337") c := acquireDefaultCtxForRouterBenchmark(b, app, request) for pb.Next() { c.indexRoute = -1 //nolint:errcheck // Benchmark hot path - error checked in verification _, _ = app.next(c) } }) // Single-threaded verification on a fresh context to preserve correctness checks. verifyRequest := &fasthttp.RequestCtx{} verifyRequest.Request.Header.SetMethod("DELETE") verifyRequest.URI().SetPath("/user/keys/1337") verifyCtx := acquireDefaultCtxForRouterBenchmark(b, app, verifyRequest) verifyCtx.indexRoute = -1 res, err := app.next(verifyCtx) require.NoError(b, err) require.True(b, res) require.Equal(b, 4, verifyCtx.indexRoute) } func Benchmark_Router_Next_Default_Immutable_Parallel(b *testing.B) { app := New(Config{Immutable: true}) app.Get("/", func(_ Ctx) error { return nil }) h := app.Handler() fctx := &fasthttp.RequestCtx{} fctx.Request.Header.SetMethod(MethodGet) fctx.Request.SetRequestURI("/") b.ReportAllocs() b.RunParallel(func(pb *testing.PB) { for pb.Next() { h(fctx) } }) } func Benchmark_Route_Match_Parallel(b *testing.B) { parsed := parseRoute("/user/keys/:id") route := &Route{use: false, root: false, star: false, routeParser: parsed, Params: parsed.params, path: "/user/keys/:id", Path: "/user/keys/:id", Method: "DELETE"} route.Handlers = append(route.Handlers, func(_ Ctx) error { return nil }) b.RunParallel(func(pb *testing.PB) { // Each worker gets its own local variables to avoid data races var params [maxParams]string for pb.Next() { _ = route.match("/user/keys/1337", "/user/keys/1337", ¶ms) } }) // Single-threaded verification to preserve correctness checks var verifyParams [maxParams]string match := route.match("/user/keys/1337", "/user/keys/1337", &verifyParams) require.True(b, match) require.Equal(b, []string{"1337"}, verifyParams[0:len(parsed.params)]) } func Benchmark_Route_Match_Star_Parallel(b *testing.B) { var match bool var params [maxParams]string parsed := parseRoute("/*") route := &Route{use: false, root: false, star: true, routeParser: parsed, Params: parsed.params, path: "/user/keys/bla", Path: "/user/keys/bla", Method: "DELETE"} route.Handlers = append(route.Handlers, func(_ Ctx) error { return nil }) b.RunParallel(func(pb *testing.PB) { for pb.Next() { match = route.match("/user/keys/bla", "/user/keys/bla", ¶ms) } }) require.True(b, match) require.Equal(b, []string{"user/keys/bla"}, params[0:len(parsed.params)]) } func Benchmark_Route_Match_Root_Parallel(b *testing.B) { var match bool var params [maxParams]string parsed := parseRoute("/") route := &Route{use: false, root: true, star: false, path: "/", routeParser: parsed, Params: parsed.params, Path: "/", Method: "DELETE"} route.Handlers = append(route.Handlers, func(_ Ctx) error { return nil }) b.RunParallel(func(pb *testing.PB) { for pb.Next() { match = route.match("/", "/", ¶ms) } }) require.True(b, match) require.Equal(b, []string{}, params[0:len(parsed.params)]) } func Benchmark_Router_Handler_CaseSensitive_Parallel(b *testing.B) { app := New() app.config.CaseSensitive = true registerDummyRoutes(app) appHandler := app.Handler() c := &fasthttp.RequestCtx{} c.Request.Header.SetMethod("DELETE") c.URI().SetPath("/user/keys/1337") b.RunParallel(func(pb *testing.PB) { for pb.Next() { appHandler(c) } }) } func Benchmark_Router_Handler_Unescape_Parallel(b *testing.B) { app := New() app.config.UnescapePath = true registerDummyRoutes(app) app.Delete("/créer", func(_ Ctx) error { return nil }) appHandler := app.Handler() c := &fasthttp.RequestCtx{} c.Request.Header.SetMethod(MethodDelete) c.URI().SetPath("/cr%C3%A9er") b.RunParallel(func(pb *testing.PB) { for pb.Next() { c.URI().SetPath("/cr%C3%A9er") appHandler(c) } }) } func Benchmark_Router_Handler_StrictRouting_Parallel(b *testing.B) { app := New() app.config.CaseSensitive = true registerDummyRoutes(app) appHandler := app.Handler() c := &fasthttp.RequestCtx{} c.Request.Header.SetMethod("DELETE") c.URI().SetPath("/user/keys/1337") b.RunParallel(func(pb *testing.PB) { for pb.Next() { appHandler(c) } }) } func Benchmark_Router_GitHub_API_Parallel(b *testing.B) { app := New() registerDummyRoutes(app) app.startupProcess() b.ResetTimer() for i := range routesFixture.TestRoutes { b.RunParallel(func(pb *testing.PB) { // Each worker gets its own RequestCtx and local variables to avoid data races c := &fasthttp.RequestCtx{} c.Request.Header.SetMethod(routesFixture.TestRoutes[i].Method) for pb.Next() { c.URI().SetPath(routesFixture.TestRoutes[i].Path) ctx := acquireDefaultCtxForRouterBenchmark(b, app, c) //nolint:errcheck // Benchmark hot path - error checked in verification _, _ = app.next(ctx) app.ReleaseCtx(ctx) } }) // Single-threaded verification on a fresh context to preserve correctness checks verifyC := &fasthttp.RequestCtx{} verifyC.Request.Header.SetMethod(routesFixture.TestRoutes[i].Method) verifyC.URI().SetPath(routesFixture.TestRoutes[i].Path) verifyCtx := acquireDefaultCtxForRouterBenchmark(b, app, verifyC) match, err := app.next(verifyCtx) app.ReleaseCtx(verifyCtx) require.NoError(b, err) require.True(b, match) } } ================================================ FILE: services.go ================================================ package fiber import ( "context" "errors" "fmt" "io" utilsstrings "github.com/gofiber/utils/v2/strings" ) // Service is an interface that defines the methods for a service. type Service interface { // Start starts the service, returning an error if it fails. Start(ctx context.Context) error // String returns a string representation of the service. // It is used to print a human-readable name of the service in the startup message. String() string // State returns the current state of the service. State(ctx context.Context) (string, error) // Terminate terminates the service, returning an error if it fails. Terminate(ctx context.Context) error } // hasConfiguredServices Checks if there are any services for the current application. func (app *App) hasConfiguredServices() bool { return len(app.configured.Services) > 0 } func (app *App) validateConfiguredServices() error { return validateServicesSlice(app.configured.Services) } func validateServicesSlice(services []Service) error { for idx, srv := range services { if srv == nil { return fmt.Errorf("fiber: service at index %d is nil", idx) } } return nil } // initServices If the app is configured to use services, this function registers // a post shutdown hook to shutdown them after the server is closed. // This function panics if there is an error starting the services. func (app *App) initServices() { if !app.hasConfiguredServices() { return } if err := app.startServices(app.servicesStartupCtx()); err != nil { panic(err) } } // servicesStartupCtx Returns the context for the services startup. // If the ServicesStartupContextProvider is not set, it returns a new background context. func (app *App) servicesStartupCtx() context.Context { if app.configured.ServicesStartupContextProvider != nil { return app.configured.ServicesStartupContextProvider() } return context.Background() } // servicesShutdownCtx Returns the context for the services shutdown. // If the ServicesShutdownContextProvider is not set, it returns a new background context. func (app *App) servicesShutdownCtx() context.Context { if app.configured.ServicesShutdownContextProvider != nil { return app.configured.ServicesShutdownContextProvider() } return context.Background() } // startServices Handles the start process of services for the current application. // Iterates over all configured services and tries to start them, returning an error if any error occurs. func (app *App) startServices(ctx context.Context) error { if !app.hasConfiguredServices() { return nil } var errs []error for idx, srv := range app.configured.Services { if srv == nil { return fmt.Errorf("fiber: service at index %d is nil", idx) } if err := ctx.Err(); err != nil { // Context is canceled, return an error the soonest possible, so that // the user can see the context cancellation error and act on it. return fmt.Errorf("context canceled while starting service %s: %w", srv.String(), err) } err := srv.Start(ctx) if err == nil { // mark the service as started app.state.setService(srv) continue } if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return fmt.Errorf("service %s start: %w", srv.String(), err) } errs = append(errs, fmt.Errorf("service %s start: %w", srv.String(), err)) } return errors.Join(errs...) } // shutdownServices Handles the shutdown process of services for the current application. // Iterates over all the started services in reverse order and tries to terminate them, // returning an error if any error occurs. func (app *App) shutdownServices(ctx context.Context) error { if app.state.ServicesLen() == 0 { return nil } var errs []error for key, srv := range app.state.Services() { if srv == nil { return fmt.Errorf("fiber: service %q is nil", key) } if err := ctx.Err(); err != nil { // Context is canceled, do a best effort to terminate the services. errs = append(errs, fmt.Errorf("service %s terminate: %w", srv.String(), err)) continue } err := srv.Terminate(ctx) if err != nil { // Best effort to terminate the services. errs = append(errs, fmt.Errorf("service %s terminate: %w", srv.String(), err)) continue } // Remove the service from the State app.state.deleteService(srv) } return errors.Join(errs...) } // logServices logs information about services and returns an error // if any configured service is nil. func (app *App) logServices(ctx context.Context, out io.Writer, colors *Colors) error { if !app.hasConfiguredServices() { return nil } scheme := colors if scheme == nil { scheme = &DefaultColors } fmt.Fprintf(out, "%sINFO%s Services: \t%s%d%s\n", scheme.Green, scheme.Reset, scheme.Blue, app.state.ServicesLen(), scheme.Reset) for key, srv := range app.state.Services() { if srv == nil { return fmt.Errorf("fiber: service %q is nil", key) } var state string var stateColor string state, err := srv.State(ctx) if err != nil { state = errString stateColor = scheme.Red } else { stateColor = scheme.Blue } fmt.Fprintf(out, "%sINFO%s 🧩 %s[ %s ] %s%s\n", scheme.Green, scheme.Reset, stateColor, utilsstrings.ToUpper(state), srv.String(), scheme.Reset) } return nil } ================================================ FILE: services_test.go ================================================ package fiber import ( "bytes" "context" "errors" "fmt" "strings" "testing" "time" "github.com/gofiber/fiber/v3/log" "github.com/stretchr/testify/require" ) const ( terminateErrorMessage = "terminate error" startErrorMessage = "start error" ) // mockService implements Service interface for testing type mockService struct { startError error terminateError error stateError error name string started bool terminated bool startDelay time.Duration terminateDelay time.Duration } func (m *mockService) Start(ctx context.Context) error { select { case <-ctx.Done(): return fmt.Errorf("context canceled: %w", ctx.Err()) default: } if m.startDelay > 0 { timer := time.NewTimer(m.startDelay) select { case <-ctx.Done(): timer.Stop() return fmt.Errorf("context canceled: %w", ctx.Err()) case <-timer.C: // Continue after delay } } if m.startError != nil { m.started = false return m.startError } m.started = true return nil } func (m *mockService) String() string { return m.name } func (m *mockService) State(ctx context.Context) (string, error) { if ctx.Err() != nil { return "", fmt.Errorf("context canceled: %w", ctx.Err()) } if m.stateError != nil { return "error", m.stateError } if m.started { return "running", nil } if m.terminated { return "stopped", nil } return "unknown", nil } func (m *mockService) Terminate(ctx context.Context) error { select { case <-ctx.Done(): return fmt.Errorf("context canceled: %w", ctx.Err()) default: } if m.terminateDelay > 0 { timer := time.NewTimer(m.terminateDelay) select { case <-ctx.Done(): timer.Stop() return fmt.Errorf("context canceled: %w", ctx.Err()) case <-timer.C: // Continue after delay } } if m.terminateError != nil { m.terminated = false return m.terminateError } m.started = false m.terminated = true return nil } func Test_HasConfiguredServices(t *testing.T) { testHasConfiguredServicesFn := func(t *testing.T, app *App, expected bool) { t.Helper() result := app.hasConfiguredServices() require.Equal(t, expected, result) } t.Run("no-services", func(t *testing.T) { testHasConfiguredServicesFn(t, &App{configured: Config{}}, false) }) t.Run("has-services", func(t *testing.T) { testHasConfiguredServicesFn(t, &App{configured: Config{Services: []Service{&mockService{name: "test-dep"}}}}, true) }) } func Test_InitServices(t *testing.T) { t.Run("no-services", func(t *testing.T) { app := &App{configured: Config{}} require.NotPanics(t, app.initServices) }) t.Run("start/success", func(t *testing.T) { // Initialize the app using the struct and defining the state and hooks manually, // because we are not checking the shutdown hooks in this test. app := &App{ configured: Config{ Services: []Service{ &mockService{name: "dep1"}, &mockService{name: "dep2"}, }, }, state: newState(), } app.hooks = newHooks(app) require.NotPanics(t, app.initServices) }) t.Run("start/error", func(t *testing.T) { // Initialize the app using the struct and defining the state and hooks manually, // because we are not checking the shutdown hooks in this test. app := &App{ configured: Config{ Services: []Service{ &mockService{name: "dep1", startError: errors.New(startErrorMessage + " 1")}, &mockService{name: "dep2", startError: errors.New(startErrorMessage + " 2")}, &mockService{name: "dep3"}, }, }, state: newState(), } app.hooks = newHooks(app) require.Panics(t, app.initServices) }) t.Run("shutdown-hooks/success", func(t *testing.T) { // Initialize the app using the New function to verify that the shutdown hooks are registered // and the app mutex is not causing a deadlock. app := New(Config{ Services: []Service{&mockService{name: "dep1"}}, }) require.NotPanics(t, app.initServices) type stringsLogger struct { strings.Builder } var buf stringsLogger log.SetOutput(&buf) app.Hooks().executeOnPostShutdownHooks(nil) require.NotContains(t, buf.String(), "failed to call post shutdown hook:") }) t.Run("shutdown-hooks/error", func(t *testing.T) { // Initialize the app using the New function to verify that the shutdown hooks are registered // and the app mutex is not causing a deadlock. app := New(Config{ Services: []Service{ &mockService{name: "dep1"}, &mockService{name: "dep2", terminateError: errors.New(terminateErrorMessage + " 2")}, }, }) require.NotPanics(t, app.initServices) type stringsLogger struct { strings.Builder } var buf stringsLogger log.SetOutput(&buf) app.Hooks().executeOnPostShutdownHooks(nil) require.Contains(t, buf.String(), "failed to shutdown services: service dep2 terminate: terminate error 2") }) } func Test_StartServices(t *testing.T) { t.Run("no-services", func(t *testing.T) { app := &App{ configured: Config{ Services: []Service{}, }, state: newState(), } err := app.startServices(context.Background()) require.NoError(t, err) require.Zero(t, app.state.ServicesLen()) }) t.Run("successful-start", func(t *testing.T) { app := &App{ configured: Config{ Services: []Service{ &mockService{name: "dep1"}, &mockService{name: "dep2"}, }, }, state: newState(), } err := app.startServices(context.Background()) require.NoError(t, err) require.Equal(t, 2, app.state.ServicesLen()) }) t.Run("failed-start", func(t *testing.T) { app := &App{ configured: Config{ Services: []Service{ &mockService{name: "dep1", startError: errors.New(startErrorMessage + " 1")}, &mockService{name: "dep2", startError: errors.New(startErrorMessage + " 2")}, &mockService{name: "dep3"}, }, }, state: newState(), } err := app.startServices(context.Background()) require.Error(t, err) require.Contains(t, err.Error(), startErrorMessage+" 1") require.Contains(t, err.Error(), startErrorMessage+" 2") require.Equal(t, 1, app.state.ServicesLen()) }) t.Run("context", func(t *testing.T) { t.Run("already-canceled", func(t *testing.T) { app := &App{ configured: Config{ Services: []Service{ &mockService{name: "dep1"}, }, }, state: newState(), } // Create a context that is already canceled ctx, cancel := context.WithCancel(context.Background()) cancel() err := app.startServices(ctx) require.ErrorIs(t, err, context.Canceled) require.Zero(t, app.state.ServicesLen()) }) t.Run("cancellation", func(t *testing.T) { // Create a service that takes some time to start slowDep := &mockService{name: "slow-dep", startDelay: 200 * time.Millisecond} app := &App{ configured: Config{ Services: []Service{slowDep}, }, state: newState(), } // Create a context that will be canceled immediately ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() // Start services with a delay that is longer than the timeout err := app.startServices(ctx) require.ErrorIs(t, err, context.DeadlineExceeded) require.Zero(t, app.state.ServicesLen()) }) }) } func Test_ShutdownServices(t *testing.T) { t.Run("no-services", func(t *testing.T) { app := &App{ configured: Config{ Services: []Service{}, }, state: newState(), } err := app.shutdownServices(context.Background()) require.NoError(t, err) require.Zero(t, app.state.ServicesLen()) }) t.Run("successful-shutdown", func(t *testing.T) { srv1 := &mockService{name: "dep1"} srv2 := &mockService{name: "dep2"} // Expected state, including the two started services expectedState := newState() expectedState.setService(srv1) expectedState.setService(srv2) app := &App{ configured: Config{ Services: []Service{srv1, srv2}, }, state: expectedState, } err := app.shutdownServices(context.Background()) require.NoError(t, err) require.Zero(t, app.state.ServicesLen()) }) t.Run("failed-shutdown", func(t *testing.T) { srv1 := &mockService{name: "dep1", terminateError: errors.New(terminateErrorMessage + " 1")} srv2 := &mockService{name: "dep2", terminateError: errors.New(terminateErrorMessage + " 2")} srv3 := &mockService{name: "dep3"} // Expected state, including the two started services expectedState := newState() expectedState.setService(srv1) expectedState.setService(srv2) expectedState.setService(srv3) app := &App{ configured: Config{ Services: []Service{srv1, srv2, srv3}, }, state: expectedState, } err := app.shutdownServices(context.Background()) require.Error(t, err) require.Contains(t, err.Error(), terminateErrorMessage+" 1") require.Contains(t, err.Error(), terminateErrorMessage+" 2") require.Equal(t, 2, app.state.ServicesLen()) // 2 services failed to terminate }) t.Run("context", func(t *testing.T) { t.Run("already-canceled", func(t *testing.T) { srv1 := &mockService{name: "dep1"} // Expected state, including the two started services expectedState := newState() expectedState.setService(srv1) app := &App{ configured: Config{ Services: []Service{srv1}, }, state: expectedState, } // Create a context that is already canceled ctx, cancel := context.WithCancel(context.Background()) cancel() err := app.shutdownServices(ctx) require.Error(t, err) require.ErrorIs(t, err, context.Canceled) require.Contains(t, err.Error(), "service dep1 terminate") require.Equal(t, 1, app.state.ServicesLen()) }) t.Run("cancellation", func(t *testing.T) { // Create a service that takes some time to terminate slowDep := &mockService{name: "slow-dep", terminateDelay: 200 * time.Millisecond} // Expected state, including the two started services expectedState := newState() expectedState.setService(slowDep) app := &App{ configured: Config{ Services: []Service{slowDep}, }, state: expectedState, } // Create a new context for shutdown ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() // Shutdown services with canceled context err := app.shutdownServices(ctx) require.ErrorIs(t, err, context.DeadlineExceeded) require.Equal(t, 1, app.state.ServicesLen()) }) }) } func Test_LogServices(t *testing.T) { t.Parallel() // Service with successful State runningService := &mockService{name: "running", started: true} // Service with State error errorService := &mockService{name: "error", stateError: errors.New("state error")} expectedState := newState() expectedState.setService(runningService) expectedState.setService(errorService) app := &App{ configured: Config{ Services: []Service{runningService, errorService}, }, state: expectedState, } var buf bytes.Buffer colors := Colors{ Green: "\033[32m", Reset: "\033[0m", Blue: "\033[34m", Red: "\033[31m", } err := app.logServices(context.Background(), &buf, &colors) require.NoError(t, err) output := buf.String() for _, srv := range app.state.Services() { stateColor := colors.Blue state := "RUNNING" if _, err := srv.State(context.Background()); err != nil { stateColor = colors.Red state = "ERROR" } expected := fmt.Sprintf("%sINFO%s 🧩 %s[ %s ] %s%s\n", colors.Green, colors.Reset, stateColor, strings.ToUpper(state), srv.String(), colors.Reset) require.Contains(t, output, expected) } } func Test_NewConfiguredServicesNil(t *testing.T) { t.Parallel() require.PanicsWithError(t, "fiber: service at index 0 is nil", func() { New(Config{ Services: []Service{nil}, }) }) } func Test_ServiceContextProviders(t *testing.T) { t.Run("no-provider", func(t *testing.T) { app := &App{ configured: Config{}, state: newState(), } require.Equal(t, context.Background(), app.servicesStartupCtx()) require.Equal(t, context.Background(), app.servicesShutdownCtx()) }) t.Run("with-provider", func(t *testing.T) { ctx := context.TODO() app := &App{ configured: Config{ ServicesStartupContextProvider: func() context.Context { return ctx }, ServicesShutdownContextProvider: func() context.Context { return ctx }, }, state: newState(), } require.Equal(t, ctx, app.servicesStartupCtx()) require.Equal(t, ctx, app.servicesShutdownCtx()) }) } func Benchmark_StartServices(b *testing.B) { benchmarkFn := func(b *testing.B, services []Service) { b.Helper() for b.Loop() { app := New(Config{ Services: services, }) ctx := context.Background() if err := app.startServices(ctx); err != nil { b.Fatal("Expected no error but got", err) } } } b.Run("no-services", func(b *testing.B) { benchmarkFn(b, []Service{}) }) b.Run("single-service", func(b *testing.B) { benchmarkFn(b, []Service{ &mockService{name: "dep1"}, }) }) b.Run("multiple-services", func(b *testing.B) { benchmarkFn(b, []Service{ &mockService{name: "dep1"}, &mockService{name: "dep2"}, &mockService{name: "dep3"}, }) }) b.Run("multiple-services-with-delays", func(b *testing.B) { benchmarkFn(b, []Service{ &mockService{name: "dep1", startDelay: 1 * time.Millisecond}, &mockService{name: "dep2", startDelay: 2 * time.Millisecond}, &mockService{name: "dep3", startDelay: 3 * time.Millisecond}, }) }) } func Benchmark_ShutdownServices(b *testing.B) { benchmarkFn := func(b *testing.B, services []Service) { b.Helper() for b.Loop() { app := New(Config{ Services: services, }) ctx := context.Background() if err := app.shutdownServices(ctx); err != nil { b.Fatal("Expected no error but got", err) } } } b.Run("no-services", func(b *testing.B) { benchmarkFn(b, []Service{}) }) b.Run("single-service", func(b *testing.B) { benchmarkFn(b, []Service{ &mockService{name: "dep1"}, }) }) b.Run("multiple-services", func(b *testing.B) { benchmarkFn(b, []Service{ &mockService{name: "dep1"}, &mockService{name: "dep2"}, &mockService{name: "dep3"}, }) }) b.Run("multiple-services-with-delays", func(b *testing.B) { benchmarkFn(b, []Service{ &mockService{name: "dep1", terminateDelay: 1 * time.Millisecond}, &mockService{name: "dep2", terminateDelay: 2 * time.Millisecond}, &mockService{name: "dep3", terminateDelay: 3 * time.Millisecond}, }) }) } func Benchmark_StartServices_withContextCancellation(b *testing.B) { benchmarkFn := func(b *testing.B, services []Service, timeout time.Duration) { b.Helper() for b.Loop() { app := New(Config{ Services: services, }) ctx, cancel := context.WithTimeout(context.Background(), timeout) err := app.startServices(ctx) // We expect an error here due to the short timeout if err == nil && timeout < time.Second { b.Fatal("Expected error due to context cancellation but got none") } cancel() } } b.Run("single-service/immediate-cancellation", func(b *testing.B) { benchmarkFn(b, []Service{ &mockService{name: "dep1", startDelay: 100 * time.Millisecond}, }, 10*time.Millisecond) }) b.Run("multiple-services/immediate-cancellation", func(b *testing.B) { benchmarkFn(b, []Service{ &mockService{name: "dep1", startDelay: 100 * time.Millisecond}, &mockService{name: "dep2", startDelay: 200 * time.Millisecond}, &mockService{name: "dep3", startDelay: 300 * time.Millisecond}, }, 10*time.Millisecond) }) b.Run("multiple-services/successful-completion", func(b *testing.B) { const timeout = 500 * time.Millisecond for b.Loop() { app := New(Config{ Services: []Service{ &mockService{name: "dep1", startDelay: 10 * time.Millisecond}, &mockService{name: "dep2", startDelay: 20 * time.Millisecond}, &mockService{name: "dep3", startDelay: 30 * time.Millisecond}, }, }) ctx, cancel := context.WithTimeout(context.Background(), timeout) err := app.startServices(ctx) if err != nil { b.Fatal("Expected no error but got", err) } cancel() } }) } func Benchmark_ShutdownServices_withContextCancellation(b *testing.B) { benchmarkFn := func(b *testing.B, services []Service, timeout time.Duration) { b.Helper() for b.Loop() { app := New(Config{ Services: services, }) err := app.startServices(context.Background()) if err != nil { b.Fatal("Expected no error during startup but got", err) } ctx, cancel := context.WithTimeout(context.Background(), timeout) err = app.shutdownServices(ctx) // We expect an error here due to the short timeout if err == nil && timeout < time.Second { b.Fatal("Expected error due to context cancellation but got none") } cancel() } } b.Run("single-service/immediate-cancellation", func(b *testing.B) { benchmarkFn(b, []Service{ &mockService{name: "dep1", terminateDelay: 100 * time.Millisecond}, }, 10*time.Millisecond) }) b.Run("multiple-services/immediate-cancellation", func(b *testing.B) { benchmarkFn(b, []Service{ &mockService{name: "dep1", terminateDelay: 100 * time.Millisecond}, &mockService{name: "dep2", terminateDelay: 200 * time.Millisecond}, &mockService{name: "dep3", terminateDelay: 300 * time.Millisecond}, }, 10*time.Millisecond) }) b.Run("multiple-services/successful-completion", func(b *testing.B) { const timeout = 500 * time.Millisecond for b.Loop() { app := New(Config{ Services: []Service{ &mockService{name: "dep1", terminateDelay: 10 * time.Millisecond}, &mockService{name: "dep2", terminateDelay: 20 * time.Millisecond}, &mockService{name: "dep3", terminateDelay: 30 * time.Millisecond}, }, }) err := app.startServices(context.Background()) if err != nil { b.Fatal("Expected no error but got", err) } ctx, cancel := context.WithTimeout(context.Background(), timeout) err = app.shutdownServices(ctx) if err != nil { b.Fatal("Expected no error but got", err) } cancel() } }) } func Benchmark_ServicesMemory(b *testing.B) { benchmarkFn := func(b *testing.B, services []Service) { b.Helper() b.ReportAllocs() var err error for b.Loop() { app := New(Config{ Services: services, }) ctx := context.Background() err = app.startServices(ctx) if err != nil { continue } err = app.shutdownServices(ctx) } require.NoError(b, err) } b.Run("no-services", func(b *testing.B) { benchmarkFn(b, []Service{}) }) b.Run("single-service", func(b *testing.B) { benchmarkFn(b, []Service{ &mockService{name: "dep1"}, }) }) b.Run("multiple-services", func(b *testing.B) { benchmarkFn(b, []Service{ &mockService{name: "dep1"}, &mockService{name: "dep2"}, &mockService{name: "dep3"}, }) }) } ================================================ FILE: state.go ================================================ package fiber import ( "encoding/hex" "strings" "sync" "github.com/google/uuid" ) const servicesStatePrefix = "gofiber-services-" var servicesStatePrefixHash string func init() { servicesStatePrefixHash = hex.EncodeToString([]byte(servicesStatePrefix + uuid.New().String())) } // State is a key-value store for Fiber's app in order to be used as a global storage for the app's dependencies. // It's a thread-safe implementation of a map[string]any, using sync.Map. type State struct { dependencies sync.Map servicePrefix string } // NewState creates a new instance of State. func newState() *State { // Initialize the services state prefix using a hashed random string return &State{ dependencies: sync.Map{}, servicePrefix: servicesStatePrefixHash, } } // Set sets a key-value pair in the State. func (s *State) Set(key string, value any) { s.dependencies.Store(key, value) } // Get retrieves a value from the State. func (s *State) Get(key string) (any, bool) { return s.dependencies.Load(key) } // MustGet retrieves a value from the State and panics if the key is not found. func (s *State) MustGet(key string) any { if dep, ok := s.Get(key); ok { return dep } panic("state: dependency not found!") } // Has checks if a key is present in the State. // It returns a boolean indicating if the key is present. func (s *State) Has(key string) bool { _, ok := s.Get(key) return ok } // Delete removes a key-value pair from the State. func (s *State) Delete(key string) { s.dependencies.Delete(key) } // Reset resets the State by removing all keys. func (s *State) Reset() { s.dependencies.Clear() } // Keys returns a slice containing all keys present in the State. func (s *State) Keys() []string { // Pre-allocate with estimated capacity to reduce allocations keys := make([]string, 0, 8) s.dependencies.Range(func(key, _ any) bool { keyStr, ok := key.(string) if !ok { return true } keys = append(keys, keyStr) return true }) return keys } // Len returns the number of keys in the State. func (s *State) Len() int { length := 0 s.dependencies.Range(func(_, _ any) bool { length++ return true }) return length } // GetState retrieves a value from the State and casts it to the desired type. // It returns the casted value and a boolean indicating if the cast was successful. func GetState[T any](s *State, key string) (T, bool) { dep, ok := s.Get(key) if ok { depT, okCast := dep.(T) return depT, okCast } var zeroVal T return zeroVal, false } // MustGetState retrieves a value from the State and casts it to the desired type. // It panics if the key is not found or if the type assertion fails. func MustGetState[T any](s *State, key string) T { dep, ok := GetState[T](s, key) if !ok { panic("state: dependency not found!") } return dep } // GetStateWithDefault retrieves a value from the State, // casting it to the desired type. If the key is not present, // it returns the provided default value. func GetStateWithDefault[T any](s *State, key string, defaultVal T) T { dep, ok := GetState[T](s, key) if !ok { return defaultVal } return dep } // GetString retrieves a string value from the State. // It returns the string and a boolean indicating successful type assertion. func (s *State) GetString(key string) (string, bool) { return GetState[string](s, key) } // GetInt retrieves an integer value from the State. // It returns the int and a boolean indicating successful type assertion. func (s *State) GetInt(key string) (int, bool) { return GetState[int](s, key) } // GetBool retrieves a boolean value from the State. // It returns the bool and a boolean indicating successful type assertion. func (s *State) GetBool(key string) (value, ok bool) { //nolint:nonamedreturns // Better idea to use named returns here return GetState[bool](s, key) } // GetFloat64 retrieves a float64 value from the State. // It returns the float64 and a boolean indicating successful type assertion. func (s *State) GetFloat64(key string) (float64, bool) { return GetState[float64](s, key) } // GetUint retrieves a uint value from the State. // It returns the uint and a boolean indicating successful type assertion. func (s *State) GetUint(key string) (uint, bool) { return GetState[uint](s, key) } // GetInt8 retrieves an int8 value from the State. // It returns the int8 and a boolean indicating successful type assertion. func (s *State) GetInt8(key string) (int8, bool) { return GetState[int8](s, key) } // GetInt16 retrieves an int16 value from the State. // It returns the int16 and a boolean indicating successful type assertion. func (s *State) GetInt16(key string) (int16, bool) { return GetState[int16](s, key) } // GetInt32 retrieves an int32 value from the State. // It returns the int32 and a boolean indicating successful type assertion. func (s *State) GetInt32(key string) (int32, bool) { return GetState[int32](s, key) } // GetInt64 retrieves an int64 value from the State. // It returns the int64 and a boolean indicating successful type assertion. func (s *State) GetInt64(key string) (int64, bool) { return GetState[int64](s, key) } // GetUint8 retrieves a uint8 value from the State. // It returns the uint8 and a boolean indicating successful type assertion. func (s *State) GetUint8(key string) (uint8, bool) { return GetState[uint8](s, key) } // GetUint16 retrieves a uint16 value from the State. // It returns the uint16 and a boolean indicating successful type assertion. func (s *State) GetUint16(key string) (uint16, bool) { return GetState[uint16](s, key) } // GetUint32 retrieves a uint32 value from the State. // It returns the uint32 and a boolean indicating successful type assertion. func (s *State) GetUint32(key string) (uint32, bool) { return GetState[uint32](s, key) } // GetUint64 retrieves a uint64 value from the State. // It returns the uint64 and a boolean indicating successful type assertion. func (s *State) GetUint64(key string) (uint64, bool) { return GetState[uint64](s, key) } // GetUintptr retrieves a uintptr value from the State. // It returns the uintptr and a boolean indicating successful type assertion. func (s *State) GetUintptr(key string) (uintptr, bool) { return GetState[uintptr](s, key) } // GetFloat32 retrieves a float32 value from the State. // It returns the float32 and a boolean indicating successful type assertion. func (s *State) GetFloat32(key string) (float32, bool) { return GetState[float32](s, key) } // GetComplex64 retrieves a complex64 value from the State. // It returns the complex64 and a boolean indicating successful type assertion. func (s *State) GetComplex64(key string) (complex64, bool) { return GetState[complex64](s, key) } // GetComplex128 retrieves a complex128 value from the State. // It returns the complex128 and a boolean indicating successful type assertion. func (s *State) GetComplex128(key string) (complex128, bool) { return GetState[complex128](s, key) } // serviceKey returns a key for a service in the State. // A key is composed of the State's servicePrefix (hashed) and the hash of the service string. // This way we can avoid collisions and have a unique key for each service. func (s *State) serviceKey(key string) string { // hash the service string to avoid collisions return s.servicePrefix + hex.EncodeToString([]byte(key)) } // setService sets a service in the State. func (s *State) setService(srv Service) { // Always prepend the service key with the servicesStateKey to avoid collisions s.Set(s.serviceKey(srv.String()), srv) } // Delete removes a key-value pair from the State. func (s *State) deleteService(srv Service) { s.Delete(s.serviceKey(srv.String())) } // serviceKeys returns a slice containing all keys present for services in the application's State. func (s *State) serviceKeys() []string { // Pre-allocate with estimated capacity to reduce allocations keys := make([]string, 0, 8) s.dependencies.Range(func(key, _ any) bool { keyStr, ok := key.(string) if !ok { return true } if !strings.HasPrefix(keyStr, s.servicePrefix) { return true // Continue iterating if key doesn't have service prefix } keys = append(keys, keyStr) return true }) return keys } // Services returns a map containing all services present in the State. // The key is the hash of the service String() value and the value is the service itself. func (s *State) Services() map[string]Service { keys := s.serviceKeys() services := make(map[string]Service, len(keys)) for _, key := range keys { services[key] = MustGetState[Service](s, key) } return services } // ServicesLen returns the number of keys for services in the State. func (s *State) ServicesLen() int { length := 0 s.dependencies.Range(func(key, _ any) bool { if str, ok := key.(string); ok && strings.HasPrefix(str, s.servicePrefix) { length++ } return true }) return length } // GetService returns a service present in the application's State. func GetService[T Service](s *State, key string) (T, bool) { srv, ok := GetState[T](s, s.serviceKey(key)) return srv, ok } // MustGetService returns a service present in the application's State. // It panics if the service is not found. func MustGetService[T Service](s *State, key string) T { srv, ok := GetService[T](s, key) if !ok { panic("state: service not found!") } return srv } ================================================ FILE: state_test.go ================================================ package fiber import ( "strconv" "testing" "github.com/stretchr/testify/require" ) func TestState_SetAndGet_WithApp(t *testing.T) { t.Parallel() // Create app app := New() // test setting and getting a value app.State().Set("foo", "bar") val, ok := app.State().Get("foo") require.True(t, ok) require.Equal(t, "bar", val) // test key not found _, ok = app.State().Get("unknown") require.False(t, ok) } func TestState_SetAndGet(t *testing.T) { t.Parallel() st := newState() // test setting and getting a value st.Set("foo", "bar") val, ok := st.Get("foo") require.True(t, ok) require.Equal(t, "bar", val) // test key not found _, ok = st.Get("unknown") require.False(t, ok) } func TestState_GetString(t *testing.T) { t.Parallel() st := newState() st.Set("str", "hello") s, ok := st.GetString("str") require.True(t, ok) require.Equal(t, "hello", s) // wrong type should return false st.Set("num", 123) s, ok = st.GetString("num") require.False(t, ok) require.Empty(t, s) // missing key should return false s, ok = st.GetString("missing") require.False(t, ok) require.Empty(t, s) } func TestState_GetInt(t *testing.T) { t.Parallel() st := newState() st.Set("num", 456) i, ok := st.GetInt("num") require.True(t, ok) require.Equal(t, 456, i) // wrong type should return zero value st.Set("str", "abc") i, ok = st.GetInt("str") require.False(t, ok) require.Equal(t, 0, i) // missing key should return zero value i, ok = st.GetInt("missing") require.False(t, ok) require.Equal(t, 0, i) } func TestState_GetBool(t *testing.T) { t.Parallel() st := newState() st.Set("flag", true) b, ok := st.GetBool("flag") require.True(t, ok) require.True(t, b) // wrong type st.Set("num", 1) b, ok = st.GetBool("num") require.False(t, ok) require.False(t, b) // missing key should return false b, ok = st.GetBool("missing") require.False(t, ok) require.False(t, b) } func TestState_GetFloat64(t *testing.T) { t.Parallel() st := newState() st.Set("pi", 3.14) f, ok := st.GetFloat64("pi") require.True(t, ok) require.InDelta(t, 3.14, f, 0.0001) // wrong type should return zero value st.Set("int", 10) f, ok = st.GetFloat64("int") require.False(t, ok) require.InDelta(t, 0.0, f, 0.0001) // missing key should return zero value f, ok = st.GetFloat64("missing") require.False(t, ok) require.InDelta(t, 0.0, f, 0.0001) } func TestState_GetUint(t *testing.T) { t.Parallel() st := newState() st.Set("uint", uint(100)) u, ok := st.GetUint("uint") require.True(t, ok) require.Equal(t, uint(100), u) st.Set("wrong", "not uint") u, ok = st.GetUint("wrong") require.False(t, ok) require.Equal(t, uint(0), u) u, ok = st.GetUint("missing") require.False(t, ok) require.Equal(t, uint(0), u) } func TestState_GetInt8(t *testing.T) { t.Parallel() st := newState() st.Set("int8", int8(10)) i, ok := st.GetInt8("int8") require.True(t, ok) require.Equal(t, int8(10), i) st.Set("wrong", "not int8") i, ok = st.GetInt8("wrong") require.False(t, ok) require.Equal(t, int8(0), i) i, ok = st.GetInt8("missing") require.False(t, ok) require.Equal(t, int8(0), i) } func TestState_GetInt16(t *testing.T) { t.Parallel() st := newState() st.Set("int16", int16(200)) i, ok := st.GetInt16("int16") require.True(t, ok) require.Equal(t, int16(200), i) st.Set("wrong", "not int16") i, ok = st.GetInt16("wrong") require.False(t, ok) require.Equal(t, int16(0), i) i, ok = st.GetInt16("missing") require.False(t, ok) require.Equal(t, int16(0), i) } func TestState_GetInt32(t *testing.T) { t.Parallel() st := newState() st.Set("int32", int32(3000)) i, ok := st.GetInt32("int32") require.True(t, ok) require.Equal(t, int32(3000), i) st.Set("wrong", "not int32") i, ok = st.GetInt32("wrong") require.False(t, ok) require.Equal(t, int32(0), i) i, ok = st.GetInt32("missing") require.False(t, ok) require.Equal(t, int32(0), i) } func TestState_GetInt64(t *testing.T) { t.Parallel() st := newState() st.Set("int64", int64(4000)) i, ok := st.GetInt64("int64") require.True(t, ok) require.Equal(t, int64(4000), i) st.Set("wrong", "not int64") i, ok = st.GetInt64("wrong") require.False(t, ok) require.Equal(t, int64(0), i) i, ok = st.GetInt64("missing") require.False(t, ok) require.Equal(t, int64(0), i) } func TestState_GetUint8(t *testing.T) { t.Parallel() st := newState() st.Set("uint8", uint8(20)) u, ok := st.GetUint8("uint8") require.True(t, ok) require.Equal(t, uint8(20), u) st.Set("wrong", "not uint8") u, ok = st.GetUint8("wrong") require.False(t, ok) require.Equal(t, uint8(0), u) u, ok = st.GetUint8("missing") require.False(t, ok) require.Equal(t, uint8(0), u) } func TestState_GetUint16(t *testing.T) { t.Parallel() st := newState() st.Set("uint16", uint16(300)) u, ok := st.GetUint16("uint16") require.True(t, ok) require.Equal(t, uint16(300), u) st.Set("wrong", "not uint16") u, ok = st.GetUint16("wrong") require.False(t, ok) require.Equal(t, uint16(0), u) u, ok = st.GetUint16("missing") require.False(t, ok) require.Equal(t, uint16(0), u) } func TestState_GetUint32(t *testing.T) { t.Parallel() st := newState() st.Set("uint32", uint32(400000)) u, ok := st.GetUint32("uint32") require.True(t, ok) require.Equal(t, uint32(400000), u) st.Set("wrong", "not uint32") u, ok = st.GetUint32("wrong") require.False(t, ok) require.Equal(t, uint32(0), u) u, ok = st.GetUint32("missing") require.False(t, ok) require.Equal(t, uint32(0), u) } func TestState_GetUint64(t *testing.T) { t.Parallel() st := newState() st.Set("uint64", uint64(5000000)) u, ok := st.GetUint64("uint64") require.True(t, ok) require.Equal(t, uint64(5000000), u) st.Set("wrong", "not uint64") u, ok = st.GetUint64("wrong") require.False(t, ok) require.Equal(t, uint64(0), u) u, ok = st.GetUint64("missing") require.False(t, ok) require.Equal(t, uint64(0), u) } func TestState_GetUintptr(t *testing.T) { t.Parallel() st := newState() var ptr uintptr = 12345 st.Set("uintptr", ptr) u, ok := st.GetUintptr("uintptr") require.True(t, ok) require.Equal(t, ptr, u) st.Set("wrong", "not uintptr") u, ok = st.GetUintptr("wrong") require.False(t, ok) require.Equal(t, uintptr(0), u) u, ok = st.GetUintptr("missing") require.False(t, ok) require.Equal(t, uintptr(0), u) } func TestState_GetFloat32(t *testing.T) { t.Parallel() st := newState() st.Set("float32", float32(3.14)) f, ok := st.GetFloat32("float32") require.True(t, ok) require.InDelta(t, float32(3.14), f, 0.0001) st.Set("wrong", "not float32") f, ok = st.GetFloat32("wrong") require.False(t, ok) require.InDelta(t, float32(0), f, 0.0001) f, ok = st.GetFloat32("missing") require.False(t, ok) require.InDelta(t, float32(0), f, 0.0001) } func TestState_GetComplex64(t *testing.T) { t.Parallel() st := newState() var c complex64 = complex(2, 3) st.Set("complex64", c) cRes, ok := st.GetComplex64("complex64") require.True(t, ok) require.Equal(t, c, cRes) st.Set("wrong", "not complex64") cRes, ok = st.GetComplex64("wrong") require.False(t, ok) require.Equal(t, complex64(0), cRes) cRes, ok = st.GetComplex64("missing") require.False(t, ok) require.Equal(t, complex64(0), cRes) } func TestState_GetComplex128(t *testing.T) { t.Parallel() st := newState() c := complex(4, 5) st.Set("complex128", c) cRes, ok := st.GetComplex128("complex128") require.True(t, ok) require.Equal(t, c, cRes) st.Set("wrong", "not complex128") cRes, ok = st.GetComplex128("wrong") require.False(t, ok) require.Equal(t, complex128(0), cRes) cRes, ok = st.GetComplex128("missing") require.False(t, ok) require.Equal(t, complex128(0), cRes) } func TestState_MustGet(t *testing.T) { t.Parallel() st := newState() st.Set("exists", "value") val := st.MustGet("exists") require.Equal(t, "value", val) // must-get on missing key should panic require.Panics(t, func() { _ = st.MustGet("missing") }) } func TestState_Has(t *testing.T) { t.Parallel() st := newState() st.Set("key", "value") require.True(t, st.Has("key")) } func TestState_Delete(t *testing.T) { t.Parallel() st := newState() st.Set("key", "value") st.Delete("key") _, ok := st.Get("key") require.False(t, ok) } func TestState_Reset(t *testing.T) { t.Parallel() st := newState() st.Set("a", 1) st.Set("b", 2) st.Reset() require.Equal(t, 0, st.Len()) require.Empty(t, st.Keys()) } func TestState_Keys(t *testing.T) { t.Parallel() st := newState() keys := []string{"one", "two", "three"} for _, k := range keys { st.Set(k, k) } returnedKeys := st.Keys() require.ElementsMatch(t, keys, returnedKeys) } func TestState_Keys_SkipsNonStringKeys(t *testing.T) { t.Parallel() st := newState() st.Set("one", "one") st.Set("two", "two") st.dependencies.Store(42, "value") st.dependencies.Store(struct{}{}, "value") returnedKeys := st.Keys() require.ElementsMatch(t, []string{"one", "two"}, returnedKeys) } func TestState_Keys_SkipsNonStringKeys_WithMixedOrder(t *testing.T) { t.Parallel() st := newState() st.Set("before", "value") st.dependencies.Store(42, "value") st.Set("after", "value") returnedKeys := st.Keys() require.ElementsMatch(t, []string{"before", "after"}, returnedKeys) } func TestState_Len(t *testing.T) { t.Parallel() st := newState() require.Equal(t, 0, st.Len()) st.Set("a", "a") require.Equal(t, 1, st.Len()) st.Set("b", "b") require.Equal(t, 2, st.Len()) st.Delete("a") require.Equal(t, 1, st.Len()) } type testCase[T any] struct { value any expected T name string key string ok bool } func runGenericTest[T any](t *testing.T, getter func(*State, string) (T, bool), tests []testCase[T]) { t.Helper() st := newState() for _, tc := range tests { st.Set(tc.key, tc.value) got, ok := getter(st, tc.key) require.Equal(t, tc.ok, ok, tc.name) require.Equal(t, tc.expected, got, tc.name) } } func TestState_GetGeneric(t *testing.T) { t.Parallel() runGenericTest(t, GetState[int], []testCase[int]{ {name: "int correct conversion", key: "num", value: 42, expected: 42, ok: true}, {name: "int wrong conversion from string", key: "str", value: "abc", expected: 0, ok: false}, }) runGenericTest(t, GetState[string], []testCase[string]{ {name: "string correct conversion", key: "strVal", value: "hello", expected: "hello", ok: true}, {name: "string wrong conversion from int", key: "intVal", value: 100, expected: "", ok: false}, }) runGenericTest(t, GetState[bool], []testCase[bool]{ {name: "bool correct conversion", key: "flag", value: true, expected: true, ok: true}, {name: "bool wrong conversion from int", key: "intFlag", value: 1, expected: false, ok: false}, }) runGenericTest(t, GetState[float64], []testCase[float64]{ {name: "float64 correct conversion", key: "pi", value: 3.14, expected: 3.14, ok: true}, {name: "float64 wrong conversion from int", key: "intVal", value: 10, expected: 0.0, ok: false}, }) } func Test_MustGetStateGeneric(t *testing.T) { t.Parallel() st := newState() st.Set("flag", true) flag := MustGetState[bool](st, "flag") require.True(t, flag) // mismatched type should panic require.Panics(t, func() { _ = MustGetState[string](st, "flag") }) // missing key should also panic require.Panics(t, func() { _ = MustGetState[string](st, "missing") }) } func Test_GetStateWithDefault(t *testing.T) { t.Parallel() st := newState() st.Set("flag", true) flag := GetStateWithDefault(st, "flag", false) require.True(t, flag) // mismatched type should return the default value str := GetStateWithDefault(st, "flag", "default") require.Equal(t, "default", str) // missing key should return the default value flag = GetStateWithDefault(st, "missing", false) require.False(t, flag) } func TestState_Service(t *testing.T) { t.Parallel() srv1 := &mockService{name: "test1"} // service 2 is using a very subtle name to check it is not picked up srv2 := &mockService{name: "test1 "} t.Run("set/get/ok", func(t *testing.T) { t.Parallel() st := newState() st.setService(srv1) got, ok := st.Get(st.serviceKey(srv1.String())) require.True(t, ok) require.Equal(t, srv1, got) }) t.Run("set/get/ko", func(t *testing.T) { t.Parallel() st := newState() st.setService(srv1) koSrv := &mockService{name: "ko"} got, ok := st.Get(st.serviceKey(koSrv.String())) require.False(t, ok) require.Nil(t, got) }) t.Run("len", func(t *testing.T) { t.Parallel() t.Run("empty", func(t *testing.T) { t.Parallel() st := newState() require.Equal(t, 0, st.Len()) require.Empty(t, st.serviceKeys()) }) t.Run("with-services", func(t *testing.T) { t.Parallel() st := newState() st.setService(srv1) st.setService(srv2) require.Equal(t, 2, st.Len()) require.Equal(t, 2, st.ServicesLen()) }) t.Run("with-services/with-keys", func(t *testing.T) { t.Parallel() st := newState() st.setService(srv1) st.setService(srv2) st.Set("key1", "value1") st.Set("key2", "value2") servicesLen := st.ServicesLen() require.Equal(t, 4, st.Len()) require.Equal(t, 2, servicesLen) }) }) t.Run("keys", func(t *testing.T) { t.Run("empty", func(t *testing.T) { t.Parallel() st := newState() // adding more keys to check they are not included st.Set("key1", "value1") st.Set("key2", "value2") require.Empty(t, st.serviceKeys()) }) t.Run("with-services", func(t *testing.T) { t.Parallel() st := newState() st.setService(srv1) st.setService(srv2) keys := st.serviceKeys() require.Len(t, keys, 2) require.Contains(t, keys, st.serviceKey(srv1.String())) require.Contains(t, keys, st.serviceKey(srv2.String())) }) t.Run("with-services/with-keys", func(t *testing.T) { t.Parallel() st := newState() st.setService(srv1) st.setService(srv2) st.Set("key1", "value1") st.Set("key2", "value2") keys := st.serviceKeys() require.Len(t, keys, 2) require.Contains(t, keys, st.serviceKey(srv1.String())) require.Contains(t, keys, st.serviceKey(srv2.String())) require.NotContains(t, keys, "key1") require.NotContains(t, keys, "key2") }) t.Run("with-non-string-key", func(t *testing.T) { t.Parallel() st := newState() st.setService(srv1) st.setService(srv2) st.dependencies.Store(42, "value") st.dependencies.Store(struct{}{}, "value") keys := st.serviceKeys() require.Len(t, keys, 2) require.Contains(t, keys, st.serviceKey(srv1.String())) require.Contains(t, keys, st.serviceKey(srv2.String())) }) t.Run("with-non-service-keys", func(t *testing.T) { t.Parallel() st := newState() st.setService(srv1) st.setService(srv2) st.Set("other", "value") st.dependencies.Store(42, "value") keys := st.serviceKeys() require.Len(t, keys, 2) require.Contains(t, keys, st.serviceKey(srv1.String())) require.Contains(t, keys, st.serviceKey(srv2.String())) }) }) t.Run("delete", func(t *testing.T) { t.Parallel() t.Run("ok", func(t *testing.T) { t.Parallel() st := newState() st.setService(srv1) st.deleteService(srv1) _, ok := st.Get(st.serviceKey(srv1.String())) require.False(t, ok) }) t.Run("missing", func(t *testing.T) { t.Parallel() st := newState() st.setService(srv1) st.deleteService(srv2) _, ok := st.Get(st.serviceKey(srv1.String())) require.True(t, ok) _, ok = st.Get(st.serviceKey(srv2.String())) require.False(t, ok) }) }) } func TestState_GetService(t *testing.T) { t.Parallel() t.Run("ok", func(t *testing.T) { t.Parallel() srv1 := &mockService{name: "test1"} st := newState() st.setService(srv1) got, ok := GetService[*mockService](st, srv1.String()) require.True(t, ok) require.Equal(t, srv1, got) }) t.Run("ko", func(t *testing.T) { t.Parallel() srv1 := &mockService{name: "test1"} st := newState() got, ok := GetService[*mockService](st, srv1.String()) require.False(t, ok) require.Nil(t, got) }) } func TestState_MustGetService(t *testing.T) { t.Parallel() t.Run("ok", func(t *testing.T) { t.Parallel() srv1 := &mockService{name: "test1"} st := newState() st.setService(srv1) got := MustGetService[*mockService](st, srv1.String()) require.Equal(t, srv1, got) }) t.Run("panics", func(t *testing.T) { t.Parallel() srv1 := &mockService{name: "test1"} st := newState() require.Panics(t, func() { _ = MustGetService[*mockService](st, srv1.String()) }) }) } func BenchmarkState_Set(b *testing.B) { b.ReportAllocs() st := newState() i := 0 for b.Loop() { i++ key := "key" + strconv.Itoa(i) st.Set(key, i) } } func BenchmarkState_Get(b *testing.B) { b.ReportAllocs() st := newState() n := 1000 // prepopulate the state for i := range n { key := "key" + strconv.Itoa(i) st.Set(key, i) } i := 0 for b.Loop() { i++ key := "key" + strconv.Itoa(i%n) st.Get(key) } } func BenchmarkState_GetString(b *testing.B) { b.ReportAllocs() st := newState() n := 1000 // prepopulate the state for i := range n { key := "key" + strconv.Itoa(i) st.Set(key, strconv.Itoa(i)) } i := 0 for b.Loop() { i++ key := "key" + strconv.Itoa(i%n) st.GetString(key) } } func BenchmarkState_GetInt(b *testing.B) { b.ReportAllocs() st := newState() n := 1000 // prepopulate the state for i := range n { key := "key" + strconv.Itoa(i) st.Set(key, i) } i := 0 for b.Loop() { i++ key := "key" + strconv.Itoa(i%n) st.GetInt(key) } } func BenchmarkState_GetBool(b *testing.B) { b.ReportAllocs() st := newState() n := 1000 // prepopulate the state for i := range n { key := "key" + strconv.Itoa(i) st.Set(key, i%2 == 0) } i := 0 for b.Loop() { i++ key := "key" + strconv.Itoa(i%n) st.GetBool(key) } } func BenchmarkState_GetFloat64(b *testing.B) { b.ReportAllocs() st := newState() n := 1000 // prepopulate the state for i := range n { key := "key" + strconv.Itoa(i) st.Set(key, float64(i)) } i := 0 for b.Loop() { i++ key := "key" + strconv.Itoa(i%n) st.GetFloat64(key) } } func BenchmarkState_MustGet(b *testing.B) { b.ReportAllocs() st := newState() n := 1000 // prepopulate the state for i := range n { key := "key" + strconv.Itoa(i) st.Set(key, i) } i := 0 for b.Loop() { i++ key := "key" + strconv.Itoa(i%n) st.MustGet(key) } } func BenchmarkState_GetStateGeneric(b *testing.B) { b.ReportAllocs() st := newState() n := 1000 // prepopulate the state for i := range n { key := "key" + strconv.Itoa(i) st.Set(key, i) } i := 0 for b.Loop() { i++ key := "key" + strconv.Itoa(i%n) GetState[int](st, key) } } func BenchmarkState_MustGetStateGeneric(b *testing.B) { b.ReportAllocs() st := newState() n := 1000 // prepopulate the state for i := range n { key := "key" + strconv.Itoa(i) st.Set(key, i) } i := 0 for b.Loop() { i++ key := "key" + strconv.Itoa(i%n) MustGetState[int](st, key) } } func BenchmarkState_GetStateWithDefault(b *testing.B) { b.ReportAllocs() st := newState() n := 1000 // prepopulate the state for i := range n { key := "key" + strconv.Itoa(i) st.Set(key, i) } i := 0 for b.Loop() { i++ key := "key" + strconv.Itoa(i%n) GetStateWithDefault(st, key, 0) } } func BenchmarkState_Has(b *testing.B) { b.ReportAllocs() st := newState() n := 1000 // prepopulate the state for i := range n { st.Set("key"+strconv.Itoa(i), i) } i := 0 for b.Loop() { i++ st.Has("key" + strconv.Itoa(i%n)) } } func BenchmarkState_Delete(b *testing.B) { b.ReportAllocs() for b.Loop() { st := newState() st.Set("a", 1) st.Delete("a") } } func BenchmarkState_Reset(b *testing.B) { b.ReportAllocs() for b.Loop() { st := newState() // add a fixed number of keys before clearing for j := range 100 { st.Set("key"+strconv.Itoa(j), j) } st.Reset() } } func BenchmarkState_Keys(b *testing.B) { b.ReportAllocs() st := newState() n := 1000 for i := range n { st.Set("key"+strconv.Itoa(i), i) } for b.Loop() { _ = st.Keys() } } func BenchmarkState_Len(b *testing.B) { b.ReportAllocs() st := newState() n := 1000 for i := range n { st.Set("key"+strconv.Itoa(i), i) } for b.Loop() { _ = st.Len() } } func BenchmarkState_GetUint(b *testing.B) { b.ReportAllocs() st := newState() n := 1000 // Prepopulate the state with uint values. for i := range n { key := "key" + strconv.Itoa(i) st.Set(key, uint(i)) //nolint:gosec // G115 - test values are small } i := 0 for b.Loop() { i++ key := "key" + strconv.Itoa(i%n) st.GetUint(key) } } func BenchmarkState_GetInt8(b *testing.B) { b.ReportAllocs() st := newState() n := 1000 // Prepopulate the state with int8 values (using modulo to stay in range). for i := range n { key := "key" + strconv.Itoa(i) st.Set(key, int8(i%128)) //nolint:gosec // G115 - modulo keeps value in range } i := 0 for b.Loop() { i++ key := "key" + strconv.Itoa(i%n) st.GetInt8(key) } } func BenchmarkState_GetInt16(b *testing.B) { b.ReportAllocs() st := newState() n := 1000 // Prepopulate the state with int16 values. for i := range n { key := "key" + strconv.Itoa(i) st.Set(key, int16(i)) //nolint:gosec // G115 - test values are small } i := 0 for b.Loop() { i++ key := "key" + strconv.Itoa(i%n) st.GetInt16(key) } } func BenchmarkState_GetInt32(b *testing.B) { b.ReportAllocs() st := newState() n := 1000 // Prepopulate the state with int32 values. for i := range n { key := "key" + strconv.Itoa(i) st.Set(key, int32(i)) //nolint:gosec // G115 - test values are small } i := 0 for b.Loop() { i++ key := "key" + strconv.Itoa(i%n) st.GetInt32(key) } } func BenchmarkState_GetInt64(b *testing.B) { b.ReportAllocs() st := newState() n := 1000 // Prepopulate the state with int64 values. for i := range n { key := "key" + strconv.Itoa(i) st.Set(key, int64(i)) } i := 0 for b.Loop() { i++ key := "key" + strconv.Itoa(i%n) st.GetInt64(key) } } func BenchmarkState_GetUint8(b *testing.B) { b.ReportAllocs() st := newState() n := 1000 // Prepopulate the state with uint8 values. for i := range n { key := "key" + strconv.Itoa(i) st.Set(key, uint8(i%256)) //nolint:gosec // G115 - modulo keeps value in range } i := 0 for b.Loop() { i++ key := "key" + strconv.Itoa(i%n) st.GetUint8(key) } } func BenchmarkState_GetUint16(b *testing.B) { b.ReportAllocs() st := newState() n := 1000 // Prepopulate the state with uint16 values. for i := range n { key := "key" + strconv.Itoa(i) st.Set(key, uint16(i)) //nolint:gosec // G115 - test values are small } i := 0 for b.Loop() { i++ key := "key" + strconv.Itoa(i%n) st.GetUint16(key) } } func BenchmarkState_GetUint32(b *testing.B) { b.ReportAllocs() st := newState() n := 1000 // Prepopulate the state with uint32 values. for i := range n { key := "key" + strconv.Itoa(i) st.Set(key, uint32(i)) //nolint:gosec // G115 - test values are small } i := 0 for b.Loop() { i++ key := "key" + strconv.Itoa(i%n) st.GetUint32(key) } } func BenchmarkState_GetUint64(b *testing.B) { b.ReportAllocs() st := newState() n := 1000 // Prepopulate the state with uint64 values. for i := range n { key := "key" + strconv.Itoa(i) st.Set(key, uint64(i)) //nolint:gosec // G115 - test values are small } i := 0 for b.Loop() { i++ key := "key" + strconv.Itoa(i%n) st.GetUint64(key) } } func BenchmarkState_GetUintptr(b *testing.B) { b.ReportAllocs() st := newState() n := 1000 // Prepopulate the state with uintptr values. for i := range n { key := "key" + strconv.Itoa(i) st.Set(key, uintptr(i)) } i := 0 for b.Loop() { i++ key := "key" + strconv.Itoa(i%n) st.GetUintptr(key) } } func BenchmarkState_GetFloat32(b *testing.B) { b.ReportAllocs() st := newState() n := 1000 // Prepopulate the state with float32 values. for i := range n { key := "key" + strconv.Itoa(i) st.Set(key, float32(i)) } i := 0 for b.Loop() { i++ key := "key" + strconv.Itoa(i%n) st.GetFloat32(key) } } func BenchmarkState_GetComplex64(b *testing.B) { b.ReportAllocs() st := newState() n := 1000 // Prepopulate the state with complex64 values. for i := range n { key := "key" + strconv.Itoa(i) // Create a complex64 value with both real and imaginary parts. st.Set(key, complex(float32(i), float32(i))) } i := 0 for b.Loop() { i++ key := "key" + strconv.Itoa(i%n) st.GetComplex64(key) } } func BenchmarkState_GetComplex128(b *testing.B) { b.ReportAllocs() st := newState() n := 1000 // Prepopulate the state with complex128 values. for i := range n { key := "key" + strconv.Itoa(i) // Create a complex128 value with both real and imaginary parts. st.Set(key, complex(float64(i), float64(i))) } i := 0 for b.Loop() { i++ key := "key" + strconv.Itoa(i%n) st.GetComplex128(key) } } func BenchmarkState_GetService(b *testing.B) { b.ReportAllocs() st := newState() srv := &mockService{name: "benchService"} st.setService(srv) for b.Loop() { _, _ = GetService[*mockService](st, srv.String()) } } func BenchmarkState_MustGetService(b *testing.B) { b.ReportAllocs() st := newState() srv := &mockService{name: "benchService"} st.setService(srv) for b.Loop() { _ = MustGetService[*mockService](st, srv.String()) } } ================================================ FILE: storage_interface.go ================================================ package fiber import ( "context" "time" ) // Storage interface for communicating with different database/key-value // providers type Storage interface { // GetWithContext gets the value for the given key with a context. // `nil, nil` is returned when the key does not exist GetWithContext(ctx context.Context, key string) ([]byte, error) // Get gets the value for the given key. // `nil, nil` is returned when the key does not exist Get(key string) ([]byte, error) // SetWithContext stores the given value for the given key // with an expiration value, 0 means no expiration. // Empty key or value will be ignored without an error. SetWithContext(ctx context.Context, key string, val []byte, exp time.Duration) error // Set stores the given value for the given key along // with an expiration value, 0 means no expiration. // Empty key or value will be ignored without an error. Set(key string, val []byte, exp time.Duration) error // DeleteWithContext deletes the value for the given key with a context. // It returns no error if the storage does not contain the key, DeleteWithContext(ctx context.Context, key string) error // Delete deletes the value for the given key. // It returns no error if the storage does not contain the key, Delete(key string) error // ResetWithContext resets the storage and deletes all keys with a context. ResetWithContext(ctx context.Context) error // Reset resets the storage and delete all keys. Reset() error // Close closes the storage and will stop any running garbage // collectors and open connections. Close() error }