Repository: luraproject/lura Branch: master Commit: 6d79b4ef723b Files: 158 Total size: 684.6 KB Directory structure: gitextract_fat13st9/ ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ ├── bug_report.md │ │ ├── feature_request.md │ │ └── report-vulnerability.md │ ├── label-commenter-config.yml │ └── workflows/ │ ├── go.yml │ ├── labels.yml │ └── lock-threads.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README.md ├── SECURITY.md ├── async/ │ ├── asyncagent.go │ └── asyncagent_test.go ├── backoff/ │ ├── backoff.go │ └── backoff_test.go ├── config/ │ ├── config.go │ ├── config_test.go │ ├── parser.go │ ├── parser_test.go │ ├── uri.go │ └── uri_test.go ├── core/ │ └── version.go ├── docs/ │ ├── BENCHMARKS.md │ ├── CONFIG.md │ ├── OVERVIEW.md │ └── README.md ├── encoding/ │ ├── encoding.go │ ├── encoding_test.go │ ├── json_benchmark_test.go │ ├── json_test.go │ └── register.go ├── go.mod ├── go.sum ├── logging/ │ ├── log.go │ └── log_test.go ├── plugin/ │ ├── plugin.go │ └── plugin_test.go ├── proxy/ │ ├── balancing.go │ ├── balancing_benchmark_test.go │ ├── balancing_test.go │ ├── concurrent.go │ ├── concurrent_benchmark_test.go │ ├── concurrent_test.go │ ├── factory.go │ ├── factory_test.go │ ├── formatter.go │ ├── formatter_benchmark_test.go │ ├── formatter_test.go │ ├── graphql.go │ ├── graphql_test.go │ ├── headers_filter.go │ ├── headers_filter_test.go │ ├── http.go │ ├── http_benchmark_test.go │ ├── http_response.go │ ├── http_response_test.go │ ├── http_test.go │ ├── logging.go │ ├── logging_test.go │ ├── merging.go │ ├── merging_benchmark_test.go │ ├── merging_test.go │ ├── plugin/ │ │ ├── modifier.go │ │ ├── modifier_test.go │ │ └── tests/ │ │ ├── error/ │ │ │ └── main.go │ │ └── logger/ │ │ └── main.go │ ├── plugin.go │ ├── plugin_test.go │ ├── proxy.go │ ├── proxy_test.go │ ├── query_strings_filter.go │ ├── query_strings_filter_test.go │ ├── register.go │ ├── register_test.go │ ├── request.go │ ├── request_benchmark_test.go │ ├── request_test.go │ ├── shadow.go │ ├── shadow_test.go │ ├── stack_benchmark_test.go │ ├── stack_test.go │ ├── static.go │ └── static_test.go ├── register/ │ ├── register.go │ └── register_test.go ├── router/ │ ├── chi/ │ │ ├── endpoint.go │ │ ├── endpoint_benchmark_test.go │ │ ├── endpoint_test.go │ │ ├── router.go │ │ └── router_test.go │ ├── gin/ │ │ ├── debug.go │ │ ├── debug_test.go │ │ ├── echo.go │ │ ├── echo_test.go │ │ ├── endpoint.go │ │ ├── endpoint_benchmark_test.go │ │ ├── endpoint_test.go │ │ ├── engine.go │ │ ├── engine_test.go │ │ ├── render.go │ │ ├── render_test.go │ │ ├── router.go │ │ ├── router_test.go │ │ └── safecast.go │ ├── gorilla/ │ │ ├── router.go │ │ └── router_test.go │ ├── helper.go │ ├── helper_test.go │ ├── httptreemux/ │ │ ├── router.go │ │ └── router_test.go │ ├── mux/ │ │ ├── debug.go │ │ ├── debug_test.go │ │ ├── echo.go │ │ ├── echo_test.go │ │ ├── endpoint.go │ │ ├── endpoint_benchmark_test.go │ │ ├── endpoint_test.go │ │ ├── engine.go │ │ ├── engine_test.go │ │ ├── render.go │ │ ├── render_test.go │ │ ├── router.go │ │ └── router_test.go │ ├── negroni/ │ │ ├── router.go │ │ └── router_test.go │ └── router.go ├── sd/ │ ├── dnssrv/ │ │ ├── subscriber.go │ │ └── subscriber_test.go │ ├── loadbalancing.go │ ├── loadbalancing_benchmark_test.go │ ├── loadbalancing_test.go │ ├── register.go │ ├── register_test.go │ └── subscriber.go ├── test/ │ ├── doc.go │ └── integration_test.go └── transport/ └── http/ ├── client/ │ ├── executor.go │ ├── executor_test.go │ ├── graphql/ │ │ ├── graphql.go │ │ └── graphql_test.go │ ├── plugin/ │ │ ├── doc.go │ │ ├── executor.go │ │ ├── plugin.go │ │ ├── plugin_test.go │ │ └── tests/ │ │ └── main.go │ ├── status.go │ └── status_test.go └── server/ ├── plugin/ │ ├── doc.go │ ├── plugin.go │ ├── plugin_test.go │ ├── server.go │ └── tests/ │ └── main.go ├── server.go ├── server_test.go └── tls_test.go ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.md ================================================ --- name: Bug report about: Create a report to help us improve title: '' labels: '' assignees: '' --- **Describe the bug** A clear and concise description of what the bug is. **To Reproduce** Steps to reproduce the behavior: 1. Configuration used 2. Steps to run the software **Expected behavior** A clear and concise description of what you expected to happen. **Logs** If applicable, any logs and debugging information **Additional context** Add any other context about the problem here. ================================================ FILE: .github/ISSUE_TEMPLATE/feature_request.md ================================================ --- name: Feature request about: Suggest an idea for this project title: '' labels: '' assignees: '' --- **Is your feature request related to a problem? Please describe.** A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] **Describe the solution you'd like** A clear and concise description of what you want to happen. **Describe alternatives you've considered** A clear and concise description of any alternative solutions or features you've considered. **Additional context** Add any other context or screenshots about the feature request here. ================================================ FILE: .github/ISSUE_TEMPLATE/report-vulnerability.md ================================================ --- name: Report vulnerability about: Report a vulnerability title: '' labels: '' assignees: '' --- For **private vulnerabilities** write to support@devops.faith instead **Vulnerabilty description** Explain the vulnerability in detail and how to reproduce when possible **Reference** E.g: https://nvd.nist.gov/vuln/detail/CVE-2019-6486 **Additional information** Impact, Known Affected Software Configurations, etc. ================================================ FILE: .github/label-commenter-config.yml ================================================ comment: footer: | --- > This is an automated comment. Responding to the bot or mentioning it won't have any effect labels: - name: invalid labeled: issue: body: | Hi, thanks for bringing this issue to our attention. Unfortunately, this issue has been marked invalid and will be closed. The most common reasons for marking an issue or pull request as invalid is because: - It's vague or not clearly actionable - It contains insufficient details to reproduce - It's plugin review or custom code review - It does not use the issue template - It's unrelated to the project (e.g., related to one of its libraries) - It does not follow the technical philosophy of the project - Violates our [Code of Conduct](https://lfprojects.org/policies/code-of-conduct/) - It's about KrakenD functionalities (which uses and mantain Lura, but is not part of the Linux Foundation) You can still make an edit or leave additional comments that lead to reopening this issue. action: close pr: body: | Hi @{{ pull_request.user.login }}, thanks for having spent the time to code and send an improvement to Lura. Unfortunately, this pull request has been marked as invalid and will be closed without merging. The most common reasons for marking an issue or pull request as invalid is because: - Contains insufficient details, it's unclear for the reviewer, or it's impossible to move forward without a lot of interaction - It's unrelated to the project (e.g., related to one of its libraries) - It does not follow the philosophy of the project - Violates our [Code of Conduct](https://lfprojects.org/policies/code-of-conduct/) You can still make an edit or leave additional comments that lead to reopening this issue. action: close - name: wontfix labeled: issue: body: | Hi, thank you for bringing this issue to our attention. Many factors influence our product roadmaps and determine the features, fixes, and suggestions we implement. When deciding what to prioritize and work on, we combine your feedback and suggestions with insights from our development team, product analytics, research findings, and more. This information, combined with our product vision, determines what we implement and its priority order. Unfortunately, we don't foresee this issue progressing any further in the short-medium term, and we are closing it. While this issue is now closed, we continue monitoring requests for our future roadmap, **including this one**. If you have additional information you would like to provide, please share. action: close pr: body: | Hi @{{ pull_request.user.login }}, thanks for having spent the time to code and send an improvement to Lura. When deciding what to accept and include in our product, we are cautious about what we add, and the time our team needs to spend to have it done exemplary, considering all edge cases. As a result, we rarely add features, make changes that a tiny number of users need, or are out-of-scope of the project. For example, we might choose safety over having a specific additional feature that adds complexity we don't see crystal clear. Sometimes "less is more" because we can focus better on crucial functionality. Although it's never nice to reject someone's work, after evaluating your code, we think it's better not to merge it or continue putting effort into it on both sides. If this PR is to solve what you considered a bug, our understanding of the functionality does not need to match your thinking. So while this pull request is now closed, **this is not a definitive decision**. You are free to provide additional information that would help us change our minds. Lura is indirectly used in millions of servers every year, and the slightest change has an impact. We are doing it for all community users' benefit and to keep the code's simplicity, philosophy, and maintainability for the long run. action: close - name: duplicate labeled: issue: body: An issue like this already exists, please follow it in the other thread action: close - name: good first issue labeled: issue: body: This issue is easy for contributing. Everyone can work on this. - name: spam labeled: issue: body: | This issue has been **LOCKED** because of spam! Please do not spam messages and/or issues on the issue tracker. You may get blocked from this repository for doing so. action: close locking: lock lock_reason: spam pr: body: | This pull-request has been **LOCKED** because of spam! Please do not spam messages and/or pull-requests on this project. You may get blocked from this repository for doing so. action: close locking: lock lock_reason: spam ================================================ FILE: .github/workflows/go.yml ================================================ name: Go on: push: branches: [ master ] pull_request: branches: [ master ] jobs: build: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - name: Set up Go uses: actions/setup-go@v3 with: go-version: "1.25" - name: Build run: go build -v ./... - name: Test run: go test -cover -race ./... - name: Integration Test run: go test -tags integration ./test ================================================ FILE: .github/workflows/labels.yml ================================================ name: Label commenter on: issues: types: [labeled, unlabeled] pull_request_target: types: [labeled, unlabeled] jobs: stale: uses: luraproject/.github/.github/workflows/label-commenter.yml@main ================================================ FILE: .github/workflows/lock-threads.yml ================================================ name: 'Lock Threads' on: schedule: - cron: '0 0 * * *' workflow_dispatch: permissions: issues: write pull-requests: write concurrency: group: lock jobs: action: runs-on: ubuntu-latest steps: - uses: dessant/lock-threads@v3 with: pr-inactive-days: '90' issue-inactive-days: '90' add-issue-labels: 'locked' issue-comment: > This issue was marked as resolved a long time ago and now has been automatically locked as there has not been any recent activity after it. You can still open a new issue and reference this link. pr-comment: > This pull request was marked as resolved a long time ago and now has been automatically locked as there has not been any recent activity after it. You can still open a new issue and reference this link. ================================================ FILE: .gitignore ================================================ vendor server.rsa.crt server.rsa.key *.pem *.json *.toml *.dot *.out *.so bench_res .cover .idea *DS_Store ================================================ FILE: CODE_OF_CONDUCT.md ================================================ # Contributor Covenant Code of Conduct ## Our Pledge In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation. ## Our Standards Examples of behavior that contributes to creating a positive environment include: * Using welcoming and inclusive language * Being respectful of differing viewpoints and experiences * Gracefully accepting constructive criticism * Focusing on what is best for the community * Showing empathy towards other community members Examples of unacceptable behavior by participants include: * The use of sexualized language or imagery and unwelcome sexual attention or advances * Trolling, insulting/derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or electronic address, without explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Our Responsibilities Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. ## Scope This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at hello@krakend.io. The project team will review and investigate all complaints, and will respond in a way that it deems appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at [http://contributor-covenant.org/version/1/4][version] [homepage]: http://contributor-covenant.org [version]: http://contributor-covenant.org/version/1/4/ ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing Thank you for your interest in contributing to Lura, there are several ways you can contribute and make this project more awesome, please see below: ## Reporting an Issue If you believe you have found an issue with the code please do not hesitate to file an issue in [Github](https://github.com/luraproject/lura/issues). When filing the issue please describe the problem with the maximum level of detail and the steps to reproduce the problem, including information about your environment. You can also open an issue requesting for help or doing a question and it's also a good way of contributing since other users might be in a similar position. Please note we have a code of conduct, please follow it in all your interactions with the project. ## Code Contributions When contributing to this repository, it is generally a good idea to discuss the change with the owners before investing a lot of time coding. The process could be: 1. Open an issue explaining the improvment or fix you want to add 2. [Fork the project](https://github.com/luraproject/lura/fork) 3. Code it in your fork 4. Submit a [pull request](https://help.github.com/articles/creating-a-pull-request) referencing the issue Your work will then be reviewed as soon as possible (suggestions about some changes, improvements or alternatives may be given). **Don't forget to add tests**, make sure that they all pass! # Help with Git Once the repository is forked, you should track the upstream (original) one using the following command: git remote add upstream https://github.com/luraproject/lura.git Then you should create your own branch: git checkout -b /- Once your changes are done (`git commit -am ''`), get the upstream changes: git checkout master git pull --rebase origin master git pull --rebase upstream master git checkout git rebase master Finally, publish your changes: git push -f origin You should be now ready to make a pull request. ================================================ FILE: LICENSE ================================================ Copyright © 2021 Lura Project a Series of LF Projects, LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: Makefile ================================================ .PHONY: all test build benchmark OS := $(shell uname | tr '[:upper:]' '[:lower:]') GIT_COMMIT := $(shell git rev-parse --short=7 HEAD) all: test build generate: go generate ./... go build -buildmode=plugin -o ./transport/http/client/plugin/tests/lura-client-example.so ./transport/http/client/plugin/tests go build -buildmode=plugin -o ./transport/http/server/plugin/tests/lura-server-example.so ./transport/http/server/plugin/tests go build -buildmode=plugin -o ./proxy/plugin/tests/lura-request-modifier-example.so ./proxy/plugin/tests/logger go build -buildmode=plugin -o ./proxy/plugin/tests/lura-error-example.so ./proxy/plugin/tests/error test: generate go test -cover -race ./... go test -tags integration ./test/... go test -tags integration ./transport/... go test -tags integration ./proxy/... benchmark: @mkdir -p bench_res @touch bench_res/${GIT_COMMIT}.out @go test -run none -bench . -benchmem ./... >> bench_res/${GIT_COMMIT}.out build: go build ./... ================================================ FILE: README.md ================================================ # The Lura Project framework [![Go Report Card](https://goreportcard.com/badge/github.com/luraproject/lura/v2)](https://goreportcard.com/report/github.com/luraproject/lura/v2) [![GoDoc](https://godoc.org/github.com/luraproject/lura/v2?status.svg)](https://godoc.org/github.com/luraproject/lura/v2) ![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/3151/badge) [![Slack Widget](https://img.shields.io/badge/join-us%20on%20slack-gray.svg?longCache=true&logo=slack&colorB=red)](https://gophers.slack.com/messages/lura) [![FOSSA Status](https://app.fossa.com/api/projects/git%2Bgithub.com%2Fluraproject%2Flura.svg?type=shield&issueType=license)](https://app.fossa.com/projects/git%2Bgithub.com%2Fluraproject%2Flura%2Fv2?ref=badge_shield&issueType=license) An open framework to assemble ultra performance API Gateways with middlewares; formerly known as _KrakenD framework_, and core service of the [KrakenD API Gateway](http://www.krakend.io). ## Motivation Consumers of REST API content (specially in microservices) often query backend services that weren't coded for the UI implementation. This is of course a good practice, but the UI consumers need to do implementations that suffer a lot of complexity and burden with the sizes of their microservices responses. Lura is an **API Gateway** builder and proxy generator that sits between the client and all the source servers, adding a new layer that removes all the complexity to the clients, providing them only the information that the UI needs. Lura acts as an **aggregator** of many sources into single endpoints and allows you to group, wrap, transform and shrink responses. Additionally it supports a myriad of middlewares and plugins that allow you to extend the functionality, such as adding Oauth authorization or security layers. Lura not only supports HTTP(S), but because it is a set of generic libraries you can build all type of API Gateways and proxies, including for instance, an RPC gateway. ### Practical Example A mobile developer needs to construct a single front page that requires data from 4 different calls to their backend services, e.g: 1) api.store.server/products 2) api.store.server/marketing-promos 3) api.users.server/users/{id_user} 4) api.users.server/shopping-cart/{id_user} The screen is very simple, and the mobile client _only_ needs to retrieve data from 4 different sources, wait for the round trip and then hand pick only a few fields from the response. What if the mobile could call a single endpoint? 1) lura.server/frontpage/{id_user} That's something Lura can do for you. And this is how it would look like: ![Gateway](https://luraproject.org/images/docs/lura-gateway.png) Lura would merge all the data and return only the fields you need (the difference in size in the graph). Visit the [Lura Project website](https://luraproject.org) for more information. ## What's in this repository? The source code for the [Lura project](https://luraproject.org) framework. It is designed to work with your own middleware and extend the functionality by using small, independent, reusable components following the Unix philosophy. Use this repository if you want to **build from source your API Gateway** or if you want to **reuse the components in another application**. If you need a fully functional API Gateway you can [download the KrakenD binary for your architecture](http://www.krakend.io/download) or [build it yourself](https://github.com/krakend/krakend-ce). ## Library Usage The Lura project is presented as a **Go library** that you can include in your own Go application to build a powerful proxy or API gateway. For a complete example, check the [KrakenD CE repository](https://github.com/krakend/krakend-ce). Of course, you will need [Go installed](https://golang.org/doc/install) in your system to compile the code. A ready to use example: ```go package main import ( "flag" "log" "os" "github.com/luraproject/lura/config" "github.com/luraproject/lura/logging" "github.com/luraproject/lura/proxy" "github.com/luraproject/lura/router/gin" ) func main() { port := flag.Int("p", 0, "Port of the service") logLevel := flag.String("l", "ERROR", "Logging level") debug := flag.Bool("d", false, "Enable the debug") configFile := flag.String("c", "/etc/lura/configuration.json", "Path to the configuration filename") flag.Parse() parser := config.NewParser() serviceConfig, err := parser.Parse(*configFile) if err != nil { log.Fatal("ERROR:", err.Error()) } serviceConfig.Debug = serviceConfig.Debug || *debug if *port != 0 { serviceConfig.Port = *port } logger, _ := logging.NewLogger(*logLevel, os.Stdout, "[LURA]") routerFactory := gin.DefaultFactory(proxy.DefaultFactory(logger), logger) routerFactory.New().Run(serviceConfig) } ``` Visit the [framework overview](/docs/OVERVIEW.md) for more details about the components of the Lura project. ## Configuration file [Lura config file](/docs/CONFIG.md) ## Benchmarks Check out the [benchmark results](/docs/BENCHMARKS.md) of several Lura components ## Contributing We are always happy to receive contributions. If you have questions, suggestions, bugs please open an issue. If you want to submit the code, create the issue and send us a pull request for review. Read [CONTRIBUTING.md](/CONTRIBUTING.md) for more information. ## Want more? - Follow us on Twitter: [@luraproject](https://twitter.com/luraproject) - Visit our [Slack channel](https://gophers.slack.com/messages/lura) - **Read the [documentation](/docs/OVERVIEW.md)** Enjoy Lura! ## License [![FOSSA Status](https://app.fossa.com/api/projects/git%2Bgithub.com%2Fluraproject%2Flura.svg?type=large)](https://app.fossa.com/projects/git%2Bgithub.com%2Fluraproject%2Flura?ref=badge_large) ================================================ FILE: SECURITY.md ================================================ # Security Policy Lura only fixes the latest version of the software, and does not patch prior versions. ## Reporting a Vulnerability Please email security@krakend.io with your discovery. As soon as we read and understand your finding we will provide an answer with next steps and possible timelines. We want to thank you in advance for the time you have spent to follow this issue, as it helps all open source users. We develop our software in the open with the help of a global community of developers and contributors with whom we share a common understanding and trust in the free exchange of knowledge. The Lura Project DOES NOT provide cash awards for discovered vulnerabilities at this time. Thank you ================================================ FILE: async/asyncagent.go ================================================ // SPDX-License-Identifier: Apache-2.0 /* */ package async import ( "context" "errors" "fmt" "math" "github.com/luraproject/lura/v2/backoff" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" "github.com/luraproject/lura/v2/proxy" "golang.org/x/sync/errgroup" ) // Options contains the configuration to pass to the async agent factory type Options struct { // Agent keeps the configuration for the async agent Agent *config.AsyncAgent // Endpoint encapsulates the configuration for the associated pipe Endpoint *config.EndpointConfig // Proxy is the pipe associated with the async agent Proxy proxy.Proxy // AgentPing is the channel for the agent to send ping messages AgentPing chan<- string // G is the error group responsible for managing the agents and the router itself G *errgroup.Group // ShouldContinue is a function signaling when to stop the connection retries ShouldContinue func(int) bool // BackoffF is a function encapsulating the backoff strategy BackoffF backoff.TimeToWaitBeforeRetry Logger logging.Logger } // Factory is a function able to start an async agent type Factory func(context.Context, Options) bool // AgentStarter groups a set of factories to be used type AgentStarter []Factory // Start executes all the factories for each async agent configuration func (a AgentStarter) Start( ctx context.Context, agents []*config.AsyncAgent, logger logging.Logger, agentPing chan<- string, pf proxy.Factory, ) func() error { if len(a) == 0 { return func() error { return ErrNoAgents } } g, ctx := errgroup.WithContext(ctx) for i, agent := range agents { i, agent := i, agent if agent.Name == "" { agent.Name = fmt.Sprintf("AsyncAgent-%02d", i) } logger.Debug(fmt.Sprintf("[SERVICE: AsyncAgent][%s] Starting the async agent", agent.Name)) for i := range agent.Backend { agent.Backend[i].Timeout = agent.Consumer.Timeout } endpoint := &config.EndpointConfig{ Endpoint: agent.Name, Timeout: agent.Consumer.Timeout, Backend: agent.Backend, ExtraConfig: agent.ExtraConfig, } p, err := pf.New(endpoint) if err != nil { logger.Error(fmt.Sprintf("[SERVICE: AsyncAgent][%s] building the proxy pipe:", agent.Name), err) continue } if agent.Connection.MaxRetries <= 0 { agent.Connection.MaxRetries = math.MaxInt } opts := Options{ Agent: agent, Endpoint: endpoint, Proxy: p, AgentPing: agentPing, G: g, ShouldContinue: func(i int) bool { return i <= agent.Connection.MaxRetries }, BackoffF: backoff.GetByName(agent.Connection.BackoffStrategy), Logger: logger, } for _, f := range a { if f(ctx, opts) { break } } } return g.Wait } var ErrNoAgents = errors.New("no agent factories defined") ================================================ FILE: async/asyncagent_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package async import ( "context" "testing" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" "github.com/luraproject/lura/v2/proxy" ) func TestAgentStarter_Start_last(t *testing.T) { var firstAgentCalled, secondAgentCalled bool firstAgent := func(_ context.Context, opts Options) bool { // TODO: check opts firstAgentCalled = true return false } secondAgent := func(_ context.Context, opts Options) bool { // TODO: check opts secondAgentCalled = true return true } ctx, cancel := context.WithCancel(context.Background()) defer cancel() ch := make(chan string) as := AgentStarter([]Factory{firstAgent, secondAgent}) agents := []*config.AsyncAgent{ {}, } wait := as.Start(ctx, agents, logging.NoOp, (chan<- string)(ch), noopProxyFactory) if err := wait(); err != nil { t.Error(err) } if !firstAgentCalled { t.Error("first agent not called") } if !secondAgentCalled { t.Error("second agent not called") } } func TestAgentStarter_Start_first(t *testing.T) { var firstAgentCalled, secondAgentCalled bool firstAgent := func(_ context.Context, opts Options) bool { firstAgentCalled = true return true } secondAgent := func(_ context.Context, opts Options) bool { secondAgentCalled = true return false } ctx, cancel := context.WithCancel(context.Background()) defer cancel() ch := make(chan string) as := AgentStarter([]Factory{firstAgent, secondAgent}) agents := []*config.AsyncAgent{ {}, } wait := as.Start(ctx, agents, logging.NoOp, (chan<- string)(ch), noopProxyFactory) if err := wait(); err != nil { t.Error(err) } if !firstAgentCalled { t.Error("first agent not called") } if secondAgentCalled { t.Error("second agent called") } } var noopProxyFactory = proxy.FactoryFunc(func(*config.EndpointConfig) (proxy.Proxy, error) { return proxy.NoopProxy, nil }) ================================================ FILE: backoff/backoff.go ================================================ // SPDX-License-Identifier: Apache-2.0 /* Package backoff contains some basic implementations and a selector by strategy name */ package backoff import ( "math/rand" "strings" "time" ) // GetByName returns the WaitBeforeRetry function implementing the strategy func GetByName(strategy string) TimeToWaitBeforeRetry { switch strings.ToLower(strategy) { case "linear": return LinearBackoff case "linear-jitter": return LinearJitterBackoff case "exponential": return ExponentialBackoff case "exponential-jitter": return ExponentialJitterBackoff } return DefaultBackoff } // TimeToWaitBeforeRetry returns the duration to wait before retrying for the // given time type TimeToWaitBeforeRetry func(int) time.Duration // DefaultBackoffDuration is the duration returned by the DefaultBackoff var DefaultBackoffDuration = time.Second // DefaultBackoff always returns DefaultBackoffDuration func DefaultBackoff(_ int) time.Duration { return DefaultBackoffDuration } // ExponentialBackoff returns ever increasing backoffs by a power of 2 func ExponentialBackoff(i int) time.Duration { return time.Duration(1< 1 { return errInvalidNoOpEncoding } e.ExtraConfig.sanitize() for j, b := range e.Backend { // we "tell" the backend which is his parent endpoint b.ParentEndpoint = e.Endpoint b.ParentEndpointMethod = e.Method if err := s.initBackendDefaults(i, j); err != nil { return err } if err := s.initBackendURLMappings(i, j, inputSet); err != nil { return err } b.ExtraConfig.sanitize() } } return nil } func (s *ServiceConfig) paramExtractionPattern() *regexp.Regexp { if s.DisableStrictREST { return simpleURLKeysPattern } return endpointURLKeysPattern } func (*ServiceConfig) extractPlaceHoldersFromURLTemplate(subject string, pattern *regexp.Regexp) []string { matches := pattern.FindAllStringSubmatch(subject, -1) keys := make([]string, len(matches)) for k, v := range matches { keys[k] = v[1] } return keys } func (s *ServiceConfig) initEndpointDefaults(e int) { endpoint := s.Endpoints[e] if endpoint.Method == "" { endpoint.Method = "GET" } if s.CacheTTL != 0 && endpoint.CacheTTL == 0 { endpoint.CacheTTL = s.CacheTTL } if s.Timeout != 0 && endpoint.Timeout == 0 { endpoint.Timeout = s.Timeout } if endpoint.ConcurrentCalls == 0 { endpoint.ConcurrentCalls = 1 } if endpoint.OutputEncoding == "" { if s.OutputEncoding != "" { endpoint.OutputEncoding = s.OutputEncoding } else { endpoint.OutputEncoding = encoding.JSON } } } func (s *ServiceConfig) initAsyncAgentDefaults(e int) { agent := s.AsyncAgents[e] if s.Timeout != 0 && agent.Consumer.Timeout == 0 { agent.Consumer.Timeout = s.Timeout } if agent.Consumer.Workers < 1 { agent.Consumer.Workers = 1 } if agent.Connection.HealthInterval < time.Second { agent.Connection.HealthInterval = time.Second } } func (s *ServiceConfig) initBackendDefaults(e, b int) error { endpoint := s.Endpoints[e] backend := endpoint.Backend[b] if len(backend.Host) == 0 { backend.Host = s.Host } else if !backend.HostSanitizationDisabled { var err error backend.Host, err = s.uriParser.SafeCleanHosts(backend.Host) if err != nil { return err } } if backend.Method == "" { backend.Method = endpoint.Method } if endpoint.OutputEncoding == encoding.NOOP { backend.Encoding = encoding.NOOP } backend.Timeout = endpoint.Timeout backend.ConcurrentCalls = endpoint.ConcurrentCalls backend.Decoder = encoding.GetRegister().Get(strings.ToLower(backend.Encoding))(backend.IsCollection) for i := range backend.HeadersToPass { backend.HeadersToPass[i] = textproto.CanonicalMIMEHeaderKey(backend.HeadersToPass[i]) } if backend.SDScheme == "" { backend.SDScheme = "http" } return nil } func (s *ServiceConfig) initBackendURLMappings(e, b int, inputParams map[string]interface{}) error { backend := s.Endpoints[e].Backend[b] backend.URLPattern = s.uriParser.CleanPath(backend.URLPattern) outputParams, outputSetSize := uniqueOutput(s.extractPlaceHoldersFromURLTemplate(backend.URLPattern, simpleURLKeysPattern)) ip := fromSetToSortedSlice(inputParams) if outputSetSize > len(ip) { return &WrongNumberOfParamsError{ Endpoint: s.Endpoints[e].Endpoint, Method: s.Endpoints[e].Method, Backend: b, InputParams: ip, OutputParams: outputParams, } } title := cases.Title(language.Und) backend.URLKeys = []string{} for _, output := range outputParams { if !sequentialParamsPattern.MatchString(output) { if _, ok := inputParams[output]; !ok { return &UndefinedOutputParamError{ Param: output, Endpoint: s.Endpoints[e].Endpoint, Method: s.Endpoints[e].Method, Backend: b, InputParams: ip, OutputParams: outputParams, } } } key := title.String(output[:1]) + output[1:] backend.URLPattern = strings.ReplaceAll(backend.URLPattern, "{"+output+"}", "{{."+key+"}}") backend.URLKeys = append(backend.URLKeys, key) } return nil } func fromSetToSortedSlice(set map[string]interface{}) []string { res := make([]string, 0, len(set)) for element := range set { res = append(res, element) } sort.Strings(res) return res } func uniqueOutput(output []string) ([]string, int) { sort.Strings(output) j := 0 outputSetSize := 0 for i := 1; i < len(output); i++ { if output[j] == output[i] { continue } if !sequentialParamsPattern.MatchString(output[j]) { outputSetSize++ } j++ output[j] = output[i] } if j == len(output) { return output, outputSetSize } return output[:j+1], outputSetSize } func (e *EndpointConfig) validate() error { matched, err := regexp.MatchString(invalidPattern, e.Endpoint) if err != nil { return &EndpointMatchError{ Err: err, Path: e.Endpoint, Method: e.Method, } } if matched { return &EndpointPathError{Path: e.Endpoint, Method: e.Method} } if len(e.Backend) == 0 { return &NoBackendsError{Path: e.Endpoint, Method: e.Method} } return nil } // EndpointMatchError is the error returned by the configuration init process when the endpoint pattern // check fails type EndpointMatchError struct { Path string Method string Err error } // Error returns a string representation of the EndpointMatchError func (e *EndpointMatchError) Error() string { return fmt.Sprintf("ignoring the '%s %s' endpoint due to a parsing error: %s", e.Method, e.Path, e.Err.Error()) } // NoBackendsError is the error returned by the configuration init process when an endpoint // is connected to 0 backends type NoBackendsError struct { Path string Method string } // Error returns a string representation of the NoBackendsError func (n *NoBackendsError) Error() string { return "ignoring the '" + n.Method + " " + n.Path + "' endpoint, since it has 0 backends defined!" } // UnsupportedVersionError is the error returned by the configuration init process when the configuration // version is not supported type UnsupportedVersionError struct { Have int Want int } // Error returns a string representation of the UnsupportedVersionError func (u *UnsupportedVersionError) Error() string { return fmt.Sprintf("unsupported version: %d (want: %d)", u.Have, u.Want) } // EndpointPathError is the error returned by the configuration init process when an endpoint // is using a forbidden path type EndpointPathError struct { Path string Method string } // Error returns a string representation of the EndpointPathError func (e *EndpointPathError) Error() string { return "ignoring the '" + e.Method + " " + e.Path + "' endpoint, since it is invalid!!!" } // UndefinedOutputParamError is the error returned by the configuration init process when an output // param is not present in the input param set type UndefinedOutputParamError struct { Endpoint string Method string Backend int InputParams []string OutputParams []string Param string } // Error returns a string representation of the UndefinedOutputParamError func (u *UndefinedOutputParamError) Error() string { return fmt.Sprintf( "undefined output param '%s'! endpoint: %s %s, backend: %d. input: %v, output: %v", u.Param, u.Method, u.Endpoint, u.Backend, u.InputParams, u.OutputParams, ) } // WrongNumberOfParamsError is the error returned by the configuration init process when the number of output // params is greatter than the number of input params type WrongNumberOfParamsError struct { Endpoint string Method string Backend int InputParams []string OutputParams []string } // Error returns a string representation of the WrongNumberOfParamsError func (w *WrongNumberOfParamsError) Error() string { return fmt.Sprintf( "input and output params do not match. endpoint: %s %s, backend: %d. input: %v, output: %v", w.Method, w.Endpoint, w.Backend, w.InputParams, w.OutputParams, ) } func SetSequentialParamsPattern(pattern string) error { re, err := regexp.Compile(pattern) if err != nil { return err } sequentialParamsPattern = re return nil } // SetInvalidPattern sets the invalidPattern variable to the provided value. func SetInvalidPattern(pattern string) { invalidPattern = pattern } func validateAddress(address string) bool { ip := net.ParseIP(address) return ip != nil } ================================================ FILE: config/config_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package config import ( "errors" "fmt" "strings" "testing" "time" ) func TestConfig_rejectInvalidVersion(t *testing.T) { subject := ServiceConfig{} err := subject.Init() if err == nil || strings.Index(err.Error(), "unsupported version: 0 (want: 3)") != 0 { t.Error("Error expected. Got", err.Error()) } } func TestConfig_rejectInvalidEndpoints(t *testing.T) { samples := []string{ "/__debug", "/__debug/", "/__debug/foo", "/__debug/foo/bar", } for _, e := range samples { subject := ServiceConfig{Version: ConfigVersion, Endpoints: []*EndpointConfig{{Endpoint: e, Method: "GET"}}} err := subject.Init() if err == nil || err.Error() != fmt.Sprintf("ignoring the 'GET %s' endpoint, since it is invalid!!!", e) { t.Errorf("Unexpected error processing '%s': %v", e, err) } } } func TestConfig_initBackendURLMappings_ok(t *testing.T) { samples := []string{ "supu/{tupu}", "/supu/{tupu1}", "/supu.local/", "supu/{tupu_56}/{supu-5t6}?a={foo}&b={foo}", "supu/{tupu_56}{supu-5t6}?a={foo}&b={foo}", "supu/tupu{supu-5t6}?a={foo}&b={foo}", "{resp0_x}/{tupu1}/{tupu_56}{supu-5t6}?a={tupu}&b={foo}", "{resp0_x}/{tupu1}/{JWT.foo}", "{resp0_x}/{tupu1}/{JWT.http://example.com/foo_bar}", } expected := []string{ "/supu/{{.Tupu}}", "/supu/{{.Tupu1}}", "/supu.local/", "/supu/{{.Tupu_56}}/{{.Supu-5t6}}?a={{.Foo}}&b={{.Foo}}", "/supu/{{.Tupu_56}}{{.Supu-5t6}}?a={{.Foo}}&b={{.Foo}}", "/supu/tupu{{.Supu-5t6}}?a={{.Foo}}&b={{.Foo}}", "/{{.Resp0_x}}/{{.Tupu1}}/{{.Tupu_56}}{{.Supu-5t6}}?a={{.Tupu}}&b={{.Foo}}", "/{{.Resp0_x}}/{{.Tupu1}}/{{.JWT.foo}}", "/{{.Resp0_x}}/{{.Tupu1}}/{{.JWT.http://example.com/foo_bar}}", } backend := Backend{} endpoint := EndpointConfig{Backend: []*Backend{&backend}} subject := ServiceConfig{Endpoints: []*EndpointConfig{&endpoint}, uriParser: NewSafeURIParser()} inputSet := map[string]interface{}{ "tupu": nil, "tupu1": nil, "tupu_56": nil, "supu-5t6": nil, "foo": nil, } for i := range samples { backend.URLPattern = samples[i] if err := subject.initBackendURLMappings(0, 0, inputSet); err != nil { t.Error(err) } if backend.URLPattern != expected[i] { t.Errorf("want: %s, have: %s\n", expected[i], backend.URLPattern) } } } func TestConfig_initBackendURLMappings_tooManyOutput(t *testing.T) { backend := Backend{URLPattern: "supu/{tupu_56}/{supu-5t6}?a={foo}&b={foo}"} endpoint := EndpointConfig{ Method: "GET", Endpoint: "/some/{tupu}", Backend: []*Backend{&backend}, } subject := ServiceConfig{Endpoints: []*EndpointConfig{&endpoint}, uriParser: NewSafeURIParser()} inputSet := map[string]interface{}{ "tupu": nil, } expectedErrMsg := "input and output params do not match. endpoint: GET /some/{tupu}, backend: 0. input: [tupu], output: [foo supu-5t6 tupu_56]" err := subject.initBackendURLMappings(0, 0, inputSet) if err == nil || err.Error() != expectedErrMsg { t.Errorf("Unexpected error: %v", err) } } func TestConfig_initBackendURLMappings_undefinedOutput(t *testing.T) { backend := Backend{URLPattern: "supu/{tupu_56}/{supu-5t6}?a={foo}&b={foo}"} endpoint := EndpointConfig{Endpoint: "/", Method: "GET", Backend: []*Backend{&backend}} subject := ServiceConfig{Endpoints: []*EndpointConfig{&endpoint}, uriParser: NewSafeURIParser()} inputSet := map[string]interface{}{ "tupu": nil, "supu": nil, "foo": nil, } expectedErrMsg := "undefined output param 'supu-5t6'! endpoint: GET /, backend: 0. input: [foo supu tupu], output: [foo supu-5t6 tupu_56]" err := subject.initBackendURLMappings(0, 0, inputSet) if err == nil || err.Error() != expectedErrMsg { t.Errorf("error expected. have: %v", err) } } func TestConfig_init(t *testing.T) { supuBackend := Backend{ URLPattern: "/__debug/supu", } supuEndpoint := EndpointConfig{ Endpoint: "/supu", Method: "post", Timeout: 1500 * time.Millisecond, CacheTTL: 6 * time.Hour, Backend: []*Backend{&supuBackend}, OutputEncoding: "some_render", } githubBackend := Backend{ URLPattern: "/", Host: []string{"https://api.github.com"}, AllowList: []string{"authorizations_url", "code_search_url"}, } githubEndpoint := EndpointConfig{ Endpoint: "/github", Timeout: 1500 * time.Millisecond, CacheTTL: 6 * time.Hour, Backend: []*Backend{&githubBackend}, } userBackend := Backend{ URLPattern: "/users/{user}", Host: []string{"https://jsonplaceholder.typicode.com"}, Mapping: map[string]string{"email": "personal_email"}, } rssBackend := Backend{ URLPattern: "/users/{user}", Host: []string{"https://jsonplaceholder.typicode.com"}, Encoding: "rss", } postBackend := Backend{ URLPattern: "/posts/{user}", Host: []string{"https://jsonplaceholder.typicode.com"}, Group: "posts", Encoding: "xml", } userEndpoint := EndpointConfig{ Endpoint: "/users/{user}", Backend: []*Backend{&userBackend, &rssBackend, &postBackend}, } subject := ServiceConfig{ Version: ConfigVersion, Timeout: 5 * time.Second, CacheTTL: 30 * time.Minute, Host: []string{"http://127.0.0.1:8080"}, Endpoints: []*EndpointConfig{&supuEndpoint, &githubEndpoint, &userEndpoint}, } if err := subject.Init(); err != nil { t.Error("Error at the configuration init:", err.Error()) } if len(supuBackend.Host) != 1 || supuBackend.Host[0] != subject.Host[0] { t.Error("Default hosts not applied to the supu backend", supuBackend.Host) } for level, method := range map[string]string{ "userBackend": userBackend.Method, "postBackend": postBackend.Method, "userEndpoint": userEndpoint.Method, } { if method != "GET" { t.Errorf("Default method not applied at %s. Get: %s", level, method) } } if supuBackend.Method != "post" { t.Error("unexpected supuBackend") } if userBackend.Timeout != subject.Timeout { t.Error("default timeout not applied to the userBackend") } if userEndpoint.CacheTTL != subject.CacheTTL { t.Error("default CacheTTL not applied to the userEndpoint") } hash, err := subject.Hash() if err != nil { t.Error(err.Error()) } if hash != "/X+fgDf29kmtPpCUh9DeJBOwewpExy3IGEjeqA9zExA=" { t.Errorf("unexpected hash: %s", hash) } } func TestConfig_initKONoBackends(t *testing.T) { subject := ServiceConfig{ Version: ConfigVersion, Host: []string{"http://127.0.0.1:8080"}, Endpoints: []*EndpointConfig{ { Endpoint: "/supu", Method: "POST", Backend: []*Backend{}, }, }, } if err := subject.Init(); err == nil || err.Error() != "ignoring the 'POST /supu' endpoint, since it has 0 backends defined!" { t.Error("Unexpected error at the configuration init!", err) } } func TestConfig_initKOMultipleBackendsForNoopEncoder(t *testing.T) { subject := ServiceConfig{ Version: ConfigVersion, Host: []string{"http://127.0.0.1:8080"}, Endpoints: []*EndpointConfig{ { Endpoint: "/supu", Method: "post", OutputEncoding: "no-op", Backend: []*Backend{ { Encoding: "no-op", }, { Encoding: "no-op", }, }, }, }, } if err := subject.Init(); err != errInvalidNoOpEncoding { t.Error("Expecting an error at the configuration init!", err) } } func TestConfig_initKOInvalidHost(t *testing.T) { subject := ServiceConfig{ Version: ConfigVersion, Host: []string{"http://127.0.0.1:8080http://127.0.0.1:8080"}, Endpoints: []*EndpointConfig{ { Endpoint: "/supu", Method: "post", Backend: []*Backend{}, }, }, } err := subject.Init() if err == nil { t.Errorf("expected to fail with invalid host") return } if !errors.Is(err, errInvalidHost) { t.Errorf("expected 'errInvalidHost' got: %s", err.Error()) return } } func TestConfig_initKOInvalidDebugPattern(t *testing.T) { dp := invalidPattern invalidPattern = "a(b" subject := ServiceConfig{ Version: ConfigVersion, Host: []string{"http://127.0.0.1:8080"}, Endpoints: []*EndpointConfig{ { Endpoint: "/__debug/supu", Method: "GET", Backend: []*Backend{}, }, }, } if err := subject.Init(); err == nil || err.Error() != "ignoring the 'GET /__debug/supu' endpoint due to a parsing error: error parsing regexp: missing closing ): `a(b`" { t.Error("Expecting an error at the configuration init!", err) } invalidPattern = dp } func TestConfig_initKOValidSetinvalidPattern(t *testing.T) { dp := invalidPattern invalidPattern = `^[^/]|/__(debug|echo|health)(/.*)?$` subject := ServiceConfig{ Version: ConfigVersion, Host: []string{"http://127.0.0.1:8080"}, Endpoints: []*EndpointConfig{ { Endpoint: "/*", Method: "GET", Backend: []*Backend{ { URLPattern: "/", Host: []string{"https://api.github.com"}, AllowList: []string{"authorizations_url", "code_search_url"}, }, }, }, }, } if err := subject.Init(); err != nil { t.Error(err) } invalidPattern = dp } ================================================ FILE: config/parser.go ================================================ // SPDX-License-Identifier: Apache-2.0 package config import ( "encoding/json" "fmt" "os" "time" ) // Parser reads a configuration file, parses it and returns the content as an init ServiceConfig struct type Parser interface { Parse(configFile string) (ServiceConfig, error) } // ParserFunc type is an adapter to allow the use of ordinary functions as subscribers. // If f is a function with the appropriate signature, ParserFunc(f) is a Parser that calls f. type ParserFunc func(string) (ServiceConfig, error) // Parse implements the Parser interface func (f ParserFunc) Parse(configFile string) (ServiceConfig, error) { return f(configFile) } // NewParser creates a new parser using the json library func NewParser() Parser { return NewParserWithFileReader(os.ReadFile) } // NewParserWithFileReader returns a Parser with the injected FileReaderFunc function func NewParserWithFileReader(f FileReaderFunc) Parser { return parser{fileReader: f} } type parser struct { fileReader FileReaderFunc } // Parser implements the Parse interface func (p parser) Parse(configFile string) (ServiceConfig, error) { var result ServiceConfig var cfg parseableServiceConfig data, err := p.fileReader(configFile) if err != nil { return result, CheckErr(err, configFile) } if err = json.Unmarshal(data, &cfg); err != nil { return result, CheckErr(err, configFile) } result = cfg.normalize() if err = result.Init(); err != nil { return result, CheckErr(err, configFile) } return result, nil } // CheckErr returns a proper documented error func CheckErr(err error, configFile string) error { switch e := err.(type) { case *json.SyntaxError: return NewParseError(err, configFile, int(e.Offset)) case *json.UnmarshalTypeError: return NewParseError(err, configFile, int(e.Offset)) case *os.PathError: return fmt.Errorf( "'%s' (%s): %s", configFile, e.Op, e.Err.Error(), ) default: return fmt.Errorf("'%s': %v", configFile, err) } } // NewParseError returns a new ParseError func NewParseError(err error, configFile string, offset int) *ParseError { b, _ := os.ReadFile(configFile) row, col := getErrorRowCol(b, offset) return &ParseError{ ConfigFile: configFile, Err: err, Offset: offset, Row: row, Col: col, } } func getErrorRowCol(source []byte, offset int) (row, col int) { if len(source) < offset { offset = len(source) - 1 } for i := 0; i < offset; i++ { v := source[i] if v == '\r' { continue } if v == '\n' { col = 0 row++ continue } col++ } return } // ParseError is an error containing details regarding the row and column where // an parse error occurred type ParseError struct { ConfigFile string Offset int Row int Col int Err error } // Error returns the error message for the ParseError func (p *ParseError) Error() string { return fmt.Sprintf( "'%s': %v, offset: %v, row: %v, col: %v", p.ConfigFile, p.Err.Error(), p.Offset, p.Row, p.Col, ) } // FileReaderFunc is a function used to read the content of a config file type FileReaderFunc func(string) ([]byte, error) type parseableServiceConfig struct { Name string `json:"name"` Endpoints []*parseableEndpointConfig `json:"endpoints"` AsyncAgents []*parseableAsyncAgent `json:"async_agent"` Timeout string `json:"timeout"` CacheTTL string `json:"cache_ttl"` Host []string `json:"host"` Port int `json:"port"` Address string `json:"listen_ip"` Version int `json:"version"` ExtraConfig *ExtraConfig `json:"extra_config,omitempty"` ReadTimeout string `json:"read_timeout"` WriteTimeout string `json:"write_timeout"` IdleTimeout string `json:"idle_timeout"` ReadHeaderTimeout string `json:"read_header_timeout"` MaxHeaderBytes int `json:"max_header_bytes"` DisableKeepAlives bool `json:"disable_keep_alives"` DisableCompression bool `json:"disable_compression"` DisableStrictREST bool `json:"disable_rest"` MaxIdleConns int `json:"max_idle_connections"` MaxIdleConnsPerHost int `json:"max_idle_connections_per_host"` IdleConnTimeout string `json:"idle_connection_timeout"` ResponseHeaderTimeout string `json:"response_header_timeout"` ExpectContinueTimeout string `json:"expect_continue_timeout"` OutputEncoding string `json:"output_encoding"` DialerTimeout string `json:"dialer_timeout"` DialerFallbackDelay string `json:"dialer_fallback_delay"` DialerKeepAlive string `json:"dialer_keep_alive"` Debug bool `json:"debug_endpoint"` Echo bool `json:"echo_endpoint"` Plugin *Plugin `json:"plugin,omitempty"` TLS *parseableTLS `json:"tls,omitempty"` ClientTLS *parseableClientTLS `json:"client_tls,omitempty"` UseH2C bool `json:"use_h2c,omitempty"` DNSCacheTTL string `json:"dns_cache_ttl"` MaxShutdownDuration string `json:"max_shutdown_wait_time"` } func (p *parseableServiceConfig) normalize() ServiceConfig { cfg := ServiceConfig{ Name: p.Name, Timeout: parseDuration(p.Timeout), CacheTTL: parseDuration(p.CacheTTL), Host: p.Host, Port: p.Port, Address: p.Address, Version: p.Version, Debug: p.Debug, Echo: p.Echo, ReadTimeout: parseDuration(p.ReadTimeout), WriteTimeout: parseDuration(p.WriteTimeout), IdleTimeout: parseDuration(p.IdleTimeout), ReadHeaderTimeout: parseDuration(p.ReadHeaderTimeout), MaxHeaderBytes: p.MaxHeaderBytes, DisableKeepAlives: p.DisableKeepAlives, DisableCompression: p.DisableCompression, DisableStrictREST: p.DisableStrictREST, MaxIdleConns: p.MaxIdleConns, MaxIdleConnsPerHost: p.MaxIdleConnsPerHost, IdleConnTimeout: parseDuration(p.IdleConnTimeout), ResponseHeaderTimeout: parseDuration(p.ResponseHeaderTimeout), ExpectContinueTimeout: parseDuration(p.ExpectContinueTimeout), DialerTimeout: parseDuration(p.DialerTimeout), DialerFallbackDelay: parseDuration(p.DialerFallbackDelay), DialerKeepAlive: parseDuration(p.DialerKeepAlive), OutputEncoding: p.OutputEncoding, Plugin: p.Plugin, UseH2C: p.UseH2C, DNSCacheTTL: parseDuration(p.DNSCacheTTL), MaxShutdownDuration: parseDuration(p.MaxShutdownDuration), } if p.TLS != nil { cfg.TLS = &TLS{ IsDisabled: p.TLS.IsDisabled, PublicKey: p.TLS.PublicKey, PrivateKey: p.TLS.PrivateKey, CaCerts: p.TLS.CaCerts, MinVersion: p.TLS.MinVersion, MaxVersion: p.TLS.MaxVersion, CurvePreferences: p.TLS.CurvePreferences, PreferServerCipherSuites: p.TLS.PreferServerCipherSuites, CipherSuites: p.TLS.CipherSuites, EnableMTLS: p.TLS.EnableMTLS, DisableSystemCaPool: p.TLS.DisableSystemCaPool, } for _, k := range p.TLS.Keys { cfg.TLS.Keys = append(cfg.TLS.Keys, TLSKeyPair(k)) } } if p.ClientTLS != nil { cfg.ClientTLS = &ClientTLS{ AllowInsecureConnections: p.ClientTLS.AllowInsecureConnections, CaCerts: p.ClientTLS.CaCerts, DisableSystemCaPool: p.ClientTLS.DisableSystemCaPool, MinVersion: p.ClientTLS.MinVersion, MaxVersion: p.ClientTLS.MaxVersion, CurvePreferences: p.ClientTLS.CurvePreferences, CipherSuites: p.ClientTLS.CipherSuites, ClientCerts: make([]ClientTLSCert, 0, len(p.ClientTLS.ClientCerts)), } for _, cc := range p.ClientTLS.ClientCerts { cfg.ClientTLS.ClientCerts = append(cfg.ClientTLS.ClientCerts, ClientTLSCert(cc)) } } if p.ExtraConfig != nil { cfg.ExtraConfig = *p.ExtraConfig } endpoints := make([]*EndpointConfig, 0, len(p.Endpoints)) for _, e := range p.Endpoints { endpoints = append(endpoints, e.normalize()) } cfg.Endpoints = endpoints agents := make([]*AsyncAgent, 0, len(p.AsyncAgents)) for _, a := range p.AsyncAgents { agents = append(agents, a.normalize()) } cfg.AsyncAgents = agents return cfg } type parseableTLSKeyPair struct { PublicKey string `json:"public_key"` PrivateKey string `json:"private_key"` } type parseableTLS struct { IsDisabled bool `json:"disabled"` PublicKey string `json:"public_key"` PrivateKey string `json:"private_key"` CaCerts []string `json:"ca_certs"` MinVersion string `json:"min_version"` MaxVersion string `json:"max_version"` CurvePreferences []uint16 `json:"curve_preferences"` PreferServerCipherSuites bool `json:"prefer_server_cipher_suites"` CipherSuites []uint16 `json:"cipher_suites"` EnableMTLS bool `json:"enable_mtls"` DisableSystemCaPool bool `json:"disable_system_ca_pool"` Keys []parseableTLSKeyPair `json:"keys"` } type parseableClientTLS struct { AllowInsecureConnections bool `json:"allow_insecure_connections"` CaCerts []string `json:"ca_certs"` DisableSystemCaPool bool `json:"disable_system_ca_pool"` MinVersion string `json:"min_version"` MaxVersion string `json:"max_version"` CurvePreferences []uint16 `json:"curve_preferences"` CipherSuites []uint16 `json:"cipher_suites"` ClientCerts []parseableClientTLSCert `json:"client_certs"` } type parseableClientTLSCert struct { Certificate string `json:"certificate"` PrivateKey string `json:"private_key"` } type parseableEndpointConfig struct { Endpoint string `json:"endpoint"` Method string `json:"method"` Backend []*parseableBackend `json:"backend"` ConcurrentCalls int `json:"concurrent_calls"` Timeout string `json:"timeout"` CacheTTL string `json:"cache_ttl"` QueryString []string `json:"input_query_strings"` ExtraConfig *ExtraConfig `json:"extra_config,omitempty"` HeadersToPass []string `json:"input_headers"` OutputEncoding string `json:"output_encoding"` } func (p *parseableEndpointConfig) normalize() *EndpointConfig { e := EndpointConfig{ Endpoint: p.Endpoint, Method: p.Method, ConcurrentCalls: p.ConcurrentCalls, Timeout: parseDuration(p.Timeout), CacheTTL: parseDuration(p.CacheTTL), QueryString: p.QueryString, HeadersToPass: p.HeadersToPass, OutputEncoding: p.OutputEncoding, } if p.ExtraConfig != nil { e.ExtraConfig = *p.ExtraConfig } backends := make([]*Backend, 0, len(p.Backend)) for _, b := range p.Backend { backends = append(backends, b.normalize()) } e.Backend = backends return &e } type parseableAsyncAgent struct { Name string `json:"name"` Connection struct { MaxRetries int `json:"max_retries"` BackoffStrategy string `json:"backoff_strategy"` HealthInterval string `json:"health_interval"` } `json:"connection"` Consumer struct { Timeout string `json:"timeout"` Workers int `json:"workers"` Topic string `json:"topic"` MaxRate float64 `json:"max_rate"` } `json:"consumer"` Encoding string `json:"encoding"` Backend []*parseableBackend `json:"backend"` ExtraConfig ExtraConfig `json:"extra_config"` } func (p *parseableAsyncAgent) normalize() *AsyncAgent { e := AsyncAgent{ Name: p.Name, Encoding: p.Encoding, Connection: Connection{ MaxRetries: p.Connection.MaxRetries, BackoffStrategy: p.Connection.BackoffStrategy, HealthInterval: parseDuration(p.Connection.HealthInterval), }, Consumer: Consumer{ Timeout: parseDuration(p.Consumer.Timeout), Workers: p.Consumer.Workers, Topic: p.Consumer.Topic, MaxRate: p.Consumer.MaxRate, }, } if p.ExtraConfig != nil { e.ExtraConfig = p.ExtraConfig } backends := make([]*Backend, 0, len(p.Backend)) for _, b := range p.Backend { backends = append(backends, b.normalize()) } e.Backend = backends return &e } type parseableBackend struct { Group string `json:"group"` Method string `json:"method"` Host []string `json:"host"` HostSanitizationDisabled bool `json:"disable_host_sanitize"` URLPattern string `json:"url_pattern"` AllowList []string `json:"allow"` DenyList []string `json:"deny"` Mapping map[string]string `json:"mapping"` Encoding string `json:"encoding"` IsCollection bool `json:"is_collection"` Target string `json:"target"` ExtraConfig *ExtraConfig `json:"extra_config,omitempty"` SD string `json:"sd"` HeadersToPass []string `json:"input_headers"` SDScheme string `json:"sd_scheme"` QueryStringsToPass []string `json:"input_query_strings"` } func (p *parseableBackend) normalize() *Backend { b := Backend{ Group: p.Group, Method: p.Method, Host: p.Host, HostSanitizationDisabled: p.HostSanitizationDisabled, URLPattern: p.URLPattern, Mapping: p.Mapping, Encoding: p.Encoding, IsCollection: p.IsCollection, Target: p.Target, SD: p.SD, SDScheme: p.SDScheme, AllowList: p.AllowList, DenyList: p.DenyList, HeadersToPass: p.HeadersToPass, QueryStringsToPass: p.QueryStringsToPass, } if b.SDScheme == "" { b.SDScheme = "http" } if p.ExtraConfig != nil { b.ExtraConfig = *p.ExtraConfig } return &b } func parseDuration(v string) time.Duration { d, err := time.ParseDuration(v) if err != nil { return 0 } return d } ================================================ FILE: config/parser_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package config import ( "os" "testing" ) func TestNewParser_ok(t *testing.T) { configPath := "/tmp/ok.json" configContent := []byte(`{ "version": 3, "name": "My lovely gateway", "port": 8080, "cache_ttl": "3600s", "timeout": "3s", "max_header_bytes": 10000, "tls": { "public_key": "cert.pem", "private_key": "key.pem" }, "async_agent": [ { "name": "agent", "connection": { "max_retries": 2 }, "consumer": { "topic": "foo.*" }, "backend": [ { "host": [ "https://api.github.com" ], "url_pattern": "/", "extra_config" : {"user":"test","hits":6,"parents":["gomez","morticia"]} } ] } ], "endpoints": [ { "endpoint": "/github", "method": "GET", "extra_config" : {"user":"test","hits":6,"parents":["gomez","morticia"]}, "backend": [ { "host": [ "https://api.github.com" ], "url_pattern": "/", "allow": [ "authorizations_url", "code_search_url" ], "extra_config" : {"user":"test","hits":6,"parents":["gomez","morticia"]} } ] }, { "endpoint": "/supu", "method": "GET", "concurrent_calls": 3, "backend": [ { "host": [ "http://127.0.0.1:8080" ], "url_pattern": "/__debug/supu" } ] }, { "endpoint": "/combination/{id}", "method": "GET", "backend": [ { "group": "first_post", "host": [ "https://jsonplaceholder.typicode.com" ], "url_pattern": "/posts/{id}", "deny": [ "userId" ] }, { "host": [ "https://jsonplaceholder.typicode.com" ], "url_pattern": "/users/{id}", "mapping": { "email": "personal_email" } } ] } ], "extra_config" : {"user":"test","hits":6,"parents":["gomez","morticia"]} }`) if err := os.WriteFile(configPath, configContent, 0644); err != nil { t.FailNow() } serviceConfig, err := NewParser().Parse(configPath) if err != nil { t.Error("Unexpected error. Got", err.Error()) } if serviceConfig.MaxHeaderBytes != 10000 { t.Errorf("unexpected max_header_bytes value. have %d, want 10000", serviceConfig.MaxHeaderBytes) } testExtraConfig(serviceConfig.ExtraConfig, t) if endpoints := len(serviceConfig.Endpoints); endpoints != 3 { t.Errorf("Unexpected number of endpoints: %d", endpoints) return } endpoint := serviceConfig.Endpoints[0] endpointExtraConfiguration := endpoint.ExtraConfig if endpointExtraConfiguration != nil { testExtraConfig(endpointExtraConfiguration, t) } else { t.Error("Extra config is not present in EndpointConfig") } if serviceConfig.TLS == nil { t.Error("TLS config not present") } else { if serviceConfig.TLS.PublicKey != "cert.pem" { t.Error("Unexpected TLS Public key") } if serviceConfig.TLS.PrivateKey != "key.pem" { t.Error("Unexpected TLS Private key") } } backend := endpoint.Backend[0] backendExtraConfiguration := backend.ExtraConfig if backendExtraConfiguration != nil { testExtraConfig(backendExtraConfiguration, t) } else { t.Error("Extra config is not present in BackendConfig") } if err := os.Remove(configPath); err != nil { t.FailNow() } if l := len(serviceConfig.AsyncAgents); l != 1 { t.Errorf("Unexpected number of agents. Have %d, want 1", l) } } func TestNewParser_errorMessages(t *testing.T) { for _, configContent := range []struct { name string path string content []byte expErr string }{ { name: "case0", path: "/tmp/ok.json", content: []byte(`{`), expErr: "'/tmp/ok.json': unexpected end of JSON input, offset: 1, row: 0, col: 1", }, { name: "case1", path: "/tmp/ok.json", content: []byte(`>`), expErr: "'/tmp/ok.json': invalid character '>' looking for beginning of value, offset: 1, row: 0, col: 1", }, { name: "case2", path: "/tmp/ok.json", content: []byte(`"`), expErr: "'/tmp/ok.json': unexpected end of JSON input, offset: 1, row: 0, col: 1", }, { name: "case3", path: "/tmp/ok.json", content: []byte(``), expErr: "'/tmp/ok.json': unexpected end of JSON input, offset: 0, row: 0, col: 0", }, { name: "case4", path: "/tmp/ok.json", content: []byte(`[{}]`), expErr: "'/tmp/ok.json': json: cannot unmarshal array into Go value of type config.parseableServiceConfig, offset: 1, row: 0, col: 1", }, { name: "case5", path: "/tmp/ok.json", content: []byte(`42`), expErr: "'/tmp/ok.json': json: cannot unmarshal number into Go value of type config.parseableServiceConfig, offset: 2, row: 0, col: 2", }, { name: "case6", path: "/tmp/ok.json", content: []byte("\r\n42"), expErr: "'/tmp/ok.json': json: cannot unmarshal number into Go value of type config.parseableServiceConfig, offset: 4, row: 1, col: 2", }, { name: "case7", path: "/tmp/ok.json", content: []byte(`{ "version": 3, "name": "My lovely gateway", "port": 8080, "cache_ttl": 3600 "timeout": "3s", "endpoints": [] }`), expErr: "'/tmp/ok.json': invalid character '\"' after object key:value pair, offset: 83, row: 5, col: 2", }, } { t.Run(configContent.name, func(t *testing.T) { if err := os.WriteFile(configContent.path, configContent.content, 0644); err != nil { t.Error(err) return } _, err := NewParser().Parse(configContent.path) if err == nil { t.Errorf("%s: Expecting error", configContent.name) return } if errMsg := err.Error(); errMsg != configContent.expErr { t.Errorf("%s: Unexpected error. Got '%s' want '%s'", configContent.name, errMsg, configContent.expErr) return } if err := os.Remove(configContent.path); err != nil { t.Errorf("%s: %s", err.Error(), configContent.name) return } }) } } func testExtraConfig(extraConfig map[string]interface{}, t *testing.T) { userVar := extraConfig["user"] if userVar != "test" { t.Error("User in extra config is not test") } parents, ok := extraConfig["parents"].([]interface{}) if !ok || parents[0] != "gomez" { t.Error("Parent 0 of user us not gomez") } if !ok || parents[1] != "morticia" { t.Error("Parent 1 of user us not morticia") } } func TestNewParser_unknownFile(t *testing.T) { _, err := NewParser().Parse("/nowhere/in/the/fs.json") if err == nil || err.Error() != "'/nowhere/in/the/fs.json' (open): no such file or directory" { t.Errorf("error expected. got '%v'", err) } } func TestNewParser_readingError(t *testing.T) { wrongConfigPath := "/tmp/reading.json" wrongConfigContent := []byte("{hello\ngo\n") if err := os.WriteFile(wrongConfigPath, wrongConfigContent, 0644); err != nil { t.FailNow() } expected := "'/tmp/reading.json': invalid character 'h' looking for beginning of object key string, offset: 2, row: 0, col: 2" _, err := NewParser().Parse(wrongConfigPath) if err == nil || err.Error() != expected { t.Error("Error expected. Got", err) } if err = os.Remove(wrongConfigPath); err != nil { t.FailNow() } } func TestNewParser_initError(t *testing.T) { wrongConfigPath := "/tmp/unmarshall.json" wrongConfigContent := []byte("{\"a\":42}") if err := os.WriteFile(wrongConfigPath, wrongConfigContent, 0644); err != nil { t.FailNow() } _, err := NewParser().Parse(wrongConfigPath) if err == nil || err.Error() != "'/tmp/unmarshall.json': unsupported version: 0 (want: 3)" { t.Error("Error expected. Got", err) } if err = os.Remove(wrongConfigPath); err != nil { t.FailNow() } } func TestParserFunc(t *testing.T) { expected := ServiceConfig{Version: 42} result, err := ParserFunc(func(_ string) (ServiceConfig, error) { return expected, nil })("path/to/the/config/file") if err != nil { t.Error(err.Error()) } if result.Version != expected.Version { t.Error("unexpected parsed config:", result) } } ================================================ FILE: config/uri.go ================================================ // SPDX-License-Identifier: Apache-2.0 package config import ( "fmt" "regexp" "strings" ) var ( endpointURLKeysPattern = regexp.MustCompile(`/\{([a-zA-Z\-_0-9]+)\}`) hostPattern = regexp.MustCompile(`(https?://)?([a-zA-Z0-9\._\-]+)(:[0-9]{2,6})?/?`) ) // URIParser defines the interface for all the URI manipulation required by KrakenD type URIParser interface { CleanHosts([]string) []string CleanHost(string) string CleanPath(string) string GetEndpointPath(string, []string) string } // Like URIParser but with safe versions of the clean host functionality that // does not panic but returns an error. type SafeURIParser interface { SafeCleanHosts([]string) ([]string, error) SafeCleanHost(string) (string, error) CleanPath(string) string GetEndpointPath(string, []string) string } // NewURIParser creates a new URIParser using the package variable RoutingPattern func NewURIParser() URIParser { return URI(RoutingPattern) } // NewSafeURIParser creates a safe URI parser that does not panic when cleaning hosts func NewSafeURIParser() URI { return URI(RoutingPattern) } // URI implements the URIParser interface type URI int // SafeCleanHosts applies the SafeCleanHost method to every member of the received array of hosts func (u URI) SafeCleanHosts(hosts []string) ([]string, error) { cleaned := make([]string, 0, len(hosts)) for i := range hosts { h, err := u.SafeCleanHost(hosts[i]) if err != nil { return nil, fmt.Errorf("host %s not valid: %w", hosts[i], errInvalidHost) } cleaned = append(cleaned, h) } return cleaned, nil } // CleanHosts applies the CleanHost method to every member of the received array of hosts // Panics in case of error. func (u URI) CleanHosts(hosts []string) []string { ss, e := u.SafeCleanHosts(hosts) if e != nil { panic(e) } return ss } // SafeCleanHost sanitizes the received host func (URI) SafeCleanHost(host string) (string, error) { matches := hostPattern.FindAllStringSubmatch(host, -1) if len(matches) != 1 { return "", errInvalidHost } keys := matches[0][1:] if keys[0] == "" { keys[0] = "http://" } return strings.Join(keys, ""), nil } // CleanHost sanitizes the received host. // Panics on error. func (u URI) CleanHost(host string) string { h, err := u.SafeCleanHost(host) if err != nil { panic(err) } return h } // CleanPath trims all the extra slashes from the received URI path func (URI) CleanPath(path string) string { return "/" + strings.TrimPrefix(path, "/") } // GetEndpointPath applies the proper replacement in the received path to generate valid route patterns func (u URI) GetEndpointPath(path string, params []string) string { result := path if u == ColonRouterPatternBuilder { for p := range params { parts := strings.Split(result, "?") parts[0] = strings.ReplaceAll(parts[0], "{"+params[p]+"}", ":"+params[p]) result = strings.Join(parts, "?") } } return result } ================================================ FILE: config/uri_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package config import "testing" func TestURIParser_cleanHosts(t *testing.T) { samples := []string{ "supu", "127.0.0.1", "https://supu.local/", "http://127.0.0.1", "supu_42.local:8080/", "http://127.0.0.1:8080", } expected := []string{ "http://supu", "http://127.0.0.1", "https://supu.local", "http://127.0.0.1", "http://supu_42.local:8080", "http://127.0.0.1:8080", } result := NewURIParser().CleanHosts(samples) for i := range result { if expected[i] != result[i] { t.Errorf("want: %s, have: %s\n", expected[i], result[i]) } } } func TestURIParser_cleanPath(t *testing.T) { samples := []string{ "supu/{tupu}", "supu/{tupu}{supu}", "/supu/{tupu}", "/supu.local/", "supu_supu.txt", "supu_42.local?a=8080", "supu/supu/supu?a=1&b=2", "debug/supu/supu?a=1&b=2", } expected := []string{ "/supu/{tupu}", "/supu/{tupu}{supu}", "/supu/{tupu}", "/supu.local/", "/supu_supu.txt", "/supu_42.local?a=8080", "/supu/supu/supu?a=1&b=2", "/debug/supu/supu?a=1&b=2", } subject := URI(BracketsRouterPatternBuilder) for i := range samples { if have := subject.CleanPath(samples[i]); expected[i] != have { t.Errorf("want: %s, have: %s\n", expected[i], have) } } } func TestURIParser_getEndpointPath(t *testing.T) { samples := []string{ "supu/{tupu}", "/supu/{tupu}{supu}", "/supu/{tupu}", "/supu.local/", "supu/{tupu}/{supu}?a={s}&b=2", } expected := []string{ "supu/:tupu", "/supu/:tupu{supu}", "/supu/:tupu", "/supu.local/", "supu/:tupu/:supu?a={s}&b=2", } sc := ServiceConfig{} subject := NewURIParser() for i := range samples { params := sc.extractPlaceHoldersFromURLTemplate(samples[i], sc.paramExtractionPattern()) if have := subject.GetEndpointPath(samples[i], params); expected[i] != have { t.Errorf("want: %s, have: %s\n", expected[i], have) } } } func TestURIParser_getEndpointPath_notStrictREST(t *testing.T) { samples := []string{ "supu/{tupu}", "/supu/{tupu}{supu}", "/supu/{tupu}", "/supu.local/", "supu/{tupu}/{supu}?a={s}&b=2", } expected := []string{ "supu/:tupu", "/supu/:tupu:supu", "/supu/:tupu", "/supu.local/", "supu/:tupu/:supu?a={s}&b=2", } sc := ServiceConfig{DisableStrictREST: true} subject := NewURIParser() for i := range samples { params := sc.extractPlaceHoldersFromURLTemplate(samples[i], sc.paramExtractionPattern()) if have := subject.GetEndpointPath(samples[i], params); expected[i] != have { t.Errorf("want: %s, have: %s\n", expected[i], have) } } } ================================================ FILE: core/version.go ================================================ // SPDX-License-Identifier: Apache-2.0 /* Package core contains some basic constants and variables */ package core import ( "fmt" "runtime" "strings" ) // KrakendHeaderName is the name of the custom KrakenD header const KrakendHeaderName = "X-KRAKEND" // KrakendVersion is the version of the build var KrakendVersion = "undefined" // GoVersion is the version of the go compiler used at build time var GoVersion = strings.TrimPrefix(runtime.Version(), "go") // GlibcVersion is the version of the glibc used by CGO at build time var GlibcVersion = "undefined" // KrakendHeaderValue is the value of the custom KrakenD header var KrakendHeaderValue = fmt.Sprintf("Version %s", KrakendVersion) // KrakendUserAgent is the value of the user agent header sent to the backends var KrakendUserAgent = fmt.Sprintf("KrakenD Version %s", KrakendVersion) ================================================ FILE: docs/BENCHMARKS.md ================================================ Benchmarks --- Here you'll find some benchmarks of the different components of the Lura framework in several scenarios. # Proxy components ## Proxy middleware stack BenchmarkProxyStack_single-8 500000 9106 ns/op 1848 B/op 35 allocs/op BenchmarkProxyStack_multi/with_1_backends-8 500000 9183 ns/op 1848 B/op 35 allocs/op BenchmarkProxyStack_multi/with_2_backends-8 300000 16130 ns/op 3520 B/op 73 allocs/op BenchmarkProxyStack_multi/with_3_backends-8 200000 20780 ns/op 5097 B/op 105 allocs/op BenchmarkProxyStack_multi/with_4_backends-8 200000 22420 ns/op 6641 B/op 137 allocs/op BenchmarkProxyStack_multi/with_5_backends-8 200000 23966 ns/op 8218 B/op 169 allocs/op ## Proxy middlewares BenchmarkNewLoadBalancedMiddleware-8 10000000 435 ns/op 328 B/op 6 allocs/op BenchmarkNewConcurrentMiddleware_singleNext-8 500000 9351 ns/op 1072 B/op 18 allocs/op BenchmarkNewRequestBuilderMiddleware-8 30000000 115 ns/op 160 B/op 2 allocs/op BenchmarkNewMergeDataMiddleware/with_2_parts-8 1000000 6746 ns/op 1360 B/op 20 allocs/op BenchmarkNewMergeDataMiddleware/with_3_parts-8 500000 10179 ns/op 1488 B/op 22 allocs/op BenchmarkNewMergeDataMiddleware/with_4_parts-8 500000 10299 ns/op 1584 B/op 24 allocs/op # Response manipulation ## Response property whitelisting BenchmarkEntityFormatter_whitelistingFilter/with_0_elements_with_0_extra_fields-8 50000000 80.6 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_whitelistingFilter/with_1_elements_with_0_extra_fields-8 10000000 441 ns/op 384 B/op 3 allocs/op BenchmarkEntityFormatter_whitelistingFilter/with_2_elements_with_0_extra_fields-8 10000000 474 ns/op 384 B/op 3 allocs/op BenchmarkEntityFormatter_whitelistingFilter/with_3_elements_with_0_extra_fields-8 10000000 516 ns/op 384 B/op 3 allocs/op BenchmarkEntityFormatter_whitelistingFilter/with_4_elements_with_0_extra_fields-8 10000000 519 ns/op 384 B/op 3 allocs/op BenchmarkEntityFormatter_whitelistingFilter/with_0_elements_with_5_extra_fields-8 50000000 84.3 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_whitelistingFilter/with_1_elements_with_5_extra_fields-8 10000000 565 ns/op 384 B/op 3 allocs/op BenchmarkEntityFormatter_whitelistingFilter/with_2_elements_with_5_extra_fields-8 10000000 601 ns/op 384 B/op 3 allocs/op BenchmarkEntityFormatter_whitelistingFilter/with_3_elements_with_5_extra_fields-8 10000000 638 ns/op 384 B/op 3 allocs/op BenchmarkEntityFormatter_whitelistingFilter/with_4_elements_with_5_extra_fields-8 10000000 627 ns/op 384 B/op 3 allocs/op BenchmarkEntityFormatter_whitelistingFilter/with_0_elements_with_10_extra_fields-8 50000000 80.7 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_whitelistingFilter/with_1_elements_with_10_extra_fields-8 10000000 703 ns/op 384 B/op 3 allocs/op BenchmarkEntityFormatter_whitelistingFilter/with_2_elements_with_10_extra_fields-8 5000000 746 ns/op 384 B/op 3 allocs/op BenchmarkEntityFormatter_whitelistingFilter/with_3_elements_with_10_extra_fields-8 5000000 779 ns/op 384 B/op 3 allocs/op BenchmarkEntityFormatter_whitelistingFilter/with_4_elements_with_10_extra_fields-8 5000000 785 ns/op 384 B/op 3 allocs/op BenchmarkEntityFormatter_whitelistingFilter/with_0_elements_with_15_extra_fields-8 50000000 81.4 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_whitelistingFilter/with_1_elements_with_15_extra_fields-8 5000000 845 ns/op 384 B/op 3 allocs/op BenchmarkEntityFormatter_whitelistingFilter/with_2_elements_with_15_extra_fields-8 5000000 886 ns/op 384 B/op 3 allocs/op BenchmarkEntityFormatter_whitelistingFilter/with_3_elements_with_15_extra_fields-8 5000000 919 ns/op 384 B/op 3 allocs/op BenchmarkEntityFormatter_whitelistingFilter/with_4_elements_with_15_extra_fields-8 5000000 929 ns/op 384 B/op 3 allocs/op BenchmarkEntityFormatter_whitelistingFilter/with_0_elements_with_20_extra_fields-8 50000000 80.9 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_whitelistingFilter/with_1_elements_with_20_extra_fields-8 5000000 988 ns/op 384 B/op 3 allocs/op BenchmarkEntityFormatter_whitelistingFilter/with_2_elements_with_20_extra_fields-8 5000000 984 ns/op 384 B/op 3 allocs/op BenchmarkEntityFormatter_whitelistingFilter/with_3_elements_with_20_extra_fields-8 5000000 998 ns/op 384 B/op 3 allocs/op BenchmarkEntityFormatter_whitelistingFilter/with_4_elements_with_20_extra_fields-8 5000000 1014 ns/op 384 B/op 3 allocs/op BenchmarkEntityFormatter_whitelistingFilter/with_0_elements_with_25_extra_fields-8 50000000 78.1 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_whitelistingFilter/with_1_elements_with_25_extra_fields-8 5000000 1149 ns/op 384 B/op 3 allocs/op BenchmarkEntityFormatter_whitelistingFilter/with_2_elements_with_25_extra_fields-8 3000000 1279 ns/op 384 B/op 3 allocs/op BenchmarkEntityFormatter_whitelistingFilter/with_3_elements_with_25_extra_fields-8 3000000 1348 ns/op 384 B/op 3 allocs/op BenchmarkEntityFormatter_whitelistingFilter/with_4_elements_with_25_extra_fields-8 3000000 1349 ns/op 384 B/op 3 allocs/op ## Response property blacklisting BenchmarkEntityFormatter_blacklistingFilter/with_0_elements_with_0_extra_fields-8 50000000 82.4 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_blacklistingFilter/with_1_elements_with_0_extra_fields-8 30000000 174 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_blacklistingFilter/with_2_elements_with_0_extra_fields-8 20000000 205 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_blacklistingFilter/with_3_elements_with_0_extra_fields-8 100000000 63.5 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_blacklistingFilter/with_4_elements_with_0_extra_fields-8 100000000 62.9 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_blacklistingFilter/with_0_elements_with_5_extra_fields-8 50000000 80.5 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_blacklistingFilter/with_1_elements_with_5_extra_fields-8 30000000 175 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_blacklistingFilter/with_2_elements_with_5_extra_fields-8 20000000 207 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_blacklistingFilter/with_3_elements_with_5_extra_fields-8 20000000 255 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_blacklistingFilter/with_4_elements_with_5_extra_fields-8 20000000 299 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_blacklistingFilter/with_0_elements_with_10_extra_fields-8 50000000 82.9 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_blacklistingFilter/with_1_elements_with_10_extra_fields-8 30000000 162 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_blacklistingFilter/with_2_elements_with_10_extra_fields-8 20000000 193 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_blacklistingFilter/with_3_elements_with_10_extra_fields-8 20000000 229 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_blacklistingFilter/with_4_elements_with_10_extra_fields-8 20000000 272 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_blacklistingFilter/with_0_elements_with_15_extra_fields-8 50000000 76.7 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_blacklistingFilter/with_1_elements_with_15_extra_fields-8 30000000 161 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_blacklistingFilter/with_2_elements_with_15_extra_fields-8 20000000 195 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_blacklistingFilter/with_3_elements_with_15_extra_fields-8 20000000 243 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_blacklistingFilter/with_4_elements_with_15_extra_fields-8 20000000 292 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_blacklistingFilter/with_0_elements_with_20_extra_fields-8 50000000 81.4 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_blacklistingFilter/with_1_elements_with_20_extra_fields-8 30000000 161 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_blacklistingFilter/with_2_elements_with_20_extra_fields-8 20000000 197 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_blacklistingFilter/with_3_elements_with_20_extra_fields-8 20000000 239 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_blacklistingFilter/with_4_elements_with_20_extra_fields-8 20000000 289 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_blacklistingFilter/with_0_elements_with_25_extra_fields-8 50000000 80.9 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_blacklistingFilter/with_1_elements_with_25_extra_fields-8 30000000 176 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_blacklistingFilter/with_2_elements_with_25_extra_fields-8 20000000 200 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_blacklistingFilter/with_3_elements_with_25_extra_fields-8 20000000 250 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_blacklistingFilter/with_4_elements_with_25_extra_fields-8 20000000 312 ns/op 48 B/op 1 allocs/op ## Response property grouping BenchmarkEntityFormatter_grouping/with_0_elements-8 20000000 277 ns/op 384 B/op 3 allocs/op BenchmarkEntityFormatter_grouping/with_5_elements-8 20000000 299 ns/op 384 B/op 3 allocs/op BenchmarkEntityFormatter_grouping/with_10_elements-8 20000000 300 ns/op 384 B/op 3 allocs/op BenchmarkEntityFormatter_grouping/with_15_elements-8 20000000 298 ns/op 384 B/op 3 allocs/op BenchmarkEntityFormatter_grouping/with_20_elements-8 20000000 298 ns/op 384 B/op 3 allocs/op BenchmarkEntityFormatter_grouping/with_25_elements-8 20000000 298 ns/op 384 B/op 3 allocs/op ## Response property mapping BenchmarkEntityFormatter_mapping/with_0_elements_with_0_extra_fields-8 100000000 61.1 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_1_elements_with_0_extra_fields-8 100000000 63.5 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_2_elements_with_0_extra_fields-8 100000000 61.8 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_3_elements_with_0_extra_fields-8 100000000 63.9 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_4_elements_with_0_extra_fields-8 100000000 63.7 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_5_elements_with_0_extra_fields-8 100000000 64.0 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_0_elements_with_5_extra_fields-8 50000000 81.4 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_1_elements_with_5_extra_fields-8 20000000 177 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_2_elements_with_5_extra_fields-8 20000000 204 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_3_elements_with_5_extra_fields-8 20000000 233 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_4_elements_with_5_extra_fields-8 20000000 266 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_5_elements_with_5_extra_fields-8 20000000 295 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_0_elements_with_10_extra_fields-8 50000000 77.4 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_1_elements_with_10_extra_fields-8 30000000 163 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_2_elements_with_10_extra_fields-8 20000000 198 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_3_elements_with_10_extra_fields-8 20000000 237 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_4_elements_with_10_extra_fields-8 20000000 298 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_5_elements_with_10_extra_fields-8 20000000 331 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_0_elements_with_15_extra_fields-8 50000000 79.5 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_1_elements_with_15_extra_fields-8 30000000 171 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_2_elements_with_15_extra_fields-8 20000000 212 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_3_elements_with_15_extra_fields-8 20000000 265 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_4_elements_with_15_extra_fields-8 20000000 295 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_5_elements_with_15_extra_fields-8 20000000 340 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_0_elements_with_20_extra_fields-8 50000000 77.5 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_1_elements_with_20_extra_fields-8 30000000 163 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_2_elements_with_20_extra_fields-8 20000000 199 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_3_elements_with_20_extra_fields-8 20000000 237 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_4_elements_with_20_extra_fields-8 20000000 287 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_5_elements_with_20_extra_fields-8 20000000 320 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_0_elements_with_25_extra_fields-8 50000000 83.2 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_1_elements_with_25_extra_fields-8 30000000 181 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_2_elements_with_25_extra_fields-8 20000000 222 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_3_elements_with_25_extra_fields-8 20000000 275 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_4_elements_with_25_extra_fields-8 20000000 292 ns/op 48 B/op 1 allocs/op BenchmarkEntityFormatter_mapping/with_5_elements_with_25_extra_fields-8 20000000 339 ns/op 48 B/op 1 allocs/op # Request generator BenchmarkRequestGeneratePath//a-8 10000000 460 ns/op 96 B/op 10 allocs/op BenchmarkRequestGeneratePath//a/{{.Supu}}-8 10000000 522 ns/op 106 B/op 10 allocs/op BenchmarkRequestGeneratePath//a?b={{.Tupu}}-8 10000000 567 ns/op 136 B/op 10 allocs/op BenchmarkRequestGeneratePath//a/{{.Supu}}/foo/{{.Foo}}-8 10000000 615 ns/op 182 B/op 10 allocs/op BenchmarkRequestGeneratePath//a/{{.Supu}}/foo/{{.Foo}}/b?c={{.Tupu}}-8 10000000 655 ns/op 236 B/op 10 allocs/op # Router Handlers ## Gin BenchmarkEndpointHandler_ko-8 1000000 5440 ns/op 3026 B/op 31 allocs/op BenchmarkEndpointHandler_ok-8 1000000 6456 ns/op 3393 B/op 36 allocs/op BenchmarkEndpointHandler_ko_Parallel-8 5000000 1534 ns/op 3028 B/op 31 allocs/op BenchmarkEndpointHandler_ok_Parallel-8 5000000 1846 ns/op 3393 B/op 36 allocs/op ## Mux BenchmarkEndpointHandler_ko-8 5000000 1815 ns/op 1088 B/op 13 allocs/op BenchmarkEndpointHandler_ok-8 5000000 1693 ns/op 1088 B/op 13 allocs/op BenchmarkEndpointHandler_ko_Parallel-8 20000000 558 ns/op 1088 B/op 13 allocs/op BenchmarkEndpointHandler_ok_Parallel-8 20000000 597 ns/op 1088 B/op 13 allocs/op ================================================ FILE: docs/CONFIG.md ================================================ # Configuration file The configuration file needs to be a `json` file. The viper parser supports other formats but they haven't been as tested as the recommended one. ## Json example { "version": 3, "name": "My lovely gateway", "port": 8080, "timeout": "10s", "cache_ttl": "3600s", "host": [ "http://127.0.0.1:8080", "http://127.0.0.2:8000", "http://127.0.0.3:9000", "http://127.0.0.4" ], "endpoints": [{ "endpoint": "/users/{user}", "method": "GET", "backend": [{ "host": [ "http://127.0.0.3:9000", "http://127.0.0.4" ], "url_pattern": "/registered/{user}", "allow": [ "some", "what" ], "mapping": { "email": "personal_email" } }, { "host": [ "http://127.0.0.1:8080" ], "url_pattern": "/users/{user}/permissions", "deny": [ "spam2", "notwanted2" ] } ], "concurrent_calls": 2, "timeout": "1000s", "cache_ttl": 3600, "input_query_strings": [ "page", "limit" ] }, { "endpoint": "/foo/bar", "method": "POST", "backend": [{ "host": [ "https://127.0.0.1:8081" ], "url_pattern": "/__debug/tupu" }], "concurrent_calls": 1, "timeout": "1000s", "cache_ttl": 3600 }, { "endpoint": "/github", "method": "GET", "backend": [{ "host": [ "https://api.github.com" ], "url_pattern": "/", "allow": [ "authorizations_url", "code_search_url" ] }], "concurrent_calls": 2, "timeout": "1000s", "cache_ttl": 3600 }, { "endpoint": "/combination/{id}/{supu}", "method": "GET", "backend": [{ "group": "first_post", "host": [ "https://jsonplaceholder.typicode.com" ], "url_pattern": "/posts/{id}?supu={supu}", "deny": [ "userId" ] }, { "host": [ "https://jsonplaceholder.typicode.com" ], "url_pattern": "/users/{id}", "mapping": { "email": "personal_email" } } ], "concurrent_calls": 3, "timeout": "1000s", "input_query_strings": [ "page", "limit" ] } ]} ================================================ FILE: docs/OVERVIEW.md ================================================ # Overview ## The Lura rules * [Reactive is key](http://www.reactivemanifesto.org/) * Reactive is key (yes, it is very very important) * Failing fast is better than succeeding slow (say it one more time!) * The simpler, the better * Everything is plugglable * Each request must be processed in its own request-scoped context ## The big picture The Lura framework is composed of a set of packages designed as building blocks for creating pipes and processors between an exposed endpoint and one or several API resources served by your backends. The most important packages are: 1. the `config` package defines the service. 2. the `router` package sets up the endpoints exposed to the clients. 3. the `proxy` package adds the required middlewares and components for further processing of the requests to send and the received responses sent by the backends, and also to manage the connections against those backends. The rest of the packages of the framework contain some helpers and adapters for complementary tasks, like encoding, logging or service discovery. ## The `config` package The `config` package contains the structs required for the service description. The `ServiceConfig` struct defines the entire service. It should be initialized before using it in order to be sure that all parameters have been normalized and default values have been applied. The `config` package also defines an interface for a file config parser and a parser based on the [viper](https://github.com/spf13/viper) library. ## The `router` package The `router` package contains an interface and several implementations for the Lura router layer using the `mux` router from the `net/http` and the `httprouter` wrapped in the `gin` framework. The router layer is responsible for setting up the HTTP(S) services, binding the endpoints defined at the `ServiceConfig` struct and transforming the http request into proxy requests before delegating the task to the inner layer (proxy). Once the internal proxy layer returns a proxy response, the router layer converts it into a proper HTTP response and sends it to the user. This layer can be easily extended in order to use any HTTP router, framework or middleware of your choice. Adding transport layer adapters for other protocols (Thrift, gRPC, AMQP, NATS, etc) is in the roadmap. As always, PRs are welcome! ## The `proxy` package The `proxy` package is where the most part of the Lura components and features are placed. It defines two important interfaces, designed to be stacked: * *Proxy* is a function that converts a given context and request into a response. * *Middleware* is a function that accepts one or more proxies and returns a single proxy wrapping them. This layer transforms the request received from the outter layer (router) into a single or several requests to your backend services, processes the responses and returns a single response. Middlewares generates custom proxies that are chained depending on the workflow defined in the configuration until each possible branch ends in a transport-related proxy. Every one of these generated proxies is able to transform the input or even clone it several times and pass it or them to the next element in the chain. Finally, they can also modify the received response or responses adding all kinds of features to the generated pipe. The Lura framework provides a default implementation of the proxy stack factory. ### Middlewares available * The `balancing` middleware uses some type of strategy for selecting a backend host to query. * The `concurrent` middleware improves the QoS by sending several concurrent requests to the next step of the chain and returning the first succesful response using a timeout for canceling the generated workload. * The `logging` middleware logs the received request and response and also the duration of the segment execution. * The `merging` middleware is a fork-and-join middleware. It is intended to split the process of the request into several concurrent processes, each one against a different backend, and to merge all the received responses from those created pipes into a single one. It applies a timeout, as the `concurrent` one does. * The `http` middleware completes the received proxy request by replacing the parameters extracted from the user request in the defined `URLPattern`. ### Proxies available * The `http` proxy translates a proxy request into an HTTP one, sends it to the backend API using a `HTTPClientFactory`, decodes the returned HTTP response with a `Decoder`, manipulates the response data with an `EntityFormatter` and returns it to the caller. ### Other components of the `proxy` package The `proxy` package also defines the `EntityFormatter`, the block responsible for enabling a powerful and fast response manipulation. ================================================ FILE: docs/README.md ================================================ # The Lura Project ## How to use it Visit the [framework overview](/docs/OVERVIEW.md) for details about the components of the Lura project. A good example about how to use it can be found in the [KrakenD CE](https://github.com/krakend/krakend-ce) API Gateway project. ## Configuration file [Lura config file](/docs/CONFIG.md). ## Benchmarks Check out the [benchmark results](/docs/BENCHMARKS.md) of several Lura components. ## Contributing Read the guidelines about [contributing](../CONTRIBUTING.md). ================================================ FILE: encoding/encoding.go ================================================ // SPDX-License-Identifier: Apache-2.0 /* Package encoding provides basic decoding implementations. Decode decodes HTTP responses: resp, _ := http.Get("http://api.example.com/") ... var data map[string]interface{} err := JSONDecoder(resp.Body, &data) */ package encoding import ( "encoding/json" "io" ) // Decoder is a function that reads from the reader and decodes it // into an map of interfaces type Decoder func(io.Reader, *map[string]interface{}) error // DecoderFactory is a function that returns CollectionDecoder or an EntityDecoder type DecoderFactory func(bool) func(io.Reader, *map[string]interface{}) error // NOOP is the key for the NoOp encoding const NOOP = "no-op" // NoOpDecoder is a decoder that does nothing func NoOpDecoder(_ io.Reader, _ *map[string]interface{}) error { return nil } func noOpDecoderFactory(_ bool) func(io.Reader, *map[string]interface{}) error { return NoOpDecoder } // JSON is the key for the json encoding const JSON = "json" // NewJSONDecoder returns the right JSON decoder func NewJSONDecoder(isCollection bool) func(io.Reader, *map[string]interface{}) error { if isCollection { return JSONCollectionDecoder } return JSONDecoder } // JSONDecoder decodes a json message into a map func JSONDecoder(r io.Reader, v *map[string]interface{}) error { d := json.NewDecoder(r) d.UseNumber() return d.Decode(v) } // JSONCollectionDecoder decodes a json collection and returns a map with the array at the 'collection' key func JSONCollectionDecoder(r io.Reader, v *map[string]interface{}) error { var collection []interface{} d := json.NewDecoder(r) d.UseNumber() if err := d.Decode(&collection); err != nil { return err } *(v) = map[string]interface{}{"collection": collection} return nil } // SAFE_JSON is the key for the json encoding const SAFE_JSON = "safejson" // NewSafeJSONDecoder returns the universal json decoder func NewSafeJSONDecoder(_ bool) func(io.Reader, *map[string]interface{}) error { return SafeJSONDecoder } // SafeJSONDecoder decodes both json objects and collections func SafeJSONDecoder(r io.Reader, v *map[string]interface{}) error { d := json.NewDecoder(r) d.UseNumber() var t interface{} if err := d.Decode(&t); err != nil { return err } switch tt := t.(type) { case map[string]interface{}: *v = tt case []interface{}: *v = map[string]interface{}{"collection": tt} default: *v = map[string]interface{}{"content": tt} } return nil } // STRING is the key for the string encoding const STRING = "string" // NewStringDecoder returns a String decoder func NewStringDecoder(_ bool) func(io.Reader, *map[string]interface{}) error { return StringDecoder } // StringDecoder returns a map with the content of the reader under the key 'content' func StringDecoder(r io.Reader, v *map[string]interface{}) error { data, err := io.ReadAll(r) if err != nil { return err } *(v) = map[string]interface{}{"content": string(data)} return nil } ================================================ FILE: encoding/encoding_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package encoding import ( "errors" "io" "strings" "testing" "github.com/luraproject/lura/v2/register" ) func TestNoOpDecoder(t *testing.T) { decoders = initDecoderRegister() defer func() { decoders = initDecoderRegister() }() d := decoders.Get(NOOP)(false) errorMsg := erroredReader("this error should never been sent") var result map[string]interface{} if err := d(errorMsg, &result); err != nil { t.Error("Unexpected error:", err.Error()) } if result != nil { t.Error("Unexpected value:", result) } } func TestRegister(t *testing.T) { decoders = initDecoderRegister() defer func() { decoders = initDecoderRegister() }() original := GetRegister() if len(original.data.Clone()) != 4 { t.Error("Unexpected number of registered factories:", len(original.data.Clone())) } decoders = &DecoderRegister{data: register.NewUntyped()} decoders.Register("some", NewJSONDecoder) if len(decoders.data.Clone()) != 1 { t.Error("Unexpected number of registered factories:", len(decoders.data.Clone())) } } func TestGet(t *testing.T) { decoders = initDecoderRegister() defer func() { decoders = initDecoderRegister() }() if len(decoders.data.Clone()) != 4 { t.Error("Unexpected number of registered factories:", len(decoders.data.Clone())) } checkDecoder(t, JSON) checkDecoder(t, "some") decoders = &DecoderRegister{data: register.NewUntyped()} decoders.Register("some", NewJSONDecoder) if len(decoders.data.Clone()) != 1 { t.Error("Unexpected number of registered factories:", len(decoders.data.Clone())) } checkDecoder(t, JSON) checkDecoder(t, "some") } func TestRegister_complete_ok(t *testing.T) { decoders = initDecoderRegister() defer func() { decoders = initDecoderRegister() }() expectedMsg := "a custom message to decode" expectedResponse := map[string]interface{}{"a": 42} if err := decoders.Register("custom", func(_ bool) func(io.Reader, *map[string]interface{}) error { return func(r io.Reader, v *map[string]interface{}) error { d, err := io.ReadAll(r) if err != nil { t.Error(err) return err } if expectedMsg != string(d) { t.Errorf("unexpected msg: %s", string(d)) return errors.New("unexpected msg to decode") } *v = expectedResponse return nil } }); err != nil { t.Error(err) return } decoder := decoders.Get("custom")(false) input := strings.NewReader(expectedMsg) var result map[string]interface{} if err := decoder(input, &result); err != nil { t.Error("Unexpected error:", err.Error()) } if v, ok := result["a"]; !ok || v.(int) != 42 { t.Error("Unexpected value:", result) } } func TestRegister_complete_ko(t *testing.T) { decoders = initDecoderRegister() defer func() { decoders = initDecoderRegister() }() expectedMsg := "a custom message to decode" expectedErr := errors.New("expect me") if err := decoders.Register("custom", func(_ bool) func(io.Reader, *map[string]interface{}) error { return func(r io.Reader, v *map[string]interface{}) error { d, err := io.ReadAll(r) if err != nil { t.Error(err) return err } if expectedMsg != string(d) { t.Errorf("unexpected msg: %s", string(d)) return errors.New("unexpected msg to decode") } // v = nil return expectedErr } }); err != nil { t.Error(err) return } decoder := decoders.Get("custom")(false) input := strings.NewReader(expectedMsg) var result map[string]interface{} if err := decoder(input, &result); err != expectedErr { t.Error("Unexpected error:", err) } if result != nil { t.Error("Unexpected value:", result) } } func checkDecoder(t *testing.T, name string) { d := decoders.Get(name)(false) input := strings.NewReader(`{"foo": "bar"}`) var result map[string]interface{} if err := d(input, &result); err != nil { t.Error("Unexpected error:", err.Error()) } if result["foo"] != "bar" { t.Error("Unexpected value:", result["foo"]) } } type erroredReader string func (e erroredReader) Error() string { return string(e) } func (e erroredReader) Read(_ []byte) (n int, err error) { return 0, e } ================================================ FILE: encoding/json_benchmark_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package encoding import ( "io" "strings" "testing" ) func BenchmarkDecoder(b *testing.B) { for _, dec := range []struct { name string decoder func(io.Reader, *map[string]interface{}) error }{ { name: "json-collection", decoder: NewJSONDecoder(true), }, { name: "json-map", decoder: NewJSONDecoder(false), }, { name: "safe-json-collection", decoder: NewSafeJSONDecoder(true), }, { name: "safe-json-map", decoder: NewSafeJSONDecoder(true), }, } { for _, tc := range []struct { name string input string }{ { name: "collection", input: `["a","b","c"]`, }, { name: "map", input: `{"foo": "bar", "supu": false, "tupu": 4.20}`, }, } { b.Run(dec.name+"/"+tc.name, func(b *testing.B) { var result map[string]interface{} for i := 0; i < b.N; i++ { _ = dec.decoder(strings.NewReader(tc.input), &result) } }) } } } ================================================ FILE: encoding/json_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package encoding import ( "encoding/json" "fmt" "strings" "testing" ) func ExampleNewJSONDecoder_map() { decoder := NewJSONDecoder(false) original := strings.NewReader(`{"foo": "bar", "supu": false, "tupu": 4.20}`) var result map[string]interface{} if err := decoder(original, &result); err != nil { fmt.Println("Unexpected error:", err.Error()) } fmt.Printf("%+v\n", result) // output: // map[foo:bar supu:false tupu:4.20] } func ExampleNewJSONDecoder_collection() { decoder := NewJSONDecoder(true) original := strings.NewReader(`["foo", "bar", "supu"]`) var result map[string]interface{} if err := decoder(original, &result); err != nil { fmt.Println("Unexpected error:", err.Error()) } fmt.Printf("%+v\n", result) // output: // map[collection:[foo bar supu]] } func TestNewJSONDecoder_map(t *testing.T) { decoder := NewJSONDecoder(false) original := strings.NewReader(`{"foo": "bar", "supu": false, "tupu": 4.20}`) var result map[string]interface{} if err := decoder(original, &result); err != nil { t.Error("Unexpected error:", err.Error()) } if len(result) != 3 { t.Error("Unexpected result:", result) } if v, ok := result["foo"]; !ok || v.(string) != "bar" { t.Error("wrong result:", result) } if v, ok := result["supu"]; !ok || v.(bool) { t.Error("wrong result:", result) } if v, ok := result["tupu"]; !ok || v.(json.Number).String() != "4.20" { t.Error("wrong result:", result) } } func TestNewJSONDecoder_collection(t *testing.T) { decoder := NewJSONDecoder(true) original := strings.NewReader(`["foo", "bar", "supu"]`) var result map[string]interface{} if err := decoder(original, &result); err != nil { t.Error("Unexpected error:", err.Error()) } if len(result) != 1 { t.Error("Unexpected result:", result) } v, ok := result["collection"] if !ok { t.Error("wrong result:", result) } embedded := v.([]interface{}) if embedded[0].(string) != "foo" { t.Error("wrong result:", result) } if embedded[1].(string) != "bar" { t.Error("wrong result:", result) } if embedded[2].(string) != "supu" { t.Error("wrong result:", result) } } func TestNewJSONDecoder_ko(t *testing.T) { decoder := NewJSONDecoder(true) original := strings.NewReader(`3`) var result map[string]interface{} if err := decoder(original, &result); err == nil { t.Error("Expecting error!") } } func ExampleNewSafeJSONDecoder() { decoder := NewSafeJSONDecoder(true) for _, body := range []string{ `{"foo": "bar", "supu": false, "tupu": 4.20}`, `["foo", "bar", "supu"]`, } { var result map[string]interface{} if err := decoder(strings.NewReader(body), &result); err != nil { fmt.Println("Unexpected error:", err.Error()) } fmt.Printf("%+v\n", result) } // output: // map[foo:bar supu:false tupu:4.20] // map[collection:[foo bar supu]] } func TestNewSafeJSONDecoder_map(t *testing.T) { decoder := NewSafeJSONDecoder(false) original := strings.NewReader(`{"foo": "bar", "supu": false, "tupu": 4.20}`) var result map[string]interface{} if err := decoder(original, &result); err != nil { t.Error("Unexpected error:", err.Error()) } if len(result) != 3 { t.Error("Unexpected result:", result) } if v, ok := result["foo"]; !ok || v.(string) != "bar" { t.Error("wrong result:", result) } if v, ok := result["supu"]; !ok || v.(bool) { t.Error("wrong result:", result) } if v, ok := result["tupu"]; !ok || v.(json.Number).String() != "4.20" { t.Error("wrong result:", result) } } func TestNewSafeJSONDecoder_collection(t *testing.T) { decoder := NewSafeJSONDecoder(true) original := strings.NewReader(`["foo", "bar", "supu"]`) var result map[string]interface{} if err := decoder(original, &result); err != nil { t.Error("Unexpected error:", err.Error()) } if len(result) != 1 { t.Error("Unexpected result:", result) } v, ok := result["collection"] if !ok { t.Error("wrong result:", result) } embedded := v.([]interface{}) if embedded[0].(string) != "foo" { t.Error("wrong result:", result) } if embedded[1].(string) != "bar" { t.Error("wrong result:", result) } if embedded[2].(string) != "supu" { t.Error("wrong result:", result) } } func TestNewSafeJSONDecoder_other(t *testing.T) { decoder := NewSafeJSONDecoder(true) original := strings.NewReader(`3`) var result map[string]interface{} if err := decoder(original, &result); err != nil { t.Error("Unexpected error:", err.Error()) } if v, ok := result["content"]; !ok || v.(json.Number).String() != "3" { t.Error("wrong result:", result) } } ================================================ FILE: encoding/register.go ================================================ // SPDX-License-Identifier: Apache-2.0 package encoding import ( "io" "github.com/luraproject/lura/v2/register" ) // GetRegister returns the package register func GetRegister() *DecoderRegister { return decoders } type untypedRegister interface { Register(name string, v interface{}) Get(name string) (interface{}, bool) Clone() map[string]interface{} } // DecoderRegister is the struct responsible of registering the decoder factories type DecoderRegister struct { data untypedRegister } // Register adds a decoder factory to the register func (r *DecoderRegister) Register(name string, dec func(bool) func(io.Reader, *map[string]interface{}) error) error { r.data.Register(name, dec) return nil } // Get returns a decoder factory from the register by name. If no factory is found, it returns a JSON decoder factory func (r *DecoderRegister) Get(name string) func(bool) func(io.Reader, *map[string]interface{}) error { for _, n := range []string{name, JSON} { if v, ok := r.data.Get(n); ok { if dec, ok := v.(func(bool) func(io.Reader, *map[string]interface{}) error); ok { return dec } } } return NewJSONDecoder } var ( decoders = initDecoderRegister() defaultDecoders = map[string]func(bool) func(io.Reader, *map[string]interface{}) error{ JSON: NewJSONDecoder, SAFE_JSON: NewSafeJSONDecoder, STRING: NewStringDecoder, NOOP: noOpDecoderFactory, } ) func initDecoderRegister() *DecoderRegister { r := &DecoderRegister{data: register.NewUntyped()} for k, v := range defaultDecoders { r.Register(k, v) } return r } ================================================ FILE: go.mod ================================================ module github.com/luraproject/lura/v2 go 1.25.0 require ( github.com/dimfeld/httptreemux/v5 v5.5.0 github.com/gin-contrib/sse v1.1.0 // indirect github.com/gin-gonic/gin v1.12.0 github.com/go-chi/chi/v5 v5.2.2 github.com/gorilla/mux v1.8.1 github.com/mattn/go-isatty v0.0.20 // indirect github.com/urfave/negroni/v2 v2.0.2 github.com/valyala/fastrand v1.1.0 ) require ( github.com/krakend/flatmap v1.2.0 golang.org/x/net v0.51.0 golang.org/x/sync v0.19.0 golang.org/x/text v0.34.0 ) require ( github.com/bytedance/gopkg v0.1.3 // indirect github.com/bytedance/sonic v1.15.0 // indirect github.com/bytedance/sonic/loader v0.5.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/gabriel-vasile/mimetype v1.4.12 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.30.1 // indirect github.com/goccy/go-json v0.10.5 // indirect github.com/goccy/go-yaml v1.19.2 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/quic-go/qpack v0.6.0 // indirect github.com/quic-go/quic-go v0.59.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.3.1 // indirect go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect golang.org/x/arch v0.22.0 // indirect golang.org/x/crypto v0.48.0 // indirect golang.org/x/sys v0.41.0 // indirect google.golang.org/protobuf v1.36.10 // indirect ) ================================================ FILE: go.sum ================================================ github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE= github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k= github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE= github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dimfeld/httptreemux/v5 v5.5.0 h1:p8jkiMrCuZ0CmhwYLcbNbl7DDo21fozhKHQ2PccwOFQ= github.com/dimfeld/httptreemux/v5 v5.5.0/go.mod h1:QeEylH57C0v3VO0tkKraVz9oD3Uu93CKPnTLbsidvSw= github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw= github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s= github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w= github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM= github.com/gin-gonic/gin v1.12.0 h1:b3YAbrZtnf8N//yjKeU2+MQsh2mY5htkZidOM7O0wG8= github.com/gin-gonic/gin v1.12.0/go.mod h1:VxccKfsSllpKshkBWgVgRniFFAzFb9csfngsqANjnLc= github.com/go-chi/chi/v5 v5.2.2 h1:CMwsvRVTbXVytCk1Wd72Zy1LAsAh9GxMmSNWLHCG618= github.com/go-chi/chi/v5 v5.2.2/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.30.1 h1:f3zDSN/zOma+w6+1Wswgd9fLkdwy06ntQJp0BBvFG0w= github.com/go-playground/validator/v10 v10.30.1/go.mod h1:oSuBIQzuJxL//3MelwSLD5hc2Tu889bF0Idm9Dg26cM= github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM= github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/krakend/flatmap v1.2.0 h1:4NPncAKH7Ca/t878kbGlc/LPWLa+m4sgBhs8aT2Q1SY= github.com/krakend/flatmap v1.2.0/go.mod h1:FyCOoggdVlWr31+aQaOFvBxlMgYfCE5yuwInLbW1/jM= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8= github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII= github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw= github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.3.1 h1:waO7eEiFDwidsBN6agj1vJQ4AG7lh2yqXyOXqhgQuyY= github.com/ugorji/go/codec v1.3.1/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4= github.com/urfave/negroni/v2 v2.0.2 h1:27gJcVxYJ2a/ytEoCHoJ7ybvyhymV4cAhGuMxkyCsrU= github.com/urfave/negroni/v2 v2.0.2/go.mod h1:SjdApKzYrObukpN/NnlejbQiZWIUjfDFzQltScGYigI= github.com/valyala/fastrand v1.1.0 h1:f+5HkLW4rsgzdNoleUOB69hyT9IlD2ZQh9GyDMfb5G8= github.com/valyala/fastrand v1.1.0/go.mod h1:HWqCzkrkg6QXT8V2EXWvXCoow7vLwOFN002oeRzjapQ= go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE= go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= golang.org/x/arch v0.22.0 h1:c/Zle32i5ttqRXjdLyyHZESLD/bB90DCU1g9l/0YBDI= golang.org/x/arch v0.22.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo= golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= ================================================ FILE: logging/log.go ================================================ // SPDX-License-Identifier: Apache-2.0 /* Package logging provides a simple logger interface and implementations */ package logging import ( "errors" "io" "log" "os" "strings" ) // Logger collects logging information at several levels type Logger interface { Debug(v ...interface{}) Info(v ...interface{}) Warning(v ...interface{}) Error(v ...interface{}) Critical(v ...interface{}) Fatal(v ...interface{}) } const ( // LEVEL_DEBUG = 0 LEVEL_DEBUG = iota // LEVEL_INFO = 1 LEVEL_INFO // LEVEL_WARNING = 2 LEVEL_WARNING // LEVEL_ERROR = 3 LEVEL_ERROR // LEVEL_CRITICAL = 4 LEVEL_CRITICAL ) var ( // ErrInvalidLogLevel is used when an invalid log level has been used. ErrInvalidLogLevel = errors.New("invalid log level") defaultLogger = BasicLogger{Level: LEVEL_CRITICAL, Logger: log.New(os.Stderr, "", log.LstdFlags)} logLevels = map[string]int{ "DEBUG": LEVEL_DEBUG, "INFO": LEVEL_INFO, "WARNING": LEVEL_WARNING, "ERROR": LEVEL_ERROR, "CRITICAL": LEVEL_CRITICAL, } // NoOp is the NO-OP logger NoOp, _ = NewLogger("CRITICAL", io.Discard, "") ) // NewLogger creates and returns a Logger object func NewLogger(level string, out io.Writer, prefix string) (BasicLogger, error) { l, ok := logLevels[strings.ToUpper(level)] if !ok { return defaultLogger, ErrInvalidLogLevel } return BasicLogger{Level: l, Prefix: prefix, Logger: log.New(out, "", log.LstdFlags)}, nil } type BasicLogger struct { Level int Prefix string Logger *log.Logger } // Debug logs a message using DEBUG as log level. func (l BasicLogger) Debug(v ...interface{}) { if l.Level > LEVEL_DEBUG { return } l.prependLog("DEBUG:", v...) } // Info logs a message using INFO as log level. func (l BasicLogger) Info(v ...interface{}) { if l.Level > LEVEL_INFO { return } l.prependLog("INFO:", v...) } // Warning logs a message using WARNING as log level. func (l BasicLogger) Warning(v ...interface{}) { if l.Level > LEVEL_WARNING { return } l.prependLog("WARNING:", v...) } // Error logs a message using ERROR as log level. func (l BasicLogger) Error(v ...interface{}) { if l.Level > LEVEL_ERROR { return } l.prependLog("ERROR:", v...) } // Critical logs a message using CRITICAL as log level. func (l BasicLogger) Critical(v ...interface{}) { l.prependLog("CRITICAL:", v...) } // Fatal is equivalent to l.Critical(fmt.Sprint()) followed by a call to os.Exit(1). func (l BasicLogger) Fatal(v ...interface{}) { l.prependLog("FATAL:", v...) os.Exit(1) } func (l BasicLogger) prependLog(level string, v ...interface{}) { msg := make([]interface{}, len(v)+2) msg[0] = l.Prefix msg[1] = level copy(msg[2:], v) l.Logger.Println(msg...) } ================================================ FILE: logging/log_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package logging import ( "bytes" "os" "os/exec" "regexp" "testing" ) const ( debugMsg = "Debug msg" infoMsg = "Info msg" warningMsg = "Warning msg" errorMsg = "Error msg" criticalMsg = "Critical msg" fatalMsg = "Fatal msg" ) func TestNewLogger(t *testing.T) { levels := []string{"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"} regexps := []*regexp.Regexp{ regexp.MustCompile(debugMsg), regexp.MustCompile(infoMsg), regexp.MustCompile(warningMsg), regexp.MustCompile(errorMsg), regexp.MustCompile(criticalMsg), } for i, level := range levels { output := logSomeStuff(level) for j := i; j < len(regexps); j++ { if !regexps[j].MatchString(output) { t.Errorf("The output doesn't contain the expected msg for the level: %s. [%s]", level, output) } } } } func TestNewLogger_unknownLevel(t *testing.T) { _, err := NewLogger("UNKNOWN", bytes.NewBuffer(make([]byte, 1024)), "pref") if err == nil { t.Error("The factory didn't return the expected error") return } if err != ErrInvalidLogLevel { t.Errorf("The factory didn't return the expected error. Got: %s", err.Error()) } } func TestNewLogger_fatal(t *testing.T) { if os.Getenv("BE_CRASHER") == "1" { l, err := NewLogger("Critical", bytes.NewBuffer(make([]byte, 1024)), "pref") if err != nil { t.Error("The factory returned an expected error:", err.Error()) return } l.Fatal("crash!!!") return } cmd := exec.Command(os.Args[0], "-test.run=TestNewLogger_fatal") cmd.Env = append(os.Environ(), "BE_CRASHER=1") err := cmd.Run() if e, ok := err.(*exec.ExitError); ok && !e.Success() { return } t.Fatalf("process ran with err %v, want exit status 1", err) } func logSomeStuff(level string) string { buff := bytes.NewBuffer(make([]byte, 1024)) logger, _ := NewLogger(level, buff, "pref") logger.Debug(debugMsg) logger.Info(infoMsg) logger.Warning(warningMsg) logger.Error(errorMsg) logger.Critical(criticalMsg) return buff.String() } ================================================ FILE: plugin/plugin.go ================================================ // SPDX-License-Identifier: Apache-2.0 /* Package plugin provides tools for loading and registering plugins */ package plugin import ( "os" "path/filepath" "strings" ) // Scan returns all the files contained in the received folder with a filename matching the given pattern func Scan(folder, pattern string) ([]string, error) { files, err := os.ReadDir(folder) if err != nil { return []string{}, err } var plugins []string for _, file := range files { if !file.IsDir() && strings.Contains(file.Name(), pattern) { plugins = append(plugins, filepath.Join(folder, file.Name())) } } return plugins, nil } ================================================ FILE: plugin/plugin_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package plugin import ( "os" "testing" ) func TestScan_ok(t *testing.T) { tmpDir, err := os.MkdirTemp(".", "test") if err != nil { t.Error("unexpected error:", err.Error()) return } defer os.RemoveAll(tmpDir) f, err := os.CreateTemp(tmpDir, "test.so") if err != nil { t.Error("unexpected error:", err.Error()) return } f.Close() defer os.RemoveAll(tmpDir) tot, err := Scan(tmpDir, ".so") if len(tot) != 1 { t.Error("unexpected number of plugins found:", tot) } if err != nil { t.Error("unexpected error:", err.Error()) } } func TestScan_noFolder(t *testing.T) { expectedErr := "open unknown: no such file or directory" tot, err := Scan("unknown", "") if len(tot) != 0 { t.Error("unexpected number of plugins loaded:", tot) } if err == nil { t.Error("expecting error!") return } if err.Error() != expectedErr { t.Error("unexpected error:", err.Error()) } } func TestScan_emptyFolder(t *testing.T) { name, err := os.MkdirTemp(".", "test") if err != nil { t.Error("unexpected error:", err.Error()) return } tot, err := Scan(name, "") if len(tot) != 0 { t.Error("unexpected number of plugins loaded:", tot) } if err != nil { t.Error("unexpected error:", err.Error()) } os.RemoveAll(name) } func TestScan_noMatches(t *testing.T) { tmpDir, err := os.MkdirTemp(".", "test") if err != nil { t.Error("unexpected error:", err.Error()) return } defer os.RemoveAll(tmpDir) f, err := os.CreateTemp(tmpDir, "test") if err != nil { t.Error("unexpected error:", err.Error()) return } f.Close() defer os.RemoveAll(tmpDir) tot, err := Scan(tmpDir, ".so") if len(tot) != 0 { t.Error("unexpected number of plugins loaded:", tot) } if err != nil { t.Error("unexpected error:", err.Error()) } } ================================================ FILE: proxy/balancing.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "context" "net/url" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" "github.com/luraproject/lura/v2/sd" ) // NewLoadBalancedMiddleware creates proxy middleware adding the most perfomant balancer // over a default subscriber func NewLoadBalancedMiddleware(remote *config.Backend) Middleware { return NewLoadBalancedMiddlewareWithSubscriber(sd.GetRegister().Get(remote.SD)(remote)) } // NewLoadBalancedMiddlewareWithSubscriber creates proxy middleware adding the most perfomant balancer // over the received subscriber func NewLoadBalancedMiddlewareWithSubscriber(subscriber sd.Subscriber) Middleware { return newLoadBalancedMiddleware(logging.NoOp, sd.NewBalancer(subscriber)) } // NewRoundRobinLoadBalancedMiddleware creates proxy middleware adding a round robin balancer // over a default subscriber func NewRoundRobinLoadBalancedMiddleware(remote *config.Backend) Middleware { return NewRoundRobinLoadBalancedMiddlewareWithSubscriber(sd.GetRegister().Get(remote.SD)(remote)) } // NewRandomLoadBalancedMiddleware creates proxy middleware adding a random balancer // over a default subscriber func NewRandomLoadBalancedMiddleware(remote *config.Backend) Middleware { return NewRandomLoadBalancedMiddlewareWithSubscriber(sd.GetRegister().Get(remote.SD)(remote)) } // NewRoundRobinLoadBalancedMiddlewareWithSubscriber creates proxy middleware adding a round robin // balancer over the received subscriber func NewRoundRobinLoadBalancedMiddlewareWithSubscriber(subscriber sd.Subscriber) Middleware { return newLoadBalancedMiddleware(logging.NoOp, sd.NewRoundRobinLB(subscriber)) } // NewRandomLoadBalancedMiddlewareWithSubscriber creates proxy middleware adding a random // balancer over the received subscriber func NewRandomLoadBalancedMiddlewareWithSubscriber(subscriber sd.Subscriber) Middleware { return newLoadBalancedMiddleware(logging.NoOp, sd.NewRandomLB(subscriber)) } // NewLoadBalancedMiddlewareWithLogger creates proxy middleware adding the most perfomant balancer // over a default subscriber func NewLoadBalancedMiddlewareWithLogger(l logging.Logger, remote *config.Backend) Middleware { return NewLoadBalancedMiddlewareWithSubscriberAndLogger(l, sd.GetRegister().Get(remote.SD)(remote)) } // NewLoadBalancedMiddlewareWithSubscriberAndLogger creates proxy middleware adding the most perfomant balancer // over the received subscriber func NewLoadBalancedMiddlewareWithSubscriberAndLogger(l logging.Logger, subscriber sd.Subscriber) Middleware { return newLoadBalancedMiddleware(l, sd.NewBalancer(subscriber)) } // NewRoundRobinLoadBalancedMiddlewareWithLogger creates proxy middleware adding a round robin balancer // over a default subscriber func NewRoundRobinLoadBalancedMiddlewareWithLogger(l logging.Logger, remote *config.Backend) Middleware { return NewRoundRobinLoadBalancedMiddlewareWithSubscriberAndLogger(l, sd.GetRegister().Get(remote.SD)(remote)) } // NewRandomLoadBalancedMiddlewareWithLogger creates proxy middleware adding a random balancer // over a default subscriber func NewRandomLoadBalancedMiddlewareWithLogger(l logging.Logger, remote *config.Backend) Middleware { return NewRandomLoadBalancedMiddlewareWithSubscriberAndLogger(l, sd.GetRegister().Get(remote.SD)(remote)) } // NewRoundRobinLoadBalancedMiddlewareWithSubscriberAndLogger creates proxy middleware adding a round robin // balancer over the received subscriber func NewRoundRobinLoadBalancedMiddlewareWithSubscriberAndLogger(l logging.Logger, subscriber sd.Subscriber) Middleware { return newLoadBalancedMiddleware(l, sd.NewRoundRobinLB(subscriber)) } // NewRandomLoadBalancedMiddlewareWithSubscriberAndLogger creates proxy middleware adding a random // balancer over the received subscriber func NewRandomLoadBalancedMiddlewareWithSubscriberAndLogger(l logging.Logger, subscriber sd.Subscriber) Middleware { return newLoadBalancedMiddleware(l, sd.NewRandomLB(subscriber)) } func newLoadBalancedMiddleware(l logging.Logger, lb sd.Balancer) Middleware { return func(next ...Proxy) Proxy { if len(next) > 1 { l.Fatal("too many proxies for this proxy middleware: newLoadBalancedMiddleware only accepts 1 proxy, got %d", len(next)) return nil } return func(ctx context.Context, r *Request) (*Response, error) { host, err := lb.Host() if err != nil { return nil, err } r.URL, err = url.Parse(host + r.Path) if err != nil { return nil, err } if len(r.Query) > 0 { if len(r.URL.RawQuery) > 0 { r.URL.RawQuery += "&" + r.Query.Encode() } else { r.URL.RawQuery += r.Query.Encode() } } return next[0](ctx, r) } } } ================================================ FILE: proxy/balancing_benchmark_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "context" "strconv" "testing" "github.com/luraproject/lura/v2/logging" ) const veryLargeString = "abcdefghijklmopqrstuvwxyzabcdefghijklmopqrstuvwxyzabcdefghijklmopqrstuvwxyzabcdefghijklmopqrstuvwxyz" func BenchmarkNewLoadBalancedMiddleware(b *testing.B) { for _, tc := range []int{3, 5, 9, 13, 17, 21, 25, 50, 100} { b.Run(strconv.Itoa(tc), func(b *testing.B) { proxy := newLoadBalancedMiddleware(logging.NoOp, dummyBalancer(veryLargeString[:tc]))(dummyProxy(&Response{})) b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { proxy(context.Background(), &Request{ Path: veryLargeString[:tc], }) } }) } } func BenchmarkNewLoadBalancedMiddleware_parallel3(b *testing.B) { benchmarkNewLoadBalancedMiddleware_parallel(b, veryLargeString[:3]) } func BenchmarkNewLoadBalancedMiddleware_parallel5(b *testing.B) { benchmarkNewLoadBalancedMiddleware_parallel(b, veryLargeString[:5]) } func BenchmarkNewLoadBalancedMiddleware_parallel9(b *testing.B) { benchmarkNewLoadBalancedMiddleware_parallel(b, veryLargeString[:9]) } func BenchmarkNewLoadBalancedMiddleware_parallel13(b *testing.B) { benchmarkNewLoadBalancedMiddleware_parallel(b, veryLargeString[:13]) } func BenchmarkNewLoadBalancedMiddleware_parallel17(b *testing.B) { benchmarkNewLoadBalancedMiddleware_parallel(b, veryLargeString[:17]) } func BenchmarkNewLoadBalancedMiddleware_parallel21(b *testing.B) { benchmarkNewLoadBalancedMiddleware_parallel(b, veryLargeString[:21]) } func BenchmarkNewLoadBalancedMiddleware_parallel25(b *testing.B) { benchmarkNewLoadBalancedMiddleware_parallel(b, veryLargeString[:25]) } func BenchmarkNewLoadBalancedMiddleware_parallel50(b *testing.B) { benchmarkNewLoadBalancedMiddleware_parallel(b, veryLargeString[:50]) } func BenchmarkNewLoadBalancedMiddleware_parallel100(b *testing.B) { benchmarkNewLoadBalancedMiddleware_parallel(b, veryLargeString[:100]) } func benchmarkNewLoadBalancedMiddleware_parallel(b *testing.B, subject string) { b.RunParallel(func(pb *testing.PB) { proxy := newLoadBalancedMiddleware(logging.NoOp, dummyBalancer(subject))(dummyProxy(&Response{})) for pb.Next() { proxy(context.Background(), &Request{ Path: subject, }) } }) } ================================================ FILE: proxy/balancing_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "context" "errors" "net" "net/url" "testing" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" "github.com/luraproject/lura/v2/sd/dnssrv" ) func TestNewLoadBalancedMiddleware_ok(t *testing.T) { want := "supu:8080/tupu" lb := newLoadBalancedMiddleware(logging.NoOp, dummyBalancer("supu:8080")) assertion := func(ctx context.Context, request *Request) (*Response, error) { if request.URL.String() != want { t.Errorf("The middleware did not update the request URL! want [%s], have [%s]\n", want, request.URL) } return nil, nil } if _, err := lb(assertion)(context.Background(), &Request{ Path: "/tupu", }); err != nil { t.Errorf("The middleware propagated an unexpected error: %s\n", err.Error()) } } func TestNewLoadBalancedMiddleware_explosiveBalancer(t *testing.T) { expected := errors.New("supu") lb := newLoadBalancedMiddleware(logging.NoOp, explosiveBalancer{expected}) if _, err := lb(explosiveProxy(t))(context.Background(), &Request{}); err != expected { t.Errorf("The middleware did not propagate the lb error\n") } } func TestNewRoundRobinLoadBalancedMiddleware(t *testing.T) { testLoadBalancedMw(t, NewRoundRobinLoadBalancedMiddleware(&config.Backend{ Host: []string{"http://127.0.0.1:8080"}, })) } func TestNewRandomLoadBalancedMiddleware(t *testing.T) { testLoadBalancedMw(t, NewRandomLoadBalancedMiddleware(&config.Backend{ Host: []string{"http://127.0.0.1:8080"}, })) } func testLoadBalancedMw(t *testing.T, lb Middleware) { for _, tc := range []struct { path string query url.Values expected string }{ { path: "/tupu", expected: "http://127.0.0.1:8080/tupu", }, { path: "/tupu?extra=true", expected: "http://127.0.0.1:8080/tupu?extra=true", }, { path: "/tupu?extra=true", query: url.Values{"some": []string{"none"}}, expected: "http://127.0.0.1:8080/tupu?extra=true&some=none", }, { path: "/tupu", query: url.Values{"some": []string{"none"}}, expected: "http://127.0.0.1:8080/tupu?some=none", }, } { assertion := func(ctx context.Context, request *Request) (*Response, error) { if request.URL.String() != tc.expected { t.Errorf("The middleware did not update the request URL! want [%s], have [%s]\n", tc.expected, request.URL) } return nil, nil } if _, err := lb(assertion)(context.Background(), &Request{ Path: tc.path, Query: tc.query, }); err != nil { t.Errorf("The middleware propagated an unexpected error: %s\n", err.Error()) } } } func TestNewLoadBalancedMiddleware_parsingError(t *testing.T) { lb := NewRandomLoadBalancedMiddleware(&config.Backend{ Host: []string{"127.0.0.1:8080"}, }) assertion := func(ctx context.Context, request *Request) (*Response, error) { t.Error("The middleware didn't block the request!") return nil, nil } if _, err := lb(assertion)(context.Background(), &Request{ Path: "/tupu", }); err == nil { t.Error("The middleware didn't propagate the expected error") } } func TestNewRoundRobinLoadBalancedMiddleware_DNSSRV(t *testing.T) { defaultLookup := dnssrv.DefaultLookup dnssrv.DefaultLookup = func(service, proto, name string) (cname string, addrs []*net.SRV, err error) { return "cname", []*net.SRV{ { Port: 8080, Target: "127.0.0.1", Weight: 1, }, }, nil } testLoadBalancedMw(t, NewRoundRobinLoadBalancedMiddlewareWithSubscriber(dnssrv.New("some.service.example.tld"))) dnssrv.DefaultLookup = defaultLookup } type dummyBalancer string func (d dummyBalancer) Host() (string, error) { return string(d), nil } type explosiveBalancer struct { Error error } func (e explosiveBalancer) Host() (string, error) { return "", e.Error } ================================================ FILE: proxy/concurrent.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "context" "errors" "fmt" "time" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" ) // NewConcurrentMiddlewareWithLogger creates a proxy middleware that enables sending several requests concurrently func NewConcurrentMiddlewareWithLogger(logger logging.Logger, remote *config.Backend) Middleware { if remote.ConcurrentCalls == 1 { logger.Fatal(fmt.Sprintf("too few concurrent calls for %s %s -> %s: NewConcurrentMiddleware expects more than 1 concurrent call, got %d", remote.ParentEndpointMethod, remote.ParentEndpoint, remote.URLPattern, remote.ConcurrentCalls)) return nil } serviceTimeout := time.Duration(75*remote.Timeout.Nanoseconds()/100) * time.Nanosecond return func(next ...Proxy) Proxy { if len(next) > 1 { logger.Fatal(fmt.Sprintf("too many proxies for this %s %s -> %s proxy middleware: NewConcurrentMiddleware only accepts 1 proxy, got %d", remote.ParentEndpointMethod, remote.ParentEndpoint, remote.URLPattern, len(next))) return nil } return func(ctx context.Context, request *Request) (*Response, error) { localCtx, cancel := context.WithTimeout(ctx, serviceTimeout) results := make(chan *Response, remote.ConcurrentCalls) failed := make(chan error, remote.ConcurrentCalls) for i := 0; i < remote.ConcurrentCalls; i++ { if i < remote.ConcurrentCalls-1 { go processConcurrentCall(localCtx, next[0], CloneRequest(request), results, failed) } else { go processConcurrentCall(localCtx, next[0], request, results, failed) } } var response *Response var err error for i := 0; i < remote.ConcurrentCalls; i++ { select { case response = <-results: if response != nil && response.IsComplete { cancel() return response, nil } case err = <-failed: case <-ctx.Done(): } } cancel() return response, err } } } // NewConcurrentMiddlewareWithLogger creates a proxy middleware that enables sending several requests concurrently. // Is recommended to use the version with a logger param. func NewConcurrentMiddleware(remote *config.Backend) Middleware { return NewConcurrentMiddlewareWithLogger(logging.NoOp, remote) } var errNullResult = errors.New("invalid response") func processConcurrentCall(ctx context.Context, next Proxy, request *Request, out chan<- *Response, failed chan<- error) { localCtx, cancel := context.WithCancel(ctx) result, err := next(localCtx, request) if err != nil { failed <- err cancel() return } if result == nil { failed <- errNullResult cancel() return } select { case out <- result: case <-ctx.Done(): failed <- ctx.Err() } cancel() } ================================================ FILE: proxy/concurrent_benchmark_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "context" "testing" "time" "github.com/luraproject/lura/v2/config" ) func BenchmarkNewConcurrentMiddleware_singleNext(b *testing.B) { backend := config.Backend{ ConcurrentCalls: 3, Timeout: time.Duration(100) * time.Millisecond, } proxy := NewConcurrentMiddleware(&backend)(dummyProxy(&Response{})) b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { proxy(context.Background(), &Request{}) } } ================================================ FILE: proxy/concurrent_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "context" "sync/atomic" "testing" "time" "github.com/luraproject/lura/v2/config" ) func TestNewConcurrentMiddleware_ok(t *testing.T) { timeout := 700 totalCalls := 3 backend := config.Backend{ ConcurrentCalls: totalCalls, Timeout: time.Duration(timeout) * time.Millisecond, } expected := Response{ Data: map[string]interface{}{"supu": 42, "tupu": true, "foo": "bar"}, IsComplete: true, } mw := NewConcurrentMiddleware(&backend) mustEnd := time.After(time.Duration(timeout) * time.Millisecond) result, err := mw(dummyProxy(&expected))(context.Background(), &Request{}) if err != nil { t.Errorf("The middleware propagated an unexpected error: %s\n", err.Error()) } select { case <-mustEnd: t.Errorf("We were expecting a response but we got none\n") default: } if result == nil { t.Errorf("The proxy returned a null result\n") return } if !result.IsComplete { t.Errorf("The proxy returned an incomplete result: %v\n", result) } if v, ok := result.Data["supu"]; !ok || v.(int) != 42 { t.Errorf("The proxy returned an unexpected result: %v\n", result) } if v, ok := result.Data["tupu"]; !ok || !v.(bool) { t.Errorf("The proxy returned an unexpected result: %v\n", result) } if v, ok := result.Data["foo"]; !ok || v.(string) != "bar" { t.Errorf("The proxy returned an unexpected result: %v\n", result) } } func TestNewConcurrentMiddleware_okAfterKo(t *testing.T) { timeout := 700 totalCalls := 3 backend := config.Backend{ ConcurrentCalls: totalCalls, Timeout: time.Duration(timeout) * time.Millisecond, } expected := Response{ Data: map[string]interface{}{"supu": 42, "tupu": true, "foo": "bar"}, IsComplete: true, } mw := NewConcurrentMiddleware(&backend) calls := uint64(0) mock := func(_ context.Context, _ *Request) (*Response, error) { total := atomic.AddUint64(&calls, 1) if total%2 == 0 { return &expected, nil } return nil, nil } mustEnd := time.After(time.Duration(timeout) * time.Millisecond) result, err := mw(mock)(context.Background(), &Request{}) if result == nil { t.Errorf("The proxy returned a null result\n") return } if err != nil { t.Errorf("The middleware propagated an unexpected error: %s\n", err.Error()) } select { case <-mustEnd: t.Errorf("We were expecting a response but we got none\n") default: } if !result.IsComplete { t.Errorf("The proxy returned an incomplete result: %v\n", result) } if v, ok := result.Data["supu"]; !ok || v.(int) != 42 { t.Errorf("The proxy returned an unexpected result: %v\n", result) } if v, ok := result.Data["tupu"]; !ok || !v.(bool) { t.Errorf("The proxy returned an unexpected result: %v\n", result) } if v, ok := result.Data["foo"]; !ok || v.(string) != "bar" { t.Errorf("The proxy returned an unexpected result: %v\n", result) } } func TestNewConcurrentMiddleware_timeout(t *testing.T) { timeout := 100 totalCalls := 3 backend := config.Backend{ ConcurrentCalls: totalCalls, Timeout: time.Duration(timeout) * time.Millisecond, } mw := NewConcurrentMiddleware(&backend) mustEnd := time.After(time.Duration(2*timeout) * time.Millisecond) response, err := mw(delayedProxy(t, time.Duration(5*timeout)*time.Millisecond, &Response{}))(context.Background(), &Request{}) if err == nil || err.Error() != "context deadline exceeded" { t.Errorf("The middleware didn't propagate a timeout error: %s\n", err) } if response != nil { t.Errorf("We weren't expecting a response but we got one: %v\n", response) return } select { case <-mustEnd: t.Errorf("We were expecting a response at this point in time!\n") return default: } } ================================================ FILE: proxy/factory.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" "github.com/luraproject/lura/v2/sd" ) // Factory creates proxies based on the received endpoint configuration. // // Both, factories and backend factories, create proxies but factories are designed as a stack makers // because they are intended to generate the complete proxy stack for a given frontend endpoint // the app would expose and they could wrap several proxies provided by a backend factory type Factory interface { New(cfg *config.EndpointConfig) (Proxy, error) } // FactoryFunc type is an adapter to allow the use of ordinary functions as proxy factories. // If f is a function with the appropriate signature, FactoryFunc(f) is a Factory that calls f. type FactoryFunc func(*config.EndpointConfig) (Proxy, error) // New implements the Factory interface func (f FactoryFunc) New(cfg *config.EndpointConfig) (Proxy, error) { return f(cfg) } // DefaultFactory returns a default http proxy factory with the injected logger func DefaultFactory(logger logging.Logger) Factory { return NewDefaultFactory(httpProxy, logger) } // DefaultFactoryWithSubscriber returns a default proxy factory with the injected logger and subscriber factory func DefaultFactoryWithSubscriber(logger logging.Logger, sF sd.SubscriberFactory) Factory { return NewDefaultFactoryWithSubscriber(httpProxy, logger, sF) } // NewDefaultFactory returns a default proxy factory with the injected proxy builder and logger func NewDefaultFactory(backendFactory BackendFactory, logger logging.Logger) Factory { sf := func(remote *config.Backend) sd.Subscriber { return sd.GetRegister().Get(remote.SD)(remote) } return NewDefaultFactoryWithSubscriber(backendFactory, logger, sf) } // NewDefaultFactoryWithSubscriber returns a default proxy factory with the injected proxy builder, // logger and subscriber factory func NewDefaultFactoryWithSubscriber(backendFactory BackendFactory, logger logging.Logger, sF sd.SubscriberFactory) Factory { return defaultFactory{backendFactory, logger, sF} } type defaultFactory struct { backendFactory BackendFactory logger logging.Logger subscriberFactory sd.SubscriberFactory } // New implements the Factory interface func (pf defaultFactory) New(cfg *config.EndpointConfig) (p Proxy, err error) { switch len(cfg.Backend) { case 0: err = ErrNoBackends case 1: p, err = pf.newSingle(cfg) default: p, err = pf.newMulti(cfg) } if err != nil { return } p = NewPluginMiddleware(pf.logger, cfg)(p) p = NewStaticMiddleware(pf.logger, cfg)(p) return } func (pf defaultFactory) newMulti(cfg *config.EndpointConfig) (p Proxy, err error) { backendProxy := make([]Proxy, len(cfg.Backend)) for i, backend := range cfg.Backend { backendProxy[i] = pf.newStack(backend) } p = NewMergeDataMiddleware(pf.logger, cfg)(backendProxy...) p = NewFlatmapMiddleware(pf.logger, cfg)(p) return } func (pf defaultFactory) newSingle(cfg *config.EndpointConfig) (Proxy, error) { return pf.newStack(cfg.Backend[0]), nil } func (pf defaultFactory) newStack(backend *config.Backend) (p Proxy) { p = pf.backendFactory(backend) p = NewBackendPluginMiddleware(pf.logger, backend)(p) p = NewGraphQLMiddleware(pf.logger, backend)(p) p = NewFilterHeadersMiddleware(pf.logger, backend)(p) p = NewLoadBalancedMiddlewareWithSubscriberAndLogger(pf.logger, pf.subscriberFactory(backend))(p) if backend.ConcurrentCalls > 1 { p = NewConcurrentMiddlewareWithLogger(pf.logger, backend)(p) } p = NewRequestBuilderMiddlewareWithLogger(pf.logger, backend)(p) // we need to filter the input query strings before the request is constructed // so the query strings are properly added to the URL: p = NewFilterQueryStringsMiddleware(pf.logger, backend)(p) return } ================================================ FILE: proxy/factory_test.go ================================================ //go:build integration || !race // +build integration !race // SPDX-License-Identifier: Apache-2.0 package proxy import ( "bytes" "context" "net/url" "strings" "testing" "time" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" "github.com/luraproject/lura/v2/sd" ) func TestFactoryFunc(t *testing.T) { buff := bytes.NewBuffer(make([]byte, 1024)) logger, err := logging.NewLogger("ERROR", buff, "pref") if err != nil { t.Error("building the logger:", err.Error()) return } factory := FactoryFunc(func(cfg *config.EndpointConfig) (Proxy, error) { return DefaultFactory(logger).New(cfg) }) if _, err := factory.New(&config.EndpointConfig{}); err != ErrNoBackends { t.Errorf("Expecting ErrNoBackends. Got: %v\n", err) } } func TestDefaultFactoryWithSubscriber(t *testing.T) { buff := bytes.NewBuffer(make([]byte, 1024)) logger, err := logging.NewLogger("ERROR", buff, "pref") if err != nil { t.Error("building the logger:", err.Error()) return } factory := DefaultFactoryWithSubscriber(logger, sd.FixedSubscriberFactory) if _, err := factory.New(&config.EndpointConfig{}); err != ErrNoBackends { t.Errorf("Expecting ErrNoBackends. Got: %v\n", err) } } func TestDefaultFactory_noBackends(t *testing.T) { buff := bytes.NewBuffer(make([]byte, 1024)) logger, err := logging.NewLogger("ERROR", buff, "pref") if err != nil { t.Error("building the logger:", err.Error()) return } factory := DefaultFactory(logger) if _, err := factory.New(&config.EndpointConfig{}); err != ErrNoBackends { t.Errorf("Expecting ErrNoBackends. Got: %v\n", err) } } func TestNewDefaultFactory_ok(t *testing.T) { buff := bytes.NewBuffer(make([]byte, 1024)) logger, err := logging.NewLogger("ERROR", buff, "pref") if err != nil { t.Error("building the logger:", err.Error()) return } expectedResponse := Response{ IsComplete: true, Data: map[string]interface{}{"foo": "bar"}, } expectedMethod := "SOME" expectedHost := "http://example.com/" expectedPath := "/foo" expectedURL := expectedHost + strings.TrimLeft(expectedPath, "/") URL, err := url.Parse(expectedHost) if err != nil { t.Errorf("building the sample url: %s\n", err.Error()) } request := Request{ Method: expectedMethod, Path: expectedPath, URL: URL, Body: newDummyReadCloser(""), } assertion := func(ctx context.Context, request *Request) (*Response, error) { if request.URL.String() != expectedURL { t.Errorf("The middlewares did not update the request URL! want [%s], have [%s]\n", expectedURL, request.URL) } return &expectedResponse, nil } factory := NewDefaultFactory(func(_ *config.Backend) Proxy { return assertion }, logger) backend := config.Backend{ URLPattern: expectedPath, Method: expectedMethod, } endpointSingle := config.EndpointConfig{ Backend: []*config.Backend{&backend}, } endpointMulti := config.EndpointConfig{ Backend: []*config.Backend{&backend, &backend}, ConcurrentCalls: 3, } serviceConfig := config.ServiceConfig{ Version: config.ConfigVersion, Endpoints: []*config.EndpointConfig{&endpointSingle, &endpointMulti}, Timeout: 100 * time.Millisecond, Host: []string{expectedHost}, } if err := serviceConfig.Init(); err != nil { t.Errorf("Error during the config init: %s\n", err.Error()) } proxyMulti, err := factory.New(&endpointMulti) if err != nil { t.Errorf("The factory returned an unexpected error: %s\n", err.Error()) } response, err := proxyMulti(context.Background(), &request) if err != nil { t.Errorf("The proxy middleware propagated an unexpected error: %s\n", err.Error()) } if !response.IsComplete || len(response.Data) != len(expectedResponse.Data) { t.Errorf("The proxy middleware propagated an unexpected error: %v\n", response) } proxySingle, err := factory.New(&endpointSingle) if err != nil { t.Errorf("The factory returned an unexpected error: %s\n", err.Error()) } response, err = proxySingle(context.Background(), &request) if err != nil { t.Errorf("The proxy middleware propagated an unexpected error: %s\n", err.Error()) } if !response.IsComplete || len(response.Data) != len(expectedResponse.Data) { t.Errorf("The proxy middleware propagated an unexpected error: %v\n", response) } } ================================================ FILE: proxy/formatter.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "context" "fmt" "strings" "github.com/krakend/flatmap/tree" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" ) // EntityFormatter formats the response data type EntityFormatter interface { Format(Response) Response } // EntityFormatterFunc holds the formatter function type EntityFormatterFunc func(Response) Response // Format implements the EntityFormatter interface func (e EntityFormatterFunc) Format(entity Response) Response { return e(entity) } type propertyFilter func(*Response) type entityFormatter struct { Target string Prefix string PropertyFilter propertyFilter Mapping map[string]string } // NewEntityFormatter creates an entity formatter with the received backend definition func NewEntityFormatter(remote *config.Backend) EntityFormatter { if ef := newFlatmapFormatter(remote.ExtraConfig, remote.Target, remote.Group); ef != nil { return ef } var propertyFilter propertyFilter if len(remote.AllowList) > 0 { propertyFilter = newAllowlistingFilter(remote.AllowList) } else { propertyFilter = newDenylistingFilter(remote.DenyList) } sanitizedMappings := make(map[string]string, len(remote.Mapping)) for i, m := range remote.Mapping { v := strings.Split(m, ".") sanitizedMappings[i] = v[0] } return entityFormatter{ Target: remote.Target, Prefix: remote.Group, PropertyFilter: propertyFilter, Mapping: sanitizedMappings, } } // Format implements the EntityFormatter interface func (e entityFormatter) Format(entity Response) Response { if e.Target != "" { extractTarget(e.Target, &entity) } if len(entity.Data) > 0 { e.PropertyFilter(&entity) } if len(entity.Data) > 0 { for formerKey, newKey := range e.Mapping { if v, ok := entity.Data[formerKey]; ok { entity.Data[newKey] = v delete(entity.Data, formerKey) } } } if e.Prefix != "" { entity.Data = map[string]interface{}{e.Prefix: entity.Data} } return entity } func extractTarget(target string, entity *Response) { for _, part := range strings.Split(target, ".") { if tmp, ok := entity.Data[part]; ok { entity.Data, ok = tmp.(map[string]interface{}) if !ok { entity.Data = map[string]interface{}{} return } } else { entity.Data = map[string]interface{}{} return } } } func AllowlistPrune(wlDict, inDict map[string]interface{}) bool { canDelete := true var deleteSibling bool for k, v := range inDict { deleteSibling = true if subWl, ok := wlDict[k]; ok { if subWlDict, okk := subWl.(map[string]interface{}); okk { if subInDict, isDict := v.(map[string]interface{}); isDict && !AllowlistPrune(subWlDict, subInDict) { deleteSibling = false } } else { // Allowlist leaf, maintain this branch deleteSibling = false } } if deleteSibling { delete(inDict, k) } else { canDelete = false } } return canDelete } func newAllowlistingFilter(Allowlist []string) propertyFilter { wlDict := make(map[string]interface{}) for _, k := range Allowlist { wlFields := strings.Split(k, ".") d := buildDictPath(wlDict, wlFields[:len(wlFields)-1]) d[wlFields[len(wlFields)-1]] = true } return func(entity *Response) { if AllowlistPrune(wlDict, entity.Data) { for k := range entity.Data { delete(entity.Data, k) } } } } func buildDictPath(accumulator map[string]interface{}, fields []string) map[string]interface{} { var ok bool var c map[string]interface{} var fIdx int fEnd := len(fields) p := accumulator for fIdx = 0; fIdx < fEnd; fIdx++ { if c, ok = p[fields[fIdx]].(map[string]interface{}); !ok { break } p = c } for ; fIdx < fEnd; fIdx++ { c = make(map[string]interface{}) p[fields[fIdx]] = c p = c } return p } func buildDenyTree(path []string, tree map[string]interface{}) { if len(path) == 0 { return } n := path[0] if len(path) == 1 { // this is the node to be deleted, so, any other child // that is under this node, does not need to be visited: // we "delete" any descendant from this node tree[n] = nil return } if k, ok := tree[n]; ok { if k == nil { // all this child should be deleted, so, no matter // if the entry says to delete some extra child.. // everything will be deleted return } childTree, ok := k.(map[string]interface{}) if !ok { // this should never happen if this algorithm is correct tree[n] = nil return } buildDenyTree(path[1:], childTree) return } // it the key does not exist, we need to keep building the children, // and at this point we know that path is at least len = 2, and that // tree[n] does not exist childTree := make(map[string]interface{}, 1) tree[n] = childTree buildDenyTree(path[1:], childTree) } func recDelete(ref map[string]interface{}, v interface{}) { m, ok := v.(map[string]interface{}) if !ok || m == nil { return } for rk, rv := range ref { dv, dok := m[rk] if !dok { continue } if rv == nil { delete(m, rk) continue } recDelete(rv.(map[string]interface{}), dv) } } func newDenylistingFilter(blacklist []string) propertyFilter { bl := make(map[string]interface{}, len(blacklist)) for _, key := range blacklist { keys := strings.Split(key, ".") buildDenyTree(keys, bl) } return func(entity *Response) { recDelete(bl, entity.Data) } } const flatmapKey = "flatmap_filter" type flatmapFormatter struct { Target string Prefix string Ops []flatmapOp } type flatmapOp struct { Type string Args [][]string } // Format implements the EntityFormatter interface func (e flatmapFormatter) Format(entity Response) Response { if e.Target != "" { extractTarget(e.Target, &entity) } e.processOps(&entity) if e.Prefix != "" { entity.Data = map[string]interface{}{e.Prefix: entity.Data} } return entity } func (e flatmapFormatter) processOps(entity *Response) { flatten, err := tree.New(entity.Data) if err != nil { return } for _, op := range e.Ops { switch op.Type { case "move": flatten.Move(op.Args[0], op.Args[1]) case "append": flatten.Append(op.Args[0], op.Args[1]) case "del": for _, k := range op.Args { flatten.Del(k) } default: } } entity.Data, _ = flatten.Get([]string{}).(map[string]interface{}) } func newFlatmapFormatter(cfg config.ExtraConfig, target, group string) *flatmapFormatter { if v, ok := cfg[Namespace]; ok { if e, ok := v.(map[string]interface{}); ok { if vs, ok := e[flatmapKey].([]interface{}); ok { if len(vs) == 0 { return nil } var ops []flatmapOp for _, v := range vs { m, ok := v.(map[string]interface{}) if !ok { continue } op := flatmapOp{} if t, ok := m["type"].(string); ok { op.Type = t } else { continue } if args, ok := m["args"].([]interface{}); ok { op.Args = make([][]string, len(args)) for k, arg := range args { if t, ok := arg.(string); ok { op.Args[k] = strings.Split(t, ".") } } } ops = append(ops, op) } if len(ops) == 0 { return nil } return &flatmapFormatter{ Target: target, Prefix: group, Ops: ops, } } } } return nil } // NewFlatmapMiddleware creates a proxy middleware that enables applying flatmap operations to the proxy response func NewFlatmapMiddleware(logger logging.Logger, cfg *config.EndpointConfig) Middleware { formatter := newFlatmapFormatter(cfg.ExtraConfig, "", "") return func(next ...Proxy) Proxy { if len(next) > 1 { logger.Fatal("too many proxies for this proxy middleware: NewFlatmapMiddleware only accepts 1 proxy, got %d", len(next)) return nil } if formatter == nil { return next[0] } logger.Debug( fmt.Sprintf( "[ENDPOINT: %s][Flatmap] Adding flatmap manipulator with %d operations", cfg.Endpoint, len(formatter.Ops), ), ) return func(ctx context.Context, request *Request) (*Response, error) { resp, err := next[0](ctx, request) if err != nil { return resp, err } r := formatter.Format(*resp) return &r, nil } } } ================================================ FILE: proxy/formatter_benchmark_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "bytes" "fmt" "strconv" "testing" "github.com/luraproject/lura/v2/config" ) func BenchmarkEntityFormatter_allowFilter(b *testing.B) { data := map[string]interface{}{ "supu": 42, "tupu": false, "foo": "bar", } for _, extraFields := range []int{0, 5, 10, 15, 20, 25} { sampleData := data for i := 0; i < extraFields; i++ { sampleData[fmt.Sprintf("%d", i)] = i } for _, testCase := range [][]string{ {}, {"supu"}, {"supu", "tupu"}, {"supu", "tupu", "foo"}, {"supu", "tupu", "foo", "unknown"}, } { sample := Response{ Data: sampleData, IsComplete: true, } b.Run(fmt.Sprintf("with %d elements with %d extra fields", len(testCase), extraFields), func(b *testing.B) { f := NewEntityFormatter(&config.Backend{AllowList: testCase}) b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { f.Format(sample) } }) } } } func benchmarkDeepChilds(depth, extraSiblings int) map[string]interface{} { data := make(map[string]interface{}, extraSiblings+1) for i := 0; i < extraSiblings; i++ { data[fmt.Sprintf("extra%d", i)] = "sibling_value" } if depth > 0 { data[fmt.Sprintf("child%d", depth)] = benchmarkDeepChilds(depth-1, extraSiblings) } else { data["child0"] = 1 } return data } func benchmarkDeepStructure(numTargets, targetDepth, extraFields, extraSiblings int) (map[string]interface{}, []string) { data := make(map[string]interface{}, numTargets+extraFields) targetKeys := make([]string, numTargets) for i := 0; i < numTargets; i++ { data[fmt.Sprintf("target%d", i)] = benchmarkDeepChilds(targetDepth-1, extraSiblings) } for j := 0; j < extraFields; j++ { data[fmt.Sprintf("extra%d", j)] = benchmarkDeepChilds(targetDepth-1, extraSiblings) } // create the target list for i := 0; i < numTargets; i++ { var buffer bytes.Buffer buffer.WriteString(fmt.Sprintf("target%d", i)) for j := targetDepth - 1; j >= 0; j-- { buffer.WriteString(fmt.Sprintf(".child%d", j)) } targetKeys[i] = buffer.String() } return data, targetKeys } func BenchmarkEntityFormatter_deepAllowFilter(b *testing.B) { numTargets := []int{0, 1, 2, 5, 10} depths := []int{1, 3, 7} for _, nTargets := range numTargets { for _, depth := range depths { extraFields := nTargets + depth*2 extraSiblings := nTargets data, allow := benchmarkDeepStructure(nTargets, depth, extraFields, extraSiblings) sample := Response{ Data: data, IsComplete: true, } f := NewEntityFormatter(&config.Backend{AllowList: allow}) b.Run(fmt.Sprintf("numTargets:%d,depth:%d,extraFields:%d,extraSiblings:%d", nTargets, depth, extraFields, extraSiblings), func(b *testing.B) { b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { f.Format(sample) } }) } } } func BenchmarkEntityFormatter_denyFilter(b *testing.B) { data := map[string]interface{}{ "supu": 42, "tupu": false, "foo": "bar", } for _, extraFields := range []int{0, 5, 10, 15, 20, 25} { sampleData := data for i := 0; i < extraFields; i++ { sampleData[fmt.Sprintf("%d", i)] = i } for _, testCase := range [][]string{ {}, {"supu"}, {"supu", "tupu"}, {"supu", "tupu", "foo"}, {"supu", "tupu", "foo", "unknown"}, } { sample := Response{ Data: sampleData, IsComplete: true, } b.Run(fmt.Sprintf("with %d elements with %d extra fields", len(testCase), extraFields), func(b *testing.B) { f := NewEntityFormatter(&config.Backend{DenyList: testCase}) b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { f.Format(sample) } }) } } } func BenchmarkEntityFormatter_grouping(b *testing.B) { preffix := "group1" for _, extraFields := range []int{0, 5, 10, 15, 20, 25} { sampleData := make(map[string]interface{}, extraFields) for i := 0; i < extraFields; i++ { sampleData[fmt.Sprintf("%d", i)] = i } sample := Response{ Data: sampleData, IsComplete: true, } b.Run(fmt.Sprintf("with %d elements", extraFields), func(b *testing.B) { f := NewEntityFormatter(&config.Backend{Group: preffix}) b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { f.Format(sample) } }) } } func BenchmarkEntityFormatter_mapping(b *testing.B) { for _, extraFields := range []int{0, 5, 10, 15, 20, 25} { sampleData := make(map[string]interface{}, extraFields) for i := 0; i < extraFields; i++ { sampleData[fmt.Sprintf("%d", i)] = i } for _, testCase := range []map[string]string{ {}, {"1": "supu"}, {"1": "supu", "2": "tupu"}, {"1": "supu", "2": "tupu", "3": "foo"}, {"1": "supu", "2": "tupu", "3": "foo", "4": "bar"}, {"1": "supu", "2": "tupu", "3": "foo", "4": "bar", "5": "a"}, } { sample := Response{ Data: sampleData, IsComplete: true, } b.Run(fmt.Sprintf("with %d elements with %d extra fields", len(testCase), extraFields), func(b *testing.B) { f := NewEntityFormatter(&config.Backend{Mapping: testCase}) b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { f.Format(sample) } }) } } } func BenchmarkEntityFormatter_flatmapAlt(b *testing.B) { f := NewEntityFormatter(&config.Backend{ Target: "content", Group: "group", ExtraConfig: config.ExtraConfig{ Namespace: map[string]interface{}{ flatmapKey: []interface{}{ map[string]interface{}{ "type": "del", "args": []interface{}{"c"}, }, map[string]interface{}{ "type": "move", "args": []interface{}{"supu", "SUPUUUUU"}, }, map[string]interface{}{ "type": "move", "args": []interface{}{"a.b", "a.BOOOOO"}, }, map[string]interface{}{ "type": "del", "args": []interface{}{"collection.*.b"}, }, map[string]interface{}{ "type": "del", "args": []interface{}{"collection.*.d"}, }, map[string]interface{}{ "type": "del", "args": []interface{}{"collection.*.e"}, }, map[string]interface{}{ "type": "move", "args": []interface{}{"collection.*.c", "collection.*.x"}, }, }, }, }, }) for _, size := range []int{1, 2, 5, 10, 20, 50, 100, 500} { b.Run(strconv.Itoa(size), func(b *testing.B) { sub := map[string]interface{}{ "b": true, "c": 42, "d": "tupu", "e": []interface{}{1, 2, 3, 4}, } sample := Response{ Data: map[string]interface{}{ "content": map[string]interface{}{ "supu": 42, "tupu": false, "foo": "bar", "a": sub, "collection": []interface{}{sub, sub, sub, sub}, }, }, IsComplete: true, } var subCol []interface{} for i := 0; i < size; i++ { subCol = append(subCol, i) } sub["e"] = subCol var sampleSubCol []interface{} for i := 0; i < size; i++ { sampleSubCol = append(sampleSubCol, sub) } sample.Data["collection"] = sampleSubCol b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { f.Format(sample) } }) } } func BenchmarkEntityFormatter_flatmap(b *testing.B) { numTargets := []int{0, 1, 2, 5, 10} depths := []int{1, 3, 7} for _, nTargets := range numTargets { for _, depth := range depths { extraFields := nTargets + depth*2 extraSiblings := nTargets data, blacklist := benchmarkDeepStructure(nTargets, depth, extraFields, extraSiblings) sample := Response{ Data: data, IsComplete: true, } var cmds []interface{} for _, path := range blacklist { cmds = append(cmds, map[string]interface{}{ "type": "del", "args": []interface{}{path}, }) } f := NewEntityFormatter(&config.Backend{ ExtraConfig: config.ExtraConfig{ Namespace: map[string]interface{}{ flatmapKey: cmds, }, }, }) b.Run(fmt.Sprintf("numTargets:%d,depth:%d,extraFields:%d,extraSiblings:%d", nTargets, depth, extraFields, extraSiblings), func(b *testing.B) { b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { f.Format(sample) } }) } } } ================================================ FILE: proxy/formatter_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "context" "reflect" "testing" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" ) func TestEntityFormatterFunc(t *testing.T) { expected := Response{Data: map[string]interface{}{"one": 1}, IsComplete: true} f := func(_ Response) Response { return expected } formatter := EntityFormatterFunc(f) result := formatter.Format(Response{}) if result.Data["one"].(int) != 1 { t.Error("unexpected result:", result.Data) } if !result.IsComplete { t.Error("unexpected result:", result) } } func TestEntityFormatter_newAllowFilter(t *testing.T) { sample := Response{ Data: map[string]interface{}{ "supu": 42, "tupu": false, "foo": "bar", "a": map[string]interface{}{ "b": true, "c": 42, "d": "tupu", }, }, IsComplete: true, } expected := Response{ Data: map[string]interface{}{ "supu": 42, "a": map[string]interface{}{ "b": true, "c": 42, }, }, IsComplete: true, } f := NewEntityFormatter(&config.Backend{AllowList: []string{"supu", "a.b", "a.c", "foo.unknown"}}) result := f.Format(sample) if v, ok := result.Data["supu"]; !ok || v != expected.Data["supu"] { t.Errorf("The formatter returned an unexpected result for the field supu: %v\n", result) } v, ok := result.Data["a"] if !ok { t.Errorf("The formatter returned an unexpected result for the fields a.b & a.c: %v\n", result) } tmp := v.(map[string]interface{}) if b, okk := tmp["b"]; !okk || !b.(bool) { t.Errorf("The formatter returned an unexpected result for the field a.b: %v\n", result) } if c, okk := tmp["c"]; !okk || c.(int) != 42 { t.Errorf("The formatter returned an unexpected result for the field a.c: %v\n", result) } if len(tmp) != 2 { t.Errorf("The formatter returned an unexpected result size for the field a: %v\n", result) } if len(result.Data) != 2 || result.IsComplete != expected.IsComplete { t.Errorf("The formatter returned an unexpected result size: %v\n", result) } } func TestEntityFormatter_newAllowDeepFields(t *testing.T) { sample := Response{ Data: map[string]interface{}{ "id": 42, "tupu": map[string]interface{}{ "muku": map[string]interface{}{ "supu": 1, "muku": 2, "gutu": map[string]interface{}{ "kugu": 42, }, }, "supu": map[string]interface{}{ "supu": 3, "muku": 4, }, }, }, IsComplete: true, } expectedSupuChild := 1 var ok bool f := NewEntityFormatter(&config.Backend{AllowList: []string{"tupu.muku.supu", "tupu.muku.gutu.kugu"}}) res := f.Format(sample) var tupu map[string]interface{} var muku map[string]interface{} var gutu map[string]interface{} var kugu int var supuChild int if tupu, ok = res.Data["tupu"].(map[string]interface{}); !ok { t.Errorf("The formatter does not have field tupu\n") } if muku, ok = tupu["muku"].(map[string]interface{}); !ok { t.Errorf("The formatter does not have field tupu.muku\n") } if supuChild, ok = muku["supu"].(int); !ok || supuChild != expectedSupuChild { t.Errorf("The formatter does not have field tupu.muku.supu or wrong value\n") } if _, ok = tupu["supu"].(map[string]interface{}); ok { t.Errorf("The formatter should have removed tupu.supu\n") } if _, ok = muku["muku"]; ok { t.Errorf("The formatter should have removed tupu.muku.muku\n") } if gutu, ok = muku["gutu"].(map[string]interface{}); !ok { t.Errorf("The formatter does not have field tupu.muku.gutu\n") } if kugu, ok = gutu["kugu"].(int); !ok || kugu != 42 { t.Errorf("The formatter does not have field tupu.muku.gutu.kugu\n") } } func TestEntityFormatter_newDenyFilter(t *testing.T) { sample := Response{ Data: map[string]interface{}{ "supu": 42, "tupu": false, "foo": "bar", "a": map[string]interface{}{ "a": map[string]interface{}{ "b": true, "c": 42, "d": "tupu", "deeper": map[string]interface{}{ "a": map[string]interface{}{ "aa": "deleteme deeper.a.aa", "bb": "do not deleteme deeper.a.bb", }, "b": map[string]interface{}{ "aa": "deleteme deeper.b.aa", "bb": "do not deleteme deeper.b.bb", }, }, }, "b": true, "c": 42, "d": "tupu", }, }, IsComplete: true, } expected := Response{ Data: map[string]interface{}{ "tupu": false, "foo": "bar", "a": map[string]interface{}{ "a": map[string]interface{}{ "c": 42, "d": "tupu", "deeper": map[string]interface{}{ "a": map[string]interface{}{ "bb": "do not deleteme deeper.a.bb", }, "b": map[string]interface{}{ "bb": "do not deleteme deeper.b.bb", }, }, }, "d": "tupu", }, }, IsComplete: true, } f := NewEntityFormatter(&config.Backend{DenyList: []string{ "supu", "a.b", "a.c", "foo.unknown", "a.a.b", "a.a.deeper.a.aa", "a.a.deeper.b.aa", }}) result := f.Format(sample) if v, ok := result.Data["tupu"]; !ok || v != expected.Data["tupu"] { t.Errorf("The formatter returned an unexpected result for the field tupu: %v\n", result) } if v, ok := result.Data["foo"]; !ok || v != expected.Data["foo"] { t.Errorf("The formatter returned an unexpected result for the field foo: %v\n", result) } v, ok := result.Data["a"] if !ok { t.Errorf("The formatter returned an unexpected result for the field a.d: %v\n", result) } tmp := v.(map[string]interface{}) if d, okk := tmp["d"]; !okk || d != "tupu" { t.Errorf("The formatter returned an unexpected result for the field a.d: %v\n", result) } if len(tmp) != 2 { // a.a should exist , and a.d should exist t.Errorf("The formatter returned an unexpected result size for the field a: %v\n", result) } if len(result.Data) != 3 || result.IsComplete != expected.IsComplete { t.Errorf("The formatter returned an unexpected result size: %v\n", result) } if !reflect.DeepEqual(expected.Data, result.Data) { t.Errorf("unexpected response. have: %+v, want: %+v", result.Data, expected.Data) } } func TestEntityFormatter_grouping(t *testing.T) { preffix := "group1" sample := Response{ Data: map[string]interface{}{ "supu": 42, "tupu": false, "foo": "bar", }, IsComplete: true, } expected := Response{ Data: map[string]interface{}{ preffix: map[string]interface{}{ "supu": 42, "tupu": false, "foo": "bar", }, }, IsComplete: true, } f := NewEntityFormatter(&config.Backend{Group: preffix}) result := f.Format(sample) if len(result.Data) != 1 || result.IsComplete != expected.IsComplete { t.Fail() } if _, ok := result.Data[preffix]; !ok { t.Fail() } group := result.Data[preffix].(map[string]interface{}) for k, expectedValue := range expected.Data[preffix].(map[string]interface{}) { if v, ok := group[k]; !ok || v != expectedValue { t.Fail() } } } func TestEntityFormatter_mapping(t *testing.T) { mapping := map[string]string{"supu": "SUPUUUUU", "tupu": "TUPUUUUU", "a.b": "a.BOOOOO"} sub := map[string]interface{}{ "b": true, "c": 42, "d": "tupu", } sample := Response{ Data: map[string]interface{}{ "supu": 42, "tupu": false, "foo": "bar", "a": sub, }, IsComplete: true, } expected := Response{ Data: map[string]interface{}{ "SUPUUUUU": 42, "TUPUUUUU": false, "foo": "bar", "a": sub, }, IsComplete: true, } f := NewEntityFormatter(&config.Backend{Mapping: mapping}) result := f.Format(sample) if len(result.Data) != 4 || result.IsComplete != expected.IsComplete { t.Errorf("The formatter returned an unexpected result size: %v\n", result.Data) } for k, expectedValue := range expected.Data { if k == "a" { continue } if v, ok := result.Data[k]; !ok || v != expectedValue { t.Errorf("The formatter returned an unexpected result for the key %s: %v\n", k, v) } } group := result.Data["a"].(map[string]interface{}) for k, expectedValue := range expected.Data["a"].(map[string]interface{}) { if v, ok := group[k]; !ok || v != expectedValue { t.Errorf("The formatter returned an unexpected result for the key %s: %v\n", k, v) } } if len(group) != 3 { t.Errorf("The formatter returned an unexpected result size for the subentity: %v\n", group) } } func TestEntityFormatter_targeting(t *testing.T) { target := "group1" sub := map[string]interface{}{ "b": true, "c": 42, "d": "tupu", } sample := Response{ Data: map[string]interface{}{ "supu": 42, "tupu": false, "foo": "bar", target: sub, }, IsComplete: true, } expected := Response{ Data: sub, IsComplete: true, } f := NewEntityFormatter(&config.Backend{Target: target}) result := f.Format(sample) if len(result.Data) != 3 || result.IsComplete != expected.IsComplete { t.Errorf("The formatter returned an unexpected result size: %v\n", result) } for k, expectedValue := range expected.Data { if v, ok := result.Data[k]; !ok || v != expectedValue { t.Errorf("The formatter returned an unexpected result for the key %s: %v\n", k, v) } } } func TestEntityFormatter_targetingNested(t *testing.T) { target := "group1" sub := map[string]interface{}{ "b": true, "c": 42, "d": "tupu", } sample := Response{ Data: map[string]interface{}{ target: map[string]interface{}{ "supu": 42, "tupu": false, "foo": "bar", target: sub, }, }, IsComplete: true, } expected := Response{ Data: sub, IsComplete: true, } f := NewEntityFormatter(&config.Backend{Target: target + "." + target}) result := f.Format(sample) if len(result.Data) != 3 || result.IsComplete != expected.IsComplete { t.Errorf("The formatter returned an unexpected result size: %v\n", result) } for k, expectedValue := range expected.Data { if v, ok := result.Data[k]; !ok || v != expectedValue { t.Errorf("The formatter returned an unexpected result for the key %s: %v\n", k, v) } } } func TestEntityFormatter_targetingUnknownFields(t *testing.T) { target := "group1" sample := Response{ Data: map[string]interface{}{ "supu": 42, "tupu": false, "foo": "bar", }, IsComplete: true, } f := NewEntityFormatter(&config.Backend{Target: target}) result := f.Format(sample) if len(result.Data) != 0 || result.IsComplete != sample.IsComplete { t.Errorf("The formatter returned an unexpected result size: %v\n", result) } } func TestEntityFormatter_targetingNonObjects(t *testing.T) { target := "group1" sample := Response{ Data: map[string]interface{}{ "supu": 42, "tupu": false, "foo": "bar", target: false, }, IsComplete: true, } f := NewEntityFormatter(&config.Backend{Target: target}) result := f.Format(sample) if len(result.Data) != 0 || result.IsComplete != sample.IsComplete { t.Errorf("The formatter returned an unexpected result size: %v\n", result) } } func TestEntityFormatter_altogether(t *testing.T) { sample := Response{ Data: map[string]interface{}{ "supu": 42, "tupu": false, "foo": "bar", "a": map[string]interface{}{ "b": true, "c": 42, "d": "tupu", }, }, IsComplete: true, } expected := Response{ Data: map[string]interface{}{ "group": map[string]interface{}{ "D": "tupu", }, }, IsComplete: true, } f := NewEntityFormatter(&config.Backend{ Target: "a", AllowList: []string{"d"}, Group: "group", Mapping: map[string]string{"d": "D"}, }) result := f.Format(sample) v, ok := result.Data["group"] if !ok { t.Errorf("The formatter returned an unexpected result for the field group.D: %v\n", result) } tmp := v.(map[string]interface{}) if d, okk := tmp["D"]; !okk || d != "tupu" { t.Errorf("The formatter returned an unexpected result for the field group.D: %v\n", result) } if len(tmp) != 1 { t.Errorf("The formatter returned an unexpected result size for the field group: %v\n", result) } if len(result.Data) != 1 || result.IsComplete != expected.IsComplete { t.Errorf("The formatter returned an unexpected result size: %v\n", result) } } func TestEntityFormatter_flatmap(t *testing.T) { sub := map[string]interface{}{ "b": true, "c": 42, "d": "tupu", "e": []interface{}{1, 2, 3, 4}, } sample := Response{ Data: map[string]interface{}{ "content": map[string]interface{}{ "supu": 42, "tupu": false, "foo": "bar", "a": sub, "collection": []interface{}{sub, sub, sub, sub}, "y": []interface{}{0, 1, 2, 3, 4, 5, 6}, "z": []interface{}{10, 11, 12, 13, 14, 15, 16}, }, }, IsComplete: true, } expected := Response{ Data: map[string]interface{}{ "group": map[string]interface{}{ "SUPUUUUU": 42, "tupu": false, "foo": "bar", "a": map[string]interface{}{ "BOOOOO": true, "c": 42, "d": "tupu", "e": []interface{}{1, 2, 3, 4}, }, "collection": []interface{}{ map[string]interface{}{"x": 42}, map[string]interface{}{"x": 42}, map[string]interface{}{"x": 42}, map[string]interface{}{"x": 42}, }, "z": []interface{}{10, 11, 12, 13, 14, 15, 16, 0, 1, 2, 3, 4, 5, 6}, }, }, IsComplete: true, } f := NewEntityFormatter(&config.Backend{ Target: "content", Group: "group", ExtraConfig: config.ExtraConfig{ Namespace: map[string]interface{}{ flatmapKey: []interface{}{ map[string]interface{}{ "type": "del", "args": []interface{}{"c"}, }, map[string]interface{}{ "type": "append", "args": []interface{}{"y", "z"}, }, map[string]interface{}{ "type": "move", "args": []interface{}{"supu", "SUPUUUUU"}, }, map[string]interface{}{ "type": "move", "args": []interface{}{"a.b", "a.BOOOOO"}, }, map[string]interface{}{ "type": "del", "args": []interface{}{ "collection.*.b", "collection.*.d", "collection.*.e", }, }, map[string]interface{}{ "type": "move", "args": []interface{}{"collection.*.c", "collection.*.x"}, }, }, }, }, }) result := f.Format(sample) if len(result.Data) != len(expected.Data) || result.IsComplete != expected.IsComplete { t.Errorf("The formatter returned an unexpected result size: %v\n", result.Data) } if !reflect.DeepEqual(expected.Data, result.Data) { t.Errorf("unexpected result: %v", result.Data) } } func TestNewFlatmapMiddleware(t *testing.T) { sub := map[string]interface{}{ "b": true, "c": 42, "d": "tupu", "e": []interface{}{1, 2, 3, 4}, } sample := Response{ Data: map[string]interface{}{ "supu": 42, "tupu": false, "foo": "bar", "a": sub, "collection": []interface{}{sub, sub, sub, sub}, "y": []interface{}{0, 1, 2, 3, 4, 5, 6}, "z": []interface{}{10, 11, 12, 13, 14, 15, 16}, }, IsComplete: true, } expected := Response{ Data: map[string]interface{}{ "SUPUUUUU": 42, "tupu": false, "foo": "bar", "a": map[string]interface{}{ "BOOOOO": true, "c": 42, "d": "tupu", "e": []interface{}{1, 2, 3, 4}, }, "collection": []interface{}{ map[string]interface{}{"x": 42}, map[string]interface{}{"x": 42}, map[string]interface{}{"x": 42}, map[string]interface{}{"x": 42}, }, "z": []interface{}{10, 11, 12, 13, 14, 15, 16, 0, 1, 2, 3, 4, 5, 6}, }, IsComplete: true, } p := NewFlatmapMiddleware( logging.NoOp, &config.EndpointConfig{ ExtraConfig: config.ExtraConfig{ Namespace: map[string]interface{}{ flatmapKey: []interface{}{ map[string]interface{}{ "type": "del", "args": []interface{}{"c"}, }, map[string]interface{}{ "type": "append", "args": []interface{}{"y", "z"}, }, map[string]interface{}{ "type": "move", "args": []interface{}{"supu", "SUPUUUUU"}, }, map[string]interface{}{ "type": "move", "args": []interface{}{"a.b", "a.BOOOOO"}, }, map[string]interface{}{ "type": "del", "args": []interface{}{"collection.*.b"}, }, map[string]interface{}{ "type": "del", "args": []interface{}{"collection.*.d"}, }, map[string]interface{}{ "type": "del", "args": []interface{}{"collection.*.e"}, }, map[string]interface{}{ "type": "move", "args": []interface{}{"collection.*.c", "collection.*.x"}, }, }, }, }, }, )(func(_ context.Context, _ *Request) (*Response, error) { return &sample, nil }) result, err := p(context.TODO(), nil) if err != nil { t.Error(err) } if len(result.Data) != len(expected.Data) { t.Errorf("The formatter returned an unexpected result size: %v\n", result.Data) } if result.IsComplete != expected.IsComplete { t.Errorf("The formatter returned an unexpected completion flag: %v\n", result.IsComplete) } if !reflect.DeepEqual(expected.Data, result.Data) { t.Errorf("unexpected result: %v", result.Data) } } ================================================ FILE: proxy/graphql.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "bytes" "context" "fmt" "io" "net/url" "strconv" "strings" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" "github.com/luraproject/lura/v2/transport/http/client/graphql" ) // NewGraphQLMiddleware returns a middleware with or without the GraphQL // proxy wrapping the next element (depending on the configuration). // It supports both queries and mutations. // For queries, it completes the variables object using the request params. // For mutations, it overides the defined variables with the request body. // The resulting request will have a proper graphql body with the query and the // variables func NewGraphQLMiddleware(logger logging.Logger, remote *config.Backend) Middleware { opt, err := graphql.GetOptions(remote.ExtraConfig) if err != nil { if err != graphql.ErrNoConfigFound { logger.Warning( fmt.Sprintf("[BACKEND: %s %s -> %s][GraphQL] %s", remote.ParentEndpoint, remote.ParentEndpoint, remote.URLPattern, err.Error())) } return emptyMiddlewareFallback(logger) } extractor := graphql.New(*opt) var generateBodyFn func(*Request) ([]byte, error) var generateQueryFn func(*Request) (url.Values, error) switch opt.Type { case graphql.OperationMutation: generateBodyFn = func(req *Request) ([]byte, error) { if req.Body == nil { return extractor.BodyFromBody(strings.NewReader("")) } defer req.Body.Close() return extractor.BodyFromBody(req.Body) } generateQueryFn = func(req *Request) (url.Values, error) { if req.Body == nil { return extractor.QueryFromBody(strings.NewReader("")) } defer req.Body.Close() return extractor.QueryFromBody(req.Body) } case graphql.OperationQuery: generateBodyFn = func(req *Request) ([]byte, error) { return extractor.BodyFromParams(req.Params) } generateQueryFn = func(req *Request) (url.Values, error) { return extractor.QueryFromParams(req.Params) } default: return emptyMiddlewareFallback(logger) } return func(next ...Proxy) Proxy { if len(next) > 1 { logger.Fatal("too many proxies for this %s %s -> %s proxy middleware: NewGraphQLMiddleware only accepts 1 proxy, got %d", remote.ParentEndpointMethod, remote.ParentEndpoint, remote.URLPattern, len(next)) return nil } logger.Debug( fmt.Sprintf( "[BACKEND: %s %s -> %s][GraphQL] Operation: %s, Method: %s", remote.ParentEndpointMethod, remote.ParentEndpoint, remote.URLPattern, opt.Type, opt.Method, ), ) if opt.Method == graphql.MethodGet { return func(ctx context.Context, req *Request) (*Response, error) { q, err := generateQueryFn(req) if err != nil { return nil, err } req.Body = io.NopCloser(bytes.NewReader([]byte{})) req.Method = string(opt.Method) req.Headers["Content-Length"] = []string{"0"} // even when there is no content, we just set the content-type // header to be safe if the server side checks it: req.Headers["Content-Type"] = []string{"application/json"} if req.Query != nil { for k, vs := range q { for _, v := range vs { req.Query.Add(k, v) } } } else { req.Query = q } return next[0](ctx, req) } } return func(ctx context.Context, req *Request) (*Response, error) { b, err := generateBodyFn(req) if err != nil { return nil, err } req.Body = io.NopCloser(bytes.NewReader(b)) req.Method = string(opt.Method) req.Headers["Content-Length"] = []string{strconv.Itoa(len(b))} req.Headers["Content-Type"] = []string{"application/json"} return next[0](ctx, req) } } } ================================================ FILE: proxy/graphql_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "context" "encoding/json" "io" "reflect" "strings" "testing" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" "github.com/luraproject/lura/v2/transport/http/client/graphql" ) func TestNewGraphQLMiddleware_mutation(t *testing.T) { query := "mutation addAuthor($author: [AddAuthorInput!]!) {\n addAuthor(input: $author) {\n author {\n id\n name\n }\n }\n}\n" mw := NewGraphQLMiddleware( logging.NoOp, &config.Backend{ ExtraConfig: config.ExtraConfig{ graphql.Namespace: map[string]interface{}{ "type": "mutation", "query": query, "variables": map[string]interface{}{ "author": map[string]interface{}{ "name": "A.N. Author", "dob": "2000-01-01", "posts": []interface{}{}, }, }, }, }, }, ) expectedResponse := &Response{ Data: map[string]interface{}{"foo": "bar"}, } prxy := mw(func(ctx context.Context, req *Request) (*Response, error) { b, err := io.ReadAll(req.Body) req.Body.Close() if err != nil { return nil, err } var request graphql.GraphQLRequest if err := json.Unmarshal(b, &request); err != nil { return nil, err } return expectedResponse, nil }) resp, err := prxy(context.Background(), &Request{ Body: io.NopCloser(strings.NewReader(`{ "name": "foo", "dob": "bar" }`)), Params: map[string]string{}, Headers: map[string][]string{}, }) if err != nil { t.Error(err) return } if !reflect.DeepEqual(resp, expectedResponse) { t.Errorf("unexpected response: %v", resp) } } func TestNewGraphQLMiddleware_query(t *testing.T) { query := "{ q(func: uid(1)) { uid } }" mw := NewGraphQLMiddleware( logging.NoOp, &config.Backend{ ExtraConfig: config.ExtraConfig{ graphql.Namespace: map[string]interface{}{ "method": "get", "type": "query", "query": query, "variables": map[string]interface{}{ "name": "{foo}", "dob": "{bar}", "posts": []interface{}{}, }, }, }, }, ) expectedResponse := &Response{Data: map[string]interface{}{"foo": "bar"}} prxy := mw(func(ctx context.Context, req *Request) (*Response, error) { request := graphql.GraphQLRequest{ Query: req.Query.Get("query"), OperationName: req.Query.Get("operationName"), Variables: map[string]interface{}{}, } json.Unmarshal([]byte(req.Query.Get("variables")), &request.Variables) if request.Query != query { t.Errorf("unexpected query: %s", request.Query) } if len(request.Variables) != 3 { t.Errorf("unexpected variables: %v", request.Variables) } if v, ok := request.Variables["name"].(string); !ok || v != "foo" { t.Errorf("unexpected var name: %v", request.Variables["name"]) } if v, ok := request.Variables["dob"].(string); !ok || v != "bar" { t.Errorf("unexpected var dob: %v", request.Variables["dob"]) } return expectedResponse, nil }) resp, err := prxy(context.Background(), &Request{ Params: map[string]string{ "Foo": "foo", "Bar": "bar", }, Body: io.NopCloser(strings.NewReader("")), Headers: map[string][]string{}, }) if err != nil { t.Error(err) return } if !reflect.DeepEqual(resp, expectedResponse) { t.Errorf("unexpected response: %v", resp) } } ================================================ FILE: proxy/headers_filter.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "context" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" ) // NewFilterHeadersMiddleware returns a middleware with or without a header filtering // proxy wrapping the next element (depending on the configuration). func NewFilterHeadersMiddleware(logger logging.Logger, remote *config.Backend) Middleware { if len(remote.HeadersToPass) == 0 { return emptyMiddlewareFallback(logger) } return func(next ...Proxy) Proxy { if len(next) > 1 { logger.Fatal("too many proxies for this %s %s -> %s proxy middleware: NewFilterHeadersMiddleware only accepts 1 proxy, got %d", remote.ParentEndpointMethod, remote.ParentEndpoint, remote.URLPattern, len(next)) return nil } nextProxy := next[0] return func(ctx context.Context, request *Request) (*Response, error) { if len(request.Headers) == 0 { return nextProxy(ctx, request) } numHeadersToPass := 0 for _, v := range remote.HeadersToPass { if _, ok := request.Headers[v]; ok { numHeadersToPass++ } } if numHeadersToPass == len(request.Headers) { // all the headers should pass, no need to clone the headers return nextProxy(ctx, request) } // ATTENTION: this is not a clone of headers! // this just filters the headers we do not want to send: // issues and race conditions could happen the same way as when we // do not filter the headers. This is a design decission, and if we // want to clone the header values (because of write modifications), // that should be done at an upper level (so the approach is the same // for non filtered parallel requests). newHeaders := make(map[string][]string, numHeadersToPass) for _, v := range remote.HeadersToPass { if values, ok := request.Headers[v]; ok { newHeaders[v] = values } } return nextProxy(ctx, &Request{ Method: request.Method, URL: request.URL, Query: request.Query, Path: request.Path, Body: request.Body, Params: request.Params, Headers: newHeaders, }) } } } ================================================ FILE: proxy/headers_filter_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "context" "testing" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" ) func TestNewFilterHeadersMiddleware(t *testing.T) { mw := NewFilterHeadersMiddleware( logging.NoOp, &config.Backend{ HeadersToPass: []string{ "X-This-Shall-Pass", "X-Gandalf-Will-Pass", }, }, ) var receivedReq *Request prxy := mw(func(ctx context.Context, req *Request) (*Response, error) { receivedReq = req return nil, nil }) sentReq := &Request{ Body: nil, Params: map[string]string{}, Headers: map[string][]string{ "X-This-Shall-Pass": []string{"tupu", "supu"}, "X-You-Shall-Not-Pass": []string{"Balrog"}, "X-Gandalf-Will-Pass": []string{"White", "Grey"}, "X-Drop-Tables": []string{"foo"}, }, } prxy(context.Background(), sentReq) if receivedReq == sentReq { t.Errorf("request should be different") return } if _, ok := receivedReq.Headers["X-This-Shall-Pass"]; !ok { t.Errorf("missing X-This-Shall-Pass") return } if _, ok := receivedReq.Headers["X-Gandalf-Will-Pass"]; !ok { t.Errorf("missing X-Gandalf-Will-Pass") return } if _, ok := receivedReq.Headers["X-Drop-Tables"]; ok { t.Errorf("should not be there X-Drop-Tables") return } if _, ok := receivedReq.Headers["X-You-Shall-Not-Pass"]; ok { t.Errorf("should not be there X-You-Shall-Not-Pass") return } // check that when headers are the expected, no need to copy sentReq = &Request{ Body: nil, Params: map[string]string{}, Headers: map[string][]string{ "X-This-Shall-Pass": []string{"tupu", "supu"}, }, } prxy(context.Background(), sentReq) if receivedReq != sentReq { t.Errorf("request should be the same, no modification of headers expected") return } } func TestNewFilterHeadersMiddlewareBlockAll(t *testing.T) { mw := NewFilterHeadersMiddleware( logging.NoOp, &config.Backend{ HeadersToPass: []string{""}, }, ) var receivedReq *Request prxy := mw(func(ctx context.Context, req *Request) (*Response, error) { receivedReq = req return nil, nil }) sentReq := &Request{ Body: nil, Params: map[string]string{}, Headers: map[string][]string{ "X-This-Shall-Pass": []string{"tupu", "supu"}, "X-You-Shall-Not-Pass": []string{"Balrog"}, }, } prxy(context.Background(), sentReq) if receivedReq == sentReq { t.Errorf("request should be different") return } if len(receivedReq.Headers) != 0 { t.Errorf("should have blocked all headers") return } } func TestNewFilterHeadersMiddlewareAllowAll(t *testing.T) { mw := NewFilterHeadersMiddleware( logging.NoOp, &config.Backend{ HeadersToPass: []string{}, }, ) var receivedReq *Request prxy := mw(func(ctx context.Context, req *Request) (*Response, error) { receivedReq = req return nil, nil }) sentReq := &Request{ Body: nil, Params: map[string]string{}, Headers: map[string][]string{ "X-This-Shall-Pass": []string{"tupu", "supu"}, "X-You-Shall-Not-Pass": []string{"Balrog"}, }, } prxy(context.Background(), sentReq) if len(receivedReq.Headers) != 2 { t.Errorf("should have let pass all headers") return } } ================================================ FILE: proxy/http.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "bytes" "context" "fmt" "io" "net/http" "strconv" "strings" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/encoding" "github.com/luraproject/lura/v2/logging" "github.com/luraproject/lura/v2/transport/http/client" ) var httpProxy = CustomHTTPProxyFactory(client.NewHTTPClient) // HTTPProxyFactory returns a BackendFactory. The Proxies it creates will use the received net/http.Client func HTTPProxyFactory(client *http.Client) BackendFactory { return CustomHTTPProxyFactory(func(_ context.Context) *http.Client { return client }) } // CustomHTTPProxyFactory returns a BackendFactory. The Proxies it creates will use the received HTTPClientFactory func CustomHTTPProxyFactory(cf client.HTTPClientFactory) BackendFactory { return func(backend *config.Backend) Proxy { return NewHTTPProxy(backend, cf, backend.Decoder) } } // NewHTTPProxy creates a http proxy with the injected configuration, HTTPClientFactory and Decoder func NewHTTPProxy(remote *config.Backend, cf client.HTTPClientFactory, decode encoding.Decoder) Proxy { return NewHTTPProxyWithHTTPExecutor(remote, client.DefaultHTTPRequestExecutor(cf), decode) } // NewHTTPProxyWithHTTPExecutor creates a http proxy with the injected configuration, HTTPRequestExecutor and Decoder func NewHTTPProxyWithHTTPExecutor(remote *config.Backend, re client.HTTPRequestExecutor, dec encoding.Decoder) Proxy { if remote.Encoding == encoding.NOOP { return NewHTTPProxyDetailed(remote, re, client.NoOpHTTPStatusHandler, NoOpHTTPResponseParser) } ef := NewEntityFormatter(remote) rp := DefaultHTTPResponseParserFactory(HTTPResponseParserConfig{dec, ef}) return NewHTTPProxyDetailed(remote, re, client.GetHTTPStatusHandler(remote), rp) } const ( clientHTTPOptions string = "backend/http/client" clientHTTPOptionRedirectPost string = "send_body_on_redirect" ) // redirectPostReaderFactory checks if the clientHTTPOptionRedirectPost is enabled // This will read the body and return a bytes.Buffer with the body content, so we // delegate to http.NewRequest the population of request.GetBody so a redirect (307 // and 308) is executed while maintaining the method and the body // This is necessary since the request comes from another http.Client and it's not // a concrete type that can be copied but just a io.ReaderCloser (*http.body) func redirectPostReaderFactory(cfg *config.Backend) func(r io.ReadCloser) io.Reader { emptyFactory := func(r io.ReadCloser) io.Reader { return r } if cfg == nil || cfg.ExtraConfig == nil { return emptyFactory } v, ok := cfg.ExtraConfig[clientHTTPOptions].(map[string]interface{}) if !ok { return emptyFactory } if opt, ok := v[clientHTTPOptionRedirectPost].(bool); !ok || !opt { return emptyFactory } return func(r io.ReadCloser) io.Reader { if r == http.NoBody || r == nil { return r } buf := new(bytes.Buffer) buf.ReadFrom(r) r.Close() return buf } } // NewHTTPProxyDetailed creates a http proxy with the injected configuration, HTTPRequestExecutor, // Decoder and HTTPResponseParser func NewHTTPProxyDetailed(cfg *config.Backend, re client.HTTPRequestExecutor, ch client.HTTPStatusHandler, rp HTTPResponseParser) Proxy { bodyFactory := redirectPostReaderFactory(cfg) return func(ctx context.Context, request *Request) (*Response, error) { requestToBackend, err := http.NewRequest(strings.ToTitle(request.Method), request.URL.String(), bodyFactory(request.Body)) if err != nil { return nil, err } requestToBackend.Header = make(map[string][]string, len(request.Headers)) for k, vs := range request.Headers { tmp := make([]string, len(vs)) copy(tmp, vs) requestToBackend.Header[k] = tmp } if request.Body != nil { if v, ok := request.Headers["Content-Length"]; ok && len(v) == 1 && v[0] != "chunked" { if size, err := strconv.Atoi(v[0]); err == nil { requestToBackend.ContentLength = int64(size) } } } resp, err := re(ctx, requestToBackend) if requestToBackend.Body != nil { requestToBackend.Body.Close() } select { case <-ctx.Done(): return nil, ctx.Err() default: } if err != nil { return nil, err } resp, err = ch(ctx, resp) if err != nil { if t, ok := err.(responseError); ok { return &Response{ Data: map[string]interface{}{ fmt.Sprintf("error_%s", t.Name()): t, }, Metadata: Metadata{StatusCode: t.StatusCode()}, }, nil } return nil, err } return rp(ctx, resp) } } // NewRequestBuilderMiddleware creates a proxy middleware that parses the request params received // from the outer layer and generates the path to the backend endpoints var NewRequestBuilderMiddleware = func(remote *config.Backend) Middleware { return newRequestBuilderMiddleware(logging.NoOp, remote) } func NewRequestBuilderMiddlewareWithLogger(logger logging.Logger, remote *config.Backend) Middleware { return newRequestBuilderMiddleware(logger, remote) } func newRequestBuilderMiddleware(l logging.Logger, remote *config.Backend) Middleware { return func(next ...Proxy) Proxy { if len(next) > 1 { l.Fatal("too many proxies for this %s %s -> %s proxy middleware: newRequestBuilderMiddleware only accepts 1 proxy, got %d", remote.ParentEndpointMethod, remote.ParentEndpoint, remote.URLPattern, len(next)) return nil } return func(ctx context.Context, r *Request) (*Response, error) { r.GeneratePath(remote.URLPattern) r.Method = remote.Method return next[0](ctx, r) } } } type responseError interface { Error() string Name() string StatusCode() int } ================================================ FILE: proxy/http_benchmark_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "context" "testing" "github.com/luraproject/lura/v2/config" ) func BenchmarkNewRequestBuilderMiddleware(b *testing.B) { backend := config.Backend{ URLPattern: "/supu", Method: "GET", } proxy := NewRequestBuilderMiddleware(&backend)(dummyProxy(&Response{})) b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { proxy(context.Background(), &Request{}) } } ================================================ FILE: proxy/http_response.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "compress/gzip" "context" "io" "net/http" "github.com/luraproject/lura/v2/encoding" ) // HTTPResponseParser defines how the response is parsed from http.Response to Response object type HTTPResponseParser func(context.Context, *http.Response) (*Response, error) // DefaultHTTPResponseParserConfig defines a default HTTPResponseParserConfig var DefaultHTTPResponseParserConfig = HTTPResponseParserConfig{ func(_ io.Reader, _ *map[string]interface{}) error { return nil }, EntityFormatterFunc(func(r Response) Response { return r }), } // HTTPResponseParserConfig contains the config for a given HttpResponseParser type HTTPResponseParserConfig struct { Decoder encoding.Decoder EntityFormatter EntityFormatter } // HTTPResponseParserFactory creates HTTPResponseParser from a given HTTPResponseParserConfig type HTTPResponseParserFactory func(HTTPResponseParserConfig) HTTPResponseParser // DefaultHTTPResponseParserFactory is the default implementation of HTTPResponseParserFactory func DefaultHTTPResponseParserFactory(cfg HTTPResponseParserConfig) HTTPResponseParser { return func(_ context.Context, resp *http.Response) (*Response, error) { defer resp.Body.Close() var reader io.ReadCloser switch resp.Header.Get("Content-Encoding") { case "gzip": gzipReader, err := gzip.NewReader(resp.Body) if err != nil { return nil, err } reader = gzipReader defer reader.Close() default: reader = resp.Body } var data map[string]interface{} if err := cfg.Decoder(reader, &data); err != nil { return nil, err } newResponse := Response{Data: data, IsComplete: true} newResponse = cfg.EntityFormatter.Format(newResponse) return &newResponse, nil } } // NoOpHTTPResponseParser is a HTTPResponseParser implementation that just copies the // http response body into the proxy response IO func NoOpHTTPResponseParser(ctx context.Context, resp *http.Response) (*Response, error) { return &Response{ Data: map[string]interface{}{}, IsComplete: true, Io: NewReadCloserWrapper(ctx, resp.Body), Metadata: Metadata{ StatusCode: resp.StatusCode, Headers: resp.Header, }, }, nil } ================================================ FILE: proxy/http_response_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "compress/gzip" "context" "io" "net/http" "net/http/httptest" "testing" "github.com/luraproject/lura/v2/encoding" ) func TestNopHTTPResponseParser(t *testing.T) { w := httptest.NewRecorder() handler := func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("header1", "value1") w.Write([]byte("some nice, interesting and long content")) } req, _ := http.NewRequest("GET", "/url", http.NoBody) handler(w, req) result, err := NoOpHTTPResponseParser(context.Background(), w.Result()) if err != nil { t.Error(err.Error()) return } if !result.IsComplete { t.Error("unexpected result") } if len(result.Data) != 0 { t.Error("unexpected result") } if result.Metadata.StatusCode != http.StatusOK { t.Error("unexpected result") } headers := result.Metadata.Headers if h, ok := headers["Header1"]; !ok || h[0] != "value1" { t.Error("unexpected result:", result.Metadata.Headers) } body, err := io.ReadAll(result.Io) if err != nil { t.Error("unexpected error:", err.Error()) } if string(body) != "some nice, interesting and long content" { t.Error("unexpected result") } } func TestDefaultHTTPResponseParser_gzipped(t *testing.T) { w := httptest.NewRecorder() handler := func(w http.ResponseWriter, _ *http.Request) { gzipWriter, _ := gzip.NewWriterLevel(w, gzip.BestSpeed) defer gzipWriter.Close() w.Header().Set("Vary", "Accept-Encoding") w.Header().Set("Content-Encoding", "gzip") w.Header().Set("Content-Type", "application/json; charset=utf-8") gzipWriter.Write([]byte(`{"msg":"some nice, interesting and long content"}`)) gzipWriter.Flush() } req, _ := http.NewRequest("GET", "/url", http.NoBody) req.Header.Add("Accept-Encoding", "gzip") handler(w, req) result, err := DefaultHTTPResponseParserFactory(HTTPResponseParserConfig{ Decoder: encoding.JSONDecoder, EntityFormatter: DefaultHTTPResponseParserConfig.EntityFormatter, })(context.Background(), w.Result()) if err != nil { t.Error(err) } if !result.IsComplete { t.Error("unexpected result") } if len(result.Data) != 1 { t.Error("unexpected result") } if m, ok := result.Data["msg"]; !ok || m != "some nice, interesting and long content" { t.Error("unexpected result") } } func TestDefaultHTTPResponseParser_gzipped_bad_header(t *testing.T) { w := httptest.NewRecorder() handler := func(w http.ResponseWriter, _ *http.Request) { gzipWriter, _ := gzip.NewWriterLevel(w, gzip.BestSpeed) defer gzipWriter.Close() w.Header().Set("Vary", "Accept-Encoding") w.Header().Set("Content-Encoding", "gzip") w.Header().Set("Content-Type", "application/json; charset=utf-8") gzipWriter.Write([]byte(`{"msg":"some nice, interesting and long content"}`)) gzipWriter.Flush() } req, _ := http.NewRequest("GET", "/url", http.NoBody) // explicitly disable gzip encoding req.Header.Add("Accept-Encoding", "identity;q=0") handler(w, req) result, err := DefaultHTTPResponseParserFactory(HTTPResponseParserConfig{ Decoder: encoding.JSONDecoder, EntityFormatter: DefaultHTTPResponseParserConfig.EntityFormatter, })(context.Background(), w.Result()) if err != nil { t.Error(err) } if !result.IsComplete { t.Error("unexpected result") } if len(result.Data) != 1 { t.Error("unexpected result") } if m, ok := result.Data["msg"]; !ok || m != "some nice, interesting and long content" { t.Error("unexpected result") } } func TestDefaultHTTPResponseParser_plain(t *testing.T) { w := httptest.NewRecorder() handler := func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Write([]byte(`{"msg":"some nice, interesting and long content"}`)) } req, _ := http.NewRequest("GET", "/url", http.NoBody) handler(w, req) result, err := DefaultHTTPResponseParserFactory(HTTPResponseParserConfig{ Decoder: encoding.JSONDecoder, EntityFormatter: DefaultHTTPResponseParserConfig.EntityFormatter, })(context.Background(), w.Result()) if err != nil { t.Error(err) } if !result.IsComplete { t.Error("unexpected result") } if len(result.Data) != 1 { t.Error("unexpected result") } if m, ok := result.Data["msg"]; !ok || m != "some nice, interesting and long content" { t.Error("unexpected result") } } ================================================ FILE: proxy/http_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "bytes" "context" "encoding/json" "errors" "fmt" "io" "net/http" "net/http/httptest" "net/url" "strings" "sync/atomic" "testing" "time" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/encoding" "github.com/luraproject/lura/v2/transport/http/client" ) func TestNewHTTPProxy_ok(t *testing.T) { expectedMethod := "GET" backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.ContentLength != 11 { t.Errorf("unexpected request size. Want: 11. Have: %d", r.ContentLength) } if h := r.Header.Get("Content-Length"); h != "11" { t.Errorf("unexpected content-length header. Want: 11. Have: %s", h) } if r.Method != expectedMethod { t.Errorf("Wrong request method. Want: %s. Have: %s", expectedMethod, r.Method) } if h := r.Header.Get("X-First"); h != "first" { t.Errorf("unexpected first header: %s", h) } if h := r.Header.Get("X-Second"); h != "second" { t.Errorf("unexpected second header: %s", h) } r.Header.Del("X-Second") fmt.Fprintf(w, "{\"supu\":42, \"tupu\":true, \"foo\": \"bar\"}") })) defer backendServer.Close() rpURL, _ := url.Parse(backendServer.URL) backend := config.Backend{ Decoder: encoding.JSONDecoder, } request := Request{ Method: expectedMethod, Path: "/", URL: rpURL, Body: newDummyReadCloser(`{"abc": 42}`), Headers: map[string][]string{ "X-First": {"first"}, "X-Second": {"second"}, "Content-Length": {"11"}, }, } mustEnd := time.After(time.Duration(150) * time.Millisecond) result, err := HTTPProxyFactory(http.DefaultClient)(&backend)(context.Background(), &request) if err != nil { t.Errorf("The proxy returned an unexpected error: %s\n", err.Error()) return } if result == nil { t.Errorf("The proxy returned a null result\n") return } select { case <-mustEnd: t.Errorf("Error: expected response") return default: } tmp, ok := result.Data["supu"] if !ok { t.Errorf("The proxy returned an unexpected result: %v\n", result) } supuValue, err := tmp.(json.Number).Int64() if err != nil || supuValue != 42 { t.Errorf("The proxy returned an unexpected result: %v\n", supuValue) } if v, ok := result.Data["tupu"]; !ok || !v.(bool) { t.Errorf("The proxy returned an unexpected result: %v\n", result) } if v, ok := result.Data["foo"]; !ok || v.(string) != "bar" { t.Errorf("The proxy returned an unexpected result: %v\n", result) } if v, ok := request.Headers["X-Second"]; !ok || len(v) != 1 { t.Errorf("the proxy request headers were changed: %v", request.Headers) } } func TestNewHTTPProxy_cancel(t *testing.T) { expectedMethod := "GET" backendServer := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { time.Sleep(time.Duration(300) * time.Millisecond) })) defer backendServer.Close() rpURL, _ := url.Parse(backendServer.URL) backend := config.Backend{ Decoder: encoding.JSONDecoder, } request := Request{ Method: expectedMethod, Path: "/", URL: rpURL, Body: newDummyReadCloser(""), } mustEnd := time.After(time.Duration(150) * time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), time.Duration(10)*time.Millisecond) defer cancel() response, err := httpProxy(&backend)(ctx, &request) if err == nil || err.Error() != "context deadline exceeded" { t.Errorf("The proxy didn't propagate a timeout error: %s\n", err) } if response != nil { t.Errorf("We weren't expecting a response but we got one: %v\n", response) return } select { case <-mustEnd: t.Errorf("We were expecting a response at this point in time!\n") return default: } } func TestNewHTTPProxy_badResponseBody(t *testing.T) { expectedMethod := "GET" backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { fmt.Fprintf(w, "supu") })) defer backendServer.Close() rpURL, _ := url.Parse(backendServer.URL) backend := config.Backend{ Decoder: encoding.JSONDecoder, } request := Request{ Method: expectedMethod, Path: "/", URL: rpURL, Body: newDummyReadCloser(""), } mustEnd := time.After(time.Duration(150) * time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), time.Duration(10)*time.Millisecond) defer cancel() response, err := httpProxy(&backend)(ctx, &request) if err == nil || err.Error() != "invalid character 's' looking for beginning of value" { t.Errorf("The proxy didn't propagate the backend error: %s\n", err) } if response != nil { t.Errorf("We weren't expecting a response but we got one: %v\n", response) } select { case <-mustEnd: t.Errorf("Error: expected response") default: } } func TestNewHTTPProxy_badStatusCode(t *testing.T) { expectedMethod := "GET" backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { http.Error(w, "booom", 500) })) defer backendServer.Close() rpURL, _ := url.Parse(backendServer.URL) backend := config.Backend{ Decoder: encoding.JSONDecoder, } request := Request{ Method: expectedMethod, Path: "/", URL: rpURL, Body: newDummyReadCloser(""), } mustEnd := time.After(time.Duration(150) * time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), time.Duration(10)*time.Millisecond) defer cancel() response, err := httpProxy(&backend)(ctx, &request) if err == nil || !strings.HasPrefix(err.Error(), "invalid status code") { t.Errorf("The proxy didn't propagate the backend error: %s\n", err) } if response != nil { t.Errorf("We weren't expecting a response but we got one: %v\n", response) } select { case <-mustEnd: t.Errorf("Error: expected response") default: } } func TestNewHTTPProxy_badStatusCode_detailed(t *testing.T) { expectedMethod := "GET" backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { http.Error(w, "booom", 500) })) defer backendServer.Close() rpURL, _ := url.Parse(backendServer.URL) backend := config.Backend{ Decoder: encoding.JSONDecoder, ExtraConfig: config.ExtraConfig{ client.Namespace: map[string]interface{}{ "return_error_details": "some", }, }, } request := Request{ Method: expectedMethod, Path: "/", URL: rpURL, Body: newDummyReadCloser(""), } mustEnd := time.After(time.Duration(150) * time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), time.Duration(10)*time.Millisecond) defer cancel() response, err := httpProxy(&backend)(ctx, &request) if err != nil { t.Errorf("The proxy propagated the backend error: %s", err.Error()) } if response == nil { t.Error("We were expecting a response but we got none") return } if response.Metadata.StatusCode != 500 { t.Errorf("unexpected error code: %d", response.Metadata.StatusCode) } b, _ := json.Marshal(response.Data) if string(b) != `{"error_some":{"http_status_code":500,"http_body":"booom\n","http_body_encoding":"text/plain; charset=utf-8"}}` { t.Errorf("unexpected response content: %s", string(b)) } select { case <-mustEnd: t.Errorf("Error: expected response") default: } } func TestNewHTTPProxy_decodingError(t *testing.T) { expectedMethod := "GET" backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { fmt.Fprintf(w, `{"supu": 42}`) })) defer backendServer.Close() rpURL, _ := url.Parse(backendServer.URL) backend := config.Backend{ Decoder: func(_ io.Reader, _ *map[string]interface{}) error { return errors.New("booom") }, } request := Request{ Method: expectedMethod, Path: "/", URL: rpURL, Body: newDummyReadCloser(""), } mustEnd := time.After(time.Duration(150) * time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), time.Duration(10)*time.Millisecond) defer cancel() response, err := httpProxy(&backend)(ctx, &request) if err == nil || err.Error() != "booom" { t.Errorf("The proxy returned an unexpected error: %s\n", err.Error()) } if response != nil { t.Errorf("We weren't expecting a response but we got one: %v\n", response) } select { case <-mustEnd: t.Errorf("Error: expected response") default: } } func TestNewHTTPProxy_badMethod(t *testing.T) { backendServer := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { t.Error("The handler shouldn't be called") })) defer backendServer.Close() rpURL, _ := url.Parse(backendServer.URL) backend := config.Backend{ Decoder: func(_ io.Reader, _ *map[string]interface{}) error { t.Error("The decoder shouldn't be called") return nil }, } request := Request{ Method: "\n", Path: "/", URL: rpURL, Body: newDummyReadCloser(""), } mustEnd := time.After(time.Duration(150) * time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), time.Duration(10)*time.Millisecond) defer cancel() _, err := httpProxy(&backend)(ctx, &request) if err == nil { t.Error("The proxy didn't return the expected error") return } if err.Error() != "net/http: invalid method \"\\n\"" { t.Errorf("The proxy returned an unexpected error: %s\n", err.Error()) return } select { case <-mustEnd: t.Errorf("Error: expected response") default: } } func TestNewHTTPProxy_requestKo(t *testing.T) { backendServer := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { t.Error("The handler shouldn't be called") })) defer backendServer.Close() rpURL, _ := url.Parse(backendServer.URL) backend := config.Backend{ Decoder: func(_ io.Reader, _ *map[string]interface{}) error { t.Error("The decoder shouldn't be called") return nil }, } request := Request{ Method: "GET", Path: "/", URL: rpURL, Body: newDummyReadCloser(""), } mustEnd := time.After(time.Duration(150) * time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), time.Duration(100)*time.Millisecond) defer cancel() expectedError := fmt.Errorf("MAYDAY, MAYDAY") _, err := NewHTTPProxyWithHTTPExecutor(&backend, func(_ context.Context, _ *http.Request) (*http.Response, error) { return nil, expectedError }, backend.Decoder)(ctx, &request) if err == nil { t.Error("The proxy didn't return the expected error") return } if err != expectedError { t.Errorf("The proxy returned an unexpected error: %s\n", err.Error()) return } select { case <-mustEnd: t.Errorf("Error: expected response") default: } } func TestNewRequestBuilderMiddleware_ok(t *testing.T) { expected := errors.New("error to be propagated") expectedMethod := "GET" expectedPath := "/supu" assertion := func(_ context.Context, request *Request) (*Response, error) { if request.Method != expectedMethod { err := fmt.Errorf("Wrong request method. Want: %s. Have: %s", expectedMethod, request.Method) t.Error(err.Error()) return nil, err } if request.Path != expectedPath { err := fmt.Errorf("Wrong request path. Want: %s. Have: %s", expectedPath, request.Path) t.Error(err.Error()) return nil, err } return nil, expected } sampleBackend := config.Backend{ URLPattern: expectedPath, Method: expectedMethod, } mw := NewRequestBuilderMiddleware(&sampleBackend) response, err := mw(assertion)(context.Background(), &Request{}) if err != expected { t.Errorf("The middleware propagated an unexpected error: %s\n", err.Error()) } if response != nil { t.Errorf("We weren't expecting a response but we got one: %v\n", response) } } func TestDefaultHTTPResponseParserConfig_nopDecoder(t *testing.T) { result := map[string]interface{}{} if err := DefaultHTTPResponseParserConfig.Decoder(bytes.NewBufferString("some body"), &result); err != nil { t.Error(err.Error()) } if len(result) != 0 { t.Error("unexpected result") } } func TestDefaultHTTPResponseParserConfig_nopEntityFormatter(t *testing.T) { expected := Response{Data: map[string]interface{}{"supu": "tupu"}, IsComplete: true} result := DefaultHTTPResponseParserConfig.EntityFormatter.Format(expected) if !result.IsComplete { t.Error("unexpected result") } d, ok := result.Data["supu"] if !ok { t.Error("unexpected result") } if v, ok := d.(string); !ok || v != "tupu" { t.Error("unexpected result") } } func TestNewHTTPProxy_noopDecoder(t *testing.T) { expectedcontent := "some nice, interesting and long content" backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("header1", "value1") w.WriteHeader(http.StatusOK) w.Write([]byte(expectedcontent)) })) defer backendServer.Close() rpURL, _ := url.Parse(backendServer.URL) backend := config.Backend{ Encoding: encoding.NOOP, Decoder: encoding.NoOpDecoder, } request := Request{ Method: "GET", Path: "/", URL: rpURL, Body: newDummyReadCloser(""), } mustEnd := time.After(time.Duration(150) * time.Millisecond) result, err := HTTPProxyFactory(http.DefaultClient)(&backend)(context.Background(), &request) if err != nil { t.Errorf("The proxy returned an unexpected error: %s\n", err.Error()) return } if result == nil { t.Errorf("The proxy returned a null result\n") return } select { case <-mustEnd: t.Errorf("Error: expected response") return default: } if len(result.Data) > 0 { t.Error("unexpected data:", result.Data) return } if result.Metadata.StatusCode != http.StatusOK { t.Error("unexpected status code:", result.Metadata.StatusCode) return } if len(result.Metadata.Headers["Header1"]) < 1 || result.Metadata.Headers["Header1"][0] != "value1" { t.Error("unexpected header:", result.Metadata.Headers) return } b := &bytes.Buffer{} if _, err := b.ReadFrom(result.Io); err != nil { t.Error(err, b.String()) return } if content := b.String(); content != expectedcontent { t.Error("unexpected content:", content) } } func TestNewHTTPProxy_redirectWithBody(t *testing.T) { var executed atomic.Uint64 expectedBody := `{"message":"redirected"}` expectedResponse := `{"message":"ok"}` backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { data, err := io.ReadAll(r.Body) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } if string(data) != expectedBody { t.Errorf("invalid data: %s", string(data)) return } executed.Add(1) w.Write([]byte(`{"message":"ok"}`)) })) defer backendServer.Close() redirServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { executed.Add(1) http.Redirect(w, r, backendServer.URL, http.StatusPermanentRedirect) })) defer redirServer.Close() rpURL, _ := url.Parse(redirServer.URL) backend := config.Backend{ Decoder: encoding.JSONDecoder, ExtraConfig: map[string]interface{}{ clientHTTPOptions: map[string]interface{}{ clientHTTPOptionRedirectPost: true, }, }, } request := Request{ Method: "POST", Path: "/", URL: rpURL, Body: newDummyReadCloser(expectedBody), } mustEnd := time.After(time.Duration(150) * time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), time.Duration(100)*time.Millisecond) defer cancel() response, err := httpProxy(&backend)(ctx, &request) if err != nil { t.Errorf("The proxy propagated a backend errors: %s\n", err) return } respData, err := json.Marshal(response.Data) if err != nil { t.Error(err) return } if string(respData) != expectedResponse { t.Errorf("unexpected response data: '%s'", string(respData)) } select { case <-mustEnd: t.Errorf("Error: expected response") default: } if executed.Load() != 2 { t.Errorf("number of executions should be 2 not %d", executed.Load()) } } ================================================ FILE: proxy/logging.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "context" "strings" "time" "github.com/luraproject/lura/v2/logging" ) // NewLoggingMiddleware creates proxy middleware for logging requests and responses func NewLoggingMiddleware(logger logging.Logger, name string) Middleware { logPrefix := "[" + strings.ToUpper(name) + "]" return func(next ...Proxy) Proxy { if len(next) > 1 { logger.Fatal("too many proxies for this proxy middleware: NewLoggingMiddleware only accepts 1 proxy, got %d", len(next)) return nil } return func(ctx context.Context, request *Request) (*Response, error) { begin := time.Now() logger.Info(logPrefix, "Calling backend") logger.Debug(logPrefix, "Request", request) result, err := next[0](ctx, request) logger.Info(logPrefix, "Call to backend took", time.Since(begin).String()) if err != nil { logger.Warning(logPrefix, "Call to backend failed:", err.Error()) return result, err } if result == nil { logger.Warning(logPrefix, "Call to backend returned a null response") } return result, err } } } ================================================ FILE: proxy/logging_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "bytes" "context" "fmt" "strings" "testing" "github.com/luraproject/lura/v2/logging" ) func TestNewLoggingMiddleware_ok(t *testing.T) { buff := bytes.NewBuffer(make([]byte, 1024)) logger, _ := logging.NewLogger("DEBUG", buff, "pref") resp := &Response{IsComplete: true} mw := NewLoggingMiddleware(logger, "supu") p := mw(dummyProxy(resp)) r, err := p(context.Background(), &Request{}) if r != resp { t.Error("The proxy didn't return the expected response") return } if err != nil { t.Errorf("The proxy returned an unexpected error: %s", err.Error()) return } logMsg := buff.String() if strings.Count(logMsg, "pref") != 3 { t.Error("The logs don't have the injected prefix") } if strings.Count(logMsg, "INFO") != 2 { t.Error("The logs don't have the expected INFO messages") } if strings.Count(logMsg, "DEBU") != 1 { t.Error("The logs don't have the expected DEBUG messages") } if !strings.Contains(logMsg, "[SUPU] Calling backend") { t.Error("The logs didn't mark the start of the execution") } if !strings.Contains(logMsg, "[SUPU] Call to backend took") { t.Error("The logs didn't mark the end of the execution") } } func TestNewLoggingMiddleware_erroredResponse(t *testing.T) { buff := bytes.NewBuffer(make([]byte, 1024)) logger, _ := logging.NewLogger("DEBUG", buff, "pref") resp := &Response{IsComplete: true} mw := NewLoggingMiddleware(logger, "supu") expextedError := fmt.Errorf("NO-body expects the %s Inquisition!", "Spanish") p := mw(func(_ context.Context, _ *Request) (*Response, error) { return resp, expextedError }) r, err := p(context.Background(), &Request{}) if r != resp { t.Error("The proxy didn't return the expected response") return } if err != expextedError { t.Errorf("The proxy didn't return the expected error: %s", err.Error()) return } logMsg := buff.String() if strings.Count(logMsg, "pref") != 4 { t.Error("The logs don't have the injected prefix") } if strings.Count(logMsg, "INFO") != 2 { t.Error("The logs don't have the expected INFO messages") } if strings.Count(logMsg, "DEBU") != 1 { t.Error("The logs don't have the expected DEBUG messages") } if strings.Count(logMsg, "WARN") != 1 { t.Error("The logs don't have the expected DEBUG messages") } if !strings.Contains(logMsg, "[SUPU] Call to backend failed: NO-body expects the Spanish Inquisition!") { t.Error("The logs didn't mark the fail of the execution") } if !strings.Contains(logMsg, "[SUPU] Calling backend") { t.Error("The logs didn't mark the start of the execution") } if !strings.Contains(logMsg, "[SUPU] Call to backend took") { t.Error("The logs didn't mark the end of the execution") } } func TestNewLoggingMiddleware_nullResponse(t *testing.T) { buff := bytes.NewBuffer(make([]byte, 1024)) logger, _ := logging.NewLogger("DEBUG", buff, "pref") mw := NewLoggingMiddleware(logger, "supu") p := mw(dummyProxy(nil)) r, err := p(context.Background(), &Request{}) if r != nil { t.Error("The proxy didn't return the expected response") return } if err != nil { t.Errorf("The proxy returned an unexpected error: %s", err.Error()) return } logMsg := buff.String() if strings.Count(logMsg, "pref") != 4 { t.Error("The logs don't have the injected prefix") } if strings.Count(logMsg, "INFO") != 2 { t.Error("The logs don't have the expected INFO messages") } if strings.Count(logMsg, "DEBU") != 1 { t.Error("The logs don't have the expected DEBUG messages") } if strings.Count(logMsg, "WARN") != 1 { t.Error("The logs don't have the expected DEBUG messages") } if !strings.Contains(logMsg, "[SUPU] Call to backend returned a null response") { t.Error("The logs didn't mark the fail of the execution") } if !strings.Contains(logMsg, "[SUPU] Calling backend") { t.Error("The logs didn't mark the start of the execution") } if !strings.Contains(logMsg, "[SUPU] Call to backend took") { t.Error("The logs didn't mark the end of the execution") } } ================================================ FILE: proxy/merging.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "context" "fmt" "io" "net/http" "regexp" "strconv" "strings" "time" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" ) // NewMergeDataMiddleware creates proxy middleware for merging responses from several backends func NewMergeDataMiddleware(logger logging.Logger, endpointConfig *config.EndpointConfig) Middleware { // skipcq: GO-R1005 totalBackends := len(endpointConfig.Backend) if totalBackends == 0 { logger.Fatal("all endpoints must have at least one backend: NewMergeDataMiddleware") return nil } if totalBackends == 1 { return emptyMiddlewareFallback(logger) } serviceTimeout := time.Duration(85*endpointConfig.Timeout.Nanoseconds()/100) * time.Nanosecond combiner := getResponseCombiner(endpointConfig.ExtraConfig) isSequential, sequentialReplacements := sequentialMergerConfig(endpointConfig) logger.Debug( fmt.Sprintf( "[ENDPOINT: %s][Merge] Backends: %d, sequential: %t, combiner: %s", endpointConfig.Endpoint, totalBackends, isSequential, getResponseCombinerName(endpointConfig.ExtraConfig), ), ) bfFactory := backendFiltererFactory.filtererFactory return func(next ...Proxy) Proxy { if len(next) != totalBackends { // we leave the panic here, because we do not want to continue // if this configuration is wrong, as it would lead to unexpected // behaviour. logger.Fatal("not enough proxies for this endpoint: NewMergeDataMiddleware") return nil } reqClone := func(r *Request) *Request { res := r.Clone(); return &res } filters, err := bfFactory(endpointConfig) if err != nil { logger.Error(fmt.Sprintf("[ENDPOINT: %s]%s %s", endpointConfig.Endpoint, backendFiltererFactory.logPrefix, err)) return func(_ context.Context, _ *Request) (*Response, error) { return nil, err } } if hasUnsafeBackends(endpointConfig) { reqClone = CloneRequest } if !isSequential { return parallelMerge(reqClone, serviceTimeout, combiner, filters, next...) } return sequentialMerge(reqClone, serviceTimeout, combiner, sequentialReplacements, filters, next...) } } // BackendFiltererFactory is a factory function that returns a list of BackendFilterer // based on the provided EndpointConfig. // The returned list must be sorted by the backend index. // The list can contain nil values, which means that the backend in that index is untouched. type BackendFiltererFactory func(*config.EndpointConfig) ([]BackendFilterer, error) // BackendFilterer evalutes the request and returns true if the backend should be used, // otherwise the backend is skipped in both normal and sequential merging. // If the backend is skipped, the response will not be merged into the final response. type BackendFilterer func(*Request) bool func defaultBackendFiltererFactory(_ *config.EndpointConfig) ([]BackendFilterer, error) { return []BackendFilterer{}, nil } type backendFiltererRegistry struct { logPrefix string filtererFactory BackendFiltererFactory } var backendFiltererFactory = backendFiltererRegistry{ filtererFactory: defaultBackendFiltererFactory, } // RegisterBackendFiltererFactory registers a new backend filterer factory // to be used by the merging middleware. // This factory is used to create a list of BackendFilterer // functions that will be used to filter backends based on the request. // Important: this function should be called everytime the middleware is created. func RegisterBackendFiltererFactory(logPrefix string, f BackendFiltererFactory) { backendFiltererFactory.logPrefix = logPrefix backendFiltererFactory.filtererFactory = f } func ResetBackendFiltererFactory() { backendFiltererFactory.logPrefix = "" backendFiltererFactory.filtererFactory = defaultBackendFiltererFactory } type sequentialBackendReplacement struct { backendIndex int destination string source []string fullResponse bool } func sequentialMergerConfig(cfg *config.EndpointConfig) (bool, [][]sequentialBackendReplacement) { // skipcq: GO-R1005 enabled := false totalBackends := len(cfg.Backend) sequentialReplacements := make([][]sequentialBackendReplacement, totalBackends) var propagatedParams []string if v, ok := cfg.ExtraConfig[Namespace]; ok { if e, ok := v.(map[string]interface{}); ok { if v, ok := e[isSequentialKey]; ok { c, ok := v.(bool) enabled = ok && c } if v, ok := e[sequentialPropagateKey]; ok { if a, ok := v.([]interface{}); ok { for _, p := range a { propagatedParams = append(propagatedParams, p.(string)) } } } } } var rePropagatedParams = regexp.MustCompile(`[Rr]esp(\d+)_?([\w-.]+)?`) var reUrlPatterns = regexp.MustCompile(`\{\{\.Resp(\d+)_([\w-.]+)\}\}`) destKeyGenerator := func(i string, t string) string { key := "Resp" + i if t != "" { key += "_" + t } return key } for i, b := range cfg.Backend { for _, match := range reUrlPatterns.FindAllStringSubmatch(b.URLPattern, -1) { if len(match) > 1 { backendIndex, err := strconv.Atoi(match[1]) if err != nil { continue } sequentialReplacements[i] = append(sequentialReplacements[i], sequentialBackendReplacement{ backendIndex: backendIndex, destination: destKeyGenerator(match[1], match[2]), source: strings.Split(match[2], "."), fullResponse: match[2] == "", }) } } if i > 0 { for _, p := range propagatedParams { for _, match := range rePropagatedParams.FindAllStringSubmatch(p, -1) { if len(match) > 1 { backendIndex, err := strconv.Atoi(match[1]) if err != nil || backendIndex >= totalBackends { continue } sequentialReplacements[i] = append(sequentialReplacements[i], sequentialBackendReplacement{ backendIndex: backendIndex, destination: destKeyGenerator(match[1], match[2]), source: strings.Split(match[2], "."), fullResponse: match[2] == "", }) } } } } } return enabled, sequentialReplacements } func hasUnsafeBackends(cfg *config.EndpointConfig) bool { if len(cfg.Backend) == 1 { return false } for _, b := range cfg.Backend { if m := strings.ToUpper(b.Method); m != http.MethodGet && m != http.MethodHead { return true } } return false } func parallelMerge( reqCloner func(*Request) *Request, timeout time.Duration, rc ResponseCombiner, filters []BackendFilterer, next ...Proxy, ) Proxy { return func(ctx context.Context, request *Request) (*Response, error) { localCtx, cancel := context.WithTimeout(ctx, timeout) proxyCount := len(next) filterCount := len(filters) parts := make(chan *Response, proxyCount) failed := make(chan error, proxyCount) for i, n := range next { if (i < filterCount) && (filters[i] != nil) && !filters[i](request) { proxyCount-- continue } go requestPart(localCtx, n, reqCloner(request), parts, failed) } acc := newIncrementalMergeAccumulator(proxyCount, rc) for i := 0; i < proxyCount; i++ { select { case err := <-failed: acc.Merge(nil, err) case response := <-parts: acc.Merge(response, nil) } } result, err := acc.Result() cancel() return result, err } } func sequentialMerge( // skipcq: GO-R1005 reqCloner func(*Request) *Request, timeout time.Duration, rc ResponseCombiner, sequentialReplacements [][]sequentialBackendReplacement, filters []BackendFilterer, next ...Proxy, ) Proxy { return func(ctx context.Context, request *Request) (*Response, error) { localCtx, cancel := context.WithTimeout(ctx, timeout) filterCount := len(filters) parts := make([]*Response, len(next)) out := make(chan *Response, 1) errCh := make(chan error, 1) sequentialMergeRegistry := map[string]string{} acc := newIncrementalMergeAccumulator(len(next), rc) TxLoop: for i, n := range next { if i > 0 { for _, r := range sequentialReplacements[i] { if r.backendIndex >= i || parts[r.backendIndex] == nil { continue } var v interface{} var ok bool data := parts[r.backendIndex].Data if len(r.source) > 1 { for _, k := range r.source[:len(r.source)-1] { v, ok = data[k] if !ok { break } clean, ok := v.(map[string]interface{}) if !ok { break } data = clean } } if found := sequentialMergeRegistry[r.destination]; found != "" { request.Params[r.destination] = found continue } if r.fullResponse { if parts[r.backendIndex].Io == nil { continue } buf, err := io.ReadAll(parts[r.backendIndex].Io) if err == nil { request.Params[r.destination] = string(buf) sequentialMergeRegistry[r.destination] = string(buf) } continue } v, ok = data[r.source[len(r.source)-1]] if !ok { continue } var param string switch clean := v.(type) { case []interface{}: if len(clean) == 0 { request.Params[r.destination] = "" break } var b strings.Builder for i := 0; i < len(clean)-1; i++ { fmt.Fprintf(&b, "%v,", clean[i]) } fmt.Fprintf(&b, "%v", clean[len(clean)-1]) param = b.String() case string: param = clean case int: param = strconv.Itoa(clean) case float64: param = strconv.FormatFloat(clean, 'E', -1, 32) case bool: param = strconv.FormatBool(clean) default: param = fmt.Sprintf("%v", v) } request.Params[r.destination] = param sequentialMergeRegistry[r.destination] = param } } if (i < filterCount) && (filters[i] != nil) && !filters[i](request) { parts[i] = &Response{IsComplete: true, Data: make(map[string]interface{})} acc.pending-- continue } sequentialRequestPart(localCtx, n, reqCloner(request), out, errCh) select { case err := <-errCh: if i == 0 { cancel() return nil, err } acc.Merge(nil, err) break TxLoop case response := <-out: acc.Merge(response, nil) if !response.IsComplete { break TxLoop } parts[i] = response } } result, err := acc.Result() cancel() return result, err } } type incrementalMergeAccumulator struct { pending int data *Response combiner ResponseCombiner errs []error } func newIncrementalMergeAccumulator(total int, combiner ResponseCombiner) *incrementalMergeAccumulator { return &incrementalMergeAccumulator{ pending: total, combiner: combiner, errs: []error{}, } } func (i *incrementalMergeAccumulator) Merge(res *Response, err error) { i.pending-- if err != nil { i.errs = append(i.errs, err) if i.data != nil { i.data.IsComplete = false } return } if res == nil { i.errs = append(i.errs, errNullResult) return } if i.data == nil { i.data = res return } i.data = i.combiner(2, []*Response{i.data, res}) } func (i *incrementalMergeAccumulator) Result() (*Response, error) { if i.data == nil { return nil, newMergeError(i.errs) } if i.pending > 0 || len(i.errs) > 0 { i.data.IsComplete = false } return i.data, newMergeError(i.errs) } func requestPart(ctx context.Context, next Proxy, request *Request, out chan<- *Response, failed chan<- error) { localCtx, cancel := context.WithCancel(ctx) in, err := next(localCtx, request) if err != nil { failed <- err cancel() return } if in == nil { failed <- errNullResult cancel() return } select { case out <- in: case <-ctx.Done(): failed <- ctx.Err() } cancel() } func sequentialRequestPart(ctx context.Context, next Proxy, request *Request, out chan<- *Response, failed chan<- error) { copyRequest := CloneRequest(request) in, err := next(ctx, request) *request = *copyRequest if err != nil { failed <- err return } if in == nil { failed <- errNullResult return } select { case out <- in: case <-ctx.Done(): failed <- ctx.Err() } } func newMergeError(errs []error) error { if len(errs) == 0 { return nil } return mergeError{errs} } type mergeError struct { errs []error } func (m mergeError) Error() string { msg := make([]string, len(m.errs)) for i, err := range m.errs { msg[i] = err.Error() } return strings.Join(msg, "\n") } func (m mergeError) Errors() []error { return m.errs } // ResponseCombiner func to merge the collected responses into a single one type ResponseCombiner func(int, []*Response) *Response // RegisterResponseCombiner adds a new response combiner into the internal register func RegisterResponseCombiner(name string, f ResponseCombiner) { responseCombiners.SetResponseCombiner(name, f) } const ( mergeKey = "combiner" isSequentialKey = "sequential" sequentialPropagateKey = "sequential_propagated_params" defaultCombinerName = "default" ) var responseCombiners = initResponseCombiners() func initResponseCombiners() *combinerRegister { return newCombinerRegister(map[string]ResponseCombiner{defaultCombinerName: combineData}, combineData) } func getResponseCombinerName(extra config.ExtraConfig) string { if v, ok := extra[Namespace]; ok { if e, ok := v.(map[string]interface{}); ok { if v, ok := e[mergeKey]; ok { if _, ok := responseCombiners.GetResponseCombiner(v.(string)); ok { return v.(string) } } } } return defaultCombinerName } func getResponseCombiner(extra config.ExtraConfig) ResponseCombiner { combiner := getResponseCombinerName(extra) c, _ := responseCombiners.GetResponseCombiner(combiner) return c } func combineData(total int, parts []*Response) *Response { isComplete := len(parts) == total var retResponse *Response for _, part := range parts { if part == nil || part.Data == nil { isComplete = false continue } isComplete = isComplete && part.IsComplete if retResponse == nil { retResponse = &Response{Data: part.Data, IsComplete: isComplete} continue } for k, v := range part.Data { retResponse.Data[k] = v } } if nil == retResponse { // do not allow nil data in the response: return &Response{Data: make(map[string]interface{}), IsComplete: isComplete} } retResponse.IsComplete = isComplete return retResponse } ================================================ FILE: proxy/merging_benchmark_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "context" "fmt" "testing" "time" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" ) func BenchmarkNewMergeDataMiddleware(b *testing.B) { backend := config.Backend{} backends := make([]*config.Backend, 10) for i := range backends { backends[i] = &backend } proxies := []Proxy{ dummyProxy(&Response{Data: map[string]interface{}{"int": 1, "float": 3.14, "bool": true, "string": "wwwww"}, IsComplete: true}), dummyProxy(&Response{Data: map[string]interface{}{"int": 1, "float": 3.14, "bool": true, "string": "wwwww"}, IsComplete: true}), dummyProxy(&Response{Data: map[string]interface{}{"int": 1, "float": 3.14, "bool": true, "string": "wwwww"}, IsComplete: true}), dummyProxy(&Response{Data: map[string]interface{}{"int": 1, "float": 3.14, "bool": true, "string": "wwwww"}, IsComplete: true}), dummyProxy(&Response{Data: map[string]interface{}{"int": 1, "float": 3.14, "bool": true, "string": "wwwww"}, IsComplete: true}), dummyProxy(&Response{Data: map[string]interface{}{"int": 1, "float": 3.14, "bool": true, "string": "wwwww"}, IsComplete: true}), dummyProxy(&Response{Data: map[string]interface{}{"int": 1, "float": 3.14, "bool": true, "string": "wwwww"}, IsComplete: true}), dummyProxy(&Response{Data: map[string]interface{}{"int": 1, "float": 3.14, "bool": true, "string": "wwwww"}, IsComplete: true}), dummyProxy(&Response{Data: map[string]interface{}{"int": 1, "float": 3.14, "bool": true, "string": "wwwww"}, IsComplete: true}), dummyProxy(&Response{Data: map[string]interface{}{"int": 1, "float": 3.14, "bool": true, "string": "wwwww"}, IsComplete: true}), } for _, totalParts := range []int{2, 3, 4, 5, 6, 7, 8, 9, 10} { b.Run(fmt.Sprintf("with %d parts", totalParts), func(b *testing.B) { endpoint := config.EndpointConfig{ Backend: backends[:totalParts], Timeout: time.Duration(100) * time.Millisecond, } proxy := NewMergeDataMiddleware(logging.NoOp, &endpoint)(proxies[:totalParts]...) b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { proxy(context.Background(), &Request{Params: map[string]string{}}) } }) } } func BenchmarkNewMergeDataMiddleware_sequential(b *testing.B) { backends := make([]*config.Backend, 10) pattern := "/some" var keys []string for i := range backends { b := &config.Backend{ URLKeys: make([]string, 4*i), URLPattern: pattern, } copy(b.URLKeys, keys) backends[i] = b for _, t := range []string{"int", "float", "bool", "string"} { next := fmt.Sprintf("Resp%d_%s", i, t) pattern += "/{{." + next + "}}" keys = append(keys, next) } } proxies := []Proxy{ dummyProxy(&Response{Data: map[string]interface{}{"int": 1, "float": 3.14, "bool": true, "string": "wwwww"}, IsComplete: true}), dummyProxy(&Response{Data: map[string]interface{}{"int": 1, "float": 3.14, "bool": true, "string": "wwwww"}, IsComplete: true}), dummyProxy(&Response{Data: map[string]interface{}{"int": 1, "float": 3.14, "bool": true, "string": "wwwww"}, IsComplete: true}), dummyProxy(&Response{Data: map[string]interface{}{"int": 1, "float": 3.14, "bool": true, "string": "wwwww"}, IsComplete: true}), dummyProxy(&Response{Data: map[string]interface{}{"int": 1, "float": 3.14, "bool": true, "string": "wwwww"}, IsComplete: true}), dummyProxy(&Response{Data: map[string]interface{}{"int": 1, "float": 3.14, "bool": true, "string": "wwwww"}, IsComplete: true}), dummyProxy(&Response{Data: map[string]interface{}{"int": 1, "float": 3.14, "bool": true, "string": "wwwww"}, IsComplete: true}), dummyProxy(&Response{Data: map[string]interface{}{"int": 1, "float": 3.14, "bool": true, "string": "wwwww"}, IsComplete: true}), dummyProxy(&Response{Data: map[string]interface{}{"int": 1, "float": 3.14, "bool": true, "string": "wwwww"}, IsComplete: true}), dummyProxy(&Response{Data: map[string]interface{}{"int": 1, "float": 3.14, "bool": true, "string": "wwwww"}, IsComplete: true}), } for _, totalParts := range []int{2, 3, 4, 5, 6, 7, 8, 9, 10} { b.Run(fmt.Sprintf("with %d parts", totalParts), func(b *testing.B) { endpoint := config.EndpointConfig{ Backend: backends[:totalParts], Timeout: time.Duration(100) * time.Millisecond, ExtraConfig: config.ExtraConfig{ Namespace: map[string]interface{}{ isSequentialKey: true, }, }, } proxy := NewMergeDataMiddleware(logging.NoOp, &endpoint)(proxies[:totalParts]...) b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { proxy(context.Background(), &Request{Params: map[string]string{}}) } }) } } ================================================ FILE: proxy/merging_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "context" "errors" "io" "strings" "testing" "time" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" ) func TestNewMergeDataMiddleware(t *testing.T) { tests := []func(t *testing.T){ testNewMergeDataMiddleware_simpleFiltering, testNewMergeDataMiddleware_sequentialFiltering, testNewMergeDataMiddleware_empty, testNewMergeDataMiddleware_ok, testNewMergeDataMiddleware_sequential, testNewMergeDataMiddleware_sequential_unavailableParams, testNewMergeDataMiddleware_sequential_erroredBackend, testNewMergeDataMiddleware_sequential_erroredFirstBackend, testNewMergeDataMiddleware_mergeIncompleteResults, testNewMergeDataMiddleware_mergeEmptyResults, testNewMergeDataMiddleware_partialTimeout, testNewMergeDataMiddleware_partial, testNewMergeDataMiddleware_nullResponse, testNewMergeDataMiddleware_timeout, testRegisterResponseCombiner, } for _, test := range tests { ResetBackendFiltererFactory() test(t) } } func testNewMergeDataMiddleware_empty(t *testing.T) { timeout := 500 * time.Millisecond backend := config.Backend{} endpoint := config.EndpointConfig{ Backend: []*config.Backend{&backend, &backend}, Timeout: timeout, } expectedErr := errors.New("wait for me") erroredProxy := func(_ context.Context, _ *Request) (*Response, error) { return nil, expectedErr } mw := NewMergeDataMiddleware(logging.NoOp, &endpoint) p := mw(erroredProxy, erroredProxy) mustEnd := time.After(2 * timeout) out, err := p(context.Background(), &Request{}) mErr, ok := err.(mergeError) if !ok { t.Errorf("The middleware propagated an unexpected error: %s\n", err) return } if len(mErr.errs) != 2 { t.Errorf("The middleware propagated an unexpected error: %s\n", err) return } if mErr.errs[0] != mErr.errs[1] || mErr.errs[0] != expectedErr { t.Errorf("The middleware propagated an unexpected error: %s\n", err) return } if out != nil { t.Errorf("The proxy returned a result\n") return } select { case <-mustEnd: t.Errorf("We were expecting a response but we got none\n") default: } } func testNewMergeDataMiddleware_ok(t *testing.T) { timeout := 500 backend := config.Backend{} endpoint := config.EndpointConfig{ Backend: []*config.Backend{&backend, &backend}, Timeout: time.Duration(timeout) * time.Millisecond, } mw := NewMergeDataMiddleware(logging.NoOp, &endpoint) p := mw( dummyProxy(&Response{Data: map[string]interface{}{"supu": 42}, IsComplete: true}), dummyProxy(&Response{Data: map[string]interface{}{"tupu": true}, IsComplete: true})) mustEnd := time.After(time.Duration(2*timeout) * time.Millisecond) out, err := p(context.Background(), &Request{}) if err != nil { t.Errorf("The middleware propagated an unexpected error: %s\n", err.Error()) } if out == nil { t.Errorf("The proxy returned a null result\n") return } select { case <-mustEnd: t.Errorf("We were expecting a response but we got none\n") default: if len(out.Data) != 2 { t.Errorf("We weren't expecting a partial response but we got %v!\n", out) } if !out.IsComplete { t.Errorf("We were expecting a completed response but we got an incompleted one!\n") } } } func testNewMergeDataMiddleware_sequential(t *testing.T) { timeout := 1000 endpoint := config.EndpointConfig{ Backend: []*config.Backend{ {URLPattern: "/"}, {URLPattern: "/aaa/{{.Resp0_array}}"}, {URLPattern: "/aaa/{{.Resp0_int}}/{{.Resp0_string}}/{{.Resp0_bool}}/{{.Resp0_float}}/{{.Resp0_struct.foo}}"}, {URLPattern: "/aaa/{{.Resp0_int}}/{{.Resp0_string}}/{{.Resp0_bool}}/{{.Resp0_float}}/{{.Resp0_struct.foo}}?x={{.Resp1_tupu}}"}, {URLPattern: "/aaa/{{.Resp0_struct.foo}}/{{.Resp0_struct.struct.foo}}/{{.Resp0_struct.struct.struct.foo}}"}, {URLPattern: "/zzz", Encoding: "no-op"}, {URLPattern: "/hit-me"}, }, Timeout: time.Duration(timeout) * time.Millisecond, ExtraConfig: config.ExtraConfig{ Namespace: map[string]interface{}{ isSequentialKey: true, sequentialPropagateKey: []interface{}{"resp0_propagated", "resp5"}, }, }, } expectedBody := "foo" checkBody := func(t *testing.T, r *Request) { if r.Body == nil { t.Error("empty body") return } b, err := io.ReadAll(r.Body) if err != nil { t.Error(err) return } r.Body.Close() if string(b) != expectedBody { t.Errorf("unexpected body '%s'", string(b)) } } mw := NewMergeDataMiddleware(logging.NoOp, &endpoint) p := mw( dummyProxy(&Response{Data: map[string]interface{}{ "int": 42, "string": "some", "bool": true, "float": 3.14, "struct": map[string]interface{}{ "foo": "bar", "struct": map[string]interface{}{ "foo": "bar", "struct": map[string]interface{}{ "foo": "bar", }, }, }, "array": []interface{}{"1", "2"}, "propagated": "everywhere", }, IsComplete: true}), func(_ context.Context, r *Request) (*Response, error) { checkBody(t, r) checkRequestParam(t, r, "Resp0_array", "1,2") checkRequestParam(t, r, "Resp0_propagated", "everywhere") return &Response{Data: map[string]interface{}{"tupu": "foo"}, IsComplete: true}, nil }, func(_ context.Context, r *Request) (*Response, error) { checkBody(t, r) checkRequestParam(t, r, "Resp0_int", "42") checkRequestParam(t, r, "Resp0_string", "some") checkRequestParam(t, r, "Resp0_float", "3.14E+00") checkRequestParam(t, r, "Resp0_bool", "true") checkRequestParam(t, r, "Resp0_struct.foo", "bar") checkRequestParam(t, r, "Resp0_propagated", "everywhere") return &Response{Data: map[string]interface{}{"tupu": "foo"}, IsComplete: true}, nil }, func(_ context.Context, r *Request) (*Response, error) { checkBody(t, r) checkRequestParam(t, r, "Resp0_int", "42") checkRequestParam(t, r, "Resp0_string", "some") checkRequestParam(t, r, "Resp0_float", "3.14E+00") checkRequestParam(t, r, "Resp0_bool", "true") checkRequestParam(t, r, "Resp0_struct.foo", "bar") checkRequestParam(t, r, "Resp1_tupu", "foo") checkRequestParam(t, r, "Resp0_propagated", "everywhere") return &Response{Data: map[string]interface{}{"aaaa": []int{1, 2, 3}}, IsComplete: true}, nil }, func(_ context.Context, r *Request) (*Response, error) { checkBody(t, r) checkRequestParam(t, r, "Resp0_struct.foo", "bar") checkRequestParam(t, r, "Resp0_struct.struct.foo", "bar") checkRequestParam(t, r, "Resp0_struct.struct.struct.foo", "bar") checkRequestParam(t, r, "Resp0_propagated", "everywhere") return &Response{Data: map[string]interface{}{"bbbb": []bool{true, false}}, IsComplete: true}, nil }, func(_ context.Context, r *Request) (*Response, error) { checkBody(t, r) checkRequestParam(t, r, "Resp0_propagated", "everywhere") return &Response{Data: map[string]interface{}{}, Io: io.NopCloser(strings.NewReader("hello")), IsComplete: true}, nil }, func(_ context.Context, r *Request) (*Response, error) { checkBody(t, r) checkRequestParam(t, r, "Resp0_propagated", "everywhere") checkRequestParam(t, r, "Resp5", "hello") return &Response{Data: map[string]interface{}{}, IsComplete: true}, nil }, ) mustEnd := time.After(time.Duration(2*timeout) * time.Millisecond) out, err := p(context.Background(), &Request{ Params: map[string]string{}, Body: io.NopCloser(strings.NewReader(expectedBody)), }) if err != nil { t.Errorf("The middleware propagated an unexpected error: %s\n", err.Error()) } if out == nil { t.Errorf("The proxy returned a null result\n") return } select { case <-mustEnd: t.Errorf("We were expecting a response but we got none\n") default: if len(out.Data) != 10 { t.Errorf("We weren't expecting a partial response but we got %v!\n", out) } if !out.IsComplete { t.Errorf("We were expecting a completed response but we got an incompleted one!\n") } } } func checkRequestParam(t *testing.T, r *Request, k, v string) { if r.Params[k] != v { t.Errorf("request without the expected set of params: %s - %+v", k, r.Params) } } func testNewMergeDataMiddleware_sequential_unavailableParams(t *testing.T) { timeout := 500 endpoint := config.EndpointConfig{ Backend: []*config.Backend{ {URLPattern: "/"}, {URLPattern: "/aaa/{{.Resp2_supu}"}, {URLPattern: "/aaa/{{.Resp0_tupu}}?x={{.Resp1_tupu}}"}, }, Timeout: time.Duration(timeout) * time.Millisecond, ExtraConfig: config.ExtraConfig{ Namespace: map[string]interface{}{ isSequentialKey: true, }, }, } mw := NewMergeDataMiddleware(logging.NoOp, &endpoint) p := mw( dummyProxy(&Response{Data: map[string]interface{}{"supu": 42}, IsComplete: true}), func(_ context.Context, r *Request) (*Response, error) { if v, ok := r.Params["Resp0_supu"]; ok || v != "" { t.Errorf("request with unexpected set of params") } return &Response{Data: map[string]interface{}{"tupu": "foo"}, IsComplete: true}, nil }, func(_ context.Context, r *Request) (*Response, error) { if v, ok := r.Params["Resp0_supu"]; ok || v != "" { t.Errorf("request with unexpected set of params") } if r.Params["Respo_tupu"] != "" { t.Errorf("request without the expected set of params") } if r.Params["Resp1_tupu"] != "foo" { t.Errorf("request without the expected set of params") } return &Response{Data: map[string]interface{}{"aaaa": []int{1, 2, 3}}, IsComplete: true}, nil }, ) mustEnd := time.After(time.Duration(2*timeout) * time.Millisecond) out, err := p(context.Background(), &Request{Params: map[string]string{}}) if err != nil { t.Errorf("The middleware propagated an unexpected error: %s\n", err.Error()) } if out == nil { t.Errorf("The proxy returned a null result\n") return } select { case <-mustEnd: t.Errorf("We were expecting a response but we got none\n") default: if len(out.Data) != 3 { t.Errorf("We weren't expecting a partial response but we got %v!\n", out) } if !out.IsComplete { t.Errorf("We were expecting a completed response but we got an incompleted one!\n") } } } func testNewMergeDataMiddleware_sequential_erroredBackend(t *testing.T) { timeout := 500 endpoint := config.EndpointConfig{ Backend: []*config.Backend{ {URLPattern: "/"}, {URLPattern: "/aaa/{{.Resp0_supu}}"}, {URLPattern: "/aaa/{{.Resp0_supu}}?x={{.Resp1_tupu}}"}, }, Timeout: time.Duration(timeout) * time.Millisecond, ExtraConfig: config.ExtraConfig{ Namespace: map[string]interface{}{ isSequentialKey: true, }, }, } expecterErr := errors.New("wait for me") mw := NewMergeDataMiddleware(logging.NoOp, &endpoint) p := mw( dummyProxy(&Response{Data: map[string]interface{}{"supu": 42}, IsComplete: true}), func(_ context.Context, r *Request) (*Response, error) { if r.Params["Resp0_supu"] != "42" { t.Errorf("request without the expected set of params") } return nil, expecterErr }, func(_ context.Context, _ *Request) (*Response, error) { return nil, nil }, ) mustEnd := time.After(time.Duration(2*timeout) * time.Millisecond) out, err := p(context.Background(), &Request{Params: map[string]string{}}) if err == nil { t.Errorf("The middleware did not propagate an error\n") return } if out == nil { t.Errorf("The proxy returned a null result\n") return } select { case <-mustEnd: t.Errorf("We were expecting a response but we got none\n") default: if len(out.Data) != 1 { t.Errorf("We weren't expecting a partial response but we got %v!\n", out) } if out.IsComplete { t.Errorf("We were not expecting a completed response!\n") } } } func testNewMergeDataMiddleware_sequential_erroredFirstBackend(t *testing.T) { timeout := 500 endpoint := config.EndpointConfig{ Backend: []*config.Backend{ {URLPattern: "/"}, {URLPattern: "/aaa/{{.Resp0_supu}}"}, {URLPattern: "/aaa/{{.Resp0_supu}}?x={{.Resp1_tupu}}"}, }, Timeout: time.Duration(timeout) * time.Millisecond, ExtraConfig: config.ExtraConfig{ Namespace: map[string]interface{}{ isSequentialKey: true, }, }, } expecterErr := errors.New("wait for me") mw := NewMergeDataMiddleware(logging.NoOp, &endpoint) p := mw( func(_ context.Context, _ *Request) (*Response, error) { return nil, expecterErr }, func(_ context.Context, _ *Request) (*Response, error) { t.Error("this backend should never be called") return nil, nil }, func(_ context.Context, _ *Request) (*Response, error) { t.Error("this backend should never be called") return nil, nil }, ) mustEnd := time.After(time.Duration(2*timeout) * time.Millisecond) out, err := p(context.Background(), &Request{Params: map[string]string{}}) if err != expecterErr { t.Errorf("The middleware did not propagate the expected error: %v\n", err) return } if out != nil { t.Errorf("The proxy returned a not null result %v", out) return } select { case <-mustEnd: t.Errorf("We were expecting a response but we got none\n") default: } } func testNewMergeDataMiddleware_mergeIncompleteResults(t *testing.T) { timeout := 500 backend := config.Backend{} endpoint := config.EndpointConfig{ Backend: []*config.Backend{&backend, &backend}, Timeout: time.Duration(timeout) * time.Millisecond, } mw := NewMergeDataMiddleware(logging.NoOp, &endpoint) p := mw( dummyProxy(&Response{Data: map[string]interface{}{"supu": 42}, IsComplete: true}), dummyProxy(&Response{Data: map[string]interface{}{"tupu": true}, IsComplete: false})) mustEnd := time.After(time.Duration(2*timeout) * time.Millisecond) out, err := p(context.Background(), &Request{}) if err != nil { t.Errorf("The middleware propagated an unexpected error: %s\n", err.Error()) } if out == nil { t.Errorf("The proxy returned a null result\n") return } select { case <-mustEnd: t.Errorf("We were expecting a response but we got none\n") default: if len(out.Data) != 2 { t.Errorf("We were expecting incomplete results merged but we got %v!\n", out.Data) } if out.IsComplete { t.Errorf("We were expecting an incomplete response but we got a completed one!\n") } } } func testNewMergeDataMiddleware_mergeEmptyResults(t *testing.T) { timeout := 500 backend := config.Backend{} endpoint := config.EndpointConfig{ Backend: []*config.Backend{&backend, &backend}, Timeout: time.Duration(timeout) * time.Millisecond, } mw := NewMergeDataMiddleware(logging.NoOp, &endpoint) p := mw( dummyProxy(&Response{Data: nil, IsComplete: false}), dummyProxy(&Response{Data: nil, IsComplete: false})) mustEnd := time.After(time.Duration(2*timeout) * time.Millisecond) out, err := p(context.Background(), &Request{}) if err != nil { t.Errorf("The middleware propagated an unexpected error: %s\n", err.Error()) } if out == nil { t.Errorf("The proxy returned a null result\n") return } select { case <-mustEnd: t.Errorf("We were expecting a response but we got none\n") default: if len(out.Data) != 0 { t.Errorf("We were expecting empty data but we got %v!\n", out) } if out.IsComplete { t.Errorf("We were expecting an incomplete response but we got an incompleted one!\n") } } } func testNewMergeDataMiddleware_partialTimeout(t *testing.T) { timeout := 100 backend := config.Backend{Timeout: time.Duration(timeout) * time.Millisecond} endpoint := config.EndpointConfig{ Backend: []*config.Backend{&backend, &backend}, Timeout: time.Duration(timeout) * time.Millisecond, } mw := NewMergeDataMiddleware(logging.NoOp, &endpoint) p := mw( delayedProxy(t, time.Duration(timeout/2)*time.Millisecond, &Response{Data: map[string]interface{}{"supu": 42}, IsComplete: true}), delayedProxy(t, time.Duration(5*timeout)*time.Millisecond, nil)) mustEnd := time.After(time.Duration(2*timeout) * time.Millisecond) out, err := p(context.Background(), &Request{}) if err == nil || err.Error() != "context deadline exceeded" { t.Errorf("The middleware propagated an unexpected error: %s\n", err.Error()) } if out == nil { t.Errorf("The proxy returned a null result\n") return } select { case <-mustEnd: t.Errorf("We were expecting a response but we got none\n") default: if len(out.Data) != 1 { t.Errorf("We were expecting a partial response but we got %v!\n", out) } if out.IsComplete { t.Errorf("We were expecting an incompleted response but we got a completed one!\n") } } } func testNewMergeDataMiddleware_partial(t *testing.T) { timeout := 100 backend := config.Backend{Timeout: time.Duration(timeout) * time.Millisecond} endpoint := config.EndpointConfig{ Backend: []*config.Backend{&backend, &backend}, Timeout: time.Duration(timeout) * time.Millisecond, } mw := NewMergeDataMiddleware(logging.NoOp, &endpoint) p := mw( dummyProxy(&Response{Data: map[string]interface{}{"supu": 42}, IsComplete: true}), dummyProxy(&Response{})) mustEnd := time.After(time.Duration(2*timeout) * time.Millisecond) out, err := p(context.Background(), &Request{}) if err != nil { t.Errorf("The middleware propagated an unexpected error: %s\n", err.Error()) } if out == nil { t.Errorf("The proxy returned a null result\n") return } select { case <-mustEnd: t.Errorf("We were expecting a response but we got none\n") default: if len(out.Data) != 1 { t.Errorf("We were expecting a partial response but we got %v!\n", out) } if out.IsComplete { t.Errorf("We were expecting an incompleted response but we got a completed one!\n") } } } func testNewMergeDataMiddleware_nullResponse(t *testing.T) { timeout := 100 backend := config.Backend{Timeout: time.Duration(timeout) * time.Millisecond} endpoint := config.EndpointConfig{ Backend: []*config.Backend{&backend, &backend}, } mw := NewMergeDataMiddleware(logging.NoOp, &endpoint) mustEnd := time.After(time.Duration(2*timeout) * time.Millisecond) out, err := mw(NoopProxy, NoopProxy)(context.Background(), &Request{}) if err == nil { t.Errorf("The middleware did not propagate the expected error") } switch mergeErr := err.(type) { case mergeError: if len(mergeErr.errs) != 2 { t.Errorf("The middleware propagated an unexpected error: %s", err.Error()) } if mergeErr.errs[0] != mergeErr.errs[1] { t.Errorf("The middleware propagated an unexpected error: %s", err.Error()) } if mergeErr.errs[0] != errNullResult { t.Errorf("The middleware propagated an unexpected error: %s", err.Error()) } default: t.Errorf("The middleware propagated an unexpected error: %s", err.Error()) } if out != nil { t.Errorf("The proxy returned a null result\n") return } select { case <-mustEnd: t.Errorf("We were expecting a response but we got none\n") default: } } func testNewMergeDataMiddleware_timeout(t *testing.T) { timeout := 100 backend := config.Backend{Timeout: time.Duration(timeout) * time.Millisecond} endpoint := config.EndpointConfig{ Backend: []*config.Backend{&backend, &backend}, Timeout: time.Duration(timeout) * time.Millisecond, } mw := NewMergeDataMiddleware(logging.NoOp, &endpoint) p := mw( delayedProxy(t, time.Duration(5*timeout)*time.Millisecond, nil), delayedProxy(t, time.Duration(5*timeout)*time.Millisecond, nil)) mustEnd := time.After(time.Duration(2*timeout) * time.Millisecond) out, err := p(context.Background(), &Request{}) if err == nil { t.Errorf("The middleware did not propagate the expected error") } switch mergeErr := err.(type) { case mergeError: if len(mergeErr.errs) != 2 { t.Errorf("The middleware propagated an unexpected error: %s", err.Error()) } if mergeErr.errs[0].Error() != mergeErr.errs[1].Error() { t.Errorf("The middleware propagated an unexpected error: %s", err.Error()) } if mergeErr.errs[0].Error() != "context deadline exceeded" { t.Errorf("The middleware propagated an unexpected error: %s", err.Error()) } default: t.Errorf("The middleware propagated an unexpected error: %s", err.Error()) } if out != nil { t.Errorf("The proxy returned a null result\n") return } select { case <-mustEnd: t.Errorf("We were expecting a response but we got none\n") default: } } func testRegisterResponseCombiner(t *testing.T) { subject := "test combiner" if len(responseCombiners.data.Clone()) != 1 { t.Error("unexpected initial size of the response combiner list:", responseCombiners.data.Clone()) } RegisterResponseCombiner(subject, getResponseCombiner(config.ExtraConfig{})) defer func() { responseCombiners = initResponseCombiners() }() if len(responseCombiners.data.Clone()) != 2 { t.Error("unexpected size of the response combiner list:", responseCombiners.data.Clone()) } timeout := 500 backend := config.Backend{} endpoint := config.EndpointConfig{ Backend: []*config.Backend{&backend, &backend}, Timeout: time.Duration(timeout) * time.Millisecond, ExtraConfig: config.ExtraConfig{ Namespace: map[string]interface{}{ mergeKey: defaultCombinerName, }, }, } mw := NewMergeDataMiddleware(logging.NoOp, &endpoint) p := mw( dummyProxy(&Response{Data: map[string]interface{}{"supu": 42}, IsComplete: true}), dummyProxy(&Response{Data: map[string]interface{}{"tupu": true}, IsComplete: true})) mustEnd := time.After(time.Duration(2*timeout) * time.Millisecond) out, err := p(context.Background(), &Request{}) if err != nil { t.Errorf("The middleware propagated an unexpected error: %s\n", err.Error()) } if out == nil { t.Errorf("The proxy returned a null result\n") return } select { case <-mustEnd: t.Errorf("We were expecting a response but we got none\n") default: if len(out.Data) != 2 { t.Errorf("We weren't expecting a partial response but we got %v!\n", out) } if !out.IsComplete { t.Errorf("We were expecting a completed response but we got an incompleted one!\n") } } } func Test_incrementalMergeAccumulator_invalidResponse(t *testing.T) { acc := newIncrementalMergeAccumulator(3, combineData) acc.Merge(nil, nil) acc.Merge(nil, nil) acc.Merge(nil, nil) res, err := acc.Result() if res != nil { t.Error("response should be nil") return } if err == nil { t.Error("expecting error") return } switch mergeErr := err.(type) { case mergeError: if len(mergeErr.errs) != 3 { t.Errorf("The middleware propagated an unexpected error: %s", err.Error()) } if mergeErr.errs[0] != mergeErr.errs[1] { t.Errorf("The middleware propagated an unexpected error: %s", err.Error()) } if mergeErr.errs[0] != mergeErr.errs[2] { t.Errorf("The middleware propagated an unexpected error: %s", err.Error()) } if mergeErr.errs[0] != errNullResult { t.Errorf("The middleware propagated an unexpected error: %s", err.Error()) } default: t.Errorf("The middleware propagated an unexpected error: %s", err.Error()) } } func Test_incrementalMergeAccumulator_incompleteResponse(t *testing.T) { acc := newIncrementalMergeAccumulator(3, combineData) acc.Merge(&Response{Data: make(map[string]interface{}), IsComplete: true}, nil) acc.Merge(&Response{Data: make(map[string]interface{}), IsComplete: false}, nil) acc.Merge(&Response{Data: make(map[string]interface{}), IsComplete: true}, nil) res, err := acc.Result() if res == nil { t.Error("response should not be nil") return } if err != nil { t.Errorf("unexpected error: %s", err.Error()) return } if res.IsComplete { t.Error("response should not be completed") } } func testNewMergeDataMiddleware_simpleFiltering(t *testing.T) { timeout := 500 backend := config.Backend{} endpoint := config.EndpointConfig{ Backend: []*config.Backend{&backend, &backend}, Timeout: time.Duration(timeout) * time.Millisecond, } RegisterBackendFiltererFactory("", func(_ *config.EndpointConfig) ([]BackendFilterer, error) { return []BackendFilterer{ func(_ *Request) bool { return true }, func(r *Request) bool { return r.Headers["X-Filter"][0] == "supu" }, }, nil }) mw := NewMergeDataMiddleware(logging.NoOp, &endpoint) p := mw( dummyProxy(&Response{Data: map[string]interface{}{"supu": 42}, IsComplete: true}), dummyProxy(&Response{Data: map[string]interface{}{"tupu": true}, IsComplete: true})) mustEnd := time.After(time.Duration(2*timeout) * time.Millisecond) out, err := p(context.Background(), &Request{Headers: map[string][]string{"X-Filter": {"meh"}}}) if err != nil { t.Errorf("The middleware propagated an unexpected error: %s\n", err.Error()) } if out == nil { t.Errorf("The proxy returned a null result\n") return } select { case <-mustEnd: t.Errorf("We were expecting a response but we got none\n") default: if len(out.Data) != 1 || out.Data["supu"] != 42 { t.Errorf("We were expecting a response from just a backend, but we got %v!\n", out) } if !out.IsComplete { t.Errorf("We were expecting a completed response but we got an incompleted one!\n") } } } func testNewMergeDataMiddleware_sequentialFiltering(t *testing.T) { timeout := 1000 endpoint := config.EndpointConfig{ Backend: []*config.Backend{ {URLPattern: "/"}, {URLPattern: "/aaa/{{.Resp0_string}}"}, {URLPattern: "/hit-me/{{.Resp1_tupu}}"}, }, Timeout: time.Duration(timeout) * time.Millisecond, ExtraConfig: config.ExtraConfig{ Namespace: map[string]interface{}{ isSequentialKey: true, }, }, } RegisterBackendFiltererFactory("", func(_ *config.EndpointConfig) ([]BackendFilterer, error) { return []BackendFilterer{ func(_ *Request) bool { return true }, func(_ *Request) bool { return false }, func(_ *Request) bool { return true }, }, nil }) mw := NewMergeDataMiddleware(logging.NoOp, &endpoint) p := mw( dummyProxy(&Response{Data: map[string]interface{}{"string": "some"}, IsComplete: true}), dummyProxy(&Response{Data: map[string]interface{}{"tupu": "foo"}, IsComplete: true}), dummyProxy(&Response{Data: map[string]interface{}{"final": "meh"}, IsComplete: true}), ) mustEnd := time.After(time.Duration(2*timeout) * time.Millisecond) out, err := p(context.Background(), &Request{ Params: map[string]string{}, }) if err != nil { t.Errorf("The middleware propagated an unexpected error: %s\n", err.Error()) } if out == nil { t.Errorf("The proxy returned a null result\n") return } select { case <-mustEnd: t.Errorf("We were expecting a response but we got none\n") default: if len(out.Data) != 2 { t.Errorf("We were expecting a response from just two backends, but we got %v!\n", out) } if !out.IsComplete { t.Errorf("We were expecting a completed response but we got an incompleted one!\n") } } } ================================================ FILE: proxy/plugin/modifier.go ================================================ // SPDX-License-Identifier: Apache-2.0 /* Package plugin provides tools for loading and registering proxy plugins */ package plugin import ( "context" "fmt" "plugin" "strings" "github.com/luraproject/lura/v2/logging" luraplugin "github.com/luraproject/lura/v2/plugin" "github.com/luraproject/lura/v2/register" ) const ( // Namespace is the namespace for the extra_config section Namespace = "github.com/devopsfaith/krakend/proxy/plugin" // requestNamespace is the internal namespace for the register to be used with request modifiers requestNamespace = "github.com/devopsfaith/krakend/proxy/plugin/request" // responseNamespace is the internal namespace for the register to be used with response modifiers responseNamespace = "github.com/devopsfaith/krakend/proxy/plugin/response" ) var modifierRegister = register.New() // ModifierFactory is a function that, given a config passed as a map, returns a modifier type ModifierFactory func(map[string]interface{}) func(interface{}) (interface{}, error) // GetRequestModifier returns a ModifierFactory from the request namespace by name func GetRequestModifier(name string) (ModifierFactory, bool) { return getModifier(requestNamespace, name) } // GetResponseModifier returns a ModifierFactory from the response namespace by name func GetResponseModifier(name string) (ModifierFactory, bool) { return getModifier(responseNamespace, name) } func getModifier(namespace, name string) (ModifierFactory, bool) { r, ok := modifierRegister.Get(namespace) if !ok { return nil, ok } m, ok := r.Get(name) if !ok { return nil, ok } res, ok := m.(func(map[string]interface{}) func(interface{}) (interface{}, error)) if !ok { return nil, ok } return ModifierFactory(res), ok } // RegisterModifier registers the injected modifier factory with the given name at the selected namespace func RegisterModifier( name string, modifierFactory func(map[string]interface{}) func(interface{}) (interface{}, error), appliesToRequest bool, appliesToResponse bool, ) { if appliesToRequest { modifierRegister.Register(requestNamespace, name, modifierFactory) } if appliesToResponse { modifierRegister.Register(responseNamespace, name, modifierFactory) } } // Registerer defines the interface for the plugins to expose in order to be able to be loaded/registered type Registerer interface { RegisterModifiers(func( name string, modifierFactory func(map[string]interface{}) func(interface{}) (interface{}, error), appliesToRequest bool, appliesToResponse bool, )) } type LoggerRegisterer interface { RegisterLogger(interface{}) } type ContextRegisterer interface { RegisterContext(context.Context) } // RegisterModifierFunc type is the function passed to the loaded Registerers type RegisterModifierFunc func( name string, modifierFactory func(map[string]interface{}) func(interface{}) (interface{}, error), appliesToRequest bool, appliesToResponse bool, ) // Load scans the given path using the pattern and registers all the found modifier plugins into the rmf func Load(path, pattern string, rmf RegisterModifierFunc) (int, error) { return LoadWithLogger(path, pattern, rmf, nil) } // LoadWithLogger scans the given path using the pattern and registers all the found modifier plugins into the rmf func LoadWithLogger(path, pattern string, rmf RegisterModifierFunc, logger logging.Logger) (int, error) { return LoadWithLoggerAndContext(context.Background(), path, pattern, rmf, logger) } func LoadWithLoggerAndContext(ctx context.Context, path, pattern string, rmf RegisterModifierFunc, logger logging.Logger) (int, error) { plugins, err := luraplugin.Scan(path, pattern) if err != nil { return 0, err } return load(ctx, plugins, rmf, logger) } func load(ctx context.Context, plugins []string, rmf RegisterModifierFunc, logger logging.Logger) (int, error) { var errors []error loadedPlugins := 0 for k, pluginName := range plugins { if err := open(ctx, pluginName, rmf, logger); err != nil { errors = append(errors, fmt.Errorf("plugin #%d (%s): %s", k, pluginName, err.Error())) continue } loadedPlugins++ } if len(errors) > 0 { return loadedPlugins, loaderError{errors: errors} } return loadedPlugins, nil } func open(ctx context.Context, pluginName string, rmf RegisterModifierFunc, logger logging.Logger) (err error) { defer func() { if r := recover(); r != nil { var ok bool err, ok = r.(error) if !ok { err = fmt.Errorf("%v", r) } } }() var p Plugin p, err = pluginOpener(pluginName) if err != nil { return } var r interface{} r, err = p.Lookup("ModifierRegisterer") if err != nil { return } registerer, ok := r.(Registerer) if !ok { return fmt.Errorf("modifier plugin loader: unknown type") } if logger != nil { if lr, ok := r.(LoggerRegisterer); ok { lr.RegisterLogger(logger) } } if lr, ok := r.(ContextRegisterer); ok { lr.RegisterContext(ctx) } RegisterExtraComponents(r) registerer.RegisterModifiers(rmf) return } var RegisterExtraComponents = func(interface{}) {} // Plugin is the interface of the loaded plugins type Plugin interface { Lookup(name string) (plugin.Symbol, error) } // pluginOpener keeps the plugin open function in a var for easy testing var pluginOpener = defaultPluginOpener func defaultPluginOpener(name string) (Plugin, error) { return plugin.Open(name) } type loaderError struct { errors []error } // Error implements the error interface func (l loaderError) Error() string { msgs := make([]string, len(l.errors)) for i, err := range l.errors { msgs[i] = err.Error() } return fmt.Sprintf("plugin loader found %d error(s): \n%s", len(msgs), strings.Join(msgs, "\n")) } func (l loaderError) Len() int { return len(l.errors) } func (l loaderError) Errs() []error { return l.errors } ================================================ FILE: proxy/plugin/modifier_test.go ================================================ //go:build integration || !race // +build integration !race // SPDX-License-Identifier: Apache-2.0 package plugin import ( "bytes" "context" "fmt" "io" "net/url" "strings" "testing" "github.com/luraproject/lura/v2/logging" ) func ExampleLoadWithLoggerAndContext() { var data []byte buf := bytes.NewBuffer(data) logger, err := logging.NewLogger("DEBUG", buf, "") if err != nil { fmt.Println(err.Error()) return } total, err := LoadWithLoggerAndContext(context.Background(), "./tests", ".so", RegisterModifier, logger) if err != nil { fmt.Println(err.Error()) return } if total != 2 { fmt.Printf("unexpected number of loaded plugins!. have %d, want 2\n", total) return } modFactory, ok := GetRequestModifier("lura-request-modifier-example-request") if !ok { fmt.Println("modifier factory not found in the register") return } modifier := modFactory(map[string]interface{}{}) input := requestWrapper{ ctx: context.WithValue(context.Background(), "myCtxKey", "some"), path: "/bar", method: "GET", headers: map[string][]string{"X-Foo": {"bar"}}, } tmp, err := modifier(input) if err != nil { fmt.Println(err.Error()) return } output, ok := tmp.(RequestWrapper) if !ok { fmt.Println("unexpected result type") return } if res := output.Path(); res != "/bar/fooo" { fmt.Printf("unexpected result path. have %s, want /bar/fooo\n", res) return } modFactory, ok = GetResponseModifier("lura-request-modifier-example-response") if !ok { fmt.Println("modifier factory not found in the register") return } modifier = modFactory(map[string]interface{}{}) response := responseWrapper{ ctx: context.WithValue(context.Background(), "myCtxKey", "other"), request: input, data: map[string]interface{}{"foo": "bar"}, } if _, err = modifier(response); err != nil { fmt.Println(err.Error()) return } lines := strings.Split(buf.String(), "\n") for i := range lines[:len(lines)-1] { fmt.Println(lines[i][21:]) } // output: // DEBUG: [PLUGIN: lura-error-example] Logger loaded // DEBUG: [PLUGIN: lura-request-modifier-example] Logger loaded // DEBUG: [PLUGIN: lura-request-modifier-example] Context loaded // DEBUG: [PLUGIN: lura-request-modifier-example] Request modifier injected // DEBUG: context key: some // DEBUG: params: map[] // DEBUG: headers: map[X-Foo:[bar]] // DEBUG: method: GET // DEBUG: url: // DEBUG: query: map[] // DEBUG: path: /bar/fooo // DEBUG: [PLUGIN: lura-request-modifier-example] Response modifier injected // DEBUG: Header X-Foo value: bar // DEBUG: context key: other // DEBUG: data: map[foo:bar] // DEBUG: is complete: false // DEBUG: headers: map[] // DEBUG: status code: 0 // DEBUG: original headers: map[X-Foo:[bar]] } func TestLoad(t *testing.T) { total, err := LoadWithLogger("./tests", ".so", RegisterModifier, logging.NoOp) if err != nil { t.Error(err.Error()) t.Fail() } if total != 2 { t.Errorf("unexpected number of loaded plugins!. have %d, want 2", total) } modFactory, ok := GetRequestModifier("lura-request-modifier-example-request") if !ok { t.Error("modifier factory not found in the register") return } modifier := modFactory(map[string]interface{}{}) input := requestWrapper{ctx: context.WithValue(context.Background(), "myCtxKey", "some"), path: "/bar"} tmp, err := modifier(input) if err != nil { t.Error(err.Error()) return } output, ok := tmp.(RequestWrapper) if !ok { t.Error("unexpected result type") return } if res := output.Path(); res != "/bar/fooo" { t.Errorf("unexpected result path. have %s, want /bar/fooo", res) } } type RequestWrapper interface { Params() map[string]string Headers() map[string][]string Body() io.ReadCloser Method() string URL() *url.URL Query() url.Values Path() string } type requestWrapper struct { ctx context.Context method string url *url.URL query url.Values path string body io.ReadCloser params map[string]string headers map[string][]string } func (r requestWrapper) Context() context.Context { return r.ctx } func (r requestWrapper) Method() string { return r.method } func (r requestWrapper) URL() *url.URL { return r.url } func (r requestWrapper) Query() url.Values { return r.query } func (r requestWrapper) Path() string { return r.path } func (r requestWrapper) Body() io.ReadCloser { return r.body } func (r requestWrapper) Params() map[string]string { return r.params } func (r requestWrapper) Headers() map[string][]string { return r.headers } type metadataWrapper struct { headers map[string][]string statusCode int } func (m metadataWrapper) Headers() map[string][]string { return m.headers } func (m metadataWrapper) StatusCode() int { return m.statusCode } type responseWrapper struct { ctx context.Context request interface{} data map[string]interface{} isComplete bool metadata metadataWrapper io io.Reader } func (r responseWrapper) Context() context.Context { return r.ctx } func (r responseWrapper) Request() interface{} { return r.request } func (r responseWrapper) Data() map[string]interface{} { return r.data } func (r responseWrapper) IsComplete() bool { return r.isComplete } func (r responseWrapper) Io() io.Reader { return r.io } func (r responseWrapper) Headers() map[string][]string { return r.metadata.headers } func (r responseWrapper) StatusCode() int { return r.metadata.statusCode } ================================================ FILE: proxy/plugin/tests/error/main.go ================================================ // SPDX-License-Identifier: Apache-2.0 package main import ( "errors" "fmt" "net/http" ) func main() {} var ModifierRegisterer = registerer("lura-error-example") var logger Logger = nil type registerer string func (r registerer) RegisterModifiers(f func( name string, modifierFactory func(map[string]interface{}) func(interface{}) (interface{}, error), appliesToRequest bool, appliesToResponse bool, )) { f(string(r)+"-request", r.requestModifierFactory, true, false) f(string(r)+"-response", r.reqsponseModifierFactory, false, true) } func (registerer) RegisterLogger(in interface{}) { l, ok := in.(Logger) if !ok { return } logger = l logger.Debug(fmt.Sprintf("[PLUGIN: %s] Logger loaded", ModifierRegisterer)) } func (registerer) requestModifierFactory(_ map[string]interface{}) func(interface{}) (interface{}, error) { logger.Debug(fmt.Sprintf("[PLUGIN: %s] Request modifier injected", ModifierRegisterer)) return func(_ interface{}) (interface{}, error) { logger.Debug(fmt.Sprintf("[PLUGIN: %s] Rejecting request", ModifierRegisterer)) return nil, requestErr } } func (registerer) reqsponseModifierFactory(_ map[string]interface{}) func(interface{}) (interface{}, error) { logger.Debug(fmt.Sprintf("[PLUGIN: %s] Response modifier injected", ModifierRegisterer)) return func(_ interface{}) (interface{}, error) { logger.Debug(fmt.Sprintf("[PLUGIN: %s] Replacing response", ModifierRegisterer)) return nil, responseErr } } type customError struct { error statusCode int } func (r customError) StatusCode() int { return r.statusCode } var ( requestErr = customError{ error: errors.New("request rejected just because"), statusCode: http.StatusTeapot, } responseErr = customError{ error: errors.New("response replaced because reasons"), statusCode: http.StatusTeapot, } ) type Logger interface { Debug(v ...interface{}) Info(v ...interface{}) Warning(v ...interface{}) Error(v ...interface{}) Critical(v ...interface{}) Fatal(v ...interface{}) } ================================================ FILE: proxy/plugin/tests/logger/main.go ================================================ // SPDX-License-Identifier: Apache-2.0 package main import ( "context" "errors" "fmt" "io" "net/url" "path" ) func main() {} var ModifierRegisterer = registerer("lura-request-modifier-example") var logger Logger = nil var ctx context.Context = context.Background() type registerer string func (r registerer) RegisterModifiers(f func( name string, modifierFactory func(map[string]interface{}) func(interface{}) (interface{}, error), appliesToRequest bool, appliesToResponse bool, )) { f(string(r)+"-request", r.requestModifierFactory, true, false) f(string(r)+"-response", r.reqsponseModifierFactory, false, true) } func (registerer) RegisterLogger(in interface{}) { l, ok := in.(Logger) if !ok { return } logger = l logger.Debug(fmt.Sprintf("[PLUGIN: %s] Logger loaded", ModifierRegisterer)) } func (registerer) RegisterContext(c context.Context) { ctx = c logger.Debug(fmt.Sprintf("[PLUGIN: %s] Context loaded", ModifierRegisterer)) } func (registerer) requestModifierFactory(_ map[string]interface{}) func(interface{}) (interface{}, error) { // check the config // return the modifier // Graceful shutdown of any service or connection managed by the plugin go func() { <-ctx.Done() logger.Debug("Shuting down the service") }() if logger == nil { fmt.Println("request modifier loaded without logger") return func(input interface{}) (interface{}, error) { req, ok := input.(RequestWrapper) if !ok { return nil, unkownTypeErr } return modifier(req), nil } } logger.Debug(fmt.Sprintf("[PLUGIN: %s] Request modifier injected", ModifierRegisterer)) return func(input interface{}) (interface{}, error) { req, ok := input.(RequestWrapper) if !ok { return nil, unkownTypeErr } r := modifier(req) requestCtx := req.Context() logger.Debug("context key:", requestCtx.Value("myCtxKey")) logger.Debug("params:", r.params) logger.Debug("headers:", r.headers) logger.Debug("method:", r.method) logger.Debug("url:", r.url) logger.Debug("query:", r.query) logger.Debug("path:", r.path) return r, nil } } func (registerer) reqsponseModifierFactory(_ map[string]interface{}) func(interface{}) (interface{}, error) { // check the cfg. If the modifier requires some configuration, // it should be under the name of the plugin. // ex: if this modifier required some A and B config params /* "extra_config":{ "plugin/req-resp-modifier":{ "name":["krakend-debugger"], "krakend-debugger":{ "A":"foo", "B":42 } } } */ go func() { <-ctx.Done() logger.Debug("Shuting down the service") }() // return the modifier if logger == nil { fmt.Println("response modifier loaded without logger") return func(input interface{}) (interface{}, error) { resp, ok := input.(ResponseWrapper) if !ok { return nil, unkownTypeErr } fmt.Println("data:", resp.Data()) fmt.Println("is complete:", resp.IsComplete()) fmt.Println("headers:", resp.Headers()) fmt.Println("status code:", resp.StatusCode()) return input, nil } } logger.Debug(fmt.Sprintf("[PLUGIN: %s] Response modifier injected", ModifierRegisterer)) return func(input interface{}) (interface{}, error) { resp, ok := input.(ResponseWrapper) if !ok { return nil, unkownTypeErr } if req, ok := resp.Request().(RequestWrapper); ok { for k, v := range req.Headers() { logger.Debug(fmt.Sprintf("Header %s value: %s", k, v[0])) } } respCtx := resp.Context() logger.Debug("context key:", respCtx.Value("myCtxKey")) logger.Debug("data:", resp.Data()) logger.Debug("is complete:", resp.IsComplete()) logger.Debug("headers:", resp.Headers()) logger.Debug("status code:", resp.StatusCode()) req, ok := resp.Request().(RequestWrapper) if ok { logger.Debug("original headers:", req.Headers()) } return resp, nil } } func modifier(req RequestWrapper) requestWrapper { return requestWrapper{ params: req.Params(), headers: req.Headers(), body: req.Body(), method: req.Method(), url: req.URL(), query: req.Query(), path: path.Join(req.Path(), "/fooo"), } } var unkownTypeErr = errors.New("unknown request type") type ResponseWrapper interface { Context() context.Context Request() interface{} Data() map[string]interface{} IsComplete() bool Headers() map[string][]string StatusCode() int } type RequestWrapper interface { Context() context.Context Params() map[string]string Headers() map[string][]string Body() io.ReadCloser Method() string URL() *url.URL Query() url.Values Path() string } type requestWrapper struct { ctx context.Context method string url *url.URL query url.Values path string body io.ReadCloser params map[string]string headers map[string][]string } func (r requestWrapper) Context() context.Context { return r.ctx } func (r requestWrapper) Method() string { return r.method } func (r requestWrapper) URL() *url.URL { return r.url } func (r requestWrapper) Query() url.Values { return r.query } func (r requestWrapper) Path() string { return r.path } func (r requestWrapper) Body() io.ReadCloser { return r.body } func (r requestWrapper) Params() map[string]string { return r.params } func (r requestWrapper) Headers() map[string][]string { return r.headers } type Logger interface { Debug(v ...interface{}) Info(v ...interface{}) Warning(v ...interface{}) Error(v ...interface{}) Critical(v ...interface{}) Fatal(v ...interface{}) } ================================================ FILE: proxy/plugin.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "context" "fmt" "io" "net/url" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" "github.com/luraproject/lura/v2/proxy/plugin" ) // NewPluginMiddleware returns an endpoint middleware wrapped (if required) with the plugin middleware. // The plugin middleware will try to load all the required plugins from the register and execute them in order. // RequestModifiers are executed before passing the request to the next middlware. ResponseModifiers are executed // once the response is returned from the next middleware. func NewPluginMiddleware(logger logging.Logger, endpoint *config.EndpointConfig) Middleware { cfg, ok := endpoint.ExtraConfig[plugin.Namespace].(map[string]interface{}) if !ok { return emptyMiddlewareFallback(logger) } return newPluginMiddleware(logger, "ENDPOINT", endpoint.Endpoint, cfg) } // NewBackendPluginMiddleware returns a backend middleware wrapped (if required) with the plugin middleware. // The plugin middleware will try to load all the required plugins from the register and execute them in order. // RequestModifiers are executed before passing the request to the next middlware. ResponseModifiers are executed // once the response is returned from the next middleware. func NewBackendPluginMiddleware(logger logging.Logger, remote *config.Backend) Middleware { cfg, ok := remote.ExtraConfig[plugin.Namespace].(map[string]interface{}) if !ok { return emptyMiddlewareFallback(logger) } return newPluginMiddleware(logger, "BACKEND", fmt.Sprintf("%s %s -> %s", remote.ParentEndpointMethod, remote.ParentEndpoint, remote.URLPattern), cfg) } func newPluginMiddleware(logger logging.Logger, tag, pattern string, cfg map[string]interface{}) Middleware { plugins, ok := cfg["name"].([]interface{}) if !ok { return emptyMiddlewareFallback(logger) } var reqModifiers []func(interface{}) (interface{}, error) var respModifiers []func(interface{}) (interface{}, error) for _, p := range plugins { name, ok := p.(string) if !ok { continue } if mf, ok := plugin.GetRequestModifier(name); ok { if fn := mf(cfg); fn != nil { reqModifiers = append(reqModifiers, fn) } continue } if mf, ok := plugin.GetResponseModifier(name); ok { if fn := mf(cfg); fn != nil { respModifiers = append(respModifiers, fn) } } } totReqModifiers, totRespModifiers := len(reqModifiers), len(respModifiers) if totReqModifiers == totRespModifiers && totRespModifiers == 0 { return emptyMiddlewareFallback(logger) } logger.Debug( fmt.Sprintf( "[%s: %s][Modifier Plugins] Adding %d request and %d response modifiers", tag, pattern, totReqModifiers, totRespModifiers, ), ) return func(next ...Proxy) Proxy { if len(next) > 1 { logger.Fatal("too many proxies for this proxy middleware: newPluginMiddleware only accepts 1 proxy, got %d tag: %s, pattern: %s", len(next), tag, pattern) return nil } if totReqModifiers == 0 { return func(ctx context.Context, r *Request) (*Response, error) { resp, err := next[0](ctx, r) if err != nil { return resp, err } return executeResponseModifiers(ctx, respModifiers, resp, newRequestWrapper(ctx, r)) } } if totRespModifiers == 0 { return func(ctx context.Context, r *Request) (*Response, error) { var err error r, err = executeRequestModifiers(ctx, reqModifiers, r) if err != nil { return nil, err } return next[0](ctx, r) } } return func(ctx context.Context, r *Request) (*Response, error) { var err error r, err = executeRequestModifiers(ctx, reqModifiers, r) if err != nil { return nil, err } resp, err := next[0](ctx, r) if err != nil { return resp, err } return executeResponseModifiers(ctx, respModifiers, resp, newRequestWrapper(ctx, r)) } } } func executeRequestModifiers(ctx context.Context, reqModifiers []func(interface{}) (interface{}, error), r *Request) (*Request, error) { var tmp RequestWrapper tmp = newRequestWrapper(ctx, r) for _, f := range reqModifiers { res, err := f(tmp) if err != nil { return nil, err } t, ok := res.(RequestWrapper) if !ok { continue } tmp = t } r.Method = tmp.Method() r.URL = tmp.URL() r.Query = tmp.Query() r.Path = tmp.Path() r.Body = tmp.Body() r.Params = tmp.Params() r.Headers = tmp.Headers() return r, nil } func executeResponseModifiers(ctx context.Context, respModifiers []func(interface{}) (interface{}, error), r *Response, req RequestWrapper) (*Response, error) { var tmp ResponseWrapper tmp = responseWrapper{ ctx: ctx, request: req, data: r.Data, isComplete: r.IsComplete, metadata: metadataWrapper{ headers: r.Metadata.Headers, statusCode: r.Metadata.StatusCode, }, io: r.Io, } for _, f := range respModifiers { res, err := f(tmp) if err != nil { return nil, err } t, ok := res.(ResponseWrapper) if !ok { continue } tmp = t } r.Data = tmp.Data() r.IsComplete = tmp.IsComplete() r.Io = tmp.Io() r.Metadata = Metadata{} r.Metadata.Headers = tmp.Headers() r.Metadata.StatusCode = tmp.StatusCode() return r, nil } // RequestWrapper is an interface for passing proxy request between the lura pipe and the loaded plugins type RequestWrapper interface { Params() map[string]string Headers() map[string][]string Body() io.ReadCloser Method() string URL() *url.URL Query() url.Values Path() string } // ResponseWrapper is an interface for passing proxy response between the lura pipe and the loaded plugins type ResponseWrapper interface { Data() map[string]interface{} Io() io.Reader IsComplete() bool Headers() map[string][]string StatusCode() int } func newRequestWrapper(ctx context.Context, r *Request) *requestWrapper { return &requestWrapper{ ctx: ctx, method: r.Method, url: r.URL, query: r.Query, path: r.Path, body: r.Body, params: r.Params, headers: r.Headers, } } type requestWrapper struct { ctx context.Context method string url *url.URL query url.Values path string body io.ReadCloser params map[string]string headers map[string][]string } func (r *requestWrapper) Context() context.Context { return r.ctx } func (r *requestWrapper) Method() string { return r.method } func (r *requestWrapper) URL() *url.URL { return r.url } func (r *requestWrapper) Query() url.Values { return r.query } func (r *requestWrapper) Path() string { return r.path } func (r *requestWrapper) Body() io.ReadCloser { return r.body } func (r *requestWrapper) Params() map[string]string { return r.params } func (r *requestWrapper) Headers() map[string][]string { return r.headers } type metadataWrapper struct { headers map[string][]string statusCode int } func (m metadataWrapper) Headers() map[string][]string { return m.headers } func (m metadataWrapper) StatusCode() int { return m.statusCode } type responseWrapper struct { ctx context.Context request interface{} data map[string]interface{} isComplete bool metadata metadataWrapper io io.Reader } func (r responseWrapper) Context() context.Context { return r.ctx } func (r responseWrapper) Request() interface{} { return r.request } func (r responseWrapper) Data() map[string]interface{} { return r.data } func (r responseWrapper) IsComplete() bool { return r.isComplete } func (r responseWrapper) Io() io.Reader { return r.io } func (r responseWrapper) Headers() map[string][]string { return r.metadata.headers } func (r responseWrapper) StatusCode() int { return r.metadata.statusCode } ================================================ FILE: proxy/plugin_test.go ================================================ //go:build integration || !race // +build integration !race // SPDX-License-Identifier: Apache-2.0 package proxy import ( "context" "fmt" "net/http" "testing" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" "github.com/luraproject/lura/v2/proxy/plugin" ) func TestNewPluginMiddleware_logger(t *testing.T) { plugin.LoadWithLogger("./plugin/tests", ".so", plugin.RegisterModifier, logging.NoOp) validator := func(ctx context.Context, r *Request) (*Response, error) { if r.Path != "/bar/fooo/fooo" { return nil, fmt.Errorf("unexpected path %s", r.Path) } return &Response{ Data: map[string]interface{}{"foo": "bar"}, IsComplete: true, Metadata: Metadata{ Headers: map[string][]string{}, StatusCode: 0, }, }, nil } bknd := NewBackendPluginMiddleware( logging.NoOp, &config.Backend{ ExtraConfig: map[string]interface{}{ plugin.Namespace: map[string]interface{}{ "name": []interface{}{"lura-request-modifier-example-request"}, }, }, }, )(validator) p := NewPluginMiddleware( logging.NoOp, &config.EndpointConfig{ ExtraConfig: map[string]interface{}{ plugin.Namespace: map[string]interface{}{ "name": []interface{}{ "lura-request-modifier-example-request", "lura-request-modifier-example-response", }, }, }, }, )(bknd) resp, err := p(context.Background(), &Request{Path: "/bar"}) if err != nil { t.Error(err.Error()) } if resp == nil { t.Errorf("unexpected response: %v", resp) return } if v, ok := resp.Data["foo"].(string); !ok || v != "bar" { t.Errorf("unexpected foo value: %v", resp.Data["foo"]) } } func TestNewPluginMiddleware_error_request(t *testing.T) { plugin.LoadWithLogger("./plugin/tests", ".so", plugin.RegisterModifier, logging.NoOp) validator := func(ctx context.Context, r *Request) (*Response, error) { t.Error("the backend should not be called") return nil, nil } bknd := NewBackendPluginMiddleware( logging.NoOp, &config.Backend{}, )(validator) p := NewPluginMiddleware( logging.NoOp, &config.EndpointConfig{ ExtraConfig: map[string]interface{}{ plugin.Namespace: map[string]interface{}{ "name": []interface{}{ "lura-error-example-request", }, }, }, }, )(bknd) resp, err := p(context.Background(), &Request{Path: "/bar"}) if resp != nil { t.Errorf("unexpected response: %v", resp) return } if err == nil { t.Error("error expected") return } customErr, ok := err.(statusCodeError) if !ok { t.Errorf("unexpected error: %+v (%T)", err, err) return } if sc := customErr.StatusCode(); sc != http.StatusTeapot { t.Errorf("unexpected status code: %d", sc) } if errorMsg := err.Error(); errorMsg != "request rejected just because" { t.Errorf("unexpected error message. have: '%s'", errorMsg) } } func TestNewPluginMiddleware_error_response(t *testing.T) { plugin.LoadWithLogger("./plugin/tests", ".so", plugin.RegisterModifier, logging.NoOp) var hit bool validator := func(ctx context.Context, r *Request) (*Response, error) { hit = true return &Response{ Data: map[string]interface{}{"foo": "bar"}, IsComplete: true, Metadata: Metadata{ Headers: map[string][]string{}, }, }, nil } bknd := NewBackendPluginMiddleware( logging.NoOp, &config.Backend{}, )(validator) p := NewPluginMiddleware( logging.NoOp, &config.EndpointConfig{ ExtraConfig: map[string]interface{}{ plugin.Namespace: map[string]interface{}{ "name": []interface{}{ "lura-error-example-response", }, }, }, }, )(bknd) resp, err := p(context.Background(), &Request{Path: "/bar"}) if resp != nil { t.Errorf("unexpected response: %v", resp) return } if err == nil { t.Error("error expected") return } customErr, ok := err.(statusCodeError) if !ok { t.Errorf("unexpected error: %+v (%T)", err, err) return } if sc := customErr.StatusCode(); sc != http.StatusTeapot { t.Errorf("unexpected status code: %d", sc) } if errorMsg := err.Error(); errorMsg != "response replaced because reasons" { t.Errorf("unexpected error message. have: '%s'", errorMsg) } if !hit { t.Error("the backend has not been called") } } func TestNewPluginMiddleware_PoisonedPlugin(t *testing.T) { plugin.RegisterModifier("poisoned", func(map[string]interface{}) func(interface{}) (interface{}, error) { return nil }, false, true) expectedResp := &Response{ Data: map[string]interface{}{"foo": "bar"}, IsComplete: true, Metadata: Metadata{ Headers: map[string][]string{}, }, } validator := func(ctx context.Context, r *Request) (*Response, error) { return expectedResp, nil } bknd := NewBackendPluginMiddleware( logging.NoOp, &config.Backend{}, )(validator) p := NewPluginMiddleware( logging.NoOp, &config.EndpointConfig{ ExtraConfig: map[string]interface{}{ plugin.Namespace: map[string]interface{}{ "name": []interface{}{ "poisoned", }, }, }, }, )(bknd) resp, err := p(context.Background(), &Request{Path: "/bar"}) if resp != expectedResp { t.Errorf("unexpected response: %v", resp) return } if err != nil { t.Error("error expected") return } } type statusCodeError interface { error StatusCode() int } ================================================ FILE: proxy/proxy.go ================================================ // SPDX-License-Identifier: Apache-2.0 /* Package proxy provides proxy and proxy middleware interfaces and implementations. */ package proxy import ( "context" "errors" "io" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" ) // Namespace to be used in extra config const Namespace = "github.com/devopsfaith/krakend/proxy" // Metadata is the Metadata of the Response which contains Headers and StatusCode type Metadata struct { Headers map[string][]string StatusCode int } // Response is the entity returned by the proxy type Response struct { Data map[string]interface{} IsComplete bool Metadata Metadata Io io.Reader } // readCloserWrapper is Io.Reader which is closed when the Context is closed or canceled type readCloserWrapper struct { ctx context.Context rc io.ReadCloser } // NewReadCloserWrapper Creates a new closeable io.Read func NewReadCloserWrapper(ctx context.Context, in io.ReadCloser) io.Reader { wrapper := readCloserWrapper{ctx, in} go wrapper.closeOnCancel() return wrapper } func (w readCloserWrapper) Read(b []byte) (int, error) { return w.rc.Read(b) } // closeOnCancel closes the io.Reader when the context is Done func (w readCloserWrapper) closeOnCancel() { <-w.ctx.Done() w.rc.Close() } var ( // ErrNoBackends is the error returned when an endpoint has no backends defined ErrNoBackends = errors.New("all endpoints must have at least one backend") // ErrTooManyBackends is the error returned when an endpoint has too many backends defined ErrTooManyBackends = errors.New("too many backends for this proxy") // ErrTooManyProxies is the error returned when a middleware has too many proxies defined ErrTooManyProxies = errors.New("too many proxies for this proxy middleware") // ErrNotEnoughProxies is the error returned when an endpoint has not enough proxies defined ErrNotEnoughProxies = errors.New("not enough proxies for this endpoint") ) // Proxy processes a request in a given context and returns a response and an error type Proxy func(ctx context.Context, request *Request) (*Response, error) // BackendFactory creates a proxy based on the received backend configuration type BackendFactory func(remote *config.Backend) Proxy // Middleware adds a middleware, decorator or wrapper over a collection of proxies, // exposing a proxy interface. // // Proxy middlewares can be stacked: // // var p Proxy // p := EmptyMiddleware(NoopProxy) // response, err := p(ctx, r) type Middleware func(next ...Proxy) Proxy // EmptyMiddlewareWithLoggger is a dummy middleware, useful for testing and fallback func EmptyMiddlewareWithLogger(logger logging.Logger, next ...Proxy) Proxy { if len(next) > 1 { logger.Fatal("too many proxies for this proxy middleware: EmptyMiddleware only accepts 1 proxy, got %d", len(next)) return nil } return next[0] } func EmptyMiddleware(next ...Proxy) Proxy { return EmptyMiddlewareWithLogger(logging.NoOp, next...) } func emptyMiddlewareFallback(logger logging.Logger) Middleware { return func(next ...Proxy) Proxy { return EmptyMiddlewareWithLogger(logger, next...) } } // NoopProxy is a do nothing proxy, useful for testing func NoopProxy(_ context.Context, _ *Request) (*Response, error) { return nil, nil } ================================================ FILE: proxy/proxy_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "bytes" "context" "fmt" "io" "strings" "sync" "testing" "time" ) func TestEmptyMiddleware_ok(t *testing.T) { expected := Response{} result, err := EmptyMiddleware(dummyProxy(&expected))(context.Background(), &Request{}) if err != nil { t.Errorf("The middleware propagated an unexpected error: %s\n", err.Error()) } if result != &expected { t.Errorf("The middleware returned an unexpected result: %v\n", result) } } func explosiveProxy(t *testing.T) Proxy { return func(ctx context.Context, _ *Request) (*Response, error) { t.Error("This proxy shouldn't been executed!") return &Response{}, nil } } func dummyProxy(r *Response) Proxy { return func(_ context.Context, _ *Request) (*Response, error) { return r, nil } } func delayedProxy(_ *testing.T, timeout time.Duration, r *Response) Proxy { return func(ctx context.Context, _ *Request) (*Response, error) { select { case <-ctx.Done(): return nil, ctx.Err() case <-time.After(timeout): return r, nil } } } func newDummyReadCloser(content string) io.ReadCloser { return dummyReadCloser{strings.NewReader(content)} } type dummyReadCloser struct { reader io.Reader } func (d dummyReadCloser) Read(p []byte) (int, error) { return d.reader.Read(p) } func (dummyReadCloser) Close() error { return nil } func TestWrapper(t *testing.T) { expected := "supu" ctx, cancel := context.WithCancel(context.Background()) defer cancel() readCloser := &dummyRC{ r: bytes.NewBufferString(expected), mu: &sync.Mutex{}, } r := NewReadCloserWrapper(ctx, readCloser) var out bytes.Buffer tot, err := out.ReadFrom(r) if err != nil { t.Errorf("Total bits read: %d. Err: %s", tot, err.Error()) return } if readCloser.IsClosed() { t.Error("The subject shouldn't be closed yet") return } if tot != 4 { t.Errorf("Unexpected number of bits read: %d", tot) return } if v := out.String(); v != expected { t.Errorf("Unexpected content: %s", v) return } cancel() <-time.After(100 * time.Millisecond) if !readCloser.IsClosed() { t.Error("The subject should be already closed") return } } type dummyRC struct { r io.Reader closed bool mu *sync.Mutex } func (d *dummyRC) Read(b []byte) (int, error) { d.mu.Lock() defer d.mu.Unlock() if d.closed { return -1, fmt.Errorf("Reading from a closed source") } return d.r.Read(b) } func (d *dummyRC) Close() error { d.mu.Lock() defer d.mu.Unlock() d.closed = true return nil } func (d *dummyRC) IsClosed() bool { d.mu.Lock() defer d.mu.Unlock() res := d.closed return res } ================================================ FILE: proxy/query_strings_filter.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "context" "net/url" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" ) // NewFilterQueryStringsMiddleware returns a middleware with or without a header filtering // proxy wrapping the next element (depending on the configuration). func NewFilterQueryStringsMiddleware(logger logging.Logger, remote *config.Backend) Middleware { if len(remote.QueryStringsToPass) == 0 { return emptyMiddlewareFallback(logger) } return func(next ...Proxy) Proxy { if len(next) > 1 { logger.Fatal("too many proxies for this %s %s -> %s proxy middleware: NewFilterQueryStringsMiddleware only accepts 1 proxy, got %d", remote.ParentEndpointMethod, remote.ParentEndpoint, remote.URLPattern, len(next)) return nil } nextProxy := next[0] return func(ctx context.Context, request *Request) (*Response, error) { if len(request.Query) == 0 { return nextProxy(ctx, request) } numQueryStringsToPass := 0 for _, v := range remote.QueryStringsToPass { if _, ok := request.Query[v]; ok { numQueryStringsToPass++ } } if numQueryStringsToPass == len(request.Query) { // all the query strings should pass, no need to clone the headers return nextProxy(ctx, request) } // ATTENTION: this is not a clone of query strings! // this just filters the query strings we do not want to send: // issues and race conditions could happen the same way as when we // do not filter the headers. This is a design decission, and if we // want to clone the query string values (because of write modifications), // that should be done at an upper level (so the approach is the same // for non filtered parallel requests). newQueryStrings := make(url.Values, numQueryStringsToPass) for _, v := range remote.QueryStringsToPass { if values, ok := request.Query[v]; ok { newQueryStrings[v] = values } } return nextProxy(ctx, &Request{ Method: request.Method, URL: request.URL, Query: newQueryStrings, Path: request.Path, Body: request.Body, Params: request.Params, Headers: request.Headers, }) } } } ================================================ FILE: proxy/query_strings_filter_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "context" "testing" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" ) func TestNewFilterQueryStringsMiddleware(t *testing.T) { mw := NewFilterQueryStringsMiddleware( logging.NoOp, &config.Backend{ QueryStringsToPass: []string{ "oak", "cedar", }, }, ) var receivedReq *Request prxy := mw(func(ctx context.Context, req *Request) (*Response, error) { receivedReq = req return nil, nil }) sentReq := &Request{ Body: nil, Params: map[string]string{}, Query: map[string][]string{ "oak": []string{"acorn", "evergreen"}, "maple": []string{"tree", "shrub"}, "cedar": []string{"mediterranean", "himalayas"}, "willow": []string{"350"}, }, } prxy(context.Background(), sentReq) if receivedReq == sentReq { t.Errorf("request should be different") return } oak, ok := receivedReq.Query["oak"] if !ok { t.Errorf("missing 'oak'") return } if len(oak) != len(sentReq.Query["oak"]) { t.Errorf("want len(oak): %d, got %d", len(sentReq.Query["oak"]), len(oak)) return } for idx, expected := range sentReq.Query["oak"] { if expected != oak[idx] { t.Errorf("want oak[%d] = %s, got %s", idx, expected, oak[idx]) return } } if _, ok := receivedReq.Query["cedar"]; !ok { t.Errorf("missing 'cedar'") return } if _, ok := receivedReq.Query["mapple"]; ok { t.Errorf("should not be there: 'mapple'") return } if _, ok := receivedReq.Query["willow"]; ok { t.Errorf("should not be there: 'willow'") return } // check that when query strings are all the expected, no need to copy sentReq = &Request{ Body: nil, Params: map[string]string{}, Query: map[string][]string{ "oak": []string{"acorn", "evergreen"}, "cedar": []string{"mediterranean", "himalayas"}, }, } prxy(context.Background(), sentReq) if receivedReq != sentReq { t.Errorf("request should be the same, no modification of query string expected") return } } func TestFilterQueryStringsBlockAll(t *testing.T) { // In order to block all the query strings, we must only let pass // the 'empty' string "" mw := NewFilterQueryStringsMiddleware( logging.NoOp, &config.Backend{ QueryStringsToPass: []string{""}, }, ) var receivedReq *Request prxy := mw(func(ctx context.Context, req *Request) (*Response, error) { receivedReq = req return nil, nil }) sentReq := &Request{ Body: nil, Params: map[string]string{}, Query: map[string][]string{ "oak": []string{"acorn", "evergreen"}, "maple": []string{"tree", "shrub"}, }, } prxy(context.Background(), sentReq) if receivedReq == sentReq { t.Errorf("request should be different") return } if len(receivedReq.Query) != 0 { t.Errorf("should have blocked all query strings") return } } func TestFilterQueryStringsAllowAll(t *testing.T) { // Empty backend query strings to passa everything mw := NewFilterQueryStringsMiddleware( logging.NoOp, &config.Backend{ QueryStringsToPass: []string{}, }, ) var receivedReq *Request prxy := mw(func(ctx context.Context, req *Request) (*Response, error) { receivedReq = req return nil, nil }) sentReq := &Request{ Body: nil, Params: map[string]string{}, Query: map[string][]string{ "oak": []string{"acorn", "evergreen"}, "maple": []string{"tree", "shrub"}, }, } prxy(context.Background(), sentReq) if len(receivedReq.Query) != 2 { t.Errorf("should have passed all query strings") return } } ================================================ FILE: proxy/register.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "github.com/luraproject/lura/v2/register" ) func NewRegister() *Register { return &Register{ responseCombiners, } } type Register struct { *combinerRegister } type combinerRegister struct { data *register.Untyped fallback ResponseCombiner } func newCombinerRegister(data map[string]ResponseCombiner, fallback ResponseCombiner) *combinerRegister { r := register.NewUntyped() for k, v := range data { r.Register(k, v) } return &combinerRegister{r, fallback} } func (r *combinerRegister) GetResponseCombiner(name string) (ResponseCombiner, bool) { v, ok := r.data.Get(name) if !ok { return r.fallback, ok } if rc, ok := v.(ResponseCombiner); ok { return rc, ok } return r.fallback, ok } func (r *combinerRegister) SetResponseCombiner(name string, rc ResponseCombiner) { r.data.Register(name, rc) } ================================================ FILE: proxy/register_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "reflect" "testing" ) func TestNewRegister_responseCombiner_ok(t *testing.T) { r := NewRegister() r.SetResponseCombiner("name1", func(total int, parts []*Response) *Response { if total < 0 || total >= len(parts) { return nil } return parts[total] }) rc, ok := r.GetResponseCombiner("name1") if !ok { t.Error("expecting response combiner") return } result := rc(0, []*Response{{IsComplete: true, Data: map[string]interface{}{"a": 42}}}) if result == nil { t.Error("expecting result") return } if !result.IsComplete { t.Error("expecting a complete result") return } if len(result.Data) != 1 { t.Error("unexpected result size:", len(result.Data)) return } } func TestNewRegister_responseCombiner_fallbackIfErrored(t *testing.T) { r := NewRegister() r.data.Register("errored", true) rc, ok := r.GetResponseCombiner("errored") if !ok { t.Error("expecting response combiner") return } original := &Response{IsComplete: true, Data: map[string]interface{}{"a": 42}} result := rc(1, []*Response{{Data: original.Data, IsComplete: original.IsComplete}}) if !reflect.DeepEqual(original.Data, result.Data) { t.Errorf("unexpected data, want=%+v | have=%+v", original.Data, result.Data) return } if result.IsComplete != original.IsComplete { t.Errorf("unexpected complete flag, want=%+v | have=%+v", original.IsComplete, result.IsComplete) return } } func TestNewRegister_responseCombiner_fallbackIfUnknown(t *testing.T) { r := NewRegister() rc, ok := r.GetResponseCombiner("unknown") if ok { t.Error("the response combiner should not be found") return } original := &Response{IsComplete: true, Data: map[string]interface{}{"a": 42}} result := rc(1, []*Response{{Data: original.Data, IsComplete: original.IsComplete}}) if !reflect.DeepEqual(original.Data, result.Data) { t.Errorf("unexpected data, want=%+v | have=%+v", original.Data, result.Data) return } if result.IsComplete != original.IsComplete { t.Errorf("unexpected complete flag, want=%+v | have=%+v", original.IsComplete, result.IsComplete) return } } ================================================ FILE: proxy/request.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "bytes" "io" "net/url" ) // Request contains the data to send to the backend type Request struct { Method string URL *url.URL Query url.Values Path string Body io.ReadCloser Params map[string]string Headers map[string][]string } // GeneratePath takes a pattern and updates the path of the request func (r *Request) GeneratePath(URLPattern string) { if len(r.Params) == 0 { r.Path = URLPattern return } buff := []byte(URLPattern) for k, v := range r.Params { var key []byte key = append(key, "{{."...) key = append(key, k...) key = append(key, "}}"...) buff = bytes.ReplaceAll(buff, key, []byte(v)) } r.Path = string(buff) } // Clone clones itself into a new request. The returned cloned request is not // thread-safe, so changes on request.Params and request.Headers could generate // race-conditions depending on the part of the pipe they are being executed. // For thread-safe request headers and/or params manipulation, use the proxy.CloneRequest // function. func (r *Request) Clone() Request { var clonedURL *url.URL if r.URL != nil { clonedURL, _ = url.Parse(r.URL.String()) } return Request{ Method: r.Method, URL: clonedURL, Query: r.Query, Path: r.Path, Body: r.Body, Params: r.Params, Headers: r.Headers, } } // CloneRequest returns a deep copy of the received request, so the received and the // returned proxy.Request do not share a pointer func CloneRequest(r *Request) *Request { clone := r.Clone() clone.Headers = CloneRequestHeaders(r.Headers) clone.Params = CloneRequestParams(r.Params) if r.Body == nil { return &clone } buf := new(bytes.Buffer) buf.ReadFrom(r.Body) r.Body.Close() r.Body = io.NopCloser(bytes.NewReader(buf.Bytes())) clone.Body = io.NopCloser(buf) return &clone } // CloneRequestHeaders returns a copy of the received request headers func CloneRequestHeaders(headers map[string][]string) map[string][]string { m := make(map[string][]string, len(headers)) for k, vs := range headers { tmp := make([]string, len(vs)) copy(tmp, vs) m[k] = tmp } return m } // CloneRequestParams returns a copy of the received request params func CloneRequestParams(params map[string]string) map[string]string { m := make(map[string]string, len(params)) for k, v := range params { m[k] = v } return m } ================================================ FILE: proxy/request_benchmark_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import "testing" func BenchmarkRequestGeneratePath(b *testing.B) { r := Request{ Method: "GET", Params: map[string]string{ "Supu": "42", "Tupu": "false", "Foo": "bar", }, } for _, testCase := range []string{ "/a", "/a/{{.Supu}}", "/a?b={{.Tupu}}", "/a/{{.Supu}}/foo/{{.Foo}}", "/a/{{.Supu}}/foo/{{.Foo}}/b?c={{.Tupu}}", } { b.Run(testCase, func(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { r.GeneratePath(testCase) } }) } } ================================================ FILE: proxy/request_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "bytes" "io" "strings" "testing" ) func TestRequestGeneratePath(t *testing.T) { r := Request{ Method: "GET", Params: map[string]string{ "Supu": "42", "Tupu": "false", "Foo": "bar", }, } for i, testCase := range [][]string{ {"/a/{{.Supu}}", "/a/42"}, {"/a?b={{.Tupu}}", "/a?b=false"}, {"/a/{{.Supu}}/foo/{{.Foo}}", "/a/42/foo/bar"}, {"/a", "/a"}, } { r.GeneratePath(testCase[0]) if r.Path != testCase[1] { t.Errorf("%d: want %s, have %s", i, testCase[1], r.Path) } } } func TestRequest_Clone(t *testing.T) { r := Request{ Method: "GET", Params: map[string]string{ "Supu": "42", "Tupu": "false", "Foo": "bar", }, Headers: map[string][]string{ "Content-Type": {"application/json"}, }, } clone := r.Clone() if len(r.Params) != len(clone.Params) { t.Errorf("wrong num of params. have: %d, want: %d", len(clone.Params), len(r.Params)) return } for k, v := range r.Params { if res, ok := clone.Params[k]; !ok { t.Errorf("param %s not cloned", k) } else if res != v { t.Errorf("unexpected param %s. have: %s, want: %s", k, res, v) } } if len(r.Headers) != len(clone.Headers) { t.Errorf("wrong num of headers. have: %d, want: %d", len(clone.Headers), len(r.Headers)) return } for k, vs := range r.Headers { if res, ok := clone.Headers[k]; !ok { t.Errorf("header %s not cloned", k) } else if len(res) != len(vs) { t.Errorf("unexpected header %s. have: %v, want: %v", k, res, vs) } } r.Headers["extra"] = []string{"supu"} if len(r.Headers) != len(clone.Headers) { t.Errorf("wrong num of headers. have: %d, want: %d", len(clone.Headers), len(r.Headers)) return } for k, vs := range r.Headers { if res, ok := clone.Headers[k]; !ok { t.Errorf("header %s not cloned", k) } else if len(res) != len(vs) { t.Errorf("unexpected header %s. have: %v, want: %v", k, res, vs) } } } func TestCloneRequest(t *testing.T) { body := `{"a":1,"b":2}` r := Request{ Method: "POST", Params: map[string]string{ "Supu": "42", "Tupu": "false", "Foo": "bar", }, Headers: map[string][]string{ "Content-Type": {"application/json"}, }, Body: io.NopCloser(strings.NewReader(body)), } clone := CloneRequest(&r) if len(r.Params) != len(clone.Params) { t.Errorf("wrong num of params. have: %d, want: %d", len(clone.Params), len(r.Params)) return } for k, v := range r.Params { if res, ok := clone.Params[k]; !ok { t.Errorf("param %s not cloned", k) } else if res != v { t.Errorf("unexpected param %s. have: %s, want: %s", k, res, v) } } if len(r.Headers) != len(clone.Headers) { t.Errorf("wrong num of headers. have: %d, want: %d", len(clone.Headers), len(r.Headers)) return } for k, vs := range r.Headers { if res, ok := clone.Headers[k]; !ok { t.Errorf("header %s not cloned", k) } else if len(res) != len(vs) { t.Errorf("unexpected header %s. have: %v, want: %v", k, res, vs) } } r.Headers["extra"] = []string{"supu"} if _, ok := clone.Headers["extra"]; ok { t.Error("the cloned instance shares its headers with the original one") } delete(r.Params, "Supu") if _, ok := clone.Params["Supu"]; !ok { t.Error("the cloned instance shares its params with the original one") } rb, _ := io.ReadAll(r.Body) cb, _ := io.ReadAll(clone.Body) if !bytes.Equal(cb, rb) || body != string(rb) { t.Errorf("unexpected bodies. original: %s, returned: %s", string(rb), string(cb)) } } ================================================ FILE: proxy/shadow.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "context" "time" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" ) const ( shadowKey = "shadow" shadowTimeoutKey = "shadow_timeout" ) type shadowFactory struct { f Factory } // New check the Backends for an ExtraConfig with the "shadow" param to true // implements the Factory interface. Sets the "shadow_timeout" defined in the // config; uses the backend timeout as fallback. func (s shadowFactory) New(cfg *config.EndpointConfig) (p Proxy, err error) { if len(cfg.Backend) == 0 { err = ErrNoBackends return } cfgCopy := *cfg var shadow []*config.Backend var regular []*config.Backend var maxTimeout time.Duration for _, b := range cfgCopy.Backend { if d, ok := isShadowBackend(b); ok { if maxTimeout < d { maxTimeout = d } shadow = append(shadow, b) continue } regular = append(regular, b) } cfgCopy.Backend = regular p, err = s.f.New(&cfgCopy) if len(shadow) > 0 { cfgCopy.Backend = shadow pShadow, _ := s.f.New(&cfgCopy) p = ShadowMiddlewareWithTimeout(maxTimeout, p, pShadow) } return } // NewShadowFactory creates a new shadowFactory using the provided Factory func NewShadowFactory(f Factory) Factory { return shadowFactory{f} } // ShadowMiddlewareWithLogger is a Middleware that creates a shadowProxy func ShadowMiddlewareWithLogger(logger logging.Logger, next ...Proxy) Proxy { switch len(next) { case 0: logger.Fatal("not enough proxies for this endpoint: ShadowMiddlewareWithLogger only accepts 1 or 2 proxies, got 0") return nil case 1: return next[0] case 2: return NewShadowProxy(next[0], next[1]) default: logger.Fatal("too many proxies for this proxy middleware: ShadowMiddlewareWithLogger only accepts 1 or 2 proxies, got %d", len(next)) return nil } } // ShadowMiddleware is a Middleware that creates a shadowProxy func ShadowMiddleware(next ...Proxy) Proxy { return ShadowMiddlewareWithLogger(logging.NoOp, next...) } // ShadowMiddlewareWithTimeoutAndLogger is a Middleware that creates a shadowProxy with a timeout in the context func ShadowMiddlewareWithTimeoutAndLogger(logger logging.Logger, timeout time.Duration, next ...Proxy) Proxy { switch len(next) { case 0: logger.Fatal("not enough proxies for this endpoint: ShadowMiddlewareWithTimeoutAndLogger only accepts 1 or 2 proxies, got 0") return nil case 1: return next[0] case 2: return NewShadowProxyWithTimeout(timeout, next[0], next[1]) default: logger.Fatal("too many proxies for this proxy middleware: ShadowMiddlewareWithTimeoutAndLogger only accepts 1 or 2 proxies, got %d", len(next)) return nil } } // ShadowMiddlewareWithTimeout is a Middleware that creates a shadowProxy with a timeout in the context func ShadowMiddlewareWithTimeout(timeout time.Duration, next ...Proxy) Proxy { return ShadowMiddlewareWithTimeoutAndLogger(logging.NoOp, timeout, next...) } // NewShadowProxy returns a Proxy that sends requests to p1 and p2 but ignores // the response of p2. func NewShadowProxy(p1, p2 Proxy) Proxy { return NewShadowProxyWithTimeout(config.DefaultTimeout, p1, p2) } // NewShadowProxyWithTimeout returns a Proxy that sends requests to p1 and p2 but ignores // the response of p2. Sets a timeout in the context. func NewShadowProxyWithTimeout(timeout time.Duration, p1, p2 Proxy) Proxy { return func(ctx context.Context, request *Request) (*Response, error) { shadowCtx, cancel := newContextWrapperWithTimeout(ctx, timeout) shadowRequest := CloneRequest(request) go func() { p2(shadowCtx, shadowRequest) cancel() }() return p1(ctx, request) } } func isShadowBackend(c *config.Backend) (time.Duration, bool) { duration := c.Timeout v, ok := c.ExtraConfig[Namespace] if !ok { return duration, false } e, ok := v.(map[string]interface{}) if !ok { return duration, false } k, ok := e[shadowKey] if !ok { return duration, false } if s, ok := k.(bool); !ok || !s { return duration, false } t, ok := e[shadowTimeoutKey].(string) if !ok { return duration, true } if d, err := time.ParseDuration(t); err == nil { duration = d } return duration, true } type contextWrapper struct { context.Context data context.Context } func (c contextWrapper) Value(key interface{}) interface{} { return c.data.Value(key) } func newContextWrapperWithTimeout(data context.Context, timeout time.Duration) (contextWrapper, context.CancelFunc) { ctx, cancel := context.WithTimeout(context.Background(), timeout) return contextWrapper{ Context: ctx, data: data, }, cancel } ================================================ FILE: proxy/shadow_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "bytes" "context" "errors" "sync/atomic" "testing" "time" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" ) var ( extraCfg = config.ExtraConfig{ Namespace: map[string]interface{}{ "shadow": true, "shadow_timeout": "10s", }, } badExtra = config.ExtraConfig{ Namespace: map[string]interface{}{ "shadow": "string", }, } ) func newAssertionProxy(counter *uint64) Proxy { return func(ctx context.Context, request *Request) (*Response, error) { atomic.AddUint64(counter, 1) return nil, nil } } func TestIsShadowBackend(t *testing.T) { cfg := &config.Backend{ExtraConfig: extraCfg} badCfg := &config.Backend{ExtraConfig: badExtra} d, ok := isShadowBackend(cfg) if !ok { t.Error("The shadow backend should be true") } if d != 10*time.Second { t.Errorf("Invalid duration %s", d) } if _, ok := isShadowBackend(&config.Backend{}); ok { t.Error("The shadow backend should be false") } if _, ok := isShadowBackend(badCfg); ok { t.Error("The shadow backend should be false") } } func TestShadowMiddleware(t *testing.T) { var counter uint64 assertProxy := newAssertionProxy(&counter) p := ShadowMiddleware(assertProxy, assertProxy) p(context.Background(), &Request{}) time.Sleep(100 * time.Millisecond) if atomic.LoadUint64(&counter) != 2 { t.Errorf("The shadow proxy should have been called 2 times, not %d", counter) } } func TestShadowFactory_noBackends(t *testing.T) { buff := bytes.NewBuffer(make([]byte, 1024)) logger, err := logging.NewLogger("ERROR", buff, "pref") if err != nil { t.Error("building the logger:", err.Error()) return } factory := DefaultFactory(logger) sFactory := NewShadowFactory(factory) if _, err := sFactory.New(&config.EndpointConfig{}); err != ErrNoBackends { t.Errorf("Expecting ErrNoBackends. Got: %v\n", err) } } func TestNewShadowFactory(t *testing.T) { buff := bytes.NewBuffer(make([]byte, 1024)) logger, err := logging.NewLogger("ERROR", buff, "pref") if err != nil { t.Error("building the logger:", err.Error()) return } var counter uint64 assertProxy := newAssertionProxy(&counter) factory := NewDefaultFactory(func(_ *config.Backend) Proxy { return assertProxy }, logger) f := NewShadowFactory(factory) sBackend := &config.Backend{ExtraConfig: extraCfg} backend := &config.Backend{} endpointConfig := &config.EndpointConfig{Backend: []*config.Backend{sBackend, backend}} serviceConfig := config.ServiceConfig{ Version: config.ConfigVersion, Endpoints: []*config.EndpointConfig{endpointConfig}, Timeout: 100 * time.Millisecond, Host: []string{"dummy"}, } if err = serviceConfig.Init(); err != nil { t.Errorf("Error during the config init: %s\n", err.Error()) } p, err := f.New(endpointConfig) if err != nil { t.Error(err) } _, err = p(context.Background(), &Request{}) if err != nil { t.Error(err) } time.Sleep(100 * time.Millisecond) if atomic.LoadUint64(&counter) != 2 { t.Errorf("The shadow proxy should have been called 2 times, not %d", counter) } } func TestShadowMiddleware_erroredBackend(t *testing.T) { timeout := 100 * time.Millisecond p := ShadowMiddleware( delayedProxy(t, timeout, &Response{Data: map[string]interface{}{"supu": 42}, IsComplete: true}), func(_ context.Context, _ *Request) (*Response, error) { return nil, errors.New("ignore me") }, ) mustEnd := time.After(time.Duration(5 * timeout)) out, err := p(context.Background(), &Request{Params: map[string]string{}}) if err != nil { t.Errorf("unexpected error: %s\n", err.Error()) return } if out == nil { t.Errorf("The proxy returned a null result\n") return } select { case <-mustEnd: t.Errorf("We were expecting a response but we got none\n") default: if len(out.Data) != 1 { t.Errorf("We weren't expecting a partial response but we got %v!\n", out) } if !out.IsComplete { t.Errorf("We were expecting a completed response!\n") } } } func TestShadowMiddleware_partialTimeout(t *testing.T) { timeout := 200 * time.Millisecond p := ShadowMiddleware( delayedProxy(t, time.Duration(5*timeout), &Response{Data: map[string]interface{}{"supu": 42}}), delayedProxy(t, time.Duration(timeout/2), &Response{Data: map[string]interface{}{"supu": 42}, IsComplete: true})) ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() out, err := p(ctx, &Request{}) if err == nil || err.Error() != "context deadline exceeded" { t.Errorf("The middleware propagated an unexpected error: %s\n", err.Error()) } if out != nil { t.Errorf("The proxy did not return a null result: %+v\n", out) return } } ================================================ FILE: proxy/stack_benchmark_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "context" "fmt" "testing" "time" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" ) var result interface{} func BenchmarkProxyStack_single(b *testing.B) { backend := &config.Backend{ ConcurrentCalls: 3, Timeout: time.Duration(100) * time.Millisecond, Host: []string{"supu:8080"}, Method: "GET", URLPattern: "/a/{{.Tupu}}", DenyList: []string{"map.aaaa"}, Mapping: map[string]string{"supu": "SUPUUUUU"}, } cfg := &config.EndpointConfig{ Backend: []*config.Backend{backend}, ExtraConfig: map[string]interface{}{ Namespace: map[string]interface{}{ staticKey: map[string]interface{}{ "data": map[string]interface{}{ "status": "errored", }, }, "strategy": "incomplete", }, }, } ef := NewEntityFormatter(backend) p := func(_ context.Context, _ *Request) (*Response, error) { res := ef.Format(Response{ Data: map[string]interface{}{ "supu": 42, "tupu": true, "foo": "bar", "map": map[string]interface{}{"aaaa": false}, "col": []interface{}{ map[string]interface{}{ "a": 1, "b": 2, }, }, }, IsComplete: true, }) return &res, nil } p = NewRoundRobinLoadBalancedMiddleware(backend)(p) p = NewConcurrentMiddleware(backend)(p) p = NewRequestBuilderMiddleware(backend)(p) p = NewStaticMiddleware(logging.NoOp, cfg)(p) request := &Request{ Method: "GET", Body: newDummyReadCloser(""), Params: map[string]string{"Tupu": "true"}, Headers: map[string][]string{}, } var r *Response b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { r, _ = p(context.Background(), request) } result = r } func BenchmarkProxyStack_multi(b *testing.B) { backend := &config.Backend{ ConcurrentCalls: 3, Timeout: time.Duration(100) * time.Millisecond, Host: []string{"supu:8080"}, Method: "GET", URLPattern: "/a/{{.Tupu}}", DenyList: []string{"map.aaaa"}, Mapping: map[string]string{"supu": "SUPUUUUU"}, } request := &Request{ Method: "GET", Body: newDummyReadCloser(""), Params: map[string]string{"Tupu": "true"}, Headers: map[string][]string{}, } for _, testCase := range [][]*config.Backend{ {backend}, {backend, backend}, {backend, backend, backend}, {backend, backend, backend, backend}, {backend, backend, backend, backend, backend}, } { b.Run(fmt.Sprintf("with %d backends", len(testCase)), func(b *testing.B) { cfg := &config.EndpointConfig{ Backend: testCase, } backendProxy := make([]Proxy, len(cfg.Backend)) for i, backend := range cfg.Backend { ef := NewEntityFormatter(backend) backendProxy[i] = func(_ context.Context, _ *Request) (*Response, error) { res := ef.Format(Response{ Data: map[string]interface{}{ "supu": 42, "tupu": true, "foo": "bar", "map": map[string]interface{}{"aaaa": false}, "col": []interface{}{ map[string]interface{}{ "a": 1, "b": 2, }, }, }, IsComplete: true, }) return &res, nil } backendProxy[i] = NewRoundRobinLoadBalancedMiddleware(backend)(backendProxy[i]) backendProxy[i] = NewConcurrentMiddleware(backend)(backendProxy[i]) backendProxy[i] = NewRequestBuilderMiddleware(backend)(backendProxy[i]) } p := NewMergeDataMiddleware(logging.NoOp, cfg)(backendProxy...) p = NewStaticMiddleware(logging.NoOp, cfg)(p) var r *Response b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { r, _ = p(context.Background(), request) } result = r }) } } func BenchmarkProxyStack_multipost(b *testing.B) { backendGET := &config.Backend{ ConcurrentCalls: 3, Timeout: time.Duration(100) * time.Millisecond, Host: []string{"supu:8080"}, Method: "GET", URLPattern: "/a/{{.Tupu}}", DenyList: []string{"map.aaaa"}, Mapping: map[string]string{"supu": "SUPUUUUU"}, } backendPOST := &config.Backend{ ConcurrentCalls: 3, Timeout: time.Duration(100) * time.Millisecond, Host: []string{"supu:8080"}, Method: "POST", URLPattern: "/a/{{.Tupu}}", DenyList: []string{"map.aaaa"}, Mapping: map[string]string{"supu": "SUPUUUUU"}, } request := &Request{ Method: "POST", Body: newDummyReadCloser(""), Params: map[string]string{"Tupu": "true"}, Headers: map[string][]string{}, } for _, testCase := range [][]*config.Backend{ {backendGET}, {backendGET, backendPOST}, {backendGET, backendPOST, backendGET}, {backendGET, backendPOST, backendGET, backendPOST}, {backendGET, backendPOST, backendGET, backendPOST, backendPOST}, } { b.Run(fmt.Sprintf("with %d backends", len(testCase)), func(b *testing.B) { cfg := &config.EndpointConfig{ Backend: testCase, } backendProxy := make([]Proxy, len(cfg.Backend)) for i, backend := range cfg.Backend { ef := NewEntityFormatter(backend) backendProxy[i] = func(_ context.Context, _ *Request) (*Response, error) { res := ef.Format(Response{ Data: map[string]interface{}{ "supu": 42, "tupu": true, "foo": "bar", "map": map[string]interface{}{"aaaa": false}, "col": []interface{}{ map[string]interface{}{ "a": 1, "b": 2, }, }, }, IsComplete: true, }) return &res, nil } backendProxy[i] = NewRoundRobinLoadBalancedMiddleware(backend)(backendProxy[i]) backendProxy[i] = NewConcurrentMiddleware(backend)(backendProxy[i]) backendProxy[i] = NewRequestBuilderMiddleware(backend)(backendProxy[i]) } p := NewMergeDataMiddleware(logging.NoOp, cfg)(backendProxy...) p = NewStaticMiddleware(logging.NoOp, cfg)(p) var r *Response b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { r, _ = p(context.Background(), request) } result = r }) } } func BenchmarkProxyStack_single_flatmap(b *testing.B) { backend := &config.Backend{ ConcurrentCalls: 3, Timeout: time.Duration(100) * time.Millisecond, Host: []string{"supu:8080"}, Method: "GET", URLPattern: "/a/{{.Tupu}}", ExtraConfig: config.ExtraConfig{ Namespace: map[string]interface{}{ flatmapKey: []interface{}{ map[string]interface{}{ "type": "del", "args": []interface{}{"map.aaaa"}, }, map[string]interface{}{ "type": "move", "args": []interface{}{"supu", "SUPUUUUU"}, }, }, }, }, } cfg := &config.EndpointConfig{ Backend: []*config.Backend{backend}, ExtraConfig: map[string]interface{}{ Namespace: map[string]interface{}{ staticKey: map[string]interface{}{ "data": map[string]interface{}{ "status": "errored", }, }, "strategy": "incomplete", }, }, } ef := NewEntityFormatter(backend) p := func(_ context.Context, _ *Request) (*Response, error) { res := ef.Format(Response{ Data: map[string]interface{}{ "supu": 42, "tupu": true, "foo": "bar", "map": map[string]interface{}{"aaaa": false}, "col": []interface{}{ map[string]interface{}{ "a": 1, "b": 2, }, }, }, IsComplete: true, }) return &res, nil } p = NewRoundRobinLoadBalancedMiddleware(backend)(p) p = NewConcurrentMiddleware(backend)(p) p = NewRequestBuilderMiddleware(backend)(p) p = NewStaticMiddleware(logging.NoOp, cfg)(p) request := &Request{ Method: "GET", Body: newDummyReadCloser(""), Params: map[string]string{"Tupu": "true"}, Headers: map[string][]string{}, } var r *Response b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { r, _ = p(context.Background(), request) } result = r } func BenchmarkProxyStack_multi_flatmap(b *testing.B) { backend := &config.Backend{ ConcurrentCalls: 3, Timeout: time.Duration(100) * time.Millisecond, Host: []string{"supu:8080"}, Method: "GET", URLPattern: "/a/{{.Tupu}}", ExtraConfig: config.ExtraConfig{ Namespace: map[string]interface{}{ flatmapKey: []interface{}{ map[string]interface{}{ "type": "del", "args": []interface{}{"map.aaaa"}, }, map[string]interface{}{ "type": "move", "args": []interface{}{"supu", "SUPUUUUU"}, }, }, }, }, } request := &Request{ Method: "GET", Body: newDummyReadCloser(""), Params: map[string]string{"Tupu": "true"}, Headers: map[string][]string{}, } for _, testCase := range [][]*config.Backend{ {backend}, {backend, backend}, {backend, backend, backend}, {backend, backend, backend, backend}, {backend, backend, backend, backend, backend}, } { b.Run(fmt.Sprintf("with %d backends", len(testCase)), func(b *testing.B) { cfg := &config.EndpointConfig{ Backend: testCase, } backendProxy := make([]Proxy, len(cfg.Backend)) for i, backend := range cfg.Backend { ef := NewEntityFormatter(backend) backendProxy[i] = func(_ context.Context, _ *Request) (*Response, error) { res := ef.Format(Response{ Data: map[string]interface{}{ "supu": 42, "tupu": true, "foo": "bar", "map": map[string]interface{}{"aaaa": false}, "col": []interface{}{ map[string]interface{}{ "a": 1, "b": 2, }, }, }, IsComplete: true, }) return &res, nil } backendProxy[i] = NewRoundRobinLoadBalancedMiddleware(backend)(backendProxy[i]) backendProxy[i] = NewConcurrentMiddleware(backend)(backendProxy[i]) backendProxy[i] = NewRequestBuilderMiddleware(backend)(backendProxy[i]) } p := NewMergeDataMiddleware(logging.NoOp, cfg)(backendProxy...) p = NewStaticMiddleware(logging.NoOp, cfg)(p) var r *Response b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { p(context.Background(), request) } result = r }) } } ================================================ FILE: proxy/stack_test.go ================================================ //go:build integration || !race // +build integration !race // SPDX-License-Identifier: Apache-2.0 package proxy import ( "context" "fmt" "io" "net/http" "net/http/httptest" "net/url" "os" "strings" "sync" "testing" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" ) func TestProxyStack_multi(t *testing.T) { results := map[string]int{} m := new(sync.Mutex) total := 100000 cfgPath := ".config.json" s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { m.Lock() results[r.URL.String()]++ m.Unlock() w.Write([]byte("{\"foo\":42}")) })) defer s.Close() { cfgContent := `{ "version":3, "endpoints":[{ "endpoint":"/{foo}", "backend":[ { "host": ["%s"], "url_pattern": "/first/{foo}", "group": "1" }, { "host": ["%s"], "url_pattern": "/second/{foo}", "group": "2" }, { "host": ["%s"], "url_pattern": "/third/{foo}", "group": "3" } ] }] }` if err := os.WriteFile(cfgPath, []byte(fmt.Sprintf(cfgContent, s.URL, s.URL, s.URL)), 0666); err != nil { t.Error(err) return } defer os.Remove(cfgPath) } cfg, err := config.NewParser().Parse(cfgPath) if err != nil { t.Error(err) return } cfg.Normalize() factory := NewDefaultFactory(httpProxy, logging.NoOp) p, err := factory.New(cfg.Endpoints[0]) if err != nil { t.Error(err) return } for i := 0; i < total; i++ { p(context.Background(), &Request{ Method: "GET", Params: map[string]string{"Foo": "42"}, Headers: map[string][]string{}, Path: "/", Query: url.Values{}, Body: io.NopCloser(strings.NewReader("")), URL: new(url.URL), }) } for k, v := range results { if v != total { t.Errorf("the url %s was consumed %d times", k, v) } } if len(results) != 3 { t.Errorf("unexpected number of consumed urls. have %d, want 3", len(results)) } } ================================================ FILE: proxy/static.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "context" "encoding/json" "fmt" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" ) // NewStaticMiddleware creates proxy middleware for adding static values to the processed responses func NewStaticMiddleware(logger logging.Logger, endpointConfig *config.EndpointConfig) Middleware { cfg, ok := getStaticMiddlewareCfg(endpointConfig.ExtraConfig) if !ok { return emptyMiddlewareFallback(logger) } b, _ := json.Marshal(cfg.Data) logger.Debug( fmt.Sprintf( "[ENDPOINT: %s][Static] Adding a static response using '%s' strategy. Data: %s", endpointConfig.Endpoint, cfg.Strategy, string(b), ), ) return func(next ...Proxy) Proxy { if len(next) > 1 { logger.Fatal("too many proxies for this proxy middleware: NewStaticMiddleware only accepts 1 proxy, got %d", len(next)) return nil } return func(ctx context.Context, request *Request) (*Response, error) { result, err := next[0](ctx, request) if !cfg.Match(result, err) { return result, err } if result == nil { result = &Response{Data: map[string]interface{}{}} } else if result.Data == nil { result.Data = map[string]interface{}{} } for k, v := range cfg.Data { result.Data[k] = v } return result, err } } } const ( staticKey = "static" staticAlwaysStrategy = "always" staticIfSuccessStrategy = "success" staticIfErroredStrategy = "errored" staticIfCompleteStrategy = "complete" staticIfIncompleteStrategy = "incomplete" ) type staticConfig struct { Data map[string]interface{} Strategy string Match func(*Response, error) bool } func getStaticMiddlewareCfg(extra config.ExtraConfig) (staticConfig, bool) { v, ok := extra[Namespace] if !ok { return staticConfig{}, ok } e, ok := v.(map[string]interface{}) if !ok { return staticConfig{}, ok } v, ok = e[staticKey] if !ok { return staticConfig{}, ok } tmp, ok := v.(map[string]interface{}) if !ok { return staticConfig{}, ok } data, ok := tmp["data"].(map[string]interface{}) if !ok { return staticConfig{}, ok } name, ok := tmp["strategy"].(string) if !ok { name = staticAlwaysStrategy } cfg := staticConfig{ Data: data, Strategy: name, Match: staticAlwaysMatch, } switch name { case staticAlwaysStrategy: cfg.Match = staticAlwaysMatch case staticIfSuccessStrategy: cfg.Match = staticIfSuccessMatch case staticIfErroredStrategy: cfg.Match = staticIfErroredMatch case staticIfCompleteStrategy: cfg.Match = staticIfCompleteMatch case staticIfIncompleteStrategy: cfg.Match = staticIfIncompleteMatch } return cfg, true } func staticAlwaysMatch(_ *Response, _ error) bool { return true } func staticIfSuccessMatch(_ *Response, err error) bool { return err == nil } func staticIfErroredMatch(_ *Response, err error) bool { return err != nil } func staticIfCompleteMatch(r *Response, err error) bool { return err == nil && r != nil && r.IsComplete } func staticIfIncompleteMatch(r *Response, _ error) bool { return r == nil || !r.IsComplete } ================================================ FILE: proxy/static_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package proxy import ( "context" "errors" "reflect" "testing" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" ) func TestNewStaticMiddleware_ok(t *testing.T) { endpoint := config.EndpointConfig{ ExtraConfig: config.ExtraConfig{ Namespace: map[string]interface{}{ staticKey: map[string]interface{}{ "data": map[string]interface{}{ "new-1": true, "new-2": map[string]interface{}{"k1": 42}, "new-3": "42", }, "strategy": "incomplete", }, }, }, } mw := NewStaticMiddleware(logging.NoOp, &endpoint) p := mw(dummyProxy(&Response{Data: map[string]interface{}{"supu": 42}, IsComplete: true})) out1, err := p(context.Background(), &Request{}) if err != nil { t.Errorf("The middleware propagated an unexpected error: %s", err.Error()) } if out1 == nil { t.Error("The proxy returned a null result") return } if len(out1.Data) != 1 { t.Errorf("We weren't expecting an extra partial response but we got %v!", out1) } if !out1.IsComplete { t.Errorf("We were expecting a completed response but we got an incompleted one!") } p = mw(dummyProxy(&Response{Data: map[string]interface{}{"supu": 42}})) out2, err := p(context.Background(), &Request{}) if err != nil { t.Errorf("The middleware propagated an unexpected error: %s", err.Error()) } if out2 == nil { t.Error("The proxy returned a null result") return } if len(out2.Data) != 4 { t.Errorf("We weren't expecting a partial response but we got %v!", out2) } expectedError := errors.New("expect me") p = mw(func(_ context.Context, _ *Request) (*Response, error) { return nil, expectedError }) out3, err := p(context.Background(), &Request{}) if err != expectedError { t.Errorf("The middleware propagated an unexpected error: %s", err) } if out3 == nil { t.Error("The proxy returned a null result") return } if len(out3.Data) != 3 { t.Errorf("We weren't expecting a partial response but we got %v!", out3) } } type staticMatcherTestCase struct { name string response *Response err error expected bool } func TestNewStaticMiddleware(t *testing.T) { data := map[string]interface{}{ "new-1": true, "new-2": map[string]interface{}{"k1": 42}, "new-3": "42", } extra := config.ExtraConfig{ Namespace: map[string]interface{}{ staticKey: map[string]interface{}{ "data": data, "strategy": staticIfCompleteStrategy, }, }, } mw := NewStaticMiddleware(logging.NoOp, &config.EndpointConfig{ExtraConfig: extra}) p := mw(func(_ context.Context, r *Request) (*Response, error) { return &Response{IsComplete: true}, nil }) resp, err := p(context.Background(), nil) if err != nil { t.Error(err) return } if !reflect.DeepEqual(data, resp.Data) { t.Errorf("unexpected data: %+v", resp.Data) } } func Test_staticAlwaysMatch(t *testing.T) { extra := config.ExtraConfig{ Namespace: map[string]interface{}{ staticKey: map[string]interface{}{ "data": map[string]interface{}{ "new-1": true, "new-2": map[string]interface{}{"k1": 42}, "new-3": "42", }, }, }, } cfg, _ := getStaticMiddlewareCfg(extra) for _, testCase := range []staticMatcherTestCase{ { name: "nil & nil", expected: true, }, { name: "nil & error", err: errors.New("ignore me"), expected: true, }, { name: "complete & nil", response: &Response{Data: map[string]interface{}{}, IsComplete: true}, expected: true, }, { name: "complete & error", response: &Response{Data: map[string]interface{}{}, IsComplete: true}, err: errors.New("ignore me"), expected: true, }, { name: "incomplete", response: &Response{}, expected: true, }, } { testStaticMatcher(t, cfg.Match, testCase) } } func Test_staticIfSuccessMatch(t *testing.T) { extra := config.ExtraConfig{ Namespace: map[string]interface{}{ staticKey: map[string]interface{}{ "data": map[string]interface{}{ "new-1": true, "new-2": map[string]interface{}{"k1": 42}, "new-3": "42", }, "strategy": staticIfSuccessStrategy, }, }, } cfg, _ := getStaticMiddlewareCfg(extra) for _, testCase := range []staticMatcherTestCase{ { name: "nil & nil", expected: true, }, { name: "nil & error", err: errors.New("ignore me"), expected: false, }, { name: "success & nil", response: &Response{}, expected: true, }, { name: "success & error", response: &Response{}, err: errors.New("ignore me"), }, } { testStaticMatcher(t, cfg.Match, testCase) } } func Test_staticIfErroredMatch(t *testing.T) { extra := config.ExtraConfig{ Namespace: map[string]interface{}{ staticKey: map[string]interface{}{ "data": map[string]interface{}{ "new-1": true, "new-2": map[string]interface{}{"k1": 42}, "new-3": "42", }, "strategy": staticIfErroredStrategy, }, }, } cfg, _ := getStaticMiddlewareCfg(extra) for _, testCase := range []staticMatcherTestCase{ { name: "nil & nil", }, { name: "nil & error", err: errors.New("ignore me"), expected: true, }, { name: "success & nil", response: &Response{}, }, { name: "success & error", response: &Response{}, err: errors.New("ignore me"), expected: true, }, } { testStaticMatcher(t, cfg.Match, testCase) } } func Test_staticIfCompleteMatch(t *testing.T) { extra := config.ExtraConfig{ Namespace: map[string]interface{}{ staticKey: map[string]interface{}{ "data": map[string]interface{}{ "new-1": true, "new-2": map[string]interface{}{"k1": 42}, "new-3": "42", }, "strategy": staticIfCompleteStrategy, }, }, } cfg, _ := getStaticMiddlewareCfg(extra) for _, testCase := range []staticMatcherTestCase{ { name: "nil & nil", }, { name: "nil & error", err: errors.New("ignore me"), }, { name: "success & nil", response: &Response{}, }, { name: "success & error", response: &Response{}, err: errors.New("ignore me"), }, { name: "complete", response: &Response{IsComplete: true}, expected: true, }, } { testStaticMatcher(t, cfg.Match, testCase) } } func Test_staticIfIncompleteMatch(t *testing.T) { extra := config.ExtraConfig{ Namespace: map[string]interface{}{ staticKey: map[string]interface{}{ "data": map[string]interface{}{ "new-1": true, "new-2": map[string]interface{}{"k1": 42}, "new-3": "42", }, "strategy": staticIfIncompleteStrategy, }, }, } cfg, _ := getStaticMiddlewareCfg(extra) for _, testCase := range []staticMatcherTestCase{ { name: "nil & nil", expected: true, }, { name: "nil & error", err: errors.New("ignore me"), expected: true, }, { name: "success & nil", response: &Response{}, expected: true, }, { name: "success & error", response: &Response{}, err: errors.New("ignore me"), expected: true, }, { name: "complete", response: &Response{IsComplete: true}, }, } { testStaticMatcher(t, cfg.Match, testCase) } } func testStaticMatcher(t *testing.T, marcher func(*Response, error) bool, testCase staticMatcherTestCase) { if marcher(testCase.response, testCase.err) != testCase.expected { t.Errorf( "[%s] unexepecting match result (%v) with: %v, %v", testCase.name, testCase.expected, testCase.response, testCase.err, ) } } func Test_getStaticMiddlewareCfg_ko(t *testing.T) { for i, cfg := range []config.ExtraConfig{ {"a": 42}, {Namespace: true}, {Namespace: map[string]interface{}{}}, {Namespace: map[string]interface{}{staticKey: 42}}, {Namespace: map[string]interface{}{staticKey: map[string]interface{}{}}}, } { if _, ok := getStaticMiddlewareCfg(cfg); ok { t.Errorf("expecting error on test #%d", i) } } } func Test_getStaticMiddlewareCfg_strategy(t *testing.T) { for _, strategy := range []string{ staticAlwaysStrategy, staticIfSuccessStrategy, staticIfErroredStrategy, staticIfCompleteStrategy, staticIfIncompleteStrategy, } { cfg := config.ExtraConfig{ Namespace: map[string]interface{}{ staticKey: map[string]interface{}{ "data": map[string]interface{}{}, "strategy": strategy, }, }, } staticCfg, ok := getStaticMiddlewareCfg(cfg) if !ok { t.Errorf("unexpecting error on test %s", strategy) } if strategy != staticCfg.Strategy { t.Errorf("wrong parsing on test %s", strategy) } } cfg := config.ExtraConfig{ Namespace: map[string]interface{}{ staticKey: map[string]interface{}{ "data": map[string]interface{}{}, }, }, } staticCfg, ok := getStaticMiddlewareCfg(cfg) if !ok { t.Error("unexpecting error parsing default strategy") } if staticAlwaysStrategy != staticCfg.Strategy { t.Error("wrong parsing default strategy") } } ================================================ FILE: register/register.go ================================================ // SPDX-License-Identifier: Apache-2.0 /* Package register offers tools for creating and managing registers. */ package register import "sync" // New returns an initialized Namespaced register func New() *Namespaced { return &Namespaced{data: NewUntyped()} } // Namespaced is a register able to keep track of elements stored // under namespaces and keys type Namespaced struct { data *Untyped } // Get returns the Untyped register stored under the namespace func (n *Namespaced) Get(namespace string) (*Untyped, bool) { v, ok := n.data.Get(namespace) if !ok { return nil, ok } register, ok := v.(*Untyped) return register, ok } // Register stores v at the key name of the Untyped register named namespace func (n *Namespaced) Register(namespace, name string, v interface{}) { if register, ok := n.Get(namespace); ok { register.Register(name, v) return } register := NewUntyped() register.Register(name, v) n.data.Register(namespace, register) } // AddNamespace adds a new, empty Untyped register under the give namespace (if // it did not exist) func (n *Namespaced) AddNamespace(namespace string) { if _, ok := n.Get(namespace); ok { return } n.data.Register(namespace, NewUntyped()) } // NewUntyped returns an empty Untyped register func NewUntyped() *Untyped { return &Untyped{ data: map[string]interface{}{}, mutex: &sync.RWMutex{}, } } // Untyped is a simple register, safe for concurrent access type Untyped struct { data map[string]interface{} mutex *sync.RWMutex } // Register stores v under the key name func (u *Untyped) Register(name string, v interface{}) { u.mutex.Lock() u.data[name] = v u.mutex.Unlock() } // Get returns the value stored at the key name func (u *Untyped) Get(name string) (interface{}, bool) { u.mutex.RLock() v, ok := u.data[name] u.mutex.RUnlock() return v, ok } // Clone returns a snapshot of the register func (u *Untyped) Clone() map[string]interface{} { u.mutex.RLock() res := make(map[string]interface{}, len(u.data)) for k, v := range u.data { res[k] = v } u.mutex.RUnlock() return res } ================================================ FILE: register/register_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package register import "testing" func TestNamespaced(t *testing.T) { r := New() r.Register("namespace1", "name1", 42) r.AddNamespace("namespace1") r.AddNamespace("namespace2") r.Register("namespace2", "name2", true) nr, ok := r.Get("namespace1") if !ok { t.Error("namespace1 not found") return } if _, ok := nr.Get("name2"); ok { t.Error("name2 found into namespace1") return } v1, ok := nr.Get("name1") if !ok { t.Error("name1 not found") return } if i, ok := v1.(int); !ok || i != 42 { t.Error("unexpected value:", v1) } nr, ok = r.Get("namespace2") if !ok { t.Error("namespace2 not found") return } if _, ok := nr.Get("name1"); ok { t.Error("name1 found into namespace2") return } v2, ok := nr.Get("name2") if !ok { t.Error("name2 not found") return } if b, ok := v2.(bool); !ok || !b { t.Error("unexpected value:", v2) } } ================================================ FILE: router/chi/endpoint.go ================================================ // SPDX-License-Identifier: Apache-2.0 package chi import ( "net/http" "github.com/go-chi/chi/v5" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/proxy" "github.com/luraproject/lura/v2/router/mux" "golang.org/x/text/cases" "golang.org/x/text/language" ) // HandlerFactory creates a handler function that adapts the chi router with the injected proxy type HandlerFactory func(*config.EndpointConfig, proxy.Proxy) http.HandlerFunc // NewEndpointHandler implements the HandleFactory interface using the default ToHTTPError function func NewEndpointHandler(cfg *config.EndpointConfig, prxy proxy.Proxy) http.HandlerFunc { hf := mux.CustomEndpointHandler( mux.NewRequestBuilder(extractParamsFromEndpoint), ) return hf(cfg, prxy) } func extractParamsFromEndpoint(r *http.Request) map[string]string { ctx := r.Context() rctx := chi.RouteContext(ctx) params := map[string]string{} if len(rctx.URLParams.Keys) > 0 { title := cases.Title(language.Und) for _, param := range rctx.URLParams.Keys { params[title.String(param[:1])+param[1:]] = chi.URLParam(r, param) } } return params } ================================================ FILE: router/chi/endpoint_benchmark_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package chi import ( "bytes" "context" "fmt" "io" "net/http" "net/http/httptest" "testing" "time" "github.com/go-chi/chi/v5" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/proxy" ) func BenchmarkEndpointHandler_ko(b *testing.B) { p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return nil, fmt.Errorf("This is %s", "a dummy error") } endpoint := &config.EndpointConfig{ Timeout: time.Second, CacheTTL: 6 * time.Hour, QueryString: []string{"b"}, } router := chi.NewRouter() router.Handle("/_chi_endpoint/", NewEndpointHandler(endpoint, p)) req, _ := http.NewRequest("GET", "http://127.0.0.1:8080/_chi_endpoint/a?b=1", http.NoBody) req.Header.Set("Content-Type", "application/json") b.ReportAllocs() for i := 0; i < b.N; i++ { w := httptest.NewRecorder() router.ServeHTTP(w, req) } } func BenchmarkEndpointHandler_ok(b *testing.B) { pResp := proxy.Response{ Data: map[string]interface{}{}, Io: io.NopCloser(&bytes.Buffer{}), IsComplete: true, Metadata: proxy.Metadata{}, } p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return &pResp, nil } endpoint := &config.EndpointConfig{ Timeout: time.Second, CacheTTL: 6 * time.Hour, QueryString: []string{"b"}, } router := chi.NewRouter() router.Handle("/_chi_endpoint/", NewEndpointHandler(endpoint, p)) req, _ := http.NewRequest("GET", "http://127.0.0.1:8080/_chi_endpoint/a?b=1", http.NoBody) req.Header.Set("Content-Type", "application/json") b.ReportAllocs() for i := 0; i < b.N; i++ { w := httptest.NewRecorder() router.ServeHTTP(w, req) } } func BenchmarkEndpointHandler_ko_Parallel(b *testing.B) { p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return nil, fmt.Errorf("This is %s", "a dummy error") } endpoint := &config.EndpointConfig{ Timeout: time.Second, CacheTTL: 6 * time.Hour, QueryString: []string{"b"}, } router := chi.NewRouter() router.Handle("/_chi_endpoint/", NewEndpointHandler(endpoint, p)) req, _ := http.NewRequest("GET", "http://127.0.0.1:8080/_chi_endpoint/a?b=1", http.NoBody) req.Header.Set("Content-Type", "application/json") b.ReportAllocs() b.RunParallel(func(pb *testing.PB) { for pb.Next() { w := httptest.NewRecorder() router.ServeHTTP(w, req) } }) } func BenchmarkEndpointHandler_ok_Parallel(b *testing.B) { pResp := proxy.Response{ Data: map[string]interface{}{}, Io: io.NopCloser(&bytes.Buffer{}), IsComplete: true, Metadata: proxy.Metadata{}, } p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return &pResp, nil } endpoint := &config.EndpointConfig{ Timeout: time.Second, CacheTTL: 6 * time.Hour, QueryString: []string{"b"}, } router := chi.NewRouter() router.Handle("/_chi_endpoint/", NewEndpointHandler(endpoint, p)) req, _ := http.NewRequest("GET", "http://127.0.0.1:8080/_chi_endpoint/a?b=1", http.NoBody) req.Header.Set("Content-Type", "application/json") b.ReportAllocs() b.RunParallel(func(pb *testing.PB) { for pb.Next() { w := httptest.NewRecorder() router.ServeHTTP(w, req) } }) } ================================================ FILE: router/chi/endpoint_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package chi import ( "bytes" "context" "encoding/json" "errors" "io" "net/http" "net/http/httptest" "testing" "time" "github.com/go-chi/chi/v5" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/proxy" "github.com/luraproject/lura/v2/transport/http/server" ) func TestEndpointHandler_ok(t *testing.T) { p := func(ctx context.Context, req *proxy.Request) (*proxy.Response, error) { data, _ := json.Marshal(req.Query) if string(data) != `{"b":["1"],"c[]":["x","y"],"d":["1","2"]}` { t.Errorf("unexpected querystring: %s", data) } return &proxy.Response{ IsComplete: true, Data: map[string]interface{}{"supu": "tupu"}, Metadata: proxy.Metadata{ Headers: map[string][]string{"a": {"a1", "a2"}}, }, }, nil } endpointHandlerTestCase{ timeout: 10, proxy: p, method: "GET", expectedBody: "{\"supu\":\"tupu\"}", expectedCache: "public, max-age=21600", expectedContent: "application/json", expectedStatusCode: http.StatusOK, completed: true, }.test(t) time.Sleep(5 * time.Millisecond) } func TestEndpointHandler_okAllParams(t *testing.T) { p := func(_ context.Context, req *proxy.Request) (*proxy.Response, error) { return &proxy.Response{ IsComplete: true, Data: map[string]interface{}{"query": req.Query, "headers": req.Headers, "params": req.Params}, Metadata: proxy.Metadata{ Headers: map[string][]string{"X-YZ": {"something"}}, StatusCode: 200, }, }, nil } endpointHandlerTestCase{ timeout: 10, proxy: p, method: "GET", expectedBody: `{"headers":{"Content-Type":["application/json"],"User-Agent":["KrakenD Version undefined"],"X-Forwarded-For":[""],"X-Forwarded-Host":["127.0.0.1:8080"]},"params":{"Param":"a"},"query":{"a":["42"],"b":["1"],"c[]":["x","y"],"d":["1","2"]}}`, expectedCache: "public, max-age=21600", expectedContent: "application/json", expectedStatusCode: http.StatusOK, completed: true, queryString: []string{"*"}, headers: []string{"*"}, expectedHeaders: map[string][]string{"X-YZ": {"something"}}, }.test(t) time.Sleep(5 * time.Millisecond) } func TestEndpointHandler_incomplete(t *testing.T) { p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return &proxy.Response{ IsComplete: false, Data: map[string]interface{}{"foo": "bar"}, }, nil } endpointHandlerTestCase{ timeout: 10, proxy: p, method: "GET", expectedBody: "{\"foo\":\"bar\"}", expectedCache: "", expectedContent: "application/json", expectedStatusCode: http.StatusOK, completed: false, }.test(t) time.Sleep(5 * time.Millisecond) } func TestEndpointHandler_errored(t *testing.T) { p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return nil, errors.New("this is a dummy error") } endpointHandlerTestCase{ timeout: 10, proxy: p, method: "GET", expectedBody: "this is a dummy error\n", expectedCache: "", expectedContent: "text/plain; charset=utf-8", expectedStatusCode: http.StatusInternalServerError, completed: false, }.test(t) time.Sleep(5 * time.Millisecond) } func TestEndpointHandler_errored_responseError(t *testing.T) { p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return nil, dummyResponseError{err: "this is a dummy error", status: http.StatusTeapot} } endpointHandlerTestCase{ timeout: 10, proxy: p, method: "GET", expectedBody: "this is a dummy error\n", expectedCache: "", expectedContent: "text/plain; charset=utf-8", expectedStatusCode: http.StatusTeapot, completed: false, }.test(t) time.Sleep(5 * time.Millisecond) } type dummyResponseError struct { err string status int } func (d dummyResponseError) Error() string { return d.err } func (d dummyResponseError) StatusCode() int { return d.status } func TestEndpointHandler_incompleteAndErrored(t *testing.T) { p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return &proxy.Response{ IsComplete: false, Data: map[string]interface{}{"foo": "bar"}, }, errors.New("This is a dummy error") } endpointHandlerTestCase{ timeout: 10, proxy: p, method: "GET", expectedBody: "{\"foo\":\"bar\"}", expectedCache: "", expectedContent: "application/json", expectedStatusCode: http.StatusOK, completed: false, }.test(t) time.Sleep(5 * time.Millisecond) } func TestEndpointHandler_cancelEmpty(t *testing.T) { p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { time.Sleep(100 * time.Millisecond) return nil, nil } endpointHandlerTestCase{ timeout: 0, proxy: p, method: "GET", expectedBody: "internal server error\n", expectedCache: "", expectedContent: "text/plain; charset=utf-8", expectedStatusCode: http.StatusInternalServerError, completed: false, }.test(t) time.Sleep(5 * time.Millisecond) } func TestEndpointHandler_cancel(t *testing.T) { p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { time.Sleep(100 * time.Millisecond) return &proxy.Response{ IsComplete: false, Data: map[string]interface{}{"foo": "bar"}, }, nil } endpointHandlerTestCase{ timeout: 0, proxy: p, method: "GET", expectedBody: "{\"foo\":\"bar\"}", expectedCache: "", expectedContent: "application/json", expectedStatusCode: http.StatusOK, completed: false, }.test(t) time.Sleep(5 * time.Millisecond) } func TestEndpointHandler_noop(t *testing.T) { endpointHandlerTestCase{ timeout: time.Minute, proxy: proxy.NoopProxy, method: "GET", expectedBody: "{}", expectedCache: "", expectedContent: "application/json", expectedStatusCode: http.StatusOK, completed: false, }.test(t) time.Sleep(5 * time.Millisecond) } type endpointHandlerTestCase struct { timeout time.Duration proxy proxy.Proxy method string expectedBody string expectedCache string expectedContent string expectedHeaders map[string][]string expectedStatusCode int completed bool queryString []string headers []string } func (tc endpointHandlerTestCase) test(t *testing.T) { endpoint := &config.EndpointConfig{ Method: "GET", Timeout: tc.timeout, CacheTTL: 6 * time.Hour, QueryString: []string{"b", "c[]", "d"}, } if len(tc.queryString) > 0 { endpoint.QueryString = tc.queryString } if len(tc.headers) > 0 { endpoint.HeadersToPass = tc.headers } s := startChiServer(NewEndpointHandler(endpoint, tc.proxy)) req, _ := http.NewRequest(tc.method, "http://127.0.0.1:8080/_chi_endpoint/a?a=42&b=1&c[]=x&c[]=y&d=1&d=2", io.NopCloser(&bytes.Buffer{})) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() s.ServeHTTP(w, req) body, ioerr := io.ReadAll(w.Result().Body) if ioerr != nil { t.Error("Reading the response:", ioerr.Error()) return } w.Result().Body.Close() content := string(body) resp := w.Result() if resp.Header.Get("Cache-Control") != tc.expectedCache { t.Error("Cache-Control error:", resp.Header.Get("Cache-Control")) } if tc.completed && resp.Header.Get(server.CompleteResponseHeaderName) != server.HeaderCompleteResponseValue { t.Error(server.CompleteResponseHeaderName, "error:", resp.Header.Get(server.CompleteResponseHeaderName)) } if !tc.completed && resp.Header.Get(server.CompleteResponseHeaderName) != server.HeaderIncompleteResponseValue { t.Error(server.CompleteResponseHeaderName, "error:", resp.Header.Get(server.CompleteResponseHeaderName)) } if resp.Header.Get("Content-Type") != tc.expectedContent { t.Error("Content-Type error:", resp.Header.Get("Content-Type")) } if resp.Header.Get("X-Krakend") != "Version undefined" { t.Error("X-Krakend error:", resp.Header.Get("X-Krakend")) } if resp.StatusCode != tc.expectedStatusCode { t.Error("Unexpected status code:", resp.StatusCode) } if content != tc.expectedBody { t.Error("Unexpected body:", content, "expected:", tc.expectedBody) } for k, v := range tc.expectedHeaders { if header := resp.Header.Get(k); v[0] != header { t.Error("Unexpected value for header:", k, header, "expected:", v[0]) } } } func startChiServer(handlerFunc http.HandlerFunc) *chi.Mux { r := chi.NewRouter() r.Handle("/_chi_endpoint/{param}", handlerFunc) return r } ================================================ FILE: router/chi/router.go ================================================ // SPDX-License-Identifier: Apache-2.0 /* Package chi provides some basic implementations for building routers based on go-chi/chi */ package chi import ( "context" "net/http" "strings" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" "github.com/luraproject/lura/v2/proxy" "github.com/luraproject/lura/v2/router" "github.com/luraproject/lura/v2/router/mux" "github.com/luraproject/lura/v2/transport/http/server" ) // ChiDefaultDebugPattern is the default pattern used to define the debug endpoint const ChiDefaultDebugPattern = "/__debug/" const logPrefix = "[SERVICE: Chi]" // RunServerFunc is a func that will run the http Server with the given params. type RunServerFunc func(context.Context, config.ServiceConfig, http.Handler) error // Config is the struct that collects the parts the router should be builded from type Config struct { Engine chi.Router Middlewares chi.Middlewares HandlerFactory HandlerFactory ProxyFactory proxy.Factory Logger logging.Logger DebugPattern string RunServer RunServerFunc } // DefaultFactory returns a chi router factory with the injected proxy factory and logger. // It also uses a default chi router and the default HandlerFactory func DefaultFactory(proxyFactory proxy.Factory, logger logging.Logger) router.Factory { return NewFactory( Config{ Engine: chi.NewRouter(), Middlewares: chi.Middlewares{middleware.Logger}, HandlerFactory: NewEndpointHandler, ProxyFactory: proxyFactory, Logger: logger, DebugPattern: ChiDefaultDebugPattern, RunServer: server.RunServer, }, ) } // NewFactory returns a chi router factory with the injected configuration func NewFactory(cfg Config) router.Factory { if cfg.DebugPattern == "" { cfg.DebugPattern = ChiDefaultDebugPattern } return factory{cfg} } type factory struct { cfg Config } // New implements the factory interface func (rf factory) New() router.Router { return rf.NewWithContext(context.Background()) } // NewWithContext implements the factory interface func (rf factory) NewWithContext(ctx context.Context) router.Router { return chiRouter{rf.cfg, ctx, rf.cfg.RunServer} } type chiRouter struct { cfg Config ctx context.Context RunServer RunServerFunc } // Run implements the router interface func (r chiRouter) Run(cfg config.ServiceConfig) { r.cfg.Engine.Use(r.cfg.Middlewares...) if cfg.Debug { r.registerDebugEndpoints() } r.cfg.Engine.Get("/__health", mux.HealthHandler) server.InitHTTPDefaultTransport(cfg) r.registerKrakendEndpoints(cfg.Endpoints) r.cfg.Engine.NotFound(func(w http.ResponseWriter, r *http.Request) { w.Header().Set(server.CompleteResponseHeaderName, server.HeaderIncompleteResponseValue) http.NotFound(w, r) }) if err := r.RunServer(r.ctx, cfg, r.cfg.Engine); err != nil { r.cfg.Logger.Error(logPrefix, err.Error()) } r.cfg.Logger.Info(logPrefix, "Router execution ended") } func (r chiRouter) registerDebugEndpoints() { debugHandler := mux.DebugHandler(r.cfg.Logger) r.cfg.Engine.Get(r.cfg.DebugPattern, debugHandler) r.cfg.Engine.Post(r.cfg.DebugPattern, debugHandler) r.cfg.Engine.Put(r.cfg.DebugPattern, debugHandler) r.cfg.Engine.Patch(r.cfg.DebugPattern, debugHandler) r.cfg.Engine.Delete(r.cfg.DebugPattern, debugHandler) } func (r chiRouter) registerKrakendEndpoints(endpoints []*config.EndpointConfig) { for _, c := range endpoints { proxyStack, err := r.cfg.ProxyFactory.New(c) if err != nil { r.cfg.Logger.Error(logPrefix, "calling the ProxyFactory", err.Error()) continue } r.registerKrakendEndpoint(c.Method, c, r.cfg.HandlerFactory(c, proxyStack), len(c.Backend)) } } func (r chiRouter) registerKrakendEndpoint(method string, endpoint *config.EndpointConfig, handler http.HandlerFunc, totBackends int) { method = strings.ToTitle(method) path := endpoint.Endpoint if method != http.MethodGet && totBackends > 1 { if !router.IsValidSequentialEndpoint(endpoint) { r.cfg.Logger.Error(logPrefix, method, "endpoints with sequential proxy enabled only allow a non-GET in the last backend! Ignoring", path) return } } switch method { case http.MethodGet: r.cfg.Engine.Get(path, handler) case http.MethodPost: r.cfg.Engine.Post(path, handler) case http.MethodPut: r.cfg.Engine.Put(path, handler) case http.MethodPatch: r.cfg.Engine.Patch(path, handler) case http.MethodDelete: r.cfg.Engine.Delete(path, handler) default: r.cfg.Logger.Error(logPrefix, "Unsupported method", method) return } r.cfg.Logger.Debug(logPrefix, "registering the endpoint", method, path) } ================================================ FILE: router/chi/router_test.go ================================================ //go:build !race // +build !race // SPDX-License-Identifier: Apache-2.0 package chi import ( "bytes" "context" "errors" "fmt" "io" "net/http" "regexp" "strings" "testing" "time" "github.com/go-chi/chi/v5" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" "github.com/luraproject/lura/v2/proxy" "github.com/luraproject/lura/v2/transport/http/server" ) func TestDefaultFactory_ok(t *testing.T) { buff := bytes.NewBuffer(make([]byte, 1024)) logger, err := logging.NewLogger("ERROR", buff, "pref") if err != nil { t.Error("building the logger:", err.Error()) return } ctx, cancel := context.WithCancel(context.Background()) defer func() { cancel() time.Sleep(5 * time.Millisecond) }() r := DefaultFactory(noopProxyFactory(map[string]interface{}{"supu": "tupu"}), logger).NewWithContext(ctx) expectedBody := "{\"supu\":\"tupu\"}" serviceCfg := config.ServiceConfig{ Port: 8062, Endpoints: []*config.EndpointConfig{ { Endpoint: "/get", Method: "GET", Timeout: 10, Backend: []*config.Backend{ {}, }, }, { Endpoint: "/get", Method: "POST", Timeout: 10, Backend: []*config.Backend{ {}, }, }, { Endpoint: "/post", Method: "Post", Timeout: 10, Backend: []*config.Backend{ {}, }, }, { Endpoint: "/put", Method: "put", Timeout: 10, Backend: []*config.Backend{ {}, }, }, { Endpoint: "/patch", Method: "PATCH", Timeout: 10, Backend: []*config.Backend{ {}, }, }, { Endpoint: "/delete", Method: "DELETE", Timeout: 10, Backend: []*config.Backend{ {}, }, }, }, } go func() { r.Run(serviceCfg) }() time.Sleep(5 * time.Millisecond) for _, endpoint := range serviceCfg.Endpoints { req, _ := http.NewRequest(strings.ToTitle(endpoint.Method), fmt.Sprintf("http://127.0.0.1:8062%s", endpoint.Endpoint), http.NoBody) req.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(req) if err != nil { t.Error("Making the request:", err.Error()) return } defer resp.Body.Close() body, ioerr := io.ReadAll(resp.Body) if ioerr != nil { t.Error("Reading the response:", ioerr.Error()) return } content := string(body) if resp.Header.Get("Cache-Control") != "" { t.Error(endpoint.Endpoint, "Cache-Control error:", resp.Header.Get("Cache-Control")) } if resp.Header.Get(server.CompleteResponseHeaderName) != server.HeaderCompleteResponseValue { t.Error(server.CompleteResponseHeaderName, "error:", resp.Header.Get(server.CompleteResponseHeaderName)) } if resp.Header.Get("Content-Type") != "application/json" { t.Error(endpoint.Endpoint, "Content-Type error:", resp.Header.Get("Content-Type")) } if resp.Header.Get("X-Krakend") != "Version undefined" { t.Error(endpoint.Endpoint, "X-Krakend error:", resp.Header.Get("X-Krakend")) } if resp.StatusCode != http.StatusOK { t.Error(endpoint.Endpoint, "Unexpected status code:", resp.StatusCode) } if content != expectedBody { t.Error(endpoint.Endpoint, "Unexpected body:", content, "expected:", expectedBody) } } } func TestDefaultFactory_ko(t *testing.T) { buff := bytes.NewBuffer(make([]byte, 1024)) logger, err := logging.NewLogger("ERROR", buff, "pref") if err != nil { t.Error("building the logger:", err.Error()) return } ctx, cancel := context.WithCancel(context.Background()) defer func() { cancel() time.Sleep(5 * time.Millisecond) }() r := NewFactory(Config{ Engine: chi.NewRouter(), Middlewares: chi.Middlewares{}, HandlerFactory: NewEndpointHandler, ProxyFactory: noopProxyFactory(map[string]interface{}{"supu": "tupu"}), Logger: logger, RunServer: server.RunServer, }).NewWithContext(ctx) serviceCfg := config.ServiceConfig{ Debug: true, Port: 8063, Endpoints: []*config.EndpointConfig{ { Endpoint: "/ignored", Method: "GETTT", Backend: []*config.Backend{ {}, }, }, { Endpoint: "/empty", Method: "GETTT", Backend: []*config.Backend{}, }, { Endpoint: "/also-ignored", Method: "PUT", Backend: []*config.Backend{ {}, {}, }, }, }, } go func() { r.Run(serviceCfg) }() time.Sleep(5 * time.Millisecond) for _, subject := range [][]string{ {"GET", "ignored"}, {"GET", "empty"}, {"PUT", "also-ignored"}, } { req, _ := http.NewRequest(subject[0], fmt.Sprintf("http://127.0.0.1:8063/%s", subject[1]), http.NoBody) req.Header.Set("Content-Type", "application/json") checkResponseIs404(t, req) } } func TestDefaultFactory_proxyFactoryCrash(t *testing.T) { buff := bytes.NewBuffer(make([]byte, 1024)) logger, err := logging.NewLogger("ERROR", buff, "pref") if err != nil { t.Error("building the logger:", err.Error()) return } ctx, cancel := context.WithCancel(context.Background()) defer func() { cancel() time.Sleep(5 * time.Millisecond) }() r := DefaultFactory(erroredProxyFactory{fmt.Errorf("%s", "crash!!!")}, logger).NewWithContext(ctx) serviceCfg := config.ServiceConfig{ Debug: true, Port: 8064, Endpoints: []*config.EndpointConfig{ { Endpoint: "/ignored", Method: "GET", Timeout: 10, Backend: []*config.Backend{ {}, }, }, }, } go func() { r.Run(serviceCfg) }() time.Sleep(5 * time.Millisecond) for _, subject := range [][]string{{"GET", "ignored"}, {"PUT", "also-ignored"}} { req, _ := http.NewRequest(subject[0], fmt.Sprintf("http://127.0.0.1:8064/%s", subject[1]), http.NoBody) req.Header.Set("Content-Type", "application/json") checkResponseIs404(t, req) } } func TestRunServer_ko(t *testing.T) { buff := new(bytes.Buffer) logger, err := logging.NewLogger("DEBUG", buff, "") if err != nil { t.Error("building the logger:", err.Error()) return } errorMsg := "runServer error" runServerFunc := func(_ context.Context, _ config.ServiceConfig, _ http.Handler) error { return errors.New(errorMsg) } pf := noopProxyFactory(map[string]interface{}{"supu": "tupu"}) r := NewFactory( Config{ Engine: chi.NewRouter(), Middlewares: chi.Middlewares{}, HandlerFactory: NewEndpointHandler, ProxyFactory: pf, Logger: logger, DebugPattern: ChiDefaultDebugPattern, RunServer: runServerFunc, }, ).New() serviceCfg := config.ServiceConfig{} r.Run(serviceCfg) re := regexp.MustCompile(errorMsg) if !re.MatchString(buff.String()) { t.Errorf("the logger doesn't contain the expected msg: %s", buff.Bytes()) } } func checkResponseIs404(t *testing.T, req *http.Request) { expectedBody := "404 page not found\n" resp, err := http.DefaultClient.Do(req) if err != nil { t.Error("Making the request:", err.Error()) return } defer resp.Body.Close() body, ioerr := io.ReadAll(resp.Body) if ioerr != nil { t.Error("Reading the response:", ioerr.Error()) return } content := string(body) if resp.Header.Get("Cache-Control") != "" { t.Error("Cache-Control error:", resp.Header.Get("Cache-Control")) } if resp.Header.Get(server.CompleteResponseHeaderName) != server.HeaderIncompleteResponseValue { t.Error(req.URL.String(), server.CompleteResponseHeaderName, "error:", resp.Header.Get(server.CompleteResponseHeaderName)) } if resp.Header.Get("Content-Type") != "text/plain; charset=utf-8" { t.Error("Content-Type error:", resp.Header.Get("Content-Type")) } if resp.Header.Get("X-Krakend") != "" { t.Error("X-Krakend error:", resp.Header.Get("X-Krakend")) } if resp.StatusCode != http.StatusNotFound { t.Error("Unexpected status code:", resp.StatusCode) } if content != expectedBody { t.Error("Unexpected body:", content, "expected:", expectedBody) } } type noopProxyFactory map[string]interface{} func (n noopProxyFactory) New(_ *config.EndpointConfig) (proxy.Proxy, error) { return func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return &proxy.Response{ IsComplete: true, Data: n, }, nil }, nil } type erroredProxyFactory struct { Error error } func (e erroredProxyFactory) New(_ *config.EndpointConfig) (proxy.Proxy, error) { return proxy.NoopProxy, e.Error } ================================================ FILE: router/gin/debug.go ================================================ // SPDX-License-Identifier: Apache-2.0 package gin import ( "io" "github.com/gin-gonic/gin" "github.com/luraproject/lura/v2/logging" ) // DebugHandler creates a dummy handler function, useful for quick integration tests func DebugHandler(logger logging.Logger) gin.HandlerFunc { logPrefixSecondary := "[ENDPOINT: /__debug/*]" return func(c *gin.Context) { logger.Debug(logPrefixSecondary, "Method:", c.Request.Method) logger.Debug(logPrefixSecondary, "URL:", c.Request.RequestURI) logger.Debug(logPrefixSecondary, "Query:", c.Request.URL.Query()) logger.Debug(logPrefixSecondary, "Params:", c.Params) logger.Debug(logPrefixSecondary, "Headers:", c.Request.Header) body, _ := io.ReadAll(c.Request.Body) c.Request.Body.Close() logger.Debug(logPrefixSecondary, "Body:", string(body)) c.JSON(200, gin.H{ "message": "pong", }) } } ================================================ FILE: router/gin/debug_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package gin import ( "bytes" "io" "net/http" "net/http/httptest" "testing" "github.com/gin-gonic/gin" "github.com/luraproject/lura/v2/logging" ) func TestDebugHandler(t *testing.T) { buff := bytes.NewBuffer(make([]byte, 1024)) logger, err := logging.NewLogger("ERROR", buff, "pref") if err != nil { t.Error("building the logger:", err.Error()) return } router := gin.New() router.GET("/_gin_endpoint/:param", DebugHandler(logger)) req, _ := http.NewRequest("GET", "http://127.0.0.1:8088/_gin_endpoint/a?b=1", io.NopCloser(&bytes.Buffer{})) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() router.ServeHTTP(w, req) body, ioerr := io.ReadAll(w.Result().Body) if ioerr != nil { t.Error("reading a response:", ioerr.Error()) return } w.Result().Body.Close() expectedBody := "{\"message\":\"pong\"}" content := string(body) if w.Result().Header.Get("Cache-Control") != "" { t.Error("Cache-Control error:", w.Result().Header.Get("Cache-Control")) } if w.Result().Header.Get("Content-Type") != "application/json; charset=utf-8" { t.Error("Content-Type error:", w.Result().Header.Get("Content-Type")) } if w.Result().Header.Get("X-Krakend") != "" { t.Error("X-Krakend error:", w.Result().Header.Get("X-Krakend")) } if w.Result().StatusCode != http.StatusOK { t.Error("Unexpected status code:", w.Result().StatusCode) } if content != expectedBody { t.Error("Unexpected body:", content, "expected:", expectedBody) } } ================================================ FILE: router/gin/echo.go ================================================ // SPDX-License-Identifier: Apache-2.0 package gin import ( "io" "net/http" "github.com/gin-gonic/gin" ) type echoResponse struct { Uri string `json:"req_uri"` UriDetails map[string]string `json:"req_uri_details"` Method string `json:"req_method"` Querystring map[string][]string `json:"req_querystring"` Body string `json:"req_body"` Headers map[string][]string `json:"req_headers"` } // EchoHandler creates a dummy handler function, useful for quick integration tests func EchoHandler() gin.HandlerFunc { return func(c *gin.Context) { var body string if c.Request.Body != nil { tmp, _ := io.ReadAll(c.Request.Body) c.Request.Body.Close() body = string(tmp) } resp := echoResponse{ Uri: c.Request.RequestURI, UriDetails: map[string]string{ "user": c.Request.URL.User.String(), "host": c.Request.Host, "path": c.Request.URL.Path, "query": c.Request.URL.Query().Encode(), "fragment": c.Request.URL.Fragment, }, Method: c.Request.Method, Querystring: c.Request.URL.Query(), Body: body, Headers: c.Request.Header, } c.JSON(http.StatusOK, resp) } } ================================================ FILE: router/gin/echo_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package gin import ( "io" "net/http" "net/http/httptest" "strings" "testing" "github.com/gin-gonic/gin" ) func TestEchoHandler(t *testing.T) { reqBody := `{"message":"some body to send"}` expectedRespBody := `{"req_uri":"http://127.0.0.1:8088/_gin_endpoint/a?b=1","req_uri_details":{"fragment":"","host":"127.0.0.1:8088","path":"/_gin_endpoint/a","query":"b=1","user":""},"req_method":"GET","req_querystring":{"b":["1"]},"req_body":"{\"message\":\"some body to send\"}","req_headers":{"Content-Type":["application/json"]}}` expectedRespNoBody := `{"req_uri":"http://127.0.0.1:8088/_gin_endpoint/a?b=1","req_uri_details":{"fragment":"","host":"127.0.0.1:8088","path":"/_gin_endpoint/a","query":"b=1","user":""},"req_method":"GET","req_querystring":{"b":["1"]},"req_body":"","req_headers":{"Content-Type":["application/json"]}}` expectedRespString := `{"req_uri":"http://127.0.0.1:8088/_gin_endpoint/a?b=1","req_uri_details":{"fragment":"","host":"127.0.0.1:8088","path":"/_gin_endpoint/a","query":"b=1","user":""},"req_method":"GET","req_querystring":{"b":["1"]},"req_body":"Hello lura","req_headers":{"Content-Type":["application/json"]}}` gin.SetMode(gin.TestMode) router := gin.New() router.GET("/_gin_endpoint/:param", EchoHandler()) for _, tc := range []struct { name string body io.Reader resp string }{ { name: "json body", body: strings.NewReader(reqBody), resp: expectedRespBody, }, { name: "no body", body: http.NoBody, resp: expectedRespNoBody, }, { name: "string body", body: strings.NewReader("Hello lura"), resp: expectedRespString, }, } { t.Run(tc.name, func(t *testing.T) { echoRunTestRequest(t, router, tc.body, tc.resp) }) } } func echoRunTestRequest(t *testing.T, e *gin.Engine, body io.Reader, expected string) { req := httptest.NewRequest("GET", "http://127.0.0.1:8088/_gin_endpoint/a?b=1", body) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() e.ServeHTTP(w, req) respBody, ioerr := io.ReadAll(w.Result().Body) if ioerr != nil { t.Error("reading a response:", ioerr.Error()) return } w.Result().Body.Close() content := string(respBody) if w.Result().Header.Get("Cache-Control") != "" { t.Error("Cache-Control error:", w.Result().Header.Get("Cache-Control")) } if w.Result().Header.Get("Content-Type") != "application/json; charset=utf-8" { t.Error("Content-Type error:", w.Result().Header.Get("Content-Type")) } if w.Result().Header.Get("X-Krakend") != "" { t.Error("X-Krakend error:", w.Result().Header.Get("X-Krakend")) } if w.Result().StatusCode != http.StatusOK { t.Error("Unexpected status code:", w.Result().StatusCode) } if content != expected { t.Error("Unexpected body:", content, "expected:", expected) } } ================================================ FILE: router/gin/endpoint.go ================================================ // SPDX-License-Identifier: Apache-2.0 package gin import ( "context" "fmt" "net/textproto" "github.com/gin-gonic/gin" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/core" "github.com/luraproject/lura/v2/logging" "github.com/luraproject/lura/v2/proxy" "github.com/luraproject/lura/v2/transport/http/server" ) const requestParamsAsterisk string = "*" // HandlerFactory creates a handler function that adapts the gin router with the injected proxy type HandlerFactory func(*config.EndpointConfig, proxy.Proxy) gin.HandlerFunc // ErrorResponseWriter writes the string representation of an error into the response body // and sets a Content-Type header for errors that implement the encodedResponseError interface. var ErrorResponseWriter = func(c *gin.Context, err error) { if te, ok := err.(encodedResponseError); ok && te.Encoding() != "" { c.Header("Content-Type", te.Encoding()) } c.Writer.WriteString(err.Error()) } // EndpointHandler implements the HandlerFactory interface using the default ToHTTPError function var EndpointHandler = CustomErrorEndpointHandler(logging.NoOp, server.DefaultToHTTPError) // CustomErrorEndpointHandler returns a HandlerFactory using the injected ToHTTPError function and logger func CustomErrorEndpointHandler(logger logging.Logger, errF server.ToHTTPError) HandlerFactory { // skipcq: GO-R1005 return func(configuration *config.EndpointConfig, prxy proxy.Proxy) gin.HandlerFunc { cacheControlHeaderValue := fmt.Sprintf("public, max-age=%d", int(configuration.CacheTTL.Seconds())) isCacheEnabled := configuration.CacheTTL.Seconds() != 0 requestGenerator := NewRequest(configuration.HeadersToPass) render := getRender(configuration) logPrefix := "[ENDPOINT: " + configuration.Endpoint + "]" return func(c *gin.Context) { requestCtx, cancel := context.WithTimeout(c, configuration.Timeout) c.Header(core.KrakendHeaderName, core.KrakendHeaderValue) response, err := prxy(requestCtx, requestGenerator(c, configuration.QueryString)) select { case <-requestCtx.Done(): if err == nil { err = server.ErrInternalError } default: } complete := server.HeaderIncompleteResponseValue if response != nil && len(response.Data) > 0 { if response.IsComplete { complete = server.HeaderCompleteResponseValue if isCacheEnabled { c.Header("Cache-Control", cacheControlHeaderValue) } } for k, vs := range response.Metadata.Headers { for _, v := range vs { c.Writer.Header().Add(k, v) } } } c.Header(server.CompleteResponseHeaderName, complete) for _, err := range c.Errors { logger.Error(logPrefix, err.Error()) } if err != nil { if t, ok := err.(multiError); ok { for i, errN := range t.Errors() { c.Error(errN) logger.Error(fmt.Sprintf("%s Error #%d: %s", logPrefix, i, errN.Error())) } } else { c.Error(err) logger.Error(logPrefix, err.Error()) } if response == nil { if t, ok := err.(headerResponseError); ok { for k, vs := range t.Headers() { for _, v := range vs { c.Writer.Header().Add(k, v) } } } if t, ok := err.(responseError); ok { c.Status(t.StatusCode()) } else { c.Status(errF(err)) } if returnErrorMsg { ErrorResponseWriter(c, err) } cancel() return } } render(c, response) cancel() } } } // NewRequest gets a request from the current gin context and the received query string func NewRequest(headersToSend []string) func(*gin.Context, []string) *proxy.Request { if len(headersToSend) == 0 { headersToSend = server.HeadersToSend } return func(c *gin.Context, queryString []string) *proxy.Request { params := make(map[string]string, len(c.Params)) for _, param := range c.Params { params[textproto.CanonicalMIMEHeaderKey(param.Key[:1])+param.Key[1:]] = param.Value } headers := make(map[string][]string, 3+len(headersToSend)) for _, k := range headersToSend { if k == requestParamsAsterisk { headers = c.Request.Header break } if h, ok := c.Request.Header[textproto.CanonicalMIMEHeaderKey(k)]; ok { headers[k] = h } } headers["X-Forwarded-For"] = []string{c.ClientIP()} headers["X-Forwarded-Host"] = []string{c.Request.Host} // if User-Agent is not forwarded using headersToSend, we set // the KrakenD router User Agent value if _, ok := headers["User-Agent"]; !ok { headers["User-Agent"] = server.UserAgentHeaderValue } else { headers["X-Forwarded-Via"] = server.UserAgentHeaderValue } query := make(map[string][]string, len(queryString)) queryValues := c.Request.URL.Query() for i := range queryString { if queryString[i] == requestParamsAsterisk { query = c.Request.URL.Query() break } if v, ok := queryValues[queryString[i]]; ok && len(v) > 0 { query[queryString[i]] = v } } return &proxy.Request{ Path: c.Request.URL.Path, Method: c.Request.Method, Query: query, Body: c.Request.Body, Params: params, Headers: headers, } } } type encodedResponseError interface { responseError Encoding() string } type responseError interface { error StatusCode() int } type headerResponseError interface { responseError Headers() map[string][]string } type multiError interface { error Errors() []error } ================================================ FILE: router/gin/endpoint_benchmark_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package gin import ( "bytes" "context" "fmt" "io" "net/http" "net/http/httptest" "testing" "time" "github.com/gin-gonic/gin" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/proxy" ) func BenchmarkEndpointHandler_ko(b *testing.B) { p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return nil, fmt.Errorf("This is %s", "a dummy error") } endpoint := &config.EndpointConfig{ Timeout: time.Second, CacheTTL: 6 * time.Hour, QueryString: []string{"b"}, } gin.SetMode(gin.TestMode) router := gin.New() router.GET("/_gin_endpoint/:param", EndpointHandler(endpoint, p)) req, _ := http.NewRequest("GET", "http://127.0.0.1:8080/_gin_endpoint/a?b=1", http.NoBody) req.Header.Set("Content-Type", "application/json") b.ReportAllocs() for i := 0; i < b.N; i++ { w := httptest.NewRecorder() router.ServeHTTP(w, req) } } func BenchmarkEndpointHandler_ok(b *testing.B) { pResp := proxy.Response{ Data: map[string]interface{}{}, Io: io.NopCloser(&bytes.Buffer{}), IsComplete: true, Metadata: proxy.Metadata{}, } p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return &pResp, nil } endpoint := &config.EndpointConfig{ Timeout: time.Second, CacheTTL: 6 * time.Hour, QueryString: []string{"b"}, } gin.SetMode(gin.TestMode) router := gin.New() router.GET("/_gin_endpoint/:param", EndpointHandler(endpoint, p)) req, _ := http.NewRequest("GET", "http://127.0.0.1:8080/_gin_endpoint/a?b=1", http.NoBody) req.Header.Set("Content-Type", "application/json") b.ReportAllocs() for i := 0; i < b.N; i++ { w := httptest.NewRecorder() router.ServeHTTP(w, req) } } func BenchmarkEndpointHandler_ko_Parallel(b *testing.B) { p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return nil, fmt.Errorf("This is %s", "a dummy error") } endpoint := &config.EndpointConfig{ Timeout: time.Second, CacheTTL: 6 * time.Hour, QueryString: []string{"b"}, } gin.SetMode(gin.TestMode) router := gin.New() router.GET("/_gin_endpoint/:param", EndpointHandler(endpoint, p)) req, _ := http.NewRequest("GET", "http://127.0.0.1:8080/_gin_endpoint/a?b=1", http.NoBody) req.Header.Set("Content-Type", "application/json") b.ReportAllocs() b.RunParallel(func(pb *testing.PB) { for pb.Next() { w := httptest.NewRecorder() router.ServeHTTP(w, req) } }) } func BenchmarkEndpointHandler_ok_Parallel(b *testing.B) { pResp := proxy.Response{ Data: map[string]interface{}{}, Io: io.NopCloser(&bytes.Buffer{}), IsComplete: true, Metadata: proxy.Metadata{}, } p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return &pResp, nil } endpoint := &config.EndpointConfig{ Timeout: time.Second, CacheTTL: 6 * time.Hour, QueryString: []string{"b"}, } gin.SetMode(gin.TestMode) router := gin.New() router.GET("/_gin_endpoint/:param", EndpointHandler(endpoint, p)) req, _ := http.NewRequest("GET", "http://127.0.0.1:8080/_gin_endpoint/a?b=1", http.NoBody) req.Header.Set("Content-Type", "application/json") b.ReportAllocs() b.RunParallel(func(pb *testing.PB) { for pb.Next() { w := httptest.NewRecorder() router.ServeHTTP(w, req) } }) } ================================================ FILE: router/gin/endpoint_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package gin import ( "bytes" "context" "encoding/json" "errors" "io" "net/http" "net/http/httptest" "reflect" "strings" "testing" "time" "github.com/gin-gonic/gin" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" "github.com/luraproject/lura/v2/proxy" "github.com/luraproject/lura/v2/transport/http/server" ) func TestEndpointHandler_ok(t *testing.T) { p := func(ctx context.Context, req *proxy.Request) (*proxy.Response, error) { if v, ok := ctx.Value("bool").(bool); !ok || !v { t.Errorf("unexpected bool context value: %v", v) } if v, ok := ctx.Value("int").(int); !ok || v != 42 { t.Errorf("unexpected int context value: %v", v) } if v, ok := ctx.Value("string").(string); !ok || v != "supu" { t.Errorf("unexpected string context value: %v", v) } data, _ := json.Marshal(req.Query) if string(data) != `{"b":["1"],"c[]":["x","y"],"d":["1","2"]}` { t.Errorf("unexpected querystring: %s", data) } return &proxy.Response{ IsComplete: true, Data: map[string]interface{}{"supu": "tupu"}, Metadata: proxy.Metadata{ Headers: map[string][]string{"a": {"a1", "a2"}}, }, }, nil } endpointHandlerTestCase{ timeout: 10, proxy: p, method: "GET", expectedBody: "{\"supu\":\"tupu\"}", expectedCache: "public, max-age=21600", expectedContent: "application/json; charset=utf-8", expectedStatusCode: http.StatusOK, completed: true, }.test(t) time.Sleep(5 * time.Millisecond) } func TestEndpointHandler_okAllParams(t *testing.T) { p := func(_ context.Context, req *proxy.Request) (*proxy.Response, error) { return &proxy.Response{ IsComplete: true, Data: map[string]interface{}{"query": req.Query, "headers": req.Headers, "params": req.Params}, Metadata: proxy.Metadata{ Headers: map[string][]string{"X-YZ": {"something"}}, StatusCode: 200, }, }, nil } endpointHandlerTestCase{ timeout: 10, proxy: p, method: "GET", expectedBody: `{"headers":{"Content-Type":["application/json"],"User-Agent":["KrakenD Version undefined"],"X-Forwarded-For":[""],"X-Forwarded-Host":["127.0.0.1:8080"]},"params":{"Param":"a"},"query":{"a":["42"],"b":["1"],"c[]":["x","y"],"d":["1","2"]}}`, expectedCache: "public, max-age=21600", expectedContent: "application/json; charset=utf-8", expectedStatusCode: http.StatusOK, completed: true, queryString: []string{"*"}, headers: []string{"*"}, expectedHeaders: map[string][]string{"X-YZ": {"something"}}, }.test(t) time.Sleep(5 * time.Millisecond) } var ctxContent = map[string]interface{}{ "bool": true, "int": 42, "string": "supu", } func TestEndpointHandler_incomplete(t *testing.T) { p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return &proxy.Response{ IsComplete: false, Data: map[string]interface{}{"foo": "bar"}, }, nil } endpointHandlerTestCase{ timeout: 10, proxy: p, method: "GET", expectedBody: "{\"foo\":\"bar\"}", expectedCache: "", expectedContent: "application/json; charset=utf-8", expectedStatusCode: http.StatusOK, completed: false, }.test(t) time.Sleep(5 * time.Millisecond) } func TestEndpointHandler_errored(t *testing.T) { p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return nil, errors.New("this is a dummy error") } endpointHandlerTestCase{ timeout: 10, proxy: p, method: "GET", expectedBody: "", expectedCache: "", expectedContent: "", expectedStatusCode: http.StatusInternalServerError, completed: false, }.test(t) time.Sleep(5 * time.Millisecond) } func TestEndpointHandler_errored_responseError(t *testing.T) { expectedBody := "this is a dummy error" p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return nil, dummyResponseError{err: expectedBody, status: http.StatusTeapot} } endpointHandlerTestCase{ timeout: 10, proxy: p, method: "GET", expectedBody: "", expectedCache: "", expectedContent: "", expectedStatusCode: http.StatusTeapot, completed: false, }.test(t) time.Sleep(5 * time.Millisecond) // Same test case but with body (return_error_msg enabled) returnErrorMsg = true endpointHandlerTestCase{ timeout: 10, proxy: p, method: "GET", expectedBody: expectedBody, expectedCache: "", expectedContent: "", expectedStatusCode: http.StatusTeapot, completed: false, }.test(t) time.Sleep(5 * time.Millisecond) returnErrorMsg = false } func TestEndpointHandler_errored_withHeaders(t *testing.T) { expectedBody := "this is a dummy error" p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return nil, dummyHeadersResponseError{ dummyResponseError: dummyResponseError{err: expectedBody, status: http.StatusTeapot}, headers: map[string][]string{ "X-Header": {"header1", "header2"}, }, } } endpointHandlerTestCase{ timeout: 10, proxy: p, method: "GET", expectedBody: "", expectedCache: "", expectedContent: "", expectedHeaders: map[string][]string{"X-Header": {"header1", "header2"}}, expectedStatusCode: http.StatusTeapot, completed: false, }.test(t) time.Sleep(5 * time.Millisecond) } func TestEndpointHandler_errored_encodedResponseError(t *testing.T) { expectedBody := `{ "message": "this is a dummy error" }` p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return nil, dummyEncodedResponseError{dummyResponseError: dummyResponseError{err: expectedBody, status: http.StatusTeapot}, encoding: "application/json"} } returnErrorMsg = true endpointHandlerTestCase{ timeout: 10, proxy: p, method: "GET", expectedBody: expectedBody, expectedCache: "", expectedContent: "application/json", expectedStatusCode: http.StatusTeapot, completed: false, }.test(t) time.Sleep(5 * time.Millisecond) returnErrorMsg = false } type dummyResponseError struct { err string status int } func (d dummyResponseError) Error() string { return d.err } func (d dummyResponseError) StatusCode() int { return d.status } type dummyEncodedResponseError struct { dummyResponseError encoding string } func (d dummyEncodedResponseError) Encoding() string { return d.encoding } type dummyHeadersResponseError struct { dummyResponseError headers map[string][]string } func (d dummyHeadersResponseError) Headers() map[string][]string { return d.headers } func TestEndpointHandler_incompleteAndErrored(t *testing.T) { p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return &proxy.Response{ IsComplete: false, Data: map[string]interface{}{"foo": "bar"}, }, errors.New("This is a dummy error") } endpointHandlerTestCase{ timeout: 10, proxy: p, method: "GET", expectedBody: "{\"foo\":\"bar\"}", expectedCache: "", expectedContent: "application/json; charset=utf-8", expectedStatusCode: http.StatusOK, completed: false, }.test(t) time.Sleep(5 * time.Millisecond) } func TestEndpointHandler_cancelEmpty(t *testing.T) { p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { time.Sleep(100 * time.Millisecond) return nil, nil } endpointHandlerTestCase{ timeout: 0, proxy: p, method: "GET", expectedBody: "", expectedCache: "", expectedContent: "", expectedStatusCode: http.StatusInternalServerError, completed: false, }.test(t) time.Sleep(5 * time.Millisecond) } func TestEndpointHandler_cancel(t *testing.T) { p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { time.Sleep(100 * time.Millisecond) return &proxy.Response{ IsComplete: false, Data: map[string]interface{}{"foo": "bar"}, }, nil } endpointHandlerTestCase{ timeout: 0, proxy: p, method: "GET", expectedBody: "{\"foo\":\"bar\"}", expectedCache: "", expectedContent: "application/json; charset=utf-8", expectedStatusCode: http.StatusOK, completed: false, }.test(t) time.Sleep(5 * time.Millisecond) } func TestEndpointHandler_noop(t *testing.T) { endpointHandlerTestCase{ timeout: time.Minute, proxy: proxy.NoopProxy, method: "GET", expectedBody: "{}", expectedCache: "", expectedContent: "application/json; charset=utf-8", expectedStatusCode: http.StatusOK, completed: false, }.test(t) time.Sleep(5 * time.Millisecond) } func TestCustomErrorEndpointHandler(t *testing.T) { buff := bytes.NewBuffer(make([]byte, 1024)) logger, err := logging.NewLogger("ERROR", buff, "pref") if err != nil { t.Error("building the logger:", err.Error()) return } hf := CustomErrorEndpointHandler(logger, server.DefaultToHTTPError) endpoint := &config.EndpointConfig{ Method: "GET", Endpoint: "/", Timeout: time.Minute, CacheTTL: 6 * time.Hour, QueryString: []string{"b", "c[]", "d"}, } p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return nil, errors.New("this is a dummy error") } s := startGinServer(hf(endpoint, p)) req, _ := http.NewRequest( "GET", "http://127.0.0.1:8080/_gin_endpoint/a?a=42&b=1&c[]=x&c[]=y&d=1&d=2", io.NopCloser(&bytes.Buffer{}), ) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() s.ServeHTTP(w, req) if content := buff.String(); !strings.Contains(content, "pref ERROR: [ENDPOINT: /] this is a dummy error") { t.Error("unexpected log content", content) } } type endpointHandlerTestCase struct { timeout time.Duration proxy proxy.Proxy method string expectedBody string expectedCache string expectedContent string expectedHeaders map[string][]string expectedStatusCode int completed bool queryString []string headers []string } func (tc endpointHandlerTestCase) test(t *testing.T) { endpoint := &config.EndpointConfig{ Method: "GET", Timeout: tc.timeout, CacheTTL: 6 * time.Hour, QueryString: []string{"b", "c[]", "d"}, } if len(tc.queryString) > 0 { endpoint.QueryString = tc.queryString } if len(tc.headers) > 0 { endpoint.HeadersToPass = tc.headers } s := startGinServer(EndpointHandler(endpoint, tc.proxy)) req, _ := http.NewRequest( tc.method, "http://127.0.0.1:8080/_gin_endpoint/a?a=42&b=1&c[]=x&c[]=y&d=1&d=2", io.NopCloser(&bytes.Buffer{}), ) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() s.ServeHTTP(w, req) body, ioerr := io.ReadAll(w.Result().Body) if ioerr != nil { t.Error("Reading the response:", ioerr.Error()) return } w.Result().Body.Close() content := string(body) resp := w.Result() if resp.Header.Get("Cache-Control") != tc.expectedCache { t.Error("Cache-Control error:", resp.Header.Get("Cache-Control")) } if tc.completed && resp.Header.Get(server.CompleteResponseHeaderName) != server.HeaderCompleteResponseValue { t.Error(server.CompleteResponseHeaderName, "error:", resp.Header.Get(server.CompleteResponseHeaderName)) } if !tc.completed && resp.Header.Get(server.CompleteResponseHeaderName) != server.HeaderIncompleteResponseValue { t.Error(server.CompleteResponseHeaderName, "error:", resp.Header.Get(server.CompleteResponseHeaderName)) } if resp.Header.Get("Content-Type") != tc.expectedContent { t.Error("Content-Type error:", resp.Header.Get("Content-Type")) } if resp.Header.Get("X-Krakend") != "Version undefined" { t.Error("X-Krakend error:", resp.Header.Get("X-Krakend")) } if resp.StatusCode != tc.expectedStatusCode { t.Error("Unexpected status code:", resp.StatusCode) } if content != tc.expectedBody { t.Error("Unexpected body:", content, "expected:", tc.expectedBody) } for k, v := range tc.expectedHeaders { h := resp.Header.Values(k) if !reflect.DeepEqual(h, v) { t.Error("Unexpected value for header:", k, h, "expected:", v) } } } func startGinServer(handlerFunc gin.HandlerFunc) *gin.Engine { gin.SetMode(gin.TestMode) r := gin.New() r.GET("/_gin_endpoint/:param", ctxMiddleware, handlerFunc) return r } func ctxMiddleware(c *gin.Context) { for k, v := range ctxContent { c.Set(k, v) } } ================================================ FILE: router/gin/engine.go ================================================ // SPDX-License-Identifier: Apache-2.0 package gin import ( "encoding/json" "errors" "fmt" "io" "net/http" "net/textproto" "net/url" "strings" "sync" "time" "github.com/gin-gonic/gin" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/core" "github.com/luraproject/lura/v2/logging" "github.com/luraproject/lura/v2/transport/http/server" ) const Namespace = "github_com/luraproject/lura/router/gin" type EngineOptions struct { Logger logging.Logger Writer io.Writer Formatter gin.LogFormatter Health <-chan string } // NewEngine returns an initialized gin engine func NewEngine(cfg config.ServiceConfig, opt EngineOptions) *gin.Engine { gin.SetMode(gin.ReleaseMode) if cfg.Debug { opt.Logger.Debug(logPrefix, "Debug enabled") } engine := gin.New() engine.RedirectTrailingSlash = true engine.RedirectFixedPath = true engine.HandleMethodNotAllowed = true engine.ContextWithFallback = true var paths []string ginOptions := engineConfiguration{} if v, ok := cfg.ExtraConfig[Namespace]; ok { if b, err := json.Marshal(v); err == nil { if err := json.Unmarshal(b, &ginOptions); err == nil { if ginOptions.DisableRedirectTrailingSlash != nil { engine.RedirectTrailingSlash = !*ginOptions.DisableRedirectTrailingSlash } if ginOptions.DisableRedirectFixedPath != nil { engine.RedirectFixedPath = !*ginOptions.DisableRedirectFixedPath } if ginOptions.DisableHandleMethodNotAllowed != nil { engine.HandleMethodNotAllowed = !*ginOptions.DisableHandleMethodNotAllowed } if ginOptions.ForwardedByClientIP != nil { engine.ForwardedByClientIP = *ginOptions.ForwardedByClientIP } if len(ginOptions.RemoteIPHeaders) > 0 { engine.RemoteIPHeaders = ginOptions.RemoteIPHeaders for k, h := range engine.RemoteIPHeaders { engine.RemoteIPHeaders[k] = textproto.CanonicalMIMEHeaderKey(h) } } if len(ginOptions.TrustedProxies) > 0 { engine.SetTrustedProxies(ginOptions.TrustedProxies) } engine.AppEngine = ginOptions.AppEngine engine.MaxMultipartMemory = ginOptions.MaxMultipartMemory engine.RemoveExtraSlash = ginOptions.RemoveExtraSlash engine.UseH2C = ginOptions.UseH2C paths = ginOptions.LoggerSkipPaths returnErrorMsg = ginOptions.ReturnErrorMsg if ginOptions.ObfuscateVersionHeader { core.KrakendHeaderValue = "Version undefined" } } } } engine.NoRoute(func(c *gin.Context) { c.Header(server.CompleteResponseHeaderName, server.HeaderIncompleteResponseValue) }) if !ginOptions.DisableAccessLog { engine.Use( gin.LoggerWithConfig(gin.LoggerConfig{ Output: opt.Writer, SkipPaths: paths, Formatter: opt.Formatter, }), ) } engine.Use(gin.Recovery()) if !ginOptions.DisablePathDecoding { engine.Use(paramChecker()) } if !ginOptions.DisableHealthEndpoint { path := "/__health" if ginOptions.HealthPath != "" { path = ginOptions.HealthPath } engine.GET(path, healthEndpoint(opt.Health)) } return engine } func healthEndpoint(health <-chan string) func(*gin.Context) { mu := new(sync.RWMutex) reports := map[string]string{} go func() { for name := range health { mu.Lock() reports[name] = time.Now().String() mu.Unlock() } }() return func(c *gin.Context) { mu.RLock() defer mu.RUnlock() c.JSON(200, gin.H{"status": "ok", "agents": reports, "now": time.Now().String()}) } } func paramChecker() gin.HandlerFunc { return func(c *gin.Context) { for _, param := range c.Params { s, err := url.PathUnescape(param.Value) if err != nil { c.Status(http.StatusBadRequest) ErrorResponseWriter(c, fmt.Errorf("error: %s", err)) c.Abort() return } if s != param.Value || strings.Contains(s, "?") || strings.Contains(s, "#") { c.Status(http.StatusBadRequest) ErrorResponseWriter(c, errors.New("error: encoded url params")) c.Abort() return } } } } type engineConfiguration struct { // Disables automatic redirection if the current route can't be matched but a // handler for the path with (without) the trailing slash exists. // For example if /foo/ is requested but a route only exists for /foo, the // client is redirected to /foo with http status code 301 for GET requests // and 307 for all other request methods. DisableRedirectTrailingSlash *bool `json:"disable_redirect_trailing_slash"` // If enabled, the router tries to fix the current request path, if no // handle is registered for it. // First superfluous path elements like ../ or // are removed. // Afterwards the router does a case-insensitive lookup of the cleaned path. // If a handle can be found for this route, the router makes a redirection // to the corrected path with status code 301 for GET requests and 307 for // all other request methods. // For example /FOO and /..//Foo could be redirected to /foo. // RedirectTrailingSlash is independent of this option. DisableRedirectFixedPath *bool `json:"disable_redirect_fixed_path"` // If enabled, the router checks if another method is allowed for the // current route, if the current request can not be routed. // If this is the case, the request is answered with 'Method Not Allowed' // and HTTP status code 405. // If no other Method is allowed, the request is delegated to the NotFound // handler. DisableHandleMethodNotAllowed *bool `json:"disable_handle_method_not_allowed"` // If enabled, client IP will be parsed from the request's headers that // match those stored at `(*gin.Engine).RemoteIPHeaders`. If no IP was // fetched, it falls back to the IP obtained from // `(*gin.Context).Request.RemoteAddr`. ForwardedByClientIP *bool `json:"forwarded_by_client_ip"` // List of headers used to obtain the client IP when // `(*gin.Engine).ForwardedByClientIP` is `true` and // `(*gin.Context).Request.RemoteAddr` is matched by at least one of the // network origins of `(*gin.Engine).TrustedProxies`. RemoteIPHeaders []string `json:"remote_ip_headers"` // List of network origins (IPv4 addresses, IPv4 CIDRs, IPv6 addresses or // IPv6 CIDRs) from which to trust request's headers that contain // alternative client IP when `(*gin.Engine).ForwardedByClientIP` is // `true`. TrustedProxies []string `json:"trusted_proxies"` // #726 #755 If enabled, it will trust some headers starting with // 'X-AppEngine...' for better integration with that PaaS. AppEngine bool `json:"app_engine"` // Value of 'maxMemory' param that is given to http.Request's ParseMultipartForm // method call. MaxMultipartMemory int64 `json:"max_multipart_memory"` // RemoveExtraSlash a parameter can be parsed from the URL even with extra slashes. // See the PR #1817 and issue #1644 RemoveExtraSlash bool `json:"remove_extra_slash"` // LoggerSkipPaths defines the set of path to avoid logging LoggerSkipPaths []string `json:"logger_skip_paths"` // AutoOptions enables the autogenerated OPTIONS endpoint for all the registered paths AutoOptions bool `json:"auto_options"` // ReturnErrorMsg flags if the error msg should be returned to the client as response body ReturnErrorMsg bool `json:"return_error_msg"` // DisableHealthEndpoint marks if the health check endpoint should be exposed DisableHealthEndpoint bool `json:"disable_health"` // HealthPath allows users to define a custom path for the health check endpoint HealthPath string `json:"health_path"` // DisableAccessLog blocks the injection of the router logger DisableAccessLog bool `json:"disable_access_log"` // DisablePathDecoding disables automatic validation of the url params looking for url encoded ones. // For example if /foo/..%252Fbar is requested and this flag is set to false, the router will // reject the request with http status code 400. DisablePathDecoding bool `json:"disable_path_decoding"` // ObfuscateVersionHeader flags if the version header returned by the router should replace the actual // version with the value "undefined" ObfuscateVersionHeader bool `json:"hide_version_header"` // UseH2C enable h2c support. UseH2C bool `json:"use_h2c"` } var returnErrorMsg bool ================================================ FILE: router/gin/engine_test.go ================================================ package gin import ( "context" "io" "net/http" "net/http/httptest" "testing" "github.com/gin-gonic/gin" "github.com/luraproject/lura/v2/config" ) func TestNewEngine_contextIsPropagated(t *testing.T) { engine := NewEngine( config.ServiceConfig{}, EngineOptions{}, ) type ctxKeyType string ctxKey := ctxKeyType("foo") ctxValue := "bar" engine.GET("/some/path", func(c *gin.Context) { c.String(http.StatusOK, "%v", c.Value(ctxKey)) }) req, _ := http.NewRequest("GET", "/some/path", http.NoBody) req = req.WithContext(context.WithValue(req.Context(), ctxKey, ctxValue)) w := httptest.NewRecorder() engine.ServeHTTP(w, req) resp := w.Result() if sc := resp.StatusCode; sc != http.StatusOK { t.Errorf("unexpected status code: %d", sc) return } b, err := io.ReadAll(resp.Body) if err != nil { t.Errorf("reading the response body: %s", err.Error()) return } if string(b) != ctxValue { t.Errorf("unexpected value: %s", string(b)) } } func TestNewEngine_paramsAreChecked(t *testing.T) { engine := NewEngine( config.ServiceConfig{}, EngineOptions{}, ) engine.GET("/user/:id/public", func(c *gin.Context) { c.String(http.StatusOK, "ok") }) assertResponse := func(path string, statusCode int, body string) { req, _ := http.NewRequest("GET", path, http.NoBody) w := httptest.NewRecorder() engine.ServeHTTP(w, req) resp := w.Result() if sc := resp.StatusCode; sc != statusCode { t.Errorf("unexpected status code: %d (expected %d)", sc, statusCode) return } b, err := io.ReadAll(resp.Body) if err != nil { t.Errorf("reading the response body: %s", err.Error()) return } if string(b) != body { t.Errorf("unexpected response body: '%s' (expected '%s')", string(b), body) } } assertResponse("/user/123/public", http.StatusOK, "ok") assertResponse("/user/123%3f/public", http.StatusBadRequest, "error: encoded url params") assertResponse("/user/123%23/public", http.StatusBadRequest, "error: encoded url params") } ================================================ FILE: router/gin/render.go ================================================ // SPDX-License-Identifier: Apache-2.0 package gin import ( "io" "net/http" "sync" "github.com/gin-gonic/gin" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/encoding" "github.com/luraproject/lura/v2/proxy" ) // Render defines the signature of the functions to be use for the final response // encoding and rendering type Render func(*gin.Context, *proxy.Response) // NEGOTIATE defines the value of the OutputEncoding for the negotiated render const NEGOTIATE = "negotiate" const XML = "xml" const YAML = "yaml" var ( mutex = &sync.RWMutex{} renderRegister = map[string]Render{ encoding.STRING: stringRender, encoding.JSON: jsonRender, encoding.NOOP: noopRender, "json-collection": jsonCollectionRender, XML: xmlRender, YAML: yamlRender, } ) func init() { // the negotiated render must be registered at the init function in order // to avoid a cyclical dependency renderRegister[NEGOTIATE] = negotiatedRender } // RegisterRender allows clients to register their custom renders func RegisterRender(name string, r Render) { mutex.Lock() renderRegister[name] = r mutex.Unlock() } func getRender(cfg *config.EndpointConfig) Render { fallback := jsonRender if len(cfg.Backend) == 1 { fallback = getWithFallback(cfg.Backend[0].Encoding, fallback) } if cfg.OutputEncoding == "" { return fallback } return getWithFallback(cfg.OutputEncoding, fallback) } func getWithFallback(key string, fallback Render) Render { mutex.RLock() r, ok := renderRegister[key] mutex.RUnlock() if !ok { return fallback } return r } func negotiatedRender(c *gin.Context, response *proxy.Response) { switch c.NegotiateFormat(gin.MIMEJSON, gin.MIMEPlain, gin.MIMEXML, gin.MIMEYAML) { case gin.MIMEXML: getWithFallback(XML, jsonRender)(c, response) case gin.MIMEPlain, gin.MIMEYAML: getWithFallback(YAML, jsonRender)(c, response) default: getWithFallback(encoding.JSON, jsonRender)(c, response) } } func stringRender(c *gin.Context, response *proxy.Response) { status := c.Writer.Status() if response == nil { c.String(status, "") return } d, ok := response.Data["content"] if !ok { c.String(status, "") return } msg, ok := d.(string) if !ok { c.String(status, "") return } c.String(status, msg) } func jsonRender(c *gin.Context, response *proxy.Response) { status := c.Writer.Status() if response == nil { c.JSON(status, emptyResponse) return } c.JSON(status, response.Data) } func jsonCollectionRender(c *gin.Context, response *proxy.Response) { status := c.Writer.Status() if response == nil { c.JSON(status, []struct{}{}) return } col, ok := response.Data["collection"] if !ok { c.JSON(status, []struct{}{}) return } c.JSON(status, col) } func xmlRender(c *gin.Context, response *proxy.Response) { status := c.Writer.Status() if response == nil { c.XML(status, nil) return } d, ok := response.Data["content"] if !ok { c.XML(status, nil) return } c.XML(status, d) } func yamlRender(c *gin.Context, response *proxy.Response) { status := c.Writer.Status() if response == nil { c.YAML(status, emptyResponse) return } c.YAML(status, response.Data) } func noopRender(c *gin.Context, response *proxy.Response) { if response == nil { c.Status(http.StatusInternalServerError) return } for k, vs := range response.Metadata.Headers { for _, v := range vs { c.Writer.Header().Add(k, v) } } c.Status(response.Metadata.StatusCode) if response.Io == nil { return } io.Copy(c.Writer, response.Io) } var emptyResponse = gin.H{} ================================================ FILE: router/gin/render_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package gin import ( "bytes" "context" "fmt" "io" "net/http" "net/http/httptest" "reflect" "testing" "time" "github.com/gin-gonic/gin" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/encoding" "github.com/luraproject/lura/v2/proxy" ) func TestRender_Negotiated_ok(t *testing.T) { type A struct { B string } p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return &proxy.Response{ IsComplete: true, Data: map[string]interface{}{"content": A{B: "supu"}}, }, nil } endpoint := &config.EndpointConfig{ Timeout: time.Second, CacheTTL: 6 * time.Hour, QueryString: []string{"b"}, OutputEncoding: NEGOTIATE, } gin.SetMode(gin.TestMode) server := gin.New() server.GET("/_gin_endpoint/:param", EndpointHandler(endpoint, p)) for _, testData := range [][]string{ {"plain", "text/plain", "application/yaml; charset=utf-8", "content:\n b: supu\n"}, {"none", "", "application/json; charset=utf-8", `{"content":{"B":"supu"}}`}, {"json", "application/json", "application/json; charset=utf-8", `{"content":{"B":"supu"}}`}, {"xml", "application/xml", "application/xml; charset=utf-8", `supu`}, } { req, _ := http.NewRequest("GET", "http://127.0.0.1:8080/_gin_endpoint/a?b=1", io.NopCloser(&bytes.Buffer{})) req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", testData[1]) w := httptest.NewRecorder() server.ServeHTTP(w, req) defer w.Result().Body.Close() body, ioerr := io.ReadAll(w.Result().Body) if ioerr != nil { t.Error("reading response body:", ioerr) return } content := string(body) if w.Result().Header.Get("Cache-Control") != "public, max-age=21600" { t.Error(testData[0], "Cache-Control error:", w.Result().Header.Get("Cache-Control")) } if w.Result().Header.Get("Content-Type") != testData[2] { t.Error(testData[0], "Content-Type error:", w.Result().Header.Get("Content-Type")) } if w.Result().Header.Get("X-Krakend") != "Version undefined" { t.Error(testData[0], "X-Krakend error:", w.Result().Header.Get("X-Krakend")) } if w.Result().StatusCode != http.StatusOK { t.Error(testData[0], "Unexpected status code:", w.Result().StatusCode) } if content != testData[3] { t.Error(testData[0], fmt.Sprintf("Unexpected body: '%s'\nexpected: '%s'", content, testData[3])) } } } func TestRender_Negotiated_noData(t *testing.T) { p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return &proxy.Response{ Data: map[string]interface{}{}, }, nil } endpoint := &config.EndpointConfig{ Timeout: time.Second, CacheTTL: 6 * time.Hour, QueryString: []string{"b"}, OutputEncoding: NEGOTIATE, } gin.SetMode(gin.TestMode) server := gin.New() server.GET("/_gin_endpoint/:param", EndpointHandler(endpoint, p)) for _, testData := range [][]string{ {"plain", "text/plain", "application/yaml; charset=utf-8", "{}\n"}, {"none", "", "application/json; charset=utf-8", "{}"}, {"json", "application/json", "application/json; charset=utf-8", "{}"}, {"xml", "application/xml", "application/xml; charset=utf-8", ""}, } { req, _ := http.NewRequest("GET", "http://127.0.0.1:8080/_gin_endpoint/a?b=1", io.NopCloser(&bytes.Buffer{})) req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", testData[1]) w := httptest.NewRecorder() server.ServeHTTP(w, req) defer w.Result().Body.Close() body, ioerr := io.ReadAll(w.Result().Body) if ioerr != nil { t.Error("reading response body:", ioerr) return } content := string(body) if w.Result().Header.Get("Content-Type") != testData[2] { t.Error(testData[0], "Content-Type error:", w.Result().Header.Get("Content-Type")) } if w.Result().Header.Get("X-Krakend") != "Version undefined" { t.Error(testData[0], "X-Krakend error:", w.Result().Header.Get("X-Krakend")) } if w.Result().StatusCode != http.StatusOK { t.Error(testData[0], "Unexpected status code:", w.Result().StatusCode) } if content != testData[3] { t.Error(testData[0], "Unexpected body:", content, "expected:", testData[3]) } } } func TestRender_Negotiated_noResponse(t *testing.T) { p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return nil, nil } endpoint := &config.EndpointConfig{ Timeout: time.Second, CacheTTL: 6 * time.Hour, QueryString: []string{"b"}, OutputEncoding: NEGOTIATE, } gin.SetMode(gin.TestMode) server := gin.New() server.GET("/_gin_endpoint/:param", EndpointHandler(endpoint, p)) for _, testData := range [][]string{ {"plain", "text/plain", "application/yaml; charset=utf-8", "{}\n"}, {"none", "", "application/json; charset=utf-8", "{}"}, {"json", "application/json", "application/json; charset=utf-8", "{}"}, {"xml", "application/xml", "application/xml; charset=utf-8", ""}, } { req, _ := http.NewRequest("GET", "http://127.0.0.1:8080/_gin_endpoint/a?b=1", io.NopCloser(&bytes.Buffer{})) req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", testData[1]) w := httptest.NewRecorder() server.ServeHTTP(w, req) defer w.Result().Body.Close() body, ioerr := io.ReadAll(w.Result().Body) if ioerr != nil { t.Error("reading response body:", ioerr) return } content := string(body) if w.Result().Header.Get("Content-Type") != testData[2] { t.Error(testData[0], "Content-Type error:", w.Result().Header.Get("Content-Type")) } if w.Result().Header.Get("X-Krakend") != "Version undefined" { t.Error(testData[0], "X-Krakend error:", w.Result().Header.Get("X-Krakend")) } if w.Result().StatusCode != http.StatusOK { t.Error(testData[0], "Unexpected status code:", w.Result().StatusCode) } if content != testData[3] { t.Error(testData[0], "Unexpected body:", content, "expected:", testData[3]) } } } func TestRender_unknown(t *testing.T) { p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return &proxy.Response{ IsComplete: true, Data: map[string]interface{}{"supu": "tupu"}, }, nil } endpoint := &config.EndpointConfig{ Timeout: time.Second, CacheTTL: 6 * time.Hour, QueryString: []string{"b"}, OutputEncoding: "unknown", } gin.SetMode(gin.TestMode) server := gin.New() server.GET("/_gin_endpoint/:param", EndpointHandler(endpoint, p)) expectedHeader := "application/json; charset=utf-8" expectedBody := `{"supu":"tupu"}` for _, testData := range [][]string{ {"plain", "text/plain"}, {"none", ""}, {"json", "application/json"}, {"unknown", "unknown"}, } { req, _ := http.NewRequest("GET", "http://127.0.0.1:8080/_gin_endpoint/a?b=1", io.NopCloser(&bytes.Buffer{})) req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", testData[1]) w := httptest.NewRecorder() server.ServeHTTP(w, req) defer w.Result().Body.Close() body, ioerr := io.ReadAll(w.Result().Body) if ioerr != nil { t.Error("reading response body:", ioerr) return } content := string(body) if w.Result().Header.Get("Cache-Control") != "public, max-age=21600" { t.Error(testData[0], "Cache-Control error:", w.Result().Header.Get("Cache-Control")) } if w.Result().Header.Get("Content-Type") != expectedHeader { t.Error(testData[0], "Content-Type error:", w.Result().Header.Get("Content-Type")) } if w.Result().Header.Get("X-Krakend") != "Version undefined" { t.Error(testData[0], "X-Krakend error:", w.Result().Header.Get("X-Krakend")) } if w.Result().StatusCode != http.StatusOK { t.Error(testData[0], "Unexpected status code:", w.Result().StatusCode) } if content != expectedBody { t.Error(testData[0], "Unexpected body:", content, "expected:", expectedBody) } } } func TestRender_string(t *testing.T) { expectedContent := "supu" expectedHeader := "text/plain; charset=utf-8" p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return &proxy.Response{ IsComplete: true, Data: map[string]interface{}{"content": expectedContent}, }, nil } endpoint := &config.EndpointConfig{ Timeout: time.Second, CacheTTL: 6 * time.Hour, QueryString: []string{"b"}, OutputEncoding: encoding.STRING, } gin.SetMode(gin.TestMode) server := gin.New() server.GET("/_gin_endpoint/:param", EndpointHandler(endpoint, p)) for _, testData := range [][]string{ {"plain", "text/plain"}, {"none", ""}, {"json", "application/json"}, {"unknown", "unknown"}, } { req, _ := http.NewRequest("GET", "http://127.0.0.1:8080/_gin_endpoint/a?b=1", io.NopCloser(&bytes.Buffer{})) req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", testData[1]) w := httptest.NewRecorder() server.ServeHTTP(w, req) defer w.Result().Body.Close() body, ioerr := io.ReadAll(w.Result().Body) if ioerr != nil { t.Error("reading response body:", ioerr) return } content := string(body) if w.Result().Header.Get("Cache-Control") != "public, max-age=21600" { t.Error(testData[0], "Cache-Control error:", w.Result().Header.Get("Cache-Control")) } if w.Result().Header.Get("Content-Type") != expectedHeader { t.Error(testData[0], "Content-Type error:", w.Result().Header.Get("Content-Type")) } if w.Result().Header.Get("X-Krakend") != "Version undefined" { t.Error(testData[0], "X-Krakend error:", w.Result().Header.Get("X-Krakend")) } if w.Result().StatusCode != http.StatusOK { t.Error(testData[0], "Unexpected status code:", w.Result().StatusCode) } if content != expectedContent { t.Error(testData[0], "Unexpected body:", content, "expected:", expectedContent) } } } func TestRender_string_noData(t *testing.T) { expectedContent := "" expectedHeader := "text/plain; charset=utf-8" for k, p := range []proxy.Proxy{ func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return &proxy.Response{ IsComplete: false, Data: map[string]interface{}{"content": 42}, }, nil }, func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return &proxy.Response{ IsComplete: false, Data: map[string]interface{}{}, }, nil }, func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return nil, nil }, } { endpoint := &config.EndpointConfig{ Timeout: time.Second, CacheTTL: 6 * time.Hour, QueryString: []string{"b"}, OutputEncoding: encoding.STRING, } gin.SetMode(gin.TestMode) server := gin.New() server.GET("/_gin_endpoint/:param", EndpointHandler(endpoint, p)) req, _ := http.NewRequest("GET", "http://127.0.0.1:8080/_gin_endpoint/a?b=1", io.NopCloser(&bytes.Buffer{})) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() server.ServeHTTP(w, req) defer w.Result().Body.Close() body, ioerr := io.ReadAll(w.Result().Body) if ioerr != nil { t.Error("reading response body:", ioerr) return } content := string(body) if w.Result().Header.Get("Content-Type") != expectedHeader { t.Error(k, "Content-Type error:", w.Result().Header.Get("Content-Type")) } if w.Result().Header.Get("X-Krakend") != "Version undefined" { t.Error(k, "X-Krakend error:", w.Result().Header.Get("X-Krakend")) } if w.Result().StatusCode != http.StatusOK { t.Error(k, "Unexpected status code:", w.Result().StatusCode) } if content != expectedContent { t.Error(k, "Unexpected body:", content, "expected:", expectedContent) } } } func TestRegisterRender(t *testing.T) { var total int expected := &proxy.Response{IsComplete: true, Data: map[string]interface{}{"a": "b"}} name := "test render" RegisterRender(name, func(_ *gin.Context, resp *proxy.Response) { *resp = *expected total++ }) subject := getRender(&config.EndpointConfig{OutputEncoding: name}) var c *gin.Context resp := proxy.Response{} subject(c, &resp) if !reflect.DeepEqual(resp, *expected) { t.Error("unexpected response", resp) } if total != 1 { t.Error("the render was called an unexpected amount of times:", total) } } func TestRender_noop(t *testing.T) { expectedContent := "supu" expectedHeader := "text/plain; charset=utf-8" expectedSetCookieValue := []string{"test1=test1", "test2=test2"} p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return &proxy.Response{ Metadata: proxy.Metadata{ StatusCode: 200, Headers: map[string][]string{ "Content-Type": {expectedHeader}, "Set-Cookie": {"test1=test1", "test2=test2"}, }, }, Io: bytes.NewBufferString(expectedContent), }, nil } endpoint := &config.EndpointConfig{ Timeout: time.Second, CacheTTL: 6 * time.Hour, QueryString: []string{"b"}, OutputEncoding: encoding.NOOP, } gin.SetMode(gin.TestMode) server := gin.New() server.GET("/_gin_endpoint/:param", EndpointHandler(endpoint, p)) req, _ := http.NewRequest("GET", "http://127.0.0.1:8080/_gin_endpoint/a?b=1", io.NopCloser(&bytes.Buffer{})) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() server.ServeHTTP(w, req) defer w.Result().Body.Close() body, ioerr := io.ReadAll(w.Result().Body) if ioerr != nil { t.Error("reading response body:", ioerr) return } content := string(body) if w.Result().Header.Get("Content-Type") != expectedHeader { t.Error("Content-Type error:", w.Result().Header.Get("Content-Type")) } if w.Result().Header.Get("X-Krakend") != "Version undefined" { t.Error("X-Krakend error:", w.Result().Header.Get("X-Krakend")) } if w.Result().StatusCode != http.StatusOK { t.Error("Unexpected status code:", w.Result().StatusCode) } if content != expectedContent { t.Error("Unexpected body:", content, "expected:", expectedContent) } gotCookie := w.Header()["Set-Cookie"] if !reflect.DeepEqual(gotCookie, expectedSetCookieValue) { t.Error("Unexpected Set-Cookie header:", gotCookie, "expected:", expectedSetCookieValue) } } func TestRender_noop_nilBody(t *testing.T) { expectedContent := "" expectedHeader := "" p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return &proxy.Response{IsComplete: true}, nil } endpoint := &config.EndpointConfig{ Timeout: time.Second, CacheTTL: 6 * time.Hour, QueryString: []string{"b"}, OutputEncoding: encoding.NOOP, } gin.SetMode(gin.TestMode) server := gin.New() server.GET("/_gin_endpoint/:param", EndpointHandler(endpoint, p)) req, _ := http.NewRequest("GET", "http://127.0.0.1:8080/_gin_endpoint/a?b=1", io.NopCloser(&bytes.Buffer{})) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() server.ServeHTTP(w, req) defer w.Result().Body.Close() body, ioerr := io.ReadAll(w.Result().Body) if ioerr != nil { t.Error("reading response body:", ioerr) return } content := string(body) if w.Result().Header.Get("Content-Type") != expectedHeader { t.Error("Content-Type error:", w.Result().Header.Get("Content-Type")) } if w.Result().Header.Get("X-Krakend") != "Version undefined" { t.Error("X-Krakend error:", w.Result().Header.Get("X-Krakend")) } if w.Result().StatusCode != http.StatusOK { t.Error("Unexpected status code:", w.Result().StatusCode) } if content != expectedContent { t.Error("Unexpected body:", content, "expected:", expectedContent) } } func TestRender_noop_nilResponse(t *testing.T) { p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return nil, nil } endpoint := &config.EndpointConfig{ Timeout: time.Second, CacheTTL: 6 * time.Hour, QueryString: []string{"b"}, OutputEncoding: encoding.NOOP, } gin.SetMode(gin.TestMode) server := gin.New() server.GET("/_gin_endpoint/:param", EndpointHandler(endpoint, p)) req, _ := http.NewRequest("GET", "http://127.0.0.1:8080/_gin_endpoint/a?b=1", io.NopCloser(&bytes.Buffer{})) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() server.ServeHTTP(w, req) if w.Result().Header.Get("Content-Type") != "" { t.Error("Content-Type error:", w.Result().Header.Get("Content-Type")) } if w.Result().Header.Get("X-Krakend") != "Version undefined" { t.Error("X-Krakend error:", w.Result().Header.Get("X-Krakend")) } if w.Result().StatusCode != http.StatusInternalServerError { t.Error("Unexpected status code:", w.Result().StatusCode) } } ================================================ FILE: router/gin/router.go ================================================ // SPDX-License-Identifier: Apache-2.0 /* Package gin provides some basic implementations for building routers based on gin-gonic/gin */ package gin import ( "context" "net/http" "sort" "strings" "sync" "github.com/gin-gonic/gin" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/core" "github.com/luraproject/lura/v2/logging" "github.com/luraproject/lura/v2/proxy" "github.com/luraproject/lura/v2/router" "github.com/luraproject/lura/v2/transport/http/server" ) const logPrefix = "[SERVICE: Gin]" // RunServerFunc is a func that will run the http Server with the given params. type RunServerFunc func(context.Context, config.ServiceConfig, http.Handler) error // Config is the struct that collects the parts the router should be builded from type Config struct { Engine *gin.Engine Middlewares []gin.HandlerFunc HandlerFactory HandlerFactory ProxyFactory proxy.Factory Logger logging.Logger RunServer RunServerFunc } // DefaultFactory returns a gin router factory with the injected proxy factory and logger. // It also uses a default gin router and the default HandlerFactory func DefaultFactory(proxyFactory proxy.Factory, logger logging.Logger) router.Factory { return NewFactory( Config{ Engine: gin.Default(), Middlewares: []gin.HandlerFunc{}, HandlerFactory: EndpointHandler, ProxyFactory: proxyFactory, Logger: logger, RunServer: server.RunServer, }, ) } // NewFactory returns a gin router factory with the injected configuration func NewFactory(cfg Config) router.Factory { return factory{cfg} } type factory struct { cfg Config } // New implements the factory interface func (rf factory) New() router.Router { return rf.NewWithContext(context.Background()) } // NewWithContext implements the factory interface func (rf factory) NewWithContext(ctx context.Context) router.Router { return ginRouter{ cfg: rf.cfg, ctx: ctx, runServerF: rf.cfg.RunServer, mu: new(sync.Mutex), urlCatalog: urlCatalog{ mu: new(sync.Mutex), catalog: map[string][]string{}, }, } } type ginRouter struct { cfg Config ctx context.Context runServerF RunServerFunc mu *sync.Mutex urlCatalog urlCatalog } type urlCatalog struct { mu *sync.Mutex catalog map[string][]string } // Run completes the router initialization and executes it func (r ginRouter) Run(cfg config.ServiceConfig) { r.mu.Lock() defer r.mu.Unlock() server.InitHTTPDefaultTransport(cfg) r.registerEndpointsAndMiddlewares(cfg) r.cfg.Logger.Info("[SERVICE: Gin] Listening on port:", cfg.Port) if err := r.runServerF(r.ctx, cfg, &safeCaster{h: r.cfg.Engine.Handler()}); err != nil && err != http.ErrServerClosed { r.cfg.Logger.Error(logPrefix, err.Error()) } r.cfg.Logger.Info(logPrefix, "Router execution ended") } func (r ginRouter) registerEndpointsAndMiddlewares(cfg config.ServiceConfig) { if cfg.Debug { r.cfg.Engine.Any("/__debug/*param", DebugHandler(r.cfg.Logger)) } if cfg.Echo { r.cfg.Engine.Any("/__echo/*param", EchoHandler()) } endpointGroup := r.cfg.Engine.Group("/") endpointGroup.Use(r.cfg.Middlewares...) r.registerKrakendEndpoints(endpointGroup, cfg) if opts, ok := cfg.ExtraConfig[Namespace].(map[string]interface{}); ok { if v, ok := opts["auto_options"].(bool); ok && v { r.cfg.Logger.Debug(logPrefix, "Enabling the auto options endpoints") r.registerOptionEndpoints(endpointGroup) } } } func (r ginRouter) registerKrakendEndpoints(rg *gin.RouterGroup, cfg config.ServiceConfig) { proxyBuildFailedHandler := func(c *gin.Context) { c.AbortWithStatus(http.StatusInternalServerError) } // build and register the pipes and endpoints sequentially for _, c := range cfg.Endpoints { proxyStack, err := r.cfg.ProxyFactory.New(c) if err != nil { r.cfg.Logger.Error(logPrefix, "Calling the ProxyFactory", err.Error()) r.registerKrakendEndpoint(rg, c.Method, c, proxyBuildFailedHandler, 1) continue } r.registerKrakendEndpoint(rg, c.Method, c, r.cfg.HandlerFactory(c, proxyStack), len(c.Backend)) } } func (r ginRouter) registerKrakendEndpoint(rg *gin.RouterGroup, method string, e *config.EndpointConfig, h gin.HandlerFunc, total int) { method = strings.ToTitle(method) path := e.Endpoint if method != http.MethodGet && total > 1 { if !router.IsValidSequentialEndpoint(e) { r.cfg.Logger.Error(logPrefix, method, "endpoints with sequential proxy enabled only allow a non-GET in the last backend! Ignoring", path) return } } switch method { case http.MethodGet: rg.GET(path, h) case http.MethodPost: rg.POST(path, h) case http.MethodPut: rg.PUT(path, h) case http.MethodPatch: rg.PATCH(path, h) case http.MethodDelete: rg.DELETE(path, h) default: r.cfg.Logger.Error(logPrefix, "[ENDPOINT:", path, "] Unsupported method", method) return } r.urlCatalog.mu.Lock() defer r.urlCatalog.mu.Unlock() methods, ok := r.urlCatalog.catalog[path] if !ok { r.urlCatalog.catalog[path] = []string{method} return } r.urlCatalog.catalog[path] = append(methods, method) } func (r ginRouter) registerOptionEndpoints(rg *gin.RouterGroup) { r.urlCatalog.mu.Lock() defer r.urlCatalog.mu.Unlock() for path, methods := range r.urlCatalog.catalog { sort.Strings(methods) allowed := strings.Join(methods, ", ") rg.OPTIONS(path, func(c *gin.Context) { c.Header("Allow", allowed) c.Header(core.KrakendHeaderName, core.KrakendHeaderValue) }) } } ================================================ FILE: router/gin/router_test.go ================================================ //go:build !race // +build !race // SPDX-License-Identifier: Apache-2.0 package gin import ( "bytes" "context" "errors" "fmt" "io" "net/http" "regexp" "strings" "testing" "time" "github.com/gin-gonic/gin" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" "github.com/luraproject/lura/v2/proxy" "github.com/luraproject/lura/v2/transport/http/server" ) func TestDefaultFactory_ok(t *testing.T) { buff := bytes.NewBuffer(make([]byte, 1024)) logger, err := logging.NewLogger("ERROR", buff, "pref") if err != nil { t.Error("building the logger:", err.Error()) return } ctx, cancel := context.WithCancel(context.Background()) defer func() { cancel() time.Sleep(5 * time.Millisecond) }() r := DefaultFactory(noopProxyFactory(map[string]interface{}{"supu": "tupu"}), logger).NewWithContext(ctx) expectedBody := "{\"supu\":\"tupu\"}" serviceCfg := config.ServiceConfig{ Port: 8072, Endpoints: []*config.EndpointConfig{ { Endpoint: "/some", Method: "GET", Timeout: 10, Backend: []*config.Backend{ {}, }, }, { Endpoint: "/some", Method: "post", Timeout: 10, Backend: []*config.Backend{ {}, }, }, { Endpoint: "/some", Method: "put", Timeout: 10, Backend: []*config.Backend{ {}, }, }, { Endpoint: "/some", Method: "PATCH", Timeout: 10, Backend: []*config.Backend{ {}, }, }, { Endpoint: "/some", Method: "DELETE", Timeout: 10, Backend: []*config.Backend{ {}, }, }, }, ExtraConfig: map[string]interface{}{ Namespace: map[string]interface{}{ "auto_options": true, }, }, } go func() { r.Run(serviceCfg) }() time.Sleep(5 * time.Millisecond) for _, endpoint := range serviceCfg.Endpoints { req, _ := http.NewRequest(strings.ToTitle(endpoint.Method), fmt.Sprintf("http://127.0.0.1:8072%s", endpoint.Endpoint), http.NoBody) req.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(req) if err != nil { t.Error("Making the request:", err.Error()) return } defer resp.Body.Close() body, ioerr := io.ReadAll(resp.Body) if ioerr != nil { t.Error("Reading the response:", ioerr.Error()) return } content := string(body) if resp.Header.Get("Cache-Control") != "" { t.Error("Cache-Control error:", resp.Header.Get("Cache-Control")) } if resp.Header.Get(server.CompleteResponseHeaderName) != server.HeaderCompleteResponseValue { t.Error(server.CompleteResponseHeaderName, "error:", resp.Header.Get(server.CompleteResponseHeaderName)) } if resp.Header.Get("Content-Type") != "application/json; charset=utf-8" { t.Error("Content-Type error:", resp.Header.Get("Content-Type")) } if resp.Header.Get("X-Krakend") != "Version undefined" { t.Error("X-Krakend error:", resp.Header.Get("X-Krakend")) } if resp.StatusCode != http.StatusOK { t.Error("Unexpected status code:", resp.StatusCode) } if content != expectedBody { t.Error("Unexpected body:", content, "expected:", expectedBody) } } req, _ := http.NewRequest("OPTIONS", "http://127.0.0.1:8072/some", http.NoBody) req.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(req) if err != nil { t.Error("Making the request:", err.Error()) return } if allow := resp.Header.Get("Allow"); allow != "DELETE, GET, PATCH, POST, PUT" { t.Errorf("unexpected options response: '%s'", allow) } } func TestDefaultFactory_ko(t *testing.T) { buff := bytes.NewBuffer(make([]byte, 1024)) logger, err := logging.NewLogger("ERROR", buff, "pref") if err != nil { t.Error("building the logger:", err.Error()) return } ctx, cancel := context.WithCancel(context.Background()) defer func() { cancel() time.Sleep(5 * time.Millisecond) }() r := DefaultFactory(noopProxyFactory(map[string]interface{}{"supu": "tupu"}), logger).NewWithContext(ctx) serviceCfg := config.ServiceConfig{ Debug: true, Echo: true, Port: 8073, Endpoints: []*config.EndpointConfig{ { Endpoint: "/ignored", Method: "GETTT", Backend: []*config.Backend{ {}, }, }, { Endpoint: "/empty", Method: "GETTT", Backend: []*config.Backend{}, }, { Endpoint: "/also-ignored", Method: "PUTT", Backend: []*config.Backend{ {}, {}, }, }, }, } go func() { r.Run(serviceCfg) }() time.Sleep(5 * time.Millisecond) for _, subject := range [][]string{ {"GET", "ignored"}, {"GET", "empty"}, {"PUT", "also-ignored"}, } { req, _ := http.NewRequest(subject[0], fmt.Sprintf("http://127.0.0.1:8073/%s", subject[1]), http.NoBody) req.Header.Set("Content-Type", "application/json") checkResponseIs404(t, req) } } func TestDefaultFactory_proxyFactoryCrash(t *testing.T) { buff := bytes.NewBuffer(make([]byte, 1024)) logger, err := logging.NewLogger("ERROR", buff, "pref") if err != nil { t.Error("building the logger:", err.Error()) return } ctx, cancel := context.WithCancel(context.Background()) defer func() { cancel() time.Sleep(5 * time.Millisecond) }() r := DefaultFactory(erroredProxyFactory{fmt.Errorf("%s", "crash!!!")}, logger).NewWithContext(ctx) serviceCfg := config.ServiceConfig{ Debug: true, Echo: true, Port: 8074, Endpoints: []*config.EndpointConfig{ { Endpoint: "/ignored", Method: "GET", Timeout: 10, Backend: []*config.Backend{ {}, }, }, }, } go func() { r.Run(serviceCfg) }() time.Sleep(5 * time.Millisecond) req, _ := http.NewRequest("GET", "http://127.0.0.1:8074/ignored", http.NoBody) req.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(req) if err != nil { t.Error("Making the request:", err.Error()) return } defer resp.Body.Close() body, ioerr := io.ReadAll(resp.Body) if ioerr != nil { t.Error("Reading the response:", ioerr.Error()) return } if resp.Header.Get("Cache-Control") != "" { t.Error(req.URL.String(), "Cache-Control error:", resp.Header.Get("Cache-Control")) } if resp.Header.Get("Content-Type") != "" { t.Error(req.URL.String(), "Content-Type error:", resp.Header.Get("Content-Type")) } if resp.Header.Get("X-Krakend") != "" { t.Error(req.URL.String(), "X-Krakend error:", resp.Header.Get("X-Krakend")) } if resp.StatusCode != http.StatusInternalServerError { t.Error(req.URL.String(), "Unexpected status code:", resp.StatusCode) } if string(body) != "" { t.Error(req.URL.String(), "Unexpected body:", string(body)) } } func TestRunServer_ko(t *testing.T) { buff := new(bytes.Buffer) logger, err := logging.NewLogger("ERROR", buff, "") if err != nil { t.Error("building the logger:", err.Error()) return } errorMsg := "runServer error" runServerFunc := func(_ context.Context, _ config.ServiceConfig, _ http.Handler) error { return errors.New(errorMsg) } pf := noopProxyFactory(map[string]interface{}{"supu": "tupu"}) r := NewFactory( Config{ Engine: gin.Default(), Middlewares: []gin.HandlerFunc{}, HandlerFactory: EndpointHandler, ProxyFactory: pf, Logger: logger, RunServer: runServerFunc, }, ).New() serviceCfg := config.ServiceConfig{} r.Run(serviceCfg) re := regexp.MustCompile(errorMsg) if !re.MatchString(buff.String()) { t.Errorf("the logger doesn't contain the expected msg: %s", buff.Bytes()) } } func checkResponseIs404(t *testing.T, req *http.Request) { expectedBody := "404 page not found" resp, err := http.DefaultClient.Do(req) if err != nil { t.Error("Making the request:", err.Error()) return } defer resp.Body.Close() body, ioerr := io.ReadAll(resp.Body) if ioerr != nil { t.Error("Reading the response:", ioerr.Error()) return } content := string(body) if resp.Header.Get("Cache-Control") != "" { t.Error(req.URL.String(), "Cache-Control error:", resp.Header.Get("Cache-Control")) } if resp.Header.Get("Content-Type") != "text/plain" { t.Error(req.URL.String(), "Content-Type error:", resp.Header.Get("Content-Type")) } if resp.Header.Get("X-Krakend") != "" { t.Error(req.URL.String(), "X-Krakend error:", resp.Header.Get("X-Krakend")) } if resp.StatusCode != http.StatusNotFound { t.Error(req.URL.String(), "Unexpected status code:", resp.StatusCode) } if content != expectedBody { t.Error(req.URL.String(), "Unexpected body:", content, "expected:", expectedBody) } } type noopProxyFactory map[string]interface{} func (n noopProxyFactory) New(_ *config.EndpointConfig) (proxy.Proxy, error) { return func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return &proxy.Response{ IsComplete: true, Data: n, }, nil }, nil } type erroredProxyFactory struct { Error error } func (e erroredProxyFactory) New(_ *config.EndpointConfig) (proxy.Proxy, error) { return proxy.NoopProxy, e.Error } ================================================ FILE: router/gin/safecast.go ================================================ // SPDX-License-Identifier: Apache-2.0 /* Package gin provides some basic implementations for building routers based on gin-gonic/gin */ package gin import ( "bufio" "errors" "net" "net/http" ) var _ http.ResponseWriter = (*safeCast)(nil) var _ http.Flusher = (*safeCast)(nil) var _ http.Hijacker = (*safeCast)(nil) var _ http.CloseNotifier = (*safeCast)(nil) var _ http.Handler = (*safeCaster)(nil) // safeCast provides fallback implementation for interfaces that are not // checked internally by gin, and that can cause a panic: // - Flusher // - Hijacker // - Notifier type safeCast struct { w http.ResponseWriter } func (s *safeCast) Header() http.Header { return s.w.Header() } func (s *safeCast) Write(b []byte) (int, error) { return s.w.Write(b) } func (s *safeCast) WriteHeader(statusCode int) { s.w.WriteHeader(statusCode) } func (s *safeCast) Flush() { if f, ok := s.w.(http.Flusher); ok { f.Flush() } } func (s *safeCast) Hijack() (net.Conn, *bufio.ReadWriter, error) { if h, ok := s.w.(http.Hijacker); ok { return h.Hijack() } return nil, nil, errors.New("not supported") } func (s *safeCast) CloseNotify() <-chan bool { if h, ok := s.w.(http.CloseNotifier); ok { return h.CloseNotify() } return make(<-chan bool, 1) } type safeCaster struct { h http.Handler } func (s *safeCaster) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.h.ServeHTTP(&safeCast{w}, r) } ================================================ FILE: router/gorilla/router.go ================================================ // SPDX-License-Identifier: Apache-2.0 /* Package gorilla provides some basic implementations for building routers based on gorilla/mux */ package gorilla import ( "net/http" gorilla "github.com/gorilla/mux" "github.com/luraproject/lura/v2/logging" "github.com/luraproject/lura/v2/proxy" "github.com/luraproject/lura/v2/router" "github.com/luraproject/lura/v2/router/mux" "github.com/luraproject/lura/v2/transport/http/server" "golang.org/x/text/cases" "golang.org/x/text/language" ) // DefaultFactory returns a net/http mux router factory with the injected proxy factory and logger func DefaultFactory(pf proxy.Factory, logger logging.Logger) router.Factory { return mux.NewFactory(DefaultConfig(pf, logger)) } // DefaultConfig returns the struct that collects the parts the router should be builded from func DefaultConfig(pf proxy.Factory, logger logging.Logger) mux.Config { return mux.Config{ Engine: gorillaEngine{gorilla.NewRouter()}, Middlewares: []mux.HandlerMiddleware{}, HandlerFactory: mux.CustomEndpointHandler(mux.NewRequestBuilder(gorillaParamsExtractor)), ProxyFactory: pf, Logger: logger, DebugPattern: "/__debug/{params}", EchoPattern: "/__echo/{params}", RunServer: server.RunServer, } } func gorillaParamsExtractor(r *http.Request) map[string]string { params := map[string]string{} title := cases.Title(language.Und) for key, value := range gorilla.Vars(r) { params[title.String(key)] = value } return params } type gorillaEngine struct { r *gorilla.Router } // Handle implements the mux.Engine interface from the lura router package func (g gorillaEngine) Handle(pattern, method string, handler http.Handler) { g.r.Handle(pattern, handler).Methods(method) } // ServeHTTP implements the http:Handler interface from the stdlib func (g gorillaEngine) ServeHTTP(w http.ResponseWriter, r *http.Request) { g.r.ServeHTTP(mux.NewHTTPErrorInterceptor(w), r) } ================================================ FILE: router/gorilla/router_test.go ================================================ //go:build !race // +build !race // SPDX-License-Identifier: Apache-2.0 package gorilla import ( "bytes" "context" "fmt" "io" "net/http" "testing" "time" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" "github.com/luraproject/lura/v2/proxy" "github.com/luraproject/lura/v2/transport/http/server" ) func TestDefaultFactory_ok(t *testing.T) { buff := bytes.NewBuffer(make([]byte, 1024)) logger, err := logging.NewLogger("ERROR", buff, "pref") if err != nil { t.Error("building the logger:", err.Error()) return } ctx, cancel := context.WithCancel(context.Background()) defer func() { cancel() time.Sleep(5 * time.Millisecond) }() r := DefaultFactory(noopProxyFactory(map[string]interface{}{"supu": "tupu"}), logger).NewWithContext(ctx) expectedBody := "{\"supu\":\"tupu\"}" serviceCfg := config.ServiceConfig{ Port: 8082, Endpoints: []*config.EndpointConfig{ { Endpoint: "/get/{id}", Method: "GET", Timeout: 10, Backend: []*config.Backend{ {}, }, }, { Endpoint: "/post", Method: "POST", Timeout: 10, Backend: []*config.Backend{ {}, }, }, { Endpoint: "/put", Method: "PUT", Timeout: 10, Backend: []*config.Backend{ {}, }, }, { Endpoint: "/patch", Method: "PATCH", Timeout: 10, Backend: []*config.Backend{ {}, }, }, { Endpoint: "/delete", Method: "DELETE", Timeout: 10, Backend: []*config.Backend{ {}, }, }, }, } go func() { r.Run(serviceCfg) }() time.Sleep(5 * time.Millisecond) for _, endpoint := range serviceCfg.Endpoints { req, _ := http.NewRequest(endpoint.Method, fmt.Sprintf("http://127.0.0.1:8082%s", endpoint.Endpoint), http.NoBody) req.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(req) if err != nil { t.Error("Making the request:", err.Error()) return } defer resp.Body.Close() body, ioerr := io.ReadAll(resp.Body) if ioerr != nil { t.Error("Reading the response:", ioerr.Error()) return } content := string(body) if resp.Header.Get("Cache-Control") != "" { t.Error(endpoint.Endpoint, "Cache-Control error:", resp.Header.Get("Cache-Control")) } if resp.Header.Get(server.CompleteResponseHeaderName) != server.HeaderCompleteResponseValue { t.Error(server.CompleteResponseHeaderName, "error:", resp.Header.Get(server.CompleteResponseHeaderName)) } if resp.Header.Get("Content-Type") != "application/json" { t.Error(endpoint.Endpoint, "Content-Type error:", resp.Header.Get("Content-Type")) } if resp.Header.Get("X-Krakend") != "Version undefined" { t.Error(endpoint.Endpoint, "X-Krakend error:", resp.Header.Get("X-Krakend")) } if resp.StatusCode != http.StatusOK { t.Error(endpoint.Endpoint, "Unexpected status code:", resp.StatusCode) } if content != expectedBody { t.Error(endpoint.Endpoint, "Unexpected body:", content, "expected:", expectedBody) } } } func TestDefaultFactory_ko(t *testing.T) { buff := bytes.NewBuffer(make([]byte, 1024)) logger, err := logging.NewLogger("ERROR", buff, "pref") if err != nil { t.Error("building the logger:", err.Error()) return } ctx, cancel := context.WithCancel(context.Background()) defer func() { cancel() time.Sleep(5 * time.Millisecond) }() r := DefaultFactory(noopProxyFactory(map[string]interface{}{"supu": "tupu"}), logger).NewWithContext(ctx) serviceCfg := config.ServiceConfig{ Debug: true, Port: 8083, Endpoints: []*config.EndpointConfig{ { Endpoint: "/ignored", Method: "GETTT", Backend: []*config.Backend{ {}, }, }, { Endpoint: "/empty", Method: "GETTT", Backend: []*config.Backend{}, }, { Endpoint: "/also-ignored", Method: "PUT", Backend: []*config.Backend{ {}, {}, }, }, }, } go func() { r.Run(serviceCfg) }() time.Sleep(5 * time.Millisecond) for _, subject := range [][]string{ {"GET", "ignored"}, {"GET", "empty"}, {"PUT", "also-ignored"}, } { req, _ := http.NewRequest(subject[0], fmt.Sprintf("http://127.0.0.1:8083/%s", subject[1]), http.NoBody) req.Header.Set("Content-Type", "application/json") checkResponseIs404(t, req) } } func TestDefaultFactory_proxyFactoryCrash(t *testing.T) { buff := bytes.NewBuffer(make([]byte, 1024)) logger, err := logging.NewLogger("ERROR", buff, "pref") if err != nil { t.Error("building the logger:", err.Error()) return } ctx, cancel := context.WithCancel(context.Background()) defer func() { cancel() time.Sleep(5 * time.Millisecond) }() r := DefaultFactory(erroredProxyFactory{fmt.Errorf("%s", "crash!!!")}, logger).NewWithContext(ctx) serviceCfg := config.ServiceConfig{ Debug: true, Echo: true, Port: 8084, Endpoints: []*config.EndpointConfig{ { Endpoint: "/ignored", Method: "GET", Timeout: 10, Backend: []*config.Backend{ {}, }, }, }, } go func() { r.Run(serviceCfg) }() time.Sleep(5 * time.Millisecond) for _, subject := range [][]string{{"GET", "ignored"}, {"PUT", "also-ignored"}} { req, _ := http.NewRequest(subject[0], fmt.Sprintf("http://127.0.0.1:8084/%s", subject[1]), http.NoBody) req.Header.Set("Content-Type", "application/json") checkResponseIs404(t, req) } } func checkResponseIs404(t *testing.T, req *http.Request) { expectedBody := "404 page not found\n" resp, err := http.DefaultClient.Do(req) if err != nil { t.Error("Making the request:", err.Error()) return } defer resp.Body.Close() body, ioerr := io.ReadAll(resp.Body) if ioerr != nil { t.Error("Reading the response:", ioerr.Error()) return } content := string(body) if resp.Header.Get("Cache-Control") != "" { t.Error("Cache-Control error:", resp.Header.Get("Cache-Control")) } if resp.Header.Get(server.CompleteResponseHeaderName) != server.HeaderIncompleteResponseValue { t.Error(req.URL.String(), server.CompleteResponseHeaderName, "error:", resp.Header.Get(server.CompleteResponseHeaderName)) } if resp.Header.Get("Content-Type") != "text/plain; charset=utf-8" { t.Error("Content-Type error:", resp.Header.Get("Content-Type")) } if resp.Header.Get("X-Krakend") != "" { t.Error("X-Krakend error:", resp.Header.Get("X-Krakend")) } if resp.StatusCode != http.StatusNotFound { t.Error("Unexpected status code:", resp.StatusCode) } if content != expectedBody { t.Error("Unexpected body:", content, "expected:", expectedBody) } } type noopProxyFactory map[string]interface{} func (n noopProxyFactory) New(_ *config.EndpointConfig) (proxy.Proxy, error) { return func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return &proxy.Response{ IsComplete: true, Data: n, }, nil }, nil } type erroredProxyFactory struct { Error error } func (e erroredProxyFactory) New(_ *config.EndpointConfig) (proxy.Proxy, error) { return proxy.NoopProxy, e.Error } type identityMiddleware struct{} func (identityMiddleware) Handler(h http.Handler) http.Handler { return h } ================================================ FILE: router/helper.go ================================================ // SPDX-License-Identifier: Apache-2.0 package router import ( "github.com/luraproject/lura/v2/config" ) func IsValidSequentialEndpoint(_ *config.EndpointConfig) bool { // if endpoint.ExtraConfig[proxy.Namespace] == nil { // return false // } // proxyCfg := endpoint.ExtraConfig[proxy.Namespace].(map[string]interface{}) // if proxyCfg["sequential"] == false { // return false // } // for i, backend := range endpoint.Backend { // if backend.Method != http.MethodGet && (i+1) != len(endpoint.Backend) { // return false // } // } return true } ================================================ FILE: router/helper_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package router // func TestIsValidSequentialEndpoint_ok(t *testing.T) { // endpoint := &config.EndpointConfig{ // Endpoint: "/correct", // Method: "PUT", // Backend: []*config.Backend{ // { // Method: "GET", // }, // { // Method: "PUT", // }, // }, // ExtraConfig: map[string]interface{}{ // proxy.Namespace: map[string]interface{}{ // "sequential": true, // }, // }, // } // success := IsValidSequentialEndpoint(endpoint) // if !success { // t.Error("Endpoint expected valid but receive invalid") // } // } // func TestIsValidSequentialEndpoint_wrong_config_not_given(t *testing.T) { // endpoint := &config.EndpointConfig{ // Endpoint: "/correct", // Method: "PUT", // Backend: []*config.Backend{ // { // Method: "GET", // }, // { // Method: "PUT", // }, // }, // ExtraConfig: map[string]interface{}{}, // } // success := IsValidSequentialEndpoint(endpoint) // if success { // t.Error("Endpoint expected invalid but receive valid") // } // } // func TestIsValidSequentialEndpoint_wrong_config_set_false(t *testing.T) { // endpoint := &config.EndpointConfig{ // Endpoint: "/correct", // Method: "PUT", // Backend: []*config.Backend{ // { // Method: "GET", // }, // { // Method: "PUT", // }, // }, // ExtraConfig: map[string]interface{}{ // proxy.Namespace: map[string]interface{}{ // "sequential": false, // }, // }} // success := IsValidSequentialEndpoint(endpoint) // if success { // t.Error("Endpoint expected invalid but receive valid") // } // } // func TestIsValidSequentialEndpoint_wrong_order(t *testing.T) { // endpoint := &config.EndpointConfig{ // Endpoint: "/correct", // Method: "PUT", // Backend: []*config.Backend{ // { // Method: "PUT", // }, // { // Method: "GET", // }, // }, // ExtraConfig: map[string]interface{}{ // proxy.Namespace: map[string]interface{}{ // "sequential": true, // }, // }, // } // success := IsValidSequentialEndpoint(endpoint) // if success { // t.Error("Endpoint expected invalid but receive valid") // } // } // func TestIsValidSequentialEndpoint_wrong_all_non_get(t *testing.T) { // endpoint := &config.EndpointConfig{ // Endpoint: "/correct", // Method: "PUT", // Backend: []*config.Backend{ // { // Method: "POST", // }, // { // Method: "PUT", // }, // }, // ExtraConfig: map[string]interface{}{ // proxy.Namespace: map[string]interface{}{ // "sequential": true, // }, // }, // } // success := IsValidSequentialEndpoint(endpoint) // if success { // t.Error("Endpoint expected invalid but receive valid") // } // } ================================================ FILE: router/httptreemux/router.go ================================================ // SPDX-License-Identifier: Apache-2.0 /* Package httptreemux provides some basic implementations for building routers based on dimfeld/httptreemux */ package httptreemux import ( "net/http" "github.com/dimfeld/httptreemux/v5" "github.com/luraproject/lura/v2/logging" "github.com/luraproject/lura/v2/proxy" "github.com/luraproject/lura/v2/router" "github.com/luraproject/lura/v2/router/mux" "github.com/luraproject/lura/v2/transport/http/server" "golang.org/x/text/cases" "golang.org/x/text/language" ) // DefaultFactory returns a net/http mux router factory with the injected proxy factory and logger func DefaultFactory(pf proxy.Factory, logger logging.Logger) router.Factory { return mux.NewFactory(DefaultConfig(pf, logger)) } // DefaultConfig returns the struct that collects the parts the router should be built from func DefaultConfig(pf proxy.Factory, logger logging.Logger) mux.Config { return mux.Config{ Engine: NewEngine(httptreemux.NewContextMux()), Middlewares: []mux.HandlerMiddleware{}, HandlerFactory: mux.CustomEndpointHandler(mux.NewRequestBuilder(ParamsExtractor)), ProxyFactory: pf, Logger: logger, DebugPattern: "/__debug/{params}", RunServer: server.RunServer, } } func ParamsExtractor(r *http.Request) map[string]string { params := map[string]string{} title := cases.Title(language.Und) for key, value := range httptreemux.ContextParams(r.Context()) { params[title.String(key)] = value } return params } func NewEngine(m *httptreemux.ContextMux) Engine { return Engine{m} } type Engine struct { r *httptreemux.ContextMux } // Handle implements the mux.Engine interface from the lura router package func (g Engine) Handle(pattern, method string, handler http.Handler) { g.r.Handle(method, pattern, handler.(http.HandlerFunc)) } // ServeHTTP implements the http:Handler interface from the stdlib func (g Engine) ServeHTTP(w http.ResponseWriter, r *http.Request) { g.r.ServeHTTP(mux.NewHTTPErrorInterceptor(w), r) } ================================================ FILE: router/httptreemux/router_test.go ================================================ //go:build !race // +build !race // SPDX-License-Identifier: Apache-2.0 package httptreemux import ( "bytes" "context" "fmt" "io" "net/http" "testing" "time" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" "github.com/luraproject/lura/v2/proxy" "github.com/luraproject/lura/v2/transport/http/server" ) func TestDefaultFactory_ok(t *testing.T) { buff := bytes.NewBuffer(make([]byte, 1024)) logger, err := logging.NewLogger("ERROR", buff, "pref") if err != nil { t.Error("building the logger:", err.Error()) return } ctx, cancel := context.WithCancel(context.Background()) defer func() { cancel() time.Sleep(5 * time.Millisecond) }() r := DefaultFactory(noopProxyFactory(map[string]interface{}{"supu": "tupu"}), logger).NewWithContext(ctx) expectedBody := "{\"supu\":\"tupu\"}" serviceCfg := config.ServiceConfig{ Port: 8082, Endpoints: []*config.EndpointConfig{ { Endpoint: "/get/:id", Method: "GET", Timeout: 10, Backend: []*config.Backend{ {}, }, }, { Endpoint: "/post", Method: "POST", Timeout: 10, Backend: []*config.Backend{ {}, }, }, { Endpoint: "/put", Method: "PUT", Timeout: 10, Backend: []*config.Backend{ {}, }, }, { Endpoint: "/patch", Method: "PATCH", Timeout: 10, Backend: []*config.Backend{ {}, }, }, { Endpoint: "/delete", Method: "DELETE", Timeout: 10, Backend: []*config.Backend{ {}, }, }, }, } go func() { r.Run(serviceCfg) }() <-time.After(5 * time.Millisecond) for _, endpoint := range serviceCfg.Endpoints { req, _ := http.NewRequest(endpoint.Method, fmt.Sprintf("http://127.0.0.1:8082%s", endpoint.Endpoint), http.NoBody) req.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(req) if err != nil { t.Error("Making the request:", err.Error()) return } defer resp.Body.Close() body, ioerr := io.ReadAll(resp.Body) if ioerr != nil { t.Error("Reading the response:", ioerr.Error()) return } content := string(body) if resp.Header.Get("Cache-Control") != "" { t.Error(endpoint.Endpoint, "Cache-Control error:", resp.Header.Get("Cache-Control")) } if resp.Header.Get(server.CompleteResponseHeaderName) != server.HeaderCompleteResponseValue { t.Error(server.CompleteResponseHeaderName, "error:", resp.Header.Get(server.CompleteResponseHeaderName)) } if resp.Header.Get("Content-Type") != "application/json" { t.Error(endpoint.Endpoint, "Content-Type error:", resp.Header.Get("Content-Type")) } if resp.Header.Get("X-Krakend") != "Version undefined" { t.Error(endpoint.Endpoint, "X-Krakend error:", resp.Header.Get("X-Krakend")) } if resp.StatusCode != http.StatusOK { t.Error(endpoint.Endpoint, "Unexpected status code:", resp.StatusCode) } if content != expectedBody { t.Error(endpoint.Endpoint, "Unexpected body:", content, "expected:", expectedBody) } } } func TestDefaultFactory_ko(t *testing.T) { buff := bytes.NewBuffer(make([]byte, 1024)) logger, err := logging.NewLogger("ERROR", buff, "pref") if err != nil { t.Error("building the logger:", err.Error()) return } ctx, cancel := context.WithCancel(context.Background()) defer func() { cancel() time.Sleep(5 * time.Millisecond) }() r := DefaultFactory(noopProxyFactory(map[string]interface{}{"supu": "tupu"}), logger).NewWithContext(ctx) serviceCfg := config.ServiceConfig{ Debug: true, Port: 8083, Endpoints: []*config.EndpointConfig{ { Endpoint: "/ignored", Method: "GETTT", Backend: []*config.Backend{ {}, }, }, { Endpoint: "/empty", Method: "GETTT", Backend: []*config.Backend{}, }, { Endpoint: "/also-ignored", Method: "PUT", Backend: []*config.Backend{ {}, {}, }, }, }, } go func() { r.Run(serviceCfg) }() time.Sleep(5 * time.Millisecond) for _, subject := range [][]string{ {"GET", "ignored"}, {"GET", "empty"}, {"PUT", "also-ignored"}, } { req, _ := http.NewRequest(subject[0], fmt.Sprintf("http://127.0.0.1:8083/%s", subject[1]), http.NoBody) req.Header.Set("Content-Type", "application/json") checkResponseIs404(t, req) } } func TestDefaultFactory_proxyFactoryCrash(t *testing.T) { buff := bytes.NewBuffer(make([]byte, 1024)) logger, err := logging.NewLogger("ERROR", buff, "pref") if err != nil { t.Error("building the logger:", err.Error()) return } ctx, cancel := context.WithCancel(context.Background()) defer func() { cancel() time.Sleep(5 * time.Millisecond) }() r := DefaultFactory(erroredProxyFactory{fmt.Errorf("%s", "crash!!!")}, logger).NewWithContext(ctx) serviceCfg := config.ServiceConfig{ Debug: true, Port: 8084, Endpoints: []*config.EndpointConfig{ { Endpoint: "/ignored", Method: "GET", Timeout: 10, Backend: []*config.Backend{ {}, }, }, }, } go func() { r.Run(serviceCfg) }() time.Sleep(5 * time.Millisecond) for _, subject := range [][]string{{"GET", "ignored"}, {"PUT", "also-ignored"}} { req, _ := http.NewRequest(subject[0], fmt.Sprintf("http://127.0.0.1:8084/%s", subject[1]), http.NoBody) req.Header.Set("Content-Type", "application/json") checkResponseIs404(t, req) } } func checkResponseIs404(t *testing.T, req *http.Request) { expectedBody := "404 page not found\n" resp, err := http.DefaultClient.Do(req) if err != nil { t.Error("Making the request:", err.Error()) return } defer resp.Body.Close() body, ioerr := io.ReadAll(resp.Body) if ioerr != nil { t.Error("Reading the response:", ioerr.Error()) return } content := string(body) if resp.Header.Get("Cache-Control") != "" { t.Error("Cache-Control error:", resp.Header.Get("Cache-Control")) } if resp.Header.Get(server.CompleteResponseHeaderName) != server.HeaderIncompleteResponseValue { t.Error(req.URL.String(), server.CompleteResponseHeaderName, "error:", resp.Header.Get(server.CompleteResponseHeaderName)) } if resp.Header.Get("Content-Type") != "text/plain; charset=utf-8" { t.Error("Content-Type error:", resp.Header.Get("Content-Type")) } if resp.Header.Get("X-Krakend") != "" { t.Error("X-Krakend error:", resp.Header.Get("X-Krakend")) } if resp.StatusCode != http.StatusNotFound { t.Error("Unexpected status code:", resp.StatusCode) } if content != expectedBody { t.Error("Unexpected body:", content, "expected:", expectedBody) } } type noopProxyFactory map[string]interface{} func (n noopProxyFactory) New(_ *config.EndpointConfig) (proxy.Proxy, error) { return func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return &proxy.Response{ IsComplete: true, Data: n, }, nil }, nil } type erroredProxyFactory struct { Error error } func (e erroredProxyFactory) New(_ *config.EndpointConfig) (proxy.Proxy, error) { return proxy.NoopProxy, e.Error } type identityMiddleware struct{} func (identityMiddleware) Handler(h http.Handler) http.Handler { return h } ================================================ FILE: router/mux/debug.go ================================================ // SPDX-License-Identifier: Apache-2.0 package mux import ( "encoding/json" "io" "net/http" "github.com/luraproject/lura/v2/logging" ) // DebugHandler creates a dummy handler function, useful for quick integration tests func DebugHandler(logger logging.Logger) http.HandlerFunc { logPrefixSecondary := "[ENDPOINT /__debug/*]" return func(w http.ResponseWriter, r *http.Request) { logger.Debug(logPrefixSecondary, "Method:", r.Method) logger.Debug(logPrefixSecondary, "URL:", r.RequestURI) logger.Debug(logPrefixSecondary, "Query:", r.URL.Query()) // logger.Debug(logPrefixSecondary, "Params:", c.Params) logger.Debug(logPrefixSecondary, "Headers:", r.Header) body, _ := io.ReadAll(r.Body) r.Body.Close() logger.Debug(logPrefixSecondary, "Body:", string(body)) js, _ := json.Marshal(map[string]string{"message": "pong"}) w.Header().Set("Content-Type", "application/json") w.Write(js) } } ================================================ FILE: router/mux/debug_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package mux import ( "bytes" "io" "net/http" "net/http/httptest" "testing" "github.com/luraproject/lura/v2/logging" ) func TestDebugHandler(t *testing.T) { buff := bytes.NewBuffer(make([]byte, 1024)) logger, err := logging.NewLogger("ERROR", buff, "pref") if err != nil { t.Error("building the logger:", err.Error()) return } handler := DebugHandler(logger) req, _ := http.NewRequest("GET", "http://127.0.0.1:8089/_mux_debug?b=1", io.NopCloser(&bytes.Buffer{})) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() handler.ServeHTTP(w, req) body, ioerr := io.ReadAll(w.Result().Body) if ioerr != nil { t.Error("reading a response:", err.Error()) return } w.Result().Body.Close() expectedBody := "{\"message\":\"pong\"}" content := string(body) if w.Result().Header.Get("Cache-Control") != "" { t.Error("Cache-Control error:", w.Result().Header.Get("Cache-Control")) } if w.Result().Header.Get("Content-Type") != "application/json" { t.Error("Content-Type error:", w.Result().Header.Get("Content-Type")) } if w.Result().Header.Get("X-Krakend") != "" { t.Error("X-Krakend error:", w.Result().Header.Get("X-Krakend")) } if w.Result().StatusCode != http.StatusOK { t.Error("Unexpected status code:", w.Result().StatusCode) } if content != expectedBody { t.Error("Unexpected body:", content, "expected:", expectedBody) } } ================================================ FILE: router/mux/echo.go ================================================ // SPDX-License-Identifier: Apache-2.0 package mux import ( "encoding/json" "io" "net/http" ) type echoResponse struct { Uri string `json:"req_uri"` UriDetails map[string]string `json:"req_uri_details"` Method string `json:"req_method"` Querystring map[string][]string `json:"req_querystring"` Body string `json:"req_body"` Headers map[string][]string `json:"req_headers"` } // EchoHandler creates a dummy handler function, useful for quick integration tests func EchoHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var body string if r.Body != nil { tmp, _ := io.ReadAll(r.Body) r.Body.Close() body = string(tmp) } resp, err := json.Marshal(echoResponse{ Uri: r.RequestURI, UriDetails: map[string]string{ "user": r.URL.User.String(), "host": r.Host, "path": r.URL.Path, "query": r.URL.Query().Encode(), "fragment": r.URL.Fragment, }, Method: r.Method, Querystring: r.URL.Query(), Body: body, Headers: r.Header, }) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") w.Write(resp) } } ================================================ FILE: router/mux/echo_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package mux import ( "io" "net/http" "net/http/httptest" "strings" "testing" ) func TestEchoHandlerNew(t *testing.T) { reqBody := `{"message":"some body to send"}` expectedRespBody := `{"req_uri":"http://127.0.0.1:8088/_gin_endpoint/a?b=1","req_uri_details":{"fragment":"","host":"127.0.0.1:8088","path":"/_gin_endpoint/a","query":"b=1","user":""},"req_method":"GET","req_querystring":{"b":["1"]},"req_body":"{\"message\":\"some body to send\"}","req_headers":{"Content-Type":["application/json"]}}` expectedRespNoBody := `{"req_uri":"http://127.0.0.1:8088/_gin_endpoint/a?b=1","req_uri_details":{"fragment":"","host":"127.0.0.1:8088","path":"/_gin_endpoint/a","query":"b=1","user":""},"req_method":"GET","req_querystring":{"b":["1"]},"req_body":"","req_headers":{"Content-Type":["application/json"]}}` expectedRespString := `{"req_uri":"http://127.0.0.1:8088/_gin_endpoint/a?b=1","req_uri_details":{"fragment":"","host":"127.0.0.1:8088","path":"/_gin_endpoint/a","query":"b=1","user":""},"req_method":"GET","req_querystring":{"b":["1"]},"req_body":"Hello lura","req_headers":{"Content-Type":["application/json"]}}` e := EchoHandler() for _, tc := range []struct { name string body io.Reader resp string }{ { name: "json body", body: strings.NewReader(reqBody), resp: expectedRespBody, }, { name: "no body", body: http.NoBody, resp: expectedRespNoBody, }, { name: "string body", body: strings.NewReader("Hello lura"), resp: expectedRespString, }, } { t.Run(tc.name, func(t *testing.T) { echoRunTestRequest(t, e, tc.body, tc.resp) }) } } func echoRunTestRequest(t *testing.T, e http.HandlerFunc, body io.Reader, expected string) { req := httptest.NewRequest("GET", "http://127.0.0.1:8088/_gin_endpoint/a?b=1", body) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() e.ServeHTTP(w, req) respBody, ioerr := io.ReadAll(w.Result().Body) if ioerr != nil { t.Error("reading a response:", ioerr.Error()) return } w.Result().Body.Close() content := string(respBody) if w.Result().Header.Get("Cache-Control") != "" { t.Error("Cache-Control error:", w.Result().Header.Get("Cache-Control")) } if w.Result().Header.Get("Content-Type") != "application/json" { t.Error("Content-Type error:", w.Result().Header.Get("Content-Type")) } if w.Result().Header.Get("X-Krakend") != "" { t.Error("X-Krakend error:", w.Result().Header.Get("X-Krakend")) } if w.Result().StatusCode != http.StatusOK { t.Error("Unexpected status code:", w.Result().StatusCode) } if content != expected { t.Error("Unexpected body:", content, "expected:", expected) } } ================================================ FILE: router/mux/endpoint.go ================================================ // SPDX-License-Identifier: Apache-2.0 package mux import ( "context" "fmt" "net" "net/http" "net/textproto" "strings" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/core" "github.com/luraproject/lura/v2/proxy" "github.com/luraproject/lura/v2/transport/http/server" ) const requestParamsAsterisk string = "*" // HandlerFactory creates a handler function that adapts the mux router with the injected proxy type HandlerFactory func(*config.EndpointConfig, proxy.Proxy) http.HandlerFunc // EndpointHandler is a HandlerFactory that adapts the mux router with the injected proxy // and the default RequestBuilder var EndpointHandler = CustomEndpointHandler(NewRequest) // CustomEndpointHandler returns a HandlerFactory with the received RequestBuilder using the default ToHTTPError function func CustomEndpointHandler(rb RequestBuilder) HandlerFactory { return CustomEndpointHandlerWithHTTPError(rb, server.DefaultToHTTPError) } // CustomEndpointHandlerWithHTTPError returns a HandlerFactory with the received RequestBuilder func CustomEndpointHandlerWithHTTPError(rb RequestBuilder, errF server.ToHTTPError) HandlerFactory { return func(configuration *config.EndpointConfig, prxy proxy.Proxy) http.HandlerFunc { cacheControlHeaderValue := fmt.Sprintf("public, max-age=%d", int(configuration.CacheTTL.Seconds())) isCacheEnabled := configuration.CacheTTL.Seconds() != 0 render := getRender(configuration) headersToSend := configuration.HeadersToPass if len(headersToSend) == 0 { headersToSend = server.HeadersToSend } method := strings.ToTitle(configuration.Method) return func(w http.ResponseWriter, r *http.Request) { w.Header().Set(core.KrakendHeaderName, core.KrakendHeaderValue) if r.Method != method { w.Header().Set(server.CompleteResponseHeaderName, server.HeaderIncompleteResponseValue) http.Error(w, "", http.StatusMethodNotAllowed) return } requestCtx, cancel := context.WithTimeout(r.Context(), configuration.Timeout) response, err := prxy(requestCtx, rb(r, configuration.QueryString, headersToSend)) select { case <-requestCtx.Done(): if err == nil { err = server.ErrInternalError } default: } if response != nil && len(response.Data) > 0 { if response.IsComplete { w.Header().Set(server.CompleteResponseHeaderName, server.HeaderCompleteResponseValue) if isCacheEnabled { w.Header().Set("Cache-Control", cacheControlHeaderValue) } } else { w.Header().Set(server.CompleteResponseHeaderName, server.HeaderIncompleteResponseValue) } for k, vs := range response.Metadata.Headers { for _, v := range vs { w.Header().Add(k, v) } } } else { w.Header().Set(server.CompleteResponseHeaderName, server.HeaderIncompleteResponseValue) if err != nil { if t, ok := err.(responseError); ok { http.Error(w, err.Error(), t.StatusCode()) } else { http.Error(w, err.Error(), errF(err)) } cancel() return } } render(w, response) cancel() } } } // RequestBuilder is a function that creates a proxy.Request from the received http request type RequestBuilder func(r *http.Request, queryString, headersToSend []string) *proxy.Request // ParamExtractor is a function that extracts query params from the requested uri type ParamExtractor func(r *http.Request) map[string]string // NoopParamExtractor is a No Op ParamExtractor (returns an empty map of params) func NoopParamExtractor(_ *http.Request) map[string]string { return map[string]string{} } // NewRequest is a RequestBuilder that creates a proxy request from the received http request without // processing the uri params var NewRequest = NewRequestBuilder(NoopParamExtractor) // NewRequestBuilder gets a RequestBuilder with the received ParamExtractor as a query param // extraction mechanism func NewRequestBuilder(paramExtractor ParamExtractor) RequestBuilder { return func(r *http.Request, queryString, headersToSend []string) *proxy.Request { params := paramExtractor(r) headers := make(map[string][]string, 3+len(headersToSend)) for _, k := range headersToSend { if k == requestParamsAsterisk { headers = r.Header break } if h, ok := r.Header[textproto.CanonicalMIMEHeaderKey(k)]; ok { headers[k] = h } } headers["X-Forwarded-For"] = []string{clientIP(r)} headers["X-Forwarded-Host"] = []string{r.Host} // if User-Agent is not forwarded using headersToSend, we set // the KrakenD router User Agent value if _, ok := headers["User-Agent"]; !ok { headers["User-Agent"] = server.UserAgentHeaderValue } else { headers["X-Forwarded-Via"] = server.UserAgentHeaderValue } query := make(map[string][]string, len(queryString)) queryValues := r.URL.Query() for i := range queryString { if queryString[i] == requestParamsAsterisk { query = queryValues break } if v, ok := queryValues[queryString[i]]; ok && len(v) > 0 { query[queryString[i]] = v } } return &proxy.Request{ Path: r.URL.Path, Method: r.Method, Query: query, Body: r.Body, Params: params, Headers: headers, } } } type responseError interface { error StatusCode() int } // clientIP implements a best effort algorithm to return the real client IP, it parses // X-Real-IP and X-Forwarded-For in order to work properly with reverse-proxies such us: nginx or haproxy. // Use X-Forwarded-For before X-Real-Ip as nginx uses X-Real-Ip with the proxy's IP. func clientIP(r *http.Request) string { clientIP := r.Header.Get("X-Forwarded-For") clientIP = strings.TrimSpace(strings.Split(clientIP, ",")[0]) if clientIP == "" { clientIP = strings.TrimSpace(r.Header.Get("X-Real-Ip")) } if clientIP != "" { return clientIP } if addr := r.Header.Get("X-Appengine-Remote-Addr"); addr != "" { return addr } if ip, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr)); err == nil { return ip } return "" } ================================================ FILE: router/mux/endpoint_benchmark_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package mux import ( "bytes" "context" "fmt" "io" "net/http" "net/http/httptest" "testing" "time" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/proxy" ) func BenchmarkEndpointHandler_ko(b *testing.B) { p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return nil, fmt.Errorf("This is %s", "a dummy error") } endpoint := &config.EndpointConfig{ Timeout: time.Second, CacheTTL: 6 * time.Hour, QueryString: []string{"b"}, } router := http.NewServeMux() router.Handle("/_gin_endpoint/", EndpointHandler(endpoint, p)) req, _ := http.NewRequest("GET", "http://127.0.0.1:8080/_gin_endpoint/a?b=1", http.NoBody) req.Header.Set("Content-Type", "application/json") b.ReportAllocs() for i := 0; i < b.N; i++ { w := httptest.NewRecorder() router.ServeHTTP(w, req) } } func BenchmarkEndpointHandler_ok(b *testing.B) { pResp := proxy.Response{ Data: map[string]interface{}{}, Io: io.NopCloser(&bytes.Buffer{}), IsComplete: true, Metadata: proxy.Metadata{}, } p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return &pResp, nil } endpoint := &config.EndpointConfig{ Timeout: time.Second, CacheTTL: 6 * time.Hour, QueryString: []string{"b"}, } router := http.NewServeMux() router.Handle("/_gin_endpoint/", EndpointHandler(endpoint, p)) req, _ := http.NewRequest("GET", "http://127.0.0.1:8080/_gin_endpoint/a?b=1", http.NoBody) req.Header.Set("Content-Type", "application/json") b.ReportAllocs() for i := 0; i < b.N; i++ { w := httptest.NewRecorder() router.ServeHTTP(w, req) } } func BenchmarkEndpointHandler_ko_Parallel(b *testing.B) { p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return nil, fmt.Errorf("This is %s", "a dummy error") } endpoint := &config.EndpointConfig{ Timeout: time.Second, CacheTTL: 6 * time.Hour, QueryString: []string{"b"}, } router := http.NewServeMux() router.Handle("/_gin_endpoint/", EndpointHandler(endpoint, p)) req, _ := http.NewRequest("GET", "http://127.0.0.1:8080/_gin_endpoint/a?b=1", http.NoBody) req.Header.Set("Content-Type", "application/json") b.ReportAllocs() b.RunParallel(func(pb *testing.PB) { for pb.Next() { w := httptest.NewRecorder() router.ServeHTTP(w, req) } }) } func BenchmarkEndpointHandler_ok_Parallel(b *testing.B) { pResp := proxy.Response{ Data: map[string]interface{}{}, Io: io.NopCloser(&bytes.Buffer{}), IsComplete: true, Metadata: proxy.Metadata{}, } p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return &pResp, nil } endpoint := &config.EndpointConfig{ Timeout: time.Second, CacheTTL: 6 * time.Hour, QueryString: []string{"b"}, } router := http.NewServeMux() router.Handle("/_gin_endpoint/", EndpointHandler(endpoint, p)) req, _ := http.NewRequest("GET", "http://127.0.0.1:8080/_gin_endpoint/a?b=1", http.NoBody) req.Header.Set("Content-Type", "application/json") b.ReportAllocs() b.RunParallel(func(pb *testing.PB) { for pb.Next() { w := httptest.NewRecorder() router.ServeHTTP(w, req) } }) } ================================================ FILE: router/mux/endpoint_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package mux import ( "bytes" "context" "encoding/json" "errors" "fmt" "io" "net/http" "net/http/httptest" "testing" "time" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/proxy" "github.com/luraproject/lura/v2/transport/http/server" ) func TestEndpointHandler_ok(t *testing.T) { p := func(_ context.Context, req *proxy.Request) (*proxy.Response, error) { data, _ := json.Marshal(req.Query) if string(data) != `{"b":["1"],"c[]":["x","y"],"d":["1","2"]}` { t.Errorf("unexpected querystring: %s", data) } return &proxy.Response{ IsComplete: true, Data: map[string]interface{}{"supu": "tupu"}, }, nil } endpointHandlerTestCase{ timeout: 10, proxy: p, method: "GET", expectedBody: "{\"supu\":\"tupu\"}", expectedCache: "public, max-age=21600", expectedContent: "application/json", expectedStatusCode: http.StatusOK, completed: true, }.test(t) time.Sleep(5 * time.Millisecond) } func TestEndpointHandler_okAllParams(t *testing.T) { p := func(_ context.Context, req *proxy.Request) (*proxy.Response, error) { return &proxy.Response{ IsComplete: true, Data: map[string]interface{}{"query": req.Query, "headers": req.Headers, "params": req.Params}, Metadata: proxy.Metadata{ Headers: map[string][]string{"X-YZ": {"something"}}, StatusCode: 200, }, }, nil } endpointHandlerTestCase{ timeout: 10, proxy: p, method: "GET", expectedBody: `{"headers":{"Content-Type":["application/json"],"User-Agent":["KrakenD Version undefined"],"X-Forwarded-For":[""],"X-Forwarded-Host":["127.0.0.1:8081"]},"params":{},"query":{"a":["42"],"b":["1"],"c[]":["x","y"],"d":["1","2"]}}`, expectedCache: "public, max-age=21600", expectedContent: "application/json", expectedStatusCode: http.StatusOK, completed: true, queryString: []string{"*"}, headers: []string{"*"}, expectedHeaders: map[string][]string{"X-YZ": {"something"}}, }.test(t) time.Sleep(5 * time.Millisecond) } func TestEndpointHandler_incomplete(t *testing.T) { p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return &proxy.Response{ IsComplete: false, Data: map[string]interface{}{"foo": "bar"}, }, nil } endpointHandlerTestCase{ timeout: 10, proxy: p, method: "GET", expectedBody: "{\"foo\":\"bar\"}", expectedCache: "", expectedContent: "application/json", expectedStatusCode: http.StatusOK, completed: false, }.test(t) time.Sleep(5 * time.Millisecond) } func TestEndpointHandler_ko(t *testing.T) { p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return nil, fmt.Errorf("This is %s", "a dummy error") } endpointHandlerTestCase{ timeout: 10, proxy: p, method: "GET", expectedBody: "This is a dummy error\n", expectedCache: "", expectedContent: "text/plain; charset=utf-8", expectedStatusCode: http.StatusInternalServerError, completed: false, }.test(t) time.Sleep(5 * time.Millisecond) } func TestEndpointHandler_incompleteAndErrored(t *testing.T) { p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return &proxy.Response{ IsComplete: false, Data: map[string]interface{}{"foo": "bar"}, }, errors.New("This is a dummy error") } endpointHandlerTestCase{ timeout: 10, proxy: p, method: "GET", expectedBody: "{\"foo\":\"bar\"}", expectedCache: "", expectedContent: "application/json", expectedStatusCode: http.StatusOK, completed: false, }.test(t) time.Sleep(5 * time.Millisecond) } func TestEndpointHandler_cancel(t *testing.T) { p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { time.Sleep(100 * time.Millisecond) return &proxy.Response{ IsComplete: false, Data: map[string]interface{}{"foo": "bar"}, }, nil } endpointHandlerTestCase{ timeout: 0, proxy: p, method: "GET", expectedBody: "{\"foo\":\"bar\"}", expectedCache: "", expectedContent: "application/json", expectedStatusCode: http.StatusOK, completed: false, }.test(t) time.Sleep(5 * time.Millisecond) } func TestEndpointHandler_cancelEmpty(t *testing.T) { p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { time.Sleep(100 * time.Millisecond) return nil, nil } endpointHandlerTestCase{ timeout: 0, proxy: p, method: "GET", expectedBody: server.ErrInternalError.Error() + "\n", expectedCache: "", expectedContent: "text/plain; charset=utf-8", expectedStatusCode: http.StatusInternalServerError, completed: false, }.test(t) time.Sleep(5 * time.Millisecond) } func TestEndpointHandler_noop(t *testing.T) { endpointHandlerTestCase{ timeout: time.Minute, proxy: proxy.NoopProxy, method: "GET", expectedBody: "{}", expectedCache: "", expectedContent: "application/json", expectedStatusCode: http.StatusOK, completed: false, }.test(t) time.Sleep(5 * time.Millisecond) } func TestEndpointHandler_badMethod(t *testing.T) { endpointHandlerTestCase{ timeout: 10, proxy: proxy.NoopProxy, method: "PUT", expectedBody: "\n", expectedCache: "", expectedContent: "text/plain; charset=utf-8", expectedStatusCode: http.StatusMethodNotAllowed, completed: false, }.test(t) time.Sleep(5 * time.Millisecond) } func TestEndpointHandler_errored_responseError(t *testing.T) { p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return nil, dummyResponseError{err: "this is a dummy error", status: http.StatusTeapot} } endpointHandlerTestCase{ timeout: 10, proxy: p, method: "GET", expectedBody: "this is a dummy error\n", expectedCache: "", expectedContent: "text/plain; charset=utf-8", expectedStatusCode: http.StatusTeapot, completed: false, }.test(t) time.Sleep(5 * time.Millisecond) } type dummyResponseError struct { err string status int } func (d dummyResponseError) Error() string { return d.err } func (d dummyResponseError) StatusCode() int { return d.status } type endpointHandlerTestCase struct { timeout time.Duration proxy proxy.Proxy method string expectedBody string expectedCache string expectedContent string expectedHeaders map[string][]string expectedStatusCode int completed bool queryString []string headers []string } func (tc endpointHandlerTestCase) test(t *testing.T) { endpoint := &config.EndpointConfig{ Method: "GET", Timeout: tc.timeout, CacheTTL: 6 * time.Hour, QueryString: []string{"b", "c[]", "d"}, } if len(tc.queryString) > 0 { endpoint.QueryString = tc.queryString } if len(tc.headers) > 0 { endpoint.HeadersToPass = tc.headers } s := startMuxServer(EndpointHandler(endpoint, tc.proxy)) req, _ := http.NewRequest(tc.method, "http://127.0.0.1:8081/_mux_endpoint?b=1&c[]=x&c[]=y&d=1&d=2&a=42", io.NopCloser(&bytes.Buffer{})) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() s.ServeHTTP(w, req) body, ioerr := io.ReadAll(w.Result().Body) if ioerr != nil { t.Error("Reading the response:", ioerr.Error()) return } w.Result().Body.Close() content := string(body) resp := w.Result() if resp.Header.Get("Cache-Control") != tc.expectedCache { t.Error("Cache-Control error:", resp.Header.Get("Cache-Control")) } if tc.completed && resp.Header.Get(server.CompleteResponseHeaderName) != server.HeaderCompleteResponseValue { t.Error(server.CompleteResponseHeaderName, "error:", resp.Header.Get(server.CompleteResponseHeaderName)) } if !tc.completed && resp.Header.Get(server.CompleteResponseHeaderName) != server.HeaderIncompleteResponseValue { t.Error(server.CompleteResponseHeaderName, "error:", resp.Header.Get(server.CompleteResponseHeaderName)) } if resp.Header.Get("Content-Type") != tc.expectedContent { t.Error("Content-Type error:", resp.Header.Get("Content-Type")) } if resp.Header.Get("X-Krakend") != "Version undefined" { t.Error("X-Krakend error:", resp.Header.Get("X-Krakend")) } if resp.StatusCode != tc.expectedStatusCode { t.Error("Unexpected status code:", resp.StatusCode) } if content != tc.expectedBody { t.Error("Unexpected body:", content, "expected:", tc.expectedBody) } for k, v := range tc.expectedHeaders { if header := resp.Header.Get(k); v[0] != header { t.Error("Unexpected value for header:", k, header, "expected:", v[0]) } } } func startMuxServer(handlerFunc http.HandlerFunc) *http.ServeMux { router := http.NewServeMux() router.Handle("/_mux_endpoint", handlerFunc) return router } ================================================ FILE: router/mux/engine.go ================================================ // SPDX-License-Identifier: Apache-2.0 package mux import ( "net/http" "sync" "github.com/luraproject/lura/v2/transport/http/server" ) // Engine defines the minimun required interface for the mux compatible engine type Engine interface { http.Handler Handle(pattern, method string, handler http.Handler) } // BasicEngine is a slightly customized http.ServeMux router type BasicEngine struct { handler *http.ServeMux dict map[string]map[string]http.HandlerFunc } // NewHTTPErrorInterceptor returns a HTTPErrorInterceptor over the injected response writer func NewHTTPErrorInterceptor(w http.ResponseWriter) *HTTPErrorInterceptor { return &HTTPErrorInterceptor{w, new(sync.Once)} } // HTTPErrorInterceptor is a reposnse writer that adds a header signaling incomplete response in case of // seeing a status code not equal to 200 type HTTPErrorInterceptor struct { http.ResponseWriter once *sync.Once } // WriteHeader records the status code and adds a header signaling incomplete responses func (i *HTTPErrorInterceptor) WriteHeader(code int) { i.once.Do(func() { if code != http.StatusOK { i.ResponseWriter.Header().Set(server.CompleteResponseHeaderName, server.HeaderIncompleteResponseValue) } }) i.ResponseWriter.WriteHeader(code) } // DefaultEngine returns a new engine using BasicEngine func DefaultEngine() *BasicEngine { return &BasicEngine{ handler: http.NewServeMux(), dict: map[string]map[string]http.HandlerFunc{}, } } // Handle registers a handler at a given url pattern and http method func (e *BasicEngine) Handle(pattern, method string, handler http.Handler) { if _, ok := e.dict[pattern]; !ok { e.dict[pattern] = map[string]http.HandlerFunc{} e.handler.Handle(pattern, e.registrableHandler(pattern)) } e.dict[pattern][method] = handler.ServeHTTP } // ServeHTTP adds a error interceptor and delegates the request dispatching to the // internal request multiplexer. func (e *BasicEngine) ServeHTTP(w http.ResponseWriter, r *http.Request) { e.handler.ServeHTTP(NewHTTPErrorInterceptor(w), r) } func (e *BasicEngine) registrableHandler(pattern string) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { if handler, ok := e.dict[pattern][req.Method]; ok { handler(rw, req) return } rw.Header().Set(server.CompleteResponseHeaderName, server.HeaderIncompleteResponseValue) http.Error(rw, "", http.StatusMethodNotAllowed) }) } ================================================ FILE: router/mux/engine_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package mux import ( "bytes" "io" "net/http" "net/http/httptest" "testing" ) func TestEngine(t *testing.T) { e := DefaultEngine() for _, method := range []string{"PUT", "POST", "DELETE"} { e.Handle("/", method, http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { http.Error(rw, "hi there!", http.StatusTeapot) })) } for _, tc := range []struct { method string status int }{ {status: http.StatusTeapot, method: "PUT"}, {status: http.StatusTeapot, method: "POST"}, {status: http.StatusTeapot, method: "DELETE"}, {status: http.StatusMethodNotAllowed, method: "GET"}, } { req, _ := http.NewRequest(tc.method, "http://127.0.0.1:8081/_mux_endpoint?b=1&c[]=x&c[]=y&d=1&d=2&a=42", io.NopCloser(&bytes.Buffer{})) w := httptest.NewRecorder() e.ServeHTTP(w, req) if sc := w.Result().StatusCode; tc.status != sc { t.Error("unexpected status code:", sc) } } } ================================================ FILE: router/mux/render.go ================================================ // SPDX-License-Identifier: Apache-2.0 package mux import ( "encoding/json" "io" "net/http" "sync" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/encoding" "github.com/luraproject/lura/v2/proxy" ) // Render defines the signature of the functions to be use for the final response // encoding and rendering type Render func(http.ResponseWriter, *proxy.Response) // NEGOTIATE defines the value of the OutputEncoding for the negotiated render const NEGOTIATE = "negotiate" var ( mutex = &sync.RWMutex{} renderRegister = map[string]Render{ encoding.STRING: stringRender, encoding.JSON: jsonRender, encoding.NOOP: noopRender, "json-collection": jsonCollectionRender, } ) // RegisterRender allows clients to register their custom renders func RegisterRender(name string, r Render) { mutex.Lock() renderRegister[name] = r mutex.Unlock() } func getRender(cfg *config.EndpointConfig) Render { fallback := jsonRender if len(cfg.Backend) == 1 { fallback = getWithFallback(cfg.Backend[0].Encoding, fallback) } if cfg.OutputEncoding == "" { return fallback } return getWithFallback(cfg.OutputEncoding, fallback) } func getWithFallback(key string, fallback Render) Render { mutex.RLock() r, ok := renderRegister[key] mutex.RUnlock() if !ok { return fallback } return r } var ( emptyResponse = []byte("{}") emptyCollection = []byte("[]") ) func jsonRender(w http.ResponseWriter, response *proxy.Response) { w.Header().Set("Content-Type", "application/json") if response == nil { w.Write(emptyResponse) return } js, err := json.Marshal(response.Data) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } w.Write(js) } func jsonCollectionRender(w http.ResponseWriter, response *proxy.Response) { w.Header().Set("Content-Type", "application/json") if response == nil { w.Write(emptyCollection) return } col, ok := response.Data["collection"] if !ok { w.Write(emptyCollection) return } js, err := json.Marshal(col) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } w.Write(js) } func stringRender(w http.ResponseWriter, response *proxy.Response) { w.Header().Set("Content-Type", "text/plain") if response == nil { w.Write([]byte{}) return } d, ok := response.Data["content"] if !ok { w.Write([]byte{}) return } msg, ok := d.(string) if !ok { w.Write([]byte{}) return } w.Write([]byte(msg)) } func noopRender(w http.ResponseWriter, response *proxy.Response) { if response == nil { http.Error(w, "", http.StatusInternalServerError) return } for k, vs := range response.Metadata.Headers { for _, v := range vs { w.Header().Add(k, v) } } if response.Metadata.StatusCode != 0 { w.WriteHeader(response.Metadata.StatusCode) } if response.Io == nil { return } io.Copy(w, response.Io) } ================================================ FILE: router/mux/render_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package mux import ( "bytes" "context" "io" "net/http" "net/http/httptest" "reflect" "testing" "time" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/encoding" "github.com/luraproject/lura/v2/proxy" ) func TestRender_unknown(t *testing.T) { p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return &proxy.Response{ IsComplete: true, Data: map[string]interface{}{"supu": "tupu"}, }, nil } endpoint := &config.EndpointConfig{ Timeout: time.Second, CacheTTL: 6 * time.Hour, QueryString: []string{"b"}, OutputEncoding: "unknown", Method: "GET", } router := http.NewServeMux() router.Handle("/_mux_endpoint", EndpointHandler(endpoint, p)) expectedHeader := "application/json" expectedBody := `{"supu":"tupu"}` for _, testData := range [][]string{ {"plain", "text/plain"}, {"none", ""}, {"json", "application/json"}, {"unknown", "unknown"}, } { req, _ := http.NewRequest("GET", "http://127.0.0.1:8080/_mux_endpoint?b=1", io.NopCloser(&bytes.Buffer{})) req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", testData[1]) w := httptest.NewRecorder() router.ServeHTTP(w, req) defer w.Result().Body.Close() body, ioerr := io.ReadAll(w.Result().Body) if ioerr != nil { t.Error("reading response body:", ioerr) return } content := string(body) if w.Result().Header.Get("Cache-Control") != "public, max-age=21600" { t.Error(testData[0], "Cache-Control error:", w.Result().Header.Get("Cache-Control")) } if w.Result().Header.Get("Content-Type") != expectedHeader { t.Error(testData[0], "Content-Type error:", w.Result().Header.Get("Content-Type")) } if w.Result().Header.Get("X-Krakend") != "Version undefined" { t.Error(testData[0], "X-Krakend error:", w.Result().Header.Get("X-Krakend")) } if w.Result().StatusCode != http.StatusOK { t.Error(testData[0], "Unexpected status code:", w.Result().StatusCode) } if content != expectedBody { t.Error(testData[0], "Unexpected body:", content, "expected:", expectedBody) } } } func TestRender_string(t *testing.T) { expectedContent := "supu" expectedHeader := "text/plain" p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return &proxy.Response{ IsComplete: true, Data: map[string]interface{}{"content": expectedContent}, }, nil } endpoint := &config.EndpointConfig{ Timeout: time.Second, CacheTTL: 6 * time.Hour, QueryString: []string{"b"}, OutputEncoding: encoding.STRING, Method: "GET", } router := http.NewServeMux() router.Handle("/_mux_endpoint", EndpointHandler(endpoint, p)) for _, testData := range [][]string{ {"plain", "text/plain"}, {"none", ""}, {"json", "application/json"}, {"unknown", "unknown"}, } { req, _ := http.NewRequest("GET", "http://127.0.0.1:8080/_mux_endpoint?b=1", io.NopCloser(&bytes.Buffer{})) req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", testData[1]) w := httptest.NewRecorder() router.ServeHTTP(w, req) defer w.Result().Body.Close() body, ioerr := io.ReadAll(w.Result().Body) if ioerr != nil { t.Error("reading response body:", ioerr) return } content := string(body) if w.Result().Header.Get("Cache-Control") != "public, max-age=21600" { t.Error(testData[0], "Cache-Control error:", w.Result().Header.Get("Cache-Control")) } if w.Result().Header.Get("Content-Type") != expectedHeader { t.Error(testData[0], "Content-Type error:", w.Result().Header.Get("Content-Type")) } if w.Result().Header.Get("X-Krakend") != "Version undefined" { t.Error(testData[0], "X-Krakend error:", w.Result().Header.Get("X-Krakend")) } if w.Result().StatusCode != http.StatusOK { t.Error(testData[0], "Unexpected status code:", w.Result().StatusCode) } if content != expectedContent { t.Error(testData[0], "Unexpected body:", content, "expected:", expectedContent) } } } func TestRender_string_noData(t *testing.T) { expectedContent := "" expectedHeader := "text/plain" for k, p := range []proxy.Proxy{ func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return &proxy.Response{ IsComplete: false, Data: map[string]interface{}{"content": 42}, }, nil }, func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return &proxy.Response{ IsComplete: false, Data: map[string]interface{}{}, }, nil }, func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return nil, nil }, } { endpoint := &config.EndpointConfig{ Timeout: time.Second, CacheTTL: 6 * time.Hour, QueryString: []string{"b"}, OutputEncoding: encoding.STRING, Method: "GET", } router := http.NewServeMux() router.Handle("/_mux_endpoint", EndpointHandler(endpoint, p)) req, _ := http.NewRequest("GET", "http://127.0.0.1:8080/_mux_endpoint?b=1", io.NopCloser(&bytes.Buffer{})) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() router.ServeHTTP(w, req) defer w.Result().Body.Close() body, ioerr := io.ReadAll(w.Result().Body) if ioerr != nil { t.Error("reading response body:", ioerr) return } content := string(body) if w.Result().Header.Get("Content-Type") != expectedHeader { t.Error(k, "Content-Type error:", w.Result().Header.Get("Content-Type")) } if w.Result().Header.Get("X-Krakend") != "Version undefined" { t.Error(k, "X-Krakend error:", w.Result().Header.Get("X-Krakend")) } if w.Result().StatusCode != http.StatusOK { t.Error(k, "Unexpected status code:", w.Result().StatusCode) } if content != expectedContent { t.Error(k, "Unexpected body:", content, "expected:", expectedContent) } } } func TestRegisterRender(t *testing.T) { var total int expected := &proxy.Response{IsComplete: true, Data: map[string]interface{}{"a": "b"}} name := "test render" RegisterRender(name, func(_ http.ResponseWriter, resp *proxy.Response) { *resp = *expected total++ }) subject := getRender(&config.EndpointConfig{OutputEncoding: name}) w := httptest.NewRecorder() resp := proxy.Response{} subject(w, &resp) if !reflect.DeepEqual(resp, *expected) { t.Error("unexpected response", resp) } if total != 1 { t.Error("the render was called an unexpected amount of times:", total) } } func TestRender_noop(t *testing.T) { expectedContent := "supu" expectedHeader := "text/plain; charset=utf-8" expectedSetCookieValue := []string{"test1=test1", "test2=test2"} p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return &proxy.Response{ Metadata: proxy.Metadata{ StatusCode: 200, Headers: map[string][]string{ "Content-Type": {expectedHeader}, "Set-Cookie": {"test1=test1", "test2=test2"}, }, }, Io: bytes.NewBufferString(expectedContent), }, nil } endpoint := &config.EndpointConfig{ Method: "GET", Timeout: time.Second, CacheTTL: 6 * time.Hour, QueryString: []string{"b"}, OutputEncoding: encoding.NOOP, } router := http.NewServeMux() router.Handle("/_mux_endpoint", EndpointHandler(endpoint, p)) req, _ := http.NewRequest("GET", "http://127.0.0.1:8080/_mux_endpoint?b=1", io.NopCloser(&bytes.Buffer{})) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() router.ServeHTTP(w, req) defer w.Result().Body.Close() body, ioerr := io.ReadAll(w.Result().Body) if ioerr != nil { t.Error("reading response body:", ioerr) return } content := string(body) if w.Result().Header.Get("Content-Type") != expectedHeader { t.Error("Content-Type error:", w.Result().Header.Get("Content-Type")) } if w.Result().Header.Get("X-Krakend") != "Version undefined" { t.Error("X-Krakend error:", w.Result().Header.Get("X-Krakend")) } if w.Result().StatusCode != http.StatusOK { t.Error("Unexpected status code:", w.Result().StatusCode) } if content != expectedContent { t.Error("Unexpected body:", content, "expected:", expectedContent) } gotCookie := w.Header()["Set-Cookie"] if !reflect.DeepEqual(gotCookie, expectedSetCookieValue) { t.Error("Unexpected Set-Cookie header:", gotCookie, "expected:", expectedSetCookieValue) } } func TestRender_noop_nilBody(t *testing.T) { expectedContent := "" expectedHeader := "" p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return &proxy.Response{IsComplete: true}, nil } endpoint := &config.EndpointConfig{ Method: "GET", Timeout: time.Second, CacheTTL: 6 * time.Hour, QueryString: []string{"b"}, OutputEncoding: encoding.NOOP, } router := http.NewServeMux() router.Handle("/_mux_endpoint", EndpointHandler(endpoint, p)) req, _ := http.NewRequest("GET", "http://127.0.0.1:8080/_mux_endpoint?b=1", io.NopCloser(&bytes.Buffer{})) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() router.ServeHTTP(w, req) defer w.Result().Body.Close() body, ioerr := io.ReadAll(w.Result().Body) if ioerr != nil { t.Error("reading response body:", ioerr) return } content := string(body) if w.Result().Header.Get("Content-Type") != expectedHeader { t.Error("Content-Type error:", w.Result().Header.Get("Content-Type")) } if w.Result().Header.Get("X-Krakend") != "Version undefined" { t.Error("X-Krakend error:", w.Result().Header.Get("X-Krakend")) } if w.Result().StatusCode != http.StatusOK { t.Error("Unexpected status code:", w.Result().StatusCode) } if content != expectedContent { t.Error("Unexpected body:", content, "expected:", expectedContent) } } func TestRender_noop_nilResponse(t *testing.T) { p := func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return nil, nil } endpoint := &config.EndpointConfig{ Method: "GET", Timeout: time.Second, CacheTTL: 6 * time.Hour, QueryString: []string{"b"}, OutputEncoding: encoding.NOOP, } router := http.NewServeMux() router.Handle("/_mux_endpoint", EndpointHandler(endpoint, p)) req, _ := http.NewRequest("GET", "http://127.0.0.1:8080/_mux_endpoint?b=1", io.NopCloser(&bytes.Buffer{})) req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() router.ServeHTTP(w, req) if w.Result().Header.Get("Content-Type") != "text/plain; charset=utf-8" { t.Error("Content-Type error:", w.Result().Header.Get("Content-Type")) } if w.Result().Header.Get("X-Krakend") != "Version undefined" { t.Error("X-Krakend error:", w.Result().Header.Get("X-Krakend")) } if w.Result().StatusCode != http.StatusInternalServerError { t.Error("Unexpected status code:", w.Result().StatusCode) } } ================================================ FILE: router/mux/router.go ================================================ // SPDX-License-Identifier: Apache-2.0 /* Package mux provides some basic implementations for building routers based on net/http mux */ package mux import ( "context" "net/http" "strings" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" "github.com/luraproject/lura/v2/proxy" "github.com/luraproject/lura/v2/router" "github.com/luraproject/lura/v2/transport/http/server" ) // DefaultDebugPattern is the default pattern used to define the debug endpoint const ( DefaultDebugPattern = "/__debug/" DefaultEchoPattern = "/__echo/" logPrefix = "[SERVICE: Mux]" ) // RunServerFunc is a func that will run the http Server with the given params. type RunServerFunc func(context.Context, config.ServiceConfig, http.Handler) error // Config is the struct that collects the parts the router should be builded from type Config struct { Engine Engine Middlewares []HandlerMiddleware HandlerFactory HandlerFactory ProxyFactory proxy.Factory Logger logging.Logger DebugPattern string EchoPattern string RunServer RunServerFunc } // HandlerMiddleware is the interface for the decorators over the http.Handler type HandlerMiddleware interface { Handler(h http.Handler) http.Handler } // DefaultFactory returns a net/http mux router factory with the injected proxy factory and logger func DefaultFactory(pf proxy.Factory, logger logging.Logger) router.Factory { return factory{ Config{ Engine: DefaultEngine(), Middlewares: []HandlerMiddleware{}, HandlerFactory: EndpointHandler, ProxyFactory: pf, Logger: logger, DebugPattern: DefaultDebugPattern, EchoPattern: DefaultEchoPattern, RunServer: server.RunServer, }, } } // NewFactory returns a net/http mux router factory with the injected configuration func NewFactory(cfg Config) router.Factory { if cfg.DebugPattern == "" { cfg.DebugPattern = DefaultDebugPattern } return factory{cfg} } type factory struct { cfg Config } // New implements the factory interface func (rf factory) New() router.Router { return rf.NewWithContext(context.Background()) } // NewWithContext implements the factory interface func (rf factory) NewWithContext(ctx context.Context) router.Router { return httpRouter{rf.cfg, ctx, rf.cfg.RunServer} } type httpRouter struct { cfg Config ctx context.Context RunServer RunServerFunc } // HealthHandler is a dummy http.HandlerFunc implementation for exposing a health check endpoint func HealthHandler(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") w.Write([]byte(`{"status":"ok"}`)) } // Run implements the router interface func (r httpRouter) Run(cfg config.ServiceConfig) { if cfg.Debug { debugHandler := DebugHandler(r.cfg.Logger) for _, method := range []string{ http.MethodGet, http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete, http.MethodHead, http.MethodOptions, http.MethodConnect, http.MethodTrace, } { r.cfg.Engine.Handle(r.cfg.DebugPattern, method, debugHandler) } } if cfg.Echo { echoHandler := EchoHandler() for _, method := range []string{ http.MethodGet, http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete, http.MethodHead, http.MethodOptions, http.MethodConnect, http.MethodTrace, } { r.cfg.Engine.Handle(r.cfg.EchoPattern, method, echoHandler) } } r.cfg.Engine.Handle("/__health", "GET", http.HandlerFunc(HealthHandler)) server.InitHTTPDefaultTransport(cfg) r.registerKrakendEndpoints(cfg.Endpoints) if err := r.RunServer(r.ctx, cfg, r.handler()); err != nil { r.cfg.Logger.Error(logPrefix, err.Error()) } r.cfg.Logger.Info(logPrefix, "Router execution ended") } func (r httpRouter) registerKrakendEndpoints(endpoints []*config.EndpointConfig) { for _, c := range endpoints { proxyStack, err := r.cfg.ProxyFactory.New(c) if err != nil { r.cfg.Logger.Error(logPrefix, "Calling the ProxyFactory", err.Error()) continue } r.registerKrakendEndpoint(c.Method, c, r.cfg.HandlerFactory(c, proxyStack), len(c.Backend)) } } func (r httpRouter) registerKrakendEndpoint(method string, endpoint *config.EndpointConfig, handler http.HandlerFunc, totBackends int) { method = strings.ToTitle(method) path := endpoint.Endpoint if method != http.MethodGet && totBackends > 1 { if !router.IsValidSequentialEndpoint(endpoint) { r.cfg.Logger.Error(logPrefix, method, " endpoints with sequential proxy enabled only allow a non-GET in the last backend! Ignoring", path) return } } switch method { case http.MethodGet: case http.MethodPost: case http.MethodPut: case http.MethodPatch: case http.MethodDelete: default: r.cfg.Logger.Error(logPrefix, "Unsupported method", method) return } r.cfg.Logger.Debug(logPrefix, "Registering the endpoint", method, path) r.cfg.Engine.Handle(path, method, handler) } func (r httpRouter) handler() http.Handler { var handler http.Handler = r.cfg.Engine for _, middleware := range r.cfg.Middlewares { r.cfg.Logger.Debug(logPrefix, "Adding the middleware", middleware) handler = middleware.Handler(handler) } return handler } ================================================ FILE: router/mux/router_test.go ================================================ //go:build !race // +build !race // SPDX-License-Identifier: Apache-2.0 package mux import ( "bytes" "context" "errors" "fmt" "io" "net/http" "regexp" "strings" "testing" "time" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" "github.com/luraproject/lura/v2/proxy" "github.com/luraproject/lura/v2/transport/http/server" ) func TestDefaultFactory_ok(t *testing.T) { buff := bytes.NewBuffer(make([]byte, 1024)) logger, err := logging.NewLogger("ERROR", buff, "pref") if err != nil { t.Error("building the logger:", err.Error()) return } ctx, cancel := context.WithCancel(context.Background()) defer func() { cancel() time.Sleep(5 * time.Millisecond) }() r := DefaultFactory(noopProxyFactory(map[string]interface{}{"supu": "tupu"}), logger).NewWithContext(ctx) expectedBody := "{\"supu\":\"tupu\"}" serviceCfg := config.ServiceConfig{ Port: 8062, Endpoints: []*config.EndpointConfig{ { Endpoint: "/get", Method: "GET", Timeout: 10, Backend: []*config.Backend{ {}, }, }, { Endpoint: "/get", Method: "POST", Timeout: 10, Backend: []*config.Backend{ {}, }, }, { Endpoint: "/post", Method: "Post", Timeout: 10, Backend: []*config.Backend{ {}, }, }, { Endpoint: "/put", Method: "put", Timeout: 10, Backend: []*config.Backend{ {}, }, }, { Endpoint: "/patch", Method: "PATCH", Timeout: 10, Backend: []*config.Backend{ {}, }, }, { Endpoint: "/delete", Method: "DELETE", Timeout: 10, Backend: []*config.Backend{ {}, }, }, }, } go func() { r.Run(serviceCfg) }() time.Sleep(5 * time.Millisecond) for _, endpoint := range serviceCfg.Endpoints { req, _ := http.NewRequest(strings.ToTitle(endpoint.Method), fmt.Sprintf("http://127.0.0.1:8062%s", endpoint.Endpoint), http.NoBody) req.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(req) if err != nil { t.Error("Making the request:", err.Error()) return } defer resp.Body.Close() body, ioerr := io.ReadAll(resp.Body) if ioerr != nil { t.Error("Reading the response:", ioerr.Error()) return } content := string(body) if resp.Header.Get("Cache-Control") != "" { t.Error(endpoint.Endpoint, "Cache-Control error:", resp.Header.Get("Cache-Control")) } if resp.Header.Get(server.CompleteResponseHeaderName) != server.HeaderCompleteResponseValue { t.Error(server.CompleteResponseHeaderName, "error:", resp.Header.Get(server.CompleteResponseHeaderName)) } if resp.Header.Get("Content-Type") != "application/json" { t.Error(endpoint.Endpoint, "Content-Type error:", resp.Header.Get("Content-Type")) } if resp.Header.Get("X-Krakend") != "Version undefined" { t.Error(endpoint.Endpoint, "X-Krakend error:", resp.Header.Get("X-Krakend")) } if resp.StatusCode != http.StatusOK { t.Error(endpoint.Endpoint, "Unexpected status code:", resp.StatusCode) } if content != expectedBody { t.Error(endpoint.Endpoint, "Unexpected body:", content, "expected:", expectedBody) } } } func TestDefaultFactory_ko(t *testing.T) { buff := bytes.NewBuffer(make([]byte, 1024)) logger, err := logging.NewLogger("ERROR", buff, "pref") if err != nil { t.Error("building the logger:", err.Error()) return } ctx, cancel := context.WithCancel(context.Background()) defer func() { cancel() time.Sleep(5 * time.Millisecond) }() r := NewFactory(Config{ Engine: DefaultEngine(), Middlewares: []HandlerMiddleware{identityMiddleware{}}, HandlerFactory: EndpointHandler, ProxyFactory: noopProxyFactory(map[string]interface{}{"supu": "tupu"}), Logger: logger, RunServer: server.RunServer, }).NewWithContext(ctx) serviceCfg := config.ServiceConfig{ Debug: true, Port: 8063, Endpoints: []*config.EndpointConfig{ { Endpoint: "/ignored", Method: "GETTT", Backend: []*config.Backend{ {}, }, }, { Endpoint: "/empty", Method: "GETTT", Backend: []*config.Backend{}, }, }, } go func() { r.Run(serviceCfg) }() time.Sleep(5 * time.Millisecond) for _, subject := range [][]string{ {"GET", "ignored"}, {"GET", "empty"}, {"PUT", "also-ignored"}, } { req, _ := http.NewRequest(subject[0], fmt.Sprintf("http://127.0.0.1:8063/%s", subject[1]), http.NoBody) req.Header.Set("Content-Type", "application/json") checkResponseIs404(t, req) } } func TestDefaultFactory_proxyFactoryCrash(t *testing.T) { buff := bytes.NewBuffer(make([]byte, 1024)) logger, err := logging.NewLogger("ERROR", buff, "pref") if err != nil { t.Error("building the logger:", err.Error()) return } ctx, cancel := context.WithCancel(context.Background()) defer func() { cancel() time.Sleep(5 * time.Millisecond) }() r := DefaultFactory(erroredProxyFactory{fmt.Errorf("%s", "crash!!!")}, logger).NewWithContext(ctx) serviceCfg := config.ServiceConfig{ Debug: true, Port: 8064, Endpoints: []*config.EndpointConfig{ { Endpoint: "/ignored", Method: "GET", Timeout: 10, Backend: []*config.Backend{ {}, }, }, }, } go func() { r.Run(serviceCfg) }() time.Sleep(5 * time.Millisecond) for _, subject := range [][]string{{"GET", "ignored"}, {"PUT", "also-ignored"}} { req, _ := http.NewRequest(subject[0], fmt.Sprintf("http://127.0.0.1:8064/%s", subject[1]), http.NoBody) req.Header.Set("Content-Type", "application/json") checkResponseIs404(t, req) } } func TestRunServer_ko(t *testing.T) { buff := new(bytes.Buffer) logger, err := logging.NewLogger("DEBUG", buff, "") if err != nil { t.Error("building the logger:", err.Error()) return } errorMsg := "runServer error" runServerFunc := func(_ context.Context, _ config.ServiceConfig, _ http.Handler) error { return errors.New(errorMsg) } pf := noopProxyFactory(map[string]interface{}{"supu": "tupu"}) r := NewFactory( Config{ Engine: DefaultEngine(), Middlewares: []HandlerMiddleware{}, HandlerFactory: EndpointHandler, ProxyFactory: pf, Logger: logger, DebugPattern: DefaultDebugPattern, RunServer: runServerFunc, }, ).New() serviceCfg := config.ServiceConfig{} r.Run(serviceCfg) re := regexp.MustCompile(errorMsg) if !re.MatchString(buff.String()) { t.Errorf("the logger doesn't contain the expected msg: %s", buff.Bytes()) } } func checkResponseIs404(t *testing.T, req *http.Request) { expectedBody := "404 page not found\n" resp, err := http.DefaultClient.Do(req) if err != nil { t.Error("Making the request:", err.Error()) return } defer resp.Body.Close() body, ioerr := io.ReadAll(resp.Body) if ioerr != nil { t.Error("Reading the response:", ioerr.Error()) return } content := string(body) if resp.Header.Get("Cache-Control") != "" { t.Error("Cache-Control error:", resp.Header.Get("Cache-Control")) } if resp.Header.Get(server.CompleteResponseHeaderName) != server.HeaderIncompleteResponseValue { t.Error(req.URL.String(), server.CompleteResponseHeaderName, "error:", resp.Header.Get(server.CompleteResponseHeaderName)) } if resp.Header.Get("Content-Type") != "text/plain; charset=utf-8" { t.Error("Content-Type error:", resp.Header.Get("Content-Type")) } if resp.Header.Get("X-Krakend") != "" { t.Error("X-Krakend error:", resp.Header.Get("X-Krakend")) } if resp.StatusCode != http.StatusNotFound { t.Error("Unexpected status code:", resp.StatusCode) } if content != expectedBody { t.Error("Unexpected body:", content, "expected:", expectedBody) } } type noopProxyFactory map[string]interface{} func (n noopProxyFactory) New(_ *config.EndpointConfig) (proxy.Proxy, error) { return func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return &proxy.Response{ IsComplete: true, Data: n, }, nil }, nil } type erroredProxyFactory struct { Error error } func (e erroredProxyFactory) New(_ *config.EndpointConfig) (proxy.Proxy, error) { return proxy.NoopProxy, e.Error } type identityMiddleware struct{} func (identityMiddleware) Handler(h http.Handler) http.Handler { return h } ================================================ FILE: router/negroni/router.go ================================================ // SPDX-License-Identifier: Apache-2.0 /* Package negroni provides some basic implementations for building routers based on urfave/negroni */ package negroni import ( "net/http" gorilla "github.com/gorilla/mux" "github.com/urfave/negroni/v2" "github.com/luraproject/lura/v2/logging" "github.com/luraproject/lura/v2/proxy" "github.com/luraproject/lura/v2/router" luragorilla "github.com/luraproject/lura/v2/router/gorilla" "github.com/luraproject/lura/v2/router/mux" ) // DefaultFactory returns a net/http mux router factory with the injected proxy factory and logger func DefaultFactory(pf proxy.Factory, logger logging.Logger, middlewares []negroni.Handler) router.Factory { return mux.NewFactory(DefaultConfig(pf, logger, middlewares)) } // DefaultConfig returns the struct that collects the parts the router should be builded from func DefaultConfig(pf proxy.Factory, logger logging.Logger, middlewares []negroni.Handler) mux.Config { return DefaultConfigWithRouter(pf, logger, NewGorillaRouter(), middlewares) } // DefaultConfigWithRouter returns the struct that collects the parts the router should be builded from with the // injected gorilla mux router func DefaultConfigWithRouter(pf proxy.Factory, logger logging.Logger, muxEngine *gorilla.Router, middlewares []negroni.Handler) mux.Config { cfg := luragorilla.DefaultConfig(pf, logger) cfg.Engine = newNegroniEngine(muxEngine, middlewares...) return cfg } // NewGorillaRouter is a wrapper over the default gorilla router builder func NewGorillaRouter() *gorilla.Router { return gorilla.NewRouter() } func newNegroniEngine(muxEngine *gorilla.Router, middlewares ...negroni.Handler) negroniEngine { negroniRouter := negroni.Classic() for _, m := range middlewares { negroniRouter.Use(m) } negroniRouter.UseHandler(muxEngine) return negroniEngine{muxEngine, negroniRouter} } type negroniEngine struct { r *gorilla.Router n *negroni.Negroni } // Handle implements the mux.Engine interface from the lura router package func (e negroniEngine) Handle(pattern, method string, handler http.Handler) { e.r.Handle(pattern, handler).Methods(method) } // ServeHTTP implements the http:Handler interface from the stdlib func (e negroniEngine) ServeHTTP(w http.ResponseWriter, r *http.Request) { e.n.ServeHTTP(mux.NewHTTPErrorInterceptor(w), r) } ================================================ FILE: router/negroni/router_test.go ================================================ //go:build !race // +build !race // SPDX-License-Identifier: Apache-2.0 package negroni import ( "bytes" "context" "fmt" "io" "net/http" "testing" "time" "github.com/urfave/negroni/v2" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" "github.com/luraproject/lura/v2/proxy" "github.com/luraproject/lura/v2/transport/http/server" ) func TestDefaultFactory_ok(t *testing.T) { buff := bytes.NewBuffer(make([]byte, 1024)) logger, err := logging.NewLogger("ERROR", buff, "pref") if err != nil { t.Error("building the logger:", err.Error()) return } ctx, cancel := context.WithCancel(context.Background()) defer func() { cancel() time.Sleep(5 * time.Millisecond) }() r := DefaultFactory(noopProxyFactory(map[string]interface{}{"supu": "tupu"}), logger, []negroni.Handler{}).NewWithContext(ctx) expectedBody := "{\"supu\":\"tupu\"}" serviceCfg := config.ServiceConfig{ Port: 8052, Endpoints: []*config.EndpointConfig{ { Endpoint: "/get/{id}", Method: "GET", Timeout: 10, Backend: []*config.Backend{ {}, }, }, { Endpoint: "/post", Method: "POST", Timeout: 10, Backend: []*config.Backend{ {}, }, }, { Endpoint: "/put", Method: "PUT", Timeout: 10, Backend: []*config.Backend{ {}, }, }, { Endpoint: "/patch", Method: "PATCH", Timeout: 10, Backend: []*config.Backend{ {}, }, }, { Endpoint: "/delete", Method: "DELETE", Timeout: 10, Backend: []*config.Backend{ {}, }, }, }, } go func() { r.Run(serviceCfg) }() time.Sleep(5 * time.Millisecond) for _, endpoint := range serviceCfg.Endpoints { req, _ := http.NewRequest(endpoint.Method, fmt.Sprintf("http://127.0.0.1:8052%s", endpoint.Endpoint), http.NoBody) req.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(req) if err != nil { t.Error("Making the request:", err.Error()) return } defer resp.Body.Close() body, ioerr := io.ReadAll(resp.Body) if ioerr != nil { t.Error("Reading the response:", ioerr.Error()) return } content := string(body) if resp.Header.Get("Cache-Control") != "" { t.Error(endpoint.Endpoint, "Cache-Control error:", resp.Header.Get("Cache-Control")) } if resp.Header.Get("Content-Type") != "application/json" { t.Error(endpoint.Endpoint, "Content-Type error:", resp.Header.Get("Content-Type")) } if resp.Header.Get("X-Krakend") != "Version undefined" { t.Error(endpoint.Endpoint, "X-Krakend error:", resp.Header.Get("X-Krakend")) } if resp.StatusCode != http.StatusOK { t.Error(endpoint.Endpoint, "Unexpected status code:", resp.StatusCode) } if content != expectedBody { t.Error(endpoint.Endpoint, "Unexpected body:", content, "expected:", expectedBody) } } } func TestDefaultFactory_middlewares(t *testing.T) { buff := bytes.NewBuffer(make([]byte, 1024)) logger, err := logging.NewLogger("ERROR", buff, "pref") if err != nil { t.Error("building the logger:", err.Error()) return } ctx, cancel := context.WithCancel(context.Background()) defer func() { cancel() time.Sleep(5 * time.Millisecond) }() count := 0 pf := noopProxyFactory(map[string]interface{}{"supu": "tupu"}) r := DefaultFactory(pf, logger, []negroni.Handler{dummyMiddleware{&count}}).NewWithContext(ctx) expectedBody := "{\"supu\":\"tupu\"}" serviceCfg := config.ServiceConfig{ Port: 8090, Endpoints: []*config.EndpointConfig{ { Endpoint: "/get/{id}", Method: "GET", Timeout: 10, Backend: []*config.Backend{ {}, }, }, }, } go func() { r.Run(serviceCfg) }() time.Sleep(5 * time.Millisecond) for _, endpoint := range serviceCfg.Endpoints { req, _ := http.NewRequest(endpoint.Method, fmt.Sprintf("http://127.0.0.1:8090%s", endpoint.Endpoint), http.NoBody) req.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(req) if err != nil { t.Error("Making the request:", err.Error()) return } defer resp.Body.Close() body, ioerr := io.ReadAll(resp.Body) if ioerr != nil { t.Error("Reading the response:", ioerr.Error()) return } content := string(body) if resp.Header.Get("Cache-Control") != "" { t.Error(endpoint.Endpoint, "Cache-Control error:", resp.Header.Get("Cache-Control")) } if resp.Header.Get(server.CompleteResponseHeaderName) != server.HeaderCompleteResponseValue { t.Error(server.CompleteResponseHeaderName, "error:", resp.Header.Get(server.CompleteResponseHeaderName)) } if resp.Header.Get("Content-Type") != "application/json" { t.Error(endpoint.Endpoint, "Content-Type error:", resp.Header.Get("Content-Type")) } if resp.Header.Get("X-Krakend") != "Version undefined" { t.Error(endpoint.Endpoint, "X-Krakend error:", resp.Header.Get("X-Krakend")) } if resp.StatusCode != http.StatusOK { t.Error(endpoint.Endpoint, "Unexpected status code:", resp.StatusCode) } if content != expectedBody { t.Error(endpoint.Endpoint, "Unexpected body:", content, "expected:", expectedBody) } } if count != 1 { t.Error("Middleware wasn't called just one time") } } func TestDefaultFactory_ko(t *testing.T) { buff := bytes.NewBuffer(make([]byte, 1024)) logger, err := logging.NewLogger("ERROR", buff, "pref") if err != nil { t.Error("building the logger:", err.Error()) return } ctx, cancel := context.WithCancel(context.Background()) defer func() { cancel() time.Sleep(5 * time.Millisecond) }() r := DefaultFactory(noopProxyFactory(map[string]interface{}{"supu": "tupu"}), logger, []negroni.Handler{}).NewWithContext(ctx) serviceCfg := config.ServiceConfig{ Debug: true, Port: 8053, Endpoints: []*config.EndpointConfig{ { Endpoint: "/ignored", Method: "GETTT", Backend: []*config.Backend{ {}, }, }, { Endpoint: "/empty", Method: "GETTT", Backend: []*config.Backend{}, }, { Endpoint: "/also-ignored", Method: "PUT", Backend: []*config.Backend{ {}, {}, }, }, }, } go func() { r.Run(serviceCfg) }() time.Sleep(5 * time.Millisecond) for _, subject := range [][]string{ {"GET", "ignored"}, {"GET", "empty"}, {"PUT", "also-ignored"}, } { req, _ := http.NewRequest(subject[0], fmt.Sprintf("http://127.0.0.1:8053/%s", subject[1]), http.NoBody) req.Header.Set("Content-Type", "application/json") checkResponseIs404(t, req) } } func TestDefaultFactory_proxyFactoryCrash(t *testing.T) { buff := bytes.NewBuffer(make([]byte, 1024)) logger, err := logging.NewLogger("ERROR", buff, "pref") if err != nil { t.Error("building the logger:", err.Error()) return } ctx, cancel := context.WithCancel(context.Background()) defer func() { cancel() time.Sleep(5 * time.Millisecond) }() r := DefaultFactory(erroredProxyFactory{fmt.Errorf("%s", "crash!!!")}, logger, []negroni.Handler{}).NewWithContext(ctx) serviceCfg := config.ServiceConfig{ Debug: true, Port: 8054, Endpoints: []*config.EndpointConfig{ { Endpoint: "/ignored", Method: "GET", Timeout: 10, Backend: []*config.Backend{ {}, }, }, }, } go func() { r.Run(serviceCfg) }() time.Sleep(5 * time.Millisecond) for _, subject := range [][]string{{"GET", "ignored"}, {"PUT", "also-ignored"}} { req, _ := http.NewRequest(subject[0], fmt.Sprintf("http://127.0.0.1:8054/%s", subject[1]), http.NoBody) req.Header.Set("Content-Type", "application/json") checkResponseIs404(t, req) } } type dummyMiddleware struct { Count *int } func (d dummyMiddleware) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) { *(d.Count) = *(d.Count) + 1 next(rw, r) } func checkResponseIs404(t *testing.T, req *http.Request) { expectedBody := "404 page not found\n" resp, err := http.DefaultClient.Do(req) if err != nil { t.Error("Making the request:", err.Error()) return } defer resp.Body.Close() body, ioerr := io.ReadAll(resp.Body) if ioerr != nil { t.Error("Reading the response:", ioerr.Error()) return } content := string(body) if resp.Header.Get("Cache-Control") != "" { t.Error("Cache-Control error:", resp.Header.Get("Cache-Control")) } if resp.Header.Get(server.CompleteResponseHeaderName) != server.HeaderIncompleteResponseValue { t.Error(req.URL.String(), server.CompleteResponseHeaderName, "error:", resp.Header.Get(server.CompleteResponseHeaderName)) } if resp.Header.Get("Content-Type") != "text/plain; charset=utf-8" { t.Error("Content-Type error:", resp.Header.Get("Content-Type")) } if resp.Header.Get("X-Krakend") != "" { t.Error("X-Krakend error:", resp.Header.Get("X-Krakend")) } if resp.StatusCode != http.StatusNotFound { t.Error("Unexpected status code:", resp.StatusCode) } if content != expectedBody { t.Error("Unexpected body:", content, "expected:", expectedBody) } } type noopProxyFactory map[string]interface{} func (n noopProxyFactory) New(_ *config.EndpointConfig) (proxy.Proxy, error) { return func(_ context.Context, _ *proxy.Request) (*proxy.Response, error) { return &proxy.Response{ IsComplete: true, Data: n, }, nil }, nil } type erroredProxyFactory struct { Error error } func (e erroredProxyFactory) New(_ *config.EndpointConfig) (proxy.Proxy, error) { return proxy.NoopProxy, e.Error } type identityMiddleware struct{} func (identityMiddleware) Handler(h http.Handler) http.Handler { return h } ================================================ FILE: router/router.go ================================================ // SPDX-License-Identifier: Apache-2.0 /* Package router defines some interfaces and common helpers for router adapters */ package router import ( "context" "github.com/luraproject/lura/v2/config" ) // Router sets up the public layer exposed to the users type Router interface { Run(config.ServiceConfig) } // RouterFunc type is an adapter to allow the use of ordinary functions as routers. // If f is a function with the appropriate signature, RouterFunc(f) is a Router that calls f. type RouterFunc func(config.ServiceConfig) // Run implements the Router interface func (f RouterFunc) Run(cfg config.ServiceConfig) { f(cfg) } // Factory creates new routers type Factory interface { New() Router NewWithContext(context.Context) Router } ================================================ FILE: sd/dnssrv/subscriber.go ================================================ // SPDX-License-Identifier: Apache-2.0 /* Package dnssrv defines some implementations for a dns based service discovery */ package dnssrv import ( "fmt" "net" "sort" "sync" "time" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/sd" ) // Namespace is the key for the dns sd module const Namespace = "dns" const DefaultTTL = 30 * time.Second const MinTTL = time.Second // Register registers the dns sd subscriber factory under the name defined by Namespace func Register() error { return sd.GetRegister().Register(Namespace, SubscriberFactory) } // TTL is the duration of the cached data var TTL = DefaultTTL func SetTTL(d time.Duration) { if d < MinTTL { // in case the TTL is less than the minimum, we leave what is // already set. return } TTL = d } // DefaultLookup is the function used for the DNS resolution var DefaultLookup = net.LookupSRV // SubscriberFactory builds a DNS_SRV Subscriber with the received config func SubscriberFactory(cfg *config.Backend) sd.Subscriber { return NewDetailedWithScheme(cfg.Host[0], DefaultLookup, TTL, cfg.SDScheme) } // New creates a DNS subscriber with the default values func New(name string) sd.Subscriber { return NewDetailed(name, DefaultLookup, TTL) } // NewDetailed creates a DNS subscriber with the received values func NewDetailed(name string, lookup lookup, ttl time.Duration) sd.Subscriber { return NewDetailedWithScheme(name, lookup, ttl, "http") } // NewDetailedWithScheme creates a DNS subscriber with the received values and the scheme to use // for the fetched server entries. func NewDetailedWithScheme(name string, lookup lookup, ttl time.Duration, scheme string) sd.Subscriber { if scheme == "" { scheme = "http" } s := subscriber{ name: name, cache: &sd.FixedSubscriber{}, mutex: &sync.RWMutex{}, ttl: ttl, lookup: lookup, scheme: scheme, } s.update() go func() { for { <-time.After(s.ttl) s.update() } }() return s } type lookup func(service, proto, name string) (cname string, addrs []*net.SRV, err error) type subscriber struct { name string cache *sd.FixedSubscriber mutex *sync.RWMutex ttl time.Duration lookup lookup scheme string } // Hosts returns a copy of the cached set of hosts. It is safe to call it concurrently func (s subscriber) Hosts() ([]string, error) { s.mutex.RLock() defer s.mutex.RUnlock() hs, err := s.cache.Hosts() if err != nil { return []string{}, err } res := make([]string, len(hs)) copy(res, hs) return res, nil } func (s subscriber) update() { instances, err := s.resolve() if err != nil { return } s.mutex.Lock() defer s.mutex.Unlock() if len(instances) > 100 { *(s.cache) = sd.NewRandomFixedSubscriber(instances) } else { *(s.cache) = sd.FixedSubscriber(instances) } } func (s subscriber) resolve() ([]string, error) { _, srvs, err := s.lookup("", "", s.name) if err != nil { return []string{}, err } sort.Slice( srvs, func(i, j int) bool { if srvs[i].Priority == srvs[j].Priority { if srvs[i].Weight == srvs[j].Weight { if srvs[i].Target == srvs[j].Target { return srvs[i].Port < srvs[j].Port } return srvs[i].Target < srvs[j].Target } return srvs[i].Weight > srvs[j].Weight } return srvs[i].Priority < srvs[j].Priority }, ) ws := make([]uint16, 0, len(srvs)) host := make([]string, 0, len(srvs)) for _, a := range srvs { if a.Priority > srvs[0].Priority { break } ws = append(ws, a.Weight) host = append(host, s.scheme+"://"+net.JoinHostPort(a.Target, fmt.Sprint(a.Port))) } instances := make([]string, 0, len(ws)) for i, times := range compact(ws) { for j := uint16(0); j < times; j++ { instances = append(instances, host[i]) } } return instances, nil } func compact(ws []uint16) []uint16 { tmp := normalize(ws) div := gcd(tmp) if div < 2 { return tmp } res := make([]uint16, len(tmp)) for i, w := range tmp { res[i] = w / div } return res } func normalize(ws []uint16) []uint16 { scale := 100 if l := len(ws); l > scale { scale = l } var sum int64 for _, w := range ws { sum += int64(w) } if sum <= int64(scale) { return ws } res := make([]uint16, len(ws)) for i, w := range ws { res[i] = uint16(int64(w) * int64(scale) / sum) } return res } func gcd(ws []uint16) uint16 { if len(ws) == 0 { return 0 } localGCD := func(a uint16, b uint16) uint16 { for b > 0 { a, b = b, a%b } return a } result := ws[0] for _, i := range ws[1:] { result = localGCD(result, i) } return result } ================================================ FILE: sd/dnssrv/subscriber_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package dnssrv import ( "errors" "fmt" "net" "testing" "time" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/sd" ) func ExampleRegister() { if err := Register(); err != nil { fmt.Println("registering the dns module:", err.Error()) return } srvSet := []*net.SRV{ { Port: 90, Target: "foobar", Weight: 2, }, { Port: 90, Target: "127.0.0.1", Weight: 2, }, { Port: 80, Target: "127.0.0.1", Weight: 2, }, { Port: 81, Target: "127.0.0.1", Weight: 4, }, { Port: 82, Target: "127.0.0.1", Weight: 10, Priority: 2, }, { Port: 83, Target: "127.0.0.1", }, } DefaultLookup = func(service, proto, name string) (cname string, addrs []*net.SRV, err error) { return "cname", srvSet, nil } s := sd.GetRegister().Get(Namespace)(&config.Backend{Host: []string{"some.example.tld"}, SD: Namespace}) hosts, err := s.Hosts() if err != nil { fmt.Println("Getting the hosts:", err.Error()) return } for _, h := range hosts { fmt.Println(h) } // output: // http://127.0.0.1:81 // http://127.0.0.1:81 // http://127.0.0.1:80 // http://127.0.0.1:90 // http://foobar:90 } func ExampleNewDetailed() { srvSet := []*net.SRV{ { Port: 90, Target: "foobar", Weight: 2, }, { Port: 90, Target: "127.0.0.1", Weight: 2, }, { Port: 80, Target: "127.0.0.1", Weight: 2, }, { Port: 81, Target: "127.0.0.1", Weight: 4, }, { Port: 82, Target: "127.0.0.1", Weight: 10, Priority: 2, }, { Port: 83, Target: "127.0.0.1", }, } lookupFunc := func(service, proto, name string) (cname string, addrs []*net.SRV, err error) { return "cname", srvSet, nil } s := NewDetailed("some.example.tld", lookupFunc, 10*time.Second) hosts, err := s.Hosts() if err != nil { fmt.Println("Getting the hosts:", err.Error()) return } for _, h := range hosts { fmt.Println(h) } // output: // http://127.0.0.1:81 // http://127.0.0.1:81 // http://127.0.0.1:80 // http://127.0.0.1:90 // http://foobar:90 } func ExampleNewDetailedWithScheme() { srvSet := []*net.SRV{ { Port: 90, Target: "foobar", Weight: 2, }, { Port: 90, Target: "127.0.0.1", Weight: 2, }, { Port: 80, Target: "127.0.0.1", Weight: 2, }, { Port: 81, Target: "127.0.0.1", Weight: 4, }, { Port: 82, Target: "127.0.0.1", Weight: 10, Priority: 2, }, { Port: 83, Target: "127.0.0.1", }, } lookupFunc := func(service, proto, name string) (cname string, addrs []*net.SRV, err error) { return "cname", srvSet, nil } s := NewDetailedWithScheme("some.example.tld", lookupFunc, 10*time.Second, "https") hosts, err := s.Hosts() if err != nil { fmt.Println("Getting the hosts:", err.Error()) return } for _, h := range hosts { fmt.Println(h) } // output: // https://127.0.0.1:81 // https://127.0.0.1:81 // https://127.0.0.1:80 // https://127.0.0.1:90 // https://foobar:90 } func TestSubscriber_LoockupError(t *testing.T) { errToReturn := errors.New("Some random error") defaultLookup := func(service, proto, name string) (cname string, addrs []*net.SRV, err error) { return "cname", []*net.SRV{}, errToReturn } ttl := 1 * time.Millisecond s := NewDetailed("some.example.tld", defaultLookup, ttl) hosts, err := s.Hosts() if err != nil { t.Error("Unexpected error!", err) } if len(hosts) != 0 { t.Error("Wrong number of hosts:", len(hosts)) } } func TestSubscriber_ResolveVeryLarge(t *testing.T) { var srvSet []*net.SRV const max = 1000 for i := 0; i < max; i++ { srvSet = append(srvSet, &net.SRV{ Port: uint16(80 + i), Target: "127.0.0.1", Weight: 65535, }) } lookupFunc := func(service, proto, name string) (cname string, addrs []*net.SRV, err error) { return "cname", srvSet, nil } s := NewDetailed("large.example.tld", lookupFunc, 10*time.Second) hosts, _ := s.Hosts() if len(hosts) != max { t.Errorf("Expected %d, but got %d", max, len(hosts)) } } func Examplecompact_basicweights() { for _, tc := range [][]uint16{ []uint16{25, 10000, 1000}, []uint16{25, 1000, 10000, 0, 65535}, []uint16{1, 65535}, []uint16{}, []uint16{0, 0, 0, 0}, } { fmt.Println(tc, compact(tc)) } // output: // [25 10000 1000] [0 10 1] // [25 1000 10000 0 65535] [0 1 13 0 85] // [1 65535] [0 1] // [] [] // [0 0 0 0] [0 0 0 0] } func Examplecompact_custom_weights() { tc := make([]uint16, 200) for i := range tc { tc[i] = uint16(3*5*7*11*13 + i) } fmt.Println(tc[:5], compact(tc[:5])) for i := range tc { tc[i] = uint16(i * 3 * 5 * 7) } fmt.Println(tc[:5], compact(tc[:5])) // output: // [15015 15016 15017 15018 15019] [19 19 20 20 20] // [0 105 210 315 420] [0 1 2 3 4] } ================================================ FILE: sd/loadbalancing.go ================================================ // SPDX-License-Identifier: Apache-2.0 package sd import ( "errors" "runtime" "sync/atomic" "github.com/valyala/fastrand" ) // Balancer applies a balancing stategy in order to select the backend host to be used type Balancer interface { Host() (string, error) } // ErrNoHosts is the error the balancer must return when there are 0 hosts ready var ErrNoHosts = errors.New("no hosts available") // NewBalancer returns the best perfomant balancer depending on the number of available processors. // If GOMAXPROCS = 1, it returns a round robin LB due there is no contention over the atomic counter. // If GOMAXPROCS > 1, it returns a pseudo random LB optimized for scaling over the number of CPUs. func NewBalancer(subscriber Subscriber) Balancer { if p := runtime.GOMAXPROCS(-1); p == 1 { return NewRoundRobinLB(subscriber) } return NewRandomLB(subscriber) } // NewRoundRobinLB returns a new balancer using a round robin strategy and starting at a random // position in the set of hosts. func NewRoundRobinLB(subscriber Subscriber) Balancer { s, ok := subscriber.(FixedSubscriber) start := uint64(0) if ok { if l := len(s); l == 1 { return nopBalancer(s[0]) } else if l > 1 { start = uint64(fastrand.Uint32n(uint32(l))) } } return &roundRobinLB{ balancer: balancer{subscriber: subscriber}, counter: start, } } type roundRobinLB struct { balancer counter uint64 } // Host implements the balancer interface func (r *roundRobinLB) Host() (string, error) { hosts, err := r.hosts() if err != nil { return "", err } offset := (atomic.AddUint64(&r.counter, 1) - 1) % uint64(len(hosts)) return hosts[offset], nil } // NewRandomLB returns a new balancer using a fastrand pseudorandom generator func NewRandomLB(subscriber Subscriber) Balancer { if s, ok := subscriber.(FixedSubscriber); ok && len(s) == 1 { return nopBalancer(s[0]) } return &randomLB{ balancer: balancer{subscriber: subscriber}, rand: fastrand.Uint32n, } } type randomLB struct { balancer rand func(uint32) uint32 } // Host implements the balancer interface func (r *randomLB) Host() (string, error) { hosts, err := r.hosts() if err != nil { return "", err } return hosts[int(r.rand(uint32(len(hosts))))], nil } type balancer struct { subscriber Subscriber } func (b *balancer) hosts() ([]string, error) { hs, err := b.subscriber.Hosts() if err != nil { return hs, err } if len(hs) <= 0 { return hs, ErrNoHosts } return hs, nil } type nopBalancer string func (b nopBalancer) Host() (string, error) { return string(b), nil } ================================================ FILE: sd/loadbalancing_benchmark_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package sd import ( "fmt" "testing" ) var balancerTestsCases = [][]string{ {"a"}, {"a", "b", "c"}, {"a", "b", "c", "e", "f"}, } func BenchmarkLB(b *testing.B) { for _, tc := range []struct { name string f func([]string) Balancer }{ {name: "round_robin", f: func(hs []string) Balancer { return NewRoundRobinLB(FixedSubscriber(hs)) }}, {name: "random", f: func(hs []string) Balancer { return NewRandomLB(FixedSubscriber(hs)) }}, } { for _, testCase := range balancerTestsCases { b.Run(fmt.Sprintf("%s/%d", tc.name, len(testCase)), func(b *testing.B) { balancer := tc.f(testCase) b.ResetTimer() for i := 0; i < b.N; i++ { balancer.Host() } }) } } } func BenchmarkLB_parallel(b *testing.B) { for _, tc := range []struct { name string f func([]string) Balancer }{ {name: "round_robin", f: func(hs []string) Balancer { return NewRoundRobinLB(FixedSubscriber(hs)) }}, {name: "random", f: func(hs []string) Balancer { return NewRandomLB(FixedSubscriber(hs)) }}, } { for _, testCase := range balancerTestsCases { b.Run(fmt.Sprintf("%s/%d", tc.name, len(testCase)), func(b *testing.B) { balancer := tc.f(testCase) b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { balancer.Host() } }) }) } } } ================================================ FILE: sd/loadbalancing_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package sd import ( "errors" "fmt" "math" "testing" "github.com/luraproject/lura/v2/config" ) func ExampleNewRoundRobinLB() { balancer := NewRoundRobinLB(FixedSubscriber([]string{"a", "b", "c"})) // code required in order to make the test deterministic balancer.(*roundRobinLB).counter = 1 for i := 0; i < 5; i++ { h, err := balancer.Host() if err != nil { fmt.Println(err.Error()) continue } fmt.Println(h) } // output // b // c // a // b // a } func TestRoundRobinLB(t *testing.T) { for _, endpoints := range balancerTestsCases { t.Run(fmt.Sprintf("%d hosts", len(endpoints)), func(t *testing.T) { var ( n = len(endpoints) counts = make(map[string]int, n) iterations = 100000 * n want = iterations / n ) for _, e := range endpoints { counts[e] = 0 } subscriber := FixedSubscriber(endpoints) balancer := NewRoundRobinLB(subscriber) if b, ok := balancer.(*roundRobinLB); ok { b.counter = 0 } for i := 0; i < iterations; i++ { endpoint, err := balancer.Host() if err != nil { t.Fail() } expected := i % n if v := endpoints[expected]; v != endpoint { t.Errorf("%d: want %s, have %s", i, endpoints[expected], endpoint) } counts[endpoint]++ } for i, have := range counts { if have != want { t.Errorf("%s: want %d, have %d", i, want, have) } } }) } } func TestRoundRobinLB_noEndpoints(t *testing.T) { subscriber := FixedSubscriber{} balancer := NewRoundRobinLB(subscriber) _, err := balancer.Host() if want, have := ErrNoHosts, err; want != have { t.Errorf("want %v, have %v", want, have) } } func ExampleNewRandomLB() { balancer := NewRandomLB(FixedSubscriber([]string{"a", "b", "c"})) // code required in order to make the test deterministic { var counter uint32 balancer.(*randomLB).rand = func(max uint32) uint32 { if max != 3 { fmt.Println("unexpected max:", max) } defer func() { counter++ }() return counter % max } } for i := 0; i < 5; i++ { h, err := balancer.Host() if err != nil { fmt.Println(err.Error()) continue } fmt.Println(h) } // output // a // b // c // a // b } func TestRandomLB(t *testing.T) { var ( endpoints = []string{"a", "b", "c", "d", "e", "f", "g"} n = len(endpoints) counts = make(map[string]int, n) iterations = 1000000 want = iterations / n tolerance = want / 100 // 1% ) for _, e := range endpoints { counts[e] = 0 } subscriber := FixedSubscriber(endpoints) balancer := NewRandomLB(subscriber) for i := 0; i < iterations; i++ { endpoint, err := balancer.Host() if err != nil { t.Fail() } counts[endpoint]++ } for i, have := range counts { delta := int(math.Abs(float64(want - have))) if delta > tolerance { t.Errorf("%s: want %d, have %d, delta %d > %d tolerance", i, want, have, delta, tolerance) } } } func TestRandomLB_single(t *testing.T) { endpoints := []string{"a"} iterations := 1000000 subscriber := FixedSubscriber(endpoints) balancer := NewRandomLB(subscriber) for i := 0; i < iterations; i++ { endpoint, err := balancer.Host() if err != nil { t.Fail() } if endpoint != endpoints[0] { t.Errorf("unexpected host %s", endpoint) } } } func TestRandomLB_noEndpoints(t *testing.T) { subscriber := FixedSubscriberFactory(&config.Backend{}) balancer := NewRandomLB(subscriber) _, err := balancer.Host() if want, have := ErrNoHosts, err; want != have { t.Errorf("want %v, have %v", want, have) } } type erroredSubscriber string func (s erroredSubscriber) Hosts() ([]string, error) { return []string{}, errors.New(string(s)) } func TestRoundRobinLB_erroredSubscriber(t *testing.T) { want := "supu" balancer := NewRoundRobinLB(erroredSubscriber(want)) host, have := balancer.Host() if host != "" || want != have.Error() { t.Errorf("want %s, have %s", want, have.Error()) } } func TestRandomLB_erroredSubscriber(t *testing.T) { want := "supu" balancer := NewRandomLB(erroredSubscriber(want)) host, have := balancer.Host() if host != "" || want != have.Error() { t.Errorf("want %s, have %s", want, have.Error()) } } ================================================ FILE: sd/register.go ================================================ // SPDX-License-Identifier: Apache-2.0 package sd import ( "github.com/luraproject/lura/v2/register" ) // GetRegister returns the package register func GetRegister() *Register { return subscriberFactories } type untypedRegister interface { Register(name string, v interface{}) Get(name string) (interface{}, bool) } // Register is a SD register, mapping different SD subscriber factories // to their respective name, so they can be accessed by name type Register struct { data untypedRegister } func initRegister() *Register { return &Register{register.NewUntyped()} } // Register adds the SubscriberFactory to the internal register under the given // name func (r *Register) Register(name string, sf SubscriberFactory) error { r.data.Register(name, sf) return nil } // Get returns the SubscriberFactory stored under the given name. It falls back to // a FixedSubscriberFactory if there is no factory with that name func (r *Register) Get(name string) SubscriberFactory { tmp, ok := r.data.Get(name) if !ok { return FixedSubscriberFactory } sf, ok := tmp.(SubscriberFactory) if !ok { return FixedSubscriberFactory } return sf } var subscriberFactories = initRegister() ================================================ FILE: sd/register_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package sd import ( "testing" "github.com/luraproject/lura/v2/config" ) func TestGetRegister_Register_ok(t *testing.T) { sf1 := func(*config.Backend) Subscriber { return SubscriberFunc(func() ([]string, error) { return []string{"one"}, nil }) } sf2 := func(*config.Backend) Subscriber { return SubscriberFunc(func() ([]string, error) { return []string{"two", "three"}, nil }) } if err := GetRegister().Register("name1", sf1); err != nil { t.Error(err) } if err := GetRegister().Register("name2", sf2); err != nil { t.Error(err) } if h, err := GetRegister().Get("name1")(&config.Backend{SD: "name1"}).Hosts(); err != nil || len(h) != 1 { t.Error("error using the sd name1") } if h, err := GetRegister().Get("name2")(&config.Backend{SD: "name2"}).Hosts(); err != nil || len(h) != 2 { t.Error("error using the sd name2") } if h, err := GetRegister().Get("name2")(&config.Backend{SD: "name2"}).Hosts(); err != nil || len(h) != 2 { t.Error("error using the sd name2") } subscriberFactories = initRegister() } func TestGetRegister_Get_unknown(t *testing.T) { if h, err := GetRegister().Get("name")(&config.Backend{Host: []string{"name"}}).Hosts(); err != nil || len(h) != 1 { t.Error("error using the default sd") } } func TestGetRegister_Get_errored(t *testing.T) { subscriberFactories.data.Register("errored", true) if h, err := GetRegister().Get("errored")(&config.Backend{SD: "errored", Host: []string{"name"}}).Hosts(); err != nil || len(h) != 1 { t.Error("error using the default sd") } subscriberFactories = initRegister() } ================================================ FILE: sd/subscriber.go ================================================ // SPDX-License-Identifier: Apache-2.0 /* Package sd defines some interfaces and implementations for service discovery */ package sd import ( "math/rand" "github.com/luraproject/lura/v2/config" ) // Subscriber keeps the set of backend hosts up to date type Subscriber interface { Hosts() ([]string, error) } // SubscriberFunc type is an adapter to allow the use of ordinary functions as subscribers. // If f is a function with the appropriate signature, SubscriberFunc(f) is a Subscriber that calls f. type SubscriberFunc func() ([]string, error) // Hosts implements the Subscriber interface by executing the wrapped function func (f SubscriberFunc) Hosts() ([]string, error) { return f() } // FixedSubscriber has a constant set of backend hosts and they never get updated type FixedSubscriber []string // Hosts implements the subscriber interface func (s FixedSubscriber) Hosts() ([]string, error) { return s, nil } // SubscriberFactory builds subscribers with the received config type SubscriberFactory func(*config.Backend) Subscriber // FixedSubscriberFactory builds a FixedSubscriber with the received config func FixedSubscriberFactory(cfg *config.Backend) Subscriber { return FixedSubscriber(cfg.Host) } // NewRandomFixedSubscriber randomizes a list of hosts and builds a FixedSubscriber with it func NewRandomFixedSubscriber(hosts []string) FixedSubscriber { res := make([]string, len(hosts)) j := 0 for _, i := range rand.Perm(len(hosts)) { res[j] = hosts[i] j++ } return FixedSubscriber(res) } ================================================ FILE: test/doc.go ================================================ // SPDX-License-Identifier: Apache-2.0 /* Package test contains the integration tests for the KrakenD framework */ package test ================================================ FILE: test/integration_test.go ================================================ //go:build integration || !race // +build integration !race // SPDX-License-Identifier: Apache-2.0 package test import ( "bytes" "context" "encoding/json" "fmt" "io" "math/rand" "net" "net/http" "net/http/httptest" "os" "strings" "testing" "text/template" "time" ginlib "github.com/gin-gonic/gin" "github.com/urfave/negroni/v2" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" "github.com/luraproject/lura/v2/proxy" "github.com/luraproject/lura/v2/router/chi" "github.com/luraproject/lura/v2/router/gin" "github.com/luraproject/lura/v2/router/gorilla" "github.com/luraproject/lura/v2/router/httptreemux" luranegroni "github.com/luraproject/lura/v2/router/negroni" "github.com/luraproject/lura/v2/transport/http/server" ) var localhostIP string func init() { ln, err := net.Listen("tcp", ":8080") if err != nil { return } go func() { conn, _ := net.Dial("tcp", "localhost:8080") <-time.After(5 * time.Second) conn.Close() }() conn, err := ln.Accept() if err != nil { return } h, _, err := net.SplitHostPort(conn.RemoteAddr().String()) if err == nil { localhostIP = h } conn.Close() } func TestKrakenD_ginRouter(t *testing.T) { ginlib.SetMode(ginlib.TestMode) ctx, cancel := context.WithCancel(context.Background()) defer cancel() testKrakenD(t, func(logger logging.Logger, cfg *config.ServiceConfig) { if cfg.ExtraConfig == nil { cfg.ExtraConfig = map[string]interface{}{} } cfg.ExtraConfig[gin.Namespace] = map[string]interface{}{ "trusted_proxies": []interface{}{"127.0.0.1/32", "::1"}, "remote_ip_headers": []interface{}{"x-forwarded-for"}, "forwarded_by_client_ip": true, "return_error_msg": true, } ignoredChan := make(chan string) opts := gin.EngineOptions{ Logger: logger, Writer: io.Discard, Health: (<-chan string)(ignoredChan), } gin.NewFactory( gin.Config{ Engine: gin.NewEngine(*cfg, opts), Middlewares: []ginlib.HandlerFunc{}, HandlerFactory: gin.EndpointHandler, ProxyFactory: proxy.DefaultFactory(logger), Logger: logger, RunServer: server.RunServer, }, ).NewWithContext(ctx).Run(*cfg) }) } func TestKrakenD_gorillaRouter(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() config.RoutingPattern = config.BracketsRouterPatternBuilder testKrakenD(t, func(logger logging.Logger, cfg *config.ServiceConfig) { gorilla.DefaultFactory(proxy.DefaultFactory(logger), logger).NewWithContext(ctx).Run(*cfg) }) config.RoutingPattern = config.ColonRouterPatternBuilder } func TestKrakenD_negroniRouter(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() config.RoutingPattern = config.BracketsRouterPatternBuilder testKrakenD(t, func(logger logging.Logger, cfg *config.ServiceConfig) { factory := luranegroni.DefaultFactory(proxy.DefaultFactory(logger), logger, []negroni.Handler{}) factory.NewWithContext(ctx).Run(*cfg) }) config.RoutingPattern = config.ColonRouterPatternBuilder } func TestKrakenD_httptreemuxRouter(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() testKrakenD(t, func(logger logging.Logger, cfg *config.ServiceConfig) { httptreemux.DefaultFactory(proxy.DefaultFactory(logger), logger).NewWithContext(ctx).Run(*cfg) }) } func TestKrakenD_chiRouter(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() config.RoutingPattern = config.BracketsRouterPatternBuilder testKrakenD(t, func(logger logging.Logger, cfg *config.ServiceConfig) { chi.DefaultFactory(proxy.DefaultFactory(logger), logger).NewWithContext(ctx).Run(*cfg) }) config.RoutingPattern = config.ColonRouterPatternBuilder } func testKrakenD(t *testing.T, runRouter func(logging.Logger, *config.ServiceConfig)) { cfg, err := setupBackend(t) if err != nil { t.Error(err) return } logger := logging.NoOp go runRouter(logger, cfg) <-time.After(300 * time.Millisecond) defaultHeaders := map[string]string{ "Content-Type": "application/json", "X-KrakenD-Completed": "true", "X-Krakend": "Version undefined", } incompleteHeader := map[string]string{ "Content-Type": "application/json", "X-KrakenD-Completed": "false", "X-Krakend": "Version undefined", } for _, tc := range []struct { name string url string method string headers map[string]string body string expBody string expHeaders map[string]string expStatusCode int }{ { name: "static", url: "/static", headers: map[string]string{}, expHeaders: incompleteHeader, expBody: `{"bar":"foobar","foo":42}`, }, { name: "param_forwarding", url: "/param_forwarding/foo/constant/bar", method: "POST", headers: map[string]string{ "Content-Type": "application/json", "Authorization": "bearer AuthorizationToken", "X-Y-Z": "x-y-z", }, body: `{"foo":"bar"}`, expHeaders: defaultHeaders, expBody: `{"path":"/foo/bar"}`, }, { name: "param_forwarding_2", url: "/param_forwarding/foo/constant/foobar", method: "POST", headers: map[string]string{ "Content-Type": "application/json", "Authorization": "bearer AuthorizationToken", "X-Y-Z": "x-y-z", }, body: `{"foo":"bar"}`, expHeaders: defaultHeaders, expBody: `{"path":"/foo/foobar"}`, }, { name: "timeout", url: "/timeout", headers: map[string]string{}, expHeaders: incompleteHeader, expBody: `{"email":"some@email.com","name":"a"}`, }, { name: "partial_with_static", url: "/partial/static", headers: map[string]string{}, expHeaders: incompleteHeader, expBody: `{"bar":"foobar","email":"some@email.com","foo":42,"name":"a"}`, }, { name: "partial", url: "/partial", headers: map[string]string{}, expHeaders: incompleteHeader, expBody: `{"email":"some@email.com","name":"a"}`, }, { name: "combination", url: "/combination", headers: map[string]string{}, expHeaders: defaultHeaders, expBody: `{"name":"a","personal_email":"some@email.com","posts":[{"body":"some content","date":"123456789"},{"body":"some other content","date":"123496789"}]}`, }, { name: "detail_error", url: "/detail_error", headers: map[string]string{}, expHeaders: incompleteHeader, expBody: `{"email":"some@email.com","error_backend_a":{"http_status_code":429,"http_body":"sad panda\n","http_body_encoding":"text/plain; charset=utf-8"},"name":"a"}`, }, { name: "querystring-params-no-params", url: "/querystring-params-test/no-params?a=1&b=2&c=3", headers: map[string]string{}, expHeaders: defaultHeaders, expBody: fmt.Sprintf(`{"headers":{"Accept-Encoding":["gzip"],"User-Agent":["KrakenD Version undefined"],"X-Forwarded-Host":["localhost:%d"]},"path":"/no-params","query":{}}`, cfg.Port), }, { name: "querystring-params-optional-query-params", url: "/querystring-params-test/query-params?a=1&b=2&c=3", headers: map[string]string{}, expHeaders: defaultHeaders, expBody: fmt.Sprintf(`{"headers":{"Accept-Encoding":["gzip"],"User-Agent":["KrakenD Version undefined"],"X-Forwarded-Host":["localhost:%d"]},"path":"/query-params","query":{"a":["1"],"b":["2"]}}`, cfg.Port), }, { name: "querystring-params-mandatory-query-params", url: "/querystring-params-test/url-params/some?a=1&b=2&c=3", headers: map[string]string{}, expHeaders: defaultHeaders, expBody: fmt.Sprintf(`{"headers":{"Accept-Encoding":["gzip"],"User-Agent":["KrakenD Version undefined"],"X-Forwarded-Host":["localhost:%d"]},"path":"/url-params","query":{"p":["some"]}}`, cfg.Port), }, { name: "querystring-params-all", url: "/querystring-params-test/all-params?a=1&b=2&c=3", headers: map[string]string{}, expHeaders: defaultHeaders, expBody: fmt.Sprintf(`{"headers":{"Accept-Encoding":["gzip"],"User-Agent":["KrakenD Version undefined"],"X-Forwarded-Host":["localhost:%d"]},"path":"/all-params","query":{"a":["1"],"b":["2"],"c":["3"]}}`, cfg.Port), }, { name: "header-params-none", url: "/header-params-test/no-params", headers: map[string]string{ "x-Test-1": "some", "X-TEST-2": "none", }, expHeaders: defaultHeaders, expBody: fmt.Sprintf(`{"headers":{"Accept-Encoding":["gzip"],"User-Agent":["KrakenD Version undefined"],"X-Forwarded-Host":["localhost:%d"]},"path":"/no-params","query":{}}`, cfg.Port), }, { name: "header-params-filter", url: "/header-params-test/filter-params", headers: map[string]string{ "x-tESt-1": "some", "X-TEST-2": "none", }, expHeaders: defaultHeaders, expBody: fmt.Sprintf(`{"headers":{"Accept-Encoding":["gzip"],"User-Agent":["KrakenD Version undefined"],"X-Forwarded-Host":["localhost:%d"],"X-Test-1":["some"]},"path":"/filter-params","query":{}}`, cfg.Port), }, { name: "header-params-all", url: "/header-params-test/all-params", headers: map[string]string{ "x-Test-1": "some", "X-TEST-2": "none", "User-Agent": "KrakenD Test", }, expHeaders: defaultHeaders, expBody: fmt.Sprintf(`{"headers":{"Accept-Encoding":["gzip"],"User-Agent":["KrakenD Test"],"X-Forwarded-Host":["localhost:%d"],"X-Forwarded-Via":["KrakenD Version undefined"],"X-Test-1":["some"],"X-Test-2":["none"]},"path":"/all-params","query":{}}`, cfg.Port), }, { name: "sequential ok", url: "/sequential/ok/foo", expHeaders: defaultHeaders, expBody: `{"first":{"path":"/provider/foo","random":42},"second":{"path":"/recipient/42","random":42}}`, }, { name: "sequential ko first", url: "/sequential/ko/first/foo", expHeaders: map[string]string{ "X-KrakenD-Completed": "false", "X-Krakend": "Version undefined", }, expStatusCode: 500, }, { name: "sequential ko last", url: "/sequential/ko/last/foo", expHeaders: incompleteHeader, expBody: `{"random":42}`, }, { name: "redirect", url: "/redirect", expHeaders: defaultHeaders, expBody: `{"path":"/","random":42}`, }, { name: "found", url: "/found", expHeaders: defaultHeaders, expBody: `{"path":"/","random":42}`, }, { name: "flatmap del", url: "/flatmap/delete", expHeaders: defaultHeaders, expBody: `{"collection":[{"body":"some content"},{"body":"some other content"}]}`, }, { name: "flatmap rename", url: "/flatmap/rename", expHeaders: defaultHeaders, expBody: `{"collection":[{"body":"some content","created_at":"123456789"},{"body":"some other content","created_at":"123496789"}]}`, }, { name: "x-forwarded-for", url: "/x-forwarded-for", headers: map[string]string{ "x-forwarded-for": "123.45.67.89", }, expHeaders: defaultHeaders, expBody: fmt.Sprintf(`{"headers":{"Accept-Encoding":["gzip"],"User-Agent":["KrakenD Version undefined"],"X-Forwarded-For":["123.45.67.89"],"X-Forwarded-Host":["localhost:%d"]}}`, cfg.Port), }, { method: "PUT", name: "sequence-accept", url: "/sequence-accept", expHeaders: defaultHeaders, }, { method: "GET", name: "error-status-code-1", url: "/error-status-code/1", expStatusCode: 200, }, { method: "GET", name: "error-status-code-2", url: "/error-status-code/2", expStatusCode: 429, }, { method: "GET", name: "error-status-code-3", url: "/error-status-code/3", expStatusCode: 200, }, { method: "POST", name: "multipost_parallel", url: "/multipost/parallel/foo", body: `{"foo":"bar"}`, expStatusCode: 200, expHeaders: defaultHeaders, expBody: fmt.Sprintf(`{"first":{"body":"{\"foo\":\"bar\"}","headers":{"Accept-Encoding":["gzip"],"User-Agent":["KrakenD Version undefined"],"X-Forwarded-For":["`+localhostIP+`"],"X-Forwarded-Host":["localhost:%d"]},"method":"POST","url":"/provider/foo"},"second":{"body":"{\"foo\":\"bar\"}","headers":{"Accept-Encoding":["gzip"],"User-Agent":["KrakenD Version undefined"],"X-Forwarded-For":["`+localhostIP+`"],"X-Forwarded-Host":["localhost:%d"]},"method":"POST","url":"/recipient/foo"}}`, cfg.Port, cfg.Port), }, { method: "POST", name: "multipost_sequential", url: "/multipost/sequential/foo", body: `{"foo":"bar"}`, expStatusCode: 200, expHeaders: defaultHeaders, expBody: fmt.Sprintf(`{"first":{"path":"/provider/foo","random":42},"second":{"body":"{\"foo\":\"bar\"}","headers":{"Accept-Encoding":["gzip"],"User-Agent":["KrakenD Version undefined"],"X-Forwarded-For":["`+localhostIP+`"],"X-Forwarded-Host":["localhost:%d"]},"method":"POST","url":"/recipient/42"},"third":{"body":"{\"foo\":\"bar\"}","headers":{"Accept-Encoding":["gzip"],"User-Agent":["KrakenD Version undefined"],"X-Forwarded-For":["`+localhostIP+`"],"X-Forwarded-Host":["localhost:%d"]},"method":"POST","url":"/recipient/42"}}`, cfg.Port, cfg.Port), }, } { tc := tc t.Run(tc.name, func(t *testing.T) { if tc.method == "" { tc.method = "GET" } var body io.Reader if tc.body != "" { body = bytes.NewBufferString(tc.body) } url := fmt.Sprintf("http://localhost:%d%s", cfg.Port, tc.url) r, _ := http.NewRequest(tc.method, url, body) for k, v := range tc.headers { r.Header.Add(k, v) } resp, err := http.DefaultClient.Do(r) if err != nil { t.Error(err) return } if resp == nil { t.Errorf("%s: nil response", resp.Request.URL.Path) return } expectedStatusCode := http.StatusOK if tc.expStatusCode != 0 { expectedStatusCode = tc.expStatusCode } if resp.StatusCode != expectedStatusCode { t.Errorf("%s: unexpected status code. have: %d, want: %d", resp.Request.URL.Path, resp.StatusCode, expectedStatusCode) } for k, v := range tc.expHeaders { if c := resp.Header.Get(k); !strings.Contains(c, v) { t.Errorf("%s: unexpected header %s: %s", resp.Request.URL.Path, k, c) } } if tc.expBody == "" { return } b, _ := io.ReadAll(resp.Body) resp.Body.Close() if tc.expBody != string(b) { t.Errorf( "%s: unexpected body: %s\n\t%s was expecting: %s", resp.Request.URL.Path, string(b), resp.Request.URL.Path, tc.expBody, ) } }) } } func setupBackend(t *testing.T) (*config.ServiceConfig, error) { data := map[string]interface{}{"port": rand.Intn(2000) + 8080} // param forwarding validation backend b1 := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { if c := r.Header.Get("Content-Type"); c != "application/json" { t.Errorf("unexpected header content-type: %s", c) http.Error(rw, "bad content-type", 400) return } if c := r.Header.Get("Authorization"); c != "bearer AuthorizationToken" { t.Errorf("unexpected header Authorization: %s", c) http.Error(rw, "bad Authorization", 400) return } if c := r.Header.Get("X-Y-Z"); c != "x-y-z" { t.Errorf("unexpected header X-Y-Z: %s", c) http.Error(rw, "bad X-Y-Z", 400) return } body, err := io.ReadAll(r.Body) if err != nil { t.Error(err) return } if string(body) != `{"foo":"bar"}` { t.Errorf("unexpected request body: %s", string(body)) return } rw.Header().Add("Content-Type", "application/json") json.NewEncoder(rw).Encode(map[string]interface{}{"path": r.URL.Path}) })) data["b1"] = b1.URL // collection generator backend b2 := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { rw.Header().Add("Content-Type", "application/json") json.NewEncoder(rw).Encode([]interface{}{ map[string]interface{}{"body": "some content", "date": "123456789"}, map[string]interface{}{"body": "some other content", "date": "123496789"}, }) })) data["b2"] = b2.URL // regular struct generator backend b3 := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { rw.Header().Add("Content-Type", "application/json") json.NewEncoder(rw).Encode(map[string]interface{}{"email": "some@email.com", "name": "a"}) })) data["b3"] = b3.URL // crasher backend b4 := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { http.Error(rw, "sad panda", http.StatusTooManyRequests) })) data["b4"] = b4.URL // slow backend b5 := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { <-time.After(time.Second) rw.Header().Add("Content-Type", "application/json") json.NewEncoder(rw).Encode(map[string]interface{}{"email": "some@email.com", "name": "a"}) })) data["b5"] = b5.URL // querystring-forwarding backend b6 := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { rw.Header().Add("Content-Type", "application/json") if ip := net.ParseIP(r.Header.Get("X-Forwarded-For")); ip == nil || !ip.IsLoopback() { http.Error(rw, "invalid X-Forwarded-For", 400) return } r.Header.Del("X-Forwarded-For") json.NewEncoder(rw).Encode(map[string]interface{}{ "path": r.URL.Path, "query": r.URL.Query(), "headers": r.Header, }) })) data["b6"] = b6.URL // path validator b7 := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { rw.Header().Add("Content-Type", "application/json") json.NewEncoder(rw).Encode(map[string]interface{}{"path": r.URL.Path, "random": 42}) })) data["b7"] = b7.URL // redirect b8 := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { http.Redirect(rw, r, b7.URL, http.StatusMovedPermanently) })) data["b8"] = b8.URL // found b9 := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { http.Redirect(rw, r, b7.URL, http.StatusFound) })) data["b9"] = b9.URL // X-Forwarded-For backend b11 := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { rw.Header().Add("Content-Type", "application/json") json.NewEncoder(rw).Encode(map[string]interface{}{ "headers": r.Header, }) })) data["b11"] = b11.URL // Echo backend b12 := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { rw.Header().Add("Content-Type", "application/json") b, _ := io.ReadAll(r.Body) r.Body.Close() json.NewEncoder(rw).Encode(map[string]interface{}{ "headers": r.Header, "body": string(b), "url": r.URL.String(), "method": r.Method, }) })) data["b12"] = b12.URL c, err := loadConfig(data) if err != nil { return nil, err } return c, nil } func loadConfig(data map[string]interface{}) (*config.ServiceConfig, error) { content, _ := os.ReadFile("lura.json") tmpl, err := template.New("test").Parse(string(content)) if err != nil { return nil, err } buf := new(bytes.Buffer) if err = tmpl.Execute(buf, data); err != nil { return nil, err } c, err := config.NewParserWithFileReader(func(s string) ([]byte, error) { return []byte(s), nil }).Parse(buf.String()) if err != nil { return nil, err } return &c, nil } ================================================ FILE: transport/http/client/executor.go ================================================ // SPDX-License-Identifier: Apache-2.0 /* Package client provides some http helpers to create http clients and executors */ package client import ( "context" "net/http" ) // HTTPRequestExecutor defines the interface of the request executor for the HTTP transport protocol type HTTPRequestExecutor func(ctx context.Context, req *http.Request) (*http.Response, error) // DefaultHTTPRequestExecutor creates a HTTPRequestExecutor with the received HTTPClientFactory func DefaultHTTPRequestExecutor(clientFactory HTTPClientFactory) HTTPRequestExecutor { return func(ctx context.Context, req *http.Request) (*http.Response, error) { return clientFactory(ctx).Do(req.WithContext(ctx)) } } // HTTPClientFactory creates http clients based with the received context type HTTPClientFactory func(ctx context.Context) *http.Client // NewHTTPClient just returns the http default client func NewHTTPClient(_ context.Context) *http.Client { return defaultHTTPClient } var defaultHTTPClient = &http.Client{} ================================================ FILE: transport/http/client/executor_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package client import ( "bytes" "context" "fmt" "io" "net/http" "net/http/httptest" "testing" ) func TestDefaultHTTPRequestExecutor(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, "Hello, client") })) defer ts.Close() re := DefaultHTTPRequestExecutor(NewHTTPClient) req, _ := http.NewRequest("GET", ts.URL, io.NopCloser(&bytes.Buffer{})) resp, err := re(context.Background(), req) if err != nil { t.Error("unexpected error:", err.Error()) } if resp.StatusCode != http.StatusOK { t.Error("unexpected status code:", resp.StatusCode) } } ================================================ FILE: transport/http/client/graphql/graphql.go ================================================ // SPDX-License-Identifier: Apache-2.0 /* Package graphql offers a param extractor and basic types for building GraphQL requests */ package graphql import ( "encoding/json" "errors" "io" "net/http" "net/url" "os" "strings" "github.com/luraproject/lura/v2/config" "golang.org/x/text/cases" "golang.org/x/text/language" ) // Namespace is the key for the backend's extra config const Namespace = "github.com/devopsfaith/krakend/transport/http/client/graphql" // OperationType contains all the operations allowed by graphql type OperationType string // OperationMethod details the method to be used with the request type OperationMethod string const ( // OperationMutation marks an operation as a mutation OperationMutation OperationType = "mutation" // OperationQuery marks an operation as a query OperationQuery OperationType = "query" MethodPost OperationMethod = http.MethodPost MethodGet OperationMethod = http.MethodGet ) // GraphQLRequest represents the graphql request body type GraphQLRequest struct { Query string `json:"query"` OperationName string `json:"operationName,omitempty"` Variables map[string]interface{} `json:"variables,omitempty"` } // Options defines a GraphQLRequest with a type, so the middlewares know what to do type Options struct { GraphQLRequest QueryPath string `json:"query_path,omitempty"` Type OperationType `json:"type"` Method OperationMethod `json:"method"` } var ErrNoConfigFound = errors.New("grapghql: no configuration found") // GetOptions extracts the Options config from the backend's extra config func GetOptions(cfg config.ExtraConfig) (*Options, error) { tmp, ok := cfg[Namespace] if !ok { return nil, ErrNoConfigFound } b, err := json.Marshal(tmp) if err != nil { return nil, err } var opt Options if err := json.Unmarshal(b, &opt); err != nil { return nil, err } opt.Method = OperationMethod(strings.ToUpper(string(opt.Method))) opt.Type = OperationType(strings.ToLower(string(opt.Type))) if opt.Method != MethodGet && opt.Method != MethodPost { opt.Method = MethodPost } if opt.QueryPath != "" { q, err := os.ReadFile(opt.QueryPath) if err != nil { return nil, err } opt.Query = string(q) } return &opt, nil } // New resturns a new Extractor, ready to be use on a middleware func New(opt Options) *Extractor { var replacements [][2]string title := cases.Title(language.Und) for k, v := range opt.Variables { val, ok := v.(string) if !ok { continue } if val[0] == '{' && val[len(val)-1] == '}' { replacements = append(replacements, [2]string{k, title.String(val[1:2]) + val[2:len(val)-1]}) } } if len(replacements) == 0 { b, _ := json.Marshal(opt.GraphQLRequest) return &Extractor{ cfg: opt, paramExtractor: func(map[string]string) (*GraphQLRequest, error) { return &opt.GraphQLRequest, nil }, newBody: func(_ map[string]string) ([]byte, error) { return b, nil }, } } paramExtractor := func(params map[string]string) (*GraphQLRequest, error) { val := GraphQLRequest{ Query: opt.Query, OperationName: opt.OperationName, Variables: map[string]interface{}{}, } for k, v := range opt.Variables { val.Variables[k] = v } for _, vs := range replacements { val.Variables[vs[0]] = params[vs[1]] } return &val, nil } return &Extractor{ cfg: opt, paramExtractor: paramExtractor, newBody: func(params map[string]string) ([]byte, error) { val, err := paramExtractor(params) if err != nil { return []byte{}, err } return json.Marshal(val) }, } } // Extractor exposes two extractor factories: one for the params (query) and one // for the request body (mutator) type Extractor struct { cfg Options paramExtractor func(map[string]string) (*GraphQLRequest, error) newBody func(map[string]string) ([]byte, error) } // QueryFromBody returns a url.Values containing the graphql request with the given query and the default variables // overiden by the request body func (e *Extractor) QueryFromBody(r io.Reader) (url.Values, error) { gr, err := e.fromBody(r) if err != nil { return nil, err } vars := url.Values{} vars.Add("query", gr.Query) if gr.OperationName != "" { vars.Add("operationName", gr.OperationName) } if len(gr.Variables) != 0 { encodedVars, _ := json.Marshal(gr.Variables) vars.Add("variables", string(encodedVars)) } return vars, nil } // BodyFromBody returns a request body containing the graphql request with the given query and the default variables // overiden by the request body func (e *Extractor) BodyFromBody(r io.Reader) ([]byte, error) { v, err := e.fromBody(r) if err != nil { return []byte{}, err } return json.Marshal(v) } func (e *Extractor) fromBody(r io.Reader) (*GraphQLRequest, error) { b, err := io.ReadAll(r) if err != nil { return nil, err } vars := map[string]interface{}{} if err := json.Unmarshal(b, &vars); err != nil { return nil, err } for k, v := range e.cfg.Variables { if _, ok := vars[k]; ok { continue } vars[k] = v } return &GraphQLRequest{ Query: e.cfg.Query, OperationName: e.cfg.OperationName, Variables: vars, }, nil } // QueryFromParams returns a url.Values containing the grapql request generated for the given query and the default // variables overiden by the request params func (e *Extractor) QueryFromParams(params map[string]string) (url.Values, error) { gr, err := e.paramExtractor(params) if err != nil { return nil, err } vars := url.Values{} vars.Add("query", gr.Query) if gr.OperationName != "" { vars.Add("operationName", gr.OperationName) } if len(gr.Variables) != 0 { encodedVars, _ := json.Marshal(gr.Variables) vars.Add("variables", string(encodedVars)) } return vars, nil } // BodyFromParams returns a request body containing the grapql request generated for the given query and the default // variables overiden by the request params func (e *Extractor) BodyFromParams(params map[string]string) ([]byte, error) { return e.newBody(params) } ================================================ FILE: transport/http/client/graphql/graphql_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package graphql import ( "fmt" "os" "strings" "github.com/luraproject/lura/v2/config" ) func ExampleExtractor() { cfg, err := GetOptions(config.ExtraConfig{ Namespace: map[string]interface{}{ "type": OperationQuery, "query": "{\n find_follower(func: uid(\"0x3\")) {\n name \n }\n }\n", "variables": map[string]interface{}{ "foo": "{foo}", "bar": "1234abc", }, }, }) if err != nil { fmt.Println(err) return } extractor := New(*cfg) { fmt.Println("BodyFromParams") body, err := extractor.BodyFromParams(map[string]string{ "Foo": "foobar", }) if err != nil { fmt.Println(err) return } fmt.Println(string(body)) } { fmt.Println("QueryFromParams") query, err := extractor.QueryFromParams(map[string]string{ "Foo": "foobar", }) if err != nil { fmt.Println(err) return } fmt.Println(query.Encode()) } { fmt.Println("BodyFromBody") body, err := extractor.BodyFromBody(strings.NewReader(`{ "foo": "foobar", "foo1": "foobar" }`)) if err != nil { fmt.Println(err) return } fmt.Println(string(body)) } { fmt.Println("QueryFromBody") query, err := extractor.QueryFromBody(strings.NewReader(`{ "foo": "foobar", "foo1": "foobar" }`)) if err != nil { fmt.Println(err) return } fmt.Println(query.Encode()) } // output: // BodyFromParams // {"query":"{\n find_follower(func: uid(\"0x3\")) {\n name \n }\n }\n","variables":{"bar":"1234abc","foo":"foobar"}} // QueryFromParams // query=%7B%0A++find_follower%28func%3A+uid%28%220x3%22%29%29+%7B%0A++++name+%0A++++%7D%0A++%7D%0A&variables=%7B%22bar%22%3A%221234abc%22%2C%22foo%22%3A%22foobar%22%7D // BodyFromBody // {"query":"{\n find_follower(func: uid(\"0x3\")) {\n name \n }\n }\n","variables":{"bar":"1234abc","foo":"foobar","foo1":"foobar"}} // QueryFromBody // query=%7B%0A++find_follower%28func%3A+uid%28%220x3%22%29%29+%7B%0A++++name+%0A++++%7D%0A++%7D%0A&variables=%7B%22bar%22%3A%221234abc%22%2C%22foo%22%3A%22foobar%22%2C%22foo1%22%3A%22foobar%22%7D } func ExampleExtractor_fromFile() { os.WriteFile(".graphql_query.txt", []byte("{\n find_follower(func: uid(\"0x3\")) {\n name \n }\n }\n"), 0664) defer os.Remove(".graphql_query.txt") cfg, err := GetOptions(config.ExtraConfig{ Namespace: map[string]interface{}{ "type": OperationQuery, "query_path": ".graphql_query.txt", "variables": map[string]interface{}{ "foo": "{foo}", "bar": "1234abc", }, }, }) if err != nil { fmt.Println(err) return } extractor := New(*cfg) { fmt.Println("BodyFromParams") body, err := extractor.BodyFromParams(map[string]string{ "Foo": "foobar", }) if err != nil { fmt.Println(err) return } fmt.Println(string(body)) } { fmt.Println("QueryFromParams") query, err := extractor.QueryFromParams(map[string]string{ "Foo": "foobar", }) if err != nil { fmt.Println(err) return } fmt.Println(query.Encode()) } { fmt.Println("BodyFromBody") body, err := extractor.BodyFromBody(strings.NewReader(`{ "foo": "foobar", "foo1": "foobar" }`)) if err != nil { fmt.Println(err) return } fmt.Println(string(body)) } { fmt.Println("QueryFromBody") query, err := extractor.QueryFromBody(strings.NewReader(`{ "foo": "foobar", "foo1": "foobar" }`)) if err != nil { fmt.Println(err) return } fmt.Println(query.Encode()) } // output: // BodyFromParams // {"query":"{\n find_follower(func: uid(\"0x3\")) {\n name \n }\n }\n","variables":{"bar":"1234abc","foo":"foobar"}} // QueryFromParams // query=%7B%0A++find_follower%28func%3A+uid%28%220x3%22%29%29+%7B%0A++++name+%0A++++%7D%0A++%7D%0A&variables=%7B%22bar%22%3A%221234abc%22%2C%22foo%22%3A%22foobar%22%7D // BodyFromBody // {"query":"{\n find_follower(func: uid(\"0x3\")) {\n name \n }\n }\n","variables":{"bar":"1234abc","foo":"foobar","foo1":"foobar"}} // QueryFromBody // query=%7B%0A++find_follower%28func%3A+uid%28%220x3%22%29%29+%7B%0A++++name+%0A++++%7D%0A++%7D%0A&variables=%7B%22bar%22%3A%221234abc%22%2C%22foo%22%3A%22foobar%22%2C%22foo1%22%3A%22foobar%22%7D } func ExampleExtractor_noReplacement() { cfg, err := GetOptions(config.ExtraConfig{ Namespace: map[string]interface{}{ "type": OperationQuery, "query": "{\n find_follower(func: uid(\"0x3\")) {\n name \n }\n }\n", "variables": map[string]interface{}{ "bar": "1234abc", }, }, }) if err != nil { fmt.Println(err) return } extractor := New(*cfg) { fmt.Println("BodyFromParams") body, err := extractor.BodyFromParams(map[string]string{ "Foo": "foobar", }) if err != nil { fmt.Println(err) return } fmt.Println(string(body)) } { fmt.Println("QueryFromParams") query, err := extractor.QueryFromParams(map[string]string{ "Foo": "foobar", }) if err != nil { fmt.Println(err) return } fmt.Println(query.Encode()) } { fmt.Println("BodyFromBody") body, err := extractor.BodyFromBody(strings.NewReader(`{ "foo": "foobar", "foo1": "foobar" }`)) if err != nil { fmt.Println(err) return } fmt.Println(string(body)) } { fmt.Println("QueryFromBody") query, err := extractor.QueryFromBody(strings.NewReader(`{ "foo": "foobar", "foo1": "foobar" }`)) if err != nil { fmt.Println(err) return } fmt.Println(query.Encode()) } // output: // BodyFromParams // {"query":"{\n find_follower(func: uid(\"0x3\")) {\n name \n }\n }\n","variables":{"bar":"1234abc"}} // QueryFromParams // query=%7B%0A++find_follower%28func%3A+uid%28%220x3%22%29%29+%7B%0A++++name+%0A++++%7D%0A++%7D%0A&variables=%7B%22bar%22%3A%221234abc%22%7D // BodyFromBody // {"query":"{\n find_follower(func: uid(\"0x3\")) {\n name \n }\n }\n","variables":{"bar":"1234abc","foo":"foobar","foo1":"foobar"}} // QueryFromBody // query=%7B%0A++find_follower%28func%3A+uid%28%220x3%22%29%29+%7B%0A++++name+%0A++++%7D%0A++%7D%0A&variables=%7B%22bar%22%3A%221234abc%22%2C%22foo%22%3A%22foobar%22%2C%22foo1%22%3A%22foobar%22%7D } ================================================ FILE: transport/http/client/plugin/doc.go ================================================ // SPDX-License-Identifier: Apache-2.0 //Package plugin provides plugin register interfaces for building http client plugins. // // Usage example: // // package main // // import ( // "context" // "errors" // "fmt" // "html" // "net/http" // ) // // // ClientRegisterer is the symbol the plugin loader will try to load. It must implement the RegisterClient interface // var ClientRegisterer = registerer("lura-example") // // type registerer string // // func (r registerer) RegisterClients(f func( // name string, // handler func(context.Context, map[string]interface{}) (http.Handler, error), // )) { // f(string(r), r.registerClients) // } // // func (r registerer) registerClients(ctx context.Context, extra map[string]interface{}) (http.Handler, error) { // // check the passed configuration and initialize the plugin // name, ok := extra["name"].(string) // if !ok { // return nil, errors.New("wrong config") // } // if name != string(r) { // return nil, fmt.Errorf("unknown register %s", name) // } // // return the actual handler wrapping or your custom logic so it can be used as a replacement for the default http client // return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { // fmt.Fprintf(w, "Hello, %q", html.EscapeString(req.URL.Path)) // }), nil // } // // func init() { // fmt.Println("lura-example client plugin loaded!!!") // } // // func main() {} package plugin ================================================ FILE: transport/http/client/plugin/executor.go ================================================ // SPDX-License-Identifier: Apache-2.0 package plugin import ( "context" "fmt" "net/http" "net/http/httptest" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" "github.com/luraproject/lura/v2/transport/http/client" ) const Namespace = "github.com/devopsfaith/krakend/transport/http/client/executor" func HTTPRequestExecutor( logger logging.Logger, next func(*config.Backend) client.HTTPRequestExecutor, ) func(*config.Backend) client.HTTPRequestExecutor { return HTTPRequestExecutorWithContext(context.Background(), logger, next) } func HTTPRequestExecutorWithContext( ctx context.Context, logger logging.Logger, next func(*config.Backend) client.HTTPRequestExecutor, ) func(*config.Backend) client.HTTPRequestExecutor { return func(cfg *config.Backend) client.HTTPRequestExecutor { logPrefix := fmt.Sprintf("[BACKEND: %s %s -> %s]", cfg.ParentEndpointMethod, cfg.ParentEndpoint, cfg.URLPattern) v, ok := cfg.ExtraConfig[Namespace] if !ok { return next(cfg) } extra, ok := v.(map[string]interface{}) if !ok { logger.Debug(logPrefix, "["+Namespace+"]", "Wrong extra config type for backend") return next(cfg) } // load plugin r, ok := clientRegister.Get(Namespace) if !ok { logger.Debug(logPrefix, "No plugins registered for the module") return next(cfg) } name, ok := extra["name"].(string) if !ok { logger.Debug(logPrefix, "No name defined in the extra config for", cfg.URLPattern) return next(cfg) } rawHf, ok := r.Get(name) if !ok { logger.Debug(logPrefix, "No plugin registered as", name) return next(cfg) } hf, ok := rawHf.(func(context.Context, map[string]interface{}) (http.Handler, error)) if !ok { logger.Warning(logPrefix, "Wrong plugin handler type:", name) return next(cfg) } handler, err := hf(ctx, extra) if err != nil { logger.Warning(logPrefix, "Error getting the plugin handler:", err.Error()) return next(cfg) } logger.Debug(logPrefix, "Injecting plugin", name) return func(ctx context.Context, req *http.Request) (*http.Response, error) { w := httptest.NewRecorder() handler.ServeHTTP(w, req.WithContext(ctx)) return w.Result(), nil } } } ================================================ FILE: transport/http/client/plugin/plugin.go ================================================ // SPDX-License-Identifier: Apache-2.0 package plugin import ( "context" "fmt" "net/http" "plugin" "strings" "github.com/luraproject/lura/v2/logging" luraplugin "github.com/luraproject/lura/v2/plugin" "github.com/luraproject/lura/v2/register" ) var clientRegister = register.New() func RegisterClient( name string, handler func(context.Context, map[string]interface{}) (http.Handler, error), ) { clientRegister.Register(Namespace, name, handler) } type Registerer interface { RegisterClients(func( name string, handler func(context.Context, map[string]interface{}) (http.Handler, error), )) } type LoggerRegisterer interface { RegisterLogger(interface{}) } type RegisterClientFunc func( name string, handler func(context.Context, map[string]interface{}) (http.Handler, error), ) func Load(path, pattern string, rcf RegisterClientFunc) (int, error) { return LoadWithLogger(path, pattern, rcf, nil) } func LoadWithLogger(path, pattern string, rcf RegisterClientFunc, logger logging.Logger) (int, error) { plugins, err := luraplugin.Scan(path, pattern) if err != nil { return 0, err } return load(plugins, rcf, logger) } func load(plugins []string, rcf RegisterClientFunc, logger logging.Logger) (int, error) { var errors []error loadedPlugins := 0 for k, pluginName := range plugins { if err := open(pluginName, rcf, logger); err != nil { errors = append(errors, fmt.Errorf("plugin #%d (%s): %s", k, pluginName, err.Error())) continue } loadedPlugins++ } if len(errors) > 0 { return loadedPlugins, loaderError{errors: errors} } return loadedPlugins, nil } func open(pluginName string, rcf RegisterClientFunc, logger logging.Logger) (err error) { defer func() { if r := recover(); r != nil { var ok bool err, ok = r.(error) if !ok { err = fmt.Errorf("%v", r) } } }() var p Plugin p, err = pluginOpener(pluginName) if err != nil { return } var r interface{} r, err = p.Lookup("ClientRegisterer") if err != nil { return } registerer, ok := r.(Registerer) if !ok { return fmt.Errorf("http-request-executor plugin loader: unknown type") } if logger != nil { if lr, ok := r.(LoggerRegisterer); ok { lr.RegisterLogger(logger) } } RegisterExtraComponents(r) registerer.RegisterClients(rcf) return } var RegisterExtraComponents = func(interface{}) {} // Plugin is the interface of the loaded plugins type Plugin interface { Lookup(name string) (plugin.Symbol, error) } // pluginOpener keeps the plugin open function in a var for easy testing var pluginOpener = defaultPluginOpener func defaultPluginOpener(name string) (Plugin, error) { return plugin.Open(name) } type loaderError struct { errors []error } // Error implements the error interface func (l loaderError) Error() string { msgs := make([]string, len(l.errors)) for i, err := range l.errors { msgs[i] = err.Error() } return fmt.Sprintf("plugin loader found %d error(s): \n%s", len(msgs), strings.Join(msgs, "\n")) } func (l loaderError) Len() int { return len(l.errors) } func (l loaderError) Errs() []error { return l.errors } ================================================ FILE: transport/http/client/plugin/plugin_test.go ================================================ //go:build integration || !race // +build integration !race // SPDX-License-Identifier: Apache-2.0 package plugin import ( "bytes" "context" "fmt" "io" "net/http" "testing" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" "github.com/luraproject/lura/v2/transport/http/client" ) func TestLoadWithLogger(t *testing.T) { buff := new(bytes.Buffer) l, _ := logging.NewLogger("DEBUG", buff, "") total, err := LoadWithLogger("./tests", ".so", RegisterClient, l) if err != nil { t.Error(err.Error()) t.Fail() } if total != 1 { t.Errorf("unexpected number of loaded plugins!. have %d, want 1", total) } hre := HTTPRequestExecutor(l, func(_ *config.Backend) client.HTTPRequestExecutor { t.Error("this factory should not been called") t.Fail() return nil }) h := hre(&config.Backend{ ExtraConfig: map[string]interface{}{ Namespace: map[string]interface{}{ "name": "krakend-client-example", }, }, }) req, _ := http.NewRequest("GET", "http://some.example.tld/path", http.NoBody) resp, err := h(context.Background(), req) if err != nil { t.Errorf("unexpected error: %s", err.Error()) return } b, err := io.ReadAll(resp.Body) if err != nil { t.Error(err) return } resp.Body.Close() if string(b) != "Hello, \"/path\"" { t.Errorf("unexpected response body: %s", string(b)) } fmt.Println(buff.String()) } ================================================ FILE: transport/http/client/plugin/tests/main.go ================================================ // SPDX-License-Identifier: Apache-2.0 package main import ( "context" "errors" "fmt" "html" "net/http" ) // ClientRegisterer is the symbol the plugin loader will try to load. It must implement the RegisterClient interface var ClientRegisterer = registerer("krakend-client-example") type registerer string var logger Logger = nil func (registerer) RegisterLogger(v interface{}) { l, ok := v.(Logger) if !ok { return } logger = l logger.Debug(fmt.Sprintf("[PLUGIN: %s] Logger loaded", ClientRegisterer)) } func (r registerer) RegisterClients(f func( name string, handler func(context.Context, map[string]interface{}) (http.Handler, error), )) { f(string(r), r.registerClients) } func (r registerer) registerClients(_ context.Context, extra map[string]interface{}) (http.Handler, error) { // check the passed configuration and initialize the plugin name, ok := extra["name"].(string) if !ok { return nil, errors.New("wrong config") } if name != string(r) { return nil, fmt.Errorf("unknown register %s", name) } if logger == nil { // return the actual handler wrapping or your custom logic so it can be used as a replacement for the default http client return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { fmt.Fprintf(w, "Hello, %q", html.EscapeString(req.URL.Path)) }), nil } // return the actual handler wrapping or your custom logic so it can be used as a replacement for the default http client return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { fmt.Fprintf(w, "Hello, %q", html.EscapeString(req.URL.Path)) logger.Debug("request:", html.EscapeString(req.URL.Path)) }), nil } func main() {} type Logger interface { Debug(v ...interface{}) Info(v ...interface{}) Warning(v ...interface{}) Error(v ...interface{}) Critical(v ...interface{}) Fatal(v ...interface{}) } ================================================ FILE: transport/http/client/status.go ================================================ // SPDX-License-Identifier: Apache-2.0 package client import ( "bytes" "context" "errors" "fmt" "io" "net/http" "github.com/luraproject/lura/v2/config" ) // Namespace to be used in extra config const Namespace = "github.com/devopsfaith/krakend/http" // ErrInvalidStatusCode is the error returned by the http proxy when the received status code // is not a 200 nor a 201 var ErrInvalidStatusCode = errors.New("invalid status code") type ErrInvalidStatus struct { statusCode int errPrefix string path string } func (e *ErrInvalidStatus) Error() string { return fmt.Sprintf("invalid status code %d %s %s", e.statusCode, e.errPrefix, e.path) } func NewErrInvalidStatusCode(resp *http.Response, errPrefix string) *ErrInvalidStatus { var p string if resp.Request != nil && resp.Request.URL != nil { p = resp.Request.URL.String() } return &ErrInvalidStatus{ statusCode: resp.StatusCode, errPrefix: errPrefix, path: p, } } // HTTPStatusHandler defines how we tread the http response code type HTTPStatusHandler func(context.Context, *http.Response) (*http.Response, error) // GetHTTPStatusHandler returns a status handler. If the 'return_error_details' key is defined // at the extra config, it returns a DetailedHTTPStatusHandler. Otherwise, it returns a // DefaultHTTPStatusHandler func GetHTTPStatusHandler(remote *config.Backend) HTTPStatusHandler { errPrefix := fmt.Sprintf("[%s %s]:", remote.Method, remote.URLPattern) if e, ok := remote.ExtraConfig[Namespace]; ok { if m, ok := e.(map[string]interface{}); ok { if v, ok := m["return_error_details"]; ok { if b, ok := v.(string); ok && b != "" { return DetailedHTTPStatusHandlerWithErrPrefix(b, errPrefix) } } else if v, ok := m["return_error_code"].(bool); ok && v { return ErrorHTTPStatusHandlerWithErrPrefix(errPrefix) } } } return DefaultHTTPStatusHandlerWithErrPrefix(errPrefix) } // DefaultHTTPStatusHandler is the default implementation of HTTPStatusHandler func DefaultHTTPStatusHandler(_ context.Context, resp *http.Response) (*http.Response, error) { if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { return nil, ErrInvalidStatusCode } return resp, nil } // DefaultHTTPStatusHandlerWithErrPrefix is the default implementation of HTTPStatusHandler // with information about the failing status code, and the failed request func DefaultHTTPStatusHandlerWithErrPrefix(errPrefix string) HTTPStatusHandler { return func(_ context.Context, resp *http.Response) (*http.Response, error) { if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { return nil, NewErrInvalidStatusCode(resp, errPrefix) } return resp, nil } } // ErrorHTTPStatusHandler is a HTTPStatusHandler that returns the status code as part of the error details func ErrorHTTPStatusHandler(ctx context.Context, resp *http.Response) (*http.Response, error) { if _, err := DefaultHTTPStatusHandler(ctx, resp); err == nil { return resp, nil } return resp, newHTTPResponseError(resp) } // ErrorHTTPStatusHandlerWithErrPrefix is a HTTPStatusHandler that returns the status code as part of the error details func ErrorHTTPStatusHandlerWithErrPrefix(errPrefix string) HTTPStatusHandler { defaultH := DefaultHTTPStatusHandlerWithErrPrefix(errPrefix) return func(ctx context.Context, resp *http.Response) (*http.Response, error) { if _, err := defaultH(ctx, resp); err == nil { return resp, nil } return resp, newHTTPResponseError(resp) } } // NoOpHTTPStatusHandler is a NO-OP implementation of HTTPStatusHandler func NoOpHTTPStatusHandler(_ context.Context, resp *http.Response) (*http.Response, error) { return resp, nil } // DetailedHTTPStatusHandler is a HTTPStatusHandler implementation func DetailedHTTPStatusHandler(name string) HTTPStatusHandler { return func(ctx context.Context, resp *http.Response) (*http.Response, error) { if _, err := DefaultHTTPStatusHandler(ctx, resp); err == nil { return resp, nil } return resp, NamedHTTPResponseError{ HTTPResponseError: newHTTPResponseError(resp), name: name, } } } // DetailedHTTPStatusHandlerWithErrPrefix is a HTTPStatusHandlers that // can receive an error prefix to be added when an error happens to help // identify the endpoint using this handler. func DetailedHTTPStatusHandlerWithErrPrefix(name, errPrefix string) HTTPStatusHandler { defaultH := DefaultHTTPStatusHandlerWithErrPrefix(errPrefix) return func(ctx context.Context, resp *http.Response) (*http.Response, error) { if _, err := defaultH(ctx, resp); err == nil { return resp, nil } return resp, NamedHTTPResponseError{ HTTPResponseError: newHTTPResponseError(resp), name: name, } } } func newHTTPResponseError(resp *http.Response) HTTPResponseError { body, err := io.ReadAll(resp.Body) if err != nil { body = []byte{} } resp.Body.Close() resp.Body = io.NopCloser(bytes.NewBuffer(body)) return HTTPResponseError{ Code: resp.StatusCode, Msg: string(body), Enc: resp.Header.Get("Content-Type"), } } // HTTPResponseError is the error to be returned by the ErrorHTTPStatusHandler type HTTPResponseError struct { Code int `json:"http_status_code"` Msg string `json:"http_body,omitempty"` Enc string `json:"http_body_encoding,omitempty"` } // Error returns the error message func (r HTTPResponseError) Error() string { return r.Msg } // StatusCode returns the status code returned by the backend func (r HTTPResponseError) StatusCode() int { return r.Code } // Encoding returns the content type returned by the backend func (r HTTPResponseError) Encoding() string { return r.Enc } // NamedHTTPResponseError is the error to be returned by the DetailedHTTPStatusHandler type NamedHTTPResponseError struct { HTTPResponseError name string } // Name returns the name of the backend where the error happened func (r NamedHTTPResponseError) Name() string { return r.name } ================================================ FILE: transport/http/client/status_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package client import ( "bytes" "context" "fmt" "io" "net/http" "strings" "testing" "github.com/luraproject/lura/v2/config" ) func TestDetailedHTTPStatusHandler(t *testing.T) { expectedErrName := "some" expectedEncoding := "application/json; charset=utf-8" cfg := &config.Backend{ ExtraConfig: config.ExtraConfig{ Namespace: map[string]interface{}{ "return_error_details": expectedErrName, }, }, } sh := GetHTTPStatusHandler(cfg) for _, code := range []int{http.StatusOK, http.StatusCreated} { resp := &http.Response{ StatusCode: code, Body: io.NopCloser(bytes.NewBufferString(`{"foo":"bar"}`)), } r, err := sh(context.Background(), resp) if r != resp { t.Errorf("#%d unexpected response: %v", code, r) return } if err != nil { t.Errorf("#%d unexpected error: %s", code, err.Error()) return } } for i, code := range statusCodes { msg := http.StatusText(code) resp := &http.Response{ StatusCode: code, Body: io.NopCloser(bytes.NewBufferString(fmt.Sprintf(`{"msg":%q}`, msg))), Header: http.Header{"Content-Type": []string{expectedEncoding}}, } r, err := sh(context.Background(), resp) if r != resp { t.Errorf("#%d unexpected response: %v", i, r) return } e, ok := err.(NamedHTTPResponseError) if !ok { t.Errorf("#%d unexpected error type %T: %s", i, err, err.Error()) return } if e.StatusCode() != code { t.Errorf("#%d unexpected status code: %d", i, e.Code) return } if e.Error() != fmt.Sprintf(`{"msg":%q}`, msg) { t.Errorf("#%d unexpected message: %s", i, e.Msg) return } if e.Name() != expectedErrName { t.Errorf("#%d unexpected error name: %s", i, e.name) return } if e.Encoding() != expectedEncoding { t.Errorf("#%d unexpected encoding: %s", i, e.Enc) } } } func TestDefaultHTTPStatusHandler(t *testing.T) { sh := GetHTTPStatusHandler(&config.Backend{}) for _, code := range []int{http.StatusOK, http.StatusCreated} { resp := &http.Response{ StatusCode: code, Body: io.NopCloser(bytes.NewBufferString(`{"foo":"bar"}`)), } r, err := sh(context.Background(), resp) if r != resp { t.Errorf("#%d unexpected response: %v", code, r) return } if err != nil { t.Errorf("#%d unexpected error: %s", code, err.Error()) return } } for _, code := range statusCodes { msg := http.StatusText(code) resp := &http.Response{ StatusCode: code, Body: io.NopCloser(bytes.NewBufferString(msg)), } r, err := sh(context.Background(), resp) if r != nil { t.Errorf("#%d unexpected response: %v", code, r) return } if !strings.HasPrefix(err.Error(), "invalid status code") { t.Errorf("#%d unexpected error: %v", code, err) return } } } var statusCodes = []int{ http.StatusBadRequest, http.StatusUnauthorized, http.StatusPaymentRequired, http.StatusForbidden, http.StatusNotFound, http.StatusMethodNotAllowed, http.StatusNotAcceptable, http.StatusProxyAuthRequired, http.StatusRequestTimeout, http.StatusConflict, http.StatusGone, http.StatusLengthRequired, http.StatusPreconditionFailed, http.StatusRequestEntityTooLarge, http.StatusRequestURITooLong, http.StatusUnsupportedMediaType, http.StatusRequestedRangeNotSatisfiable, http.StatusExpectationFailed, http.StatusTeapot, // http.StatusMisdirectedRequest, http.StatusUnprocessableEntity, http.StatusLocked, http.StatusFailedDependency, http.StatusUpgradeRequired, http.StatusPreconditionRequired, http.StatusTooManyRequests, http.StatusRequestHeaderFieldsTooLarge, http.StatusUnavailableForLegalReasons, http.StatusInternalServerError, http.StatusNotImplemented, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout, http.StatusHTTPVersionNotSupported, http.StatusVariantAlsoNegotiates, http.StatusInsufficientStorage, http.StatusLoopDetected, http.StatusNotExtended, http.StatusNetworkAuthenticationRequired, } ================================================ FILE: transport/http/server/plugin/doc.go ================================================ // SPDX-License-Identifier: Apache-2.0 //Package plugin provides plugin register interfaces for building http handler plugins. // // Usage example: // // package main // // import ( // "context" // "errors" // "fmt" // "html" // "net/http" // ) // // // HandlerRegisterer is the symbol the plugin loader will try to load. It must implement the Registerer interface // var HandlerRegisterer = registerer("lura-example") // // type registerer string // // func (r registerer) RegisterHandlers(f func( // name string, // handler func(context.Context, map[string]interface{}, http.Handler) (http.Handler, error), // )) { // f(string(r), r.registerHandlers) // } // // func (r registerer) registerHandlers(ctx context.Context, extra map[string]interface{}, _ http.Handler) (http.Handler, error) { // // check the passed configuration and initialize the plugin // name, ok := extra["name"].(string) // if !ok { // return nil, errors.New("wrong config") // } // if name != string(r) { // return nil, fmt.Errorf("unknown register %s", name) // } // // return the actual handler wrapping or your custom logic so it can be used as a replacement for the default http handler // return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { // fmt.Fprintf(w, "Hello, %q", html.EscapeString(req.URL.Path)) // }), nil // } // // func init() { // fmt.Println("lura-example handler plugin loaded!!!") // } // // func main() {} package plugin ================================================ FILE: transport/http/server/plugin/plugin.go ================================================ // SPDX-License-Identifier: Apache-2.0 package plugin import ( "context" "fmt" "net/http" "plugin" "strings" "github.com/luraproject/lura/v2/logging" luraplugin "github.com/luraproject/lura/v2/plugin" "github.com/luraproject/lura/v2/register" ) var serverRegister = register.New() func RegisterHandler( name string, handler func(context.Context, map[string]interface{}, http.Handler) (http.Handler, error), ) { serverRegister.Register(Namespace, name, handler) } type Registerer interface { RegisterHandlers(func( name string, handler func(context.Context, map[string]interface{}, http.Handler) (http.Handler, error), )) } type LoggerRegisterer interface { RegisterLogger(interface{}) } type RegisterHandlerFunc func( name string, handler func(context.Context, map[string]interface{}, http.Handler) (http.Handler, error), ) func Load(path, pattern string, rcf RegisterHandlerFunc) (int, error) { return LoadWithLogger(path, pattern, rcf, nil) } func LoadWithLogger(path, pattern string, rcf RegisterHandlerFunc, logger logging.Logger) (int, error) { plugins, err := luraplugin.Scan(path, pattern) if err != nil { return 0, err } return load(plugins, rcf, logger) } func load(plugins []string, rcf RegisterHandlerFunc, logger logging.Logger) (int, error) { var errors []error loadedPlugins := 0 for k, pluginName := range plugins { if err := open(pluginName, rcf, logger); err != nil { errors = append(errors, fmt.Errorf("plugin #%d (%s): %s", k, pluginName, err.Error())) continue } loadedPlugins++ } if len(errors) > 0 { return loadedPlugins, loaderError{errors: errors} } return loadedPlugins, nil } func open(pluginName string, rcf RegisterHandlerFunc, logger logging.Logger) (err error) { defer func() { if r := recover(); r != nil { var ok bool err, ok = r.(error) if !ok { err = fmt.Errorf("%v", r) } } }() var p Plugin p, err = pluginOpener(pluginName) if err != nil { return } var r interface{} r, err = p.Lookup("HandlerRegisterer") if err != nil { return } registerer, ok := r.(Registerer) if !ok { return fmt.Errorf("http-server-handler plugin loader: unknown type") } if logger != nil { if lr, ok := r.(LoggerRegisterer); ok { lr.RegisterLogger(logger) } } RegisterExtraComponents(r) registerer.RegisterHandlers(rcf) return } var RegisterExtraComponents = func(interface{}) {} // Plugin is the interface of the loaded plugins type Plugin interface { Lookup(name string) (plugin.Symbol, error) } // pluginOpener keeps the plugin open function in a var for easy testing var pluginOpener = defaultPluginOpener func defaultPluginOpener(name string) (Plugin, error) { return plugin.Open(name) } type loaderError struct { errors []error } // Error implements the error interface func (l loaderError) Error() string { msgs := make([]string, len(l.errors)) for i, err := range l.errors { msgs[i] = err.Error() } return fmt.Sprintf("plugin loader found %d error(s): \n%s", len(msgs), strings.Join(msgs, "\n")) } func (l loaderError) Len() int { return len(l.errors) } func (l loaderError) Errs() []error { return l.errors } ================================================ FILE: transport/http/server/plugin/plugin_test.go ================================================ //go:build integration || !race // +build integration !race // SPDX-License-Identifier: Apache-2.0 package plugin import ( "bytes" "context" "fmt" "io" "net/http" "net/http/httptest" "testing" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" ) func TestLoadWithLogger(t *testing.T) { buff := new(bytes.Buffer) l, _ := logging.NewLogger("DEBUG", buff, "") total, err := LoadWithLogger("./tests", ".so", RegisterHandler, l) if err != nil { t.Error(err.Error()) t.Fail() } if total != 1 { t.Errorf("unexpected number of loaded plugins!. have %d, want 1", total) } var handler http.Handler hre := New(l, func(_ context.Context, _ config.ServiceConfig, h http.Handler) error { handler = h return nil }) if err := hre( context.Background(), config.ServiceConfig{ ExtraConfig: map[string]interface{}{ Namespace: map[string]interface{}{ "name": "krakend-server-example", }, }, }, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Error("this handler should not been called") }), ); err != nil { t.Error(err) return } req, _ := http.NewRequest("GET", "http://some.example.tld/path", http.NoBody) w := httptest.NewRecorder() handler.ServeHTTP(w, req) resp := w.Result() if resp.StatusCode != 200 { t.Errorf("unexpected status code: %d", resp.StatusCode) return } b, err := io.ReadAll(resp.Body) if err != nil { t.Error(err) return } resp.Body.Close() if string(b) != "Hello, \"/path\"" { t.Errorf("unexpected response body: %s", string(b)) } fmt.Println(buff.String()) } ================================================ FILE: transport/http/server/plugin/server.go ================================================ // SPDX-License-Identifier: Apache-2.0 package plugin import ( "context" "net/http" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" ) const Namespace = "github_com/devopsfaith/krakend/transport/http/server/handler" const logPrefix = "[PLUGIN: Server]" type RunServer func(context.Context, config.ServiceConfig, http.Handler) error func New(logger logging.Logger, next RunServer) RunServer { return func(ctx context.Context, cfg config.ServiceConfig, handler http.Handler) error { v, ok := cfg.ExtraConfig[Namespace] if !ok { return next(ctx, cfg, handler) } extra, ok := v.(map[string]interface{}) if !ok { logger.Debug(logPrefix, "Wrong extra_config type") return next(ctx, cfg, handler) } // load plugin(s) r, ok := serverRegister.Get(Namespace) if !ok { logger.Debug(logPrefix, "No plugins registered for the module") return next(ctx, cfg, handler) } name, nameOk := extra["name"].(string) fifoRaw, fifoOk := extra["name"].([]interface{}) if !nameOk && !fifoOk { logger.Debug(logPrefix, "No plugins required in the extra config") return next(ctx, cfg, handler) } var fifo []string if !fifoOk { fifo = []string{name} } else { for _, x := range fifoRaw { if v, ok := x.(string); ok { fifo = append(fifo, v) } } } for _, name := range fifo { rawHf, ok := r.Get(name) if !ok { logger.Error(logPrefix, "No plugin registered as", name) continue } hf, ok := rawHf.(func(context.Context, map[string]interface{}, http.Handler) (http.Handler, error)) if !ok { logger.Error(logPrefix, "Wrong plugin handler type:", name) continue } handlerWrapper, err := hf(ctx, extra, handler) if err != nil { logger.Error(logPrefix, "Error getting the plugin handler:", err.Error()) continue } logger.Info(logPrefix, "Injecting plugin", name) handler = handlerWrapper } return next(ctx, cfg, handler) } } ================================================ FILE: transport/http/server/plugin/tests/main.go ================================================ // SPDX-License-Identifier: Apache-2.0 package main import ( "context" "fmt" "html" "net/http" ) // HandlerRegisterer is the symbol the plugin loader will try to load. It must implement the Registerer interface var HandlerRegisterer = registerer("krakend-server-example") type registerer string var logger Logger = nil func (registerer) RegisterLogger(v interface{}) { l, ok := v.(Logger) if !ok { return } logger = l logger.Debug(fmt.Sprintf("[PLUGIN: %s] Logger loaded", HandlerRegisterer)) } func (r registerer) RegisterHandlers(f func( name string, handler func(context.Context, map[string]interface{}, http.Handler) (http.Handler, error), )) { f(string(r), r.registerHandlers) } func (registerer) registerHandlers(_ context.Context, _ map[string]interface{}, _ http.Handler) (http.Handler, error) { // check the passed configuration and initialize the plugin // possible config example: /* "extra_config":{ "plugin/http-server":{ "name":["krakend-server-example"], "krakend-server-example":{ "A":"foo", "B":42 } } } */ if logger == nil { // return the actual handler wrapping or your custom logic so it can be used as a replacement for the default http handler return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { fmt.Fprintf(w, "Hello, %q", html.EscapeString(req.URL.Path)) }), nil } // return the actual handler wrapping or your custom logic so it can be used as a replacement for the default http handler return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { fmt.Fprintf(w, "Hello, %q", html.EscapeString(req.URL.Path)) logger.Debug("request:", html.EscapeString(req.URL.Path)) }), nil } func main() {} type Logger interface { Debug(v ...interface{}) Info(v ...interface{}) Warning(v ...interface{}) Error(v ...interface{}) Critical(v ...interface{}) Fatal(v ...interface{}) } ================================================ FILE: transport/http/server/server.go ================================================ // SPDX-License-Identifier: Apache-2.0 /* Package server provides tools to create http servers and handlers wrapping the lura router */ package server import ( "context" "crypto/tls" "crypto/x509" "errors" "fmt" "net" "net/http" "os" "sync" "time" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/core" "github.com/luraproject/lura/v2/logging" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" ) // ToHTTPError translates an error into a HTTP status code type ToHTTPError func(error) int // DefaultToHTTPError is a ToHTTPError transalator that always returns an // internal server error func DefaultToHTTPError(_ error) int { return http.StatusInternalServerError } const ( // HeaderCompleteResponseValue is the value of the CompleteResponseHeader // if the response is complete HeaderCompleteResponseValue = "true" // HeaderIncompleteResponseValue is the value of the CompleteResponseHeader // if the response is not complete HeaderIncompleteResponseValue = "false" ) var ( // CompleteResponseHeaderName is the header to flag incomplete responses to the client CompleteResponseHeaderName = "X-Krakend-Completed" // HeadersToSend are the headers to pass from the router request to the proxy HeadersToSend = []string{"Content-Type"} // UserAgentHeaderValue is the value of the User-Agent header to add to the proxy request UserAgentHeaderValue = []string{core.KrakendUserAgent} // ErrInternalError is the error returned by the router when something went wrong ErrInternalError = errors.New("internal server error") // ErrPrivateKey is the error returned by the router when the private key is not defined ErrPrivateKey = errors.New("private key not defined") // ErrPublicKey is the error returned by the router when the public key is not defined ErrPublicKey = errors.New("public key not defined") loggerPrefix = "[SERVICE: HTTP Server]" ) // InitHTTPDefaultTransport ensures the default HTTP transport is configured just once per execution func InitHTTPDefaultTransport(cfg config.ServiceConfig) { InitHTTPDefaultTransportWithLogger(cfg, nil) } func InitHTTPDefaultTransportWithLogger(cfg config.ServiceConfig, logger logging.Logger) { if logger == nil { logger = logging.NoOp } if cfg.AllowInsecureConnections { if cfg.ClientTLS == nil { cfg.ClientTLS = &config.ClientTLS{} } cfg.ClientTLS.AllowInsecureConnections = true } onceTransportConfig.Do(func() { http.DefaultTransport = NewTransport(cfg, logger) }) } func NewTransport(cfg config.ServiceConfig, logger logging.Logger) *http.Transport { return &http.Transport{ Proxy: http.ProxyFromEnvironment, DialContext: (&net.Dialer{ Timeout: cfg.DialerTimeout, KeepAlive: cfg.DialerKeepAlive, FallbackDelay: cfg.DialerFallbackDelay, DualStack: true, }).DialContext, DisableCompression: cfg.DisableCompression, DisableKeepAlives: cfg.DisableKeepAlives, MaxIdleConns: cfg.MaxIdleConns, MaxIdleConnsPerHost: cfg.MaxIdleConnsPerHost, IdleConnTimeout: cfg.IdleConnTimeout, ResponseHeaderTimeout: cfg.ResponseHeaderTimeout, ExpectContinueTimeout: cfg.ExpectContinueTimeout, TLSHandshakeTimeout: 10 * time.Second, TLSClientConfig: ParseClientTLSConfigWithLogger(cfg.ClientTLS, logger), } } // RunServer runs a http.Server with the given handler and configuration. // It configures the TLS layer if required by the received configuration. func RunServer(ctx context.Context, cfg config.ServiceConfig, handler http.Handler) error { return RunServerWithLoggerFactory(nil)(ctx, cfg, handler) } func RunServerWithLoggerFactory(l logging.Logger) func(context.Context, config.ServiceConfig, http.Handler) error { return func(ctx context.Context, cfg config.ServiceConfig, handler http.Handler) error { done := make(chan error) s := NewServerWithLogger(cfg, handler, l) if s.TLSConfig == nil { go func() { done <- s.ListenAndServe() }() } else { if cfg.TLS.PublicKey != "" || cfg.TLS.PrivateKey != "" { cfg.TLS.Keys = append(cfg.TLS.Keys, config.TLSKeyPair{ PublicKey: cfg.TLS.PublicKey, PrivateKey: cfg.TLS.PrivateKey, }) } if len(cfg.TLS.Keys) == 0 { return ErrPublicKey } for _, k := range cfg.TLS.Keys { if k.PublicKey == "" { return ErrPublicKey } if k.PrivateKey == "" { return ErrPrivateKey } cert, err := tls.LoadX509KeyPair(k.PublicKey, k.PrivateKey) if err != nil { return err } s.TLSConfig.Certificates = append(s.TLSConfig.Certificates, cert) } go func() { // since we already use the list of certificates in the config // we do not need to specify the files for public and private key here done <- s.ListenAndServeTLS("", "") }() } select { case err := <-done: return err case <-ctx.Done(): if cfg.MaxShutdownDuration <= 0 { return s.Shutdown(context.Background()) } withTimeout, cancel := context.WithTimeout(context.Background(), cfg.MaxShutdownDuration) defer cancel() return s.Shutdown(withTimeout) } } } // NewServer returns a http.Server ready to serve the injected handler func NewServer(cfg config.ServiceConfig, handler http.Handler) *http.Server { return NewServerWithLogger(cfg, handler, nil) } func NewServerWithLogger(cfg config.ServiceConfig, handler http.Handler, logger logging.Logger) *http.Server { if cfg.UseH2C { handler = h2c.NewHandler(handler, &http2.Server{}) } return &http.Server{ Addr: net.JoinHostPort(cfg.Address, fmt.Sprintf("%d", cfg.Port)), Handler: handler, ReadTimeout: cfg.ReadTimeout, WriteTimeout: cfg.WriteTimeout, ReadHeaderTimeout: cfg.ReadHeaderTimeout, IdleTimeout: cfg.IdleTimeout, TLSConfig: ParseTLSConfigWithLogger(cfg.TLS, logger), MaxHeaderBytes: cfg.MaxHeaderBytes, } } // ParseTLSConfig creates a tls.Config from the TLS section of the service configuration func ParseTLSConfig(cfg *config.TLS) *tls.Config { return ParseTLSConfigWithLogger(cfg, nil) } func ParseTLSConfigWithLogger(cfg *config.TLS, logger logging.Logger) *tls.Config { if cfg == nil { return nil } if cfg.IsDisabled { return nil } if logger == nil { logger = logging.NoOp } tlsConfig := &tls.Config{ MinVersion: parseTLSVersion(cfg.MinVersion), MaxVersion: parseTLSVersion(cfg.MaxVersion), CurvePreferences: parseCurveIDs(cfg.CurvePreferences), CipherSuites: parseCipherSuites(cfg.CipherSuites), } if !cfg.EnableMTLS { return tlsConfig } certPool := loadCertPool(cfg.DisableSystemCaPool, cfg.CaCerts, logger) for _, cert := range cfg.Keys { caCert, err := os.ReadFile(cert.PublicKey) if err != nil { logger.Error(fmt.Sprintf("%s Cannot load public key %s: %s", loggerPrefix, cert.PublicKey, err.Error())) continue } certPool.AppendCertsFromPEM(caCert) } tlsConfig.ClientCAs = certPool tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert return tlsConfig } func ParseClientTLSConfigWithLogger(cfg *config.ClientTLS, logger logging.Logger) *tls.Config { if cfg == nil { return nil } return &tls.Config{ InsecureSkipVerify: cfg.AllowInsecureConnections, RootCAs: loadCertPool(cfg.DisableSystemCaPool, cfg.CaCerts, logger), MinVersion: parseTLSVersion(cfg.MinVersion), MaxVersion: parseTLSVersion(cfg.MaxVersion), CurvePreferences: parseCurveIDs(cfg.CurvePreferences), CipherSuites: parseCipherSuites(cfg.CipherSuites), Certificates: loadClientCerts(cfg.ClientCerts, logger), } } func loadCertPool(disableSystemCaPool bool, caCerts []string, logger logging.Logger) *x509.CertPool { certPool := x509.NewCertPool() if !disableSystemCaPool { if systemCertPool, err := x509.SystemCertPool(); err == nil { certPool = systemCertPool } else { logger.Error(fmt.Sprintf("%s Cannot load system CA pool: %s", loggerPrefix, err.Error())) } } for _, path := range caCerts { if ca, err := os.ReadFile(path); err == nil { certPool.AppendCertsFromPEM(ca) } else { logger.Error(fmt.Sprintf("%s Cannot load certificate CA %s: %s", loggerPrefix, path, err.Error())) } } return certPool } func loadClientCerts(certFiles []config.ClientTLSCert, logger logging.Logger) []tls.Certificate { certs := make([]tls.Certificate, 0, len(certFiles)) for _, certAndKey := range certFiles { cert, err := tls.LoadX509KeyPair(certAndKey.Certificate, certAndKey.PrivateKey) if err != nil { logger.Error(fmt.Sprintf("%s Cannot load client certificate %s, %s: %s", loggerPrefix, certAndKey.Certificate, certAndKey.PrivateKey, err.Error())) continue } certs = append(certs, cert) } return certs } func parseTLSVersion(key string) uint16 { if v, ok := versions[key]; ok { return v } return tls.VersionTLS13 } func parseCurveIDs(curvePreferences []uint16) []tls.CurveID { l := len(curvePreferences) if l == 0 { return defaultCurves } curves := make([]tls.CurveID, len(curvePreferences)) for i := range curves { curves[i] = tls.CurveID(curvePreferences[i]) } return curves } func parseCipherSuites(cipherSuites []uint16) []uint16 { l := len(cipherSuites) if l == 0 { return defaultCipherSuites } cs := make([]uint16, l) for i := range cs { cs[i] = uint16(cipherSuites[i]) } return cs } var ( onceTransportConfig sync.Once defaultCurves = []tls.CurveID{ tls.CurveP521, tls.CurveP384, tls.CurveP256, } defaultCipherSuites = []uint16{ tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, } versions = map[string]uint16{ "SSL3.0": tls.VersionSSL30, "TLS10": tls.VersionTLS10, "TLS11": tls.VersionTLS11, "TLS12": tls.VersionTLS12, "TLS13": tls.VersionTLS13, } ) ================================================ FILE: transport/http/server/server_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package server import ( "context" "crypto/tls" "crypto/x509" "errors" "fmt" "html" "log" "math/rand" "net" "net/http" "os" "strings" "testing" "time" "github.com/luraproject/lura/v2/config" "github.com/luraproject/lura/v2/logging" "golang.org/x/net/http2" ) func init() { rand.Seed(time.Now().Unix()) } func TestRunServer_TLS(t *testing.T) { testKeysAreAvailable(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() port := newPort() done := make(chan error) go func() { done <- RunServer( ctx, config.ServiceConfig{ Port: port, TLS: &config.TLS{ PublicKey: "cert.pem", PrivateKey: "key.pem", CaCerts: []string{"ca.pem"}, }, }, http.HandlerFunc(dummyHandler), ) }() client, err := httpsClient("cert.pem") if err != nil { t.Error(err) return } <-time.After(100 * time.Millisecond) resp, err := client.Get(fmt.Sprintf("https://localhost:%d", port)) if err != nil { t.Error(err) return } if resp.StatusCode != 200 { t.Errorf("unexpected status code: %d", resp.StatusCode) return } // now lets initialize the global default transport and use a regular // client to connect to the server InitHTTPDefaultTransport(config.ServiceConfig{ ClientTLS: &config.ClientTLS{ CaCerts: []string{"ca.pem"}, DisableSystemCaPool: true, }, }) rawClient := http.Client{} resp, err = rawClient.Get(fmt.Sprintf("https://localhost:%d", port)) if err != nil { t.Error(err) return } if resp.StatusCode != 200 { t.Errorf("unexpected status code: %d", resp.StatusCode) return } cancel() if err = <-done; err != nil { t.Error(err) } } func TestRunServer_MTLS(t *testing.T) { testKeysAreAvailable(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() port := 36517 done := make(chan error) cfg := config.ServiceConfig{ Port: port, TLS: &config.TLS{ Keys: []config.TLSKeyPair{ { PublicKey: "cert.pem", PrivateKey: "key.pem", }, }, CaCerts: []string{"ca.pem"}, EnableMTLS: true, }, ClientTLS: &config.ClientTLS{ AllowInsecureConnections: false, // we do not check the server cert CaCerts: []string{"ca.pem"}, ClientCerts: []config.ClientTLSCert{ { Certificate: "cert.pem", PrivateKey: "key.pem", }, }, }, } go func() { done <- RunServer(ctx, cfg, http.HandlerFunc(dummyHandler)) }() client, err := mtlsClient("cert.pem", "key.pem") if err != nil { t.Error(err) return } <-time.After(1000 * time.Millisecond) resp, err := client.Get(fmt.Sprintf("https://localhost:%d", port)) if err != nil { t.Error(err) return } if resp.StatusCode != 200 { t.Errorf("unexpected status code: %d", resp.StatusCode) return } logger := logging.NoOp // since test are run in a suite, and `InitHTTPDefaultTransportWithLogger` is // used to setup the `http.DefaultTransport` global variable once, we need to // create a client here like if it was using the default created with the // clientTLS config. // This is a copy of the code we can find inside // InitHTTPDefaultTransportWithLogger(serviceConfig, nil): transport := NewTransport(cfg, logger) defClient := http.Client{ Transport: transport, } resp, err = defClient.Get(fmt.Sprintf("https://localhost:%d", port)) if err != nil { t.Error(err) return } if resp.StatusCode != 200 { t.Errorf("unexpected status code: %d", resp.StatusCode) return } cancel() if err = <-done; err != nil { t.Error(err) } } func TestRunServer_MTLSOldConfigFormat(t *testing.T) { testKeysAreAvailable(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() port := 36517 done := make(chan error) cfg := config.ServiceConfig{ Port: port, TLS: &config.TLS{ PublicKey: "cert.pem", PrivateKey: "key.pem", CaCerts: []string{"ca.pem"}, EnableMTLS: true, }, ClientTLS: &config.ClientTLS{ AllowInsecureConnections: false, // we do not check the server cert CaCerts: []string{"ca.pem"}, ClientCerts: []config.ClientTLSCert{ { Certificate: "cert.pem", PrivateKey: "key.pem", }, }, }, } go func() { done <- RunServer(ctx, cfg, http.HandlerFunc(dummyHandler)) }() client, err := mtlsClient("cert.pem", "key.pem") if err != nil { t.Error(err) return } <-time.After(1000 * time.Millisecond) resp, err := client.Get(fmt.Sprintf("https://localhost:%d", port)) if err != nil { t.Error(err) return } if resp.StatusCode != 200 { t.Errorf("unexpected status code: %d", resp.StatusCode) return } logger := logging.NoOp // since test are run in a suite, and `InitHTTPDefaultTransportWithLogger` is // used to setup the `http.DefaultTransport` global variable once, we need to // create a client here like if it was using the default created with the // clientTLS config. // This is a copy of the code we can find inside // InitHTTPDefaultTransportWithLogger(serviceConfig, nil): transport := NewTransport(cfg, logger) defClient := http.Client{ Transport: transport, } resp, err = defClient.Get(fmt.Sprintf("https://localhost:%d", port)) if err != nil { t.Error(err) return } if resp.StatusCode != 200 { t.Errorf("unexpected status code: %d", resp.StatusCode) return } cancel() if err = <-done; err != nil { t.Error(err) } } func TestRunServer_plain(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() port := newPort() done := make(chan error) go func() { done <- RunServer( ctx, config.ServiceConfig{Port: port}, http.HandlerFunc(dummyHandler), ) }() <-time.After(100 * time.Millisecond) resp, err := http.Get(fmt.Sprintf("http://localhost:%d", port)) if err != nil { t.Error(err) return } if resp.StatusCode != 200 { t.Errorf("unexpected status code: %d", resp.StatusCode) return } cancel() if err = <-done; err != nil { t.Error(err) } } func TestRunServer_h2c(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() port := newPort() done := make(chan error) go func() { done <- RunServer( ctx, config.ServiceConfig{ Port: port, UseH2C: true, }, http.HandlerFunc(dummyHandler), ) }() <-time.After(100 * time.Millisecond) client := h2cClient() resp, err := client.Get(fmt.Sprintf("http://localhost:%d", port)) if err != nil { t.Error(err) return } if resp.StatusCode != 200 { t.Errorf("unexpected status code: %d", resp.StatusCode) return } cancel() if err = <-done; err != nil { t.Error(err) } } func TestRunServer_disabledTLS(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() done := make(chan error) port := newPort() go func() { done <- RunServer( ctx, config.ServiceConfig{ Port: port, TLS: &config.TLS{ IsDisabled: true, }}, http.HandlerFunc(dummyHandler), ) }() <-time.After(100 * time.Millisecond) resp, err := http.Get(fmt.Sprintf("http://localhost:%d", port)) if err != nil { t.Error(err) return } if resp.StatusCode != 200 { t.Errorf("unexpected status code: %d", resp.StatusCode) return } cancel() if err = <-done; err != nil { t.Error(err) } } func TestRunServer_err(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() done := make(chan error) for _, tc := range []struct { cfg *config.TLS err error }{ { cfg: &config.TLS{}, err: ErrPublicKey, }, { cfg: &config.TLS{ PublicKey: "unknown", }, err: ErrPrivateKey, }, } { go func() { done <- RunServer( ctx, config.ServiceConfig{TLS: tc.cfg}, http.HandlerFunc(dummyHandler), ) }() if err := <-done; err != tc.err { t.Error(err) } } } func TestRunServer_errBadKeys(t *testing.T) { done := make(chan error) ctx, cancel := context.WithCancel(context.Background()) defer cancel() go func() { done <- RunServer( ctx, config.ServiceConfig{TLS: &config.TLS{ PublicKey: "unknown", PrivateKey: "unknown", }}, http.HandlerFunc(dummyHandler), ) }() if err := <-done; err == nil || err.Error() != "open unknown: no such file or directory" { t.Error(err) } } func Test_parseTLSVersion(t *testing.T) { for _, tc := range []struct { in string out uint16 }{ {in: "SSL3.0", out: tls.VersionSSL30}, {in: "TLS10", out: tls.VersionTLS10}, {in: "TLS11", out: tls.VersionTLS11}, {in: "TLS12", out: tls.VersionTLS12}, {in: "TLS13", out: tls.VersionTLS13}, {in: "Unknown", out: tls.VersionTLS13}, } { if res := parseTLSVersion(tc.in); res != tc.out { t.Errorf("input %s generated output %d. expected: %d", tc.in, res, tc.out) } } } func Test_parseCurveIDs(t *testing.T) { original := []uint16{1, 2, 3} cs := parseCurveIDs(original) for k, v := range cs { if original[k] != uint16(v) { t.Errorf("unexpected curves %v. expected: %v", cs, original) } } } func Test_parseCipherSuites(t *testing.T) { original := []uint16{1, 2, 3} cs := parseCipherSuites(original) for k, v := range cs { if original[k] != uint16(v) { t.Errorf("unexpected ciphersuites %v. expected: %v", cs, original) } } } func dummyHandler(rw http.ResponseWriter, req *http.Request) { fmt.Fprintf(rw, "Hello, %q", html.EscapeString(req.URL.Path)) } func testKeysAreAvailable(t *testing.T) { files, err := os.ReadDir(".") if err != nil { log.Fatal(err) } for _, k := range []string{"cert.pem", "key.pem"} { var exists bool for _, file := range files { if file.Name() == k { exists = true break } } if !exists { t.Errorf("file %s not present", k) } } } func httpsClient(cert string) (*http.Client, error) { cer, err := os.ReadFile(cert) if err != nil { return nil, err } roots := x509.NewCertPool() ok := roots.AppendCertsFromPEM(cer) if !ok { return nil, errors.New("failed to parse root certificate") } tlsConf := &tls.Config{ MinVersion: tls.VersionTLS12, CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384, tls.CurveP256}, PreferServerCipherSuites: true, CipherSuites: []uint16{ tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, }, RootCAs: roots, } return &http.Client{Transport: &http.Transport{TLSClientConfig: tlsConf}}, nil } func mtlsClient(certPath, keyPath string) (*http.Client, error) { cert, err := tls.LoadX509KeyPair(certPath, keyPath) if err != nil { return nil, err } cacer, err := os.ReadFile(certPath) if err != nil { return nil, err } roots := x509.NewCertPool() ok := roots.AppendCertsFromPEM(cacer) if !ok { return nil, errors.New("failed to parse root certificate") } tlsConf := &tls.Config{ MinVersion: tls.VersionTLS12, CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384, tls.CurveP256}, PreferServerCipherSuites: true, CipherSuites: []uint16{ tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, }, RootCAs: roots, Certificates: []tls.Certificate{cert}, } return &http.Client{Transport: &http.Transport{TLSClientConfig: tlsConf}}, nil } // h2cClient initializes client which executes cleartext http2 requests func h2cClient() *http.Client { return &http.Client{ Transport: &http2.Transport{ DialTLSContext: func(_ context.Context, network, addr string, _ *tls.Config) (net.Conn, error) { return net.Dial(network, addr) }, AllowHTTP: true, }, } } // newPort returns random port numbers to avoid port collisions during the tests func newPort() int { return 16666 + rand.Intn(40000) // skipcq: GSC-G404 } func TestRunServer_MultipleTLS(t *testing.T) { testKeysAreAvailable(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() port := newPort() done := make(chan error) go func() { done <- RunServer( ctx, config.ServiceConfig{ Port: port, TLS: &config.TLS{ CaCerts: []string{"ca.pem", "exampleca.pem"}, Keys: []config.TLSKeyPair{ { PublicKey: "cert.pem", PrivateKey: "key.pem", }, { PublicKey: "examplecert.pem", PrivateKey: "examplekey.pem", }, }, }, }, http.HandlerFunc(dummyHandler), ) }() client, err := httpsClient("cert.pem") if err != nil { t.Error(err) return } <-time.After(100 * time.Millisecond) resp, err := client.Get(fmt.Sprintf("https://localhost:%d", port)) if err != nil { t.Error(err) return } if resp.StatusCode != 200 { t.Errorf("unexpected status code: %d", resp.StatusCode) return } client, err = httpsClient("examplecert.pem") if err != nil { t.Error(err) return } _, err = client.Get(fmt.Sprintf("https://127.0.0.1:%d", port)) // should fail, because it will be served with cert.pem if err == nil || strings.Contains(err.Error(), "bad certificate") { t.Error("expected to have 'bad certificate' error") return } req, _ := http.NewRequest("GET", fmt.Sprintf("https://example.com:%d", port), http.NoBody) overrideHostTransport(client) resp, err = client.Do(req) if err != nil { t.Error(err) return } if resp.StatusCode != 200 { t.Errorf("unexpected status code: %d", resp.StatusCode) return } cancel() if err = <-done; err != nil { t.Error(err) } } // overrideHostTransport subtitutes the actual address that the request will // connecto (overriding the dns resolution). func overrideHostTransport(client *http.Client) { t := http.DefaultTransport.(*http.Transport).Clone() if client.Transport != nil { if tt, ok := client.Transport.(*http.Transport); ok { t = tt } } myDialer := &net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, DualStack: true, } t.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) { _, port, err := net.SplitHostPort(address) if err != nil { return nil, err } overrideAddress := net.JoinHostPort("127.0.0.1", port) return myDialer.DialContext(ctx, network, overrideAddress) } client.Transport = t } ================================================ FILE: transport/http/server/tls_test.go ================================================ // SPDX-License-Identifier: Apache-2.0 package server import ( "crypto/rand" "crypto/rsa" "crypto/x509" "crypto/x509/pkix" "encoding/pem" "fmt" "log" "math/big" "net" "os" "time" ) type certDef struct { Prefix string IPAddresses []string DNSNames []string } func (c certDef) Org() string { if c.Prefix == "" { return "Acme Co" } return c.Prefix + " " + "Acme Co" } func init() { certs := []certDef{ certDef{ Prefix: "", IPAddresses: []string{"127.0.0.1", "::1"}, DNSNames: []string{"localhost"}, }, certDef{ Prefix: "example", IPAddresses: []string{"127.0.0.1"}, DNSNames: []string{"example.com"}, }, } for _, cd := range certs { if err := generateNamedCert(cd); err != nil { log.Fatal(err.Error()) } } } func generateNamedCert(hostCert certDef) error { priv, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { return fmt.Errorf("Failed to generate private key: %v", err) } keyUsage := x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment notBefore := time.Now() notAfter := notBefore.Add(1000000 * time.Hour) serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) if err != nil { return fmt.Errorf("Failed to generate serial number: %v", err) } template := x509.Certificate{ SerialNumber: serialNumber, Subject: pkix.Name{ Organization: []string{hostCert.Org()}, }, NotBefore: notBefore, NotAfter: notAfter, KeyUsage: keyUsage, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, BasicConstraintsValid: true, } for _, strIP := range hostCert.IPAddresses { if ip := net.ParseIP(strIP); ip != nil { template.IPAddresses = append(template.IPAddresses, ip) } } template.DNSNames = append(template.DNSNames, hostCert.DNSNames...) template.IsCA = true template.KeyUsage |= x509.KeyUsageCertSign derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) if err != nil { return fmt.Errorf("Failed to create certificate: %v", err) } caBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) if err != nil { return fmt.Errorf("Failed to create ca: %v", err) } serverCert := hostCert.Prefix + "cert.pem" serverKey := hostCert.Prefix + "key.pem" caCert := hostCert.Prefix + "ca.pem" certOut, err := os.Create(serverCert) if err != nil { return fmt.Errorf("Failed to open %s for writing: %v", serverCert, err) } if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { return fmt.Errorf("Failed to write data to %s: %v", serverCert, err) } if err := certOut.Close(); err != nil { return fmt.Errorf("Error closing %s: %v", serverCert, err) } log.Printf("wrote %s\n", serverCert) keyOut, err := os.OpenFile(serverKey, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) if err != nil { return fmt.Errorf("Failed to open %s for writing: %v", serverKey, err) } privBytes, err := x509.MarshalPKCS8PrivateKey(priv) if err != nil { return fmt.Errorf("Unable to marshal private key: %v", err) } if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil { return fmt.Errorf("Failed to write data to %s: %v", serverKey, err) } if err := keyOut.Close(); err != nil { return fmt.Errorf("Error closing %s: %v", serverKey, err) } log.Printf("wrote %s\n", serverKey) caOut, err := os.Create(caCert) if err != nil { return fmt.Errorf("Failed to open %s for writing: %v", caCert, err) } if err := pem.Encode(caOut, &pem.Block{Type: "CERTIFICATE", Bytes: caBytes}); err != nil { return fmt.Errorf("Failed to write data to %s: %v", caCert, err) } if err := caOut.Close(); err != nil { return fmt.Errorf("Error closing %s: %v", caCert, err) } log.Printf("wrote %s\n", caCert) return nil }