Repository: cloudwego/hertz Branch: main Commit: 925126d9ae0d Files: 450 Total size: 2.5 MB Directory structure: gitextract_i3d5wkwv/ ├── .codecov.yml ├── .gitattributes ├── .github/ │ ├── CODEOWNERS │ ├── ISSUE_TEMPLATE/ │ │ ├── bug_report.md │ │ ├── feature_request.md │ │ └── question.md │ ├── PULL_REQUEST_TEMPLATE.md │ ├── labels.json │ └── workflows/ │ ├── cmd-tests.yml │ ├── invalid_question.yml │ ├── labeler.yml │ ├── pr-check.yml │ ├── unit-tests.yml │ └── vulncheck.yml ├── .gitignore ├── .golangci.yaml ├── .licenserc.yaml ├── .typos.toml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── NOTICE ├── README.md ├── README_cn.md ├── ROADMAP.md ├── cmd/ │ └── hz/ │ ├── app/ │ │ └── app.go │ ├── config/ │ │ ├── argument.go │ │ └── cmd.go │ ├── doc.go │ ├── generator/ │ │ ├── client.go │ │ ├── custom_files.go │ │ ├── file.go │ │ ├── handler.go │ │ ├── layout.go │ │ ├── layout_tpl.go │ │ ├── model/ │ │ │ ├── define.go │ │ │ ├── expr.go │ │ │ ├── golang/ │ │ │ │ ├── constant.go │ │ │ │ ├── enum.go │ │ │ │ ├── file.go │ │ │ │ ├── function.go │ │ │ │ ├── init.go │ │ │ │ ├── oneof.go │ │ │ │ ├── struct.go │ │ │ │ ├── typedef.go │ │ │ │ └── variable.go │ │ │ └── model.go │ │ ├── model.go │ │ ├── model_test.go │ │ ├── package.go │ │ ├── package_tpl.go │ │ ├── router.go │ │ ├── router_test.go │ │ ├── template.go │ │ └── template_funcs.go │ ├── go.mod │ ├── go.sum │ ├── main.go │ ├── meta/ │ │ ├── const.go │ │ ├── manifest.go │ │ └── manifest_test.go │ ├── protobuf/ │ │ ├── api/ │ │ │ ├── api.pb.go │ │ │ └── api.proto │ │ ├── ast.go │ │ ├── plugin.go │ │ ├── plugin_stubs.go │ │ ├── plugin_test.go │ │ ├── resolver.go │ │ ├── tag_test.go │ │ ├── tags.go │ │ └── test_data/ │ │ ├── protobuf_tag_test.out │ │ └── test_tag.proto │ ├── test_hz_unix.sh │ ├── test_hz_windows.sh │ ├── testdata/ │ │ ├── protobuf2/ │ │ │ ├── api.proto │ │ │ ├── other/ │ │ │ │ ├── other.proto │ │ │ │ └── other_base.proto │ │ │ └── psm/ │ │ │ ├── base.proto │ │ │ └── psm.proto │ │ ├── protobuf3/ │ │ │ ├── api.proto │ │ │ ├── other/ │ │ │ │ ├── other.proto │ │ │ │ └── other_base.proto │ │ │ └── psm/ │ │ │ ├── base.proto │ │ │ └── psm.proto │ │ └── thrift/ │ │ ├── common.thrift │ │ ├── data/ │ │ │ ├── basic_data.thrift │ │ │ └── data.thrift │ │ └── psm.thrift │ ├── thrift/ │ │ ├── ast.go │ │ ├── plugin.go │ │ ├── plugin_test.go │ │ ├── resolver.go │ │ ├── tag_test.go │ │ ├── tags.go │ │ ├── test_data/ │ │ │ ├── test_tag.thrift │ │ │ └── thrift_tag_test.out │ │ └── thriftgo_util.go │ └── util/ │ ├── ast.go │ ├── ast_test.go │ ├── data.go │ ├── data_test.go │ ├── env.go │ ├── fs.go │ ├── logs/ │ │ ├── api.go │ │ └── std.go │ ├── string.go │ ├── tool_install.go │ └── tool_install_test.go ├── examples/ │ ├── html_rendering/ │ │ ├── index.tmpl │ │ ├── main.go │ │ └── template.html │ └── standard/ │ └── main.go ├── go.mod ├── go.sum ├── internal/ │ ├── bytesconv/ │ │ ├── bytesconv.go │ │ ├── bytesconv_32.go │ │ ├── bytesconv_32_test.go │ │ ├── bytesconv_64.go │ │ ├── bytesconv_64_test.go │ │ ├── bytesconv_table.go │ │ ├── bytesconv_table_gen.go │ │ ├── bytesconv_test.go │ │ ├── bytesconv_timing_test.go │ │ ├── doc.go │ │ └── errors.go │ ├── bytestr/ │ │ └── bytes.go │ ├── nocopy/ │ │ └── nocopy.go │ ├── stats/ │ │ ├── stats_util.go │ │ ├── stats_util_test.go │ │ ├── tracer.go │ │ └── tracer_test.go │ ├── tagexpr/ │ │ ├── LICENSE │ │ ├── README.md │ │ ├── example_test.go │ │ ├── expr.go │ │ ├── expr_test.go │ │ ├── handler.go │ │ ├── selector.go │ │ ├── selector_test.go │ │ ├── spec_func.go │ │ ├── spec_func_test.go │ │ ├── spec_operand.go │ │ ├── spec_operator.go │ │ ├── spec_range.go │ │ ├── spec_range_test.go │ │ ├── spec_selector.go │ │ ├── spec_test.go │ │ ├── tagexpr.go │ │ ├── tagexpr_test.go │ │ ├── tagparser.go │ │ ├── tagparser_test.go │ │ ├── utils.go │ │ └── validator/ │ │ ├── README.md │ │ ├── default.go │ │ ├── example_test.go │ │ ├── func.go │ │ ├── validator.go │ │ └── validator_test.go │ ├── test/ │ │ └── mock/ │ │ └── binder/ │ │ ├── binder.go │ │ └── binder_test.go │ └── testutils/ │ ├── testutils.go │ └── testutils_test.go ├── licenses/ │ ├── LICENSE-echo.txt │ ├── LICENSE-fasthttp.txt │ ├── LICENSE-fsnotify │ ├── LICENSE-gin.txt │ ├── LICENSE-go-version.txt │ ├── LICENSE-protobuf.txt │ ├── LICENSE-protoreflect.txt │ ├── LICENSE-sprig.txt │ └── LICENSE-yaml.txt ├── pkg/ │ ├── app/ │ │ ├── client/ │ │ │ ├── client.go │ │ │ ├── client_test.go │ │ │ ├── client_unix_test.go │ │ │ ├── client_windows_test.go │ │ │ ├── discovery/ │ │ │ │ ├── discovery.go │ │ │ │ └── discovery_test.go │ │ │ ├── loadbalance/ │ │ │ │ ├── lbcache.go │ │ │ │ ├── lbcache_test.go │ │ │ │ ├── loadbalance.go │ │ │ │ ├── weight_random.go │ │ │ │ └── weight_random_test.go │ │ │ ├── middleware.go │ │ │ ├── middleware_test.go │ │ │ ├── option.go │ │ │ ├── option_test.go │ │ │ └── retry/ │ │ │ ├── option.go │ │ │ ├── retry.go │ │ │ └── retry_test.go │ │ ├── context.go │ │ ├── context_test.go │ │ ├── fs.go │ │ ├── fs_test.go │ │ ├── middlewares/ │ │ │ ├── client/ │ │ │ │ └── sd/ │ │ │ │ ├── discovery.go │ │ │ │ ├── discovery_test.go │ │ │ │ ├── options.go │ │ │ │ └── options_test.go │ │ │ └── server/ │ │ │ ├── basic_auth/ │ │ │ │ ├── basic_auth.go │ │ │ │ ├── basic_auth_test.go │ │ │ │ └── doc.go │ │ │ └── recovery/ │ │ │ ├── option.go │ │ │ ├── option_test.go │ │ │ ├── recovery.go │ │ │ └── recovery_test.go │ │ └── server/ │ │ ├── binding/ │ │ │ ├── binder.go │ │ │ ├── binder_test.go │ │ │ ├── config.go │ │ │ ├── default.go │ │ │ ├── internal/ │ │ │ │ └── decoder/ │ │ │ │ ├── base_type_decoder.go │ │ │ │ ├── customized_type_decoder.go │ │ │ │ ├── decoder.go │ │ │ │ ├── getter.go │ │ │ │ ├── gjson_required.go │ │ │ │ ├── map_type_decoder.go │ │ │ │ ├── multipart_file_decoder.go │ │ │ │ ├── reflect.go │ │ │ │ ├── slice_getter.go │ │ │ │ ├── slice_type_decoder.go │ │ │ │ ├── sonic_required.go │ │ │ │ ├── struct_type_decoder.go │ │ │ │ ├── tag.go │ │ │ │ ├── text_decoder.go │ │ │ │ ├── util.go │ │ │ │ └── util_test.go │ │ │ ├── reflect.go │ │ │ ├── reflect_internal_test.go │ │ │ ├── reflect_test.go │ │ │ ├── tagexpr_bind_test.go │ │ │ ├── testdata/ │ │ │ │ ├── hello.pb.go │ │ │ │ └── hello.proto │ │ │ ├── validator.go │ │ │ └── validator_test.go │ │ ├── hertz.go │ │ ├── hertz_test.go │ │ ├── hertz_unix_test.go │ │ ├── mocks_test.go │ │ ├── option.go │ │ ├── option_test.go │ │ ├── registry/ │ │ │ ├── registry.go │ │ │ └── registry_test.go │ │ ├── render/ │ │ │ ├── data.go │ │ │ ├── doc.go │ │ │ ├── html.go │ │ │ ├── html_test.go │ │ │ ├── json.go │ │ │ ├── json_test.go │ │ │ ├── protobuf.go │ │ │ ├── render.go │ │ │ ├── render_test.go │ │ │ ├── text.go │ │ │ └── xml.go │ │ └── server_bench_test.go │ ├── common/ │ │ ├── adaptor/ │ │ │ ├── handler.go │ │ │ ├── handler_test.go │ │ │ ├── request.go │ │ │ ├── request_test.go │ │ │ ├── response.go │ │ │ ├── utils.go │ │ │ └── utils_test.go │ │ ├── bytebufferpool/ │ │ │ ├── bytebuffer.go │ │ │ ├── bytebuffer_test.go │ │ │ ├── doc.go │ │ │ ├── pool.go │ │ │ └── pool_test.go │ │ ├── compress/ │ │ │ ├── compress.go │ │ │ ├── compress_test.go │ │ │ └── doc.go │ │ ├── config/ │ │ │ ├── client_option.go │ │ │ ├── client_option_test.go │ │ │ ├── option.go │ │ │ ├── option_test.go │ │ │ ├── request_option.go │ │ │ └── request_option_test.go │ │ ├── errors/ │ │ │ ├── errors.go │ │ │ └── errors_test.go │ │ ├── hlog/ │ │ │ ├── consts.go │ │ │ ├── default.go │ │ │ ├── default_test.go │ │ │ ├── hlog.go │ │ │ ├── hlog_test.go │ │ │ ├── log.go │ │ │ ├── system.go │ │ │ └── system_test.go │ │ ├── json/ │ │ │ ├── sonic.go │ │ │ └── std.go │ │ ├── stackless/ │ │ │ ├── doc.go │ │ │ ├── func.go │ │ │ ├── func_test.go │ │ │ ├── func_timing_test.go │ │ │ ├── writer.go │ │ │ └── writer_test.go │ │ ├── test/ │ │ │ ├── assert/ │ │ │ │ ├── assert.go │ │ │ │ └── assert_test.go │ │ │ └── mock/ │ │ │ ├── body_data.go │ │ │ ├── body_data_test.go │ │ │ ├── network.go │ │ │ ├── network_test.go │ │ │ ├── reader.go │ │ │ ├── reader_test.go │ │ │ ├── writer.go │ │ │ └── writer_test.go │ │ ├── testdata/ │ │ │ ├── conf/ │ │ │ │ └── p_s_m.yaml │ │ │ ├── proto/ │ │ │ │ ├── test.pb.go │ │ │ │ └── test.proto │ │ │ ├── template/ │ │ │ │ ├── htmltemplate.html │ │ │ │ └── index.tmpl │ │ │ └── test.txt │ │ ├── timer/ │ │ │ ├── doc.go │ │ │ ├── timer.go │ │ │ └── timer_test.go │ │ ├── tracer/ │ │ │ ├── stats/ │ │ │ │ ├── event.go │ │ │ │ ├── event_test.go │ │ │ │ └── status.go │ │ │ ├── traceinfo/ │ │ │ │ ├── httpstats.go │ │ │ │ ├── interface.go │ │ │ │ └── traceinfo.go │ │ │ └── tracer.go │ │ ├── ut/ │ │ │ ├── context.go │ │ │ ├── context_test.go │ │ │ ├── request.go │ │ │ ├── request_test.go │ │ │ ├── response.go │ │ │ └── response_test.go │ │ └── utils/ │ │ ├── bufpool.go │ │ ├── chunk.go │ │ ├── chunk_test.go │ │ ├── env.go │ │ ├── ioutil.go │ │ ├── ioutil_test.go │ │ ├── netaddr.go │ │ ├── netaddr_test.go │ │ ├── network.go │ │ ├── network_test.go │ │ ├── path.go │ │ ├── path_test.go │ │ ├── utils.go │ │ └── utils_test.go │ ├── network/ │ │ ├── connection.go │ │ ├── dialer/ │ │ │ ├── dialer.go │ │ │ ├── dialer_test.go │ │ │ └── netpoll.go │ │ ├── dialer.go │ │ ├── netpoll/ │ │ │ ├── connection.go │ │ │ ├── connection_test.go │ │ │ ├── dial.go │ │ │ ├── dial_test.go │ │ │ ├── transport.go │ │ │ └── transport_test.go │ │ ├── standard/ │ │ │ ├── connection.go │ │ │ ├── connection_test.go │ │ │ ├── dial.go │ │ │ ├── dial_test.go │ │ │ ├── transport.go │ │ │ ├── transport_test.go │ │ │ └── unix_test.go │ │ ├── transport.go │ │ ├── utils.go │ │ ├── utils_test.go │ │ ├── version.go │ │ ├── writer.go │ │ └── writer_test.go │ ├── protocol/ │ │ ├── args.go │ │ ├── args_test.go │ │ ├── client/ │ │ │ ├── client.go │ │ │ └── client_test.go │ │ ├── consts/ │ │ │ ├── default.go │ │ │ ├── fs.go │ │ │ ├── headers.go │ │ │ ├── http2.go │ │ │ ├── methods.go │ │ │ └── status.go │ │ ├── cookie.go │ │ ├── cookie_test.go │ │ ├── doc.go │ │ ├── header.go │ │ ├── header_test.go │ │ ├── header_timing_test.go │ │ ├── http1/ │ │ │ ├── client.go │ │ │ ├── client_test.go │ │ │ ├── client_unix_test.go │ │ │ ├── ext/ │ │ │ │ ├── common.go │ │ │ │ ├── common_test.go │ │ │ │ ├── error.go │ │ │ │ ├── headerscanner.go │ │ │ │ ├── headerscanner_test.go │ │ │ │ ├── stream.go │ │ │ │ └── stream_test.go │ │ │ ├── factory/ │ │ │ │ ├── client.go │ │ │ │ └── server.go │ │ │ ├── proxy/ │ │ │ │ └── proxy.go │ │ │ ├── req/ │ │ │ │ ├── header.go │ │ │ │ ├── header_test.go │ │ │ │ ├── request.go │ │ │ │ └── request_test.go │ │ │ ├── resp/ │ │ │ │ ├── header.go │ │ │ │ ├── header_test.go │ │ │ │ ├── response.go │ │ │ │ ├── response_test.go │ │ │ │ ├── writer.go │ │ │ │ └── writer_test.go │ │ │ ├── server.go │ │ │ ├── server_test.go │ │ │ └── server_timing_test.go │ │ ├── multipart.go │ │ ├── multipart_test.go │ │ ├── request.go │ │ ├── request_test.go │ │ ├── response.go │ │ ├── response_test.go │ │ ├── server.go │ │ ├── sse/ │ │ │ ├── event.go │ │ │ ├── event_test.go │ │ │ ├── example_test.go │ │ │ ├── reader.go │ │ │ ├── reader_test.go │ │ │ ├── request.go │ │ │ ├── request_test.go │ │ │ ├── utils.go │ │ │ ├── utils_test.go │ │ │ ├── writer.go │ │ │ └── writer_test.go │ │ ├── suite/ │ │ │ ├── client.go │ │ │ └── server.go │ │ ├── trailer.go │ │ ├── trailer_test.go │ │ ├── uri.go │ │ ├── uri_test.go │ │ ├── uri_timing_test.go │ │ ├── uri_unix.go │ │ ├── uri_unix_test.go │ │ ├── uri_windows.go │ │ └── uri_windows_test.go │ └── route/ │ ├── consts/ │ │ └── const.go │ ├── engine.go │ ├── engine_test.go │ ├── netpoll.go │ ├── param/ │ │ └── param.go │ ├── routergroup.go │ ├── routergroup_test.go │ ├── routes_test.go │ ├── routes_timing_test.go │ ├── tree.go │ └── tree_test.go ├── scripts/ │ ├── .utils/ │ │ ├── check_go_mod.sh │ │ ├── check_version.sh │ │ └── funcs.sh │ ├── release-hotfix.sh │ └── release.sh └── version.go ================================================ FILE CONTENTS ================================================ ================================================ FILE: .codecov.yml ================================================ github_checks: annotations: false ignore: - "images/.*" - "examples/.*" - "cmd/.*" coverage: status: project: default: # pass if coverage drops by no more than 0.05% # this is possibly caused by unstable coverage. threshold: 0.05% patch: default: target: 80% ================================================ FILE: .gitattributes ================================================ # Handle line endings automatically for files detected as text # and leave all files detected as binary untouched. * text=auto eol=lf # Force bash scripts to always use LF line endings so that if a repo is accessed # in Unix via a file share from Windows, the scripts will work. *.sh text eol=lf ================================================ FILE: .github/CODEOWNERS ================================================ # For more information, please refer to https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners * @cloudwego/hertz-reviewers @cloudwego/hertz-approvers @cloudwego/hertz-maintainers ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.md ================================================ --- name: Bug report about: Create a report to help us improve title: '' labels: '' assignees: '' --- **Describe the bug** A clear and concise description of what the bug is. **To Reproduce** Steps to reproduce the behavior: 1. Go to '...' 2. Click on '....' 3. Scroll down to '....' 4. See error **Expected behavior** A clear and concise description of what you expected to happen. **Screenshots** If applicable, add screenshots to help explain your problem. **Hertz version:** Please provide the version of Hertz you are using. **Environment:** The output of `go env`. **Additional context** Add any other context about the problem here. ================================================ FILE: .github/ISSUE_TEMPLATE/feature_request.md ================================================ --- name: Feature request about: Suggest an idea for this project title: '' labels: '' assignees: '' --- **Is your feature request related to a problem? Please describe.** A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] **Describe the solution you'd like** A clear and concise description of what you want to happen. **Describe alternatives you've considered** A clear and concise description of any alternative solutions or features you've considered. **Additional context** Add any other context or screenshots about the feature request here. ================================================ FILE: .github/ISSUE_TEMPLATE/question.md ================================================ --- name: Question about: Ask a question, so we can help you easily title: '' labels: '' assignees: '' --- **Describe the Question** A clear and concise description of what the question is. **Reproducible Code** Please construct a minimum complete and reproducible example for us to get the same error. And tell us how to reproduce it like how you send a request or send what request. **Expected behavior** A clear and concise description of what you expected to happen. **Screenshots** If applicable, add screenshots to help explain your question. **Hertz version:** Please provide the version of Hertz you are using. **Environment:** The output of `go env`. **Additional context** Add any other context about the question here. ================================================ FILE: .github/PULL_REQUEST_TEMPLATE.md ================================================ #### What type of PR is this? #### Check the PR title. - [ ] This PR title match the format: \(optional scope): \ - [ ] The description of this PR title is user-oriented and clear enough for others to understand. - [ ] Attach the PR updating the user documentation if the current PR requires user awareness at the usage level. [User docs repo](https://github.com/cloudwego/cloudwego.github.io) #### (Optional) Translate the PR title into Chinese. #### (Optional) More detailed description for this PR(en: English/zh: Chinese). en: zh(optional): #### (Optional) Which issue(s) this PR fixes: #### (Optional) The PR that updates user documentation: ================================================ FILE: .github/labels.json ================================================ { "labels": { "invalid_issue": { "name": "invalid issue", "colour": "#CF2E1F", "description": "invalid issue (not related to Hertz or described in document or not enough information provided)" } }, "issue": { "invalid_issue": { "requires": 3, "conditions": [ { "type": "descriptionMatches", "pattern": "/^((?!Describe the bug).)*$/is" }, { "type": "descriptionMatches", "pattern": "/^((?!Is your feature request related to a problem\\? Please describe).)*$/is" }, { "type": "descriptionMatches", "pattern": "/^((?!Describe the Question).)*$/is" } ] } } } ================================================ FILE: .github/workflows/cmd-tests.yml ================================================ name: Cmd Tests on: push: paths: - 'cmd/**' pull_request: paths: - 'cmd/**' jobs: hz-test-unix: runs-on: [ self-hosted, Linux, X64 ] steps: - uses: actions/checkout@v4 - name: Set up Go uses: actions/setup-go@v5 with: go-version: stable cache: false # don't use cache for self-hosted runners - name: Setup Environment run: | echo "GOPATH=$(go env GOPATH)" >> $GITHUB_ENV echo "$(go env GOPATH)/bin" >> $GITHUB_PATH - name: Hz Test run: | cd cmd/hz sh test_hz_unix.sh hz-test-windows: runs-on: windows-latest steps: - uses: actions/checkout@v4 - name: Set up Go uses: actions/setup-go@v5 with: go-version: stable - name: Install Protobuf shell: pwsh run: | Invoke-WebRequest https://github.com/protocolbuffers/protobuf/releases/download/v3.19.4/protoc-3.19.4-win64.zip -OutFile protoc-3.19.4-win64.zip Expand-Archive protoc-3.19.4-win64.zip -DestinationPath protoc-3.19.4-win64 $GOPATH=go env GOPATH Copy-Item -Path protoc-3.19.4-win64\bin\protoc.exe -Destination $GOPATH\bin Copy-Item -Path protoc-3.19.4-win64\include\* -Destination cmd\hz\testdata\include\google -Recurse protoc --version - name: Hz Test run: | cd cmd/hz sh test_hz_windows.sh ================================================ FILE: .github/workflows/invalid_question.yml ================================================ name: "Close Invalid Issue" on: schedule: - cron: "0 0,8,16 * * *" permissions: contents: read jobs: stale: permissions: issues: write runs-on: ubuntu-latest env: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues uses: actions/stale@v6 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `issue` template. The issue will be closed in 1 days if no further activity occurs." stale-issue-label: "stale" days-before-stale: 0 days-before-close: 1 remove-stale-when-updated: true only-labels: "invalid issue" ================================================ FILE: .github/workflows/labeler.yml ================================================ name: "Labeler" on: issues: types: [ opened, edited, reopened ] jobs: triage: if: contains(github.event.issue.labels.*.name, 'invalid issue') || join(github.event.issue.labels) == '' runs-on: ubuntu-latest name: Label issues steps: - name: check out uses: actions/checkout@v3 - name: labeler uses: jbinda/super-labeler-action@develop with: GITHUB_TOKEN: "${{ secrets.GITHUB_TOKEN }}" ================================================ FILE: .github/workflows/pr-check.yml ================================================ name: Pull Request Check on: [ pull_request ] jobs: compliant: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - name: Check License Header uses: apache/skywalking-eyes/header@v0.4.0 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: Check Spell uses: crate-ci/typos@master lint: runs-on: [ self-hosted, Linux, X64 ] steps: - uses: actions/checkout@v4 - name: Set up Go uses: actions/setup-go@v5 with: go-version: stable cache: false # don't use cache for self-hosted runners - name: Golangci Lint # https://golangci-lint.run/ uses: golangci/golangci-lint-action@v8 with: version: latest only-new-issues: true ================================================ FILE: .github/workflows/unit-tests.yml ================================================ name: Unit Tests on: [push, pull_request] jobs: unit-test-x64: strategy: matrix: version: ["1.19", "1.20", oldstable, stable] runs-on: [self-hosted, Linux, X64] steps: - uses: actions/checkout@v4 - name: Set up Go uses: actions/setup-go@v5 with: go-version: ${{ matrix.version }} cache: false # don't use cache for self-hosted runners - name: Unit Test run: go test -race ./... unit-test-arm64: strategy: matrix: version: ["1.19", "1.20", oldstable, stable] runs-on: [self-hosted, ARM64] steps: - uses: actions/checkout@v4 - name: Set up Go uses: actions/setup-go@v5 with: go-version: ${{ matrix.version }} cache: false # don't use cache for self-hosted runners - name: Unit Test run: go test -race ./... unit-test-no-netpoll: runs-on: [self-hosted, Linux] steps: - uses: actions/checkout@v4 - name: Set up Go uses: actions/setup-go@v5 with: go-version: stable cache: false # don't use cache for self-hosted runners - name: Unit Test run: HERTZ_NO_NETPOLL=1 go test -race ./... ut-windows: runs-on: [self-hosted, Windows] steps: - uses: actions/checkout@v4 - name: Set up Go uses: actions/setup-go@v5 with: go-version: stable cache: false # don't use cache for self-hosted runners - name: Unit Test run: go test -race ./... code-cov: runs-on: [self-hosted, Linux, X64] steps: - uses: actions/checkout@v4 - name: Set up Go uses: actions/setup-go@v5 with: go-version: stable cache: false # don't use cache for self-hosted runners - name: Unit Test run: go test -coverpkg=./... -coverprofile=coverage.txt ./... - uses: codecov/codecov-action@v5 with: fail_ci_if_error: true ================================================ FILE: .github/workflows/vulncheck.yml ================================================ name: Run govulncheck on: push: branches: - develop paths: - "**" - "!**.md" pull_request: paths: - "**" - "!**.md" jobs: vulncheck-check: runs-on: ubuntu-latest env: GO111MODULE: on steps: - name: Fetch Repository uses: actions/checkout@v4 - name: Install Go uses: actions/setup-go@v5 with: go-version: stable check-latest: true cache: false - name: Install Govulncheck run: go install golang.org/x/vuln/cmd/govulncheck@latest - name: Run Govulncheck run: govulncheck ./... ================================================ FILE: .gitignore ================================================ .idea .vscode pkg/app/fs.go.hertz.gz coverage.txt coverage.out # OSX trash .DS_Store # test benchmark tmp output cpu.out mem.out *.test ================================================ FILE: .golangci.yaml ================================================ version: "2" linters: default: none enable: - govet - ineffassign - staticcheck - unconvert - unused settings: staticcheck: checks: - all - -SA5008 exclusions: generated: lax presets: - comments - common-false-positives - legacy - std-error-handling paths: - third_party$ - builtin$ - examples$ formatters: enable: - goimports exclusions: generated: lax paths: - third_party$ - builtin$ - examples$ ================================================ FILE: .licenserc.yaml ================================================ header: license: spdx-id: Apache-2.0 copyright-owner: CloudWeGo Authors paths: - '**/*.go' - '**/*.s' paths-ignore: - pkg/common/testdata/** - cmd/hz/protobuf/api comment: on-failure ================================================ FILE: .typos.toml ================================================ # Typo check: https://github.com/crate-ci/typos [files] extend-exclude = ["go.mod", "go.sum"] [default.extend-identifiers] pn = "pn" # packageName in cmd/hz ConnTLSer = "ConnTLSer" OPTIO = "OPTIO" # b[:5] for OPTION [default.extend-words] typ = "typ" # type Flate = "Flate" # flate algorithm [default] # Only ignore these exact "weird-case" examples used to demonstrate case-insensitive matching. extend-ignore-re = [ # Comments showing header normalization examples: "\\bcONTENT-lenGTH\\b", # Test intentionally using odd case: "\\bconnecTION\\b", ] ================================================ FILE: CODE_OF_CONDUCT.md ================================================ # Contributor Covenant Code of Conduct ## Our Pledge We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community. ## Our Standards Examples of behavior that contributes to a positive environment for our community include: * Demonstrating empathy and kindness toward other people * Being respectful of differing opinions, viewpoints, and experiences * Giving and gracefully accepting constructive feedback * Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience * Focusing on what is best not just for us as individuals, but for the overall community Examples of unacceptable behavior include: * The use of sexualized language or imagery, and sexual attention or advances of any kind * Trolling, insulting or derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or email address, without their explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Enforcement Responsibilities Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful. Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate. ## Scope This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at conduct@cloudwego.io. All complaints will be reviewed and investigated promptly and fairly. All community leaders are obligated to respect the privacy and security of the reporter of any incident. ## Enforcement Guidelines Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct: ### 1. Correction **Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community. **Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested. ### 2. Warning **Community Impact**: A violation through a single incident or series of actions. **Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban. ### 3. Temporary Ban **Community Impact**: A serious violation of community standards, including sustained inappropriate behavior. **Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban. ### 4. Permanent Ban **Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. **Consequence**: A permanent ban from any sort of public interaction within the community. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity). [homepage]: https://www.contributor-covenant.org For answers to common questions about this code of conduct, see the FAQ at https://www.contributor-covenant.org/faq. Translations are available at https://www.contributor-covenant.org/translations. ================================================ FILE: CONTRIBUTING.md ================================================ # How to Contribute ## Your First Pull Request We use github for our codebase. You can start by reading [How To Pull Request](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/about-pull-requests). ## Without Semantic Versioning We keep the stable code in branch `main` like `golang.org/x`. Development base on branch `develop`. And we promise the **Forward Compatibility** by adding new package directory with suffix `v2/v3` when code has break changes. ## Branch Organization We use [git-flow](https://nvie.com/posts/a-successful-git-branching-model/) as our branch organization, as known as [FDD](https://en.wikipedia.org/wiki/Feature-driven_development) ## Bugs ### 1. How to Find Known Issues We are using [Github Issues](https://github.com/cloudwego/hertz/issues) for our public bugs. We keep a close eye on this and try to make it clear when we have an internal fix in progress. Before filing a new task, try to make sure your problem doesn’t already exist. ### 2. Reporting New Issues Providing a reduced test code is a recommended way for reporting issues. Then can placed in: - Just in issues - [Golang Playground](https://play.golang.org/) ### 3. Security Bugs Please do not report the safe disclosure of bugs to public issues. Contact us by [Support Email](mailto:conduct@cloudwego.io) ## How to Get in Touch - [Email](mailto:conduct@cloudwego.io) ## Submit a Pull Request Before you submit your Pull Request (PR) consider the following guidelines: 1. Search [GitHub](https://github.com/cloudwego/hertz/pulls) for an open or closed PR that relates to your submission. You don't want to duplicate existing efforts. 2. Please submit an issue instead of PR if you have a better suggestion for format tools. We won't accept a lot of file changes directly without issue statement and assignment. 3. Be sure that the issue describes the problem you're fixing, or documents the design for the feature you'd like to add. Before we accepting your work, we need to conduct some checks and evaluations. So, It will be better if you can discuss the design with us. 4. [Fork](https://docs.github.com/en/github/getting-started-with-github/fork-a-repo) the cloudwego/hertz repo. 5. In your forked repository, make your changes in a new git branch: ``` git checkout -b my-fix-branch develop ``` 6. Create your patch, including appropriate test cases. 7. Follow our [Style Guides](#code-style-guides). 8. Commit your changes using a descriptive commit message that follows [AngularJS Git Commit Message Conventions](https://docs.google.com/document/d/1QrDFcIiPjSLDn3EL15IJygNPiHORgU1_OOAqWjiDU5Y/edit). Adherence to these conventions is necessary because release notes are automatically generated from these messages. 9. Push your branch to GitHub: ``` git push origin my-fix-branch ``` 10. In GitHub, send a pull request to `hertz:develop` with a clear and unambiguous title. ## Contribution Prerequisites - Our development environment keeps up with [Go Official](https://golang.org/project/). - You need fully checking with lint tools before submit your pull request. [gofmt](https://golang.org/pkg/cmd/gofmt/) and [golangci-lint](https://github.com/golangci/golangci-lint) - You are familiar with [Github](https://github.com) - Maybe you need familiar with [Actions](https://github.com/features/actions)(our default workflow tool). ## Code Style Guides See [Go Code Review Comments](https://github.com/golang/go/wiki/CodeReviewComments). Good resources: - [Effective Go](https://golang.org/doc/effective_go) - [Pingcap General advice](https://pingcap.github.io/style-guide/general.html) - [Uber Go Style Guide](https://github.com/uber-go/guide/blob/master/style.md) ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: Makefile ================================================ SHELL := /bin/bash .PHONY: \ help \ coverage \ vet \ lint \ fmt \ version all: imports fmt lint vet errors build help: @echo 'Usage: make ... ' @echo '' @echo 'Available targets are:' @echo '' @echo ' help Show this help screen.' @echo ' coverage Report code tests coverage.' @echo ' vet Run go vet.' @echo ' lint Run golint.' @echo ' fmt Run go fmt.' @echo ' version Display Go version.' @echo '' @echo 'Targets run by default are: lint, vet.' @echo '' print-%: @echo $* = $($*) deps: go get golang.org/x/lint/golint coverage: go test $(go list ./... | grep -v examples) -coverprofile coverage.txt ./... vet: go vet ./... lint: deps golint ./... fmt: go install mvdan.cc/gofumpt@latest gofumpt -l -w -extra . pre-dev: make pre-commit pre-commit: bash script/pre-commit-hook release: package-release sign-release version: @go version ================================================ FILE: NOTICE ================================================ CloudWeGo Copyright 2022 CloudWeGo Authors Apache Thrift Copyright (C) 2006 - 2019, The Apache Software Foundation This product includes software developed at The Apache Software Foundation (http://www.apache.org/). ================================================ FILE: README.md ================================================ # Hertz English | [中文](README_cn.md) [![Release](https://img.shields.io/github/v/release/cloudwego/hertz)](https://github.com/cloudwego/hertz/releases) [![WebSite](https://img.shields.io/website?up_message=cloudwego&url=https%3A%2F%2Fwww.cloudwego.io%2F)](https://www.cloudwego.io/) [![License](https://img.shields.io/github/license/cloudwego/hertz)](https://github.com/cloudwego/hertz/blob/main/LICENSE) [![Go Report Card](https://goreportcard.com/badge/github.com/cloudwego/hertz)](https://goreportcard.com/report/github.com/cloudwego/hertz) [![OpenIssue](https://img.shields.io/github/issues/cloudwego/hertz)](https://github.com/cloudwego/hertz/issues) [![ClosedIssue](https://img.shields.io/github/issues-closed/cloudwego/hertz)](https://github.com/cloudwego/hertz/issues?q=is%3Aissue+is%3Aclosed) ![Stars](https://img.shields.io/github/stars/cloudwego/hertz) ![Forks](https://img.shields.io/github/forks/cloudwego/hertz) Hertz [həːts] is a high-usability, high-performance and high-extensibility Golang HTTP framework that helps developers build microservices. It was originally a fork of [fasthttp](https://github.com/valyala/fasthttp) and inspired by [gin](https://github.com/gin-gonic/gin), [echo](https://github.com/labstack/echo) and combined with the internal requirements in ByteDance. At present, it has been widely used inside ByteDance. Nowadays, more and more microservices use Golang. If you have requirements for microservice performance and hope that the framework can fully meet the internal customizable requirements, Hertz will be a good choice. ## Basic Features - High usability During the development process, it is often more important to write the correct code quickly. Therefore, in the iterative process of Hertz, we actively listen to users' opinions and continue to polish the framework, hoping to provide users with a better user experience and help users write correct code faster. - High performance Hertz uses the self-developed high-performance network library Netpoll by default. In some special scenarios, compared to Go Net, Hertz has certain advantages in QPS and time delay. For performance data, please refer to the Echo data in the figure below. Comparison of four frameworks: ![Performance](images/performance-4.png) Comparison of three frameworks: ![Performance](images/performance-3.png) For detailed performance data, please refer to [hertz-benchmark](https://github.com/cloudwego/hertz-benchmark). - High extensibility Hertz adopts a layered design, providing more interfaces and default extension implementations. Users can also extend by themselves. At the same time, thanks to the layered design of the framework, the extensibility of the framework will be much greater. At present, only stable capabilities are open-sourced to the community. More planning refers to [RoadMap](ROADMAP.md). - Multi-protocol support The Hertz framework provides HTTP/1.1 and ALPN protocol support natively. In addition, due to the layered design, Hertz even supports custom build protocol resolution logic to meet any needs of protocol layer extensions. - Network layer switching capability Hertz implements the function to switch between Netpoll and Go Net on demand. Users can choose the appropriate network library for different scenarios. And Hertz also supports the extension of network library in the form of plug-ins. ## Documentation ### [Getting Started](https://www.cloudwego.io/docs/hertz/getting-started/) ### Example The Hertz-Examples repository provides code out of the box. [more](https://www.cloudwego.io/zh/docs/hertz/tutorials/example/) ### Basic Features Contains introduction and use of general middleware, context selection, data binding, data rendering, direct access, logging, error handling. [more](https://www.cloudwego.io/zh/docs/hertz/tutorials/basic-feature/) ### Observability Contains instrumentation, logging, tracing. [more](https://www.cloudwego.io/docs/hertz/tutorials/observability/) ### Framework Extension Contains network library extensions, protocol extensions, logger extensions, monitoring extensions. [more](https://www.cloudwego.io/zh/docs/hertz/tutorials/framework-exten/) ### Reference Framework configurable items list. [more](https://www.cloudwego.io/zh/docs/hertz/reference/) ### FAQ Frequently Asked Questions. [more](https://www.cloudwego.io/zh/docs/hertz/faq/) ## Performance Performance testing can only provide a relative reference. In production, there are many factors that can affect actual performance. We provide the [hertz-benchmark](https://github.com/cloudwego/hertz-benchmark) project to track and compare the performance of Hertz and other frameworks in different situations for reference. ## Related Projects - [Netpoll](https://github.com/cloudwego/netpoll): A high-performance network library. Hertz integrated by default. - [Example](https://github.com/cloudwego/hertz-examples): Use examples of Hertz. ## Extensions [Hertz-contrib](https://github.com/hertz-contrib) is a partial extension library of Hertz, which users can integrate into Hertz through options according to their needs, built and maintained by the community. | Extensions | Description | |----------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | [Autotls](https://github.com/hertz-contrib/autotls) | Make Hertz support Let's Encrypt. | | [Http2](https://github.com/hertz-contrib/http2) | HTTP2 support for Hertz. | | [Websocket](https://github.com/hertz-contrib/websocket) | Enable Hertz to support the Websocket protocol. | | [Etag](https://github.com/hertz-contrib/etag) | Support ETag (or entity tag) HTTP response header for Hertz. | | [Limiter](https://github.com/hertz-contrib/limiter) | Provides a current limiter based on the bbr algorithm. | | [Monitor-prometheus](https://github.com/hertz-contrib/monitor-prometheus) | Provides service monitoring based on Prometheus. | | [Obs-opentelemetry](https://github.com/hertz-contrib/obs-opentelemetry) | Hertz's Opentelemetry extension that supports Metric, Logger, Tracing and works out of the box. | | [Opensergo](https://github.com/hertz-contrib/opensergo) | The Opensergo extension. | | [Pprof](https://github.com/hertz-contrib/pprof) | Extension for Hertz integration with Pprof. | | [Registry](https://github.com/hertz-contrib/registry) | Provides service registry and discovery functions. So far, the supported service discovery extensions are nacos, consul, etcd, eureka, polaris, servicecomb, zookeeper, redis. | | [Sentry](https://github.com/hertz-contrib/hertzsentry) | Sentry extension provides some unified interfaces to help users perform real-time error monitoring. | | [Tracer](https://github.com/hertz-contrib/tracer) | Link tracing based on Opentracing. | | [Basicauth](https://github.com/cloudwego/hertz/tree/develop/pkg/app/middlewares/server/basic_auth) | Basicauth middleware can provide HTTP basic authentication. | | [Jwt](https://github.com/hertz-contrib/jwt) | Jwt extension. | | [Keyauth](https://github.com/hertz-contrib/keyauth) | Provides token-based authentication. | | [Requestid](https://github.com/hertz-contrib/requestid) | Add request id in response. | | [Sessions](https://github.com/hertz-contrib/sessions) | Session middleware with multi-state store support. | | [Casbin](https://github.com/hertz-contrib/casbin) | Supports various access control models by Casbin. | | [Cors](https://github.com/hertz-contrib/cors) | Provides cross-domain resource sharing support. | | [Csrf](https://github.com/hertz-contrib/csrf) | Csrf middleware is used to prevent cross-site request forgery attacks. | | [Secure](https://github.com/hertz-contrib/secure) | Secure middleware with multiple configuration items. | | [Gzip](https://github.com/hertz-contrib/gzip) | A Gzip extension with multiple options. | | [I18n](https://github.com/hertz-contrib/i18n) | Helps translate Hertz programs into multi programming languages. | | [Lark](https://github.com/hertz-contrib/lark-hertz) | Use hertz handle Lark/Feishu card message and event callback. | | [Loadbalance](https://github.com/hertz-contrib/loadbalance) | Provides load balancing algorithms for Hertz. | | [Logger](https://github.com/hertz-contrib/logger) | Logger extension for Hertz, which provides support for zap, logrus, zerologs logging frameworks. | | [Recovery](https://github.com/cloudwego/hertz/tree/develop/pkg/app/middlewares/server/recovery) | Recovery middleware for Hertz. | | [Reverseproxy](https://github.com/hertz-contrib/reverseproxy) | Implement a reverse proxy. | | [Swagger](https://github.com/hertz-contrib/swagger) | Automatically generate RESTful API documentation with Swagger 2.0. | | [Cache](https://github.com/hertz-contrib/cache) | Hertz middleware for cache HTTP response with multi-backend support | ## Blogs - [ByteDance Practice on Go Network Library](https://www.cloudwego.io/blog/2020/05/24/bytedance-practices-on-go-network-library/) - [Ultra-large-scale Enterprise-level Microservice HTTP Framework — Hertz is Officially Open Source!](https://www.cloudwego.io/zh/blog/2022/06/21/%E8%B6%85%E5%A4%A7%E8%A7%84%E6%A8%A1%E7%9A%84%E4%BC%81%E4%B8%9A%E7%BA%A7%E5%BE%AE%E6%9C%8D%E5%8A%A1-http-%E6%A1%86%E6%9E%B6-hertz-%E6%AD%A3%E5%BC%8F%E5%BC%80%E6%BA%90/) - [ByteDance Open Source Go HTTP Framework Hertz Design Practice](https://www.cloudwego.io/zh/blog/2022/06/21/%E5%AD%97%E8%8A%82%E8%B7%B3%E5%8A%A8%E5%BC%80%E6%BA%90-go-http-%E6%A1%86%E6%9E%B6-hertz-%E8%AE%BE%E8%AE%A1%E5%AE%9E%E8%B7%B5/) - [Help ByteDance Reduce Costs and Increase Efficiency, the Design Practice for Large-scale Enterprise-level HTTP Framework Hertz](https://www.cloudwego.io/zh/blog/2022/09/27/%E5%8A%A9%E5%8A%9B%E5%AD%97%E8%8A%82%E9%99%8D%E6%9C%AC%E5%A2%9E%E6%95%88%E5%A4%A7%E8%A7%84%E6%A8%A1%E4%BC%81%E4%B8%9A%E7%BA%A7-http-%E6%A1%86%E6%9E%B6-hertz-%E8%AE%BE%E8%AE%A1%E5%AE%9E%E8%B7%B5/) - [Getting Started with Hertz: Performance Testing Guide](https://www.cloudwego.io/blog/2023/02/24/getting-started-with-hertz-performance-testing-guide/) ## Contributing [Contributing](https://github.com/cloudwego/hertz/blob/main/CONTRIBUTING.md) ## RoadMap [Hertz RoadMap](ROADMAP.md) ## License Hertz is distributed under the [Apache License, version 2.0](https://github.com/cloudwego/hertz/blob/main/LICENSE). The licenses of third party dependencies of Hertz are explained [here](https://github.com/cloudwego/hertz/blob/main/licenses). ## Community - Email: [conduct@cloudwego.io](conduct@cloudwego.io) - How to become a member: [COMMUNITY MEMBERSHIP](https://github.com/cloudwego/community/blob/main/COMMUNITY_MEMBERSHIP.md) - Issues: [Issues](https://github.com/cloudwego/hertz/issues) - Slack: Join our CloudWeGo community [Slack Channel](https://join.slack.com/t/cloudwego/shared_invite/zt-tmcbzewn-UjXMF3ZQsPhl7W3tEDZboA). - Lark: Scan the QR code below with [Lark](https://www.larksuite.com/zh_cn/download) to join our CloudWeGo/hertz user group. ![LarkGroup](images/lark_group.png) ## Contributors Thank you for your contribution to Hertz! [![Contributors](https://contrib.rocks/image?repo=cloudwego/hertz)](https://github.com/cloudwego/hertz/graphs/contributors) ## Landscapes

  

CloudWeGo enriches the CNCF CLOUD NATIVE Landscape.

================================================ FILE: README_cn.md ================================================ # Hertz [English](README.md) | 中文 [![Release](https://img.shields.io/github/v/release/cloudwego/hertz)](https://github.com/cloudwego/hertz/releases) [![WebSite](https://img.shields.io/website?up_message=cloudwego&url=https%3A%2F%2Fwww.cloudwego.io%2F)](https://www.cloudwego.io/) [![License](https://img.shields.io/github/license/cloudwego/hertz)](https://github.com/cloudwego/hertz/blob/main/LICENSE) [![Go Report Card](https://goreportcard.com/badge/github.com/cloudwego/hertz)](https://goreportcard.com/report/github.com/cloudwego/hertz) [![OpenIssue](https://img.shields.io/github/issues/cloudwego/hertz)](https://github.com/cloudwego/hertz/issues) [![ClosedIssue](https://img.shields.io/github/issues-closed/cloudwego/hertz)](https://github.com/cloudwego/hertz/issues?q=is%3Aissue+is%3Aclosed) ![Stars](https://img.shields.io/github/stars/cloudwego/hertz) ![Forks](https://img.shields.io/github/forks/cloudwego/hertz) Hertz[həːts] 是一个 Golang 微服务 HTTP 框架,最初fork自[fasthttp](https://github.com/valyala/fasthttp),并在设计时参考了其他开源框架[gin](https://github.com/gin-gonic/gin)、[echo](https://github.com/labstack/echo) 的优势,并结合字节跳动内部的需求,使其具有高易用性、高 性能、高扩展性等特点,目前在字节跳动内部已广泛使用。如今越来越多的微服务选择使用 Golang,如果对微服务性能有要求,又希望框架能够充分满足内部的可定制化需求,Hertz 会是一个不错的选择。 ## 框架特点 - 高易用性 在开发过程中,快速写出来正确的代码往往是更重要的。因此,在 Hertz 在迭代过程中,积极听取用户意见,持续打磨框架,希望为用户提供一个更好的使用体验,帮助用户更快的写出正确的代码。 - 高性能 Hertz 默认使用自研的高性能网络库 Netpoll,在一些特殊场景相较于 go net,Hertz 在 QPS、时延上均具有一定优势。关于性能数据,可参考下图 Echo 数据。 四个框架的对比: ![Performance](images/performance-4.png) 三个框架的对比: ![Performance](images/performance-3.png) 关于详细的性能数据,可参考 [hertz-benchmark](https://github.com/cloudwego/hertz-benchmark)。 - 高扩展性 Hertz 采用了分层设计,提供了较多的接口以及默认的扩展实现,用户也可以自行扩展。同时得益于框架的分层设计,框架的扩展性也会大很多。目前仅将稳定的能力开源给社区,更多的规划参考 [RoadMap](ROADMAP.md)。 - 多协议支持 Hertz 框架原生提供 HTTP/1.1 及 ALPN 协议支持。除此之外,由于分层设计,Hertz 甚至支持自定义构建协议解析逻辑,以满足协议层扩展的任意需求。 - 网络层切换能力 Hertz 实现了 Netpoll 和 Golang 原生网络库 间按需切换能力,用户可以针对不同的场景选择合适的网络库,同时也支持以插件的方式为 Hertz 扩展网络库实现。 ## 详细文档 ### [快速开始](https://www.cloudwego.io/zh/docs/hertz/getting-started/) ### Example Hertz-Examples 仓库提供了开箱即用的代码,[详见](https://www.cloudwego.io/zh/docs/hertz/tutorials/example/)。 ### 用户指南 ### 基本特性 包含通用中间件的介绍和使用,上下文选择,数据绑定,数据渲染,直连访问,日志,错误处理,[详见文档](https://www.cloudwego.io/zh/docs/hertz/tutorials/basic-feature/) ### 可观测性 包含日志,链路追踪,埋点,[详见文档](https://www.cloudwego.io/zh/docs/hertz/tutorials/observability/) ### 框架扩展 包含网络库扩展,协议扩展,日志扩展,监控扩展,服务注册与发现扩展,[详见文档](https://www.cloudwego.io/zh/docs/hertz/tutorials/framework-exten/) ### 参考 框架可配置项一览,[详见文档](https://www.cloudwego.io/zh/docs/hertz/reference/) ### FAQ 常见问题排查,[详见文档](https://www.cloudwego.io/zh/docs/hertz/faq/) ## 框架性能 性能测试只能提供相对参考,工业场景下,有诸多因素可以影响实际的性能表现 我们提供了 [hertz-benchmark](https://github.com/cloudwego/hertz-benchmark) 项目用来长期追踪和比较 Hertz 与其他框架在不同情况下的性能数据以供参考 ## 相关项目 - [Netpoll](https://github.com/cloudwego/netpoll): 自研高性能网络库,Hertz 默认集成 - [Example](https://github.com/cloudwego/hertz-examples): Hertz 使用例子 ## 相关拓展 [hertz-Contrib](https://github.com/hertz-contrib) 是 Hertz 扩展生态所在组织,提供服务注册发现、可观测、安全、流量治理、协议、HTTP 通用能力等扩展,由社区共建与维护 | 拓展 | 描述 | |----------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------| | [Autotls](https://github.com/hertz-contrib/autotls) | 为 Hertz 支持 Let's Encrypt 。 | | [Http2](https://github.com/hertz-contrib/http2) | 提供对 HTTP2 的支持。 | | [Websocket](https://github.com/hertz-contrib/websocket) | 使 Hertz 支持 Websocket 协议。 | | [Etag](https://github.com/hertz-contrib/etag) | 提供 ETag HTTP 响应标头。 | | [Limiter](https://github.com/hertz-contrib/limiter) | 提供了基于 bbr 算法的限流器。 | | [Monitor-prometheus](https://github.com/hertz-contrib/monitor-prometheus) | 提供基于 Prometheus 服务监控功能。 | | [Obs-opentelemetry](https://github.com/hertz-contrib/obs-opentelemetry) | Hertz 的 Opentelemetry 扩展,支持 Metric、Logger、Tracing 并且达到开箱即用。 | | [Opensergo](https://github.com/hertz-contrib/opensergo) | Opensergo 扩展。 | | [Pprof](https://github.com/hertz-contrib/pprof) | Hertz 集成 Pprof 的扩展。 | | [Registry](https://github.com/hertz-contrib/registry) | 提供服务注册与发现功能。到现在为止,支持的服务发现拓展有 nacos, consul, etcd, eureka, polaris, servicecomb, zookeeper, redis。 | | [Sentry](https://github.com/hertz-contrib/hertzsentry) | Sentry 拓展提供了一些统一的接口来帮助用户进行实时的错误监控。 | | [Tracer](https://github.com/hertz-contrib/tracer) | 基于 Opentracing 的链路追踪。 | | [Basicauth](https://github.com/cloudwego/hertz/tree/develop/pkg/app/middlewares/server/basic_auth) | Basicauth 中间件能够提供 HTTP 基本身份验证。 | | [Jwt](https://github.com/hertz-contrib/jwt) | Jwt 拓展。 | | [Keyauth](https://github.com/hertz-contrib/keyauth) | 提供基于 token 的身份验证。 | | [Requestid](https://github.com/hertz-contrib/requestid) | 在 response 中添加 request id。 | | [Sessions](https://github.com/hertz-contrib/sessions) | 具有多状态存储支持的 Session 中间件。 | | [Casbin](https://github.com/hertz-contrib/casbin) | 通过 Casbin 支持各种访问控制模型。 | | [Cors](https://github.com/hertz-contrib/cors) | 提供跨域资源共享支持。 | | [Csrf](https://github.com/hertz-contrib/csrf) | Csrf 中间件用于防止跨站点请求伪造攻击。 | | [Secure](https://github.com/hertz-contrib/secure) | 具有多配置项的 Secure 中间件。 | | [Gzip](https://github.com/hertz-contrib/gzip) | 含多个可选项的 Gzip 拓展。 | | [I18n](https://github.com/hertz-contrib/i18n) | 可帮助将 Hertz 程序翻译成多种语言。 | | [Lark](https://github.com/hertz-contrib/lark-hertz) | 在 Hertz 中处理 Lark/飞书的卡片消息和事件的回调。 | | [Loadbalance](https://github.com/hertz-contrib/loadbalance) | 提供适用于 Hertz 的负载均衡算法。 | | [Logger](https://github.com/hertz-contrib/logger) | Hertz 的日志拓展,提供了对 zap、logrus、zerologs 日志框架的支持。 | | [Recovery](https://github.com/cloudwego/hertz/tree/develop/pkg/app/middlewares/server/recovery) | Hertz 的异常恢复中间件。 | | [Reverseproxy](https://github.com/hertz-contrib/reverseproxy) | 实现反向代理。 | | [Swagger](https://github.com/hertz-contrib/swagger) | 使用 Swagger 2.0 自动生成 RESTful API 文档。 | | [Cache](https://github.com/hertz-contrib/cache) | 用于缓存 HTTP 接口内容的 Hertz 中间件,支持多种客户端。 | ## 相关文章 - [字节跳动在 Go 网络库上的实践](https://www.cloudwego.io/zh/blog/2020/05/24/%E5%AD%97%E8%8A%82%E8%B7%B3%E5%8A%A8%E5%9C%A8-go-%E7%BD%91%E7%BB%9C%E5%BA%93%E4%B8%8A%E7%9A%84%E5%AE%9E%E8%B7%B5/) - [超大规模的企业级微服务 HTTP 框架 — Hertz 正式开源!](https://www.cloudwego.io/zh/blog/2022/06/21/%E8%B6%85%E5%A4%A7%E8%A7%84%E6%A8%A1%E7%9A%84%E4%BC%81%E4%B8%9A%E7%BA%A7%E5%BE%AE%E6%9C%8D%E5%8A%A1-http-%E6%A1%86%E6%9E%B6-hertz-%E6%AD%A3%E5%BC%8F%E5%BC%80%E6%BA%90/) - [字节跳动开源 Go HTTP 框架 Hertz 设计实践](https://www.cloudwego.io/zh/blog/2022/06/21/%E5%AD%97%E8%8A%82%E8%B7%B3%E5%8A%A8%E5%BC%80%E6%BA%90-go-http-%E6%A1%86%E6%9E%B6-hertz-%E8%AE%BE%E8%AE%A1%E5%AE%9E%E8%B7%B5/) - [助力字节降本增效,大规模企业级 HTTP 框架 Hertz 设计实践](https://www.cloudwego.io/zh/blog/2022/09/27/%E5%8A%A9%E5%8A%9B%E5%AD%97%E8%8A%82%E9%99%8D%E6%9C%AC%E5%A2%9E%E6%95%88%E5%A4%A7%E8%A7%84%E6%A8%A1%E4%BC%81%E4%B8%9A%E7%BA%A7-http-%E6%A1%86%E6%9E%B6-hertz-%E8%AE%BE%E8%AE%A1%E5%AE%9E%E8%B7%B5/) - [HTTP 框架 Hertz 实践入门:性能测试指南](https://www.cloudwego.io/zh/blog/2023/02/24/http-%E6%A1%86%E6%9E%B6-hertz-%E5%AE%9E%E8%B7%B5%E5%85%A5%E9%97%A8%E6%80%A7%E8%83%BD%E6%B5%8B%E8%AF%95%E6%8C%87%E5%8D%97/) ## 贡献代码 [Contributing](https://github.com/cloudwego/hertz/blob/main/CONTRIBUTING.md) ## RoadMap [Hertz RoadMap](ROADMAP.md) ## 开源许可 Hertz 基于[Apache License 2.0](https://github.com/cloudwego/hertz/blob/main/LICENSE) 许可证,其依赖的三方组件的开源许可见 [Licenses](https://github.com/cloudwego/hertz/blob/main/licenses)。 ## 联系我们 - Email: conduct@cloudwego.io - 如何成为 member: [COMMUNITY MEMBERSHIP](https://github.com/cloudwego/community/blob/main/COMMUNITY_MEMBERSHIP.md) - Issues: [Issues](https://github.com/cloudwego/hertz/issues) - Slack: 加入我们的 [Slack 频道](https://join.slack.com/t/cloudwego/shared_invite/zt-tmcbzewn-UjXMF3ZQsPhl7W3tEDZboA) - 飞书用户群([注册飞书](https://www.larksuite.com/zh_cn/download)进群) ![LarkGroup](images/lark_group_cn.png) ## 贡献者 感谢您对 Hertz 作出的贡献! [![Contributors](https://contrib.rocks/image?repo=cloudwego/hertz)](https://github.com/cloudwego/hertz/graphs/contributors) ## Landscapes

  

CloudWeGo 丰富了 CNCF 云原生生态

================================================ FILE: ROADMAP.md ================================================ # Hertz RoadMap From 2025, instead of developing new features we will focus on optimizing core functionalities and user experience. The following is a list of planned projects. - [ ] Optimize `RequestContext` issues under concurrency. - [ ] Refactor binding and validator for better extensibility, and deprecate built-in implementations. - [ ] Make `netpoll` optional, and `pkg/network/standard` the default. - [ ] Enhance Content-Encoding extension interface for better extensibility. - [ ] Optimize `pkg/common/adaptor`, and deprecate implementations with `net/http` alternatives. - [ ] Deprecate the built-in protobuf code generator, use [cloudwego/prutal](https://github.com/cloudwego/prutal) All users are encouraged to provide suggestions on the projects listed above, or to submit proposals for enhancing current features. ================================================ FILE: cmd/hz/app/app.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package app import ( "errors" "fmt" "os" "path/filepath" "strings" "github.com/cloudwego/hertz/cmd/hz/config" "github.com/cloudwego/hertz/cmd/hz/generator" "github.com/cloudwego/hertz/cmd/hz/meta" "github.com/cloudwego/hertz/cmd/hz/protobuf" "github.com/cloudwego/hertz/cmd/hz/thrift" "github.com/cloudwego/hertz/cmd/hz/util" "github.com/cloudwego/hertz/cmd/hz/util/logs" "github.com/urfave/cli/v2" ) // global args. MUST fork it when use var globalArgs = config.NewArgument() func New(c *cli.Context) error { args, err := globalArgs.Parse(c, meta.CmdNew) if err != nil { return cli.Exit(err, meta.LoadError) } setLogVerbose(args.Verbose) logs.Debugf("args: %#v\n", args) exist, err := util.PathExist(filepath.Join(args.OutDir, meta.ManifestFile)) if err != nil { return cli.Exit(err, meta.LoadError) } if exist && !args.ForceNew { return cli.Exit(fmt.Errorf("the current is already a hertz project, if you want to regenerate it you can specify \"-force\""), meta.LoadError) } err = GenerateLayout(args) if err != nil { return cli.Exit(err, meta.GenerateLayoutError) } err = TriggerPlugin(args) if err != nil { return cli.Exit(err, meta.PluginError) } // ".hz" file converges to the hz tool manifest := new(meta.Manifest) args.InitManifest(manifest) err = manifest.Persist(args.OutDir) if err != nil { return cli.Exit(fmt.Errorf("persist manifest failed: %v", err), meta.PersistError) } if !args.NeedGoMod && args.IdlType == meta.IdlThrift { logs.Warn(meta.AddThriftReplace) } return nil } func Update(c *cli.Context) error { // begin to update args, err := globalArgs.Parse(c, meta.CmdUpdate) if err != nil { return cli.Exit(err, meta.LoadError) } setLogVerbose(args.Verbose) logs.Debugf("Args: %#v\n", args) manifest := new(meta.Manifest) err = manifest.InitAndValidate(args.OutDir) if err != nil { return cli.Exit(err, meta.LoadError) } // update argument by ".hz", can automatically get "handler_dir"/"model_dir"/"router_dir" args.UpdateByManifest(manifest) err = TriggerPlugin(args) if err != nil { return cli.Exit(err, meta.PluginError) } // If the "handler_dir"/"model_dir" is updated, write it back to ".hz" args.UpdateManifest(manifest) err = manifest.Persist(args.OutDir) if err != nil { return cli.Exit(fmt.Errorf("persist manifest failed: %v", err), meta.PersistError) } return nil } func Model(c *cli.Context) error { args, err := globalArgs.Parse(c, meta.CmdModel) if err != nil { return cli.Exit(err, meta.LoadError) } setLogVerbose(args.Verbose) logs.Debugf("Args: %#v\n", args) err = TriggerPlugin(args) if err != nil { return cli.Exit(err, meta.PluginError) } return nil } func Client(c *cli.Context) error { args, err := globalArgs.Parse(c, meta.CmdClient) if err != nil { return cli.Exit(err, meta.LoadError) } setLogVerbose(args.Verbose) logs.Debugf("Args: %#v\n", args) err = TriggerPlugin(args) if err != nil { return cli.Exit(err, meta.PluginError) } return nil } func PluginMode() { mode := os.Getenv(meta.EnvPluginMode) if len(os.Args) <= 1 && mode != "" { switch mode { case meta.ThriftPluginName: plugin := new(thrift.Plugin) os.Exit(plugin.Run()) case meta.ProtocPluginName: plugin := new(protobuf.Plugin) os.Exit(plugin.Run()) } } } func Init() *cli.App { // flags verboseFlag := cli.BoolFlag{Name: "verbose,vv", Usage: "turn on verbose mode", Destination: &globalArgs.Verbose} idlFlag := cli.StringSliceFlag{Name: "idl", Usage: "Specify the IDL file path. (.thrift or .proto)"} moduleFlag := cli.StringFlag{Name: "module", Aliases: []string{"mod"}, Usage: "Specify the Go module name.", Destination: &globalArgs.Gomod} serviceNameFlag := cli.StringFlag{Name: "service", Usage: "Specify the service name.", Destination: &globalArgs.ServiceName} outDirFlag := cli.StringFlag{Name: "out_dir", Usage: "Specify the project path.", Destination: &globalArgs.OutDir} handlerDirFlag := cli.StringFlag{Name: "handler_dir", Usage: "Specify the handler relative path (based on \"out_dir\").", Destination: &globalArgs.HandlerDir} modelDirFlag := cli.StringFlag{Name: "model_dir", Usage: "Specify the model relative path (based on \"out_dir\").", Destination: &globalArgs.ModelDir} routerDirFlag := cli.StringFlag{Name: "router_dir", Usage: "Specify the router relative path (based on \"out_dir\").", Destination: &globalArgs.RouterDir} useFlag := cli.StringFlag{Name: "use", Usage: "Specify the model package to import for handler.", Destination: &globalArgs.Use} baseDomainFlag := cli.StringFlag{Name: "base_domain", Usage: "Specify the request domain.", Destination: &globalArgs.BaseDomain} clientDirFlag := cli.StringFlag{Name: "client_dir", Usage: "Specify the client path. If not specified, IDL generated path is used for 'client' command; no client code is generated for 'new' command", Destination: &globalArgs.ClientDir} forceClientDirFlag := cli.StringFlag{Name: "force_client_dir", Usage: "Specify the client path, and won't use namespaces as subpaths", Destination: &globalArgs.ForceClientDir} optPkgFlag := cli.StringSliceFlag{Name: "option_package", Aliases: []string{"P"}, Usage: "Specify the package path. ({include_path}={import_path})"} includesFlag := cli.StringSliceFlag{Name: "proto_path", Aliases: []string{"I"}, Usage: "Add an IDL search path for includes. (Valid only if idl is protobuf)"} excludeFilesFlag := cli.StringSliceFlag{Name: "exclude_file", Aliases: []string{"E"}, Usage: "Specify the files that do not need to be updated."} thriftOptionsFlag := cli.StringSliceFlag{Name: "thriftgo", Aliases: []string{"t"}, Usage: "Specify arguments for the thriftgo. ({flag}={value})"} protoOptionsFlag := cli.StringSliceFlag{Name: "protoc", Aliases: []string{"p"}, Usage: "Specify arguments for the protoc. ({flag}={value})"} thriftPluginsFlag := cli.StringSliceFlag{Name: "thrift-plugins", Usage: "Specify plugins for the thriftgo. ({plugin_name}:{options})"} protoPluginsFlag := cli.StringSliceFlag{Name: "protoc-plugins", Usage: "Specify plugins for the protoc. ({plugin_name}:{options}:{out_dir})"} noRecurseFlag := cli.BoolFlag{Name: "no_recurse", Usage: "Generate master model only.", Destination: &globalArgs.NoRecurse} forceNewFlag := cli.BoolFlag{Name: "force", Aliases: []string{"f"}, Usage: "Force new a project, which will overwrite the generated files", Destination: &globalArgs.ForceNew} forceUpdateClientFlag := cli.BoolFlag{Name: "force_client", Usage: "Force update 'hertz_client.go'", Destination: &globalArgs.ForceUpdateClient} enableExtendsFlag := cli.BoolFlag{Name: "enable_extends", Usage: "Parse 'extends' for thrift IDL", Destination: &globalArgs.EnableExtends} sortRouterFlag := cli.BoolFlag{Name: "sort_router", Usage: "Sort router register code, to avoid code difference", Destination: &globalArgs.SortRouter} jsonEnumStrFlag := cli.BoolFlag{Name: "json_enumstr", Usage: "Use string instead of num for json enums when idl is thrift.", Destination: &globalArgs.JSONEnumStr} queryEnumIntFlag := cli.BoolFlag{Name: "query_enumint", Usage: "Use num instead of string for query enum parameter.", Destination: &globalArgs.QueryEnumAsInt} unsetOmitemptyFlag := cli.BoolFlag{Name: "unset_omitempty", Usage: "Remove 'omitempty' tag for generated struct.", Destination: &globalArgs.UnsetOmitempty} protoCamelJSONTag := cli.BoolFlag{Name: "pb_camel_json_tag", Usage: "Convert Name style for json tag to camel(Only works protobuf).", Destination: &globalArgs.ProtobufCamelJSONTag} snakeNameFlag := cli.BoolFlag{Name: "snake_tag", Usage: "Use snake_case style naming for tags. (Only works for 'form', 'query', 'json')", Destination: &globalArgs.SnakeName} rmTagFlag := cli.StringSliceFlag{Name: "rm_tag", Usage: "Remove the default tag(json/query/form). If the annotation tag is set explicitly, it will not be removed."} customLayout := cli.StringFlag{Name: "customize_layout", Usage: "Specify the path for layout template.", Destination: &globalArgs.CustomizeLayout} customLayoutData := cli.StringFlag{Name: "customize_layout_data_path", Usage: "Specify the path for layout template render data.", Destination: &globalArgs.CustomizeLayoutData} customPackage := cli.StringFlag{Name: "customize_package", Usage: "Specify the path for package template.", Destination: &globalArgs.CustomizePackage} handlerByMethod := cli.BoolFlag{Name: "handler_by_method", Usage: "Generate a separate handler file for each method.", Destination: &globalArgs.HandlerByMethod} trimGoPackage := cli.StringFlag{Name: "trim_gopackage", Aliases: []string{"trim_pkg"}, Usage: "Trim the prefix of go_package for protobuf.", Destination: &globalArgs.TrimGoPackage} // client flag enableClientOptionalFlag := cli.BoolFlag{Name: "enable_optional", Usage: "Optional field do not transfer for thrift if not set.(Only works for query tag)", Destination: &globalArgs.EnableClientOptional} // app app := cli.NewApp() app.Name = "hz" app.Usage = "A idl parser and code generator for Hertz projects" app.Version = meta.Version // The default separator for multiple parameters is modified to ";" app.SliceFlagSeparator = ";" // global flags app.Flags = []cli.Flag{ &verboseFlag, } // Commands app.Commands = []*cli.Command{ { Name: meta.CmdNew, Usage: "Generate a new Hertz project", Flags: []cli.Flag{ &idlFlag, &serviceNameFlag, &moduleFlag, &outDirFlag, &handlerDirFlag, &modelDirFlag, &routerDirFlag, &clientDirFlag, &useFlag, &includesFlag, &thriftOptionsFlag, &protoOptionsFlag, &optPkgFlag, &trimGoPackage, &noRecurseFlag, &forceNewFlag, &enableExtendsFlag, &sortRouterFlag, &jsonEnumStrFlag, &unsetOmitemptyFlag, &protoCamelJSONTag, &snakeNameFlag, &rmTagFlag, &excludeFilesFlag, &customLayout, &customLayoutData, &customPackage, &handlerByMethod, &protoPluginsFlag, &thriftPluginsFlag, }, Action: New, }, { Name: meta.CmdUpdate, Usage: "Update an existing Hertz project", Flags: []cli.Flag{ &idlFlag, &moduleFlag, &outDirFlag, &handlerDirFlag, &modelDirFlag, &clientDirFlag, &useFlag, &includesFlag, &thriftOptionsFlag, &protoOptionsFlag, &optPkgFlag, &trimGoPackage, &noRecurseFlag, &enableExtendsFlag, &sortRouterFlag, &jsonEnumStrFlag, &unsetOmitemptyFlag, &protoCamelJSONTag, &snakeNameFlag, &rmTagFlag, &excludeFilesFlag, &customPackage, &handlerByMethod, &protoPluginsFlag, &thriftPluginsFlag, }, Action: Update, }, { Name: meta.CmdModel, Usage: "Generate model code only", Flags: []cli.Flag{ &idlFlag, &moduleFlag, &outDirFlag, &modelDirFlag, &includesFlag, &thriftOptionsFlag, &protoOptionsFlag, &noRecurseFlag, &trimGoPackage, &jsonEnumStrFlag, &unsetOmitemptyFlag, &protoCamelJSONTag, &snakeNameFlag, &rmTagFlag, &excludeFilesFlag, }, Action: Model, }, { Name: meta.CmdClient, Usage: "Generate hertz client based on IDL", Flags: []cli.Flag{ &idlFlag, &moduleFlag, &baseDomainFlag, &modelDirFlag, &clientDirFlag, &useFlag, &forceClientDirFlag, &forceUpdateClientFlag, &includesFlag, &thriftOptionsFlag, &protoOptionsFlag, &noRecurseFlag, &enableExtendsFlag, &trimGoPackage, &jsonEnumStrFlag, &enableClientOptionalFlag, &queryEnumIntFlag, &unsetOmitemptyFlag, &protoCamelJSONTag, &snakeNameFlag, &rmTagFlag, &excludeFilesFlag, &customPackage, &protoPluginsFlag, &thriftPluginsFlag, }, Action: Client, }, } return app } func setLogVerbose(verbose bool) { if verbose { logs.SetLevel(logs.LevelDebug) } else { logs.SetLevel(logs.LevelWarn) } } func GenerateLayout(args *config.Argument) error { lg := &generator.LayoutGenerator{ TemplateGenerator: generator.TemplateGenerator{ OutputDir: args.OutDir, Excludes: args.Excludes, }, } layout := generator.Layout{ GoModule: args.Gomod, ServiceName: args.ServiceName, UseApacheThrift: args.IdlType == meta.IdlThrift, HasIdl: 0 != len(args.IdlPaths), ModelDir: args.ModelDir, HandlerDir: args.HandlerDir, RouterDir: args.RouterDir, NeedGoMod: args.NeedGoMod, } if args.CustomizeLayout == "" { // generate by default err := lg.GenerateByService(layout) if err != nil { return fmt.Errorf("generating layout failed: %v", err) } } else { // generate by customized layout configPath, dataPath := args.CustomizeLayout, args.CustomizeLayoutData logs.Infof("get customized layout info, layout_config_path: %s, template_data_path: %s", configPath, dataPath) exist, err := util.PathExist(configPath) if err != nil { return fmt.Errorf("check customized layout config file exist failed: %v", err) } if !exist { return errors.New("layout_config_path doesn't exist") } lg.ConfigPath = configPath // generate by service info if dataPath == "" { err := lg.GenerateByService(layout) if err != nil { return fmt.Errorf("generating layout failed: %v", err) } } else { // generate by customized data err := lg.GenerateByConfig(dataPath) if err != nil { return fmt.Errorf("generating layout failed: %v", err) } } } err := lg.Persist() if err != nil { return fmt.Errorf("generating layout failed: %v", err) } return nil } func TriggerPlugin(args *config.Argument) error { if len(args.IdlPaths) == 0 { return nil } cmd, err := config.BuildPluginCmd(args) if err != nil { return fmt.Errorf("build plugin command failed: %v", err) } compiler, err := config.IdlTypeToCompiler(args.IdlType) if err != nil { return fmt.Errorf("get compiler failed: %v", err) } logs.Debugf("begin to trigger plugin, compiler: %s, idl_paths: %v", compiler, args.IdlPaths) buf, err := cmd.CombinedOutput() if err != nil { out := strings.TrimSpace(string(buf)) if !strings.HasSuffix(out, meta.TheUseOptionMessage) { return fmt.Errorf("plugin %s_gen_hertz returns error: %v, cause:\n%v", compiler, err, string(buf)) } } // If len(buf) != 0, the plugin returned the log. if len(buf) != 0 { fmt.Println(string(buf)) } logs.Debugf("end run plugin %s_gen_hertz", compiler) return nil } ================================================ FILE: cmd/hz/config/argument.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package config import ( "fmt" "os" "path/filepath" "strings" "github.com/cloudwego/hertz/cmd/hz/meta" "github.com/cloudwego/hertz/cmd/hz/util" "github.com/cloudwego/hertz/cmd/hz/util/logs" "github.com/urfave/cli/v2" ) type Argument struct { // Mode meta.Mode // operating mode(0-compiler, 1-plugin) CmdType string // command type Verbose bool // print verbose log Cwd string // execution path OutDir string // output path HandlerDir string // handler path ModelDir string // model path RouterDir string // router path ClientDir string // client path BaseDomain string // request domain ForceClientDir string // client dir (not use namespace as a subpath) IdlType string // idl type IdlPaths []string // master idl path RawOptPkg []string // user-specified package import path OptPkgMap map[string]string Includes []string PkgPrefix string TrimGoPackage string // trim go_package for protobuf, avoid to generate multiple directory Gopath string // $GOPATH Gosrc string // $GOPATH/src Gomod string Gopkg string // $GOPATH/src/{{gopkg}} ServiceName string // service name Use string NeedGoMod bool JSONEnumStr bool QueryEnumAsInt bool UnsetOmitempty bool ProtobufCamelJSONTag bool ProtocOptions []string // options to pass through to protoc ThriftOptions []string // options to pass through to thriftgo for go flag ProtobufPlugins []string ThriftPlugins []string SnakeName bool RmTags []string Excludes []string NoRecurse bool HandlerByMethod bool ForceNew bool ForceUpdateClient bool SnakeStyleMiddleware bool EnableExtends bool SortRouter bool // client flag EnableClientOptional bool CustomizeLayout string CustomizeLayoutData string CustomizePackage string ModelBackend string } func NewArgument() *Argument { return &Argument{ OptPkgMap: make(map[string]string), Includes: make([]string, 0, 4), Excludes: make([]string, 0, 4), ProtocOptions: make([]string, 0, 4), ThriftOptions: make([]string, 0, 4), } } // Parse initializes a new argument based on its own information func (arg *Argument) Parse(c *cli.Context, cmd string) (*Argument, error) { // v2 cli cannot put the StringSlice flag to struct, so we need to parse it here arg.parseStringSlice(c) args := arg.Fork() args.CmdType = cmd err := args.checkPath() if err != nil { return nil, err } err = args.checkIDL() if err != nil { return nil, err } err = args.checkPackage() if err != nil { return nil, err } return args, nil } func (arg *Argument) parseStringSlice(c *cli.Context) { arg.IdlPaths = c.StringSlice("idl") arg.Includes = c.StringSlice("proto_path") arg.Excludes = c.StringSlice("exclude_file") arg.RawOptPkg = c.StringSlice("option_package") arg.ThriftOptions = c.StringSlice("thriftgo") arg.ProtocOptions = c.StringSlice("protoc") arg.ThriftPlugins = c.StringSlice("thrift-plugins") arg.ProtobufPlugins = c.StringSlice("protoc-plugins") arg.RmTags = c.StringSlice("rm_tag") } func (arg *Argument) UpdateByManifest(m *meta.Manifest) { if arg.HandlerDir == "" && m.HandlerDir != "" { logs.Infof("use \"handler_dir\" in \".hz\" as the handler generated dir\n") arg.HandlerDir = m.HandlerDir } if arg.ModelDir == "" && m.ModelDir != "" { logs.Infof("use \"model_dir\" in \".hz\" as the model generated dir\n") arg.ModelDir = m.ModelDir } if len(m.RouterDir) != 0 { logs.Infof("use \"router_dir\" in \".hz\" as the router generated dir\n") arg.RouterDir = m.RouterDir } } // checkPath sets the project path and verifies that the model、handler、router and client path is compliant func (arg *Argument) checkPath() error { dir, err := os.Getwd() if err != nil { return fmt.Errorf("get current path failed: %s", err) } arg.Cwd = dir if arg.OutDir == "" { arg.OutDir = dir } if !filepath.IsAbs(arg.OutDir) { ap := filepath.Join(arg.Cwd, arg.OutDir) arg.OutDir = ap } if arg.ModelDir != "" && filepath.IsAbs(arg.ModelDir) { return fmt.Errorf("model path %s must be relative to out_dir", arg.ModelDir) } if arg.HandlerDir != "" && filepath.IsAbs(arg.HandlerDir) { return fmt.Errorf("handler path %s must be relative to out_dir", arg.HandlerDir) } if arg.RouterDir != "" && filepath.IsAbs(arg.RouterDir) { return fmt.Errorf("router path %s must be relative to out_dir", arg.RouterDir) } if arg.ClientDir != "" && filepath.IsAbs(arg.ClientDir) { return fmt.Errorf("router path %s must be relative to out_dir", arg.ClientDir) } return nil } // checkIDL check if the idl path exists, set and check the idl type func (arg *Argument) checkIDL() error { for i, path := range arg.IdlPaths { abPath, err := filepath.Abs(path) if err != nil { return fmt.Errorf("idl path %s is not absolute", path) } ext := filepath.Ext(abPath) if ext == "" || ext[0] != '.' { return fmt.Errorf("idl path %s is not a valid file", path) } ext = ext[1:] switch ext { case meta.IdlThrift: arg.IdlType = meta.IdlThrift case meta.IdlProto: arg.IdlType = meta.IdlProto default: return fmt.Errorf("IDL type %s is not supported", ext) } arg.IdlPaths[i] = abPath } return nil } func (arg *Argument) IsUpdate() bool { return arg.CmdType == meta.CmdUpdate } func (arg *Argument) IsNew() bool { return arg.CmdType == meta.CmdNew } // checkPackage check and set the gopath、 module and package name func (arg *Argument) checkPackage() error { gopath, err := util.GetGOPATH() if err != nil { return fmt.Errorf("get gopath failed: %s", err) } if gopath == "" { return fmt.Errorf("GOPATH is not set") } arg.Gopath = gopath arg.Gosrc = filepath.Join(gopath, "src") // Generate the project under gopath, use the relative path as the package name if strings.HasPrefix(arg.Cwd, arg.Gosrc) { if gopkg, err := filepath.Rel(arg.Gosrc, arg.Cwd); err != nil { return fmt.Errorf("get relative path to GOPATH/src failed: %s", err) } else { arg.Gopkg = gopkg } } if len(arg.Gomod) == 0 { // not specified "go module" // search go.mod recursively module, path, ok := util.SearchGoMod(arg.Cwd, true) if ok { // find go.mod in upper level, use it as project module, don't generate go.mod rel, err := filepath.Rel(path, arg.Cwd) if err != nil { return fmt.Errorf("can not get relative path, err :%v", err) } arg.Gomod = filepath.Join(module, rel) logs.Debugf("find module '%s' from '%s/go.mod', so use it as module name", module, path) } if len(arg.Gomod) == 0 { // don't find go.mod in upper level, use relative path as module name, generate go.mod logs.Debugf("use gopath's relative path '%s' as the module name", arg.Gopkg) // gopkg will be "" under non-gopath arg.Gomod = arg.Gopkg arg.NeedGoMod = true } arg.Gomod = filepath.ToSlash(arg.Gomod) } else { // specified "go module" // search go.mod in current path module, path, ok := util.SearchGoMod(arg.Cwd, false) if ok { // go.mod exists in current path, check module name, don't generate go.mod if module != arg.Gomod { return fmt.Errorf("module name given by the '-module/mod' option ('%s') is not consist with the name defined in go.mod ('%s' from %s), try to remove '-module/mod' option in your command\n", arg.Gomod, module, path) } } else { // go.mod don't exist in current path, generate go.mod arg.NeedGoMod = true } } if len(arg.Gomod) == 0 { return fmt.Errorf("can not get go module, please specify a module name with the '-module/mod' flag") } if len(arg.RawOptPkg) > 0 { arg.OptPkgMap = make(map[string]string, len(arg.RawOptPkg)) for _, op := range arg.RawOptPkg { ps := strings.SplitN(op, "=", 2) if len(ps) != 2 { return fmt.Errorf("invalid option package: %s", op) } arg.OptPkgMap[ps[0]] = ps[1] } arg.RawOptPkg = nil } return nil } func (arg *Argument) Pack() ([]string, error) { data, err := util.PackArgs(arg) if err != nil { return nil, fmt.Errorf("pack argument failed: %s", err) } return data, nil } func (arg *Argument) Unpack(data []string) error { err := util.UnpackArgs(data, arg) if err != nil { return fmt.Errorf("unpack argument failed: %s", err) } return nil } // Fork can copy its own parameters to a new argument func (arg *Argument) Fork() *Argument { args := NewArgument() *args = *arg util.CopyString2StringMap(arg.OptPkgMap, args.OptPkgMap) util.CopyStringSlice(&arg.Includes, &args.Includes) util.CopyStringSlice(&arg.Excludes, &args.Excludes) util.CopyStringSlice(&arg.ProtocOptions, &args.ProtocOptions) util.CopyStringSlice(&arg.ThriftOptions, &args.ThriftOptions) return args } func (arg *Argument) GetGoPackage() (string, error) { if arg.Gomod != "" { return arg.Gomod, nil } else if arg.Gopkg != "" { return arg.Gopkg, nil } return "", fmt.Errorf("project package name is not set") } func IdlTypeToCompiler(idlType string) (string, error) { switch idlType { case meta.IdlProto: return meta.TpCompilerProto, nil case meta.IdlThrift: return meta.TpCompilerThrift, nil default: return "", fmt.Errorf("IDL type %s is not supported", idlType) } } func (arg *Argument) ModelPackagePrefix() (string, error) { ret := arg.Gomod if arg.ModelDir == "" { path, err := util.RelativePath(meta.ModelDir) if !strings.HasPrefix(path, "/") { path = "/" + path } if err != nil { return "", err } ret += path } else { path, err := util.RelativePath(arg.ModelDir) if err != nil { return "", err } ret += "/" + path } return filepath.ToSlash(ret), nil } func (arg *Argument) ModelOutDir() string { ret := arg.OutDir if arg.ModelDir == "" { ret = filepath.Join(ret, meta.ModelDir) } else { ret = filepath.Join(ret, arg.ModelDir) } return ret } func (arg *Argument) GetHandlerDir() (string, error) { if arg.HandlerDir == "" { return util.RelativePath(meta.HandlerDir) } return util.RelativePath(arg.HandlerDir) } func (arg *Argument) GetModelDir() (string, error) { if arg.ModelDir == "" { return util.RelativePath(meta.ModelDir) } return util.RelativePath(arg.ModelDir) } func (arg *Argument) GetRouterDir() (string, error) { if arg.RouterDir == "" { return util.RelativePath(meta.RouterDir) } return util.RelativePath(arg.RouterDir) } func (arg *Argument) GetClientDir() (string, error) { if arg.ClientDir == "" { return "", nil } return util.RelativePath(arg.ClientDir) } func (arg *Argument) InitManifest(m *meta.Manifest) { m.Version = meta.Version m.HandlerDir = arg.HandlerDir m.ModelDir = arg.ModelDir m.RouterDir = arg.RouterDir } func (arg *Argument) UpdateManifest(m *meta.Manifest) { m.Version = meta.Version if arg.HandlerDir != m.HandlerDir { m.HandlerDir = arg.HandlerDir } if arg.HandlerDir != m.ModelDir { m.ModelDir = arg.ModelDir } // "router_dir" must not be defined by "update" command } ================================================ FILE: cmd/hz/config/cmd.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package config import ( "fmt" "os" "os/exec" "path/filepath" "strings" "syscall" "github.com/cloudwego/hertz/cmd/hz/meta" "github.com/cloudwego/hertz/cmd/hz/util" "github.com/cloudwego/hertz/cmd/hz/util/logs" ) func lookupTool(idlType string) (string, error) { tool := meta.TpCompilerThrift if idlType == meta.IdlProto { tool = meta.TpCompilerProto } path, err := exec.LookPath(tool) logs.Debugf("[DEBUG]path:%v", path) if err != nil { goPath, err := util.GetGOPATH() if err != nil { return "", fmt.Errorf("get 'GOPATH' failed for find %s : %v", tool, path) } path = filepath.Join(goPath, "bin", tool) } isExist, err := util.PathExist(path) if err != nil { return "", fmt.Errorf("check '%s' path error: %v", path, err) } if !isExist { if tool == meta.TpCompilerThrift { // If thriftgo does not exist, the latest version will be installed automatically. err := util.InstallAndCheckThriftgo() if err != nil { return "", fmt.Errorf("can't install '%s' automatically, please install it manually for https://github.com/cloudwego/thriftgo, err : %v", tool, err) } } else { // todo: protoc automatic installation return "", fmt.Errorf("%s is not installed, please install it first", tool) } } if tool == meta.TpCompilerThrift { // If thriftgo exists, the version is detected; if the version is lower than v0.2.0 then the latest version of thriftgo is automatically installed. err := util.CheckAndUpdateThriftgo() if err != nil { return "", fmt.Errorf("update thriftgo version failed, please install it manually for https://github.com/cloudwego/thriftgo, err: %v", err) } } return path, nil } // link removes the previous symbol link and rebuilds a new one. func link(src, dst string) error { err := syscall.Unlink(dst) if err != nil && !os.IsNotExist(err) { return fmt.Errorf("unlink %q: %s", dst, err) } err = os.Symlink(src, dst) if err != nil { return fmt.Errorf("symlink %q: %s", dst, err) } return nil } func BuildPluginCmd(args *Argument) (*exec.Cmd, error) { exe, err := os.Executable() if err != nil { return nil, fmt.Errorf("failed to detect current executable, err: %v", err) } argPacks, err := args.Pack() if err != nil { return nil, err } kas := strings.Join(argPacks, ",") path, err := lookupTool(args.IdlType) if err != nil { return nil, err } cmd := &exec.Cmd{ Path: path, } if args.IdlType == meta.IdlThrift { // thriftgo os.Setenv(meta.EnvPluginMode, meta.ThriftPluginName) cmd.Args = append(cmd.Args, meta.TpCompilerThrift) for _, inc := range args.Includes { cmd.Args = append(cmd.Args, "-i", inc) } if args.Verbose { cmd.Args = append(cmd.Args, "-v") } thriftOpt, err := args.GetThriftgoOptions() if err != nil { return nil, err } cmd.Args = append(cmd.Args, "-o", args.ModelOutDir(), "-g", thriftOpt, "-p", "hertz="+exe+":"+kas, ) for _, p := range args.ThriftPlugins { cmd.Args = append(cmd.Args, "-p", p) } if !args.NoRecurse { cmd.Args = append(cmd.Args, "-r") } } else { // protoc os.Setenv(meta.EnvPluginMode, meta.ProtocPluginName) cmd.Args = append(cmd.Args, meta.TpCompilerProto) for _, inc := range args.Includes { cmd.Args = append(cmd.Args, "-I", inc) } for _, inc := range args.IdlPaths { cmd.Args = append(cmd.Args, "-I", filepath.Dir(inc)) } cmd.Args = append(cmd.Args, "--plugin=protoc-gen-hertz="+exe, "--hertz_out="+args.OutDir, "--hertz_opt="+kas, ) for _, p := range args.ProtobufPlugins { pluginParams := strings.Split(p, ":") if len(pluginParams) != 3 { logs.Warnf("Failed to get the correct protoc plugin parameters for %. "+ "Please specify the protoc plugin in the form of \"plugin_name:options:out_dir\"", p) os.Exit(1) } // pluginParams[0] -> plugin name, pluginParams[1] -> plugin options, pluginParams[2] -> out_dir cmd.Args = append(cmd.Args, fmt.Sprintf("--%s_out=%s", pluginParams[0], pluginParams[2]), fmt.Sprintf("--%s_opt=%s", pluginParams[0], pluginParams[1]), ) } for _, kv := range args.ProtocOptions { cmd.Args = append(cmd.Args, "--"+kv) } } cmd.Args = append(cmd.Args, args.IdlPaths...) logs.Infof(strings.Join(cmd.Args, " ")) logs.Flush() return cmd, nil } func (arg *Argument) GetThriftgoOptions() (string, error) { defaultOpt := "reserve_comments,gen_json_tag=false," prefix, err := arg.ModelPackagePrefix() if err != nil { return "", err } arg.ThriftOptions = append(arg.ThriftOptions, "package_prefix="+prefix) if arg.JSONEnumStr { arg.ThriftOptions = append(arg.ThriftOptions, "json_enum_as_text") } gas := "go:" + defaultOpt + strings.Join(arg.ThriftOptions, ",") return gas, nil } ================================================ FILE: cmd/hz/doc.go ================================================ // Copyright 2022 CloudWeGo Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Package "github.com/cloudwego/hertz/cmd/hz" contains packages for building the hz command line tool. // APIs exported by packages under this directory do not promise any backward // compatibility, so please do not rely on them. package main ================================================ FILE: cmd/hz/generator/client.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package generator import ( "path" "path/filepath" "strings" "github.com/cloudwego/hertz/cmd/hz/generator/model" "github.com/cloudwego/hertz/cmd/hz/util" ) type ClientMethod struct { *HttpMethod BodyParamsCode string QueryParamsCode string PathParamsCode string HeaderParamsCode string FormValueCode string FormFileCode string } type ClientConfig struct { QueryEnumAsInt bool } type ClientFile struct { Config ClientConfig FilePath string PackageName string ServiceName string BaseDomain string Imports map[string]*model.Model ClientMethods []*ClientMethod } func (pkgGen *HttpPackageGenerator) genClient(pkg *HttpPackage, clientDir string) error { for _, s := range pkg.Services { cliDir := util.SubDir(clientDir, util.ToSnakeCase(s.Name)) if len(pkgGen.ForceClientDir) != 0 { cliDir = pkgGen.ForceClientDir } hertzClientPath := filepath.Join(cliDir, hertzClientTplName) isExist, err := util.PathExist(hertzClientPath) if err != nil { return err } baseDomain := s.BaseDomain if len(pkgGen.BaseDomain) != 0 { baseDomain = pkgGen.BaseDomain } client := ClientFile{ FilePath: filepath.Join(cliDir, util.ToSnakeCase(s.Name)+".go"), PackageName: util.ToSnakeCase(filepath.Base(cliDir)), ServiceName: util.ToCamelCase(s.Name), ClientMethods: s.ClientMethods, BaseDomain: baseDomain, Config: ClientConfig{QueryEnumAsInt: pkgGen.QueryEnumAsInt}, } if !isExist || pkgGen.ForceUpdateClient { err := pkgGen.TemplateGenerator.Generate(client, hertzClientTplName, hertzClientPath, false) if err != nil { return err } } client.Imports = make(map[string]*model.Model, len(client.ClientMethods)) for _, m := range client.ClientMethods { // Iterate over the request and return parameters of the method to get import path. for key, mm := range m.Models { if v, ok := client.Imports[mm.PackageName]; ok && v.Package != mm.Package { client.Imports[key] = mm continue } client.Imports[mm.PackageName] = mm } } if len(pkgGen.UseDir) != 0 { oldModelPkg := util.SubPackage(pkgGen.ProjPackage, filepath.Clean(pkgGen.ModelDir)) newModelPkg := path.Clean(pkgGen.UseDir) for _, m := range client.ClientMethods { for _, mm := range m.Models { mm.Package = strings.Replace(mm.Package, oldModelPkg, newModelPkg, 1) } } } err = pkgGen.TemplateGenerator.Generate(client, idlClientName, client.FilePath, false) if err != nil { return err } } return nil } ================================================ FILE: cmd/hz/generator/custom_files.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package generator import ( "bytes" "fmt" "io/ioutil" "path/filepath" "strings" "text/template" "github.com/cloudwego/hertz/cmd/hz/util" "github.com/cloudwego/hertz/cmd/hz/util/logs" ) type FilePathRenderInfo struct { MasterIDLName string // master IDL name GenPackage string // master IDL generate code package HandlerDir string // handler generate dir ModelDir string // model generate dir RouterDir string // router generate dir ProjectDir string // projectDir GoModule string // go module ServiceName string // service name, changed as services are traversed MethodName string // method name, changed as methods are traversed HandlerGenPath string // "api.gen_path" value } type IDLPackageRenderInfo struct { FilePathRenderInfo ServiceInfos *HttpPackage } type CustomizedFileForMethod struct { *HttpMethod FilePath string FilePackage string ServiceInfo *Service // service info for this method IDLPackageInfo *IDLPackageRenderInfo // IDL info for this service } type CustomizedFileForService struct { *Service FilePath string FilePackage string IDLPackageInfo *IDLPackageRenderInfo // IDL info for this service } type CustomizedFileForIDL struct { *IDLPackageRenderInfo FilePath string FilePackage string } // todo: 1. how to import other file, if the other file name is a template // genCustomizedFile generate customized file template func (pkgGen *HttpPackageGenerator) genCustomizedFile(pkg *HttpPackage) error { filePathRenderInfo := FilePathRenderInfo{ MasterIDLName: pkg.IdlName, GenPackage: pkg.Package, HandlerDir: pkgGen.HandlerDir, ModelDir: pkgGen.ModelDir, RouterDir: pkgGen.RouterDir, ProjectDir: pkgGen.OutputDir, GoModule: pkgGen.ProjPackage, // methodName & serviceName will change as traverse } idlPackageRenderInfo := IDLPackageRenderInfo{ FilePathRenderInfo: filePathRenderInfo, ServiceInfos: pkg, } for _, tplInfo := range pkgGen.tplsInfo { // the default template has been automatically generated by the tool, so skip if tplInfo.Default { continue } // loop generate file if tplInfo.LoopService || tplInfo.LoopMethod { loopMethod := tplInfo.LoopMethod loopService := tplInfo.LoopService if loopService && !loopMethod { // only loop service for _, service := range idlPackageRenderInfo.ServiceInfos.Services { filePathRenderInfo.ServiceName = service.Name err := pkgGen.genLoopService(tplInfo, filePathRenderInfo, service, &idlPackageRenderInfo) if err != nil { return err } } } else { // loop service & method, because if loop method, the service must be looped for _, service := range idlPackageRenderInfo.ServiceInfos.Services { for _, method := range service.Methods { filePathRenderInfo.ServiceName = service.Name filePathRenderInfo.MethodName = method.Name filePathRenderInfo.HandlerGenPath = method.OutputDir err := pkgGen.genLoopMethod(tplInfo, filePathRenderInfo, method, service, &idlPackageRenderInfo) if err != nil { return err } } } } } else { // generate customized file single err := pkgGen.genSingleCustomizedFile(tplInfo, filePathRenderInfo, idlPackageRenderInfo) if err != nil { return err } } } return nil } // renderFilePath used to render file path template to get real file path func renderFilePath(tplInfo *Template, filePathRenderInfo FilePathRenderInfo) (string, error) { tpl, err := template.New(tplInfo.Path).Funcs(funcMap).Parse(tplInfo.Path) if err != nil { return "", fmt.Errorf("parse file path template(%s) failed, err: %v", tplInfo.Path, err) } filePath := bytes.NewBuffer(nil) err = tpl.Execute(filePath, filePathRenderInfo) if err != nil { return "", fmt.Errorf("render file path template(%s) failed, err: %v", tplInfo.Path, err) } return filePath.String(), nil } func renderInsertKey(tplInfo *Template, data interface{}) (string, error) { tpl, err := template.New(tplInfo.UpdateBehavior.InsertKey).Funcs(funcMap).Parse(tplInfo.UpdateBehavior.InsertKey) if err != nil { return "", fmt.Errorf("parse insert key template(%s) failed, err: %v", tplInfo.UpdateBehavior.InsertKey, err) } insertKey := bytes.NewBuffer(nil) err = tpl.Execute(insertKey, data) if err != nil { return "", fmt.Errorf("render insert key template(%s) failed, err: %v", tplInfo.UpdateBehavior.InsertKey, err) } return insertKey.String(), nil } // renderImportTpl will render import template // it will return the []string, like blow: // ["import", alias "import", import] // other format will be error func renderImportTpl(tplInfo *Template, data interface{}) ([]string, error) { var importList []string for _, impt := range tplInfo.UpdateBehavior.ImportTpl { tpl, err := template.New(impt).Funcs(funcMap).Parse(impt) if err != nil { return nil, fmt.Errorf("parse import template(%s) failed, err: %v", impt, err) } imptContent := bytes.NewBuffer(nil) err = tpl.Execute(imptContent, data) if err != nil { return nil, fmt.Errorf("render import template(%s) failed, err: %v", impt, err) } importList = append(importList, imptContent.String()) } var ret []string for _, impts := range importList { // 'import render result' may have multiple imports if strings.Contains(impts, "\n") { for _, impt := range strings.Split(impts, "\n") { ret = append(ret, impt) } } else { ret = append(ret, impts) } } return ret, nil } // renderAppendContent used to render append content for 'update' command func renderAppendContent(tplInfo *Template, renderInfo interface{}) (string, error) { tpl, err := template.New(tplInfo.Path).Funcs(funcMap).Parse(tplInfo.UpdateBehavior.AppendTpl) if err != nil { return "", fmt.Errorf("parse append content template(%s) failed, err: %v", tplInfo.Path, err) } appendContent := bytes.NewBuffer(nil) err = tpl.Execute(appendContent, renderInfo) if err != nil { return "", fmt.Errorf("render append content template(%s) failed, err: %v", tplInfo.Path, err) } return appendContent.String(), nil } // appendUpdateFile used to append content to file for 'update' command func appendUpdateFile(tplInfo *Template, renderInfo interface{}, fileContent []byte) ([]byte, error) { // render insert content appendContent, err := renderAppendContent(tplInfo, renderInfo) if err != nil { return []byte(""), err } buf := bytes.NewBuffer(nil) _, err = buf.Write(fileContent) if err != nil { return []byte(""), fmt.Errorf("write file(%s) failed, err: %v", tplInfo.Path, err) } // "\r\n" && "\n" has the same suffix if !bytes.HasSuffix(buf.Bytes(), []byte("\n")) { _, err = buf.WriteString("\n") if err != nil { return []byte(""), fmt.Errorf("write file(%s) line break failed, err: %v", tplInfo.Path, err) } } _, err = buf.WriteString(appendContent) if err != nil { return []byte(""), fmt.Errorf("append file(%s) failed, err: %v", tplInfo.Path, err) } return buf.Bytes(), nil } func getInsertImportContent(tplInfo *Template, renderInfo interface{}, fileContent []byte) ([][2]string, error) { importContent, err := renderImportTpl(tplInfo, renderInfo) if err != nil { return nil, err } var imptSlice [][2]string for _, impt := range importContent { // import has to format // 1. alias "import" // 2. "import" // 3. import (can not contain '"') impt = strings.TrimSpace(impt) if !strings.Contains(impt, "\"") { // 3. import (can not contain '"') if bytes.Contains(fileContent, []byte(impt)) { continue } imptSlice = append(imptSlice, [2]string{"", impt}) } else { if !strings.HasSuffix(impt, "\"") { return nil, fmt.Errorf("import can not has suffix \"\"\", for file: %s", tplInfo.Path) } if strings.HasPrefix(impt, "\"") { // 2. "import" if bytes.Contains(fileContent, []byte(impt[1:len(impt)-1])) { continue } imptSlice = append(imptSlice, [2]string{"", impt[1 : len(impt)-1]}) } else { // 3. alias "import" idx := strings.Index(impt, "\"") if idx == -1 { return nil, fmt.Errorf("error import format for file: %s", tplInfo.Path) } if bytes.Contains(fileContent, []byte(impt[idx+1:len(impt)-1])) { continue } imptSlice = append(imptSlice, [2]string{impt[:idx], impt[idx+1 : len(impt)-1]}) } } } return imptSlice, nil } // genLoopService used to generate files by 'service' func (pkgGen *HttpPackageGenerator) genLoopService(tplInfo *Template, filePathRenderInfo FilePathRenderInfo, service *Service, idlPackageRenderInfo *IDLPackageRenderInfo) error { filePath, err := renderFilePath(tplInfo, filePathRenderInfo) if err != nil { return err } // determine if a custom file exists exist, err := util.PathExist(filePath) if err != nil { return fmt.Errorf("judge file(%s) exists failed, err: %v", filePath, err) } if !exist { // create file data := CustomizedFileForService{ Service: service, FilePath: filePath, FilePackage: util.SplitPackage(filepath.Dir(filePath), ""), IDLPackageInfo: idlPackageRenderInfo, } err = pkgGen.TemplateGenerator.Generate(data, tplInfo.Path, filePath, false) if err != nil { return err } } else { switch tplInfo.UpdateBehavior.Type { case Skip: // do nothing logs.Infof("do not update file '%s', because the update behavior is 'Unchanged'", filePath) case Cover: // re-generate logs.Infof("re-generate file '%s', because the update behavior is 'Regenerate'", filePath) data := CustomizedFileForService{ Service: service, FilePath: filePath, FilePackage: util.SplitPackage(filepath.Dir(filePath), ""), IDLPackageInfo: idlPackageRenderInfo, } err := pkgGen.TemplateGenerator.Generate(data, tplInfo.Path, filePath, false) if err != nil { return err } case Append: // todo: append logic need to be optimized for method fileContent, err := ioutil.ReadFile(filePath) if err != nil { return err } var appendContent []byte // get insert content if tplInfo.UpdateBehavior.AppendKey == "method" { for _, method := range service.Methods { data := CustomizedFileForMethod{ HttpMethod: method, FilePath: filePath, FilePackage: util.SplitPackage(filepath.Dir(filePath), ""), ServiceInfo: service, IDLPackageInfo: idlPackageRenderInfo, } insertKey, err := renderInsertKey(tplInfo, data) if err != nil { return err } if bytes.Contains(fileContent, []byte(insertKey)) { continue } imptSlice, err := getInsertImportContent(tplInfo, data, fileContent) if err != nil { return err } // insert new import to the fileContent for _, impt := range imptSlice { if bytes.Contains(fileContent, []byte(impt[1])) { continue } fileContent, err = util.AddImportForContent(fileContent, impt[0], impt[1]) // insert import error do not influence the generated file if err != nil { logs.Warnf("can not add import(%s) for file(%s), err: %v\n", impt[1], filePath, err) } } appendContent, err = appendUpdateFile(tplInfo, data, appendContent) if err != nil { return err } } if len(tplInfo.UpdateBehavior.AppendLocation) == 0 { // default, append to end of file buf := bytes.NewBuffer(nil) _, err = buf.Write(fileContent) if err != nil { return fmt.Errorf("write file(%s) failed, err: %v", tplInfo.Path, err) } _, err = buf.Write(appendContent) if err != nil { return fmt.Errorf("append file(%s) failed, err: %v", tplInfo.Path, err) } logs.Infof("append content for file '%s', because the update behavior is 'Append' and appendKey is 'method'", filePath) pkgGen.files = append(pkgGen.files, File{filePath, buf.String(), false, ""}) } else { // 'append location', append new content after 'append location' part := bytes.Split(fileContent, []byte(tplInfo.UpdateBehavior.AppendLocation)) if len(part) == 0 { return fmt.Errorf("can not find append location '%s' for file '%s'\n", tplInfo.UpdateBehavior.AppendLocation, filePath) } if len(part) != 2 { return fmt.Errorf("do not support multiple append location '%s' for file '%s'\n", tplInfo.UpdateBehavior.AppendLocation, filePath) } buf := bytes.NewBuffer(nil) err = writeBytes(buf, part[0], []byte(tplInfo.UpdateBehavior.AppendLocation), appendContent, part[1]) if err != nil { return fmt.Errorf("write file(%s) failed, err: %v", tplInfo.Path, err) } logs.Infof("append content for file '%s', because the update behavior is 'Append' and appendKey is 'method'", filePath) pkgGen.files = append(pkgGen.files, File{filePath, buf.String(), false, ""}) } } else { logs.Warnf("Loop 'service' field for '%s' only append content by appendKey for 'method', so cannot append content", filePath) } default: // do nothing logs.Warnf("unknown update behavior, do nothing") } } return nil } // genLoopMethod used to generate files by 'method' func (pkgGen *HttpPackageGenerator) genLoopMethod(tplInfo *Template, filePathRenderInfo FilePathRenderInfo, method *HttpMethod, service *Service, idlPackageRenderInfo *IDLPackageRenderInfo) error { filePath, err := renderFilePath(tplInfo, filePathRenderInfo) if err != nil { return err } // determine if a custom file exists exist, err := util.PathExist(filePath) if err != nil { return fmt.Errorf("judge file(%s) exists failed, err: %v", filePath, err) } if !exist { // create file data := CustomizedFileForMethod{ HttpMethod: method, FilePath: filePath, FilePackage: util.SplitPackage(filepath.Dir(filePath), ""), ServiceInfo: service, IDLPackageInfo: idlPackageRenderInfo, } err := pkgGen.TemplateGenerator.Generate(data, tplInfo.Path, filePath, false) if err != nil { return err } } else { switch tplInfo.UpdateBehavior.Type { case Skip: // do nothing logs.Infof("do not update file '%s', because the update behavior is 'Unchanged'", filePath) case Cover: // re-generate logs.Infof("re-generate file '%s', because the update behavior is 'Regenerate'", filePath) data := CustomizedFileForMethod{ HttpMethod: method, FilePath: filePath, FilePackage: util.SplitPackage(filepath.Dir(filePath), ""), ServiceInfo: service, IDLPackageInfo: idlPackageRenderInfo, } err := pkgGen.TemplateGenerator.Generate(data, tplInfo.Path, filePath, false) if err != nil { return err } case Append: // for loop method, no need to append something; so do nothing logs.Warnf("do not append content for file '%s', because the update behavior is 'Append' and loop 'method' have no need to append content", filePath) default: // do nothing logs.Warnf("unknown update behavior, do nothing") } } return nil } // genSingleCustomizedFile used to generate file by 'master IDL' func (pkgGen *HttpPackageGenerator) genSingleCustomizedFile(tplInfo *Template, filePathRenderInfo FilePathRenderInfo, idlPackageRenderInfo IDLPackageRenderInfo) error { // generate file single filePath, err := renderFilePath(tplInfo, filePathRenderInfo) if err != nil { return err } // determine if a custom file exists exist, err := util.PathExist(filePath) if err != nil { return fmt.Errorf("judge file(%s) exists failed, err: %v", filePath, err) } if !exist { // create file data := CustomizedFileForIDL{ IDLPackageRenderInfo: &idlPackageRenderInfo, FilePath: filePath, FilePackage: util.SplitPackage(filepath.Dir(filePath), ""), } err := pkgGen.TemplateGenerator.Generate(data, tplInfo.Path, filePath, false) if err != nil { return err } } else { switch tplInfo.UpdateBehavior.Type { case Skip: // do nothing logs.Infof("do not update file '%s', because the update behavior is 'Unchanged'", filePath) case Cover: // re-generate logs.Infof("re-generate file '%s', because the update behavior is 'Regenerate'", filePath) data := CustomizedFileForIDL{ IDLPackageRenderInfo: &idlPackageRenderInfo, FilePath: filePath, FilePackage: util.SplitPackage(filepath.Dir(filePath), ""), } err := pkgGen.TemplateGenerator.Generate(data, tplInfo.Path, filePath, false) if err != nil { return err } case Append: // todo: append logic need to be optimized for single file fileContent, err := ioutil.ReadFile(filePath) if err != nil { return err } if tplInfo.UpdateBehavior.AppendKey == "method" { var appendContent []byte for _, service := range idlPackageRenderInfo.ServiceInfos.Services { for _, method := range service.Methods { data := CustomizedFileForMethod{ HttpMethod: method, FilePath: filePath, FilePackage: util.SplitPackage(filepath.Dir(filePath), ""), ServiceInfo: service, IDLPackageInfo: &idlPackageRenderInfo, } insertKey, err := renderInsertKey(tplInfo, data) if err != nil { return err } if bytes.Contains(fileContent, []byte(insertKey)) { continue } imptSlice, err := getInsertImportContent(tplInfo, data, fileContent) if err != nil { return err } for _, impt := range imptSlice { if bytes.Contains(fileContent, []byte(impt[1])) { continue } fileContent, err = util.AddImportForContent(fileContent, impt[0], impt[1]) if err != nil { logs.Warnf("can not add import(%s) for file(%s)\n", impt[1], filePath) } } appendContent, err = appendUpdateFile(tplInfo, data, appendContent) if err != nil { return err } } } if len(tplInfo.UpdateBehavior.AppendLocation) == 0 { // default, append to end of file buf := bytes.NewBuffer(nil) _, err = buf.Write(fileContent) if err != nil { return fmt.Errorf("write file(%s) failed, err: %v", tplInfo.Path, err) } _, err = buf.Write(appendContent) if err != nil { return fmt.Errorf("append file(%s) failed, err: %v", tplInfo.Path, err) } logs.Infof("append content for file '%s', because the update behavior is 'Append' and appendKey is 'method'", filePath) pkgGen.files = append(pkgGen.files, File{filePath, buf.String(), false, ""}) } else { // 'append location', append new content after 'append location' part := bytes.Split(fileContent, []byte(tplInfo.UpdateBehavior.AppendLocation)) if len(part) == 0 { return fmt.Errorf("can not find append location '%s' for file '%s'\n", tplInfo.UpdateBehavior.AppendLocation, filePath) } if len(part) != 2 { return fmt.Errorf("do not support multiple append location '%s' for file '%s'\n", tplInfo.UpdateBehavior.AppendLocation, filePath) } buf := bytes.NewBuffer(nil) err = writeBytes(buf, part[0], []byte(tplInfo.UpdateBehavior.AppendLocation), appendContent, part[1]) if err != nil { return fmt.Errorf("write file(%s) failed, err: %v", tplInfo.Path, err) } logs.Infof("append content for file '%s', because the update behavior is 'Append' and appendKey is 'method'", filePath) pkgGen.files = append(pkgGen.files, File{filePath, buf.String(), false, ""}) } } else if tplInfo.UpdateBehavior.AppendKey == "service" { var appendContent []byte for _, service := range idlPackageRenderInfo.ServiceInfos.Services { data := CustomizedFileForService{ Service: service, FilePath: filePath, FilePackage: util.SplitPackage(filepath.Dir(filePath), ""), IDLPackageInfo: &idlPackageRenderInfo, } insertKey, err := renderInsertKey(tplInfo, data) if err != nil { return err } if bytes.Contains(fileContent, []byte(insertKey)) { continue } imptSlice, err := getInsertImportContent(tplInfo, data, fileContent) if err != nil { return err } for _, impt := range imptSlice { if bytes.Contains(fileContent, []byte(impt[1])) { continue } fileContent, err = util.AddImportForContent(fileContent, impt[0], impt[1]) if err != nil { logs.Warnf("can not add import(%s) for file(%s)\n", impt[1], filePath) } } appendContent, err = appendUpdateFile(tplInfo, data, appendContent) if err != nil { return err } } if len(tplInfo.UpdateBehavior.AppendLocation) == 0 { // default, append to end of file buf := bytes.NewBuffer(nil) _, err = buf.Write(fileContent) if err != nil { return fmt.Errorf("write file(%s) failed, err: %v", tplInfo.Path, err) } _, err = buf.Write(appendContent) if err != nil { return fmt.Errorf("append file(%s) failed, err: %v", tplInfo.Path, err) } logs.Infof("append content for file '%s', because the update behavior is 'Append' and appendKey is 'service'", filePath) pkgGen.files = append(pkgGen.files, File{filePath, buf.String(), false, ""}) } else { // 'append location', append new content after 'append location' part := bytes.Split(fileContent, []byte(tplInfo.UpdateBehavior.AppendLocation)) if len(part) == 0 { return fmt.Errorf("can not find append location '%s' for file '%s'\n", tplInfo.UpdateBehavior.AppendLocation, filePath) } if len(part) != 2 { return fmt.Errorf("do not support multiple append location '%s' for file '%s'\n", tplInfo.UpdateBehavior.AppendLocation, filePath) } buf := bytes.NewBuffer(nil) err = writeBytes(buf, part[0], []byte(tplInfo.UpdateBehavior.AppendLocation), appendContent, part[1]) if err != nil { return fmt.Errorf("write file(%s) failed, err: %v", tplInfo.Path, err) } logs.Infof("append content for file '%s', because the update behavior is 'Append' and appendKey is 'service'", filePath) pkgGen.files = append(pkgGen.files, File{filePath, buf.String(), false, ""}) } } else { // add append content to the file directly data := CustomizedFileForIDL{ IDLPackageRenderInfo: &idlPackageRenderInfo, FilePath: filePath, FilePackage: util.SplitPackage(filepath.Dir(filePath), ""), } file, err := appendUpdateFile(tplInfo, data, fileContent) if err != nil { return err } pkgGen.files = append(pkgGen.files, File{filePath, string(file), false, ""}) } default: // do nothing logs.Warnf("unknown update behavior, do nothing") } } return nil } func writeBytes(buf *bytes.Buffer, bytes ...[]byte) error { for _, b := range bytes { _, err := buf.Write(b) if err != nil { return err } } return nil } ================================================ FILE: cmd/hz/generator/file.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package generator import ( "fmt" "go/format" "path/filepath" "strings" "github.com/cloudwego/hertz/cmd/hz/util" ) type File struct { Path string Content string NoRepeat bool FileTplName string } // Lint is used to statically analyze and format go code func (file *File) Lint() error { name := filepath.Base(file.Path) if strings.HasSuffix(name, ".go") { out, err := format.Source(util.Str2Bytes(file.Content)) if err != nil { return fmt.Errorf("lint file '%s' failed, err: %v", name, err.Error()) } file.Content = util.Bytes2Str(out) } return nil } ================================================ FILE: cmd/hz/generator/handler.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package generator import ( "bytes" "fmt" "io/ioutil" "path" "path/filepath" "strings" "github.com/cloudwego/hertz/cmd/hz/generator/model" "github.com/cloudwego/hertz/cmd/hz/util" "github.com/cloudwego/hertz/cmd/hz/util/logs" ) type HttpMethod struct { Name string HTTPMethod string Comment string RequestTypeName string RequestTypePackage string RequestTypeRawName string ReturnTypeName string ReturnTypePackage string ReturnTypeRawName string Path string Serializer string OutputDir string RefPackage string // handler import dir RefPackageAlias string // handler import alias ModelPackage map[string]string GenHandler bool // Whether to generate one handler, when an idl interface corresponds to multiple http method // Annotations map[string]string Models map[string]*model.Model } type Handler struct { FilePath string PackageName string ProjPackage string Imports map[string]*model.Model Methods []*HttpMethod } type SingleHandler struct { *HttpMethod FilePath string PackageName string ProjPackage string } type Client struct { Handler ServiceName string } func (pkgGen *HttpPackageGenerator) genHandler(pkg *HttpPackage, handlerDir, handlerPackage string, root *RouterNode) error { for _, s := range pkg.Services { var handler Handler if pkgGen.HandlerByMethod { // generate handler by method for _, m := range s.Methods { filePath := filepath.Join(handlerDir, m.OutputDir, util.ToSnakeCase(m.Name)+".go") handler = Handler{ FilePath: filePath, PackageName: util.SplitPackage(filepath.Dir(filePath), ""), Methods: []*HttpMethod{m}, ProjPackage: pkgGen.ProjPackage, } if err := pkgGen.processHandler(&handler, root, handlerDir, m.OutputDir, true); err != nil { return fmt.Errorf("generate handler %s failed, err: %v", handler.FilePath, err.Error()) } if m.GenHandler { if err := pkgGen.updateHandler(handler, handlerTplName, handler.FilePath, false); err != nil { return fmt.Errorf("generate handler %s failed, err: %v", handler.FilePath, err.Error()) } } } } else { // generate handler service tmpHandlerDir := handlerDir tmpHandlerPackage := handlerPackage if len(s.ServiceGenDir) != 0 { tmpHandlerDir = s.ServiceGenDir tmpHandlerPackage = util.SubPackage(pkgGen.ProjPackage, strings.TrimPrefix(tmpHandlerDir, "/")) } handler = Handler{ FilePath: filepath.Join(tmpHandlerDir, util.ToSnakeCase(s.Name)+".go"), PackageName: util.SplitPackage(tmpHandlerPackage, ""), Methods: s.Methods, ProjPackage: pkgGen.ProjPackage, } for _, m := range s.Methods { m.RefPackage = tmpHandlerPackage m.RefPackageAlias = util.BaseName(tmpHandlerPackage, "") } if err := pkgGen.processHandler(&handler, root, "", "", false); err != nil { return fmt.Errorf("generate handler %s failed, err: %v", handler.FilePath, err.Error()) } // Avoid generating duplicate handlers when IDL interface corresponds to multiple http methods methods := handler.Methods handler.Methods = []*HttpMethod{} for _, m := range methods { if m.GenHandler { handler.Methods = append(handler.Methods, m) } } if err := pkgGen.updateHandler(handler, handlerTplName, handler.FilePath, false); err != nil { return fmt.Errorf("generate handler %s failed, err: %v", handler.FilePath, err.Error()) } } if len(pkgGen.ClientDir) != 0 { clientDir := util.SubDir(pkgGen.ClientDir, pkg.Package) clientPackage := util.SubPackage(pkgGen.ProjPackage, clientDir) client := Client{} client.Handler = handler client.ServiceName = s.Name client.PackageName = util.SplitPackage(clientPackage, "") client.FilePath = filepath.Join(clientDir, util.ToSnakeCase(s.Name)+".go") if err := pkgGen.updateClient(client, clientTplName, client.FilePath, false); err != nil { return fmt.Errorf("generate client %s failed, err: %v", client.FilePath, err.Error()) } } } return nil } func (pkgGen *HttpPackageGenerator) processHandler(handler *Handler, root *RouterNode, handlerDir, projectOutDir string, handlerByMethod bool) error { singleHandlerPackage := "" if handlerByMethod { singleHandlerPackage = util.SubPackage(pkgGen.ProjPackage, filepath.Join(handlerDir, projectOutDir)) } handler.Imports = make(map[string]*model.Model, len(handler.Methods)) for _, m := range handler.Methods { // Iterate over the request and return parameters of the method to get import path. for key, mm := range m.Models { if v, ok := handler.Imports[mm.PackageName]; ok && v.Package != mm.Package { handler.Imports[key] = mm continue } handler.Imports[mm.PackageName] = mm } err := root.Update(m, handler.PackageName, singleHandlerPackage, pkgGen.SortRouter) if err != nil { return err } } if len(pkgGen.UseDir) != 0 { oldModelPkg := util.SubPackage(pkgGen.ProjPackage, filepath.Clean(pkgGen.ModelDir)) newModelPkg := path.Clean(pkgGen.UseDir) for _, m := range handler.Methods { for _, mm := range m.Models { mm.Package = strings.Replace(mm.Package, oldModelPkg, newModelPkg, 1) } } } handler.Format() return nil } func (pkgGen *HttpPackageGenerator) updateHandler(handler interface{}, handlerTpl, filePath string, noRepeat bool) error { if pkgGen.tplsInfo[handlerTpl].Disable { return nil } isExist, err := util.PathExist(filePath) if err != nil { return err } if !isExist { return pkgGen.TemplateGenerator.Generate(handler, handlerTpl, filePath, noRepeat) } if pkgGen.HandlerByMethod { // method by handler, do not need to insert new content return nil } file, err := ioutil.ReadFile(filePath) if err != nil { return err } // insert new model imports for alias, model := range handler.(Handler).Imports { if bytes.Contains(file, []byte(model.Package)) { continue } file, err = util.AddImportForContent(file, alias, model.Package) if err != nil { return err } } // insert customized imports if tplInfo, exist := pkgGen.TemplateGenerator.tplsInfo[handlerTpl]; exist { if len(tplInfo.UpdateBehavior.ImportTpl) != 0 { imptSlice, err := getInsertImportContent(tplInfo, handler, file) if err != nil { return err } for _, impt := range imptSlice { if bytes.Contains(file, []byte(impt[1])) { continue } file, err = util.AddImportForContent(file, impt[0], impt[1]) if err != nil { logs.Warnf("can not add import(%s) for file(%s), err: %v\n", impt[1], filePath, err) } } } } // insert new handler for _, method := range handler.(Handler).Methods { if bytes.Contains(file, []byte(fmt.Sprintf("func %s(", method.Name))) { continue } // Generate additional handlers using templates handlerSingleTpl := pkgGen.tpls[handlerSingleTplName] if handlerSingleTpl == nil { return fmt.Errorf("tpl %s not found", handlerSingleTplName) } data := SingleHandler{ HttpMethod: method, FilePath: handler.(Handler).FilePath, PackageName: handler.(Handler).PackageName, ProjPackage: handler.(Handler).ProjPackage, } handlerFunc := bytes.NewBuffer(nil) err = handlerSingleTpl.Execute(handlerFunc, data) if err != nil { return fmt.Errorf("execute template \"%s\" failed, %v", handlerSingleTplName, err) } buf := bytes.NewBuffer(nil) _, err = buf.Write(file) if err != nil { return fmt.Errorf("write handler \"%s\" failed, %v", method.Name, err) } _, err = buf.Write(handlerFunc.Bytes()) if err != nil { return fmt.Errorf("write handler \"%s\" failed, %v", method.Name, err) } file = buf.Bytes() } pkgGen.files = append(pkgGen.files, File{filePath, string(file), false, ""}) return nil } func (pkgGen *HttpPackageGenerator) updateClient(client interface{}, clientTpl, filePath string, noRepeat bool) error { isExist, err := util.PathExist(filePath) if err != nil { return err } if !isExist { return pkgGen.TemplateGenerator.Generate(client, clientTpl, filePath, noRepeat) } logs.Infof("Client file:%s has been generated, so don't update it", filePath) return nil } func (m *HttpMethod) InitComment() { text := strings.TrimLeft(strings.TrimSpace(m.Comment), "/") if text == "" { text = "// " + m.Name + " ." } else if strings.HasPrefix(text, m.Name) { text = "// " + text } else { text = "// " + m.Name + " " + text } text = strings.Replace(text, "\n", "\n// ", -1) if !strings.Contains(text, "@router ") { text += "\n// @router " + m.Path } m.Comment = text + " [" + m.HTTPMethod + "]" } func MapSerializer(serializer string) string { switch serializer { case "json": return "JSON" case "thrift": return "Thrift" case "pb": return "ProtoBuf" default: return "JSON" } } func (h *Handler) Format() { for _, m := range h.Methods { m.Serializer = MapSerializer(m.Serializer) m.InitComment() } } ================================================ FILE: cmd/hz/generator/layout.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package generator import ( "encoding/json" "errors" "fmt" "io/ioutil" "path/filepath" "reflect" "strings" "github.com/cloudwego/hertz/cmd/hz/meta" "github.com/cloudwego/hertz/cmd/hz/util" "gopkg.in/yaml.v2" ) // Layout contains the basic information of idl type Layout struct { OutDir string GoModule string ServiceName string UseApacheThrift bool HasIdl bool NeedGoMod bool ModelDir string HandlerDir string RouterDir string } // LayoutGenerator contains the information generated by generating the layout template type LayoutGenerator struct { ConfigPath string TemplateGenerator } var ( layoutConfig = defaultLayoutConfig packageConfig = defaultPkgConfig ) func SetDefaultTemplateConfig() { layoutConfig = defaultLayoutConfig packageConfig = defaultPkgConfig } func (lg *LayoutGenerator) Init() error { config := layoutConfig // unmarshal from user-defined config file if it exists if lg.ConfigPath != "" { cdata, err := ioutil.ReadFile(lg.ConfigPath) if err != nil { return fmt.Errorf("read layout config from %s failed, err: %v", lg.ConfigPath, err.Error()) } config = TemplateConfig{} if err = yaml.Unmarshal(cdata, &config); err != nil { return fmt.Errorf("unmarshal layout config failed, err: %v", err.Error()) } } if reflect.DeepEqual(config, TemplateConfig{}) { return errors.New("empty config") } lg.Config = &config return lg.TemplateGenerator.Init() } // checkInited initialize template definition func (lg *LayoutGenerator) checkInited() error { if lg.tpls == nil || lg.dirs == nil { if err := lg.Init(); err != nil { return fmt.Errorf("init layout config failed, err: %v", err.Error()) } } return nil } func (lg *LayoutGenerator) Generate(data map[string]interface{}) error { if err := lg.checkInited(); err != nil { return err } return lg.TemplateGenerator.Generate(data, "", "", false) } func (lg *LayoutGenerator) GenerateByService(service Layout) error { if err := lg.checkInited(); err != nil { return err } if len(service.HandlerDir) != 0 { // override the default "biz/handler/ping.go" to "HANDLER_DIR/ping.go" defaultPingDir := defaultHandlerDir + sp + "ping.go" if tpl, exist := lg.tpls[defaultPingDir]; exist { delete(lg.tpls, defaultPingDir) newPingDir := filepath.Clean(service.HandlerDir + sp + "ping.go") lg.tpls[newPingDir] = tpl } } if len(service.RouterDir) != 0 { defaultRegisterDir := defaultRouterDir + sp + registerTplName if tpl, exist := lg.tpls[defaultRegisterDir]; exist { delete(lg.tpls, defaultRegisterDir) newRegisterDir := filepath.Clean(service.RouterDir + sp + registerTplName) lg.tpls[newRegisterDir] = tpl } } if !service.NeedGoMod { gomodFile := "go.mod" if _, exist := lg.tpls[gomodFile]; exist { delete(lg.tpls, gomodFile) } } if util.IsWindows() { buildSh := "build.sh" bootstrapSh := defaultScriptDir + sp + "bootstrap.sh" if _, exist := lg.tpls[buildSh]; exist { delete(lg.tpls, buildSh) } if _, exist := lg.tpls[bootstrapSh]; exist { delete(lg.tpls, bootstrapSh) } } sd, err := serviceToLayoutData(service) if err != nil { return err } rd, err := serviceToRouterData(service) if err != nil { return err } if service.HasIdl { for k := range lg.tpls { if strings.Contains(k, registerTplName) { delete(lg.tpls, k) break } } } data := map[string]interface{}{ "*": sd, layoutConfig.Layouts[routerIndex].Path: rd, // router.go layoutConfig.Layouts[routerGenIndex].Path: rd, // router_gen.go } return lg.Generate(data) } // serviceToLayoutData stores go mod, serviceName, UseApacheThrift mapping func serviceToLayoutData(service Layout) (map[string]interface{}, error) { goMod := service.GoModule if goMod == "" { return nil, errors.New("please specify go-module") } handlerPkg := filepath.Base(defaultHandlerDir) if len(service.HandlerDir) != 0 { handlerPkg = filepath.Base(service.HandlerDir) } routerPkg := filepath.Base(defaultRouterDir) if len(service.RouterDir) != 0 { routerPkg = filepath.Base(service.RouterDir) } serviceName := service.ServiceName if len(serviceName) == 0 { serviceName = meta.DefaultServiceName } return map[string]interface{}{ "GoModule": goMod, "ServiceName": serviceName, "UseApacheThrift": service.UseApacheThrift, "HandlerPkg": handlerPkg, "RouterPkg": routerPkg, }, nil } // serviceToRouterData stores the registers function, router import path, handler import path func serviceToRouterData(service Layout) (map[string]interface{}, error) { routerDir := sp + defaultRouterDir handlerDir := sp + defaultHandlerDir if len(service.RouterDir) != 0 { routerDir = sp + service.RouterDir } if len(service.HandlerDir) != 0 { handlerDir = sp + service.HandlerDir } return map[string]interface{}{ "Registers": []string{}, "RouterPkgPath": service.GoModule + filepath.ToSlash(routerDir), "HandlerPkgPath": service.GoModule + filepath.ToSlash(handlerDir), }, nil } func (lg *LayoutGenerator) GenerateByConfig(configPath string) error { if err := lg.checkInited(); err != nil { return err } buf, err := ioutil.ReadFile(configPath) if err != nil { return fmt.Errorf("read data file '%s' failed, err: %v", configPath, err.Error()) } var data map[string]interface{} if err := json.Unmarshal(buf, &data); err != nil { return fmt.Errorf("unmarshal json data failed, err: %v", err.Error()) } return lg.Generate(data) } func (lg *LayoutGenerator) Degenerate() error { return lg.TemplateGenerator.Degenerate() } ================================================ FILE: cmd/hz/generator/layout_tpl.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package generator import "path/filepath" //-----------------------------------Default Layout----------------------------------------- const ( sp = string(filepath.Separator) defaultBizDir = "biz" defaultModelDir = "biz" + sp + "model" defaultHandlerDir = "biz" + sp + "handler" defaultServiceDir = "biz" + sp + "service" defaultDalDir = "biz" + sp + "dal" defaultScriptDir = "script" defaultConfDir = "conf" defaultRouterDir = "biz" + sp + "router" defaultClientDir = "biz" + sp + "client" ) const ( routerGenIndex = 8 routerIndex = 9 RegisterFile = "router_gen.go" ) var defaultLayoutConfig = TemplateConfig{ Layouts: []Template{ { Path: defaultDalDir + sp, }, { Path: defaultHandlerDir + sp, }, { Path: defaultModelDir + sp, }, { Path: defaultServiceDir + sp, }, { Path: "main.go", Body: `// Code generated by hertz generator. package main import ( "github.com/cloudwego/hertz/pkg/app/server" ) func main() { h := server.Default() register(h) h.Spin() } `, }, { Path: "go.mod", Delims: [2]string{"{{", "}}"}, Body: `module {{.GoModule}} {{- if .UseApacheThrift}} replace github.com/apache/thrift => github.com/apache/thrift v0.13.0 {{- end}} `, }, { Path: ".gitignore", Body: `*.o *.a *.so _obj _test *.[568vq] [568vq].out *.cgo1.go *.cgo2.c _cgo_defun.c _cgo_gotypes.go _cgo_export.* _testmain.go *.exe *.exe~ *.test *.prof *.rar *.zip *.gz *.psd *.bmd *.cfg *.pptx *.log *nohup.out *settings.pyc *.sublime-project *.sublime-workspace !.gitkeep .DS_Store /.idea /.vscode /output *.local.yml dumped_hertz_remote_config.json `, }, { Path: defaultHandlerDir + sp + "ping.go", Body: `// Code generated by hertz generator. package {{.HandlerPkg}} import ( "context" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol/consts" ) // Ping . func Ping(ctx context.Context, c *app.RequestContext) { c.JSON(consts.StatusOK, utils.H{ "message": "pong", }) } `, }, { Path: RegisterFile, Body: `// Code generated by hertz generator. DO NOT EDIT. package main import ( "github.com/cloudwego/hertz/pkg/app/server" router "{{.RouterPkgPath}}" ) // register registers all routers. func register(r *server.Hertz) { router.GeneratedRegister(r) customizedRegister(r) } `, }, { Path: "router.go", Body: `// Code generated by hertz generator. package main import ( "github.com/cloudwego/hertz/pkg/app/server" handler "{{.HandlerPkgPath}}" ) // customizeRegister registers customize routers. func customizedRegister(r *server.Hertz){ r.GET("/ping", handler.Ping) // your code ... } `, }, { Path: defaultRouterDir + sp + registerTplName, Body: `// Code generated by hertz generator. DO NOT EDIT. package {{.RouterPkg}} import ( "github.com/cloudwego/hertz/pkg/app/server" ) // GeneratedRegister registers routers generated by IDL. func GeneratedRegister(r *server.Hertz){ ` + insertPointNew + ` } `, }, { Path: "build.sh", Body: `#!/bin/bash RUN_NAME={{.ServiceName}} mkdir -p output/bin cp script/* output 2>/dev/null chmod +x output/bootstrap.sh go build -o output/bin/${RUN_NAME}`, }, { Path: defaultScriptDir + sp + "bootstrap.sh", Body: `#!/bin/bash CURDIR=$(cd $(dirname $0); pwd) BinaryName={{.ServiceName}} echo "$CURDIR/bin/${BinaryName}" exec $CURDIR/bin/${BinaryName}`, }, }, } ================================================ FILE: cmd/hz/generator/model/define.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package model var ( BaseTypes = []*Type{TypeBool, TypeByte, TypeInt8, TypeInt16, TypeInt32, TypeInt64, TypeUint8, TypeUint16, TypeUint32, TypeUint64, TypeFloat64, TypeString, TypeBinary} ContainerTypes = []*Type{TypeBaseList, TypeBaseMap, TypeBaseSet} BaseModel = Model{} ) var ( TypeBool = &Type{ Name: "bool", Scope: &BaseModel, Kind: KindBool, } TypeByte = &Type{ Name: "int8", Scope: &BaseModel, Kind: KindInt8, } TypePbByte = &Type{ Name: "byte", Scope: &BaseModel, Kind: KindInt8, } TypeUint8 = &Type{ Name: "uint8", Scope: &BaseModel, Kind: KindInt8, } TypeUint16 = &Type{ Name: "uint16", Scope: &BaseModel, Kind: KindInt16, } TypeUint32 = &Type{ Name: "uint32", Scope: &BaseModel, Kind: KindInt32, } TypeUint64 = &Type{ Name: "uint64", Scope: &BaseModel, Kind: KindInt64, } TypeUint = &Type{ Name: "uint", Scope: &BaseModel, Kind: KindInt, } TypeInt8 = &Type{ Name: "int8", Scope: &BaseModel, Kind: KindInt8, } TypeInt16 = &Type{ Name: "int16", Scope: &BaseModel, Kind: KindInt16, } TypeInt32 = &Type{ Name: "int32", Scope: &BaseModel, Kind: KindInt32, } TypeInt64 = &Type{ Name: "int64", Scope: &BaseModel, Kind: KindInt64, } TypeInt = &Type{ Name: "int", Scope: &BaseModel, Kind: KindInt, } TypeFloat32 = &Type{ Name: "float32", Scope: &BaseModel, Kind: KindFloat64, } TypeFloat64 = &Type{ Name: "float64", Scope: &BaseModel, Kind: KindFloat64, } TypeString = &Type{ Name: "string", Scope: &BaseModel, Kind: KindString, } TypeBinary = &Type{ Name: "binary", Scope: &BaseModel, Kind: KindSlice, Category: CategoryBinary, Extra: []*Type{TypePbByte}, } TypeBaseMap = &Type{ Name: "map", Scope: &BaseModel, Kind: KindMap, Category: CategoryMap, } TypeBaseSet = &Type{ Name: "set", Scope: &BaseModel, Kind: KindSlice, Category: CategorySet, } TypeBaseList = &Type{ Name: "list", Scope: &BaseModel, Kind: KindSlice, Category: CategoryList, } ) func NewCategoryType(typ *Type, cg Category) *Type { cyp := *typ cyp.Category = cg return &cyp } func NewStructType(name string, cg Category) *Type { return &Type{ Name: name, Scope: nil, Kind: KindStruct, Category: cg, Indirect: false, Extra: nil, HasNew: true, } } func NewFuncType(name string, cg Category) *Type { return &Type{ Name: name, Scope: nil, Kind: KindFunc, Category: cg, Indirect: false, Extra: nil, HasNew: false, } } func IsBaseType(typ *Type) bool { for _, t := range BaseTypes { if typ == t { return true } } return false } func NewEnumType(name string, cg Category) *Type { return &Type{ Name: name, Scope: &BaseModel, Kind: KindInt, Category: cg, Indirect: false, Extra: nil, HasNew: true, } } func NewOneofType(name string) *Type { return &Type{ Name: name, Scope: &BaseModel, Kind: KindInterface, Indirect: false, Extra: nil, HasNew: true, } } ================================================ FILE: cmd/hz/generator/model/expr.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package model import ( "fmt" "strconv" ) type BoolExpression struct { Src bool } func (boolExpr BoolExpression) Expression() string { if boolExpr.Src { return "true" } else { return "false" } } type StringExpression struct { Src string } func (stringExpr StringExpression) Expression() string { return fmt.Sprintf("%q", stringExpr.Src) } type NumberExpression struct { Src string } func (numExpr NumberExpression) Expression() string { return numExpr.Src } type ListExpression struct { ElementType *Type Elements []Literal } type IntExpression struct { Src int } func (intExpr IntExpression) Expression() string { return strconv.Itoa(intExpr.Src) } type DoubleExpression struct { Src float64 } func (doubleExpr DoubleExpression) Expression() string { return strconv.FormatFloat(doubleExpr.Src, 'f', -1, 64) } func (listExpr ListExpression) Expression() string { ret := "[]" + listExpr.ElementType.Name + "{\n" for _, e := range listExpr.Elements { ret += e.Expression() + ",\n" } ret += "\n}" return ret } type MapExpression struct { KeyType *Type ValueType *Type Elements map[string]Literal } func (mapExpr MapExpression) Expression() string { ret := "map[" + mapExpr.KeyType.Name + "]" + mapExpr.ValueType.Name + "{\n" for k, e := range mapExpr.Elements { ret += fmt.Sprintf("%q: %s,\n", k, e.Expression()) } ret += "\n}" return ret } ================================================ FILE: cmd/hz/generator/model/golang/constant.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package golang var constants = ` {{define "Constants"}} const {{.Name}} {{.Type.ResolveName ROOT}} = {{.Value.Expression}} {{end}} ` ================================================ FILE: cmd/hz/generator/model/golang/enum.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package golang // Enum . var enum = ` {{define "Enum"}} {{- $EnumType := (Identify .Name)}} type {{$EnumType}} {{.GoType}} const ( {{- range $i, $e := .Values}} {{$EnumType}}_{{$e.Name}} {{$EnumType}} = {{$e.Value.Expression}} {{- end}} ) func (p {{$EnumType}}) String() string { switch p { {{- range $i, $e := .Values}} case {{$EnumType}}_{{$e.Name}}: return "{{printf "%s%s" $EnumType $e.Name | SnakeCase}}" {{- end}} } return "" } func {{$EnumType}}FromString(s string) ({{$EnumType}}, error) { switch s { {{- range $i, $e := .Values}} case "{{printf "%s%s" $EnumType $e.Name | SnakeCase}}": return {{$EnumType}}_{{$e.Name}}, nil {{- end}} } return {{$EnumType}}(0), fmt.Errorf("not a valid {{$EnumType}} string") } {{- if Features.MarshalEnumToText}} func (p {{$EnumType}}) MarshalText() ([]byte, error) { return []byte(p.String()), nil } func (p *{{$EnumType}}) UnmarshalText(text []byte) error { q, err := {{$EnumType}}FromString(string(text)) if err != nil { return err } *p = q return nil } {{- end}} {{end}} ` ================================================ FILE: cmd/hz/generator/model/golang/file.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package golang var file = `{{$ROOT := . -}} // Code generated by hz. package {{.PackageName}} import ( "fmt" {{- range $alias, $model := .Imports}} {{$model.PackageName}} "{{$model.Package}}" {{- end}} ) {{- range .Typedefs}} {{template "Typedef" .}} {{- end}} {{- range .Constants}} {{template "Constants" .}} {{- end}} {{- range .Variables}} {{template "Variables" .}} {{- end}} {{- range .Functions}} {{template "Function" .}} {{- end}} {{- range .Enums}} {{template "Enum" .}} {{- end}} {{- range .Oneofs}} {{template "Oneof" .}} {{- end}} {{- range .Structs}} {{template "Struct" .}} {{- end}} {{- range .Methods}} {{template "Method" .}} {{- end}} ` ================================================ FILE: cmd/hz/generator/model/golang/function.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package golang var function = ` {{define "Function"}} func {{template "FuncBody" . -}} {{end}}{{/* define "Function" */}} {{define "FuncBody"}} {{- .Name -}}( {{- range $i, $arg := .Args -}} {{- if gt $i 0}}, {{end -}} {{$arg.Name}} {{$arg.Type.ResolveName ROOT}} {{- end -}}{{/* range */}}) {{- if gt (len .Rets) 0}} ({{end -}} {{- range $i, $ret := .Rets -}} {{- if gt $i 0}}, {{end -}} {{$ret.Type.ResolveName ROOT}} {{- end -}}{{/* range */}} {{- if gt (len .Rets) 0}}) {{end -}}{ {{.Code}} } {{end}}{{/* define "FuncBody" */}} ` var method = ` {{define "Method"}} func ({{.ReceiverName}} {{.ReceiverType.ResolveName ROOT}}) {{- template "FuncBody" .Function -}} {{end}} ` ================================================ FILE: cmd/hz/generator/model/golang/init.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package golang import ( "fmt" "strings" "text/template" ) var tpls *template.Template var list = map[string]string{ "file": file, "typedef": typedef, "constants": constants, "variables": variables, "function": function, "enum": enum, "struct": structLike, "method": method, "oneof": oneof, } /***********************Export API*******************************/ func Template() (*template.Template, error) { if tpls != nil { return tpls, nil } tpls = new(template.Template) tpls = tpls.Funcs(funcMap) var err error for k, li := range list { tpls, err = tpls.Parse(li) if err != nil { return nil, fmt.Errorf("parse template '%s' failed, err: %v", k, err.Error()) } } return tpls, nil } func List() map[string]string { return list } /***********************Template Funcs**************************/ var funcMap = template.FuncMap{ "Features": getFeatures, "Identify": identify, "CamelCase": camelCase, "SnakeCase": snakeCase, "GetTypedefReturnStr": getTypedefReturnStr, } func Funcs(name string, fn interface{}) error { if _, ok := funcMap[name]; ok { return fmt.Errorf("duplicate function: %s has been registered", name) } funcMap[name] = fn return nil } func identify(name string) string { return name } func camelCase(name string) string { return name } func snakeCase(name string) string { return name } func getTypedefReturnStr(name string) string { if strings.Contains(name, ".") { idx := strings.LastIndex(name, ".") return name[:idx] + "." + "New" + name[idx+1:] + "()" } return "New" + name + "()" } /***********************Template Options**************************/ type feature struct { MarshalEnumToText bool TypedefAsTypeAlias bool } var features = feature{} func getFeatures() feature { return features } func SetOption(opt string) error { switch opt { case "MarshalEnumToText": features.MarshalEnumToText = true case "TypedefAsTypeAlias": features.TypedefAsTypeAlias = true } return nil } var Options = []string{ "MarshalEnumToText", "TypedefAsTypeAlias", } func GetOptions() []string { return Options } ================================================ FILE: cmd/hz/generator/model/golang/oneof.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package golang var oneof = ` {{define "Oneof"}} type {{$.InterfaceName}} interface { {{$.InterfaceName}}() } {{range $i, $f := .Choices}} type {{$f.MessageName}}_{{$f.ChoiceName}} struct { {{$f.ChoiceName}} {{$f.Type.ResolveName ROOT}} } {{end}} {{range $i, $f := .Choices}} func (*{{$f.MessageName}}_{{$f.ChoiceName}}) {{$.InterfaceName}}() {} {{end}} {{range $i, $f := .Choices}} func (p *{{$f.MessageName}}) Get{{$f.ChoiceName}}() {{$f.Type.ResolveName ROOT}} { if p, ok := p.Get{{$.OneofName}}().(*{{$f.MessageName}}_{{$f.ChoiceName}}); ok { return p.{{$f.ChoiceName}} } return {{$f.Type.ResolveDefaultValue}} } {{end}} {{end}} ` ================================================ FILE: cmd/hz/generator/model/golang/struct.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package golang // StructLike is the code template for struct, union, and exception. var structLike = ` {{define "Struct"}} {{- $TypeName := (Identify .Name) -}} {{$MessageLeadingComments := .LeadingComments}} {{if ne (len $MessageLeadingComments) 0}} //{{$MessageLeadingComments}} {{end -}} type {{$TypeName}} struct { {{- range $i, $f := .Fields}} {{- $FieldLeadingComments := $f.LeadingComments}} {{$FieldTrailingComments := $f.TrailingComments -}} {{- if ne (len $FieldLeadingComments) 0 -}} //{{$FieldLeadingComments}} {{end -}} {{- if $f.IsPointer -}} {{$f.Name}} *{{$f.Type.ResolveName ROOT}} {{$f.GenGoTags}}{{if ne (len $FieldTrailingComments) 0}} //{{$FieldTrailingComments}}{{end -}} {{- else -}} {{$f.Name}} {{$f.Type.ResolveName ROOT}} {{$f.GenGoTags}}{{if ne (len $FieldTrailingComments) 0}} //{{$FieldTrailingComments}}{{end -}} {{- end -}} {{- end}} } func New{{$TypeName}}() *{{$TypeName}} { return &{{$TypeName}}{ {{template "StructLikeDefault" .}} } } {{template "FieldGetOrSet" .}} {{if eq .Category 14}} func (p *{{$TypeName}}) CountSetFields{{$TypeName}}() int { count := 0 {{- range $i, $f := .Fields}} {{- if $f.Type.IsSettable}} if p.IsSet{{$f.Name}}() { count++ } {{- end}} {{- end}} return count } {{- end}} func (p *{{$TypeName}}) String() string { if p == nil { return "" } return fmt.Sprintf("{{$TypeName}}(%+v)", *p) } {{- if eq .Category 15}} func (p *{{$TypeName}}) Error() string { return p.String() } {{- end}} {{- end}}{{/* define "StructLike" */}} {{- define "StructLikeDefault"}} {{- range $i, $f := .Fields}} {{- if $f.IsSetDefault}} {{$f.Name}}: {{$f.DefaultValue.Expression}}, {{- end}} {{- end}} {{- end -}}{{/* define "StructLikeDefault" */}} {{- define "FieldGetOrSet"}} {{- $TypeName := (Identify .Name)}} {{- range $i, $f := .Fields}} {{$FieldName := $f.Name}} {{$FieldTypeName := $f.Type.ResolveName ROOT}} {{- if $f.Type.IsSettable}} func (p *{{$TypeName}}) IsSet{{$FieldName}}() bool { return p.{{$FieldName}} != nil } {{- end}}{{/* IsSettable . */}} func (p *{{$TypeName}}) Get{{$FieldName}}() {{$FieldTypeName}} { {{- if $f.Type.IsSettable}} if !p.IsSet{{$FieldName}}() { return {{with $f.DefaultValue}}{{$f.DefaultValue.Expression}}{{else}}nil{{end}} } {{- end}} {{- if $f.IsPointer}} return *p.{{$FieldName}} {{else}} return p.{{$FieldName}} {{- end -}} } func (p *{{$TypeName}}) Set{{$FieldName}}(val {{$FieldTypeName}}) { {{- if $f.IsPointer}} *p.{{$FieldName}} = val {{else}} p.{{$FieldName}} = val {{- end -}} } {{- end}}{{/* range .Fields */}} {{- end}}{{/* define "FieldGetOrSet" */}} ` ================================================ FILE: cmd/hz/generator/model/golang/typedef.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package golang // Typedef . var typedef = ` {{define "Typedef"}} {{- $NewTypeName := (Identify .Alias)}} {{- $OldTypeName := .Type.ResolveNameForTypedef ROOT}} type {{$NewTypeName}} = {{$OldTypeName}} {{if eq .Type.Kind 25}}{{if .Type.HasNew}} func New{{$NewTypeName}}() *{{$NewTypeName}} { return {{(GetTypedefReturnStr $OldTypeName)}} } {{- end}}{{- end}} {{- end}} ` ================================================ FILE: cmd/hz/generator/model/golang/variable.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package golang var variables = ` {{- define "Variables"}} var {{.Name}} {{.Type.ResolveName ROOT}} = {{.Value.Expression}} {{end}} ` ================================================ FILE: cmd/hz/generator/model/model.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package model import ( "errors" "fmt" "strings" ) type Kind uint const ( KindInvalid Kind = iota KindBool KindInt KindInt8 KindInt16 KindInt32 KindInt64 KindUint KindUint8 KindUint16 KindUint32 KindUint64 KindUintptr KindFloat32 KindFloat64 KindComplex64 KindComplex128 KindArray KindChan KindFunc KindInterface KindMap KindPtr KindSlice KindString KindStruct KindUnsafePointer ) type Category int64 const ( CategoryConstant Category = 1 CategoryBinary Category = 8 CategoryMap Category = 9 CategoryList Category = 10 CategorySet Category = 11 CategoryEnum Category = 12 CategoryStruct Category = 13 CategoryUnion Category = 14 CategoryException Category = 15 CategoryTypedef Category = 16 CategoryService Category = 17 ) type Model struct { FilePath string Package string Imports map[string]*Model //{{import}}:Model // rendering data PackageName string // Imports map[string]string //{{alias}}:{{import}} Typedefs []TypeDef Constants []Constant Variables []Variable Functions []Function Enums []Enum Structs []Struct Methods []Method Oneofs []Oneof } func (m Model) IsEmpty() bool { return len(m.Typedefs) == 0 && len(m.Constants) == 0 && len(m.Variables) == 0 && len(m.Functions) == 0 && len(m.Enums) == 0 && len(m.Structs) == 0 && len(m.Methods) == 0 } type Models []*Model func (a *Models) MergeMap(b map[string]*Model) { for _, v := range b { insert := true for _, p := range *a { if p == v { insert = false } } if insert { *a = append(*a, v) } } return } func (a *Models) MergeArray(b []*Model) { for _, v := range b { insert := true for _, p := range *a { if p == v { insert = false } } if insert { *a = append(*a, v) } } return } type RequiredNess int const ( RequiredNess_Default RequiredNess = 0 RequiredNess_Required RequiredNess = 1 RequiredNess_Optional RequiredNess = 2 ) type Type struct { Name string Scope *Model Kind Kind Indirect bool Category Category Extra []*Type // [{key_type},{value_type}] for map, [{element_type}] for list or set HasNew bool } func (rt *Type) ResolveDefaultValue() string { if rt == nil { return "" } switch rt.Kind { case KindInt, KindInt8, KindInt16, KindInt32, KindInt64, KindUint, KindUint16, KindUint32, KindUint64, KindFloat32, KindFloat64, KindComplex64, KindComplex128: return "0" case KindBool: return "false" case KindString: return "\"\"" default: return "nil" } } func (rt *Type) ResolveNameForTypedef(scope *Model) (string, error) { if rt == nil { return "", errors.New("type is nil") } name := rt.Name if rt.Scope == nil { return rt.Name, nil } switch rt.Kind { case KindArray, KindSlice: if len(rt.Extra) != 1 { return "", fmt.Errorf("the type: %s should have 1 extra type, but has %d", rt.Name, len(rt.Extra)) } resolveName, err := rt.Extra[0].ResolveName(scope) if err != nil { return "", err } name = fmt.Sprintf("[]%s", resolveName) case KindMap: if len(rt.Extra) != 2 { return "", fmt.Errorf("the type: %s should have 2 extra types, but has %d", rt.Name, len(rt.Extra)) } resolveKey, err := rt.Extra[0].ResolveName(scope) if err != nil { return "", err } resolveValue, err := rt.Extra[1].ResolveName(scope) if err != nil { return "", err } name = fmt.Sprintf("map[%s]%s", resolveKey, resolveValue) case KindChan: if len(rt.Extra) != 1 { return "", fmt.Errorf("the type: %s should have 1 extra type, but has %d", rt.Name, len(rt.Extra)) } resolveName, err := rt.Extra[0].ResolveName(scope) if err != nil { return "", err } name = fmt.Sprintf("chan %s", resolveName) } if scope != nil && rt.Scope != &BaseModel && rt.Scope.Package != scope.Package { name = rt.Scope.PackageName + "." + name } return name, nil } func (rt *Type) ResolveName(scope *Model) (string, error) { if rt == nil { return "", fmt.Errorf("type is nil") } name := rt.Name if rt.Scope == nil { if rt.Kind == KindStruct { return "*" + rt.Name, nil } return rt.Name, nil } if rt.Category == CategoryTypedef { if scope != nil && rt.Scope != &BaseModel && rt.Scope.Package != scope.Package { name = rt.Scope.PackageName + "." + name } if rt.Kind == KindStruct { name = "*" + name } return name, nil } switch rt.Kind { case KindArray, KindSlice: if len(rt.Extra) != 1 { return "", fmt.Errorf("The type: %s should have 1 extra type, but has %d", rt.Name, len(rt.Extra)) } resolveName, err := rt.Extra[0].ResolveName(scope) if err != nil { return "", err } name = fmt.Sprintf("[]%s", resolveName) case KindMap: if len(rt.Extra) != 2 { return "", fmt.Errorf("The type: %s should have 2 extra type, but has %d", rt.Name, len(rt.Extra)) } resolveKey, err := rt.Extra[0].ResolveName(scope) if err != nil { return "", err } resolveValue, err := rt.Extra[1].ResolveName(scope) if err != nil { return "", err } name = fmt.Sprintf("map[%s]%s", resolveKey, resolveValue) case KindChan: if len(rt.Extra) != 1 { return "", fmt.Errorf("The type: %s should have 1 extra type, but has %d", rt.Name, len(rt.Extra)) } resolveName, err := rt.Extra[0].ResolveName(scope) if err != nil { return "", err } name = fmt.Sprintf("chan %s", resolveName) } if scope != nil && rt.Scope != &BaseModel && rt.Scope.Package != scope.Package { name = rt.Scope.PackageName + "." + name } if rt.Kind == KindStruct { name = "*" + name } return name, nil } func (rt *Type) IsBinary() bool { return rt.Category == CategoryBinary && (rt.Kind == KindSlice || rt.Kind == KindArray) } func (rt *Type) IsBaseType() bool { return rt.Kind < KindComplex64 } func (rt *Type) IsSettable() bool { switch rt.Kind { case KindArray, KindChan, KindFunc, KindInterface, KindMap, KindPtr, KindSlice, KindUnsafePointer: return true } return false } type TypeDef struct { Scope *Model Alias string Type *Type } type Constant struct { Scope *Model Name string Type *Type Value Literal } type Literal interface { Expression() string } type Variable struct { Scope *Model Name string Type *Type Value Literal } type Function struct { Scope *Model Name string Args []Variable Rets []Variable Code string } type Method struct { Scope *Model ReceiverName string ReceiverType *Type ByPtr bool Function } type Enum struct { Scope *Model Name string GoType string Values []Constant } type Struct struct { Scope *Model Name string Fields []Field Category Category LeadingComments string } type Field struct { Scope *Struct Name string Type *Type IsSetDefault bool DefaultValue Literal Required RequiredNess Tags Tags LeadingComments string TrailingComments string IsPointer bool } type Oneof struct { MessageName string OneofName string InterfaceName string Choices []Choice } type Choice struct { MessageName string ChoiceName string Type *Type } type Tags []Tag type Tag struct { Key string Value string IsDefault bool // default tag } func (ts Tags) String() string { ret := make([]string, 0, len(ts)) for _, t := range ts { ret = append(ret, fmt.Sprintf("%v:%q", t.Key, t.Value)) } return strings.Join(ret, " ") } func (ts *Tags) Remove(name string) { ret := make([]Tag, 0, len(*ts)) for _, t := range *ts { if t.Key != name { ret = append(ret, t) } } *ts = ret } func (ts Tags) Len() int { return len(ts) } func (ts Tags) Less(i, j int) bool { return ts[i].Key < ts[j].Key } func (ts Tags) Swap(i, j int) { ts[i], ts[j] = ts[j], ts[i] } func (f Field) GenGoTags() string { if len(f.Tags) == 0 { return "" } return fmt.Sprintf("`%s`", f.Tags.String()) } ================================================ FILE: cmd/hz/generator/model.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package generator import ( "fmt" "path/filepath" "strings" "text/template" "github.com/cloudwego/hertz/cmd/hz/generator/model" "github.com/cloudwego/hertz/cmd/hz/generator/model/golang" "github.com/cloudwego/hertz/cmd/hz/meta" "github.com/cloudwego/hertz/cmd/hz/util" ) //---------------------------------Backend---------------------------------- type Option string const ( OptionMarshalEnumToText Option = "MarshalEnumToText" OptionTypedefAsTypeAlias Option = "TypedefAsTypeAlias" ) type Backend interface { Template() (*template.Template, error) List() map[string]string SetOption(opts string) error GetOptions() []string Funcs(name string, fn interface{}) error } type GolangBackend struct{} func (gb *GolangBackend) Template() (*template.Template, error) { return golang.Template() } func (gb *GolangBackend) List() map[string]string { return golang.List() } func (gb *GolangBackend) SetOption(opts string) error { return golang.SetOption(opts) } func (gb *GolangBackend) GetOptions() []string { return golang.GetOptions() } func (gb *GolangBackend) Funcs(name string, fn interface{}) error { return golang.Funcs(name, fn) } func switchBackend(backend meta.Backend) Backend { switch backend { case meta.BackendGolang: return &GolangBackend{} } return loadThirdPartyBackend(string(backend)) } func loadThirdPartyBackend(plugin string) Backend { panic("no implement yet!") } /**********************Generating*************************/ func (pkgGen *HttpPackageGenerator) LoadBackend(backend meta.Backend) error { bd := switchBackend(backend) if bd == nil { return fmt.Errorf("no found backend '%s'", backend) } for _, opt := range pkgGen.Options { if err := bd.SetOption(string(opt)); err != nil { return fmt.Errorf("set option %s error, err: %v", opt, err.Error()) } } err := bd.Funcs("ROOT", func() *model.Model { return pkgGen.curModel }) if err != nil { return fmt.Errorf("register global function in model template failed, err: %v", err.Error()) } tpl, err := bd.Template() if err != nil { return fmt.Errorf("load backend %s failed, err: %v", backend, err.Error()) } if pkgGen.tpls == nil { pkgGen.tpls = map[string]*template.Template{} } pkgGen.tpls[modelTplName] = tpl pkgGen.loadedBackend = bd return nil } func (pkgGen *HttpPackageGenerator) GenModel(data *model.Model, gen bool) error { if pkgGen.processedModels == nil { pkgGen.processedModels = map[*model.Model]bool{} } if _, ok := pkgGen.processedModels[data]; !ok { var path string var updatePackage bool if strings.HasPrefix(data.Package, pkgGen.ProjPackage) && data.PackageName != pkgGen.ProjPackage { path = data.Package[len(pkgGen.ProjPackage):] } else { path = data.Package updatePackage = true } modelDir := util.SubDir(pkgGen.ModelDir, path) if updatePackage { data.Package = util.SubPackage(pkgGen.ProjPackage, modelDir) } data.FilePath = filepath.Join(modelDir, util.BaseNameAndTrim(data.FilePath)+".go") pkgGen.processedModels[data] = true } for _, dep := range data.Imports { if err := pkgGen.GenModel(dep, false); err != nil { return fmt.Errorf("generate model %s failed, err: %v", dep.FilePath, err.Error()) } } if gen && !data.IsEmpty() { pkgGen.curModel = data removeDuplicateImport(data) err := pkgGen.TemplateGenerator.Generate(data, modelTplName, data.FilePath, false) pkgGen.curModel = nil return err } return nil } // Idls with the same Package do not need to refer to each other func removeDuplicateImport(data *model.Model) { for k, v := range data.Imports { if data.Package == v.Package { delete(data.Imports, k) } } } ================================================ FILE: cmd/hz/generator/model_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package generator import ( "testing" "text/template" "github.com/cloudwego/hertz/cmd/hz/generator/model" "github.com/cloudwego/hertz/cmd/hz/meta" ) type StringValue struct { src string } func (sv *StringValue) Expression() string { return sv.src } func TestIdlGenerator_GenModel(t *testing.T) { typeModel := &model.Type{ Name: "Model", Kind: model.KindStruct, Indirect: true, } typeErr := &model.Type{ Name: "error", Kind: model.KindInterface, Indirect: false, } type fields struct { ConfigPath string OutputDir string Backend meta.Backend handlerDir string routerDir string modelDir string ProjPackage string Config *TemplateConfig tpls map[string]*template.Template } type args struct { data *model.Model } tests := []struct { name string fields fields args args wantErr bool }{ { name: "", fields: fields{ OutputDir: "./testdata", Backend: meta.BackendGolang, }, args: args{ data: &model.Model{ FilePath: "idl/main.thrift", Package: "model/psm", PackageName: "psm", Imports: map[string]*model.Model{ "base": { Package: "model/base", PackageName: "base", }, }, Typedefs: []model.TypeDef{ { Alias: "HerztModel", Type: typeModel, }, }, Constants: []model.Constant{ { Name: "OBJ", Type: typeErr, Value: &StringValue{"fmt.Errorf(\"EOF\")"}, }, }, Variables: []model.Variable{ { Name: "Object", Type: typeModel, Value: &StringValue{"&Model{}"}, }, }, Functions: []model.Function{ { Name: "Init", Args: nil, Rets: []model.Variable{ { Name: "err", Type: typeErr, }, }, Code: "return nil", }, }, Enums: []model.Enum{ { Name: "Sex", Values: []model.Constant{ { Name: "Male", Type: &model.Type{ Name: "int", Kind: model.KindInt, Indirect: false, Category: 1, }, Value: &StringValue{"1"}, }, { Name: "Femal", Type: &model.Type{ Name: "int", Kind: model.KindInt, Indirect: false, Category: 1, }, Value: &StringValue{"2"}, }, }, }, }, Structs: []model.Struct{ { Name: "Model", Fields: []model.Field{ { Name: "A", Type: &model.Type{ Name: "[]byte", Kind: model.KindSlice, Indirect: false, Category: model.CategoryBinary, }, IsSetDefault: true, DefaultValue: &StringValue{"[]byte(\"\")"}, }, { Name: "B", Type: &model.Type{ Name: "Base", Kind: model.KindStruct, Indirect: false, }, }, }, Category: model.CategoryUnion, }, }, Methods: []model.Method{ { ReceiverName: "self", ReceiverType: typeModel, ByPtr: true, Function: model.Function{ Name: "Bind", Args: []model.Variable{ { Name: "c", Type: &model.Type{ Name: "RequestContext", Scope: &model.Model{ PackageName: "hertz", }, Kind: model.KindStruct, Indirect: true, }, }, }, Rets: []model.Variable{ { Name: "error", Type: typeErr, }, }, Code: "return nil", }, }, }, }, }, wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { self := &HttpPackageGenerator{ ConfigPath: tt.fields.ConfigPath, Backend: tt.fields.Backend, HandlerDir: tt.fields.handlerDir, RouterDir: tt.fields.routerDir, ModelDir: tt.fields.modelDir, ProjPackage: tt.fields.ProjPackage, TemplateGenerator: TemplateGenerator{ OutputDir: tt.fields.OutputDir, Config: tt.fields.Config, tpls: tt.fields.tpls, }, Options: []Option{ OptionTypedefAsTypeAlias, OptionMarshalEnumToText, }, } err := self.LoadBackend(meta.BackendGolang) if err != nil { t.Fatal(err) } if err := self.GenModel(tt.args.data, true); (err != nil) != tt.wantErr { t.Errorf("IdlGenerator.GenModel() error = %v, wantErr %v", err, tt.wantErr) } if err := self.Persist(); err != nil { t.Fatal(err) } }) } } ================================================ FILE: cmd/hz/generator/package.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package generator import ( "errors" "fmt" "io/ioutil" "path/filepath" "reflect" "text/template" "github.com/cloudwego/hertz/cmd/hz/generator/model" "github.com/cloudwego/hertz/cmd/hz/meta" "github.com/cloudwego/hertz/cmd/hz/util" "gopkg.in/yaml.v2" ) type HttpPackage struct { IdlName string Package string Services []*Service Models []*model.Model RouterInfo *Router } type Service struct { Name string Methods []*HttpMethod ClientMethods []*ClientMethod Models []*model.Model // all dependency models BaseDomain string // base domain for client code ServiceGroup string // service level router group ServiceGenDir string // handler_dir for handler_by_service } // HttpPackageGenerator is used to record the configuration related to generating hertz http code. type HttpPackageGenerator struct { ConfigPath string // package template path Backend meta.Backend // model template Options []Option CmdType string ProjPackage string // go module for project HandlerDir string RouterDir string ModelDir string // like: biz/model or biz\model (Windows) UseDir string // XXX: should be UsePkg, not a filepath? ClientDir string // client dir for "new"/"update" command IdlClientDir string // client dir for "client" command ForceClientDir string // client dir without namespace for "client" command BaseDomain string // request domain for "client" command QueryEnumAsInt bool // client code use number for query parameter ServiceGenDir string NeedModel bool HandlerByMethod bool // generate handler files with method dimension SnakeStyleMiddleware bool // use snake name style for middleware SortRouter bool ForceUpdateClient bool // force update 'hertz_client.go' loadedBackend Backend curModel *model.Model processedModels map[*model.Model]bool TemplateGenerator } func (pkgGen *HttpPackageGenerator) Init() error { defaultConfig := packageConfig customConfig := TemplateConfig{} // unmarshal from user-defined config file if it exists if pkgGen.ConfigPath != "" { cdata, err := ioutil.ReadFile(pkgGen.ConfigPath) if err != nil { return fmt.Errorf("read layout config from %s failed, err: %v", pkgGen.ConfigPath, err.Error()) } if err = yaml.Unmarshal(cdata, &customConfig); err != nil { return fmt.Errorf("unmarshal layout config failed, err: %v", err.Error()) } if reflect.DeepEqual(customConfig, TemplateConfig{}) { return errors.New("empty config") } } if pkgGen.tpls == nil { pkgGen.tpls = make(map[string]*template.Template, len(defaultConfig.Layouts)) } if pkgGen.tplsInfo == nil { pkgGen.tplsInfo = make(map[string]*Template, len(defaultConfig.Layouts)) } // extract routerTplName/middlewareTplName/handlerTplName/registerTplName/modelTplName/clientTplName directories // load default template for _, layout := range defaultConfig.Layouts { // default template use "fileName" as template name path := filepath.Base(layout.Path) err := pkgGen.loadLayout(layout, path, true) if err != nil { return err } } // override the default template, other customized file template will be loaded by "TemplateGenerator.Init" for _, layout := range customConfig.Layouts { if !IsDefaultPackageTpl(layout.Path) { continue } err := pkgGen.loadLayout(layout, layout.Path, true) if err != nil { return err } } pkgGen.Config = &customConfig // load Model tpl if need if pkgGen.Backend != "" { if err := pkgGen.LoadBackend(pkgGen.Backend); err != nil { return fmt.Errorf("load model template failed, err: %v", err.Error()) } } pkgGen.processedModels = make(map[*model.Model]bool) pkgGen.TemplateGenerator.isPackageTpl = true return pkgGen.TemplateGenerator.Init() } func (pkgGen *HttpPackageGenerator) checkInited() (bool, error) { if pkgGen.tpls == nil { if err := pkgGen.Init(); err != nil { return false, fmt.Errorf("init layout config failed, err: %v", err.Error()) } } return pkgGen.ConfigPath == "", nil } func (pkgGen *HttpPackageGenerator) Generate(pkg *HttpPackage) error { if _, err := pkgGen.checkInited(); err != nil { return err } if len(pkg.Models) != 0 { for _, m := range pkg.Models { if err := pkgGen.GenModel(m, pkgGen.NeedModel); err != nil { return fmt.Errorf("generate model %s failed, err: %v", m.FilePath, err.Error()) } } } if pkgGen.CmdType == meta.CmdClient { // default client dir clientDir := pkgGen.IdlClientDir // user specify client dir if len(pkgGen.ClientDir) != 0 { clientDir = pkgGen.ClientDir } if err := pkgGen.genClient(pkg, clientDir); err != nil { return err } if err := pkgGen.genCustomizedFile(pkg); err != nil { return err } return nil } // this is for handler_by_service, the handler_dir is {$HANDLER_DIR}/{$PKG} handlerDir := util.SubDir(pkgGen.HandlerDir, pkg.Package) if pkgGen.HandlerByMethod { handlerDir = pkgGen.HandlerDir } handlerPackage := util.SubPackage(pkgGen.ProjPackage, handlerDir) routerDir := util.SubDir(pkgGen.RouterDir, pkg.Package) routerPackage := util.SubPackage(pkgGen.ProjPackage, routerDir) root := NewRouterTree() if err := pkgGen.genHandler(pkg, handlerDir, handlerPackage, root); err != nil { return err } if err := pkgGen.genRouter(pkg, root, handlerPackage, routerDir, routerPackage); err != nil { return err } if err := pkgGen.genCustomizedFile(pkg); err != nil { return err } return nil } ================================================ FILE: cmd/hz/generator/package_tpl.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package generator var ( routerTplName = "router.go" middlewareTplName = "middleware.go" middlewareSingleTplName = "middleware_single.go" handlerTplName = "handler.go" handlerSingleTplName = "handler_single.go" modelTplName = "model.go" registerTplName = "register.go" clientTplName = "client.go" // generate a default client for server hertzClientTplName = "hertz_client.go" // underlying client for client command idlClientName = "idl_client.go" // client of service for quick call insertPointNew = "//INSERT_POINT: DO NOT DELETE THIS LINE!" insertPointPatternNew = `//INSERT_POINT\: DO NOT DELETE THIS LINE\!` ) var templateNameSet = map[string]string{ routerTplName: routerTplName, middlewareTplName: middlewareTplName, middlewareSingleTplName: middlewareSingleTplName, handlerTplName: handlerTplName, handlerSingleTplName: handlerSingleTplName, modelTplName: modelTplName, registerTplName: registerTplName, clientTplName: clientTplName, hertzClientTplName: hertzClientTplName, idlClientName: idlClientName, } func IsDefaultPackageTpl(name string) bool { if _, exist := templateNameSet[name]; exist { return true } return false } var defaultPkgConfig = TemplateConfig{ Layouts: []Template{ { Path: defaultHandlerDir + sp + handlerTplName, Delims: [2]string{"{{", "}}"}, Body: `// Code generated by hertz generator. package {{.PackageName}} import ( "context" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/protocol/consts" {{- range $k, $v := .Imports}} {{$k}} "{{$v.Package}}" {{- end}} ) {{range $_, $MethodInfo := .Methods}} {{$MethodInfo.Comment}} func {{$MethodInfo.Name}}(ctx context.Context, c *app.RequestContext) { var err error {{if ne $MethodInfo.RequestTypeName "" -}} var req {{$MethodInfo.RequestTypeName}} err = c.BindAndValidate(&req) if err != nil { c.String(consts.StatusBadRequest, err.Error()) return } {{end}} resp := new({{$MethodInfo.ReturnTypeName}}) c.{{.Serializer}}(consts.StatusOK, resp) } {{end}} `, }, { Path: defaultRouterDir + sp + routerTplName, Delims: [2]string{"{{", "}}"}, Body: `// Code generated by hertz generator. DO NOT EDIT. package {{$.PackageName}} import ( "github.com/cloudwego/hertz/pkg/app/server" {{- range $k, $v := .HandlerPackages}} {{$k}} "{{$v}}" {{- end}} ) /* This file will register all the routes of the services in the master idl. And it will update automatically when you use the "update" command for the idl. So don't modify the contents of the file, or your code will be deleted when it is updated. */ {{define "g"}} {{- if eq .Path "/"}}r {{- else}}{{.GroupName}}{{end}} {{- end}} {{define "G"}} {{- if ne .Handler ""}} {{- .GroupName}}.{{.HttpMethod}}("{{.Path}}", append({{.HandlerMiddleware}}Mw(), {{.Handler}})...) {{- end}} {{- if ne (len .Children) 0}} {{.MiddleWare}} := {{template "g" .}}.Group("{{.Path}}", {{.GroupMiddleware}}Mw()...) {{- end}} {{- range $_, $router := .Children}} {{- if ne .Handler ""}} {{template "G" $router}} {{- else}} { {{template "G" $router}} } {{- end}} {{- end}} {{- end}} // Register register routes based on the IDL 'api.${HTTP Method}' annotation. func Register(r *server.Hertz) { {{template "G" .Router}} } `, }, { Path: defaultRouterDir + sp + registerTplName, Body: `// Code generated by hertz generator. DO NOT EDIT. package {{.PackageName}} import ( "github.com/cloudwego/hertz/pkg/app/server" {{$.DepPkgAlias}} "{{$.DepPkg}}" ) // GeneratedRegister registers routers generated by IDL. func GeneratedRegister(r *server.Hertz){ ` + insertPointNew + ` {{$.DepPkgAlias}}.Register(r) } `, }, // Model tpl is imported by model generator. Here only decides model directory. { Path: defaultModelDir + sp + modelTplName, Body: ``, }, { Path: defaultRouterDir + sp + middlewareTplName, Delims: [2]string{"{{", "}}"}, Body: `// Code generated by hertz generator. package {{$.PackageName}} import ( "github.com/cloudwego/hertz/pkg/app" ) {{define "M"}} {{- if ne .Children.Len 0}} func {{.GroupMiddleware}}Mw() []app.HandlerFunc { // your code... return nil } {{end}} {{- if ne .Handler ""}} func {{.HandlerMiddleware}}Mw() []app.HandlerFunc { // your code... return nil } {{end}} {{range $_, $router := $.Children}}{{template "M" $router}}{{end}} {{- end}} {{template "M" .Router}} `, }, { Path: defaultClientDir + sp + clientTplName, Delims: [2]string{"{{", "}}"}, Body: `// Code generated by hertz generator. package {{$.PackageName}} import ( "github.com/cloudwego/hertz/pkg/app/client" "github.com/cloudwego/hertz/pkg/common/config" ) type {{.ServiceName}}Client struct { client * client.Client } func New{{.ServiceName}}Client(opt ...config.ClientOption) (*{{.ServiceName}}Client, error) { c, err := client.NewClient(opt...) if err != nil { return nil, err } return &{{.ServiceName}}Client{ client: c, }, nil } `, }, { Path: defaultHandlerDir + sp + handlerSingleTplName, Delims: [2]string{"{{", "}}"}, Body: ` {{.Comment}} func {{.Name}}(ctx context.Context, c *app.RequestContext) { var err error {{if ne .RequestTypeName "" -}} var req {{.RequestTypeName}} err = c.BindAndValidate(&req) if err != nil { c.String(consts.StatusBadRequest, err.Error()) return } {{end}} resp := new({{.ReturnTypeName}}) c.{{.Serializer}}(consts.StatusOK, resp) } `, }, { Path: defaultRouterDir + sp + middlewareSingleTplName, Delims: [2]string{"{{", "}}"}, Body: ` func {{.MiddleWare}}Mw() []app.HandlerFunc { // your code... return nil } `, }, { Path: defaultRouterDir + sp + hertzClientTplName, Delims: [2]string{"{{", "}}"}, Body: hertzClientTpl, }, { Path: defaultRouterDir + sp + idlClientName, Delims: [2]string{"{{", "}}"}, Body: idlClientTpl, }, }, } var hertzClientTpl = `// Code generated by hz. package {{.PackageName}} import ( "context" "encoding/json" "encoding/xml" "fmt" "io" "net/http" "net/url" "reflect" "regexp" "strings" hertz_client "github.com/cloudwego/hertz/pkg/app/client" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/client" ) type use interface { Use(mws ...hertz_client.Middleware) } // Definition of global data and types. type ResponseResultDecider func(statusCode int, rawResponse *protocol.Response) (isError bool) type ( bindRequestBodyFunc func(c *cli, r *request) (contentType string, body io.Reader, err error) beforeRequestFunc func(*cli, *request) error afterResponseFunc func(*cli, *response) error ) var ( hdrContentTypeKey = http.CanonicalHeaderKey("Content-Type") hdrContentEncodingKey = http.CanonicalHeaderKey("Content-Encoding") plainTextType = "text/plain; charset=utf-8" jsonContentType = "application/json; charset=utf-8" formContentType = "multipart/form-data" jsonCheck = regexp.MustCompile(` + "`(?i:(application|text)/(json|.*\\+json|json\\-.*)(; |$))`)\n" + `xmlCheck = regexp.MustCompile(` + "`(?i:(application|text)/(xml|.*\\+xml)(; |$))`)\n" + ` ) // Configuration of client type Option struct { f func(*Options) } type Options struct { hostUrl string doer client.Doer header http.Header requestBodyBind bindRequestBodyFunc responseResultDecider ResponseResultDecider middlewares []hertz_client.Middleware clientOption []config.ClientOption } func getOptions(ops ...Option) *Options { opts := &Options{} for _, do := range ops { do.f(opts) } return opts } // WithHertzClientOption is used to pass configuration for the hertz client func WithHertzClientOption(opt ...config.ClientOption) Option { return Option{func(op *Options) { op.clientOption = append(op.clientOption, opt...) }} } // WithHertzClientMiddleware is used to register the middleware for the hertz client func WithHertzClientMiddleware(mws ...hertz_client.Middleware) Option { return Option{func(op *Options) { op.middlewares = append(op.middlewares, mws...) }} } // WithHertzClient is used to register a custom hertz client func WithHertzClient(client client.Doer) Option { return Option{func(op *Options) { op.doer = client }} } // WithHeader is used to add the default header, which is carried by every request func WithHeader(header http.Header) Option { return Option{func(op *Options) { op.header = header }} } // WithResponseResultDecider configure custom deserialization of http response to response struct func WithResponseResultDecider(decider ResponseResultDecider) Option { return Option{func(op *Options) { op.responseResultDecider = decider }} } func withHostUrl(HostUrl string) Option { return Option{func(op *Options) { op.hostUrl = HostUrl }} } // underlying client type cli struct { hostUrl string doer client.Doer header http.Header bindRequestBody bindRequestBodyFunc responseResultDecider ResponseResultDecider beforeRequest []beforeRequestFunc afterResponse []afterResponseFunc } func (c *cli) Use(mws ...hertz_client.Middleware) error { u, ok := c.doer.(use) if !ok { return errors.NewPublic("doer does not support middleware, choose the right doer.") } u.Use(mws...) return nil } func newClient(opts *Options) (*cli, error) { if opts.requestBodyBind == nil { opts.requestBodyBind = defaultRequestBodyBind } if opts.responseResultDecider == nil { opts.responseResultDecider = defaultResponseResultDecider } if opts.doer == nil { cli, err := hertz_client.NewClient(opts.clientOption...) if err != nil { return nil, err } opts.doer = cli } c := &cli{ hostUrl: opts.hostUrl, doer: opts.doer, header: opts.header, bindRequestBody: opts.requestBodyBind, responseResultDecider: opts.responseResultDecider, beforeRequest: []beforeRequestFunc{ parseRequestURL, parseRequestHeader, createHTTPRequest, }, afterResponse: []afterResponseFunc{ parseResponseBody, }, } if len(opts.middlewares) != 0 { if err := c.Use(opts.middlewares...); err != nil { return nil, err } } return c, nil } func (c *cli) execute(req *request) (*response, error) { var err error for _, f := range c.beforeRequest { if err = f(c, req); err != nil { return nil, err } } if hostHeader := req.header.Get("Host"); hostHeader != "" { req.rawRequest.Header.SetHost(hostHeader) } resp := protocol.Response{} err = c.doer.Do(req.ctx, req.rawRequest, &resp) response := &response{ request: req, rawResponse: &resp, } if err != nil { return response, err } body, err := resp.BodyE() if err != nil { return nil, err } if strings.EqualFold(resp.Header.Get(hdrContentEncodingKey), "gzip") && resp.Header.ContentLength() != 0 { body, err = resp.BodyGunzip() if err != nil { return nil, err } } response.bodyByte = body response.size = int64(len(response.bodyByte)) // Apply Response middleware for _, f := range c.afterResponse { if err = f(c, response); err != nil { break } } return response, err } // r get request func (c *cli) r() *request { return &request{ queryParam: url.Values{}, header: http.Header{}, pathParam: map[string]string{}, formParam: map[string]string{}, fileParam: map[string]string{}, client: c, queryEnumAsInt: {{.Config.QueryEnumAsInt}}, } } type response struct { request *request rawResponse *protocol.Response bodyByte []byte size int64 } // statusCode method returns the HTTP status code for the executed request. func (r *response) statusCode() int { if r.rawResponse == nil { return 0 } return r.rawResponse.StatusCode() } // body method returns HTTP response as []byte array for the executed request. func (r *response) body() []byte { if r.rawResponse == nil { return []byte{} } return r.bodyByte } // Header method returns the response headers func (r *response) header() http.Header { if r.rawResponse == nil { return http.Header{} } h := http.Header{} r.rawResponse.Header.VisitAll(func(key, value []byte) { h.Add(string(key), string(value)) }) return h } type request struct { client *cli url string method string queryEnumAsInt bool queryParam url.Values header http.Header pathParam map[string]string formParam map[string]string fileParam map[string]string bodyParam interface{} rawRequest *protocol.Request ctx context.Context requestOptions []config.RequestOption result interface{} Error interface{} } func (r *request) setContext(ctx context.Context) *request { r.ctx = ctx return r } func (r *request) context() context.Context { return r.ctx } func (r *request) setHeader(header, value string) *request { r.header.Set(header, value) return r } func (r *request) addHeader(header, value string) *request { r.header.Add(header, value) return r } func (r *request) addHeaders(params map[string]string) *request { for k, v := range params { r.addHeader(k, v) } return r } func (r *request) setQueryParam(param string, value interface{}) *request { if value == nil { return r } v := reflect.ValueOf(value) if v.Kind() == reflect.Pointer && v.IsNil() { return r } switch v.Kind() { case reflect.Slice, reflect.Array: for index := 0; index < v.Len(); index++ { if r.queryEnumAsInt && (v.Index(index).Kind() == reflect.Int32 || v.Index(index).Kind() == reflect.Int64) { r.queryParam.Add(param, fmt.Sprintf("%d", v.Index(index).Interface())) } else { r.queryParam.Add(param, fmt.Sprint(v.Index(index).Interface())) } } case reflect.Int32, reflect.Int64: if r.queryEnumAsInt { r.queryParam.Add(param, fmt.Sprintf("%d", v.Interface())) } else { r.queryParam.Add(param, fmt.Sprint(v)) } default: r.queryParam.Set(param, fmt.Sprint(v)) } return r } func (r *request) setResult(res interface{}) *request { r.result = res return r } func (r *request) setError(err interface{}) *request { r.Error = err return r } func (r *request) setHeaders(headers map[string]string) *request { for h, v := range headers { r.setHeader(h, v) } return r } func (r *request) setQueryParams(params map[string]interface{}) *request { for p, v := range params { r.setQueryParam(p, v) } return r } func (r *request) setPathParams(params map[string]string) *request { for p, v := range params { r.pathParam[p] = v } return r } func (r *request) setFormParams(params map[string]string) *request { for p, v := range params { r.formParam[p] = v } return r } func (r *request) setFormFileParams(params map[string]string) *request { for p, v := range params { r.fileParam[p] = v } return r } func (r *request) setBodyParam(body interface{}) *request { r.bodyParam = body return r } func (r *request) setRequestOption(option ...config.RequestOption) *request { r.requestOptions = append(r.requestOptions, option...) return r } func (r *request) execute(method, url string) (*response, error) { r.method = method r.url = url return r.client.execute(r) } func parseRequestURL(c *cli, r *request) error { if len(r.pathParam) > 0 { for p, v := range r.pathParam { r.url = strings.Replace(r.url, ":"+p, url.PathEscape(v), -1) } } // Parsing request URL reqURL, err := url.Parse(r.url) if err != nil { return err } // If request.URL is relative path then added c.HostURL into // the request URL otherwise request.URL will be used as-is if !reqURL.IsAbs() { r.url = reqURL.String() if len(r.url) > 0 && r.url[0] != '/' { r.url = "/" + r.url } reqURL, err = url.Parse(c.hostUrl + r.url) if err != nil { return err } } // Adding Query Param query := make(url.Values) for k, v := range r.queryParam { // remove query param from client level by key // since overrides happens for that key in the request query.Del(k) for _, iv := range v { query.Add(k, iv) } } if len(query) > 0 { if isStringEmpty(reqURL.RawQuery) { reqURL.RawQuery = query.Encode() } else { reqURL.RawQuery = reqURL.RawQuery + "&" + query.Encode() } } r.url = reqURL.String() return nil } func isStringEmpty(str string) bool { return len(strings.TrimSpace(str)) == 0 } func parseRequestHeader(c *cli, r *request) error { hdr := make(http.Header) if c.header != nil { for k := range c.header { hdr[k] = append(hdr[k], c.header[k]...) } } for k := range r.header { hdr.Del(k) hdr[k] = append(hdr[k], r.header[k]...) } if len(r.formParam) != 0 || len(r.fileParam) != 0 { hdr.Add(hdrContentTypeKey, formContentType) } r.header = hdr return nil } // detectContentType method is used to figure out "request.Body" content type for request header func detectContentType(body interface{}) string { contentType := plainTextType kind := reflect.Indirect(reflect.ValueOf(body)).Kind() switch kind { case reflect.Struct, reflect.Map: contentType = jsonContentType case reflect.String: contentType = plainTextType default: if b, ok := body.([]byte); ok { contentType = http.DetectContentType(b) } else if kind == reflect.Slice { contentType = jsonContentType } } return contentType } func defaultRequestBodyBind(c *cli, r *request) (contentType string, body io.Reader, err error) { if !isPayloadSupported(r.method) { return } var bodyBytes []byte contentType = r.header.Get(hdrContentTypeKey) if isStringEmpty(contentType) { contentType = detectContentType(r.bodyParam) r.header.Set(hdrContentTypeKey, contentType) } kind := reflect.Indirect(reflect.ValueOf(r.bodyParam)).Kind() if isJSONType(contentType) && (kind == reflect.Struct || kind == reflect.Map || kind == reflect.Slice) { bodyBytes, err = json.Marshal(r.bodyParam) } else if isXMLType(contentType) && (kind == reflect.Struct) { bodyBytes, err = xml.Marshal(r.bodyParam) } if err != nil { return } return contentType, strings.NewReader(string(bodyBytes)), nil } func isPayloadSupported(m string) bool { return !(m == http.MethodHead || m == http.MethodOptions || m == http.MethodGet || m == http.MethodDelete) } func createHTTPRequest(c *cli, r *request) (err error) { contentType, body, err := c.bindRequestBody(c, r) if !isStringEmpty(contentType) { r.header.Set(hdrContentTypeKey, contentType) } if err == nil { r.rawRequest = protocol.NewRequest(r.method, r.url, body) if contentType == formContentType && isPayloadSupported(r.method) { if r.rawRequest.IsBodyStream() { r.rawRequest.ResetBody() } r.rawRequest.SetMultipartFormData(r.formParam) r.rawRequest.SetFiles(r.fileParam) } for key, values := range r.header { for _, val := range values { r.rawRequest.Header.Add(key, val) } } r.rawRequest.SetOptions(r.requestOptions...) } return err } func silently(_ ...interface{}) {} // defaultResponseResultDecider method returns true if HTTP status code >= 400 otherwise false. func defaultResponseResultDecider(statusCode int, rawResponse *protocol.Response) bool { return statusCode > 399 } // IsJSONType method is to check JSON content type or not func isJSONType(ct string) bool { return jsonCheck.MatchString(ct) } // IsXMLType method is to check XML content type or not func isXMLType(ct string) bool { return xmlCheck.MatchString(ct) } func parseResponseBody(c *cli, res *response) (err error) { if res.statusCode() == http.StatusNoContent { return } // Handles only JSON or XML content type ct := res.header().Get(hdrContentTypeKey) isError := c.responseResultDecider(res.statusCode(), res.rawResponse) if isError { if res.request.Error != nil { if isJSONType(ct) || isXMLType(ct) { err = unmarshalContent(ct, res.bodyByte, res.request.Error) } } else { jsonByte, jsonErr := json.Marshal(map[string]interface{}{ "status_code": res.rawResponse.StatusCode(), "body": string(res.bodyByte), }) if jsonErr != nil { return jsonErr } err = errors.NewPublic(string(jsonByte)) } } else if res.request.result != nil { if isJSONType(ct) || isXMLType(ct) { err = unmarshalContent(ct, res.bodyByte, res.request.result) return } } return } // unmarshalContent content into object from JSON or XML func unmarshalContent(ct string, b []byte, d interface{}) (err error) { if isJSONType(ct) { err = json.Unmarshal(b, d) } else if isXMLType(ct) { err = xml.Unmarshal(b, d) } return } ` var idlClientTpl = `// Code generated by hertz generator. package {{.PackageName}} import ( "context" "fmt" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/protocol" {{- range $k, $v := .Imports}} {{$k}} "{{$v.Package}}" {{- end}} ) // unused protection var ( _ = fmt.Formatter(nil) ) type Client interface { {{range $_, $MethodInfo := .ClientMethods}} {{$MethodInfo.Name}}(context context.Context, req *{{$MethodInfo.RequestTypeName}}, reqOpt ...config.RequestOption) (resp *{{$MethodInfo.ReturnTypeName}}, rawResponse *protocol.Response, err error) {{end}} } type {{.ServiceName}}Client struct { client *cli } func New{{.ServiceName}}Client(hostUrl string, ops ...Option) (Client, error) { opts := getOptions(append(ops, withHostUrl(hostUrl))...) cli, err := newClient(opts) if err != nil { return nil, err } return &{{.ServiceName}}Client{ client: cli, }, nil } {{range $_, $MethodInfo := .ClientMethods}} func (s *{{$.ServiceName}}Client) {{$MethodInfo.Name}}(context context.Context, req *{{$MethodInfo.RequestTypeName}}, reqOpt ...config.RequestOption) (resp *{{$MethodInfo.ReturnTypeName}}, rawResponse *protocol.Response, err error) { httpResp := &{{$MethodInfo.ReturnTypeName}}{} ret, err := s.client.r(). setContext(context). setQueryParams(map[string]interface{}{ {{$MethodInfo.QueryParamsCode}} }). setPathParams(map[string]string{ {{$MethodInfo.PathParamsCode}} }). addHeaders(map[string]string{ {{$MethodInfo.HeaderParamsCode}} }). setFormParams(map[string]string{ {{$MethodInfo.FormValueCode}} }). setFormFileParams(map[string]string{ {{$MethodInfo.FormFileCode}} }). {{$MethodInfo.BodyParamsCode}} setRequestOption(reqOpt...). setResult(httpResp). execute("{{if EqualFold $MethodInfo.HTTPMethod "Any"}}POST{{else}}{{ $MethodInfo.HTTPMethod }}{{end}}", "{{$MethodInfo.Path}}") if err != nil { return nil, nil, err } resp = httpResp rawResponse = ret.rawResponse return resp, rawResponse, nil } {{end}} var defaultClient, _ = New{{.ServiceName}}Client("{{.BaseDomain}}") func ConfigDefaultClient(ops ...Option) (err error) { defaultClient, err = New{{.ServiceName}}Client("{{.BaseDomain}}", ops...) return } {{range $_, $MethodInfo := .ClientMethods}} func {{$MethodInfo.Name}}(context context.Context, req *{{$MethodInfo.RequestTypeName}}, reqOpt ...config.RequestOption) (resp *{{$MethodInfo.ReturnTypeName}}, rawResponse *protocol.Response, err error) { return defaultClient.{{$MethodInfo.Name}}(context, req, reqOpt...) } {{end}} ` ================================================ FILE: cmd/hz/generator/router.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package generator import ( "bytes" "fmt" "io/ioutil" "math" "path/filepath" "regexp" "sort" "strconv" "strings" "unicode" "github.com/cloudwego/hertz/cmd/hz/util" ) type Router struct { FilePath string PackageName string HandlerPackages map[string]string // {{basename}}:{{import_path}} Router *RouterNode } type RouterNode struct { GroupName string // current group name(the parent middleware name), used to register route. example: {{.GroupName}}.{{HttpMethod}} MiddleWare string // current node middleware, used to be group name for children. HandlerMiddleware string GroupMiddleware string PathPrefix string Path string Parent *RouterNode Children childrenRouterInfo Handler string // {{HandlerPackage}}.{{HandlerName}} HandlerPackage string HandlerPackageAlias string HttpMethod string } type RegisterInfo struct { PackageName string DepPkgAlias string DepPkg string } // NewRouterTree contains "/" as root node func NewRouterTree() *RouterNode { return &RouterNode{ GroupName: "root", MiddleWare: "root", GroupMiddleware: "root", Path: "/", Parent: nil, } } func (routerNode *RouterNode) Sort() { sort.Sort(routerNode.Children) } func (routerNode *RouterNode) Update(method *HttpMethod, handlerType, handlerPkg string, sortRouter bool) error { if method.Path == "" { return fmt.Errorf("empty path for method '%s'", method.Name) } paths := strings.Split(method.Path, "/") if paths[0] == "" { paths = paths[1:] } parent, last := routerNode.FindNearest(paths, method.HTTPMethod, sortRouter) if last == len(paths) { return fmt.Errorf("path '%s' has been registered", method.Path) } name := util.ToVarName(paths[:last]) parent.Insert(name, method, handlerType, paths[last:], handlerPkg, sortRouter) parent.Sort() return nil } func (routerNode *RouterNode) RawHandlerName() string { parts := strings.Split(routerNode.Handler, ".") handlerName := parts[len(parts)-1] return handlerName } // DyeGroupName traverses the routing tree in depth and names the handler/group middleware for each node. // If snakeStyleMiddleware is set to true, the name style of the middleware will use snake name style. func (routerNode *RouterNode) DyeGroupName(snakeStyleMiddleware bool) error { groups := []string{"root"} hook := func(layer int, node *RouterNode) error { node.GroupName = groups[layer] if node.MiddleWare == "" { pname := node.Path if len(pname) > 1 && pname[0] == '/' { pname = pname[1:] } if node.Parent != nil { node.PathPrefix = node.Parent.PathPrefix + "_" + util.ToGoFuncName(pname) } else { node.PathPrefix = "_" + util.ToGoFuncName(pname) } handlerMiddlewareName := "" isLeafNode := false if len(node.Handler) != 0 { handlerMiddlewareName = node.RawHandlerName() // If it is a leaf node, then "group middleware name" and "handler middleware name" are the same if len(node.Children) == 0 { pname = handlerMiddlewareName isLeafNode = true } } pname = convertToMiddlewareName(pname) handlerMiddlewareName = convertToMiddlewareName(handlerMiddlewareName) if isLeafNode { name, err := util.GetMiddlewareUniqueName(pname) if err != nil { return fmt.Errorf("get unique name for middleware '%s' failed, err: %v", name, err) } pname = name handlerMiddlewareName = name } else { var err error pname, err = util.GetMiddlewareUniqueName(pname) if err != nil { return fmt.Errorf("get unique name for middleware '%s' failed, err: %v", pname, err) } handlerMiddlewareName, err = util.GetMiddlewareUniqueName(handlerMiddlewareName) if err != nil { return fmt.Errorf("get unique name for middleware '%s' failed, err: %v", handlerMiddlewareName, err) } } node.MiddleWare = "_" + pname if len(node.Handler) != 0 { node.HandlerMiddleware = "_" + handlerMiddlewareName if snakeStyleMiddleware { node.HandlerMiddleware = "_" + node.RawHandlerName() } } node.GroupMiddleware = node.MiddleWare if snakeStyleMiddleware { node.GroupMiddleware = node.PathPrefix } } if layer >= len(groups)-1 { groups = append(groups, node.MiddleWare) } else { groups[layer+1] = node.MiddleWare } return nil } // Deep traversal from the 0th level of the routing tree. err := routerNode.DFS(0, hook) return err } func (routerNode *RouterNode) DFS(i int, hook func(layer int, node *RouterNode) error) error { if routerNode == nil { return nil } err := hook(i, routerNode) if err != nil { return err } for _, n := range routerNode.Children { err = n.DFS(i+1, hook) if err != nil { return err } } return nil } var handlerPkgMap map[string]string func (routerNode *RouterNode) Insert(name string, method *HttpMethod, handlerType string, paths []string, handlerPkg string, sortRouter bool) { cur := routerNode for i, p := range paths { c := &RouterNode{ Path: "/" + p, Parent: cur, } if i == len(paths)-1 { // generate handler by method if len(handlerPkg) != 0 { // get a unique package alias for every handler pkgAlias := filepath.Base(handlerPkg) pkgAlias = util.ToVarName([]string{pkgAlias}) val, exist := handlerPkgMap[handlerPkg] if !exist { pkgAlias, _ = util.GetHandlerPackageUniqueName(pkgAlias) if len(handlerPkgMap) == 0 { handlerPkgMap = make(map[string]string, 10) } handlerPkgMap[handlerPkg] = pkgAlias } else { pkgAlias = val } c.HandlerPackageAlias = pkgAlias c.Handler = pkgAlias + "." + method.Name c.HandlerPackage = handlerPkg method.RefPackage = c.HandlerPackage method.RefPackageAlias = c.HandlerPackageAlias } else { // generate handler by service c.Handler = handlerType + "." + method.Name if len(method.RefPackage) != 0 { c.Handler = method.RefPackageAlias + "." + method.Name c.HandlerPackageAlias = method.RefPackageAlias c.HandlerPackage = method.RefPackage } } c.HttpMethod = getHttpMethod(method.HTTPMethod) } if cur.Children == nil { cur.Children = make([]*RouterNode, 0, 1) } cur.Children = append(cur.Children, c) if sortRouter { sort.Sort(cur.Children) } cur = c } } func getHttpMethod(method string) string { if strings.EqualFold(method, "Any") { return "Any" } return strings.ToUpper(method) } func (routerNode *RouterNode) FindNearest(paths []string, method string, sortRouter bool) (*RouterNode, int) { ns := len(paths) cur := routerNode i := 0 path := paths[i] for j := 0; j < len(cur.Children); j++ { c := cur.Children[j] tmpMethod := "" // group do not have http method if i == ns { // only i==ns, the path is http method node tmpMethod = method } if ("/" + path) == c.Path { if sortRouter && !strings.EqualFold(c.HttpMethod, tmpMethod) { continue } i++ if i == ns { return cur, i - 1 } path = paths[i] cur = c j = -1 } } return cur, i } type childrenRouterInfo []*RouterNode // Len is the number of elements in the collection. func (c childrenRouterInfo) Len() int { return len(c) } // Less reports whether the element with // index i should sort before the element with index j. func (c childrenRouterInfo) Less(i, j int) bool { if c[i].HttpMethod == "" && c[j].HttpMethod != "" { return false } if c[i].HttpMethod != "" && c[j].HttpMethod == "" { return true } // remove non-litter char // eg. /a -> a // /:a -> a ci := removeNonLetterPrefix(c[i].Path) cj := removeNonLetterPrefix(c[j].Path) // if ci == cj, use HTTP method for sort, preventing sorting inconsistencies if ci == cj { return c[i].HttpMethod < c[j].HttpMethod } return ci < cj } func removeNonLetterPrefix(str string) string { for i, char := range str { if unicode.IsLetter(char) || unicode.IsDigit(char) { return str[i:] } } return str } // Swap swaps the elements with indexes i and j. func (c childrenRouterInfo) Swap(i, j int) { c[i], c[j] = c[j], c[i] } var ( regRegisterV3 = regexp.MustCompile(insertPointPatternNew) regImport = regexp.MustCompile(`import \(\n`) ) func (pkgGen *HttpPackageGenerator) updateRegister(pkg, rDir, pkgName string) error { if pkgGen.tplsInfo[registerTplName].Disable { return nil } register := RegisterInfo{ PackageName: filepath.Base(rDir), DepPkgAlias: strings.ReplaceAll(pkgName, "/", "_"), DepPkg: pkg, } registerPath := filepath.Join(rDir, registerTplName) isExist, err := util.PathExist(registerPath) if err != nil { return err } if !isExist { return pkgGen.TemplateGenerator.Generate(register, registerTplName, registerPath, false) } file, err := ioutil.ReadFile(registerPath) if err != nil { return fmt.Errorf("read register '%s' failed, err: %v", registerPath, err.Error()) } insertReg := register.DepPkgAlias + ".Register(r)\n" if !checkDupRegister(file, insertReg) { file, err = util.AddImport(registerPath, register.DepPkgAlias, register.DepPkg) if err != nil { return err } subIndexReg := regRegisterV3.FindSubmatchIndex(file) if len(subIndexReg) != 2 || subIndexReg[0] < 1 { return fmt.Errorf("wrong format %s: insert-point '%s' not found", string(file), insertPointPatternNew) } bufReg := bytes.NewBuffer(nil) bufReg.Write(file[:subIndexReg[1]]) bufReg.WriteString("\n\t" + insertReg) bufReg.Write(file[subIndexReg[1]:]) pkgGen.files = append(pkgGen.files, File{registerPath, string(bufReg.Bytes()), false, registerTplName}) } return nil } func checkDupRegister(file []byte, insertReg string) bool { return bytes.Contains(file, []byte("\t"+insertReg)) || bytes.Contains(file, []byte(" "+insertReg)) } func appendMw(mws []string, mw string) ([]string, string) { for i := 0; true; i++ { if i == math.MaxInt { break } if !stringsIncludes(mws, mw) { mws = append(mws, mw) break } mw += strconv.Itoa(i) } return mws, mw } func stringsIncludes(strs []string, str string) bool { for _, s := range strs { if s == str { return true } } return false } func (pkgGen *HttpPackageGenerator) genRouter(pkg *HttpPackage, root *RouterNode, handlerPackage, routerDir, routerPackage string) error { err := root.DyeGroupName(pkgGen.SnakeStyleMiddleware) if err != nil { return err } router := Router{ FilePath: filepath.Join(routerDir, util.BaseNameAndTrim(pkg.IdlName)+".go"), PackageName: filepath.Base(routerDir), HandlerPackages: map[string]string{ util.BaseName(handlerPackage, ""): handlerPackage, }, Router: root, } handlerMap := make(map[string]string) hook := func(layer int, node *RouterNode) error { if len(node.HandlerPackage) != 0 { handlerMap[node.HandlerPackageAlias] = node.HandlerPackage } return nil } root.DFS(0, hook) if len(handlerMap) != 0 { router.HandlerPackages = handlerMap } if pkgGen.SnakeStyleMiddleware { // unique middleware name for SnakeStyleMiddleware mws := []string{} hook := func(layer int, node *RouterNode) error { if len(node.Children) == 0 { return nil } groupMwName := node.GroupMiddleware handlerMwName := node.HandlerMiddleware if len(groupMwName) != 0 { mws, groupMwName = appendMw(mws, groupMwName) } if len(handlerMwName) != 0 { mws, handlerMwName = appendMw(mws, handlerMwName) } if groupMwName != node.GroupMiddleware { node.GroupMiddleware = groupMwName } if handlerMwName != node.HandlerMiddleware { node.HandlerMiddleware = handlerMwName } return nil } root.DFS(0, hook) } // store router info pkg.RouterInfo = &router if !pkgGen.tplsInfo[routerTplName].Disable { if err := pkgGen.TemplateGenerator.Generate(router, routerTplName, router.FilePath, false); err != nil { return fmt.Errorf("generate router %s failed, err: %v", router.FilePath, err.Error()) } } if err := pkgGen.updateMiddlewareReg(router, middlewareTplName, filepath.Join(routerDir, "middleware.go")); err != nil { return fmt.Errorf("generate middleware %s failed, err: %v", filepath.Join(routerDir, "middleware.go"), err.Error()) } if err := pkgGen.updateRegister(routerPackage, pkgGen.RouterDir, pkg.Package); err != nil { return fmt.Errorf("update register for %s failed, err: %v", filepath.Join(routerDir, registerTplName), err.Error()) } return nil } func (pkgGen *HttpPackageGenerator) updateMiddlewareReg(router interface{}, middlewareTpl, filePath string) error { if pkgGen.tplsInfo[middlewareTpl].Disable { return nil } isExist, err := util.PathExist(filePath) if err != nil { return err } if !isExist { return pkgGen.TemplateGenerator.Generate(router, middlewareTpl, filePath, false) } var middlewareList []string _ = router.(Router).Router.DFS(0, func(layer int, node *RouterNode) error { // non-leaf node will generate group middleware if node.Children.Len() > 0 && len(node.GroupMiddleware) > 0 { middlewareList = append(middlewareList, node.GroupMiddleware) } if len(node.HandlerMiddleware) > 0 { middlewareList = append(middlewareList, node.HandlerMiddleware) } return nil }) file, err := ioutil.ReadFile(filePath) if err != nil { return err } for _, mw := range middlewareList { mwNamePattern := fmt.Sprintf(" %sMw", mw) if pkgGen.SnakeStyleMiddleware { mwNamePattern = fmt.Sprintf(" %s_mw", mw) } if bytes.Contains(file, []byte(mwNamePattern)) { continue } middlewareSingleTpl := pkgGen.tpls[middlewareSingleTplName] if middlewareSingleTpl == nil { return fmt.Errorf("tpl %s not found", middlewareSingleTplName) } data := make(map[string]string, 1) data["MiddleWare"] = mw middlewareFunc := bytes.NewBuffer(nil) err = middlewareSingleTpl.Execute(middlewareFunc, data) if err != nil { return fmt.Errorf("execute template \"%s\" failed, %v", middlewareSingleTplName, err) } buf := bytes.NewBuffer(nil) _, err = buf.Write(file) if err != nil { return fmt.Errorf("write middleware \"%s\" failed, %v", mw, err) } _, err = buf.Write(middlewareFunc.Bytes()) if err != nil { return fmt.Errorf("write middleware \"%s\" failed, %v", mw, err) } file = buf.Bytes() } pkgGen.files = append(pkgGen.files, File{filePath, string(file), false, middlewareTplName}) return nil } // convertToMiddlewareName converts a route path to a middleware name func convertToMiddlewareName(path string) string { path = util.ToVarName([]string{path}) path = strings.ToLower(path) return path } ================================================ FILE: cmd/hz/generator/router_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package generator import "testing" func Test_checkDupRegister(t *testing.T) { type args struct { file []byte insertReg string } tests := []struct { name string args args want bool }{ { name: "dup tab", args: args{ file: []byte("package main\n\nimport (\n\t\"hertz.io/hertz/pkg/app\"\n)\n\nfunc register() {\n\tapp.Register(r)\n}"), insertReg: "app.Register(r)\n", }, want: true, }, { name: "dup space", args: args{ file: []byte("package main\n\nimport (\n\t\"hertz.io/hertz/pkg/app\"\n)\n\nfunc register() {\n app.Register(r)\n}"), insertReg: "app.Register(r)\n", }, want: true, }, { name: "not dup prefix", args: args{ file: []byte("package main\n\nimport (\n\t\"hertz.io/hertz/pkg/app_2\"\n)\n\nfunc register() {\n\tapp_2.Register(r)\n}"), insertReg: "app.Register(r)\n", }, want: false, }, { name: "not dup subfix", args: args{ file: []byte("package main\n\nimport (\n\t\"hertz.io/hertz/pkg/xapp\"\n)\n\nfunc register() {\n xapp.Register(r)\n}"), insertReg: "app.Register(r)\n", }, want: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := checkDupRegister(tt.args.file, tt.args.insertReg); got != tt.want { t.Errorf("checkDupRegister() = %v, want %v", got, tt.want) } }) } } ================================================ FILE: cmd/hz/generator/template.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package generator import ( "bytes" "errors" "fmt" "os" "path/filepath" "strings" "text/template" "github.com/cloudwego/hertz/cmd/hz/meta" "github.com/cloudwego/hertz/cmd/hz/util" "github.com/cloudwego/hertz/cmd/hz/util/logs" ) var DefaultDelimiters = [2]string{"{{", "}}"} type TemplateConfig struct { Layouts []Template `yaml:"layouts"` } const ( Skip = "skip" Cover = "cover" Append = "append" ) type Template struct { Default bool // Is it the default template Path string `yaml:"path"` // The generated path and its filename, such as biz/handler/ping.go Delims [2]string `yaml:"delims"` // Template Action Instruction Identifier, default: "{{}}" Body string `yaml:"body"` // Render template, currently only supports go template syntax Disable bool `yaml:"disable"` // Disable generating file, used to disable default package template LoopMethod bool `yaml:"loop_method"` // Loop generate files based on "method" LoopService bool `yaml:"loop_service"` // Loop generate files based on "service" UpdateBehavior UpdateBehavior `yaml:"update_behavior"` // Update command behavior; 0:unchanged, 1:regenerate, 2:append } type UpdateBehavior struct { Type string `yaml:"type"` // Update behavior type: skip/cover/append // the following variables are used for append update AppendKey string `yaml:"append_key"` // Append content based in key; for example: 'method'/'service' InsertKey string `yaml:"insert_key"` // Insert content by "insert_key" AppendTpl string `yaml:"append_content_tpl"` // Append content if UpdateBehavior is "append" ImportTpl []string `yaml:"import_tpl"` // Import insert template AppendLocation string `yaml:"append_location"` // AppendLocation specifies the location of append, the default is the end of the file } // TemplateGenerator contains information about the output template type TemplateGenerator struct { OutputDir string Config *TemplateConfig Excludes []string tpls map[string]*template.Template // "template name" -> "Template", it is used get the "parsed template" directly tplsInfo map[string]*Template // "template name" -> "template info", it is used to get the original "template information" dirs map[string]bool isPackageTpl bool files []File excludedFiles map[string]*File } func (tg *TemplateGenerator) Init() error { if tg.Config == nil { return errors.New("config not set yet") } if tg.tpls == nil { tg.tpls = make(map[string]*template.Template, len(tg.Config.Layouts)) } if tg.tplsInfo == nil { tg.tplsInfo = make(map[string]*Template, len(tg.Config.Layouts)) } if tg.dirs == nil { tg.dirs = make(map[string]bool) } for _, l := range tg.Config.Layouts { if tg.isPackageTpl && IsDefaultPackageTpl(l.Path) { continue } // check if is a directory var noFile bool if strings.HasSuffix(l.Path, string(filepath.Separator)) { noFile = true } path := l.Path if filepath.IsAbs(path) { return fmt.Errorf("absolute template path '%s' is not allowed", path) } dir := filepath.Dir(path) isExist, err := util.PathExist(filepath.Join(tg.OutputDir, dir)) if err != nil { return fmt.Errorf("check directory '%s' failed, err: %v", dir, err.Error()) } if isExist { tg.dirs[dir] = true } else { tg.dirs[dir] = false } if noFile { continue } // parse templates if _, ok := tg.tpls[path]; ok { continue } err = tg.loadLayout(l, path, false) if err != nil { return err } } excludes := make(map[string]*File, len(tg.Excludes)) for _, f := range tg.Excludes { excludes[f] = &File{} } tg.excludedFiles = excludes return nil } func (tg *TemplateGenerator) loadLayout(layout Template, tplName string, isDefaultTpl bool) error { delims := DefaultDelimiters if layout.Delims[0] != "" && layout.Delims[1] != "" { delims = layout.Delims } // insert template funcs tpl := template.New(tplName).Funcs(funcMap) tpl = tpl.Delims(delims[0], delims[1]) var err error if tpl, err = tpl.Parse(layout.Body); err != nil { return fmt.Errorf("parse template '%s' failed, err: %v", tplName, err.Error()) } layout.Default = isDefaultTpl tg.tpls[tplName] = tpl tg.tplsInfo[tplName] = &layout return nil } func (tg *TemplateGenerator) Generate(input interface{}, tplName, filepath string, noRepeat bool) error { // check if "*" (global scope) data exists, and stores it to all var all map[string]interface{} if data, ok := input.(map[string]interface{}); ok { ad, ok := data["*"] if ok { all = ad.(map[string]interface{}) } if all == nil { all = map[string]interface{}{} } all["hzVersion"] = meta.Version } file := bytes.NewBuffer(nil) if tplName != "" { tpl := tg.tpls[tplName] if tpl == nil { return fmt.Errorf("tpl %s not found", tplName) } if err := tpl.Execute(file, input); err != nil { return fmt.Errorf("render template '%s' failed, err: %v", tplName, err.Error()) } in := File{filepath, string(file.Bytes()), noRepeat, tplName} tg.files = append(tg.files, in) return nil } for path, tpl := range tg.tpls { file.Reset() var fd interface{} // search and merge rendering data if data, ok := input.(map[string]interface{}); ok { td := map[string]interface{}{} tmp, ok := data[path] if ok { td = tmp.(map[string]interface{}) } for k, v := range all { td[k] = v } fd = td } else { fd = input } if err := tpl.Execute(file, fd); err != nil { return fmt.Errorf("render template '%s' failed, err: %v", path, err.Error()) } in := File{path, string(file.Bytes()), noRepeat, tpl.Name()} tg.files = append(tg.files, in) } return nil } func (tg *TemplateGenerator) Persist() error { files := tg.files outPath := tg.OutputDir if !filepath.IsAbs(outPath) { outPath, _ = filepath.Abs(outPath) } for _, data := range files { // check for -E flags if _, ok := tg.excludedFiles[filepath.Join(data.Path)]; ok { continue } // lint file if err := data.Lint(); err != nil { return err } // create rendered file abPath := filepath.Join(outPath, data.Path) abDir := filepath.Dir(abPath) isExist, err := util.PathExist(abDir) if err != nil { return fmt.Errorf("check directory '%s' failed, err: %v", abDir, err.Error()) } if !isExist { if err := os.MkdirAll(abDir, os.FileMode(0o744)); err != nil { return fmt.Errorf("mkdir %s failed, err: %v", abDir, err.Error()) } } err = func() error { fileMode := os.FileMode(0o644) if strings.HasSuffix(abPath, ".sh") { fileMode = os.FileMode(0o755) } if err := os.WriteFile(abPath, []byte(data.Content), fileMode); err != nil { return fmt.Errorf("write file '%s' failed, err: %v", abPath, err) } return nil }() if err != nil { return err } } tg.files = tg.files[:0] return nil } func (tg *TemplateGenerator) GetFormatAndExcludedFiles() ([]File, error) { var files []File outPath := tg.OutputDir if !filepath.IsAbs(outPath) { outPath, _ = filepath.Abs(outPath) } for _, data := range tg.Files() { if _, ok := tg.excludedFiles[filepath.Join(data.Path)]; ok { continue } // check repeat files logs.Infof("Write %s", data.Path) isExist, err := util.PathExist(filepath.Join(data.Path)) if err != nil { return nil, fmt.Errorf("check file '%s' failed, err: %v", data.Path, err.Error()) } if isExist && data.NoRepeat { if data.FileTplName == handlerTplName { logs.Warnf("Handler file(%s) has been generated.\n If you want to re-generate it, please copy and delete the file to prevent the already written code from being deleted.", data.Path) } else if data.FileTplName == routerTplName { logs.Warnf("Router file(%s) has been generated.\n If you want to re-generate it, please delete the file.", data.Path) } else { logs.Warnf("file '%s' already exists, so drop the generated file", data.Path) } continue } // lint file if err := data.Lint(); err != nil { logs.Warnf("Lint file: %s failed:\n %s\n", data.Path, data.Content) } files = append(files, data) } return files, nil } func (tg *TemplateGenerator) Files() []File { return tg.files } func (tg *TemplateGenerator) Degenerate() error { outPath := tg.OutputDir if !filepath.IsAbs(outPath) { outPath, _ = filepath.Abs(outPath) } for path := range tg.tpls { abPath := filepath.Join(outPath, path) if err := os.RemoveAll(abPath); err != nil { return fmt.Errorf("remove file '%s' failed, err: %v", path, err.Error()) } } for dir, exist := range tg.dirs { if !exist { abDir := filepath.Join(outPath, dir) if err := os.RemoveAll(abDir); err != nil { return fmt.Errorf("remove directory '%s' failed, err: %v", dir, err.Error()) } } } return nil } ================================================ FILE: cmd/hz/generator/template_funcs.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package generator import ( "strings" "text/template" "github.com/Masterminds/sprig/v3" "github.com/cloudwego/hertz/cmd/hz/util" ) var funcMap = func() template.FuncMap { m := template.FuncMap{ "GetUniqueHandlerOutDir": getUniqueHandlerOutDir, "ToSnakeCase": util.ToSnakeCase, "Split": strings.Split, "Trim": strings.Trim, "EqualFold": strings.EqualFold, } for key, f := range sprig.TxtFuncMap() { m[key] = f } return m }() // getUniqueHandlerOutDir uses to get unique "api.handler_path" func getUniqueHandlerOutDir(methods []*HttpMethod) (ret []string) { outDirMap := make(map[string]string) for _, method := range methods { if _, exist := outDirMap[method.OutputDir]; !exist { outDirMap[method.OutputDir] = method.OutputDir ret = append(ret, method.OutputDir) } } return ret } ================================================ FILE: cmd/hz/go.mod ================================================ module github.com/cloudwego/hertz/cmd/hz go 1.16 require ( github.com/Masterminds/sprig/v3 v3.2.3 github.com/cloudwego/thriftgo v0.4.2-0.20250604064713-0e1e704080b1 github.com/hashicorp/go-version v1.5.0 github.com/jhump/protoreflect v1.12.0 github.com/urfave/cli/v2 v2.23.0 golang.org/x/tools v0.6.0 google.golang.org/protobuf v1.28.0 gopkg.in/yaml.v2 v2.4.0 ) ================================================ FILE: cmd/hz/go.sum ================================================ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v1.1.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI= github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU= github.com/Masterminds/semver/v3 v3.2.0 h1:3MEsd0SM6jqZojhjLWWeBY+Kcjy9i6MQAeY7YgDP83g= github.com/Masterminds/semver/v3 v3.2.0/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= github.com/Masterminds/sprig/v3 v3.2.3 h1:eL2fZNezLomi0uOLqjQoN6BfsDD+fyLtgbJMAj9n6YA= github.com/Masterminds/sprig/v3 v3.2.3/go.mod h1:rXcFaZ2zZbLRJv/xSysmlgIM1u11eBaRMhvYXJNkGuM= github.com/bytedance/gopkg v0.1.1 h1:3azzgSkiaw79u24a+w9arfH8OfnQQ4MHUt9lJFREEaE= github.com/bytedance/gopkg v0.1.1/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cloudwego/gopkg v0.1.4 h1:EoQiCG4sTonTPHxOGE0VlQs+sQR+Hsi2uN0qqwu8O50= github.com/cloudwego/gopkg v0.1.4/go.mod h1:FQuXsRWRsSqJLsMVd5SYzp8/Z1y5gXKnVvRrWUOsCMI= github.com/cloudwego/thriftgo v0.4.2-0.20250604064713-0e1e704080b1 h1:iuQJK+ZEtb0uA9cTjWW65gj2R0UB03GFZRD8IwAbDaE= github.com/cloudwego/thriftgo v0.4.2-0.20250604064713-0e1e704080b1/go.mod h1:/D4zRAEj1t3/Tq1bVGDMnRt3wxpHfalXfZWvq/n4YmY= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hashicorp/go-version v1.5.0 h1:O293SZ2Eg+AAYijkVK3jR786Am1bhDEh2GHT0tIVE5E= github.com/hashicorp/go-version v1.5.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/huandu/xstrings v1.3.3 h1:/Gcsuc1x8JVbJ9/rlye4xZnVAbEkGauT8lbebqcQws4= github.com/huandu/xstrings v1.3.3/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/imdario/mergo v0.3.11 h1:3tnifQM4i+fbajXKBHXWEH+KvNHqojZ778UH75j3bGA= github.com/imdario/mergo v0.3.11/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA= github.com/jhump/gopoet v0.0.0-20190322174617-17282ff210b3/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI= github.com/jhump/gopoet v0.1.0/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI= github.com/jhump/goprotoc v0.5.0/go.mod h1:VrbvcYrQOrTi3i0Vf+m+oqQWk9l72mjkJCYo7UvLHRQ= github.com/jhump/protoreflect v1.11.0/go.mod h1:U7aMIjN0NWq9swDP7xDdoMfRHb35uiuTd3Z9nFXJf5E= github.com/jhump/protoreflect v1.12.0 h1:1NQ4FpWMgn3by/n1X0fbeKEUxP1wBt7+Oitpv01HR10= github.com/jhump/protoreflect v1.12.0/go.mod h1:JytZfP5d0r8pVNLZvai7U/MCuTWITgrI4tTg7puQFKI= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/mitchellh/copystructure v1.0.0 h1:Laisrj+bAB6b/yJwB5Bt3ITZhGJdqmxquMKeZ+mmkFQ= github.com/mitchellh/copystructure v1.0.0/go.mod h1:SNtv71yrdKgLRyLFxmLdkAbkKEFWgYaq1OVrnRcwhnw= github.com/mitchellh/reflectwalk v1.0.0 h1:9D+8oIskB4VJBN5SFlmc27fSlIBZaov1Wpk/IfikLNY= github.com/mitchellh/reflectwalk v1.0.0/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng= github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/urfave/cli/v2 v2.23.0 h1:pkly7gKIeYv3olPAeNajNpLjeJrmTPYCoZWaV+2VfvE= github.com/urfave/cli/v2 v2.23.0/go.mod h1:1CNUng3PtjQMtRzJO4FMXBQvkGtuYRxxiR9xMa7jMwI= github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU= github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.19.0/go.mod h1:2CuTdWZ7KHSQwUzKva0cbMg6q2DMI3Mmxp+gKJbskEk= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 h1:+kGHl1aib/qcwaRi1CbqBZ1rk19r85MNUf8HaBghugY= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.38.0 h1:/9BgsAsa5nWe26HqOlvlgJnqBuktYOLCgjCPqsa56W0= google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= ================================================ FILE: cmd/hz/main.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package main import ( "os" "github.com/cloudwego/hertz/cmd/hz/app" "github.com/cloudwego/hertz/cmd/hz/util/logs" ) func main() { // run in plugin mode app.PluginMode() // run in normal mode Run() } func Run() { defer func() { logs.Flush() }() cli := app.Init() err := cli.Run(os.Args) if err != nil { logs.Errorf("%v\n", err) } } ================================================ FILE: cmd/hz/meta/const.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package meta import ( "path/filepath" "runtime" ) // Version hz version const Version = "v0.9.7" const DefaultServiceName = "hertz_service" // Mode hz run modes type Mode int // SysType is the running program's operating system type const SysType = runtime.GOOS const WindowsOS = "windows" const EnvPluginMode = "HERTZ_PLUGIN_MODE" // hz Commands const ( CmdUpdate = "update" CmdNew = "new" CmdModel = "model" CmdClient = "client" ) // hz IDLs const ( IdlThrift = "thrift" IdlProto = "proto" ) // Third-party Compilers const ( TpCompilerThrift = "thriftgo" TpCompilerProto = "protoc" ) // hz Plugins const ( ProtocPluginName = "protoc-gen-hertz" ThriftPluginName = "thrift-gen-hertz" ) // hz Errors const ( LoadError = 1 GenerateLayoutError = 2 PersistError = 3 PluginError = 4 ) // Package Dir const ( ModelDir = "biz" + string(filepath.Separator) + "model" RouterDir = "biz" + string(filepath.Separator) + "router" HandlerDir = "biz" + string(filepath.Separator) + "handler" ) // Backend Model Backends type Backend string const ( BackendGolang Backend = "golang" ) // template const value const ( SetBodyParam = "setBodyParam(req).\n" ) // TheUseOptionMessage indicates that the generating of 'model code' is aborted due to the -use option for thrift IDL. const TheUseOptionMessage = "'model code' is not generated due to the '-use' option" const AddThriftReplace = "do not generate 'go.mod', please add 'replace github.com/apache/thrift => github.com/apache/thrift v0.13.0' to your 'go.mod'" ================================================ FILE: cmd/hz/meta/manifest.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package meta import ( "bytes" "fmt" "os" "path/filepath" gv "github.com/hashicorp/go-version" "gopkg.in/yaml.v2" ) const ManifestFile = ".hz" type Manifest struct { Version string `yaml:"hz version"` HandlerDir string `yaml:"handlerDir"` ModelDir string `yaml:"modelDir"` RouterDir string `yaml:"routerDir"` } var GoVersion *gv.Version func init() { // valid by unit test already, so no need to check error GoVersion, _ = gv.NewVersion(Version) } func (manifest *Manifest) InitAndValidate(dir string) error { m, err := loadConfigFile(filepath.Join(dir, ManifestFile)) if err != nil { return fmt.Errorf("can not load \".hz\", err: %v", err) } if len(m.Version) == 0 { return fmt.Errorf("can not get hz version form \".hz\", current project doesn't belong to hertz framework") } *manifest = *m _, err = gv.NewVersion(manifest.Version) if err != nil { return fmt.Errorf("invalid hz version in \".hz\", err: %v", err) } return nil } const hzTitle = "// Code generated by hz. DO NOT EDIT." func (manifest *Manifest) String() string { conf, _ := yaml.Marshal(*manifest) return hzTitle + "\n\n" + string(conf) } func (manifest *Manifest) Persist(dir string) error { file := filepath.Join(dir, ManifestFile) fd, err := os.OpenFile(file, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.FileMode(0o644)) if err != nil { return err } defer fd.Close() _, err = fd.WriteString(manifest.String()) return err } // loadConfigFile load config file from path func loadConfigFile(path string) (*Manifest, error) { file, err := os.ReadFile(path) if err != nil { return nil, err } var manifest Manifest file = bytes.TrimPrefix(file, []byte(hzTitle)) if err = yaml.Unmarshal(file, &manifest); err != nil { return nil, fmt.Errorf("decode \".hz\" failed, err: %v", err) } return &manifest, nil } ================================================ FILE: cmd/hz/meta/manifest_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package meta import ( "testing" gv "github.com/hashicorp/go-version" ) func TestValidate(t *testing.T) { _, err := gv.NewVersion(Version) if err != nil { t.Fatalf("not a valid version: %s", err) } } ================================================ FILE: cmd/hz/protobuf/api/api.pb.go ================================================ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.30.0 // protoc v3.21.12 // source: api.proto package api import ( reflect "reflect" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" descriptorpb "google.golang.org/protobuf/types/descriptorpb" ) const ( // Verify that this generated code is sufficiently up-to-date. _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) // Verify that runtime/protoimpl is sufficiently up-to-date. _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) var file_api_proto_extTypes = []protoimpl.ExtensionInfo{ { ExtendedType: (*descriptorpb.FieldOptions)(nil), ExtensionType: (*string)(nil), Field: 50101, Name: "api.raw_body", Tag: "bytes,50101,opt,name=raw_body", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.FieldOptions)(nil), ExtensionType: (*string)(nil), Field: 50102, Name: "api.query", Tag: "bytes,50102,opt,name=query", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.FieldOptions)(nil), ExtensionType: (*string)(nil), Field: 50103, Name: "api.header", Tag: "bytes,50103,opt,name=header", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.FieldOptions)(nil), ExtensionType: (*string)(nil), Field: 50104, Name: "api.cookie", Tag: "bytes,50104,opt,name=cookie", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.FieldOptions)(nil), ExtensionType: (*string)(nil), Field: 50105, Name: "api.body", Tag: "bytes,50105,opt,name=body", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.FieldOptions)(nil), ExtensionType: (*string)(nil), Field: 50106, Name: "api.path", Tag: "bytes,50106,opt,name=path", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.FieldOptions)(nil), ExtensionType: (*string)(nil), Field: 50107, Name: "api.vd", Tag: "bytes,50107,opt,name=vd", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.FieldOptions)(nil), ExtensionType: (*string)(nil), Field: 50108, Name: "api.form", Tag: "bytes,50108,opt,name=form", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.FieldOptions)(nil), ExtensionType: (*string)(nil), Field: 50109, Name: "api.js_conv", Tag: "bytes,50109,opt,name=js_conv", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.FieldOptions)(nil), ExtensionType: (*string)(nil), Field: 50110, Name: "api.file_name", Tag: "bytes,50110,opt,name=file_name", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.FieldOptions)(nil), ExtensionType: (*string)(nil), Field: 50111, Name: "api.none", Tag: "bytes,50111,opt,name=none", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.FieldOptions)(nil), ExtensionType: (*string)(nil), Field: 50131, Name: "api.form_compatible", Tag: "bytes,50131,opt,name=form_compatible", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.FieldOptions)(nil), ExtensionType: (*string)(nil), Field: 50132, Name: "api.js_conv_compatible", Tag: "bytes,50132,opt,name=js_conv_compatible", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.FieldOptions)(nil), ExtensionType: (*string)(nil), Field: 50133, Name: "api.file_name_compatible", Tag: "bytes,50133,opt,name=file_name_compatible", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.FieldOptions)(nil), ExtensionType: (*string)(nil), Field: 50134, Name: "api.none_compatible", Tag: "bytes,50134,opt,name=none_compatible", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.FieldOptions)(nil), ExtensionType: (*string)(nil), Field: 51001, Name: "api.go_tag", Tag: "bytes,51001,opt,name=go_tag", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.MethodOptions)(nil), ExtensionType: (*string)(nil), Field: 50201, Name: "api.get", Tag: "bytes,50201,opt,name=get", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.MethodOptions)(nil), ExtensionType: (*string)(nil), Field: 50202, Name: "api.post", Tag: "bytes,50202,opt,name=post", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.MethodOptions)(nil), ExtensionType: (*string)(nil), Field: 50203, Name: "api.put", Tag: "bytes,50203,opt,name=put", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.MethodOptions)(nil), ExtensionType: (*string)(nil), Field: 50204, Name: "api.delete", Tag: "bytes,50204,opt,name=delete", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.MethodOptions)(nil), ExtensionType: (*string)(nil), Field: 50205, Name: "api.patch", Tag: "bytes,50205,opt,name=patch", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.MethodOptions)(nil), ExtensionType: (*string)(nil), Field: 50206, Name: "api.options", Tag: "bytes,50206,opt,name=options", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.MethodOptions)(nil), ExtensionType: (*string)(nil), Field: 50207, Name: "api.head", Tag: "bytes,50207,opt,name=head", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.MethodOptions)(nil), ExtensionType: (*string)(nil), Field: 50208, Name: "api.any", Tag: "bytes,50208,opt,name=any", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.MethodOptions)(nil), ExtensionType: (*string)(nil), Field: 50301, Name: "api.gen_path", Tag: "bytes,50301,opt,name=gen_path", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.MethodOptions)(nil), ExtensionType: (*string)(nil), Field: 50302, Name: "api.api_version", Tag: "bytes,50302,opt,name=api_version", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.MethodOptions)(nil), ExtensionType: (*string)(nil), Field: 50303, Name: "api.tag", Tag: "bytes,50303,opt,name=tag", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.MethodOptions)(nil), ExtensionType: (*string)(nil), Field: 50304, Name: "api.name", Tag: "bytes,50304,opt,name=name", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.MethodOptions)(nil), ExtensionType: (*string)(nil), Field: 50305, Name: "api.api_level", Tag: "bytes,50305,opt,name=api_level", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.MethodOptions)(nil), ExtensionType: (*string)(nil), Field: 50306, Name: "api.serializer", Tag: "bytes,50306,opt,name=serializer", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.MethodOptions)(nil), ExtensionType: (*string)(nil), Field: 50307, Name: "api.param", Tag: "bytes,50307,opt,name=param", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.MethodOptions)(nil), ExtensionType: (*string)(nil), Field: 50308, Name: "api.baseurl", Tag: "bytes,50308,opt,name=baseurl", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.MethodOptions)(nil), ExtensionType: (*string)(nil), Field: 50309, Name: "api.handler_path", Tag: "bytes,50309,opt,name=handler_path", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.MethodOptions)(nil), ExtensionType: (*string)(nil), Field: 50331, Name: "api.handler_path_compatible", Tag: "bytes,50331,opt,name=handler_path_compatible", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.EnumValueOptions)(nil), ExtensionType: (*int32)(nil), Field: 50401, Name: "api.http_code", Tag: "varint,50401,opt,name=http_code", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.ServiceOptions)(nil), ExtensionType: (*string)(nil), Field: 50402, Name: "api.base_domain", Tag: "bytes,50402,opt,name=base_domain", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.ServiceOptions)(nil), ExtensionType: (*string)(nil), Field: 50731, Name: "api.base_domain_compatible", Tag: "bytes,50731,opt,name=base_domain_compatible", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.ServiceOptions)(nil), ExtensionType: (*string)(nil), Field: 50732, Name: "api.service_path", Tag: "bytes,50732,opt,name=service_path", Filename: "api.proto", }, { ExtendedType: (*descriptorpb.MessageOptions)(nil), ExtensionType: (*string)(nil), Field: 50830, Name: "api.reserve", Tag: "bytes,50830,opt,name=reserve", Filename: "api.proto", }, } // Extension fields to descriptorpb.FieldOptions. var ( // optional string raw_body = 50101; E_RawBody = &file_api_proto_extTypes[0] // optional string query = 50102; E_Query = &file_api_proto_extTypes[1] // optional string header = 50103; E_Header = &file_api_proto_extTypes[2] // optional string cookie = 50104; E_Cookie = &file_api_proto_extTypes[3] // optional string body = 50105; E_Body = &file_api_proto_extTypes[4] // optional string path = 50106; E_Path = &file_api_proto_extTypes[5] // optional string vd = 50107; E_Vd = &file_api_proto_extTypes[6] // optional string form = 50108; E_Form = &file_api_proto_extTypes[7] // optional string js_conv = 50109; E_JsConv = &file_api_proto_extTypes[8] // optional string file_name = 50110; E_FileName = &file_api_proto_extTypes[9] // optional string none = 50111; E_None = &file_api_proto_extTypes[10] // 50131~50160 used to extend field option by hz // // optional string form_compatible = 50131; E_FormCompatible = &file_api_proto_extTypes[11] // optional string js_conv_compatible = 50132; E_JsConvCompatible = &file_api_proto_extTypes[12] // optional string file_name_compatible = 50133; E_FileNameCompatible = &file_api_proto_extTypes[13] // optional string none_compatible = 50134; E_NoneCompatible = &file_api_proto_extTypes[14] // optional string go_tag = 51001; E_GoTag = &file_api_proto_extTypes[15] ) // Extension fields to descriptorpb.MethodOptions. var ( // optional string get = 50201; E_Get = &file_api_proto_extTypes[16] // optional string post = 50202; E_Post = &file_api_proto_extTypes[17] // optional string put = 50203; E_Put = &file_api_proto_extTypes[18] // optional string delete = 50204; E_Delete = &file_api_proto_extTypes[19] // optional string patch = 50205; E_Patch = &file_api_proto_extTypes[20] // optional string options = 50206; E_Options = &file_api_proto_extTypes[21] // optional string head = 50207; E_Head = &file_api_proto_extTypes[22] // optional string any = 50208; E_Any = &file_api_proto_extTypes[23] // optional string gen_path = 50301; E_GenPath = &file_api_proto_extTypes[24] // The path specified by the user when the client code is generated, with a higher priority than api_version // optional string api_version = 50302; E_ApiVersion = &file_api_proto_extTypes[25] // Specify the value of the :version variable in path when the client code is generated // optional string tag = 50303; E_Tag = &file_api_proto_extTypes[26] // rpc tag, can be multiple, separated by commas // optional string name = 50304; E_Name = &file_api_proto_extTypes[27] // Name of rpc // optional string api_level = 50305; E_ApiLevel = &file_api_proto_extTypes[28] // Interface Level // optional string serializer = 50306; E_Serializer = &file_api_proto_extTypes[29] // Serialization method // optional string param = 50307; E_Param = &file_api_proto_extTypes[30] // Whether client requests take public parameters // optional string baseurl = 50308; E_Baseurl = &file_api_proto_extTypes[31] // Baseurl used in ttnet routing // optional string handler_path = 50309; E_HandlerPath = &file_api_proto_extTypes[32] // handler_path specifies the path to generate the method // 50331~50360 used to extend method option by hz // // optional string handler_path_compatible = 50331; E_HandlerPathCompatible = &file_api_proto_extTypes[33] // handler_path specifies the path to generate the method ) // Extension fields to descriptorpb.EnumValueOptions. var ( // optional int32 http_code = 50401; E_HttpCode = &file_api_proto_extTypes[34] ) // Extension fields to descriptorpb.ServiceOptions. var ( // optional string base_domain = 50402; E_BaseDomain = &file_api_proto_extTypes[35] // 50731~50760 used to extend service option by hz // // optional string base_domain_compatible = 50731; E_BaseDomainCompatible = &file_api_proto_extTypes[36] // optional string service_path = 50732; E_ServicePath = &file_api_proto_extTypes[37] ) // Extension fields to descriptorpb.MessageOptions. var ( // optional string reserve = 50830; E_Reserve = &file_api_proto_extTypes[38] ) var File_api_proto protoreflect.FileDescriptor var file_api_proto_rawDesc = []byte{ 0x0a, 0x09, 0x61, 0x70, 0x69, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x03, 0x61, 0x70, 0x69, 0x1a, 0x20, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x64, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x6f, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x3a, 0x3a, 0x0a, 0x08, 0x72, 0x61, 0x77, 0x5f, 0x62, 0x6f, 0x64, 0x79, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xb5, 0x87, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x72, 0x61, 0x77, 0x42, 0x6f, 0x64, 0x79, 0x3a, 0x35, 0x0a, 0x05, 0x71, 0x75, 0x65, 0x72, 0x79, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xb6, 0x87, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x71, 0x75, 0x65, 0x72, 0x79, 0x3a, 0x37, 0x0a, 0x06, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xb7, 0x87, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x3a, 0x37, 0x0a, 0x06, 0x63, 0x6f, 0x6f, 0x6b, 0x69, 0x65, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xb8, 0x87, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x63, 0x6f, 0x6f, 0x6b, 0x69, 0x65, 0x3a, 0x33, 0x0a, 0x04, 0x62, 0x6f, 0x64, 0x79, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xb9, 0x87, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x62, 0x6f, 0x64, 0x79, 0x3a, 0x33, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xba, 0x87, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x3a, 0x2f, 0x0a, 0x02, 0x76, 0x64, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xbb, 0x87, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x76, 0x64, 0x3a, 0x33, 0x0a, 0x04, 0x66, 0x6f, 0x72, 0x6d, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xbc, 0x87, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x6f, 0x72, 0x6d, 0x3a, 0x38, 0x0a, 0x07, 0x6a, 0x73, 0x5f, 0x63, 0x6f, 0x6e, 0x76, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xbd, 0x87, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6a, 0x73, 0x43, 0x6f, 0x6e, 0x76, 0x3a, 0x3c, 0x0a, 0x09, 0x66, 0x69, 0x6c, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xbe, 0x87, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x66, 0x69, 0x6c, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x3a, 0x33, 0x0a, 0x04, 0x6e, 0x6f, 0x6e, 0x65, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xbf, 0x87, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x6f, 0x6e, 0x65, 0x3a, 0x48, 0x0a, 0x0f, 0x66, 0x6f, 0x72, 0x6d, 0x5f, 0x63, 0x6f, 0x6d, 0x70, 0x61, 0x74, 0x69, 0x62, 0x6c, 0x65, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xd3, 0x87, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x66, 0x6f, 0x72, 0x6d, 0x43, 0x6f, 0x6d, 0x70, 0x61, 0x74, 0x69, 0x62, 0x6c, 0x65, 0x3a, 0x4d, 0x0a, 0x12, 0x6a, 0x73, 0x5f, 0x63, 0x6f, 0x6e, 0x76, 0x5f, 0x63, 0x6f, 0x6d, 0x70, 0x61, 0x74, 0x69, 0x62, 0x6c, 0x65, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xd4, 0x87, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x10, 0x6a, 0x73, 0x43, 0x6f, 0x6e, 0x76, 0x43, 0x6f, 0x6d, 0x70, 0x61, 0x74, 0x69, 0x62, 0x6c, 0x65, 0x3a, 0x51, 0x0a, 0x14, 0x66, 0x69, 0x6c, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x5f, 0x63, 0x6f, 0x6d, 0x70, 0x61, 0x74, 0x69, 0x62, 0x6c, 0x65, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xd5, 0x87, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x66, 0x69, 0x6c, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x43, 0x6f, 0x6d, 0x70, 0x61, 0x74, 0x69, 0x62, 0x6c, 0x65, 0x3a, 0x48, 0x0a, 0x0f, 0x6e, 0x6f, 0x6e, 0x65, 0x5f, 0x63, 0x6f, 0x6d, 0x70, 0x61, 0x74, 0x69, 0x62, 0x6c, 0x65, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xd6, 0x87, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x6e, 0x6f, 0x6e, 0x65, 0x43, 0x6f, 0x6d, 0x70, 0x61, 0x74, 0x69, 0x62, 0x6c, 0x65, 0x3a, 0x36, 0x0a, 0x06, 0x67, 0x6f, 0x5f, 0x74, 0x61, 0x67, 0x12, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xb9, 0x8e, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x67, 0x6f, 0x54, 0x61, 0x67, 0x3a, 0x32, 0x0a, 0x03, 0x67, 0x65, 0x74, 0x12, 0x1e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x99, 0x88, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x67, 0x65, 0x74, 0x3a, 0x34, 0x0a, 0x04, 0x70, 0x6f, 0x73, 0x74, 0x12, 0x1e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x9a, 0x88, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x6f, 0x73, 0x74, 0x3a, 0x32, 0x0a, 0x03, 0x70, 0x75, 0x74, 0x12, 0x1e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x9b, 0x88, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x70, 0x75, 0x74, 0x3a, 0x38, 0x0a, 0x06, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x12, 0x1e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x9c, 0x88, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x3a, 0x36, 0x0a, 0x05, 0x70, 0x61, 0x74, 0x63, 0x68, 0x12, 0x1e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x9d, 0x88, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x70, 0x61, 0x74, 0x63, 0x68, 0x3a, 0x3a, 0x0a, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x1e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x9e, 0x88, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x3a, 0x34, 0x0a, 0x04, 0x68, 0x65, 0x61, 0x64, 0x12, 0x1e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x9f, 0x88, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x68, 0x65, 0x61, 0x64, 0x3a, 0x32, 0x0a, 0x03, 0x61, 0x6e, 0x79, 0x12, 0x1e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xa0, 0x88, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x61, 0x6e, 0x79, 0x3a, 0x3b, 0x0a, 0x08, 0x67, 0x65, 0x6e, 0x5f, 0x70, 0x61, 0x74, 0x68, 0x12, 0x1e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xfd, 0x88, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x67, 0x65, 0x6e, 0x50, 0x61, 0x74, 0x68, 0x3a, 0x41, 0x0a, 0x0b, 0x61, 0x70, 0x69, 0x5f, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x1e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xfe, 0x88, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x70, 0x69, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x3a, 0x32, 0x0a, 0x03, 0x74, 0x61, 0x67, 0x12, 0x1e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xff, 0x88, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x74, 0x61, 0x67, 0x3a, 0x34, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x1e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x80, 0x89, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x3a, 0x3d, 0x0a, 0x09, 0x61, 0x70, 0x69, 0x5f, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x81, 0x89, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x61, 0x70, 0x69, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x3a, 0x40, 0x0a, 0x0a, 0x73, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x69, 0x7a, 0x65, 0x72, 0x12, 0x1e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x82, 0x89, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x73, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x69, 0x7a, 0x65, 0x72, 0x3a, 0x36, 0x0a, 0x05, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x12, 0x1e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x83, 0x89, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x70, 0x61, 0x72, 0x61, 0x6d, 0x3a, 0x3a, 0x0a, 0x07, 0x62, 0x61, 0x73, 0x65, 0x75, 0x72, 0x6c, 0x12, 0x1e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x84, 0x89, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x62, 0x61, 0x73, 0x65, 0x75, 0x72, 0x6c, 0x3a, 0x43, 0x0a, 0x0c, 0x68, 0x61, 0x6e, 0x64, 0x6c, 0x65, 0x72, 0x5f, 0x70, 0x61, 0x74, 0x68, 0x12, 0x1e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x85, 0x89, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x68, 0x61, 0x6e, 0x64, 0x6c, 0x65, 0x72, 0x50, 0x61, 0x74, 0x68, 0x3a, 0x58, 0x0a, 0x17, 0x68, 0x61, 0x6e, 0x64, 0x6c, 0x65, 0x72, 0x5f, 0x70, 0x61, 0x74, 0x68, 0x5f, 0x63, 0x6f, 0x6d, 0x70, 0x61, 0x74, 0x69, 0x62, 0x6c, 0x65, 0x12, 0x1e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x9b, 0x89, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x68, 0x61, 0x6e, 0x64, 0x6c, 0x65, 0x72, 0x50, 0x61, 0x74, 0x68, 0x43, 0x6f, 0x6d, 0x70, 0x61, 0x74, 0x69, 0x62, 0x6c, 0x65, 0x3a, 0x40, 0x0a, 0x09, 0x68, 0x74, 0x74, 0x70, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x12, 0x21, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6e, 0x75, 0x6d, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xe1, 0x89, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x08, 0x68, 0x74, 0x74, 0x70, 0x43, 0x6f, 0x64, 0x65, 0x3a, 0x42, 0x0a, 0x0b, 0x62, 0x61, 0x73, 0x65, 0x5f, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x1f, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xe2, 0x89, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x62, 0x61, 0x73, 0x65, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x3a, 0x57, 0x0a, 0x16, 0x62, 0x61, 0x73, 0x65, 0x5f, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x5f, 0x63, 0x6f, 0x6d, 0x70, 0x61, 0x74, 0x69, 0x62, 0x6c, 0x65, 0x12, 0x1f, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xab, 0x8c, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x14, 0x62, 0x61, 0x73, 0x65, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x43, 0x6f, 0x6d, 0x70, 0x61, 0x74, 0x69, 0x62, 0x6c, 0x65, 0x3a, 0x44, 0x0a, 0x0c, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x5f, 0x70, 0x61, 0x74, 0x68, 0x12, 0x1f, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0xac, 0x8c, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x50, 0x61, 0x74, 0x68, 0x3a, 0x3b, 0x0a, 0x07, 0x72, 0x65, 0x73, 0x65, 0x72, 0x76, 0x65, 0x12, 0x1f, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x8e, 0x8d, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x72, 0x65, 0x73, 0x65, 0x72, 0x76, 0x65, 0x42, 0x06, 0x5a, 0x04, 0x2f, 0x61, 0x70, 0x69, } var file_api_proto_goTypes = []interface{}{ (*descriptorpb.FieldOptions)(nil), // 0: google.protobuf.FieldOptions (*descriptorpb.MethodOptions)(nil), // 1: google.protobuf.MethodOptions (*descriptorpb.EnumValueOptions)(nil), // 2: google.protobuf.EnumValueOptions (*descriptorpb.ServiceOptions)(nil), // 3: google.protobuf.ServiceOptions (*descriptorpb.MessageOptions)(nil), // 4: google.protobuf.MessageOptions } var file_api_proto_depIdxs = []int32{ 0, // 0: api.raw_body:extendee -> google.protobuf.FieldOptions 0, // 1: api.query:extendee -> google.protobuf.FieldOptions 0, // 2: api.header:extendee -> google.protobuf.FieldOptions 0, // 3: api.cookie:extendee -> google.protobuf.FieldOptions 0, // 4: api.body:extendee -> google.protobuf.FieldOptions 0, // 5: api.path:extendee -> google.protobuf.FieldOptions 0, // 6: api.vd:extendee -> google.protobuf.FieldOptions 0, // 7: api.form:extendee -> google.protobuf.FieldOptions 0, // 8: api.js_conv:extendee -> google.protobuf.FieldOptions 0, // 9: api.file_name:extendee -> google.protobuf.FieldOptions 0, // 10: api.none:extendee -> google.protobuf.FieldOptions 0, // 11: api.form_compatible:extendee -> google.protobuf.FieldOptions 0, // 12: api.js_conv_compatible:extendee -> google.protobuf.FieldOptions 0, // 13: api.file_name_compatible:extendee -> google.protobuf.FieldOptions 0, // 14: api.none_compatible:extendee -> google.protobuf.FieldOptions 0, // 15: api.go_tag:extendee -> google.protobuf.FieldOptions 1, // 16: api.get:extendee -> google.protobuf.MethodOptions 1, // 17: api.post:extendee -> google.protobuf.MethodOptions 1, // 18: api.put:extendee -> google.protobuf.MethodOptions 1, // 19: api.delete:extendee -> google.protobuf.MethodOptions 1, // 20: api.patch:extendee -> google.protobuf.MethodOptions 1, // 21: api.options:extendee -> google.protobuf.MethodOptions 1, // 22: api.head:extendee -> google.protobuf.MethodOptions 1, // 23: api.any:extendee -> google.protobuf.MethodOptions 1, // 24: api.gen_path:extendee -> google.protobuf.MethodOptions 1, // 25: api.api_version:extendee -> google.protobuf.MethodOptions 1, // 26: api.tag:extendee -> google.protobuf.MethodOptions 1, // 27: api.name:extendee -> google.protobuf.MethodOptions 1, // 28: api.api_level:extendee -> google.protobuf.MethodOptions 1, // 29: api.serializer:extendee -> google.protobuf.MethodOptions 1, // 30: api.param:extendee -> google.protobuf.MethodOptions 1, // 31: api.baseurl:extendee -> google.protobuf.MethodOptions 1, // 32: api.handler_path:extendee -> google.protobuf.MethodOptions 1, // 33: api.handler_path_compatible:extendee -> google.protobuf.MethodOptions 2, // 34: api.http_code:extendee -> google.protobuf.EnumValueOptions 3, // 35: api.base_domain:extendee -> google.protobuf.ServiceOptions 3, // 36: api.base_domain_compatible:extendee -> google.protobuf.ServiceOptions 3, // 37: api.service_path:extendee -> google.protobuf.ServiceOptions 4, // 38: api.reserve:extendee -> google.protobuf.MessageOptions 39, // [39:39] is the sub-list for method output_type 39, // [39:39] is the sub-list for method input_type 39, // [39:39] is the sub-list for extension type_name 0, // [0:39] is the sub-list for extension extendee 0, // [0:0] is the sub-list for field type_name } func init() { file_api_proto_init() } func file_api_proto_init() { if File_api_proto != nil { return } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_api_proto_rawDesc, NumEnums: 0, NumMessages: 0, NumExtensions: 39, NumServices: 0, }, GoTypes: file_api_proto_goTypes, DependencyIndexes: file_api_proto_depIdxs, ExtensionInfos: file_api_proto_extTypes, }.Build() File_api_proto = out.File file_api_proto_rawDesc = nil file_api_proto_goTypes = nil file_api_proto_depIdxs = nil } ================================================ FILE: cmd/hz/protobuf/api/api.proto ================================================ syntax = "proto2"; package api; import "google/protobuf/descriptor.proto"; option go_package = "/api"; extend google.protobuf.FieldOptions { optional string raw_body = 50101; optional string query = 50102; optional string header = 50103; optional string cookie = 50104; optional string body = 50105; optional string path = 50106; optional string vd = 50107; optional string form = 50108; optional string js_conv = 50109; optional string file_name = 50110; optional string none = 50111; // 50131~50160 used to extend field option by hz optional string form_compatible = 50131; optional string js_conv_compatible = 50132; optional string file_name_compatible = 50133; optional string none_compatible = 50134; // 50135 is reserved to vt_compatible // optional FieldRules vt_compatible = 50135; optional string go_tag = 51001; } extend google.protobuf.MethodOptions { optional string get = 50201; optional string post = 50202; optional string put = 50203; optional string delete = 50204; optional string patch = 50205; optional string options = 50206; optional string head = 50207; optional string any = 50208; optional string gen_path = 50301; // The path specified by the user when the client code is generated, with a higher priority than api_version optional string api_version = 50302; // Specify the value of the :version variable in path when the client code is generated optional string tag = 50303; // rpc tag, can be multiple, separated by commas optional string name = 50304; // Name of rpc optional string api_level = 50305; // Interface Level optional string serializer = 50306; // Serialization method optional string param = 50307; // Whether client requests take public parameters optional string baseurl = 50308; // Baseurl used in ttnet routing optional string handler_path = 50309; // handler_path specifies the path to generate the method // 50331~50360 used to extend method option by hz optional string handler_path_compatible = 50331; // handler_path specifies the path to generate the method } extend google.protobuf.EnumValueOptions { optional int32 http_code = 50401; // 50431~50460 used to extend enum option by hz } extend google.protobuf.ServiceOptions { optional string base_domain = 50402; // 50731~50760 used to extend service option by hz optional string base_domain_compatible = 50731; optional string service_path = 50732; } extend google.protobuf.MessageOptions { // optional FieldRules msg_vt = 50111; optional string reserve = 50830; // 550831 is reserved to msg_vt_compatible // optional FieldRules msg_vt_compatible = 50831; } ================================================ FILE: cmd/hz/protobuf/ast.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package protobuf import ( "fmt" "path/filepath" "sort" "strings" "github.com/cloudwego/hertz/cmd/hz/generator" "github.com/cloudwego/hertz/cmd/hz/generator/model" "github.com/cloudwego/hertz/cmd/hz/meta" "github.com/cloudwego/hertz/cmd/hz/protobuf/api" "github.com/cloudwego/hertz/cmd/hz/util" "github.com/cloudwego/hertz/cmd/hz/util/logs" "github.com/jhump/protoreflect/desc" "google.golang.org/protobuf/compiler/protogen" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/runtime/protoimpl" "google.golang.org/protobuf/types/descriptorpb" ) var BaseProto = descriptorpb.FileDescriptorProto{} // getGoPackage get option go_package // If pkgMap is specified, the specified value is used as the go_package; // If go package is not specified, then the value of package is used as go_package. func getGoPackage(f *descriptorpb.FileDescriptorProto, pkgMap map[string]string) string { if f.Options == nil { f.Options = new(descriptorpb.FileOptions) } if f.Options.GoPackage == nil { f.Options.GoPackage = new(string) } goPkg := f.Options.GetGoPackage() // if go_package has ";", for example go_package="/a/b/c;d", we will use "/a/b/c" as go_package if strings.Contains(goPkg, ";") { pkg := strings.Split(goPkg, ";") if len(pkg) == 2 { logs.Warnf("The go_package of the file(%s) is \"%s\", hz will use \"%s\" as the go_package.", f.GetName(), goPkg, pkg[0]) goPkg = pkg[0] } } if goPkg == "" { goPkg = f.GetPackage() } if opt, ok := pkgMap[f.GetName()]; ok { return opt } return goPkg } func switchBaseType(typ descriptorpb.FieldDescriptorProto_Type) *model.Type { switch typ { case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE, descriptorpb.FieldDescriptorProto_TYPE_GROUP: return nil case descriptorpb.FieldDescriptorProto_TYPE_INT64: return model.TypeInt64 case descriptorpb.FieldDescriptorProto_TYPE_INT32: return model.TypeInt32 case descriptorpb.FieldDescriptorProto_TYPE_UINT64: return model.TypeUint64 case descriptorpb.FieldDescriptorProto_TYPE_UINT32: return model.TypeUint32 case descriptorpb.FieldDescriptorProto_TYPE_FIXED64: return model.TypeUint64 case descriptorpb.FieldDescriptorProto_TYPE_FIXED32: return model.TypeUint32 case descriptorpb.FieldDescriptorProto_TYPE_BOOL: return model.TypeBool case descriptorpb.FieldDescriptorProto_TYPE_STRING: return model.TypeString case descriptorpb.FieldDescriptorProto_TYPE_BYTES: return model.TypeBinary case descriptorpb.FieldDescriptorProto_TYPE_SFIXED32: return model.TypeInt32 case descriptorpb.FieldDescriptorProto_TYPE_SFIXED64: return model.TypeInt64 case descriptorpb.FieldDescriptorProto_TYPE_SINT32: return model.TypeInt32 case descriptorpb.FieldDescriptorProto_TYPE_SINT64: return model.TypeInt64 case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE: return model.TypeFloat64 case descriptorpb.FieldDescriptorProto_TYPE_FLOAT: return model.TypeFloat32 } return nil } func astToService(ast *descriptorpb.FileDescriptorProto, resolver *Resolver, cmdType string, gen *protogen.Plugin) ([]*generator.Service, error) { resolver.ExportReferred(true, false) ss := ast.GetService() out := make([]*generator.Service, 0, len(ss)) var merges model.Models for _, s := range ss { service := &generator.Service{ Name: s.GetName(), } service.BaseDomain = "" domainAnno := getCompatibleAnnotation(s.GetOptions(), api.E_BaseDomain, api.E_BaseDomainCompatible) if cmdType == meta.CmdClient { val, ok := domainAnno.(string) if ok && len(val) != 0 { service.BaseDomain = val } } ms := s.GetMethod() methods := make([]*generator.HttpMethod, 0, len(ms)) clientMethods := make([]*generator.ClientMethod, 0, len(ms)) servicePathAnno := checkFirstOption(api.E_ServicePath, s.GetOptions()) servicePath := "" if val, ok := servicePathAnno.(string); ok { servicePath = val } for _, m := range ms { rs := getAllOptions(HttpMethodOptions, m.GetOptions()) if len(rs) == 0 { continue } httpOpts := httpOptions{} for k, v := range rs { httpOpts = append(httpOpts, httpOption{ method: k, path: v.(string), }) } // turn the map into a slice and sort it to make sure getting the results in the same order every time sort.Sort(httpOpts) var handlerOutDir string genPath := getCompatibleAnnotation(m.GetOptions(), api.E_HandlerPath, api.E_HandlerPathCompatible) handlerOutDir, ok := genPath.(string) if !ok || len(handlerOutDir) == 0 { handlerOutDir = "" } if len(handlerOutDir) == 0 { handlerOutDir = servicePath } // protoGoInfo can get generated "Go Info" for proto file. // the type name may be different between "***.proto" and "***.pb.go" protoGoInfo, exist := gen.FilesByPath[ast.GetName()] if !exist { return nil, fmt.Errorf("file(%s) can not exist", ast.GetName()) } methodGoInfo, err := getMethod(protoGoInfo, s, m) if err != nil { return nil, err } inputGoType := methodGoInfo.Input outputGoType := methodGoInfo.Output reqName := m.GetInputType() sb, err := resolver.ResolveIdentifier(reqName) if err != nil { return nil, err } reqName = util.BaseName(sb.Scope.GetOptions().GetGoPackage(), "") + "." + inputGoType.GoIdent.GoName reqRawName := inputGoType.GoIdent.GoName reqPackage := util.BaseName(sb.Scope.GetOptions().GetGoPackage(), "") respName := m.GetOutputType() st, err := resolver.ResolveIdentifier(respName) if err != nil { return nil, err } respName = util.BaseName(st.Scope.GetOptions().GetGoPackage(), "") + "." + outputGoType.GoIdent.GoName respRawName := outputGoType.GoIdent.GoName respPackage := util.BaseName(sb.Scope.GetOptions().GetGoPackage(), "") var serializer string sl, sv := checkFirstOptions(SerializerOptions, m.GetOptions()) if sl != "" { serializer = sv.(string) } method := &generator.HttpMethod{ Name: util.CamelString(m.GetName()), HTTPMethod: httpOpts[0].method, Path: httpOpts[0].path, Serializer: serializer, OutputDir: handlerOutDir, GenHandler: true, } goOptMapAlias := make(map[string]string, 1) refs := resolver.ExportReferred(false, true) method.Models = make(map[string]*model.Model, len(refs)) for _, ref := range refs { if val, exist := method.Models[ref.Model.PackageName]; exist { if val.Package == ref.Model.Package { method.Models[ref.Model.PackageName] = ref.Model goOptMapAlias[ref.Model.Package] = ref.Model.PackageName } else { file := filepath.Base(ref.Model.FilePath) fileName := strings.Split(file, ".") newPkg := fileName[len(fileName)-2] + "_" + val.PackageName method.Models[newPkg] = ref.Model goOptMapAlias[ref.Model.Package] = newPkg } continue } method.Models[ref.Model.PackageName] = ref.Model goOptMapAlias[ref.Model.Package] = ref.Model.PackageName } merges = service.Models merges.MergeMap(method.Models) if goOptMapAlias[sb.Scope.GetOptions().GetGoPackage()] != "" { reqName = goOptMapAlias[sb.Scope.GetOptions().GetGoPackage()] + "." + inputGoType.GoIdent.GoName } if goOptMapAlias[sb.Scope.GetOptions().GetGoPackage()] != "" { respName = goOptMapAlias[st.Scope.GetOptions().GetGoPackage()] + "." + outputGoType.GoIdent.GoName } method.RequestTypeName = reqName method.RequestTypeRawName = reqRawName method.RequestTypePackage = reqPackage method.ReturnTypeName = respName method.ReturnTypeRawName = respRawName method.ReturnTypePackage = respPackage methods = append(methods, method) for idx, anno := range httpOpts { if idx == 0 { continue } tmp := *method tmp.HTTPMethod = anno.method tmp.Path = anno.path tmp.GenHandler = false methods = append(methods, &tmp) } if cmdType == meta.CmdClient { clientMethod := &generator.ClientMethod{} clientMethod.HttpMethod = method err := parseAnnotationToClient(clientMethod, gen, ast, s, m) if err != nil { return nil, err } clientMethods = append(clientMethods, clientMethod) } } service.ClientMethods = clientMethods service.Methods = methods service.Models = merges out = append(out, service) } return out, nil } func getCompatibleAnnotation(options proto.Message, anno, compatibleAnno *protoimpl.ExtensionInfo) interface{} { if proto.HasExtension(options, anno) { return checkFirstOption(anno, options) } else if proto.HasExtension(options, compatibleAnno) { return checkFirstOption(compatibleAnno, options) } return nil } func parseAnnotationToClient(clientMethod *generator.ClientMethod, gen *protogen.Plugin, ast *descriptorpb.FileDescriptorProto, s *descriptorpb.ServiceDescriptorProto, m *descriptorpb.MethodDescriptorProto) error { file, exist := gen.FilesByPath[ast.GetName()] if !exist { return fmt.Errorf("file(%s) can not exist", ast.GetName()) } method, err := getMethod(file, s, m) if err != nil { return err } // pb input type must be message inputType := method.Input var ( hasBodyAnnotation bool hasFormAnnotation bool ) for _, f := range inputType.Fields { hasAnnotation := false isStringFieldType := false if f.Desc.Kind() == protoreflect.StringKind { isStringFieldType = true } if proto.HasExtension(f.Desc.Options(), api.E_Query) { hasAnnotation = true queryAnnos := proto.GetExtension(f.Desc.Options(), api.E_Query) val := checkSnakeName(queryAnnos.(string)) clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName) } if proto.HasExtension(f.Desc.Options(), api.E_Path) { hasAnnotation = true pathAnnos := proto.GetExtension(f.Desc.Options(), api.E_Path) val := pathAnnos.(string) if isStringFieldType { clientMethod.PathParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName) } else { clientMethod.PathParamsCode += fmt.Sprintf("%q: fmt.Sprint(req.Get%s()),\n", val, f.GoName) } } if proto.HasExtension(f.Desc.Options(), api.E_Header) { hasAnnotation = true headerAnnos := proto.GetExtension(f.Desc.Options(), api.E_Header) val := headerAnnos.(string) if isStringFieldType { clientMethod.HeaderParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName) } else { clientMethod.HeaderParamsCode += fmt.Sprintf("%q: fmt.Sprint(req.Get%s()),\n", val, f.GoName) } } if formAnnos := getCompatibleAnnotation(f.Desc.Options(), api.E_Form, api.E_FormCompatible); formAnnos != nil { hasAnnotation = true hasFormAnnotation = true val := checkSnakeName(formAnnos.(string)) if isStringFieldType { clientMethod.FormValueCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName) } else { clientMethod.FormValueCode += fmt.Sprintf("%q: fmt.Sprint(req.Get%s()),\n", val, f.GoName) } } if proto.HasExtension(f.Desc.Options(), api.E_Body) { hasAnnotation = true hasBodyAnnotation = true } if fileAnnos := getCompatibleAnnotation(f.Desc.Options(), api.E_FileName, api.E_FileNameCompatible); fileAnnos != nil { hasAnnotation = true hasFormAnnotation = true val := fileAnnos.(string) clientMethod.FormFileCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName) } if proto.HasExtension(f.Desc.Options(), api.E_Cookie) { hasAnnotation = true // cookie do nothing } if !hasAnnotation && strings.EqualFold(clientMethod.HTTPMethod, "get") { clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", checkSnakeName(string(f.Desc.Name())), f.GoName) } } clientMethod.BodyParamsCode = meta.SetBodyParam if hasBodyAnnotation && hasFormAnnotation { clientMethod.FormValueCode = "" clientMethod.FormFileCode = "" } if !hasBodyAnnotation && hasFormAnnotation { clientMethod.BodyParamsCode = "" } return nil } func getMethod(file *protogen.File, s *descriptorpb.ServiceDescriptorProto, m *descriptorpb.MethodDescriptorProto) (*protogen.Method, error) { for _, f := range file.Services { if f.Desc.Name() == protoreflect.Name(s.GetName()) { for _, method := range f.Methods { if string(method.Desc.Name()) == m.GetName() { return method, nil } } } } return nil, fmt.Errorf("can not find method: %s", m.GetName()) } //---------------------------------Model-------------------------------- func astToModel(ast *descriptorpb.FileDescriptorProto, rs *Resolver) (*model.Model, error) { main := rs.mainPkg.Model if main == nil { main = new(model.Model) } mainFileDes := rs.files.PbReflect[ast.GetName()] isProto3 := mainFileDes.IsProto3() // Enums ems := ast.GetEnumType() enums := make([]model.Enum, 0, len(ems)) for _, e := range ems { em := model.Enum{ Scope: main, Name: e.GetName(), GoType: "int32", } es := e.GetValue() vs := make([]model.Constant, 0, len(es)) for _, ee := range es { vs = append(vs, model.Constant{ Scope: main, Name: ee.GetName(), Type: model.TypeInt32, Value: model.IntExpression{Src: int(ee.GetNumber())}, }) } em.Values = vs enums = append(enums, em) } main.Enums = enums // Structs sts := ast.GetMessageType() structs := make([]model.Struct, 0, len(sts)*2) oneofs := make([]model.Oneof, 0, 1) for _, st := range sts { stMessage := mainFileDes.FindMessage(ast.GetPackage() + "." + st.GetName()) stLeadingComments := getMessageLeadingComments(stMessage) s := model.Struct{ Scope: main, Name: st.GetName(), Category: model.CategoryStruct, LeadingComments: stLeadingComments, } ns := st.GetNestedType() nestedMessageInfoMap := getNestedMessageInfoMap(stMessage) for _, nt := range ns { if IsMapEntry(nt) { continue } nestedMessageInfo := nestedMessageInfoMap[nt.GetName()] nestedMessageLeadingComment := getMessageLeadingComments(nestedMessageInfo) s := model.Struct{ Scope: main, Name: st.GetName() + "_" + nt.GetName(), Category: model.CategoryStruct, LeadingComments: nestedMessageLeadingComment, } fs := nt.GetField() ns := nt.GetNestedType() vs := make([]model.Field, 0, len(fs)) oneofMap := make(map[string]model.Field) oneofType, err := resolveOneof(nestedMessageInfo, oneofMap, rs, isProto3, s, ns) if err != nil { return nil, err } oneofs = append(oneofs, oneofType...) choiceSet := make(map[string]bool) for _, f := range fs { if field, exist := oneofMap[f.GetName()]; exist { if _, ex := choiceSet[field.Name]; !ex { choiceSet[field.Name] = true vs = append(vs, field) } continue } dv := f.GetDefaultValue() fieldLeadingComments, fieldTrailingComments := getFiledComments(f, nestedMessageInfo) t, err := rs.ResolveType(f, ns) if err != nil { return nil, err } field := model.Field{ Scope: &s, Name: util.CamelString(f.GetName()), Type: t, LeadingComments: fieldLeadingComments, TrailingComments: fieldTrailingComments, IsPointer: isPointer(f, isProto3), } if dv != "" { field.IsSetDefault = true field.DefaultValue, err = parseDefaultValue(f.GetType(), f.GetDefaultValue()) if err != nil { return nil, err } } err = injectTagsToModel(f, &field, true) if err != nil { return nil, err } vs = append(vs, field) } checkDuplicatedFileName(vs) s.Fields = vs structs = append(structs, s) } fs := st.GetField() vs := make([]model.Field, 0, len(fs)) oneofMap := make(map[string]model.Field) oneofType, err := resolveOneof(stMessage, oneofMap, rs, isProto3, s, ns) if err != nil { return nil, err } oneofs = append(oneofs, oneofType...) choiceSet := make(map[string]bool) for _, f := range fs { if field, exist := oneofMap[f.GetName()]; exist { if _, ex := choiceSet[field.Name]; !ex { choiceSet[field.Name] = true vs = append(vs, field) } continue } dv := f.GetDefaultValue() fieldLeadingComments, fieldTrailingComments := getFiledComments(f, stMessage) t, err := rs.ResolveType(f, ns) if err != nil { return nil, err } field := model.Field{ Scope: &s, Name: util.CamelString(f.GetName()), Type: t, LeadingComments: fieldLeadingComments, TrailingComments: fieldTrailingComments, IsPointer: isPointer(f, isProto3), } if dv != "" { field.IsSetDefault = true field.DefaultValue, err = parseDefaultValue(f.GetType(), f.GetDefaultValue()) if err != nil { return nil, err } } err = injectTagsToModel(f, &field, true) if err != nil { return nil, err } vs = append(vs, field) } checkDuplicatedFileName(vs) s.Fields = vs structs = append(structs, s) } main.Oneofs = oneofs main.Structs = structs // In case of only the service refers another model, therefore scanning service is necessary ss := ast.GetService() for _, s := range ss { ms := s.GetMethod() for _, m := range ms { _, err := rs.ResolveIdentifier(m.GetInputType()) if err != nil { return nil, err } _, err = rs.ResolveIdentifier(m.GetOutputType()) if err != nil { return nil, err } } } return main, nil } // getMessageLeadingComments can get struct LeadingComment func getMessageLeadingComments(stMessage *desc.MessageDescriptor) string { if stMessage == nil { return "" } stComments := stMessage.GetSourceInfo().GetLeadingComments() stComments = formatComments(stComments) return stComments } // getFiledComments can get field LeadingComments and field TailingComments for field func getFiledComments(f *descriptorpb.FieldDescriptorProto, stMessage *desc.MessageDescriptor) (string, string) { if stMessage == nil { return "", "" } fieldNum := f.GetNumber() field := stMessage.FindFieldByNumber(fieldNum) fieldInfo := field.GetSourceInfo() fieldLeadingComments := fieldInfo.GetLeadingComments() fieldTailingComments := fieldInfo.GetTrailingComments() fieldLeadingComments = formatComments(fieldLeadingComments) fieldTailingComments = formatComments(fieldTailingComments) return fieldLeadingComments, fieldTailingComments } // formatComments can format the comments for beauty func formatComments(comments string) string { if len(comments) == 0 { return "" } comments = util.TrimLastChar(comments) comments = util.AddSlashForComments(comments) return comments } // getNestedMessageInfoMap can get all nested struct func getNestedMessageInfoMap(stMessage *desc.MessageDescriptor) map[string]*desc.MessageDescriptor { nestedMessage := stMessage.GetNestedMessageTypes() nestedMessageInfoMap := make(map[string]*desc.MessageDescriptor, len(nestedMessage)) for _, nestedMsg := range nestedMessage { nestedMsgName := nestedMsg.GetName() nestedMessageInfoMap[nestedMsgName] = nestedMsg } return nestedMessageInfoMap } func parseDefaultValue(typ descriptorpb.FieldDescriptorProto_Type, val string) (model.Literal, error) { switch typ { case descriptorpb.FieldDescriptorProto_TYPE_BYTES, descriptorpb.FieldDescriptorProto_TYPE_STRING: return model.StringExpression{Src: val}, nil case descriptorpb.FieldDescriptorProto_TYPE_BOOL: return model.BoolExpression{Src: val == "true"}, nil case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE, descriptorpb.FieldDescriptorProto_TYPE_FLOAT, descriptorpb.FieldDescriptorProto_TYPE_INT64, descriptorpb.FieldDescriptorProto_TYPE_UINT64, descriptorpb.FieldDescriptorProto_TYPE_INT32, descriptorpb.FieldDescriptorProto_TYPE_FIXED64, descriptorpb.FieldDescriptorProto_TYPE_FIXED32, descriptorpb.FieldDescriptorProto_TYPE_UINT32, descriptorpb.FieldDescriptorProto_TYPE_ENUM, descriptorpb.FieldDescriptorProto_TYPE_SFIXED32, descriptorpb.FieldDescriptorProto_TYPE_SFIXED64, descriptorpb.FieldDescriptorProto_TYPE_SINT32, descriptorpb.FieldDescriptorProto_TYPE_SINT64: return model.NumberExpression{Src: val}, nil default: return nil, fmt.Errorf("unsupported type %s", typ.String()) } } func isPointer(f *descriptorpb.FieldDescriptorProto, isProto3 bool) bool { if f.GetType() == descriptorpb.FieldDescriptorProto_TYPE_MESSAGE || f.GetType() == descriptorpb.FieldDescriptorProto_TYPE_BYTES { return false } if !isProto3 { if f.GetLabel() == descriptorpb.FieldDescriptorProto_LABEL_REPEATED { return false } return true } switch f.GetLabel() { case descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL: if !f.GetProto3Optional() { return false } return true default: return false } } func resolveOneof(stMessage *desc.MessageDescriptor, oneofMap map[string]model.Field, rs *Resolver, isProto3 bool, s model.Struct, ns []*descriptorpb.DescriptorProto) ([]model.Oneof, error) { oneofs := make([]model.Oneof, 0, 1) if len(stMessage.GetOneOfs()) != 0 { for _, oneof := range stMessage.GetOneOfs() { if isProto3 { if oneof.IsSynthetic() { continue } } oneofName := oneof.GetName() messageName := s.Name typeName := "is" + messageName + "_" + oneofName field := model.Field{ Scope: &s, Name: util.CamelString(oneofName), Type: model.NewOneofType(typeName), IsPointer: false, } oneofComment := oneof.GetSourceInfo().GetLeadingComments() oneofComment = formatComments(oneofComment) var oneofLeadingComments string if oneofComment == "" { oneofLeadingComments = fmt.Sprintf(" Types that are assignable to %s:\n", oneofName) } else { oneofLeadingComments = fmt.Sprintf("%s\n//\n// Types that are assignable to %s:\n", oneofComment, oneofName) } for idx, ch := range oneof.GetChoices() { if idx == len(oneof.GetChoices())-1 { oneofLeadingComments = oneofLeadingComments + fmt.Sprintf("// *%s_%s", messageName, ch.GetName()) } else { oneofLeadingComments = oneofLeadingComments + fmt.Sprintf("// *%s_%s\n", messageName, ch.GetName()) } } field.LeadingComments = oneofLeadingComments choices := make([]model.Choice, 0, len(oneof.GetChoices())) for _, ch := range oneof.GetChoices() { t, err := rs.ResolveType(ch.AsFieldDescriptorProto(), ns) if err != nil { return nil, err } choice := model.Choice{ MessageName: messageName, ChoiceName: ch.GetName(), Type: t, } choices = append(choices, choice) oneofMap[ch.GetName()] = field } oneofType := model.Oneof{ MessageName: messageName, OneofName: oneofName, InterfaceName: typeName, Choices: choices, } oneofs = append(oneofs, oneofType) } } return oneofs, nil } func getNewFieldName(fieldName string, fieldNameSet map[string]bool) string { if _, ex := fieldNameSet[fieldName]; ex { fieldName = fieldName + "_" return getNewFieldName(fieldName, fieldNameSet) } return fieldName } func checkDuplicatedFileName(vs []model.Field) { fieldNameSet := make(map[string]bool) for i := 0; i < len(vs); i++ { if _, ex := fieldNameSet[vs[i].Name]; ex { newName := getNewFieldName(vs[i].Name, fieldNameSet) fieldNameSet[newName] = true vs[i].Name = newName } else { fieldNameSet[vs[i].Name] = true } } } ================================================ FILE: cmd/hz/protobuf/plugin.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * Copyright (c) 2018 The Go Authors. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are * met: * * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above * copyright notice, this list of conditions and the following disclaimer * in the documentation and/or other materials provided with the * distribution. * * Neither the name of Google Inc. nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package protobuf import ( "encoding/json" "fmt" "io/ioutil" "os" "path" "path/filepath" "strings" "github.com/cloudwego/hertz/cmd/hz/config" "github.com/cloudwego/hertz/cmd/hz/generator" "github.com/cloudwego/hertz/cmd/hz/generator/model" "github.com/cloudwego/hertz/cmd/hz/meta" "github.com/cloudwego/hertz/cmd/hz/util" "github.com/cloudwego/hertz/cmd/hz/util/logs" gengo "google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo" "google.golang.org/protobuf/compiler/protogen" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/runtime/protoimpl" "google.golang.org/protobuf/types/descriptorpb" "google.golang.org/protobuf/types/pluginpb" ) type Plugin struct { *protogen.Plugin Package string Recursive bool OutDir string ModelDir string UseDir string IdlClientDir string RmTags RemoveTags PkgMap map[string]string logger *logs.StdLogger } type RemoveTags []string func (rm *RemoveTags) Exist(tag string) bool { for _, rmTag := range *rm { if rmTag == tag { return true } } return false } var debugPlugin = os.Getenv("HERTZ_DEBUG_PLUGIN") != "" func (plugin *Plugin) Run() int { plugin.setLogger() args := &config.Argument{} defer func() { if args == nil { return } if args.Verbose { verboseLog := plugin.recvVerboseLogger() if len(verboseLog) != 0 { fmt.Fprintf(os.Stderr, verboseLog) } } else { warning := plugin.recvWarningLogger() if len(warning) != 0 { fmt.Fprintf(os.Stderr, warning) } } }() // read protoc request in, err := ioutil.ReadAll(os.Stdin) if err != nil { logs.Errorf("read request failed: %s\n", err.Error()) return meta.PluginError } req := &pluginpb.CodeGeneratorRequest{} err = proto.Unmarshal(in, req) if err != nil { logs.Errorf("unmarshal request failed: %s\n", err.Error()) return meta.PluginError } args, err = plugin.parseArgs(*req.Parameter) if err != nil { logs.Errorf("parse args failed: %s\n", err.Error()) return meta.PluginError } if debugPlugin { os.WriteFile("./req.pb", in, 0644) js, err := json.Marshal(args) if err != nil { logs.Errorf("marshal request failed: %s\n", err.Error()) return meta.PluginError } os.WriteFile("./args.json", js, 0644) } CheckTagOption(args) // generate err = plugin.Handle(req, args) if err != nil { logs.Errorf("generate failed: %s\n", err.Error()) return meta.PluginError } return 0 } func (plugin *Plugin) setLogger() { plugin.logger = logs.NewStdLogger(logs.LevelInfo) plugin.logger.Defer = true plugin.logger.ErrOnly = true logs.SetLogger(plugin.logger) } func (plugin *Plugin) recvWarningLogger() string { warns := plugin.logger.Warn() plugin.logger.Flush() logs.SetLogger(logs.NewStdLogger(logs.LevelInfo)) return warns } func (plugin *Plugin) recvVerboseLogger() string { info := plugin.logger.Out() warns := plugin.logger.Warn() verboseLog := string(info) + warns plugin.logger.Flush() logs.SetLogger(logs.NewStdLogger(logs.LevelInfo)) return verboseLog } func (plugin *Plugin) parseArgs(param string) (*config.Argument, error) { args := new(config.Argument) params := strings.Split(param, ",") err := args.Unpack(params) if err != nil { return nil, err } plugin.Package, err = args.GetGoPackage() if err != nil { return nil, err } plugin.Recursive = !args.NoRecurse plugin.ModelDir, err = args.GetModelDir() if err != nil { return nil, err } plugin.OutDir = args.OutDir plugin.PkgMap = args.OptPkgMap plugin.UseDir = args.Use return args, nil } func (plugin *Plugin) Response(resp *pluginpb.CodeGeneratorResponse) error { out, err := proto.Marshal(resp) if err != nil { return fmt.Errorf("marshal response failed: %s", err.Error()) } _, err = os.Stdout.Write(out) if err != nil { return fmt.Errorf("write response failed: %s", err.Error()) } return nil } func (plugin *Plugin) Handle(req *pluginpb.CodeGeneratorRequest, args *config.Argument) error { plugin.fixGoPackage(req, plugin.PkgMap, args.TrimGoPackage) // new plugin opts := protogen.Options{} gen, err := opts.New(req) plugin.Plugin = gen plugin.RmTags = args.RmTags if err != nil { return fmt.Errorf("new protoc plugin failed: %s", err.Error()) } // plugin start working err = plugin.GenerateFiles(gen) if err != nil { // Error within the plugin will be responded by the plugin. // But if the plugin does not response correctly, the error is returned to the upper level. err := fmt.Errorf("generate model file failed: %s", err.Error()) gen.Error(err) resp := gen.Response() err2 := plugin.Response(resp) if err2 != nil { return err } return nil } if args.CmdType == meta.CmdModel { resp := gen.Response() // plugin stop working err = plugin.Response(resp) if err != nil { return fmt.Errorf("write response failed: %s", err.Error()) } return nil } files := gen.Request.ProtoFile maps := make(map[string]*descriptorpb.FileDescriptorProto, len(files)) for _, file := range files { maps[file.GetName()] = file } main := maps[gen.Request.FileToGenerate[len(gen.Request.FileToGenerate)-1]] deps := make(map[string]*descriptorpb.FileDescriptorProto, len(main.GetDependency())) for _, dep := range main.GetDependency() { if f, ok := maps[dep]; !ok { err := fmt.Errorf("dependency file not found: %s", dep) gen.Error(err) resp := gen.Response() err2 := plugin.Response(resp) if err2 != nil { return err } return nil } else { deps[dep] = f } } pkgFiles, err := plugin.genHttpPackage(main, deps, args) if err != nil { err := fmt.Errorf("generate package files failed: %s", err.Error()) gen.Error(err) resp := gen.Response() err2 := plugin.Response(resp) if err2 != nil { return err } return nil } // construct plugin response resp := gen.Response() // all files that need to be generated are returned to protoc for _, pkgFile := range pkgFiles { filePath := pkgFile.Path content := pkgFile.Content renderFile := &pluginpb.CodeGeneratorResponse_File{ Name: &filePath, Content: &content, } resp.File = append(resp.File, renderFile) } // plugin stop working err = plugin.Response(resp) if err != nil { return fmt.Errorf("write response failed: %s", err.Error()) } return nil } // fixGoPackage will update go_package to store all the model files in ${model_dir} func (plugin *Plugin) fixGoPackage(req *pluginpb.CodeGeneratorRequest, pkgMap map[string]string, trimGoPackage string) { gopkg := plugin.Package for _, f := range req.ProtoFile { if strings.HasPrefix(f.GetPackage(), "google.protobuf") { continue } if len(trimGoPackage) != 0 && strings.HasPrefix(f.GetOptions().GetGoPackage(), trimGoPackage) { *f.Options.GoPackage = strings.TrimPrefix(*f.Options.GoPackage, trimGoPackage) } opt := getGoPackage(f, pkgMap) if !strings.Contains(opt, gopkg) { if strings.HasPrefix(opt, "/") { opt = gopkg + opt } else { opt = gopkg + "/" + opt } } impt := plugin.fixModelPathAndPackage(opt) *f.Options.GoPackage = impt } } // fixModelPathAndPackage will modify the go_package to adapt the go_package of the hz, // for example adding the go module and model dir. func (plugin *Plugin) fixModelPathAndPackage(pkg string) (impt string) { if strings.HasPrefix(pkg, plugin.Package) { // NOTE: no idea why we need util.ImportToSanitizedPath // it seems like we only need to convert the last part of a package from /a.b.c -> /a_b_c // "cloudwego/hertz/biz/model/a/b/c" -> "/biz/model/a/b/c" impt = util.ImportToSanitizedPath(pkg[len(plugin.Package):]) impt = filepath.ToSlash(impt) // we always use package path instead of filepath in this func // "/biz/model/a/b/c" -> "biz/model/a/b/c" impt = strings.TrimPrefix(impt, "/") } if plugin.ModelDir != "" && plugin.ModelDir != "." { modelPkg := filepath.ToSlash(plugin.ModelDir) if !strings.HasPrefix(impt, modelPkg) { impt = path.Join(modelPkg, impt) // make sure all models under plugin.ModelDir } } impt = path.Join(plugin.Package, impt) return } func (plugin *Plugin) GenerateFiles(pluginPb *protogen.Plugin) error { idl := pluginPb.Request.FileToGenerate[len(pluginPb.Request.FileToGenerate)-1] pluginPb.SupportedFeatures = gengo.SupportedFeatures for _, f := range pluginPb.Files { if f.Proto.GetName() == idl { err := plugin.GenerateFile(pluginPb, f) if err != nil { return err } impt := string(f.GoImportPath) if strings.HasPrefix(impt, plugin.Package) { impt = impt[len(plugin.Package):] } plugin.IdlClientDir = impt } else if plugin.Recursive { if strings.HasPrefix(f.Proto.GetPackage(), "google.protobuf") { continue } err := plugin.GenerateFile(pluginPb, f) if err != nil { return err } } } return nil } func (plugin *Plugin) GenerateFile(gen *protogen.Plugin, f *protogen.File) error { impt := string(f.GoImportPath) if strings.HasPrefix(impt, plugin.Package) { impt = impt[len(plugin.Package):] } f.GeneratedFilenamePrefix = filepath.Join(filepath.FromSlash(impt), util.BaseName(f.Proto.GetName(), ".proto")) f.Generate = true // if use third-party model, no model code is generated within the project if len(plugin.UseDir) != 0 { return nil } file, err := generateFile(gen, f, plugin.RmTags) if err != nil || file == nil { return fmt.Errorf("generate file %s failed: %s", f.Proto.GetName(), err.Error()) } return nil } // generateFile generates the contents of a .pb.go file. func generateFile(gen *protogen.Plugin, file *protogen.File, rmTags RemoveTags) (*protogen.GeneratedFile, error) { filename := file.GeneratedFilenamePrefix + ".pb.go" g := gen.NewGeneratedFile(filename, file.GoImportPath) f := newFileInfo(file) genStandaloneComments(g, f, int32(FileDescriptorProto_Syntax_field_number)) genGeneratedHeader(gen, g, f) genStandaloneComments(g, f, int32(FileDescriptorProto_Package_field_number)) packageDoc := genPackageKnownComment(f) g.P(packageDoc, "package ", f.GoPackageName) g.P() // Emit a static check that enforces a minimum version of the proto package. if gengo.GenerateVersionMarkers { g.P("const (") g.P("// Verify that this generated code is sufficiently up-to-date.") g.P("_ = ", protoimplPackage.Ident("EnforceVersion"), "(", protoimpl.GenVersion, " - ", protoimplPackage.Ident("MinVersion"), ")") g.P("// Verify that runtime/protoimpl is sufficiently up-to-date.") g.P("_ = ", protoimplPackage.Ident("EnforceVersion"), "(", protoimplPackage.Ident("MaxVersion"), " - ", protoimpl.GenVersion, ")") g.P(")") g.P() } for i, imps := 0, f.Desc.Imports(); i < imps.Len(); i++ { genImport(gen, g, f, imps.Get(i)) } for _, enum := range f.allEnums { genEnum(g, f, enum) } var err error for _, message := range f.allMessages { err = genMessage(g, f, message, rmTags) if err != nil { return nil, err } } genExtensions(g, f) genReflectFileDescriptor(gen, g, f) return g, nil } func genMessage(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, rmTags RemoveTags) error { if m.Desc.IsMapEntry() { return nil } // Message type declaration. g.Annotate(m.GoIdent.GoName, m.Location) leadingComments := appendDeprecationSuffix(m.Comments.Leading, m.Desc.Options().(*descriptorpb.MessageOptions).GetDeprecated()) g.P(leadingComments, "type ", m.GoIdent, " struct {") err := genMessageFields(g, f, m, rmTags) if err != nil { return err } g.P("}") g.P() genMessageKnownFunctions(g, f, m) genMessageDefaultDecls(g, f, m) genMessageMethods(g, f, m) genMessageOneofWrapperTypes(g, f, m) return nil } func genMessageFields(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, rmTags RemoveTags) error { sf := f.allMessageFieldsByPtr[m] genMessageInternalFields(g, f, m, sf) var err error for _, field := range m.Fields { err = genMessageField(g, f, m, field, sf, rmTags) if err != nil { return err } } return nil } func genMessageField(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, field *protogen.Field, sf *structFields, rmTags RemoveTags) error { if oneof := field.Oneof; oneof != nil && !oneof.Desc.IsSynthetic() { // It would be a bit simpler to iterate over the oneofs below, // but generating the field here keeps the contents of the Go // struct in the same order as the contents of the source // .proto file. if oneof.Fields[0] != field { return nil // only generate for first appearance } tags := structTags{ {"protobuf_oneof", string(oneof.Desc.Name())}, } if m.isTracked { tags = append(tags, gotrackTags...) } g.Annotate(m.GoIdent.GoName+"."+oneof.GoName, oneof.Location) leadingComments := oneof.Comments.Leading if leadingComments != "" { leadingComments += "\n" } ss := []string{fmt.Sprintf(" Types that are assignable to %s:\n", oneof.GoName)} for _, field := range oneof.Fields { ss = append(ss, "\t*"+field.GoIdent.GoName+"\n") } leadingComments += protogen.Comments(strings.Join(ss, "")) g.P(leadingComments, oneof.GoName, " ", oneofInterfaceName(oneof), tags) sf.append(oneof.GoName) return nil } goType, pointer := fieldGoType(g, f, field) if pointer { goType = "*" + goType } tags := structTags{ {"protobuf", fieldProtobufTagValue(field)}, //{"json", fieldJSONTagValue(field)}, } if field.Desc.IsMap() { key := field.Message.Fields[0] val := field.Message.Fields[1] tags = append(tags, structTags{ {"protobuf_key", fieldProtobufTagValue(key)}, {"protobuf_val", fieldProtobufTagValue(val)}, }...) } err := injectTagsToStructTags(field.Desc, &tags, true, rmTags) if err != nil { return err } if m.isTracked { tags = append(tags, gotrackTags...) } name := field.GoName if field.Desc.IsWeak() { name = WeakFieldPrefix_goname + name } g.Annotate(m.GoIdent.GoName+"."+name, field.Location) leadingComments := appendDeprecationSuffix(field.Comments.Leading, field.Desc.Options().(*descriptorpb.FieldOptions).GetDeprecated()) g.P(leadingComments, name, " ", goType, tags, trailingComment(field.Comments.Trailing)) sf.append(field.GoName) return nil } func (plugin *Plugin) getIdlInfo(ast *descriptorpb.FileDescriptorProto, deps map[string]*descriptorpb.FileDescriptorProto, args *config.Argument) (*generator.HttpPackage, error) { if ast == nil { return nil, fmt.Errorf("ast is nil") } pkg := getGoPackage(ast, map[string]string{}) main := &model.Model{ FilePath: ast.GetName(), Package: pkg, PackageName: util.BaseName(pkg, ""), } fileInfo := FileInfos{ Official: deps, PbReflect: nil, } rs, err := NewResolver(ast, fileInfo, main, map[string]string{}) if err != nil { return nil, fmt.Errorf("new protobuf resolver failed, err:%v", err) } err = rs.LoadAll(ast) if err != nil { return nil, err } services, err := astToService(ast, rs, args.CmdType, plugin.Plugin) if err != nil { return nil, err } var models model.Models for _, s := range services { models.MergeArray(s.Models) } return &generator.HttpPackage{ Services: services, IdlName: ast.GetName(), Package: util.BaseName(pkg, ""), Models: models, }, nil } func (plugin *Plugin) genHttpPackage(ast *descriptorpb.FileDescriptorProto, deps map[string]*descriptorpb.FileDescriptorProto, args *config.Argument) ([]generator.File, error) { options := CheckTagOption(args) idl, err := plugin.getIdlInfo(ast, deps, args) if err != nil { return nil, err } customPackageTemplate := args.CustomizePackage pkg, err := args.GetGoPackage() if err != nil { return nil, err } handlerDir, err := args.GetHandlerDir() if err != nil { return nil, err } routerDir, err := args.GetRouterDir() if err != nil { return nil, err } modelDir, err := args.GetModelDir() if err != nil { return nil, err } clientDir, err := args.GetClientDir() if err != nil { return nil, err } sg := generator.HttpPackageGenerator{ ConfigPath: customPackageTemplate, HandlerDir: handlerDir, RouterDir: routerDir, ModelDir: modelDir, UseDir: args.Use, ClientDir: clientDir, TemplateGenerator: generator.TemplateGenerator{ OutputDir: args.OutDir, Excludes: args.Excludes, }, ProjPackage: pkg, Options: options, HandlerByMethod: args.HandlerByMethod, CmdType: args.CmdType, IdlClientDir: plugin.IdlClientDir, ForceClientDir: args.ForceClientDir, BaseDomain: args.BaseDomain, QueryEnumAsInt: args.QueryEnumAsInt, SnakeStyleMiddleware: args.SnakeStyleMiddleware, SortRouter: args.SortRouter, ForceUpdateClient: args.ForceUpdateClient, } if args.ModelBackend != "" { sg.Backend = meta.Backend(args.ModelBackend) } generator.SetDefaultTemplateConfig() err = sg.Generate(idl) if err != nil { return nil, fmt.Errorf("generate http package error: %v", err) } files, err := sg.GetFormatAndExcludedFiles() if err != nil { return nil, fmt.Errorf("persist http package error: %v", err) } return files, nil } ================================================ FILE: cmd/hz/protobuf/plugin_stubs.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * Copyright (c) 2018 The Go Authors. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are * met: * * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above * copyright notice, this list of conditions and the following disclaimer * in the documentation and/or other materials provided with the * distribution. * * Neither the name of Google Inc. nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package protobuf import ( "fmt" "strconv" "strings" "unicode" "unicode/utf8" _ "unsafe" "google.golang.org/protobuf/compiler/protogen" "google.golang.org/protobuf/reflect/protoreflect" ) // Field numbers for google.protobuf.FileDescriptorProto. const ( FileDescriptorProto_Name_field_number protoreflect.FieldNumber = 1 FileDescriptorProto_Package_field_number protoreflect.FieldNumber = 2 FileDescriptorProto_Dependency_field_number protoreflect.FieldNumber = 3 FileDescriptorProto_PublicDependency_field_number protoreflect.FieldNumber = 10 FileDescriptorProto_WeakDependency_field_number protoreflect.FieldNumber = 11 FileDescriptorProto_MessageType_field_number protoreflect.FieldNumber = 4 FileDescriptorProto_EnumType_field_number protoreflect.FieldNumber = 5 FileDescriptorProto_Service_field_number protoreflect.FieldNumber = 6 FileDescriptorProto_Extension_field_number protoreflect.FieldNumber = 7 FileDescriptorProto_Options_field_number protoreflect.FieldNumber = 8 FileDescriptorProto_SourceCodeInfo_field_number protoreflect.FieldNumber = 9 FileDescriptorProto_Syntax_field_number protoreflect.FieldNumber = 12 ) const WeakFieldPrefix_goname = "XXX_weak_" type fileInfo struct { *protogen.File allEnums []*enumInfo allMessages []*messageInfo allExtensions []*extensionInfo allEnumsByPtr map[*enumInfo]int // value is index into allEnums allMessagesByPtr map[*messageInfo]int // value is index into allMessages allMessageFieldsByPtr map[*messageInfo]*structFields // needRawDesc specifies whether the generator should emit logic to provide // the legacy raw descriptor in GZIP'd form. // This is updated by enum and message generation logic as necessary, // and checked at the end of file generation. needRawDesc bool } type enumInfo struct { *protogen.Enum genJSONMethod bool genRawDescMethod bool } type messageInfo struct { *protogen.Message genRawDescMethod bool genExtRangeMethod bool isTracked bool hasWeak bool } type extensionInfo struct { *protogen.Extension } type structFields struct { count int unexported map[int]string } func (sf *structFields) append(name string) { if r, _ := utf8.DecodeRuneInString(name); !unicode.IsUpper(r) { if sf.unexported == nil { sf.unexported = make(map[int]string) } sf.unexported[sf.count] = name } sf.count++ } type structTags [][2]string func (tags structTags) String() string { if len(tags) == 0 { return "" } var ss []string for _, tag := range tags { // NOTE: When quoting the value, we need to make sure the backtick // character does not appear. Convert all cases to the escaped hex form. key := tag[0] val := strings.Replace(strconv.Quote(tag[1]), "`", `\x60`, -1) ss = append(ss, fmt.Sprintf("%s:%s", key, val)) } return "`" + strings.Join(ss, " ") + "`" } type goImportPath interface { String() string Ident(string) protogen.GoIdent } type trailingComment protogen.Comments func (c trailingComment) String() string { s := strings.TrimSuffix(protogen.Comments(c).String(), "\n") if strings.Contains(s, "\n") { // We don't support multi-lined trailing comments as it is unclear // how to best render them in the generated code. return "" } return s } //go:linkname gotrackTags google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.gotrackTags var gotrackTags structTags var ( protoPackage goImportPath = protogen.GoImportPath("google.golang.org/protobuf/proto") protoifacePackage goImportPath = protogen.GoImportPath("google.golang.org/protobuf/runtime/protoiface") protoimplPackage goImportPath = protogen.GoImportPath("google.golang.org/protobuf/runtime/protoimpl") protojsonPackage goImportPath = protogen.GoImportPath("google.golang.org/protobuf/encoding/protojson") protoreflectPackage goImportPath = protogen.GoImportPath("google.golang.org/protobuf/reflect/protoreflect") protoregistryPackage goImportPath = protogen.GoImportPath("google.golang.org/protobuf/reflect/protoregistry") ) //go:linkname newFileInfo google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.newFileInfo func newFileInfo(file *protogen.File) *fileInfo //go:linkname genPackageKnownComment google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.genPackageKnownComment func genPackageKnownComment(f *fileInfo) protogen.Comments //go:linkname genStandaloneComments google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.genStandaloneComments func genStandaloneComments(g *protogen.GeneratedFile, f *fileInfo, n int32) //go:linkname genGeneratedHeader google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.genGeneratedHeader func genGeneratedHeader(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo) //go:linkname genImport google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.genImport func genImport(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo, imp protoreflect.FileImport) //go:linkname genEnum google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.genEnum func genEnum(g *protogen.GeneratedFile, f *fileInfo, e *enumInfo) //go:linkname genMessageInternalFields google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.genMessageInternalFields func genMessageInternalFields(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo, sf *structFields) //go:linkname genExtensions google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.genExtensions func genExtensions(g *protogen.GeneratedFile, f *fileInfo) //go:linkname genReflectFileDescriptor google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.genReflectFileDescriptor func genReflectFileDescriptor(gen *protogen.Plugin, g *protogen.GeneratedFile, f *fileInfo) //go:linkname appendDeprecationSuffix google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.appendDeprecationSuffix func appendDeprecationSuffix(prefix protogen.Comments, deprecated bool) protogen.Comments //go:linkname genMessageDefaultDecls google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.genMessageDefaultDecls func genMessageDefaultDecls(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) //go:linkname genMessageKnownFunctions google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.genMessageKnownFunctions func genMessageKnownFunctions(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) //go:linkname genMessageMethods google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.genMessageMethods func genMessageMethods(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) //go:linkname genMessageOneofWrapperTypes google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.genMessageOneofWrapperTypes func genMessageOneofWrapperTypes(g *protogen.GeneratedFile, f *fileInfo, m *messageInfo) //go:linkname oneofInterfaceName google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.oneofInterfaceName func oneofInterfaceName(oneof *protogen.Oneof) string //go:linkname fieldGoType google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.fieldGoType func fieldGoType(g *protogen.GeneratedFile, f *fileInfo, field *protogen.Field) (goType string, pointer bool) //go:linkname fieldProtobufTagValue google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.fieldProtobufTagValue func fieldProtobufTagValue(field *protogen.Field) string //go:linkname fieldJSONTagValue google.golang.org/protobuf/cmd/protoc-gen-go/internal_gengo.fieldJSONTagValue func fieldJSONTagValue(field *protogen.Field) string ================================================ FILE: cmd/hz/protobuf/plugin_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package protobuf import ( "io/ioutil" "strings" "testing" "github.com/cloudwego/hertz/cmd/hz/meta" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/pluginpb" ) func TestPlugin_Handle(t *testing.T) { return // where is "../testdata/request_protoc.out" ?????? in, err := ioutil.ReadFile("../testdata/request_protoc.out") if err != nil { t.Fatal(err) } req := &pluginpb.CodeGeneratorRequest{} err = proto.Unmarshal(in, req) if err != nil { t.Fatalf("unmarshal stdin request error: %v", err) } // prepare args plu := &Plugin{} plu.setLogger() args, _ := plu.parseArgs(*req.Parameter) plu.Handle(req, args) plu.recvWarningLogger() } func TestFixModelPathAndPackage(t *testing.T) { plu := &Plugin{} plu.Package = "cloudwego/hertz" plu.ModelDir = meta.ModelDir // default model dir ret1 := [][]string{ {"a/b/c", "cloudwego/hertz/biz/model/a/b/c"}, {"biz/model/a/b/c", "cloudwego/hertz/biz/model/a/b/c"}, {"cloudwego/hertz/a/b/c", "cloudwego/hertz/biz/model/a/b/c"}, {"cloudwego/hertz/biz/model/a/b/c", "cloudwego/hertz/biz/model/a/b/c"}, } for _, r := range ret1 { tmp := r[0] if !strings.Contains(tmp, plu.Package) { if strings.HasPrefix(tmp, "/") { tmp = plu.Package + tmp } else { tmp = plu.Package + "/" + tmp } } result := plu.fixModelPathAndPackage(tmp) if result != r[1] { t.Fatalf("want go package: %s, but get: %s", r[1], result) } } plu.ModelDir = "model_test" // customized model dir ret2 := [][]string{ {"a/b/c", "cloudwego/hertz/model_test/a/b/c"}, {"model_test/a/b/c", "cloudwego/hertz/model_test/a/b/c"}, {"cloudwego/hertz/a/b/c", "cloudwego/hertz/model_test/a/b/c"}, {"cloudwego/hertz/model_test/a/b/c", "cloudwego/hertz/model_test/a/b/c"}, } for _, r := range ret2 { tmp := r[0] if !strings.Contains(tmp, plu.Package) { if strings.HasPrefix(tmp, "/") { tmp = plu.Package + tmp } else { tmp = plu.Package + "/" + tmp } } result := plu.fixModelPathAndPackage(tmp) if result != r[1] { t.Fatalf("want go package: %s, but get: %s", r[1], result) } } } ================================================ FILE: cmd/hz/protobuf/resolver.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package protobuf import ( "fmt" "strings" "github.com/cloudwego/hertz/cmd/hz/generator/model" "github.com/cloudwego/hertz/cmd/hz/util" "github.com/jhump/protoreflect/desc" "google.golang.org/protobuf/types/descriptorpb" ) type Symbol struct { Space string Name string IsValue bool Type *model.Type Value interface{} Scope *descriptorpb.FileDescriptorProto } type NameSpace map[string]*Symbol var ( ConstTrue = Symbol{ IsValue: true, Type: model.TypeBool, Value: true, Scope: &BaseProto, } ConstFalse = Symbol{ IsValue: true, Type: model.TypeBool, Value: false, Scope: &BaseProto, } ConstEmptyString = Symbol{ IsValue: true, Type: model.TypeString, Value: "", Scope: &BaseProto, } ) type PackageReference struct { IncludeBase string IncludePath string Model *model.Model Ast *descriptorpb.FileDescriptorProto Referred bool } func getReferPkgMap(pkgMap map[string]string, incs []*descriptorpb.FileDescriptorProto, mainModel *model.Model) (map[string]*PackageReference, error) { var err error out := make(map[string]*PackageReference, len(pkgMap)) pkgAliasMap := make(map[string]string, len(incs)) // bugfix: add main package to avoid namespace conflict mainPkg := mainModel.Package mainPkgName := mainModel.PackageName mainPkgName, err = util.GetPackageUniqueName(mainPkgName) if err != nil { return nil, err } pkgAliasMap[mainPkg] = mainPkgName for _, inc := range incs { pkg := getGoPackage(inc, pkgMap) path := inc.GetName() base := util.BaseName(path, ".proto") fileName := inc.GetName() pkgName := util.BaseName(pkg, "") if pn, exist := pkgAliasMap[pkg]; exist { pkgName = pn } else { pkgName, err = util.GetPackageUniqueName(pkgName) pkgAliasMap[pkg] = pkgName if err != nil { return nil, fmt.Errorf("get package unique name failed, err: %v", err) } } out[fileName] = &PackageReference{base, path, &model.Model{ FilePath: path, Package: pkg, PackageName: pkgName, }, inc, false} } return out, nil } type FileInfos struct { Official map[string]*descriptorpb.FileDescriptorProto PbReflect map[string]*desc.FileDescriptor } type Resolver struct { // idl symbols rootName string root NameSpace deps map[string]NameSpace // exported models mainPkg PackageReference refPkgs map[string]*PackageReference files FileInfos } func updateFiles(fileName string, files FileInfos) (FileInfos, error) { file, _ := files.PbReflect[fileName] if file == nil { return FileInfos{}, fmt.Errorf("%s not found", fileName) } fileDep := file.GetDependencies() maps := make(map[string]*descriptorpb.FileDescriptorProto, len(fileDep)+1) sourceInfoMap := make(map[string]*desc.FileDescriptor, len(fileDep)+1) for _, dep := range fileDep { ast := dep.AsFileDescriptorProto() maps[dep.GetName()] = ast sourceInfoMap[dep.GetName()] = dep } ast := file.AsFileDescriptorProto() maps[file.GetName()] = ast sourceInfoMap[file.GetName()] = file newFileInfo := FileInfos{ Official: maps, PbReflect: sourceInfoMap, } return newFileInfo, nil } func NewResolver(ast *descriptorpb.FileDescriptorProto, files FileInfos, model *model.Model, pkgMap map[string]string) (*Resolver, error) { file := ast.GetName() deps := ast.GetDependency() var err error if files.PbReflect != nil { files, err = updateFiles(file, files) if err != nil { return nil, err } } incs := make([]*descriptorpb.FileDescriptorProto, 0, len(deps)) for _, dep := range deps { if v, ok := files.Official[dep]; ok { incs = append(incs, v) } else { return nil, fmt.Errorf("%s not found", dep) } } pm, err := getReferPkgMap(pkgMap, incs, model) if err != nil { return nil, fmt.Errorf("get package map failed, err: %v", err) } return &Resolver{ root: make(NameSpace), deps: make(map[string]NameSpace), refPkgs: pm, files: files, mainPkg: PackageReference{ IncludeBase: util.BaseName(file, ".proto"), IncludePath: file, Model: model, Ast: ast, Referred: false, }, }, nil } func (resolver *Resolver) GetRefModel(includeBase string) (*model.Model, error) { if includeBase == "" { return resolver.mainPkg.Model, nil } ref, ok := resolver.refPkgs[includeBase] if !ok { return nil, fmt.Errorf("%s not found", includeBase) } return ref.Model, nil } func (resolver *Resolver) getBaseType(f *descriptorpb.FieldDescriptorProto, nested []*descriptorpb.DescriptorProto) (*model.Type, error) { bt := switchBaseType(f.GetType()) if bt != nil { return checkListType(bt, f.GetLabel()), nil } nt := getNestedType(f, nested) if nt != nil { fields := nt.GetField() if IsMapEntry(nt) { t := *model.TypeBaseMap tk, err := resolver.ResolveType(fields[0], nt.GetNestedType()) if err != nil { return nil, err } tv, err := resolver.ResolveType(fields[1], nt.GetNestedType()) if err != nil { return nil, err } t.Extra = []*model.Type{tk, tv} return &t, nil } } return nil, nil } func IsMapEntry(nt *descriptorpb.DescriptorProto) bool { fields := nt.GetField() return len(fields) == 2 && fields[0].GetName() == "key" && fields[1].GetName() == "value" } func checkListType(typ *model.Type, label descriptorpb.FieldDescriptorProto_Label) *model.Type { if label == descriptorpb.FieldDescriptorProto_LABEL_REPEATED { t := *model.TypeBaseList t.Extra = []*model.Type{typ} return &t } return typ } func getNestedType(f *descriptorpb.FieldDescriptorProto, nested []*descriptorpb.DescriptorProto) *descriptorpb.DescriptorProto { tName := f.GetTypeName() entry := util.SplitPackageName(tName, "") for _, nt := range nested { if nt.GetName() == entry { return nt } } return nil } func (resolver *Resolver) ResolveType(f *descriptorpb.FieldDescriptorProto, nested []*descriptorpb.DescriptorProto) (*model.Type, error) { bt, err := resolver.getBaseType(f, nested) if err != nil { return nil, err } if bt != nil { return bt, nil } tName := f.GetTypeName() symbol, err := resolver.ResolveIdentifier(tName) if err != nil { return nil, err } deepType := checkListType(symbol.Type, f.GetLabel()) return deepType, nil } func (resolver *Resolver) ResolveIdentifier(id string) (ret *Symbol, err error) { ret = resolver.Get(id) if ret == nil { return nil, fmt.Errorf("not found identifier %s", id) } var ref *PackageReference if _, ok := resolver.deps[ret.Space]; ok { ref = resolver.refPkgs[ret.Scope.GetName()] if ref != nil { ref.Referred = true ret.Type.Scope = ref.Model } } // bugfix: root & dep file has the same package(namespace), the 'ret' will miss the namespace match for root. // This results in a lack of dependencies in the generated handlers. if ref == nil && ret.Scope == resolver.mainPkg.Ast { resolver.mainPkg.Referred = true ret.Type.Scope = resolver.mainPkg.Model } return } func (resolver *Resolver) getFieldType(f *descriptorpb.FieldDescriptorProto, nested []*descriptorpb.DescriptorProto) (*model.Type, error) { dt, err := resolver.getBaseType(f, nested) if err != nil { return nil, err } if dt != nil { return dt, nil } sb := resolver.Get(f.GetTypeName()) if sb != nil { return sb.Type, nil } return nil, fmt.Errorf("not found type %s", f.GetTypeName()) } func (resolver *Resolver) Get(name string) *Symbol { if strings.HasPrefix(name, "."+resolver.rootName) { id := strings.TrimPrefix(name, "."+resolver.rootName+".") if v, ok := resolver.root[id]; ok { return v } } // directly map first var space string if idx := strings.LastIndex(name, "."); idx >= 0 && idx < len(name)-1 { space = strings.TrimLeft(name[:idx], ".") } if ns, ok := resolver.deps[space]; ok { id := strings.TrimPrefix(name, "."+space+".") if s, ok := ns[id]; ok { return s } } // iterate check nested type in dependencies for s, m := range resolver.deps { if strings.HasPrefix(name, "."+s) { id := strings.TrimPrefix(name, "."+s+".") if s, ok := m[id]; ok { return s } } } return nil } func (resolver *Resolver) ExportReferred(all, needMain bool) (ret []*PackageReference) { for _, v := range resolver.refPkgs { if all { ret = append(ret, v) } else if v.Referred { ret = append(ret, v) } v.Referred = false } if needMain && (all || resolver.mainPkg.Referred) { ret = append(ret, &resolver.mainPkg) } resolver.mainPkg.Referred = false return } func (resolver *Resolver) LoadAll(ast *descriptorpb.FileDescriptorProto) error { var err error resolver.root, err = resolver.LoadOne(ast) if err != nil { return fmt.Errorf("load main idl failed: %s", err) } resolver.rootName = ast.GetPackage() includes := ast.GetDependency() astMap := make(map[string]NameSpace, len(includes)) for _, dep := range includes { file, ok := resolver.files.Official[dep] if !ok { return fmt.Errorf("not found included idl %s", dep) } depNamespace, err := resolver.LoadOne(file) if err != nil { return fmt.Errorf("load idl '%s' failed: %s", dep, err) } ns, existed := astMap[file.GetPackage()] if existed { depNamespace = mergeNamespace(ns, depNamespace) } astMap[file.GetPackage()] = depNamespace } resolver.deps = astMap return nil } func mergeNamespace(first, second NameSpace) NameSpace { for k, v := range second { if _, existed := first[k]; !existed { first[k] = v } } return first } func LoadBaseIdentifier(ast *descriptorpb.FileDescriptorProto) map[string]*Symbol { ret := make(NameSpace, len(ast.GetEnumType())+len(ast.GetMessageType())+len(ast.GetExtension())+len(ast.GetService())) ret["true"] = &ConstTrue ret["false"] = &ConstFalse ret[`""`] = &ConstEmptyString ret["bool"] = &Symbol{ Type: model.TypeBool, Scope: ast, } ret["uint32"] = &Symbol{ Type: model.TypeUint32, Scope: ast, } ret["uint64"] = &Symbol{ Type: model.TypeUint64, Scope: ast, } ret["fixed32"] = &Symbol{ Type: model.TypeUint32, Scope: ast, } ret["fixed64"] = &Symbol{ Type: model.TypeUint64, Scope: ast, } ret["int32"] = &Symbol{ Type: model.TypeInt32, Scope: ast, } ret["int64"] = &Symbol{ Type: model.TypeInt64, Scope: ast, } ret["sint32"] = &Symbol{ Type: model.TypeInt32, Scope: ast, } ret["sint64"] = &Symbol{ Type: model.TypeInt64, Scope: ast, } ret["sfixed32"] = &Symbol{ Type: model.TypeInt32, Scope: ast, } ret["sfixed64"] = &Symbol{ Type: model.TypeInt64, Scope: ast, } ret["double"] = &Symbol{ Type: model.TypeFloat64, Scope: ast, } ret["float"] = &Symbol{ Type: model.TypeFloat32, Scope: ast, } ret["string"] = &Symbol{ Type: model.TypeString, Scope: ast, } ret["bytes"] = &Symbol{ Type: model.TypeBinary, Scope: ast, } return ret } func (resolver *Resolver) LoadOne(ast *descriptorpb.FileDescriptorProto) (NameSpace, error) { ret := LoadBaseIdentifier(ast) space := util.BaseName(ast.GetPackage(), "") prefix := "." + space for _, e := range ast.GetEnumType() { name := strings.TrimLeft(e.GetName(), prefix) ret[e.GetName()] = &Symbol{ Name: name, Space: space, IsValue: false, Value: e, Scope: ast, Type: model.NewEnumType(name, model.CategoryEnum), } for _, ee := range e.GetValue() { name := strings.TrimLeft(ee.GetName(), prefix) ret[ee.GetName()] = &Symbol{ Name: name, Space: space, IsValue: true, Value: ee, Scope: ast, Type: model.NewCategoryType(model.TypeInt, model.CategoryEnum), } } } for _, mt := range ast.GetMessageType() { name := strings.TrimLeft(mt.GetName(), prefix) ret[mt.GetName()] = &Symbol{ Name: name, Space: space, IsValue: false, Value: mt, Scope: ast, Type: model.NewStructType(name, model.CategoryStruct), } for _, nt := range mt.GetNestedType() { ntname := name + "_" + nt.GetName() ret[name+"."+nt.GetName()] = &Symbol{ Name: ntname, Space: space, IsValue: false, Value: nt, Scope: ast, Type: model.NewStructType(ntname, model.CategoryStruct), } } } for _, s := range ast.GetService() { name := strings.TrimLeft(s.GetName(), prefix) ret[s.GetName()] = &Symbol{ Name: name, Space: space, IsValue: false, Value: s, Scope: ast, Type: model.NewFuncType(name, model.CategoryService), } } return ret, nil } func (resolver *Resolver) GetFiles() FileInfos { return resolver.files } ================================================ FILE: cmd/hz/protobuf/tag_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package protobuf import ( "io/ioutil" "strings" "testing" "google.golang.org/protobuf/compiler/protogen" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/pluginpb" ) func TestTagGenerate(t *testing.T) { type TagStruct struct { Annotation string GeneratedTag string ActualTag string } tagList := []TagStruct{ { Annotation: "query", GeneratedTag: "protobuf:\"bytes,1,opt,name=QueryTag\" json:\"QueryTag,omitempty\" query:\"query\"", }, { Annotation: "raw_body", GeneratedTag: "protobuf:\"bytes,2,opt,name=RawBodyTag\" json:\"RawBodyTag,omitempty\" raw_body:\"raw_body\"", }, { Annotation: "path", GeneratedTag: "protobuf:\"bytes,3,opt,name=PathTag\" json:\"PathTag,omitempty\" path:\"path\"", }, { Annotation: "form", GeneratedTag: "protobuf:\"bytes,4,opt,name=FormTag\" form:\"form\" json:\"FormTag,omitempty\"", }, { Annotation: "cookie", GeneratedTag: "protobuf:\"bytes,5,opt,name=CookieTag\" cookie:\"cookie\" json:\"CookieTag,omitempty\"", }, { Annotation: "header", GeneratedTag: "protobuf:\"bytes,6,opt,name=HeaderTag\" header:\"header\" json:\"HeaderTag,omitempty\"", }, { Annotation: "body", GeneratedTag: "bytes,7,opt,name=BodyTag\" form:\"body\" json:\"body,omitempty\"", }, { Annotation: "go.tag", GeneratedTag: "bytes,8,opt,name=GoTag\" form:\"form\" goTag:\"tag\" header:\"header\" json:\"json\" query:\"query\"", }, { Annotation: "vd", GeneratedTag: "bytes,9,opt,name=VdTag\" form:\"VdTag\" json:\"VdTag,omitempty\" query:\"VdTag\" vd:\"$!='?'\"", }, { Annotation: "non", GeneratedTag: "bytes,10,opt,name=DefaultTag\" form:\"DefaultTag\" json:\"DefaultTag,omitempty\" query:\"DefaultTag\"", }, { Annotation: "query required", GeneratedTag: "bytes,11,req,name=ReqQuery\" json:\"ReqQuery,required\" query:\"query,required\"", }, { Annotation: "query optional", GeneratedTag: "bytes,12,opt,name=OptQuery\" json:\"OptQuery,omitempty\" query:\"query\"", }, { Annotation: "body required", GeneratedTag: "protobuf:\"bytes,13,req,name=ReqBody\" form:\"body,required\" json:\"body,required\"", }, { Annotation: "body optional", GeneratedTag: "protobuf:\"bytes,14,opt,name=OptBody\" form:\"body\" json:\"body,omitempty\"", }, { Annotation: "go.tag required", GeneratedTag: "protobuf:\"bytes,15,req,name=ReqGoTag\" form:\"ReqGoTag,required\" json:\"json\" query:\"ReqGoTag,required\"", }, { Annotation: "go.tag optional", GeneratedTag: "bytes,16,opt,name=OptGoTag\" form:\"OptGoTag\" json:\"json\" query:\"OptGoTag\"", }, { Annotation: "go tag cover query", GeneratedTag: "bytes,17,req,name=QueryGoTag\" json:\"QueryGoTag,required\" query:\"queryTag\"", }, } in, err := ioutil.ReadFile("./test_data/protobuf_tag_test.out") if err != nil { t.Fatal(err) } req := &pluginpb.CodeGeneratorRequest{} err = proto.Unmarshal(in, req) if err != nil { t.Fatalf("unmarshal stdin request error: %v", err) } opts := protogen.Options{} gen, err := opts.New(req) for _, f := range gen.Files { if f.Proto.GetName() == "test_tag.proto" { fileInfo := newFileInfo(f) for _, message := range fileInfo.allMessages { for idx, field := range message.Fields { tags := structTags{ {"protobuf", fieldProtobufTagValue(field)}, } err = injectTagsToStructTags(field.Desc, &tags, true, nil) if err != nil { t.Fatal(err) } var actualTag string for i, tag := range tags { if i == 0 { actualTag = tag[0] + ":" + "\"" + tag[1] + "\"" } else { actualTag = actualTag + " " + tag[0] + ":" + "\"" + tag[1] + "\"" } } tagList[idx].ActualTag = actualTag } } } } for i := range tagList { if !strings.Contains(tagList[i].ActualTag, tagList[i].GeneratedTag) { t.Fatalf("expected tag: '%s', but autual tag: '%s'", tagList[i].GeneratedTag, tagList[i].ActualTag) } } } ================================================ FILE: cmd/hz/protobuf/tags.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package protobuf import ( "fmt" "sort" "strconv" "strings" "github.com/cloudwego/hertz/cmd/hz/config" "github.com/cloudwego/hertz/cmd/hz/generator" "github.com/cloudwego/hertz/cmd/hz/generator/model" "github.com/cloudwego/hertz/cmd/hz/protobuf/api" "github.com/cloudwego/hertz/cmd/hz/util" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/runtime/protoimpl" "google.golang.org/protobuf/types/descriptorpb" ) var ( jsonSnakeName = false unsetOmitempty = false protobufCamelJSONTagStyle = false ) func CheckTagOption(args *config.Argument) (ret []generator.Option) { if args == nil { return } if args.SnakeName { jsonSnakeName = true } if args.UnsetOmitempty { unsetOmitempty = true } if args.JSONEnumStr { ret = append(ret, generator.OptionMarshalEnumToText) } if args.ProtobufCamelJSONTag { protobufCamelJSONTagStyle = true } return ret } func checkSnakeName(name string) string { if jsonSnakeName { name = util.ToSnakeCase(name) } return name } var ( HttpMethodOptions = map[*protoimpl.ExtensionInfo]string{ api.E_Get: "GET", api.E_Post: "POST", api.E_Put: "PUT", api.E_Patch: "PATCH", api.E_Delete: "DELETE", api.E_Options: "OPTIONS", api.E_Head: "HEAD", api.E_Any: "Any", } BindingTags = map[*protoimpl.ExtensionInfo]string{ api.E_Path: "path", api.E_Query: "query", api.E_Header: "header", api.E_Cookie: "cookie", api.E_Body: "json", // Do not change the relative order of "api.E_Form" and "api.E_Body", so that "api.E_Form" can overwrite the form tag generated by "api.E_Body" api.E_Form: "form", api.E_FormCompatible: "form", api.E_RawBody: "raw_body", } ValidatorTags = map[*protoimpl.ExtensionInfo]string{api.E_Vd: "vd"} SerializerOptions = map[*protoimpl.ExtensionInfo]string{api.E_Serializer: "serializer"} ) type httpOption struct { method string path string } type httpOptions []httpOption func (s httpOptions) Len() int { return len(s) } func (s httpOptions) Swap(i, j int) { s[i], s[j] = s[j], s[i] } func (s httpOptions) Less(i, j int) bool { return s[i].method < s[j].method } func getAllOptions(extensions map[*protoimpl.ExtensionInfo]string, opts ...protoreflect.ProtoMessage) map[string]interface{} { out := map[string]interface{}{} for _, opt := range opts { for e, t := range extensions { if proto.HasExtension(opt, e) { v := proto.GetExtension(opt, e) out[t] = v } } } return out } func checkFirstOptions(extensions map[*protoimpl.ExtensionInfo]string, opts ...protoreflect.ProtoMessage) (string, interface{}) { for _, opt := range opts { for e, t := range extensions { if proto.HasExtension(opt, e) { v := proto.GetExtension(opt, e) return t, v } } } return "", nil } func checkFirstOption(ext *protoimpl.ExtensionInfo, opts ...protoreflect.ProtoMessage) interface{} { for _, opt := range opts { if proto.HasExtension(opt, ext) { v := proto.GetExtension(opt, ext) return v } } return nil } func checkOption(ext *protoimpl.ExtensionInfo, opts ...protoreflect.ProtoMessage) (ret []interface{}) { for _, opt := range opts { if proto.HasExtension(opt, ext) { v := proto.GetExtension(opt, ext) ret = append(ret, v) } } return } func tag(k string, v interface{}) model.Tag { return model.Tag{ Key: k, Value: fmt.Sprintf("%v", v), } } //-----------------------------------For Compiler--------------------------- func defaultBindingTags(f *descriptorpb.FieldDescriptorProto) []model.Tag { opts := f.GetOptions() out := make([]model.Tag, 3) if v := checkFirstOption(api.E_Body, opts); v != nil { val := getJsonValue(f, v.(string)) out[0] = tag("json", val) } else { out[0] = jsonTag(f) } if v := checkFirstOption(api.E_Query, opts); v != nil { val := checkRequire(f, v.(string)) out[1] = tag(BindingTags[api.E_Query], val) } else { val := checkRequire(f, checkSnakeName(f.GetName())) out[1] = tag(BindingTags[api.E_Query], val) } if v := checkFirstOption(api.E_Form, opts); v != nil { val := checkRequire(f, v.(string)) out[2] = tag(BindingTags[api.E_Form], val) } else { val := checkRequire(f, checkSnakeName(f.GetName())) out[2] = tag(BindingTags[api.E_Form], val) } return out } func jsonTag(f *descriptorpb.FieldDescriptorProto) (ret model.Tag) { ret.Key = "json" ret.Value = checkSnakeName(f.GetJsonName()) if v := checkFirstOption(api.E_JsConv, f.GetOptions()); v != nil { ret.Value += ",string" } else if v := checkFirstOption(api.E_JsConvCompatible, f.GetOptions()); v != nil { ret.Value += ",string" } if !unsetOmitempty && f.GetLabel() == descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL { ret.Value += ",omitempty" } else if f.GetLabel() == descriptorpb.FieldDescriptorProto_LABEL_REQUIRED { ret.Value += ",required" } return } func injectTagsToModel(f *descriptorpb.FieldDescriptorProto, gf *model.Field, needDefault bool) error { as := f.GetOptions() tags := gf.Tags if tags == nil { tags = make([]model.Tag, 0, 4) } // binding tags if needDefault { tags = append(tags, defaultBindingTags(f)...) } for k, v := range BindingTags { if vv := checkFirstOption(k, as); vv != nil { tags.Remove(v) if v == "json" { vv = getJsonValue(f, vv.(string)) } else { vv = checkRequire(f, vv.(string)) } tags = append(tags, tag(v, vv)) } } // validator tags for k, v := range ValidatorTags { for _, vv := range checkOption(k, as) { tags = append(tags, tag(v, vv)) } } // go.tags for _, v := range checkOption(api.E_GoTag, as) { gts := util.SplitGoTags(v.(string)) for _, gt := range gts { sp := strings.SplitN(gt, ":", 2) if len(sp) != 2 { return fmt.Errorf("invalid go tag: %s", v) } vv, err := strconv.Unquote(sp[1]) if err != nil { return fmt.Errorf("invalid go.tag value: %s, err: %v", sp[1], err.Error()) } key := sp[0] tags.Remove(key) tags = append(tags, model.Tag{ Key: key, Value: vv, }) } } sort.Sort(tags) gf.Tags = tags return nil } func getJsonValue(f *descriptorpb.FieldDescriptorProto, val string) string { if v := checkFirstOption(api.E_JsConv, f.GetOptions()); v != nil { val += ",string" } else if v := checkFirstOption(api.E_JsConvCompatible, f.GetOptions()); v != nil { val += ",string" } if !unsetOmitempty && f.GetLabel() == descriptorpb.FieldDescriptorProto_LABEL_OPTIONAL { val += ",omitempty" } else if f.GetLabel() == descriptorpb.FieldDescriptorProto_LABEL_REQUIRED { val += ",required" } return val } func checkRequire(f *descriptorpb.FieldDescriptorProto, val string) string { if f.GetLabel() == descriptorpb.FieldDescriptorProto_LABEL_REQUIRED { val += ",required" } return val } //-------------------------For plugin--------------------------------- func m2s(mt model.Tag) (ret [2]string) { ret[0] = mt.Key ret[1] = mt.Value return ret } func reflectJsonTag(f protoreflect.FieldDescriptor) (ret model.Tag) { ret.Key = "json" if protobufCamelJSONTagStyle { ret.Value = checkSnakeName(f.JSONName()) } else { ret.Value = checkSnakeName(string(f.Name())) } if v := checkFirstOption(api.E_Body, f.Options()); v != nil { ret.Value += ",string" } if descriptorpb.FieldDescriptorProto_Label(f.Cardinality()) == descriptorpb.FieldDescriptorProto_LABEL_REQUIRED { ret.Value += ",required" } else if !unsetOmitempty { ret.Value += ",omitempty" } return } func defaultBindingStructTags(f protoreflect.FieldDescriptor) []model.Tag { opts := f.Options() out := make([]model.Tag, 3) bindingTags := []*protoimpl.ExtensionInfo{ api.E_Path, api.E_Query, api.E_Form, api.E_FormCompatible, api.E_Header, api.E_Cookie, api.E_Body, api.E_RawBody, } // If the user provides an annotation, return json tag directly for _, tag := range bindingTags { if vv := checkFirstOption(tag, opts); vv != nil { out[0] = reflectJsonTag(f) return out[:1] } } if v := checkFirstOption(api.E_Body, opts); v != nil { val := getStructJsonValue(f, v.(string)) out[0] = tag("json", val) } else { t := reflectJsonTag(f) t.IsDefault = true out[0] = t } if v := checkFirstOption(api.E_Query, opts); v != nil { val := checkStructRequire(f, v.(string)) out[1] = tag(BindingTags[api.E_Query], val) } else { val := checkStructRequire(f, checkSnakeName(string(f.Name()))) t := tag(BindingTags[api.E_Query], val) t.IsDefault = true out[1] = t } if v := checkFirstOption(api.E_Form, opts); v != nil { val := checkStructRequire(f, v.(string)) out[2] = tag(BindingTags[api.E_Form], val) } else { if v := checkFirstOption(api.E_FormCompatible, opts); v != nil { // compatible form_compatible val := checkStructRequire(f, v.(string)) t := tag(BindingTags[api.E_Form], val) t.IsDefault = true out[2] = t } else { val := checkStructRequire(f, checkSnakeName(string(f.Name()))) t := tag(BindingTags[api.E_Form], val) t.IsDefault = true out[2] = t } } return out } func injectTagsToStructTags(f protoreflect.FieldDescriptor, out *structTags, needDefault bool, rmTags RemoveTags) error { as := f.Options() // binding tags tags := model.Tags(make([]model.Tag, 0, 6)) if needDefault { tags = append(tags, defaultBindingStructTags(f)...) } for k, v := range BindingTags { if vv := checkFirstOption(k, as); vv != nil { tags.Remove(v) // body annotation will generate "json" & "form" tag for protobuf if v == "json" { formVal := vv vv = getStructJsonValue(f, vv.(string)) formVal = checkStructRequire(f, formVal.(string)) tags = append(tags, tag("form", formVal)) } else { vv = checkStructRequire(f, vv.(string)) } tags = append(tags, tag(v, vv)) } } // validator tags for k, v := range ValidatorTags { if vv := checkFirstOption(k, as); vv != nil { tags = append(tags, tag(v, vv)) } } if v := checkFirstOption(api.E_GoTag, as); v != nil { gts := util.SplitGoTags(v.(string)) for _, gt := range gts { sp := strings.SplitN(gt, ":", 2) if len(sp) != 2 { return fmt.Errorf("invalid go tag: %s", v) } vv, err := strconv.Unquote(sp[1]) if err != nil { return fmt.Errorf("invalid go.tag value: %s, err: %v", sp[1], err.Error()) } key := sp[0] tags.Remove(key) tags = append(tags, model.Tag{ Key: key, Value: vv, }) } } disableTag := false if vv := checkFirstOption(api.E_None, as); vv != nil { if strings.EqualFold(vv.(string), "true") { disableTag = true } } else if vv := checkFirstOption(api.E_NoneCompatible, as); vv != nil { if strings.EqualFold(vv.(string), "true") { disableTag = true } } for _, t := range tags { if t.IsDefault && rmTags.Exist(t.Key) { tags.Remove(t.Key) } } sort.Sort(tags) for _, t := range tags { if disableTag { *out = append(*out, [2]string{t.Key, "-"}) } else { *out = append(*out, m2s(t)) } } return nil } func getStructJsonValue(f protoreflect.FieldDescriptor, val string) string { if v := checkFirstOption(api.E_JsConv, f.Options()); v != nil { val += ",string" } else if v := checkFirstOption(api.E_JsConvCompatible, f.Options()); v != nil { val += ",string" } if descriptorpb.FieldDescriptorProto_Label(f.Cardinality()) == descriptorpb.FieldDescriptorProto_LABEL_REQUIRED { val += ",required" } else if !unsetOmitempty { val += ",omitempty" } return val } func checkStructRequire(f protoreflect.FieldDescriptor, val string) string { if descriptorpb.FieldDescriptorProto_Label(f.Cardinality()) == descriptorpb.FieldDescriptorProto_LABEL_REQUIRED { val += ",required" } return val } ================================================ FILE: cmd/hz/protobuf/test_data/test_tag.proto ================================================ syntax = "proto2"; package test; option go_package = "cloudwego.hertz.hz"; import "api.proto"; message MultiTagReq { // basic feature optional string QueryTag = 1 [(api.query)="query"]; optional string RawBodyTag = 2 [(api.raw_body)="raw_body"]; optional string PathTag = 3 [(api.path)="path"]; optional string FormTag = 4 [(api.form)="form"]; optional string CookieTag = 5 [(api.cookie)="cookie"]; optional string HeaderTag = 6 [(api.header)="header"]; optional string BodyTag = 7 [(api.body)="body"]; optional string GoTag = 8 [(api.go_tag)="json:\"json\" query:\"query\" form:\"form\" header:\"header\" goTag:\"tag\""]; optional string VdTag = 9 [(api.vd)="$!='?'"]; optional string DefaultTag = 10; // optional / required required string ReqQuery = 11 [(api.query)="query"]; optional string OptQuery = 12 [(api.query)="query"]; required string ReqBody = 13 [(api.body)="body"]; optional string OptBody = 14 [(api.body)="body"]; required string ReqGoTag = 15 [(api.go_tag)="json:\"json\""]; optional string OptGoTag = 16 [(api.go_tag)="json:\"json\""]; // gotag cover feature required string QueryGoTag = 17 [(api.query)="query", (api.go_tag)="query:\"queryTag\""]; } ================================================ FILE: cmd/hz/test_hz_unix.sh ================================================ #! /usr/bin/env bash set -e # const value define moduleName="github.com/cloudwego/hertz/cmd/hz/test" curDir=`pwd` thriftIDL=$curDir"/testdata/thrift/psm.thrift" protobuf2IDL=$curDir"/testdata/protobuf2/psm/psm.proto" proto2Search=$curDir"/testdata/protobuf2" protobuf3IDL=$curDir"/testdata/protobuf3/psm/psm.proto" proto3Search=$curDir"/testdata/protobuf3" protoSearch="/usr/local/include" compile_hz() { go build -o hz } PATH_BIN=$PWD/bin mkdir -p $PATH_BIN export PATH=$PATH_BIN:$PATH install_dependent_tools() { # install thriftgo go install github.com/cloudwego/thriftgo@latest # install protoc wget https://github.com/protocolbuffers/protobuf/releases/download/v3.19.4/protoc-3.19.4-linux-x86_64.zip unzip -d protoc-3.19.4-linux-x86_64 protoc-3.19.4-linux-x86_64.zip cp protoc-3.19.4-linux-x86_64/bin/protoc $PATH_BIN cp -r protoc-3.19.4-linux-x86_64/include/google $PATH_BIN } go_tidy_build() { # make sure we get the latest version for testing go get github.com/cloudwego/hertz@develop go mod tidy && go build . } test_thrift() { mkdir -p test cd test ../hz new --idl=$thriftIDL --mod=$moduleName -f --model_dir=hertz_model --handler_dir=hertz_handler --router_dir=hertz_router go_tidy_build ../hz update --idl=$thriftIDL ../hz model --idl=$thriftIDL --model_dir=hertz_model ../hz client --idl=$thriftIDL --client_dir=hertz_client cd .. rm -rf test } test_protobuf2() { # test protobuf2 mkdir -p test cd test ../hz new -I=$protoSearch -I=$proto2Search --idl=$protobuf2IDL --mod=$moduleName -f --model_dir=hertz_model --handler_dir=hertz_handler --router_dir=hertz_router go_tidy_build ../hz update -I=$protoSearch -I=$proto2Search --idl=$protobuf2IDL ../hz model -I=$protoSearch -I=$proto2Search --idl=$protobuf2IDL --model_dir=hertz_model ../hz client -I=$protoSearch -I=$proto2Search --idl=$protobuf2IDL --client_dir=hertz_client cd .. rm -rf test } test_protobuf3() { # test protobuf2 mkdir -p test cd test ../hz new -I=$protoSearch -I=$proto3Search --idl=$protobuf3IDL --mod=$moduleName -f --model_dir=hertz_model --handler_dir=hertz_handler --router_dir=hertz_router go_tidy_build ../hz update -I=$protoSearch -I=$proto3Search --idl=$protobuf3IDL ../hz model -I=$protoSearch -I=$proto3Search --idl=$protobuf3IDL --model_dir=hertz_model ../hz client -I=$protoSearch -I=$proto3Search --idl=$protobuf3IDL --client_dir=hertz_client cd .. rm -rf test } main() { compile_hz install_dependent_tools echo "test thrift......" test_thrift echo "test protobuf2......" test_protobuf2 echo "test protobuf3......" test_protobuf3 echo "hz execute success" } main ================================================ FILE: cmd/hz/test_hz_windows.sh ================================================ #! /usr/bin/env bash set -e # const value define moduleName="github.com/cloudwego/hertz/cmd/hz/test" curDir=`pwd` thriftIDL=$curDir"/testdata/thrift/psm.thrift" protobuf2IDL=$curDir"/testdata/protobuf2/psm/psm.proto" proto2Search=$curDir"/testdata/protobuf2" protobuf3IDL=$curDir"/testdata/protobuf3/psm/psm.proto" proto3Search=$curDir"/testdata/protobuf3" protoSearch=$curDir"/testdata/include" compile_hz() { go install . } install_dependent_tools() { # install thriftgo go install github.com/cloudwego/thriftgo@latest } go_tidy_build() { # make sure we get the latest version for testing go get github.com/cloudwego/hertz@develop go mod tidy && go build . } test_thrift() { # test thrift mkdir -p test cd test hz new --idl=$thriftIDL --mod=$moduleName -f --model_dir=hertz_model --handler_dir=hertz_handler --router_dir=hertz_router go_tidy_build hz update --idl=$thriftIDL hz model --idl=$thriftIDL --model_dir=hertz_model hz client --idl=$thriftIDL --client_dir=hertz_client cd .. rm -rf test } test_protobuf2() { # test protobuf2 mkdir -p test cd test hz new -I=$protoSearch -I=$proto2Search --idl=$protobuf2IDL --mod=$moduleName -f --model_dir=hertz_model --handler_dir=hertz_handler --router_dir=hertz_router go_tidy_build hz update -I=$protoSearch -I=$proto2Search --idl=$protobuf2IDL hz model -I=$protoSearch -I=$proto2Search --idl=$protobuf2IDL --model_dir=hertz_model hz client -I=$protoSearch -I=$proto2Search --idl=$protobuf2IDL --client_dir=hertz_client cd .. rm -rf test } test_protobuf3() { # test protobuf2 mkdir -p test cd test hz new -I=$protoSearch -I=$proto3Search --idl=$protobuf3IDL --mod=$moduleName -f --model_dir=hertz_model --handler_dir=hertz_handler --router_dir=hertz_router go_tidy_build hz update -I=$protoSearch -I=$proto3Search --idl=$protobuf3IDL hz model -I=$protoSearch -I=$proto3Search --idl=$protobuf3IDL --model_dir=hertz_model hz client -I=$protoSearch -I=$proto3Search --idl=$protobuf3IDL --client_dir=hertz_client cd .. rm -rf test } main() { compile_hz install_dependent_tools # todo: add thrift test when thriftgo fixed windows echo "test thrift......" test_thrift echo "test protobuf2......" test_protobuf2 echo "test protobuf3......" test_protobuf3 echo "hz execute success" } main ================================================ FILE: cmd/hz/testdata/protobuf2/api.proto ================================================ syntax = "proto2"; package api; import "google/protobuf/descriptor.proto"; option go_package = "github.com/cloudwego/hertz/cmd/hz/test/hertz_model/api"; extend google.protobuf.FieldOptions { optional string raw_body = 50101; optional string query = 50102; optional string header = 50103; optional string cookie = 50104; optional string body = 50105; optional string path = 50106; optional string vd = 50107; optional string form = 50108; optional string js_conv = 50109; optional string file_name = 50110; optional string none = 50111; // 50131~50160 used to extend field option by hz optional string form_compatible = 50131; optional string js_conv_compatible = 50132; optional string file_name_compatible = 50133; optional string none_compatible = 50134; optional string go_tag = 51001; } extend google.protobuf.MethodOptions { optional string get = 50201; optional string post = 50202; optional string put = 50203; optional string delete = 50204; optional string patch = 50205; optional string options = 50206; optional string head = 50207; optional string any = 50208; optional string gen_path = 50301; // The path specified by the user when the client code is generated, with a higher priority than api_version optional string api_version = 50302; // Specify the value of the :version variable in path when the client code is generated optional string tag = 50303; // rpc tag, can be multiple, separated by commas optional string name = 50304; // Name of rpc optional string api_level = 50305; // Interface Level optional string serializer = 50306; // Serialization method optional string param = 50307; // Whether client requests take public parameters optional string baseurl = 50308; // Baseurl used in ttnet routing optional string handler_path = 50309; // handler_path specifies the path to generate the method // 50331~50360 used to extend method option by hz optional string handler_path_compatible = 50331; // handler_path specifies the path to generate the method } extend google.protobuf.EnumValueOptions { optional int32 http_code = 50401; // 50431~50460 used to extend enum option by hz } extend google.protobuf.ServiceOptions { optional string base_domain = 50402; // 50731~50760 used to extend service option by hz optional string base_domain_compatible = 50731; } ================================================ FILE: cmd/hz/testdata/protobuf2/other/other.proto ================================================ syntax = "proto2"; package hertz.other; import "other/other_base.proto"; option go_package = "github.com/cloudwego/hertz/cmd/hz/test/hertz_model/other"; message OtherType { optional string IsBaseString = 1; optional OtherBaseType IsOtherBaseType = 2; } ================================================ FILE: cmd/hz/testdata/protobuf2/other/other_base.proto ================================================ syntax = "proto2"; package hertz.other; option go_package = "github.com/cloudwego/hertz/cmd/hz/test/hertz_model/other"; message OtherBaseType { optional string IsOtherBaseTypeString = 1; } ================================================ FILE: cmd/hz/testdata/protobuf2/psm/base.proto ================================================ syntax = "proto2"; package base; option go_package = "github.com/cloudwego/hertz/cmd/hz/test/hertz_model/psm"; message Base { optional string IsBaseString = 1; } enum BaseEnumType { TWEET = 0; RETWEET = 1; } ================================================ FILE: cmd/hz/testdata/protobuf2/psm/psm.proto ================================================ syntax = "proto2"; package psm; option go_package = "github.com/cloudwego/hertz/cmd/hz/test/hertz_model/psm"; import "api.proto"; import "base.proto"; import "other/other.proto"; enum EnumType { TWEET = 0; RETWEET = 1; } message UnusedMessageType { optional string IsUnusedMessageType = 1; } message BaseType { optional base.Base IsBaseType = 1; } message MultiTypeReq { // basic type (leading comments) optional bool IsBoolOpt = 1; required bool IsBoolReq = 2; optional int32 IsInt32Opt = 3; required int32 IsInt32Req = 4; optional int64 IsInt64Opt = 5; optional uint32 IsUInt32Opt = 6; optional uint64 IsUInt64Opt = 7; optional sint32 IsSInt32Opt = 8; optional sint64 IsSInt64Opt = 9; optional fixed32 IsFix32Opt = 10; optional fixed64 IsFix64Opt = 11; optional sfixed32 IsSFix32Opt = 12; optional sfixed64 IsSFix64Opt = 13; optional double IsDoubleOpt = 14; required double IsDoubleReq = 15; optional float IsFloatOpt = 16; optional string IsStringOpt = 17; required string IsStringReq = 18; optional bytes IsBytesOpt = 19; optional bytes IsBytesReq = 20; // slice repeated string IsRepeatedString = 21; repeated BaseType IsRepeatedBaseType = 22; // map map IsStringMap = 23; map IsBaseTypeMap = 24; // oneof // multiple comments oneof TestOneof { string IsOneofString = 25; BaseType IsOneofBaseType = 26; int32 IsOneofInt = 100; bool IsOneofBool = 101; double IsOneoDouble = 102; bytes IsOneofBytes = 103; } // this is oneof2, one field in oneof oneof TestOneof2 { string IsOneof2String = 104; } message NestedMessageType { optional string IsNestedString = 1; optional BaseType IsNestedBaseType = 2; repeated BaseType IsNestedRepeatedBaseType = 3; // nested oneof oneof NestedMsgOneof { string IsNestedMsgOneofString = 4; EnumType IsNestedMsgOneofEnumType = 5; } } // nested message optional NestedMessageType IsNestedType = 27; // other dependency optional base.Base IsCurrentPackageBase = 28; optional hertz.other.OtherType IsOtherType = 29; // enum optional EnumType IsEnumTypeOpt = 30; required EnumType IsEnumTypeReq = 31; repeated EnumType IsEnumTypeList = 32; optional base.BaseEnumType IsBaseEnumType = 33; } message MultiTagReq { optional string QueryTag = 1 [(api.query) = "query", (api.none) = "true"]; optional string RawBodyTag = 2 [(api.raw_body) = "raw_body"]; optional string CookieTag = 3 [(api.cookie) = "cookie"]; optional string BodyTag = 4 [(api.body) = "body"]; optional string PathTag = 5 [(api.path) = "path"]; optional string VdTag = 6 [(api.vd) = "$!='?'"]; optional string FormTag = 7 [(api.form) = "form"]; optional string DefaultTag = 8 [(api.go_tag) = "FFF:\"fff\" json:\"json\""]; } message CompatibleAnnoReq { optional string FormCompatibleTag = 1 [(api.form_compatible) = "form"]; optional string FilenameCompatibleTag = 2 [(api.file_name_compatible) = "file_name"]; optional string NoneCompatibleTag = 3 [(api.none_compatible) = "true"]; optional string JsConvCompatibleTag = 4 [(api.js_conv_compatible) = "true"]; } message Resp { optional string Resp = 1; } message MultiNameStyleMessage { optional string hertz = 1; optional string Hertz = 2; optional string hertz_demo = 3; optional string hertz_demo_idl = 4; optional string hertz_Idl = 5; optional string hertzDemo = 6; optional string h = 7; optional string H = 8; optional string hertz_ = 9; } service Hertz { rpc Method1(MultiTypeReq) returns(Resp) { option (api.get) = "/company/department/group/user:id/name"; } rpc Method2(MultiTypeReq) returns(Resp) { option (api.post) = "/company/department/group/user:id/sex"; } rpc Method3(MultiTypeReq) returns(Resp) { option (api.put) = "/company/department/group/user:id/number"; } rpc Method4(MultiTypeReq) returns(Resp) { option (api.delete) = "/company/department/group/user:id/age"; } rpc Method5(MultiTagReq) returns(Resp) { option (api.options) = "/school/class/student/name"; } rpc Method6(MultiTagReq) returns(Resp) { option (api.head) = "/school/class/student/number"; } rpc Method7(MultiTagReq) returns(Resp) { option (api.patch) = "/school/class/student/sex"; } rpc Method8(MultiTagReq) returns(Resp) { option (api.any) = "/school/class/student/grade/*subjects"; } } ================================================ FILE: cmd/hz/testdata/protobuf3/api.proto ================================================ syntax = "proto2"; package api; import "google/protobuf/descriptor.proto"; option go_package = "github.com/cloudwego/hertz/cmd/hz/test/hertz_model/api"; extend google.protobuf.FieldOptions { optional string raw_body = 50101; optional string query = 50102; optional string header = 50103; optional string cookie = 50104; optional string body = 50105; optional string path = 50106; optional string vd = 50107; optional string form = 50108; optional string js_conv = 50109; optional string file_name = 50110; optional string none = 50111; // 50131~50160 used to extend field option by hz optional string form_compatible = 50131; optional string js_conv_compatible = 50132; optional string file_name_compatible = 50133; optional string none_compatible = 50134; optional string go_tag = 51001; } extend google.protobuf.MethodOptions { optional string get = 50201; optional string post = 50202; optional string put = 50203; optional string delete = 50204; optional string patch = 50205; optional string options = 50206; optional string head = 50207; optional string any = 50208; optional string gen_path = 50301; // The path specified by the user when the client code is generated, with a higher priority than api_version optional string api_version = 50302; // Specify the value of the :version variable in path when the client code is generated optional string tag = 50303; // rpc tag, can be multiple, separated by commas optional string name = 50304; // Name of rpc optional string api_level = 50305; // Interface Level optional string serializer = 50306; // Serialization method optional string param = 50307; // Whether client requests take public parameters optional string baseurl = 50308; // Baseurl used in ttnet routing optional string handler_path = 50309; // handler_path specifies the path to generate the method // 50331~50360 used to extend method option by hz optional string handler_path_compatible = 50331; // handler_path specifies the path to generate the method } extend google.protobuf.EnumValueOptions { optional int32 http_code = 50401; // 50431~50460 used to extend enum option by hz } extend google.protobuf.ServiceOptions { optional string base_domain = 50402; // 50731~50760 used to extend service option by hz optional string base_domain_compatible = 50731; } ================================================ FILE: cmd/hz/testdata/protobuf3/other/other.proto ================================================ syntax = "proto2"; package hertz.other; import "other/other_base.proto"; option go_package = "github.com/cloudwego/hertz/cmd/hz/test/hertz_model/other"; message OtherType { optional string IsBaseString = 1; optional OtherBaseType IsOtherBaseType = 2; } ================================================ FILE: cmd/hz/testdata/protobuf3/other/other_base.proto ================================================ syntax = "proto2"; package hertz.other; option go_package = "github.com/cloudwego/hertz/cmd/hz/test/hertz_model/other"; message OtherBaseType { optional string IsOtherBaseTypeString = 1; } ================================================ FILE: cmd/hz/testdata/protobuf3/psm/base.proto ================================================ syntax = "proto2"; package base; option go_package = "github.com/cloudwego/hertz/cmd/hz/test/hertz_model/psm"; message Base { optional string IsBaseString = 1; } ================================================ FILE: cmd/hz/testdata/protobuf3/psm/psm.proto ================================================ syntax = "proto3"; package psm; option go_package = "github.com/cloudwego/hertz/cmd/hz/test/hertz_model/psm"; import "api.proto"; import "base.proto"; import "other/other.proto"; enum EnumType { TWEET = 0; RETWEET = 1; } message UnusedMessageType { optional string IsUnusedMessageType = 1; } message BaseType { optional base.Base IsBaseType = 1; } message MultiTypeReq { // basic type (leading comments) optional bool IsBoolOpt = 1; optional int32 IsInt32Opt = 3; int64 IsInt64Default = 5; optional uint32 IsUInt32Opt = 6; uint64 IsUInt64Default = 7; optional sint32 IsSInt32Opt = 8; sint64 IsSInt64Default = 9; optional fixed32 IsFix32Opt = 10; optional fixed64 IsFix64Opt = 11; optional sfixed32 IsSFix32Opt = 12; optional sfixed64 IsSFix64Opt = 13; optional double IsDoubleOpt = 14; optional float IsFloatOpt = 16; optional string IsStringOpt = 17; optional bytes IsBytesOpt = 19; bytes IsBytesDefault = 20; // slice repeated string IsRepeatedString = 21; repeated BaseType IsRepeatedBaseType = 22; // map map IsStringMap = 23; map IsBaseTypeMap = 24; // oneof oneof TestOneof { string IsOneofString = 25; BaseType IsOneofBaseTypeString = 26; } oneof TestOneof2 { string IsOneofString2 = 100; } // nested message message NestedMessageType { oneof NestedOneof { string YYY = 4; string GGG = 5; } optional string IsNestedString = 1; optional BaseType IsNestedBaseType = 2; repeated BaseType IsNestedRepeatedBaseType = 3; } optional NestedMessageType IsNestedType = 27; // other dependency optional base.Base IsCurrentPackageBase = 28; optional hertz.other.OtherType IsOtherType = 29; // enum optional EnumType IsEnumTypeOpt = 30; EnumType IsEnumTypeDefault = 31; } message MultiTagReq { optional string QueryTag = 1 [(api.query) = "query", (api.none) = "true"]; optional string RawBodyTag = 2 [(api.raw_body)="raw_body"]; optional string CookieTag = 3 [(api.cookie)="cookie"]; optional string BodyTag = 4 [(api.body)="body"]; optional string PathTag = 5 [(api.path)="path"]; optional string VdTag = 6 [(api.vd)="$!='?'"]; optional string DefaultTag = 7; oneof TestOneof { string IsOneofString = 25; BaseType IsOneofBaseTypeString = 26; } } message CompatibleAnnoReq { optional string FormCompatibleTag = 1 [(api.form_compatible) = "form"]; optional string FilenameCompatibleTag = 2 [(api.file_name_compatible) = "file_name"]; optional string NoneCompatibleTag = 3 [(api.none_compatible) = "true"]; optional string JsConvCompatibleTag = 4 [(api.js_conv_compatible) = "true"]; } message Resp { optional string Resp = 1; } service Hertz { rpc Method1(MultiTypeReq) returns(Resp) { option (api.get)="/company/department/group/user:id/name"; } rpc Method2(MultiTypeReq) returns(Resp) { option (api.post)="/company/department/group/user:id/sex"; } rpc Method3(MultiTypeReq) returns(Resp) { option (api.put)="/company/department/group/user:id/number"; } rpc Method4(MultiTypeReq) returns(Resp) { option (api.delete)="/company/department/group/user:id/age"; } rpc Method5(MultiTagReq) returns(Resp) { option (api.options)="/school/class/student/name"; } rpc Method6(MultiTagReq) returns(Resp) { option (api.head)="/school/class/student/number"; } rpc Method7(MultiTagReq) returns(Resp) { option (api.patch)="/school/class/student/sex"; } rpc Method8(MultiTagReq) returns(Resp) { option (api.any)="/school/class/student/grade/*subjects"; } } ================================================ FILE: cmd/hz/testdata/thrift/common.thrift ================================================ namespace go toutiao.middleware.hertz struct CommonType { 1: required string IsCommonString; 2: optional string TTT; 3: required bool HHH; 4: required Base GGG; } struct Base { 1: optional string AAA; 2: optional i32 BBB; } ================================================ FILE: cmd/hz/testdata/thrift/data/basic_data.thrift ================================================ namespace go toutiao.middleware.hertz_data struct BasicDataType { 1: optional string IsBasicDataString; } ================================================ FILE: cmd/hz/testdata/thrift/data/data.thrift ================================================ include "basic_data.thrift" namespace go toutiao.middleware.hertz_data struct DataType { 1: optional basic_data.BasicDataType IsDataString; } ================================================ FILE: cmd/hz/testdata/thrift/psm.thrift ================================================ include "common.thrift" include "data/data.thrift" namespace go toutiao.middleware.hertz const string STRING_CONST = "hertz"; enum EnumType { TWEET, RETWEET = 2, } typedef i32 MyInteger struct BaseType { 1: string GoTag = "test" (go.tag="json:\"go\" goTag:\"tag\""); 2: optional string IsBaseString = "test"; 3: optional common.CommonType IsDepCommonType = {"IsCommonString":"test", "TTT":"test", "HHH":true, "GGG": {"AAA":"test","BBB":32}}; 4: optional EnumType IsBaseTypeEnum = 1; } typedef common.CommonType FFF typedef BaseType MyBaseType struct MultiTypeReq { // basic type (leading comments) 1: optional bool IsBoolOpt = true; // trailing comments 2: required bool IsBoolReq; 3: optional byte IsByteOpt = 8; 4: required byte IsByteReq; //5: optional i8 IsI8Opt; // unsupported i8, suggest byte //6: required i8 IsI8Req = 5; // default 7: optional i16 IsI16Opt = 16; 8: optional i32 IsI32Opt; 9: optional i64 IsI64Opt; 10: optional double IsDoubleOpt; 11: required double IsDoubleReq; 12: optional string IsStringOpt = "test"; 13: required string IsStringReq; 14: optional list IsList; 22: required list IsListReq; 15: optional set IsSet; 16: optional map IsMap; 21: optional map IsStructMap; // struct type 17: optional BaseType IsBaseType; // use struct name 18: optional MyBaseType IsMyBaseType; // use typedef for struct 19: optional common.CommonType IsCommonType = {"IsCommonString": "fffff"}; 20: optional data.DataType IsDataType; // multi-dependent struct } typedef data.DataType IsMyDataType struct MultiTagReq { 1: string QueryTag (api.query="query"); 2: string RawBodyTag (api.raw_body="raw_body"); 3: string PathTag (api.path="path"); 4: string FormTag (api.form="form"); 5: string CookieTag (api.cookie="cookie"); 6: string HeaderTag (api.header="header"); 7: string ProtobufTag (api.protobuf="protobuf"); 8: string BodyTag (api.body="body"); 9: string GoTag (go.tag="json:\"go\" goTag:\"tag\""); 10: string VdTag (api.vd="$!='?'"); 11: string DefaultTag; } struct Resp { 1: string Resp = "this is Resp"; } struct MultiNameStyleReq { 1: optional string hertz; 2: optional string Hertz; 3: optional string hertz_demo; 4: optional string hertz_demo_idl; 5: optional string hertz_Idl; 6: optional string hertzDemo; 7: optional string h; 8: optional string H; 9: optional string hertz_; } struct MultiDefaultReq { 1: optional bool IsBoolOpt = true; 2: required bool IsBoolReq = false; 3: optional i32 IsI32Opt = 32; 4: required i32 IsI32Req = 32; 5: optional string IsStringOpt = "test"; 6: required string IsStringReq = "test"; 14: optional list IsListOpt = ["test", "ttt", "sdsds"]; 22: required list IsListReq = ["test", "ttt", "sdsds"]; 15: optional set IsSet = ["test", "ttt", "sdsds"]; 16: optional map IsMapOpt = {"test": "ttt", "ttt": "lll"}; 17: required map IsMapReq = {"test": "ttt", "ttt": "lll"}; 21: optional map IsStructMapOpt = {"test": {"GoTag":"fff", "IsBaseTypeEnum":1, "IsBaseString":"ddd", "IsDepCommonType": {"IsCommonString":"fffffff", "TTT":"ttt", "HHH":true, "GGG": {"AAA":"test","BBB":32}}}}; 25: required map IsStructMapReq = {"test": {"GoTag":"fff", "IsBaseTypeEnum":1, "IsBaseString":"ddd", "IsDepCommonType": {"IsCommonString":"fffffff", "TTT":"ttt", "HHH":true, "GGG": {"AAA":"test","BBB":32}}}}; 23: optional common.CommonType IsDepCommonTypeOpt = {"IsCommonString":"fffffff", "TTT":"ttt", "HHH":true, "GGG": {"AAA":"test","BBB":32}}; 24: required common.CommonType IsDepCommonTypeReq = {"IsCommonString":"fffffff", "TTT":"ttt", "HHH":true, "GGG": {"AAA":"test","BBB":32}}; } typedef map IsTypedefContainer service Hertz { Resp Method1(1: MultiTypeReq request) (api.get="/company/department/group/user:id/name", api.handler_path="v1"); Resp Method2(1: MultiTagReq request) (api.post="/company/department/group/user:id/sex", api.handler_path="v1"); Resp Method3(1: BaseType request) (api.put="/company/department/group/user:id/number", api.handler_path="v1"); Resp Method4(1: data.DataType request) (api.delete="/company/department/group/user:id/age", api.handler_path="v1"); Resp Method5(1: MultiTypeReq request) (api.options="/school/class/student/name", api.handler_path="v2"); Resp Method6(1: MultiTagReq request) (api.head="/school/class/student/number", api.handler_path="v2"); Resp Method7(1: MultiTagReq request) (api.patch="/school/class/student/sex", api.handler_path="v2"); Resp Method8(1: BaseType request) (api.any="/school/class/student/grade/*subjects", api.handler_path="v2"); Resp Method9(1: IsTypedefContainer request) (api.get="/typedef/container", api.handler_path="v2"); Resp Method10(1: map request) (api.get="/container", api.handler_path="v2"); } ================================================ FILE: cmd/hz/thrift/ast.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package thrift import ( "fmt" "sort" "strings" "github.com/cloudwego/hertz/cmd/hz/config" "github.com/cloudwego/hertz/cmd/hz/generator" "github.com/cloudwego/hertz/cmd/hz/generator/model" "github.com/cloudwego/hertz/cmd/hz/meta" "github.com/cloudwego/hertz/cmd/hz/util" "github.com/cloudwego/hertz/cmd/hz/util/logs" "github.com/cloudwego/thriftgo/generator/golang" "github.com/cloudwego/thriftgo/generator/golang/styles" "github.com/cloudwego/thriftgo/parser" "github.com/cloudwego/thriftgo/semantic" ) /*---------------------------Import-----------------------------*/ func getGoPackage(ast *parser.Thrift, pkgMap map[string]string) string { filePackage := ast.GetFilename() if opt, ok := pkgMap[filePackage]; ok { return opt } else { goPackage := ast.GetNamespaceOrReferenceName("go") if goPackage != "" { return util.SplitPackage(goPackage, "") } // If namespace is not declared, the file name (without the extension) is used as the package name return util.SplitPackage(filePackage, ".thrift") } } /*---------------------------Service-----------------------------*/ func astToService(ast *parser.Thrift, resolver *Resolver, args *config.Argument) ([]*generator.Service, error) { ss := ast.GetServices() out := make([]*generator.Service, 0, len(ss)) var models model.Models extendServices := getExtendServices(ast) for _, s := range ss { // if the service is extended, it is not processed if extendServices.exist(s.Name) && args.EnableExtends { logs.Debugf("%s is extended, so skip it\n", s.Name) continue } resolver.ExportReferred(true, false) service := &generator.Service{ Name: s.GetName(), } service.BaseDomain = "" domainAnno := getAnnotation(s.Annotations, ApiBaseDomain) if len(domainAnno) == 1 { if args.CmdType == meta.CmdClient { service.BaseDomain = domainAnno[0] } } service.ServiceGroup = "" groupAnno := getAnnotation(s.Annotations, ApiServiceGroup) if len(groupAnno) == 1 { if args.CmdType != meta.CmdClient { service.ServiceGroup = groupAnno[0] } } service.ServiceGenDir = "" serviceGenDirAnno := getAnnotation(s.Annotations, ApiServiceGenDir) if len(serviceGenDirAnno) == 1 { if args.CmdType != meta.CmdClient { service.ServiceGenDir = serviceGenDirAnno[0] } } ms := s.GetFunctions() if len(s.Extends) != 0 && args.EnableExtends { // all the services that are extended to the current service extendsFuncs, err := getAllExtendFunction(s, ast, resolver, args) if err != nil { return nil, fmt.Errorf("parser extend function failed, err=%v", err) } ms = append(ms, extendsFuncs...) } methods := make([]*generator.HttpMethod, 0, len(ms)) clientMethods := make([]*generator.ClientMethod, 0, len(ms)) servicePathAnno := getAnnotation(s.Annotations, ApiServicePath) servicePath := "" if len(servicePathAnno) > 0 { servicePath = servicePathAnno[0] } for _, m := range ms { rs := getAnnotations(m.Annotations, HttpMethodAnnotations) if len(rs) == 0 { continue } httpAnnos := httpAnnotations{} for k, v := range rs { httpAnnos = append(httpAnnos, httpAnnotation{ method: k, path: v, }) } // turn the map into a slice and sort it to make sure getting the results in the same order every time sort.Sort(httpAnnos) handlerOutDir := servicePath genPaths := getAnnotation(m.Annotations, ApiGenPath) if len(genPaths) == 1 { handlerOutDir = genPaths[0] } else if len(genPaths) > 0 { return nil, fmt.Errorf("too many 'api.handler_path' for %s", m.Name) } hmethod, path := httpAnnos[0].method, httpAnnos[0].path if len(path) == 0 || path[0] == "" { return nil, fmt.Errorf("invalid api.%s for %s.%s: %s", hmethod, s.Name, m.Name, path) } var reqName, reqRawName, reqPackage string if len(m.Arguments) >= 1 { if len(m.Arguments) > 1 { logs.Warnf("function '%s' has more than one argument, but only the first can be used in hertz now", m.GetName()) } var err error reqName, err = resolver.ResolveTypeName(m.Arguments[0].GetType()) if err != nil { return nil, err } if strings.Contains(reqName, ".") && !m.Arguments[0].GetType().Category.IsContainerType() { // If reqName contains "." , then it must be of the form "pkg.name". // so reqRawName='name', reqPackage='pkg' names := strings.Split(reqName, ".") if len(names) != 2 { return nil, fmt.Errorf("request name: %s is wrong", reqName) } reqRawName = names[1] reqPackage = names[0] } } var respName, respRawName, respPackage string if !m.Oneway { var err error respName, err = resolver.ResolveTypeName(m.GetFunctionType()) if err != nil { return nil, err } if strings.Contains(respName, ".") && !m.GetFunctionType().Category.IsContainerType() { names := strings.Split(respName, ".") if len(names) != 2 { return nil, fmt.Errorf("response name: %s is wrong", respName) } // If respName contains "." , then it must be of the form "pkg.name". // so respRawName='name', respPackage='pkg' respRawName = names[1] respPackage = names[0] } } sr, _ := util.GetFirstKV(getAnnotations(m.Annotations, SerializerTags)) method := &generator.HttpMethod{ Name: util.CamelString(m.GetName()), HTTPMethod: hmethod, RequestTypeName: reqName, RequestTypeRawName: reqRawName, RequestTypePackage: reqPackage, ReturnTypeName: respName, ReturnTypeRawName: respRawName, ReturnTypePackage: respPackage, Path: path[0], Serializer: sr, OutputDir: handlerOutDir, GenHandler: true, // Annotations: m.Annotations, } refs := resolver.ExportReferred(false, true) method.Models = make(map[string]*model.Model, len(refs)) for _, ref := range refs { if v, ok := method.Models[ref.Model.PackageName]; ok && (v.Package != ref.Model.Package) { return nil, fmt.Errorf("Package name: %s redeclared in %s and %s ", ref.Model.PackageName, v.Package, ref.Model.Package) } method.Models[ref.Model.PackageName] = ref.Model } models.MergeMap(method.Models) methods = append(methods, method) for idx, anno := range httpAnnos { for i := 0; i < len(anno.path); i++ { if idx == 0 && i == 0 { // idx==0 && i==0 has been added above continue } newMethod, err := newHTTPMethod(s, m, method, i, anno) if err != nil { return nil, err } methods = append(methods, newMethod) } } if args.CmdType == meta.CmdClient { clientMethod := &generator.ClientMethod{} clientMethod.HttpMethod = method rt, err := resolver.ResolveIdentifier(m.Arguments[0].GetType().GetName()) if err != nil { return nil, err } err = parseAnnotationToClient(clientMethod, m.Arguments[0].GetType(), rt, args.EnableClientOptional) if err != nil { return nil, err } clientMethods = append(clientMethods, clientMethod) } } service.ClientMethods = clientMethods service.Methods = methods service.Models = models out = append(out, service) } return out, nil } func newHTTPMethod(s *parser.Service, m *parser.Function, method *generator.HttpMethod, i int, anno httpAnnotation) (*generator.HttpMethod, error) { newMethod := *method hmethod, path := anno.method, anno.path if path[i] == "" { return nil, fmt.Errorf("invalid api.%s for %s.%s: %s", hmethod, s.Name, m.Name, path[i]) } newMethod.HTTPMethod = hmethod newMethod.Path = path[i] newMethod.GenHandler = false return &newMethod, nil } func parseAnnotationToClient(clientMethod *generator.ClientMethod, p *parser.Type, symbol ResolvedSymbol, enableOptional bool) error { if p == nil { return fmt.Errorf("get type failed for parse annotatoon to client") } typeName := p.GetName() if strings.Contains(typeName, ".") { ret := strings.Split(typeName, ".") typeName = ret[len(ret)-1] } scope, err := golang.BuildScope(thriftgoUtil, symbol.Scope) if err != nil { return fmt.Errorf("can not build scope for %s", p.Name) } thriftgoUtil.SetRootScope(scope) st := scope.StructLike(typeName) if st == nil { logs.Infof("the type '%s' for method '%s' is base type, so skip parse client info\n") return nil } var ( hasBodyAnnotation bool hasFormAnnotation bool ) for _, field := range st.Fields() { hasAnnotation := false isStringFieldType := false isOptional := false if field.GetType().String() == "string" { isStringFieldType = true } if field.GetRequiredness() == parser.FieldType_Optional { isOptional = true } if anno := getAnnotation(field.Annotations, AnnotationQuery); len(anno) > 0 { hasAnnotation = true query := checkSnakeName(anno[0]) if isOptional && enableOptional { clientMethod.QueryParamsCode += fmt.Sprintf("%q: func() interface{} {\n\t\t\t\tif req.IsSet%s() {\n\t\t\t\t\treturn req.Get%s()\n\t\t\t\t} else {\n\t\t\t\t\treturn nil\n\t\t\t\t}}(),\n", query, field.GoName().String(), field.GoName().String()) } else { clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", query, field.GoName().String()) } } if anno := getAnnotation(field.Annotations, AnnotationPath); len(anno) > 0 { hasAnnotation = true path := anno[0] if isStringFieldType { clientMethod.PathParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", path, field.GoName().String()) } else { clientMethod.PathParamsCode += fmt.Sprintf("%q: fmt.Sprint(req.Get%s()),\n", path, field.GoName().String()) } } if anno := getAnnotation(field.Annotations, AnnotationHeader); len(anno) > 0 { hasAnnotation = true header := anno[0] if isStringFieldType { clientMethod.HeaderParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", header, field.GoName().String()) } else { clientMethod.HeaderParamsCode += fmt.Sprintf("%q: fmt.Sprint(req.Get%s()),\n", header, field.GoName().String()) } } if anno := getAnnotation(field.Annotations, AnnotationForm); len(anno) > 0 { hasAnnotation = true form := checkSnakeName(anno[0]) hasFormAnnotation = true if isStringFieldType { clientMethod.FormValueCode += fmt.Sprintf("%q: req.Get%s(),\n", form, field.GoName().String()) } else { clientMethod.FormValueCode += fmt.Sprintf("%q: fmt.Sprint(req.Get%s()),\n", form, field.GoName().String()) } } if anno := getAnnotation(field.Annotations, AnnotationBody); len(anno) > 0 { hasAnnotation = true hasBodyAnnotation = true } if anno := getAnnotation(field.Annotations, AnnotationFileName); len(anno) > 0 { hasAnnotation = true fileName := anno[0] hasFormAnnotation = true clientMethod.FormFileCode += fmt.Sprintf("%q: req.Get%s(),\n", fileName, field.GoName().String()) } if anno := getAnnotation(field.Annotations, AnnotationCookie); len(anno) > 0 { hasAnnotation = true // cookie do nothing } if !hasAnnotation && strings.EqualFold(clientMethod.HTTPMethod, "get") { if isOptional && enableOptional { clientMethod.QueryParamsCode += fmt.Sprintf("%q: func() interface{} {\n\t\t\t\tif req.IsSet%s() {\n\t\t\t\t\treturn req.Get%s()\n\t\t\t\t} else {\n\t\t\t\t\treturn nil\n\t\t\t\t}}(),\n", checkSnakeName(field.GetName()), field.GoName().String(), field.GoName().String()) } else { clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", checkSnakeName(field.GetName()), field.GoName().String()) } } } clientMethod.BodyParamsCode = meta.SetBodyParam if hasBodyAnnotation && hasFormAnnotation { clientMethod.FormValueCode = "" clientMethod.FormFileCode = "" } if !hasBodyAnnotation && hasFormAnnotation { clientMethod.BodyParamsCode = "" } return nil } type extendServiceList []string func (svr extendServiceList) exist(serviceName string) bool { for _, s := range svr { if s == serviceName { return true } } return false } func getExtendServices(ast *parser.Thrift) (res extendServiceList) { for a := range ast.DepthFirstSearch() { for _, svc := range a.Services { if len(svc.Extends) > 0 { res = append(res, svc.Extends) } } } return } func getAllExtendFunction(svc *parser.Service, ast *parser.Thrift, resolver *Resolver, args *config.Argument) (res []*parser.Function, err error) { if len(svc.Extends) == 0 { return } parts := semantic.SplitType(svc.Extends) switch len(parts) { case 1: if resolver.mainPkg.Ast.Filename == ast.Filename { // extended current service for master IDL extendSvc, found := ast.GetService(parts[0]) if found { funcs := extendSvc.GetFunctions() // determine if it still has extends extendFuncs, err := getAllExtendFunction(extendSvc, ast, resolver, args) if err != nil { return nil, err } res = append(res, append(funcs, extendFuncs...)...) } return res, nil } else { // extended current service for other IDL extendSvc, found := ast.GetService(parts[0]) if found { base, err := addResolverDependency(resolver, ast, args) if err != nil { return nil, err } funcs := extendSvc.GetFunctions() for _, f := range funcs { processExtendsType(f, base) } extendFuncs, err := getAllExtendFunction(extendSvc, ast, resolver, args) if err != nil { return nil, err } res = append(res, append(funcs, extendFuncs...)...) } return res, nil } case 2: refAst, found := ast.GetReference(parts[0]) base, err := addResolverDependency(resolver, refAst, args) if err != nil { return nil, err } // ff the service extends from other files, it has to resolve the dependencies of other files as well for _, dep := range refAst.Includes { _, err := addResolverDependency(resolver, dep.Reference, args) if err != nil { return nil, err } } if found { extendSvc, found := refAst.GetService(parts[1]) if found { funcs := extendSvc.GetFunctions() for _, f := range funcs { processExtendsType(f, base) } extendFuncs, err := getAllExtendFunction(extendSvc, refAst, resolver, args) if err != nil { return nil, err } res = append(res, append(funcs, extendFuncs...)...) } } return res, nil } return res, nil } func processExtendsType(f *parser.Function, base string) { // the method of other file is extended, and the package of req/resp needs to be changed // ex. base.thrift -> Resp Method(Req){} // base.Resp Method(base.Req){} if len(f.Arguments) > 0 { if f.Arguments[0].Type.Category.IsContainerType() { switch f.Arguments[0].Type.Category { case parser.Category_Set, parser.Category_List: if !strings.Contains(f.Arguments[0].Type.ValueType.Name, ".") && f.Arguments[0].Type.ValueType.Category.IsStruct() { f.Arguments[0].Type.ValueType.Name = base + "." + f.Arguments[0].Type.ValueType.Name } case parser.Category_Map: if !strings.Contains(f.Arguments[0].Type.ValueType.Name, ".") && f.Arguments[0].Type.ValueType.Category.IsStruct() { f.Arguments[0].Type.ValueType.Name = base + "." + f.Arguments[0].Type.ValueType.Name } if !strings.Contains(f.Arguments[0].Type.KeyType.Name, ".") && f.Arguments[0].Type.KeyType.Category.IsStruct() { f.Arguments[0].Type.KeyType.Name = base + "." + f.Arguments[0].Type.KeyType.Name } } } else { if !strings.Contains(f.Arguments[0].Type.Name, ".") && f.Arguments[0].Type.Category.IsStruct() { f.Arguments[0].Type.Name = base + "." + f.Arguments[0].Type.Name } } } if f.FunctionType.Category.IsContainerType() { switch f.FunctionType.Category { case parser.Category_Set, parser.Category_List: if !strings.Contains(f.FunctionType.ValueType.Name, ".") && f.FunctionType.ValueType.Category.IsStruct() { f.FunctionType.ValueType.Name = base + "." + f.FunctionType.ValueType.Name } case parser.Category_Map: if !strings.Contains(f.FunctionType.ValueType.Name, ".") && f.FunctionType.ValueType.Category.IsStruct() { f.FunctionType.ValueType.Name = base + "." + f.FunctionType.ValueType.Name } if !strings.Contains(f.FunctionType.KeyType.Name, ".") && f.FunctionType.KeyType.Category.IsStruct() { f.FunctionType.KeyType.Name = base + "." + f.FunctionType.KeyType.Name } } } else { if !strings.Contains(f.FunctionType.Name, ".") && f.FunctionType.Category.IsStruct() { f.FunctionType.Name = base + "." + f.FunctionType.Name } } } func getUniqueResolveDependentName(name string, resolver *Resolver) string { rawName := name for i := 0; i < 10000; i++ { if _, exist := resolver.deps[name]; !exist { return name } name = rawName + fmt.Sprint(i) } return name } func addResolverDependency(resolver *Resolver, ast *parser.Thrift, args *config.Argument) (string, error) { namespace, err := resolver.LoadOne(ast) if err != nil { return "", err } baseName := util.BaseName(ast.Filename, ".thrift") if refPkg, exist := resolver.refPkgs[baseName]; !exist { resolver.deps[baseName] = namespace } else { if ast.Filename != refPkg.Ast.Filename { baseName = getUniqueResolveDependentName(baseName, resolver) resolver.deps[baseName] = namespace } } pkg := getGoPackage(ast, args.OptPkgMap) impt := ast.Filename pkgName := util.SplitPackageName(pkg, "") pkgName, err = util.GetPackageUniqueName(pkgName) if err != nil { return "", err } ref := &PackageReference{baseName, impt, &model.Model{ FilePath: ast.Filename, Package: pkg, PackageName: pkgName, }, ast, false} if _, exist := resolver.refPkgs[baseName]; !exist { resolver.refPkgs[baseName] = ref } return baseName, nil } /*---------------------------Model-----------------------------*/ var BaseThrift = parser.Thrift{} var baseTypes = map[string]string{ "bool": "bool", "byte": "int8", "i8": "int8", "i16": "int16", "i32": "int32", "i64": "int64", "double": "float64", "string": "string", "binary": "[]byte", } func switchBaseType(typ *parser.Type) *model.Type { switch typ.Name { case "bool": return model.TypeBool case "byte": return model.TypeByte case "i8": return model.TypeInt8 case "i16": return model.TypeInt16 case "i32": return model.TypeInt32 case "i64": return model.TypeInt64 case "int": return model.TypeInt case "double": return model.TypeFloat64 case "string": return model.TypeString case "binary": return model.TypeBinary } return nil } func newBaseType(typ *model.Type, cg model.Category) *model.Type { cyp := *typ cyp.Category = cg return &cyp } func newStructType(name string, cg model.Category) *model.Type { return &model.Type{ Name: name, Scope: nil, Kind: model.KindStruct, Category: cg, Indirect: false, Extra: nil, HasNew: true, } } func newEnumType(name string, cg model.Category) *model.Type { return &model.Type{ Name: name, Scope: &model.BaseModel, Kind: model.KindInt, Category: cg, } } func newFuncType(name string, cg model.Category) *model.Type { return &model.Type{ Name: name, Scope: nil, Kind: model.KindFunc, Category: cg, Indirect: false, Extra: nil, HasNew: false, } } func (resolver *Resolver) getFieldType(typ *parser.Type) (*model.Type, error) { if dt, _ := resolver.getBaseType(typ); dt != nil { return dt, nil } sb := resolver.Get(typ.Name) if sb != nil { return sb.Type, nil } return nil, fmt.Errorf("unknown type: %s", typ.Name) } type ResolvedSymbol struct { Base string Src string *Symbol } func (rs ResolvedSymbol) Expression() string { base, err := NameStyle.Identify(rs.Base) if err != nil { logs.Warnf("%s naming style for %s failed, fall back to %s, please refer to the variable manually!", NameStyle.Name(), rs.Base, rs.Base) base = rs.Base } // base type no need to do name style if model.IsBaseType(rs.Type) { // base type mapping if val, exist := baseTypes[rs.Base]; exist { base = val } } if rs.Src != "" { if !rs.IsValue && model.IsBaseType(rs.Type) { return base } return fmt.Sprintf("%s.%s", rs.Src, base) } return base } func astToModel(ast *parser.Thrift, rs *Resolver) (*model.Model, error) { main := rs.mainPkg.Model if main == nil { main = new(model.Model) } // typedefs tds := ast.GetTypedefs() typdefs := make([]model.TypeDef, 0, len(tds)) for _, t := range tds { td := model.TypeDef{ Scope: main, Alias: t.Alias, } if bt, err := rs.ResolveType(t.Type); bt == nil || err != nil { return nil, fmt.Errorf("%s has no type definition, error: %s", t.String(), err) } else { td.Type = bt } typdefs = append(typdefs, td) } main.Typedefs = typdefs // constants cts := ast.GetConstants() constants := make([]model.Constant, 0, len(cts)) variables := make([]model.Variable, 0, len(cts)) for _, c := range cts { ft, err := rs.ResolveType(c.Type) if err != nil { return nil, err } if ft.Name == model.TypeBaseList.Name || ft.Name == model.TypeBaseMap.Name || ft.Name == model.TypeBaseSet.Name { resolveValue, err := rs.ResolveConstantValue(c.Value) if err != nil { return nil, err } vt := model.Variable{ Scope: main, Name: c.Name, Type: ft, Value: resolveValue, } variables = append(variables, vt) } else { resolveValue, err := rs.ResolveConstantValue(c.Value) if err != nil { return nil, err } ct := model.Constant{ Scope: main, Name: c.Name, Type: ft, Value: resolveValue, } constants = append(constants, ct) } } main.Constants = constants main.Variables = variables // Enums ems := ast.GetEnums() enums := make([]model.Enum, 0, len(ems)) for _, e := range ems { em := model.Enum{ Scope: main, Name: e.GetName(), GoType: "int64", } vs := make([]model.Constant, 0, len(e.Values)) for _, ee := range e.Values { vs = append(vs, model.Constant{ Scope: main, Name: ee.Name, Type: model.TypeInt64, Value: model.IntExpression{Src: int(ee.Value)}, }) } em.Values = vs enums = append(enums, em) } main.Enums = enums // Structs sts := make([]*parser.StructLike, 0, len(ast.Structs)) sts = append(sts, ast.Structs...) structs := make([]model.Struct, 0, len(ast.Structs)+len(ast.Unions)+len(ast.Exceptions)) for _, st := range sts { s := model.Struct{ Scope: main, Name: st.GetName(), Category: model.CategoryStruct, LeadingComments: removeCommentsSlash(st.GetReservedComments()), } vs := make([]model.Field, 0, len(st.Fields)) for _, f := range st.Fields { fieldName, _ := (&styles.ThriftGo{}).Identify(f.Name) isP, err := isPointer(f, rs) if err != nil { return nil, err } resolveType, err := rs.ResolveType(f.Type) if err != nil { return nil, err } field := model.Field{ Scope: &s, Name: fieldName, Type: resolveType, // IsSetDefault: f.IsSetDefault(), LeadingComments: removeCommentsSlash(f.GetReservedComments()), IsPointer: isP, } err = injectTags(f, &field, true, true) if err != nil { return nil, err } vs = append(vs, field) } checkDuplicatedFileName(vs) s.Fields = vs structs = append(structs, s) } sts = make([]*parser.StructLike, 0, len(ast.Unions)) sts = append(sts, ast.Unions...) for _, st := range sts { s := model.Struct{ Scope: main, Name: st.GetName(), Category: model.CategoryUnion, LeadingComments: removeCommentsSlash(st.GetReservedComments()), } vs := make([]model.Field, 0, len(st.Fields)) for _, f := range st.Fields { fieldName, _ := (&styles.ThriftGo{}).Identify(f.Name) isP, err := isPointer(f, rs) if err != nil { return nil, err } resolveType, err := rs.ResolveType(f.Type) if err != nil { return nil, err } field := model.Field{ Scope: &s, Name: fieldName, Type: resolveType, LeadingComments: removeCommentsSlash(f.GetReservedComments()), IsPointer: isP, } err = injectTags(f, &field, true, true) if err != nil { return nil, err } vs = append(vs, field) } checkDuplicatedFileName(vs) s.Fields = vs structs = append(structs, s) } sts = make([]*parser.StructLike, 0, len(ast.Exceptions)) sts = append(sts, ast.Exceptions...) for _, st := range sts { s := model.Struct{ Scope: main, Name: st.GetName(), Category: model.CategoryException, LeadingComments: removeCommentsSlash(st.GetReservedComments()), } vs := make([]model.Field, 0, len(st.Fields)) for _, f := range st.Fields { fieldName, _ := (&styles.ThriftGo{}).Identify(f.Name) isP, err := isPointer(f, rs) if err != nil { return nil, err } resolveType, err := rs.ResolveType(f.Type) if err != nil { return nil, err } field := model.Field{ Scope: &s, Name: fieldName, Type: resolveType, LeadingComments: removeCommentsSlash(f.GetReservedComments()), IsPointer: isP, } err = injectTags(f, &field, true, true) if err != nil { return nil, err } vs = append(vs, field) } checkDuplicatedFileName(vs) s.Fields = vs structs = append(structs, s) } main.Structs = structs // In case of only the service refers another model, therefore scanning service is necessary ss := ast.GetServices() var err error for _, s := range ss { for _, m := range s.GetFunctions() { _, err = rs.ResolveType(m.GetFunctionType()) if err != nil { return nil, err } for _, a := range m.GetArguments() { _, err = rs.ResolveType(a.GetType()) if err != nil { return nil, err } } } } return main, nil } // removeCommentsSlash can remove double slash for comments with thrift func removeCommentsSlash(comments string) string { if comments == "" { return "" } return comments[2:] } func isPointer(f *parser.Field, rs *Resolver) (bool, error) { typ, err := rs.ResolveType(f.GetType()) if err != nil { return false, err } if typ == nil { return false, fmt.Errorf("can not get type: %s for %s", f.GetType(), f.GetName()) } if typ.Kind == model.KindStruct || typ.Kind == model.KindMap || typ.Kind == model.KindSlice { return false, nil } if f.GetRequiredness().IsOptional() { return true, nil } else { return false, nil } } func getNewFieldName(fieldName string, fieldNameSet map[string]bool) string { if _, ex := fieldNameSet[fieldName]; ex { fieldName = fieldName + "_" return getNewFieldName(fieldName, fieldNameSet) } return fieldName } func checkDuplicatedFileName(vs []model.Field) { fieldNameSet := make(map[string]bool) for i := 0; i < len(vs); i++ { if _, ex := fieldNameSet[vs[i].Name]; ex { newName := getNewFieldName(vs[i].Name, fieldNameSet) fieldNameSet[newName] = true vs[i].Name = newName } else { fieldNameSet[vs[i].Name] = true } } } ================================================ FILE: cmd/hz/thrift/plugin.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package thrift import ( "encoding/json" "errors" "fmt" "io/ioutil" "os" "path/filepath" "strings" "github.com/cloudwego/hertz/cmd/hz/config" "github.com/cloudwego/hertz/cmd/hz/generator" "github.com/cloudwego/hertz/cmd/hz/generator/model" "github.com/cloudwego/hertz/cmd/hz/meta" "github.com/cloudwego/hertz/cmd/hz/util" "github.com/cloudwego/hertz/cmd/hz/util/logs" "github.com/cloudwego/thriftgo/generator/backend" "github.com/cloudwego/thriftgo/generator/golang" "github.com/cloudwego/thriftgo/generator/golang/styles" "github.com/cloudwego/thriftgo/parser" thriftgo_plugin "github.com/cloudwego/thriftgo/plugin" ) type Plugin struct { req *thriftgo_plugin.Request args *config.Argument logger *logs.StdLogger rmTags []string } var debugPlugin = os.Getenv("HERTZ_DEBUG_PLUGIN") != "" func NewPlugin(args *config.Argument, req *thriftgo_plugin.Request) *Plugin { ret := &Plugin{ args: args, req: req, } ret.setLogger() return ret } func (plugin *Plugin) Run() int { plugin.setLogger() args := &config.Argument{} defer func() { if args == nil { return } if args.Verbose { verboseLog := plugin.recvVerboseLogger() if len(verboseLog) != 0 { fmt.Fprintf(os.Stderr, verboseLog) } } else { warning := plugin.recvWarningLogger() if len(warning) != 0 { fmt.Fprintf(os.Stderr, warning) } } }() in, err := plugin.handleRequest() if err != nil { logs.Errorf("handle request failed: %s", err.Error()) return meta.PluginError } args, err = plugin.parseArgs() if err != nil { logs.Errorf("parse args failed: %s", err.Error()) return meta.PluginError } if debugPlugin { os.WriteFile("./req.tf", in, 0644) js, err := json.Marshal(args) if err != nil { logs.Errorf("marshal request failed: %s\n", err.Error()) return meta.PluginError } os.WriteFile("./args.json", js, 0644) } res, err := plugin.Handle(args) if err != nil { logs.Errorf("handle failed: %s", err.Error()) return meta.PluginError } if res != nil { if err = plugin.response(res); err != nil { logs.Errorf("response failed: %s", err.Error()) return meta.PluginError } } return 0 } func (plugin *Plugin) Handle(args *config.Argument) (*thriftgo_plugin.Response, error) { plugin.rmTags = args.RmTags if args.CmdType == meta.CmdModel { // check tag options for model mode CheckTagOption(plugin.args) res, err := plugin.GetResponse(nil, args.OutDir) if err != nil { logs.Errorf("get response failed: %s", err.Error()) return nil, err } if err := plugin.response(res); err != nil { logs.Errorf("response failed: %s", err.Error()) return nil, err } return nil, nil } err := plugin.initNameStyle() if err != nil { logs.Errorf("init naming style failed: %s", err.Error()) return nil, err } options := CheckTagOption(plugin.args) pkgInfo, err := plugin.getPackageInfo() if err != nil { logs.Errorf("get http package info failed: %s", err.Error()) return nil, err } customPackageTemplate := args.CustomizePackage pkg, err := args.GetGoPackage() if err != nil { logs.Errorf("get go package failed: %s", err.Error()) return nil, err } handlerDir, err := args.GetHandlerDir() if err != nil { logs.Errorf("get handler dir failed: %s", err.Error()) return nil, err } routerDir, err := args.GetRouterDir() if err != nil { logs.Errorf("get router dir failed: %s", err.Error()) return nil, err } modelDir, err := args.GetModelDir() if err != nil { logs.Errorf("get model dir failed: %s", err.Error()) return nil, err } clientDir, err := args.GetClientDir() if err != nil { logs.Errorf("get client dir failed: %s", err.Error()) return nil, err } sg := generator.HttpPackageGenerator{ ConfigPath: customPackageTemplate, HandlerDir: handlerDir, RouterDir: routerDir, ModelDir: modelDir, UseDir: args.Use, ClientDir: clientDir, TemplateGenerator: generator.TemplateGenerator{ OutputDir: args.OutDir, Excludes: args.Excludes, }, ProjPackage: pkg, Options: options, HandlerByMethod: args.HandlerByMethod, CmdType: args.CmdType, IdlClientDir: util.SubDir(modelDir, pkgInfo.Package), ForceClientDir: args.ForceClientDir, BaseDomain: args.BaseDomain, QueryEnumAsInt: args.QueryEnumAsInt, SnakeStyleMiddleware: args.SnakeStyleMiddleware, SortRouter: args.SortRouter, ForceUpdateClient: args.ForceUpdateClient, } if args.ModelBackend != "" { sg.Backend = meta.Backend(args.ModelBackend) } generator.SetDefaultTemplateConfig() err = sg.Generate(pkgInfo) if err != nil { logs.Errorf("generate package failed: %s", err.Error()) return nil, err } if len(args.Use) != 0 { err = sg.Persist() if err != nil { logs.Errorf("persist file failed within '-use' option: %s", err.Error()) return nil, err } res := thriftgo_plugin.BuildErrorResponse(errors.New(meta.TheUseOptionMessage).Error()) err = plugin.response(res) if err != nil { logs.Errorf("response failed: %s", err.Error()) return nil, err } return nil, nil } files, err := sg.GetFormatAndExcludedFiles() if err != nil { logs.Errorf("format file failed: %s", err.Error()) return nil, err } res, err := plugin.GetResponse(files, sg.OutputDir) if err != nil { logs.Errorf("get response failed: %s", err.Error()) return nil, err } return res, nil } func (plugin *Plugin) setLogger() { plugin.logger = logs.NewStdLogger(logs.LevelInfo) plugin.logger.Defer = true plugin.logger.ErrOnly = true logs.SetLogger(plugin.logger) } func (plugin *Plugin) recvWarningLogger() string { warns := plugin.logger.Warn() plugin.logger.Flush() logs.SetLogger(logs.NewStdLogger(logs.LevelInfo)) return warns } func (plugin *Plugin) recvVerboseLogger() string { info := plugin.logger.Out() warns := plugin.logger.Warn() verboseLog := string(info) + warns plugin.logger.Flush() logs.SetLogger(logs.NewStdLogger(logs.LevelInfo)) return verboseLog } func (plugin *Plugin) handleRequest() ([]byte, error) { data, err := ioutil.ReadAll(os.Stdin) if err != nil { return nil, fmt.Errorf("read request failed: %s", err.Error()) } req, err := thriftgo_plugin.UnmarshalRequest(data) if err != nil { return data, fmt.Errorf("unmarshal request failed: %s", err.Error()) } plugin.req = req // init thriftgo utils thriftgoUtil = golang.NewCodeUtils(backend.DummyLogFunc()) thriftgoUtil.HandleOptions(req.GeneratorParameters) return data, nil } func (plugin *Plugin) parseArgs() (*config.Argument, error) { if plugin.req == nil { return nil, fmt.Errorf("request is nil") } args := new(config.Argument) err := args.Unpack(plugin.req.PluginParameters) if err != nil { logs.Errorf("unpack args failed: %s", err.Error()) } plugin.args = args return args, nil } // initNameStyle initializes the naming style based on the "naming_style" option for thrift. func (plugin *Plugin) initNameStyle() error { if len(plugin.args.ThriftOptions) == 0 { return nil } for _, opt := range plugin.args.ThriftOptions { parts := strings.SplitN(opt, "=", 2) if len(parts) == 2 && parts[0] == "naming_style" { NameStyle = styles.NewNamingStyle(parts[1]) if NameStyle == nil { return fmt.Errorf("do not support \"%s\" naming style", parts[1]) } break } } return nil } func (plugin *Plugin) getPackageInfo() (*generator.HttpPackage, error) { req := plugin.req args := plugin.args ast := req.GetAST() if ast == nil { return nil, fmt.Errorf("no ast") } logs.Infof("Processing %s", ast.GetFilename()) pkgMap := args.OptPkgMap pkg := getGoPackage(ast, pkgMap) main := &model.Model{ FilePath: ast.Filename, Package: pkg, PackageName: util.SplitPackageName(pkg, ""), } rs, err := NewResolver(ast, main, pkgMap) if err != nil { return nil, fmt.Errorf("new thrift resolver failed, err:%v", err) } err = rs.LoadAll(ast) if err != nil { return nil, err } idlPackage := getGoPackage(ast, pkgMap) if idlPackage == "" { return nil, fmt.Errorf("go package for '%s' is not defined", ast.GetFilename()) } services, err := astToService(ast, rs, args) if err != nil { return nil, err } var models model.Models for _, s := range services { models.MergeArray(s.Models) } return &generator.HttpPackage{ Services: services, IdlName: ast.GetFilename(), Package: idlPackage, Models: models, }, nil } func (plugin *Plugin) response(res *thriftgo_plugin.Response) error { data, err := thriftgo_plugin.MarshalResponse(res) if err != nil { return fmt.Errorf("marshal response failed: %s", err.Error()) } _, err = os.Stdout.Write(data) if err != nil { return fmt.Errorf("write response failed: %s", err.Error()) } return nil } func (plugin *Plugin) InsertTag() ([]*thriftgo_plugin.Generated, error) { var res []*thriftgo_plugin.Generated if plugin.args.NoRecurse { outPath := plugin.req.OutputPath packageName := getGoPackage(plugin.req.AST, nil) fileName := util.BaseNameAndTrim(plugin.req.AST.GetFilename()) + ".go" outPath = filepath.Join(outPath, packageName, fileName) for _, st := range plugin.req.AST.Structs { stName := st.GetName() for _, f := range st.Fields { fieldName := f.GetName() tagString, err := getTagString(f, plugin.rmTags) if err != nil { return nil, err } insertPointer := "struct." + stName + "." + fieldName + "." + "tag" gen := &thriftgo_plugin.Generated{ Content: tagString, Name: &outPath, InsertionPoint: &insertPointer, } res = append(res, gen) } } return res, nil } for ast := range plugin.req.AST.DepthFirstSearch() { outPath := plugin.req.OutputPath packageName := getGoPackage(ast, nil) fileName := util.BaseNameAndTrim(ast.GetFilename()) + ".go" outPath = filepath.Join(outPath, packageName, fileName) for _, st := range ast.Structs { stName := st.GetName() for _, f := range st.Fields { fieldName := f.GetName() tagString, err := getTagString(f, plugin.rmTags) if err != nil { return nil, err } insertPointer := "struct." + stName + "." + fieldName + "." + "tag" gen := &thriftgo_plugin.Generated{ Content: tagString, Name: &outPath, InsertionPoint: &insertPointer, } res = append(res, gen) } } } return res, nil } func (plugin *Plugin) GetResponse(files []generator.File, outputDir string) (*thriftgo_plugin.Response, error) { var contents []*thriftgo_plugin.Generated for _, file := range files { filePath := filepath.Join(outputDir, file.Path) content := &thriftgo_plugin.Generated{ Content: file.Content, Name: &filePath, } contents = append(contents, content) } insertTag, err := plugin.InsertTag() if err != nil { return nil, err } contents = append(contents, insertTag...) return &thriftgo_plugin.Response{ Contents: contents, }, nil } func getTagString(f *parser.Field, rmTags []string) (string, error) { field := model.Field{} err := injectTags(f, &field, true, false) if err != nil { return "", err } disableTag := false if v := getAnnotation(f.Annotations, AnnotationNone); len(v) > 0 { if strings.EqualFold(v[0], "true") { disableTag = true } } for _, rmTag := range rmTags { for _, t := range field.Tags { if t.IsDefault && strings.EqualFold(t.Key, rmTag) { field.Tags.Remove(t.Key) } } } var tagString string tags := field.Tags for idx, tag := range tags { value := tag.Value if disableTag { value = "-" } if idx == 0 { tagString += " " + tag.Key + ":\"" + value + "\"" + " " } else if idx == len(tags)-1 { tagString += tag.Key + ":\"" + value + "\"" } else { tagString += tag.Key + ":\"" + value + "\"" + " " } } return tagString, nil } ================================================ FILE: cmd/hz/thrift/plugin_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package thrift import ( "io/ioutil" "testing" "github.com/cloudwego/hertz/cmd/hz/generator" "github.com/cloudwego/hertz/cmd/hz/meta" "github.com/cloudwego/thriftgo/plugin" ) func TestRun(t *testing.T) { data, err := ioutil.ReadFile("../testdata/request_thrift.out") if err != nil { t.Fatal(err) } req, err := plugin.UnmarshalRequest(data) if err != nil { t.Fatal(err) } plu := new(Plugin) plu.setLogger() plu.req = req _, err = plu.parseArgs() if err != nil { t.Fatal(err) } options := CheckTagOption(plu.args) pkgInfo, err := plu.getPackageInfo() if err != nil { t.Fatal(err) } args := plu.args customPackageTemplate := args.CustomizePackage pkg, err := args.GetGoPackage() if err != nil { t.Fatal(err) } handlerDir, err := args.GetHandlerDir() if err != nil { t.Fatal(err) } routerDir, err := args.GetRouterDir() if err != nil { t.Fatal(err) } modelDir, err := args.GetModelDir() if err != nil { t.Fatal(err) } clientDir, err := args.GetClientDir() if err != nil { t.Fatal(err) } sg := generator.HttpPackageGenerator{ ConfigPath: customPackageTemplate, HandlerDir: handlerDir, RouterDir: routerDir, ModelDir: modelDir, UseDir: args.Use, ClientDir: clientDir, TemplateGenerator: generator.TemplateGenerator{ OutputDir: args.OutDir, Excludes: args.Excludes, }, ProjPackage: pkg, Options: options, HandlerByMethod: args.HandlerByMethod, CmdType: args.CmdType, ForceClientDir: args.ForceClientDir, BaseDomain: args.BaseDomain, QueryEnumAsInt: args.QueryEnumAsInt, SnakeStyleMiddleware: args.SnakeStyleMiddleware, SortRouter: args.SortRouter, } if args.ModelBackend != "" { sg.Backend = meta.Backend(args.ModelBackend) } err = sg.Generate(pkgInfo) if err != nil { t.Fatalf("generate package failed: %v", err) } files, err := sg.GetFormatAndExcludedFiles() if err != nil { return } res, err := plu.GetResponse(files, sg.OutputDir) if err != nil { return } plu.response(res) if err != nil { return } } ================================================ FILE: cmd/hz/thrift/resolver.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package thrift import ( "fmt" "strings" "github.com/cloudwego/hertz/cmd/hz/generator/model" "github.com/cloudwego/hertz/cmd/hz/util" "github.com/cloudwego/thriftgo/parser" ) var ( ConstTrue = Symbol{ IsValue: true, Type: model.TypeBool, Value: true, Scope: &BaseThrift, } ConstFalse = Symbol{ IsValue: true, Type: model.TypeBool, Value: false, Scope: &BaseThrift, } ConstEmptyString = Symbol{ IsValue: true, Type: model.TypeString, Value: "", Scope: &BaseThrift, } ) type PackageReference struct { IncludeBase string IncludePath string Model *model.Model Ast *parser.Thrift Referred bool } func getReferPkgMap(pkgMap map[string]string, incs []*parser.Include, mainModel *model.Model) (map[string]*PackageReference, error) { var err error out := make(map[string]*PackageReference, len(pkgMap)) pkgAliasMap := make(map[string]string, len(incs)) // bugfix: add main package to avoid namespace conflict mainPkg := mainModel.Package mainPkgName := mainModel.PackageName mainPkgName, err = util.GetPackageUniqueName(mainPkgName) if err != nil { return nil, err } pkgAliasMap[mainPkg] = mainPkgName for _, inc := range incs { pkg := getGoPackage(inc.Reference, pkgMap) impt := inc.GetPath() base := util.BaseNameAndTrim(impt) pkgName := util.SplitPackageName(pkg, "") if pn, exist := pkgAliasMap[pkg]; exist { pkgName = pn } else { pkgName, err = util.GetPackageUniqueName(pkgName) pkgAliasMap[pkg] = pkgName if err != nil { return nil, fmt.Errorf("get package unique name failed, err: %v", err) } } out[base] = &PackageReference{base, impt, &model.Model{ FilePath: inc.Path, Package: pkg, PackageName: pkgName, }, inc.Reference, false} } return out, nil } type Symbol struct { IsValue bool Type *model.Type Value interface{} Scope *parser.Thrift } type NameSpace map[string]*Symbol type Resolver struct { // idl symbols root NameSpace deps map[string]NameSpace // exported models mainPkg PackageReference refPkgs map[string]*PackageReference } func NewResolver(ast *parser.Thrift, model *model.Model, pkgMap map[string]string) (*Resolver, error) { pm, err := getReferPkgMap(pkgMap, ast.GetIncludes(), model) if err != nil { return nil, fmt.Errorf("get package map failed, err: %v", err) } file := ast.GetFilename() return &Resolver{ root: make(NameSpace), deps: make(map[string]NameSpace), refPkgs: pm, mainPkg: PackageReference{ IncludeBase: util.BaseNameAndTrim(file), IncludePath: ast.GetFilename(), Model: model, Ast: ast, Referred: false, }, }, nil } func (resolver *Resolver) GetRefModel(includeBase string) (*model.Model, error) { if includeBase == "" { return resolver.mainPkg.Model, nil } ref, ok := resolver.refPkgs[includeBase] if !ok { return nil, fmt.Errorf("not found include %s", includeBase) } return ref.Model, nil } func (resolver *Resolver) getBaseType(typ *parser.Type) (*model.Type, bool) { tt := switchBaseType(typ) if tt != nil { return tt, true } if typ.Name == "map" { t := *model.TypeBaseMap return &t, false } if typ.Name == "list" { t := *model.TypeBaseList return &t, false } if typ.Name == "set" { t := *model.TypeBaseList return &t, false } return nil, false } func (resolver *Resolver) ResolveType(typ *parser.Type) (*model.Type, error) { bt, base := resolver.getBaseType(typ) if bt != nil { if base { return bt, nil } else { if typ.Name == model.TypeBaseMap.Name { resolveKey, err := resolver.ResolveType(typ.KeyType) if err != nil { return nil, err } resolveValue, err := resolver.ResolveType(typ.ValueType) if err != nil { return nil, err } bt.Extra = append(bt.Extra, resolveKey, resolveValue) } else if typ.Name == model.TypeBaseList.Name || typ.Name == model.TypeBaseSet.Name { resolveValue, err := resolver.ResolveType(typ.ValueType) if err != nil { return nil, err } bt.Extra = append(bt.Extra, resolveValue) } else { return nil, fmt.Errorf("invalid DefinitionType(%+v)", bt) } return bt, nil } } id := typ.GetName() rs, err := resolver.ResolveIdentifier(id) if err != nil { return nil, err } sb := rs.Symbol if sb == nil { return nil, fmt.Errorf("not found identifier %s", id) } return sb.Type, nil } func (resolver *Resolver) ResolveConstantValue(constant *parser.ConstValue) (model.Literal, error) { switch constant.Type { case parser.ConstType_ConstInt: return model.IntExpression{Src: int(constant.TypedValue.GetInt())}, nil case parser.ConstType_ConstDouble: return model.DoubleExpression{Src: constant.TypedValue.GetDouble()}, nil case parser.ConstType_ConstLiteral: return model.StringExpression{Src: constant.TypedValue.GetLiteral()}, nil case parser.ConstType_ConstList: eleType, err := switchConstantType(constant.Type) if err != nil { return nil, err } ret := model.ListExpression{ ElementType: eleType, } for _, i := range constant.TypedValue.List { elem, err := resolver.ResolveConstantValue(i) if err != nil { return nil, err } ret.Elements = append(ret.Elements, elem) } return ret, nil case parser.ConstType_ConstMap: keyType, err := switchConstantType(constant.TypedValue.Map[0].Key.Type) if err != nil { return nil, err } valueType, err := switchConstantType(constant.TypedValue.Map[0].Value.Type) if err != nil { return nil, err } ret := model.MapExpression{ KeyType: keyType, ValueType: valueType, Elements: make(map[string]model.Literal, len(constant.TypedValue.Map)), } for _, v := range constant.TypedValue.Map { value, err := resolver.ResolveConstantValue(v.Value) if err != nil { return nil, err } ret.Elements[v.Key.String()] = value } return ret, nil case parser.ConstType_ConstIdentifier: return resolver.ResolveIdentifier(*constant.TypedValue.Identifier) } return model.StringExpression{Src: constant.String()}, nil } func (resolver *Resolver) ResolveIdentifier(id string) (ret ResolvedSymbol, err error) { sb := resolver.Get(id) if sb == nil { return ResolvedSymbol{}, fmt.Errorf("identifier '%s' not found", id) } ret.Symbol = sb ret.Base = id if sb.Scope == &BaseThrift { return } if sb.Scope == resolver.mainPkg.Ast { resolver.mainPkg.Referred = true ret.Src = resolver.mainPkg.Model.PackageName return } idx := strings.LastIndex(id, ".") depName := id[:idx] typeName := id[idx+1:] if ref, ok := resolver.refPkgs[depName]; ok { ref.Referred = true ret.Base = typeName ret.Src = ref.Model.PackageName if ret.Type == nil { ret.Type = &model.Type{} } ret.Type.Scope = ref.Model } else { return ResolvedSymbol{}, fmt.Errorf("can't resolve identifier '%s'", id) } return } func (resolver *Resolver) ResolveTypeName(typ *parser.Type) (string, error) { if typ.GetIsTypedef() { rt, err := resolver.ResolveIdentifier(typ.GetName()) if err != nil { return "", err } return rt.Expression(), nil } switch typ.GetCategory() { case parser.Category_Map: keyType, err := resolver.ResolveTypeName(typ.GetKeyType()) if err != nil { return "", err } if typ.GetKeyType().GetCategory().IsStruct() { keyType = "*" + keyType } valueType, err := resolver.ResolveTypeName(typ.GetValueType()) if err != nil { return "", err } if typ.GetValueType().GetCategory().IsStruct() { valueType = "*" + valueType } return fmt.Sprintf("map[%s]%s", keyType, valueType), nil case parser.Category_List, parser.Category_Set: // list/set -> []element for thriftgo // valueType refers the element type for list/set elemType, err := resolver.ResolveTypeName(typ.GetValueType()) if err != nil { return "", err } if typ.GetValueType().GetCategory().IsStruct() { elemType = "*" + elemType } return fmt.Sprintf("[]%s", elemType), err } rt, err := resolver.ResolveIdentifier(typ.GetName()) if err != nil { return "", err } return rt.Expression(), nil } func (resolver *Resolver) Get(name string) *Symbol { s, ok := resolver.root[name] if ok { return s } if strings.Contains(name, ".") { idx := strings.LastIndex(name, ".") depName := name[:idx] typeName := name[idx+1:] if ref, ok := resolver.deps[depName]; ok { if ss, ok := ref[typeName]; ok { return ss } } } return nil } func (resolver *Resolver) ExportReferred(all, needMain bool) (ret []*PackageReference) { for _, v := range resolver.refPkgs { if all { ret = append(ret, v) v.Referred = false } else if v.Referred { ret = append(ret, v) v.Referred = false } } if needMain && (all || resolver.mainPkg.Referred) { ret = append(ret, &resolver.mainPkg) } resolver.mainPkg.Referred = false return } func (resolver *Resolver) LoadAll(ast *parser.Thrift) error { var err error resolver.root, err = resolver.LoadOne(ast) if err != nil { return fmt.Errorf("load root package: %s", err) } includes := ast.GetIncludes() astMap := make(map[string]NameSpace, len(includes)) for _, dep := range includes { bName := util.BaseName(dep.Path, ".thrift") astMap[bName], err = resolver.LoadOne(dep.Reference) if err != nil { return fmt.Errorf("load idl %s: %s", dep.Path, err) } } resolver.deps = astMap for _, td := range ast.Typedefs { name := td.GetAlias() if _, ex := resolver.root[name]; ex { if resolver.root[name].Type != nil { typ := newTypedefType(resolver.root[name].Type, name) resolver.root[name].Type = &typ continue } } sym := resolver.Get(td.Type.GetName()) typ := newTypedefType(sym.Type, name) resolver.root[name].Type = &typ } return nil } func LoadBaseIdentifier() NameSpace { ret := make(NameSpace, 16) ret["true"] = &ConstTrue ret["false"] = &ConstFalse ret[`""`] = &ConstEmptyString ret["bool"] = &Symbol{ Type: model.TypeBool, Scope: &BaseThrift, } ret["byte"] = &Symbol{ Type: model.TypeByte, Scope: &BaseThrift, } ret["i8"] = &Symbol{ Type: model.TypeInt8, Scope: &BaseThrift, } ret["i16"] = &Symbol{ Type: model.TypeInt16, Scope: &BaseThrift, } ret["i32"] = &Symbol{ Type: model.TypeInt32, Scope: &BaseThrift, } ret["i64"] = &Symbol{ Type: model.TypeInt64, Scope: &BaseThrift, } ret["int"] = &Symbol{ Type: model.TypeInt, Scope: &BaseThrift, } ret["double"] = &Symbol{ Type: model.TypeFloat64, Scope: &BaseThrift, } ret["string"] = &Symbol{ Type: model.TypeString, Scope: &BaseThrift, } ret["binary"] = &Symbol{ Type: model.TypeBinary, Scope: &BaseThrift, } ret["list"] = &Symbol{ Type: model.TypeBaseList, Scope: &BaseThrift, } ret["set"] = &Symbol{ Type: model.TypeBaseSet, Scope: &BaseThrift, } ret["map"] = &Symbol{ Type: model.TypeBaseMap, Scope: &BaseThrift, } return ret } func (resolver *Resolver) LoadOne(ast *parser.Thrift) (NameSpace, error) { ret := LoadBaseIdentifier() for _, e := range ast.Enums { prefix := e.GetName() ret[prefix] = &Symbol{ IsValue: false, Value: e, Scope: ast, Type: newEnumType(prefix, model.CategoryEnum), } for _, ee := range e.Values { name := prefix + "." + ee.GetName() if _, exist := ret[name]; exist { return nil, fmt.Errorf("duplicated identifier '%s' in %s", name, ast.Filename) } ret[name] = &Symbol{ IsValue: true, Value: ee, Scope: ast, Type: newBaseType(model.TypeInt, model.CategoryEnum), } } } for _, e := range ast.Constants { name := e.GetName() if _, exist := ret[name]; exist { return nil, fmt.Errorf("duplicated identifier '%s' in %s", name, ast.Filename) } gt, _ := resolver.getBaseType(e.Type) ret[name] = &Symbol{ IsValue: true, Value: e, Scope: ast, Type: gt, } } for _, e := range ast.Structs { name := e.GetName() if _, exist := ret[name]; exist { return nil, fmt.Errorf("duplicated identifier '%s' in %s", name, ast.Filename) } ret[name] = &Symbol{ IsValue: false, Value: e, Scope: ast, Type: newStructType(name, model.CategoryStruct), } } for _, e := range ast.Unions { name := e.GetName() if _, exist := ret[name]; exist { return nil, fmt.Errorf("duplicated identifier '%s' in %s", name, ast.Filename) } ret[name] = &Symbol{ IsValue: false, Value: e, Scope: ast, Type: newStructType(name, model.CategoryStruct), } } for _, e := range ast.Exceptions { name := e.GetName() if _, exist := ret[name]; exist { return nil, fmt.Errorf("duplicated identifier '%s' in %s", name, ast.Filename) } ret[name] = &Symbol{ IsValue: false, Value: e, Scope: ast, Type: newStructType(name, model.CategoryStruct), } } for _, e := range ast.Services { name := e.GetName() if _, exist := ret[name]; exist { return nil, fmt.Errorf("duplicated identifier '%s' in %s", name, ast.Filename) } ret[name] = &Symbol{ IsValue: false, Value: e, Scope: ast, Type: newFuncType(name, model.CategoryService), } } for _, td := range ast.Typedefs { name := td.GetAlias() if _, exist := ret[name]; exist { return nil, fmt.Errorf("duplicated identifier '%s' in %s", name, ast.Filename) } gt, _ := resolver.getBaseType(td.Type) if gt == nil { sym := ret[td.Type.Name] if sym != nil { gt = sym.Type } } ret[name] = &Symbol{ IsValue: false, Value: td, Scope: ast, Type: gt, } } return ret, nil } func switchConstantType(constant parser.ConstType) (*model.Type, error) { switch constant { case parser.ConstType_ConstInt: return model.TypeInt, nil case parser.ConstType_ConstDouble: return model.TypeFloat64, nil case parser.ConstType_ConstLiteral: return model.TypeString, nil default: return nil, fmt.Errorf("unknown constant type %d", constant) } } func newTypedefType(t *model.Type, name string) model.Type { tmp := t typ := *tmp typ.Name = name typ.Category = model.CategoryTypedef return typ } ================================================ FILE: cmd/hz/thrift/tag_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package thrift import ( "io/ioutil" "strings" "testing" "github.com/cloudwego/hertz/cmd/hz/config" "github.com/cloudwego/thriftgo/plugin" ) func TestInsertTag(t *testing.T) { data, err := ioutil.ReadFile("./test_data/thrift_tag_test.out") if err != nil { t.Fatal(err) } req, err := plugin.UnmarshalRequest(data) if err != nil { t.Fatal(err) } plu := new(Plugin) plu.req = req plu.args = new(config.Argument) type TagStruct struct { Annotation string GeneratedTag string ActualTag string } tagList := []TagStruct{ { Annotation: "query", GeneratedTag: "json:\"DefaultQueryTag\" query:\"query\"", }, { Annotation: "raw_body", GeneratedTag: "json:\"RawBodyTag\" raw_body:\"raw_body\"", }, { Annotation: "path", GeneratedTag: "json:\"PathTag\" path:\"path\"", }, { Annotation: "form", GeneratedTag: "form:\"form\" json:\"FormTag\"", }, { Annotation: "cookie", GeneratedTag: "cookie:\"cookie\" json:\"CookieTag\"", }, { Annotation: "header", GeneratedTag: "header:\"header\" json:\"HeaderTag\"", }, { Annotation: "body", GeneratedTag: "form:\"body\" json:\"body\"", }, { Annotation: "go.tag", GeneratedTag: "", }, { Annotation: "vd", GeneratedTag: "form:\"VdTag\" json:\"VdTag\" query:\"VdTag\" vd:\"$!='?'\"", }, { Annotation: "non", GeneratedTag: "form:\"DefaultTag\" json:\"DefaultTag\" query:\"DefaultTag\"", }, { Annotation: "query required", GeneratedTag: "json:\"ReqQuery,required\" query:\"query,required\"", }, { Annotation: "query optional", GeneratedTag: "json:\"OptQuery,omitempty\" query:\"query\"", }, { Annotation: "body required", GeneratedTag: "form:\"body,required\" json:\"body,required\"", }, { Annotation: "body optional", GeneratedTag: "form:\"body\" json:\"body,omitempty\"", }, { Annotation: "go.tag required", GeneratedTag: "form:\"ReqGoTag,required\" query:\"ReqGoTag,required\"", }, { Annotation: "go.tag optional", GeneratedTag: "form:\"OptGoTag\" query:\"OptGoTag\"", }, { Annotation: "go tag cover query", GeneratedTag: "form:\"QueryGoTag,required\" json:\"QueryGoTag,required\"", }, } tags, err := plu.InsertTag() if err != nil { t.Fatal(err) } for i, tag := range tags { tagList[i].ActualTag = tag.Content if !strings.Contains(tagList[i].ActualTag, tagList[i].GeneratedTag) { t.Fatalf("expected tag: '%s', but autual tag: '%s'", tagList[i].GeneratedTag, tagList[i].ActualTag) } } } ================================================ FILE: cmd/hz/thrift/tags.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package thrift import ( "fmt" "sort" "strconv" "strings" "github.com/cloudwego/hertz/cmd/hz/config" "github.com/cloudwego/hertz/cmd/hz/generator" "github.com/cloudwego/hertz/cmd/hz/generator/model" "github.com/cloudwego/hertz/cmd/hz/util" "github.com/cloudwego/thriftgo/parser" ) const ( AnnotationQuery = "api.query" AnnotationForm = "api.form" AnnotationPath = "api.path" AnnotationHeader = "api.header" AnnotationCookie = "api.cookie" AnnotationBody = "api.body" AnnotationRawBody = "api.raw_body" AnnotationJsConv = "api.js_conv" AnnotationNone = "api.none" AnnotationFileName = "api.file_name" AnnotationValidator = "api.vd" AnnotationGoTag = "go.tag" ) const ( ApiGet = "api.get" ApiPost = "api.post" ApiPut = "api.put" ApiPatch = "api.patch" ApiDelete = "api.delete" ApiOptions = "api.options" ApiHEAD = "api.head" ApiAny = "api.any" ApiPath = "api.path" ApiSerializer = "api.serializer" ApiGenPath = "api.handler_path" ) const ( ApiBaseDomain = "api.base_domain" ApiServiceGroup = "api.service_group" ApiServiceGenDir = "api.service_gen_dir" // handler_dir for handler_by_service ApiServicePath = "api.service_path" // declare the path to the service's handler according to this annotation for handler_by_method ) var ( HttpMethodAnnotations = map[string]string{ ApiGet: "GET", ApiPost: "POST", ApiPut: "PUT", ApiPatch: "PATCH", ApiDelete: "DELETE", ApiOptions: "OPTIONS", ApiHEAD: "HEAD", ApiAny: "ANY", } HttpMethodOptionAnnotations = map[string]string{ ApiGenPath: "handler_path", } BindingTags = map[string]string{ AnnotationPath: "path", AnnotationQuery: "query", AnnotationHeader: "header", AnnotationCookie: "cookie", AnnotationBody: "json", AnnotationForm: "form", AnnotationRawBody: "raw_body", } SerializerTags = map[string]string{ ApiSerializer: "serializer", } ValidatorTags = map[string]string{AnnotationValidator: "vd"} ) var ( jsonSnakeName = false unsetOmitempty = false ) func CheckTagOption(args *config.Argument) []generator.Option { var ret []generator.Option if args == nil { return ret } if args.SnakeName { jsonSnakeName = true } if args.UnsetOmitempty { unsetOmitempty = true } if args.JSONEnumStr { ret = append(ret, generator.OptionMarshalEnumToText) } return ret } func checkSnakeName(name string) string { if jsonSnakeName { name = util.ToSnakeCase(name) } return name } func getAnnotation(input parser.Annotations, target string) []string { if len(input) == 0 { return nil } for _, anno := range input { if strings.ToLower(anno.Key) == target { return anno.Values } } return []string{} } type httpAnnotation struct { method string path []string } type httpAnnotations []httpAnnotation func (s httpAnnotations) Len() int { return len(s) } func (s httpAnnotations) Swap(i, j int) { s[i], s[j] = s[j], s[i] } func (s httpAnnotations) Less(i, j int) bool { return s[i].method < s[j].method } func getAnnotations(input parser.Annotations, targets map[string]string) map[string][]string { if len(input) == 0 || len(targets) == 0 { return nil } out := map[string][]string{} for k, t := range targets { var ret *parser.Annotation for _, anno := range input { if strings.ToLower(anno.Key) == k { ret = anno break } } if ret == nil { continue } out[t] = ret.Values } return out } func defaultBindingTags(f *parser.Field) []model.Tag { out := make([]model.Tag, 3) bindingTags := []string{ AnnotationQuery, AnnotationForm, AnnotationPath, AnnotationHeader, AnnotationCookie, AnnotationBody, AnnotationRawBody, } for _, tag := range bindingTags { if v := getAnnotation(f.Annotations, tag); len(v) > 0 { out[0] = jsonTag(f) return out[:1] } } if v := getAnnotation(f.Annotations, AnnotationBody); len(v) > 0 { val := getJsonValue(f, v[0]) out[0] = tag("json", val) } else { t := jsonTag(f) t.IsDefault = true out[0] = t } if v := getAnnotation(f.Annotations, AnnotationQuery); len(v) > 0 { val := checkRequire(f, v[0]) out[1] = tag(BindingTags[AnnotationQuery], val) } else { val := checkRequire(f, checkSnakeName(f.Name)) t := tag(BindingTags[AnnotationQuery], val) t.IsDefault = true out[1] = t } if v := getAnnotation(f.Annotations, AnnotationForm); len(v) > 0 { val := checkRequire(f, v[0]) out[2] = tag(BindingTags[AnnotationForm], val) } else { val := checkRequire(f, checkSnakeName(f.Name)) t := tag(BindingTags[AnnotationForm], val) t.IsDefault = true out[2] = t } return out } func jsonTag(f *parser.Field) (ret model.Tag) { ret.Key = "json" ret.Value = checkSnakeName(f.Name) if v := getAnnotation(f.Annotations, AnnotationJsConv); len(v) > 0 { ret.Value += ",string" } if !unsetOmitempty && f.Requiredness == parser.FieldType_Optional { ret.Value += ",omitempty" } else if f.Requiredness == parser.FieldType_Required { ret.Value += ",required" } return } func tag(k, v string) model.Tag { return model.Tag{ Key: k, Value: v, } } func annotationToTags(as parser.Annotations, targets map[string]string) (tags []model.Tag) { rets := getAnnotations(as, targets) for k, v := range rets { for _, vv := range v { tags = append(tags, model.Tag{ Key: k, Value: vv, }) } } return } func injectTags(f *parser.Field, gf *model.Field, needDefault, needGoTag bool) error { as := f.Annotations if as == nil { as = parser.Annotations{} } tags := gf.Tags if tags == nil { tags = make([]model.Tag, 0, len(as)) } if needDefault { tags = append(tags, defaultBindingTags(f)...) } // binding tags bts := annotationToTags(as, BindingTags) for _, t := range bts { key := t.Key tags.Remove(key) if key == "json" { formVal := t.Value t.Value = getJsonValue(f, t.Value) formVal = checkRequire(f, formVal) tags = append(tags, tag("form", formVal)) } else { t.Value = checkRequire(f, t.Value) } tags = append(tags, t) } // validator tags tags = append(tags, annotationToTags(as, ValidatorTags)...) // the tag defined by gotag with higher priority checkGoTag(as, &tags) // go.tags for compiler mode if needGoTag { rets := getAnnotation(as, AnnotationGoTag) for _, v := range rets { gts := util.SplitGoTags(v) for _, gt := range gts { sp := strings.SplitN(gt, ":", 2) if len(sp) != 2 { return fmt.Errorf("invalid go tag: %s", v) } vv, err := strconv.Unquote(sp[1]) if err != nil { return fmt.Errorf("invalid go.tag value: %s, err: %v", sp[1], err.Error()) } key := sp[0] tags.Remove(key) tags = append(tags, model.Tag{ Key: key, Value: vv, }) } } } sort.Sort(tags) gf.Tags = tags return nil } func getJsonValue(f *parser.Field, val string) string { if v := getAnnotation(f.Annotations, AnnotationJsConv); len(v) > 0 { val += ",string" } if !unsetOmitempty && f.Requiredness == parser.FieldType_Optional { val += ",omitempty" } else if f.Requiredness == parser.FieldType_Required { val += ",required" } return val } func checkRequire(f *parser.Field, val string) string { if f.Requiredness == parser.FieldType_Required { val += ",required" } return val } // checkGoTag removes the tag defined in gotag func checkGoTag(as parser.Annotations, tags *model.Tags) error { rets := getAnnotation(as, AnnotationGoTag) for _, v := range rets { gts := util.SplitGoTags(v) for _, gt := range gts { sp := strings.SplitN(gt, ":", 2) if len(sp) != 2 { return fmt.Errorf("invalid go tag: %s", v) } key := sp[0] tags.Remove(key) } } return nil } ================================================ FILE: cmd/hz/thrift/test_data/test_tag.thrift ================================================ namespace go cloudwego.hertz.hz struct MultiTagReq { // basic feature 1: string DefaultQueryTag (api.query="query"); 2: string RawBodyTag (api.raw_body="raw_body"); 3: string PathTag (api.path="path"); 4: string FormTag (api.form="form"); 5: string CookieTag (api.cookie="cookie"); 6: string HeaderTag (api.header="header"); 7: string BodyTag (api.body="body"); 8: string GoTag (go.tag="json:\"json\" query:\"query\" form:\"form\" header:\"header\" goTag:\"tag\""); 9: string VdTag (api.vd="$!='?'"); 10: string DefaultTag; // optional / required 11: required string ReqQuery (api.query="query"); 12: optional string OptQuery (api.query="query"); 13: required string ReqBody (api.body="body"); 14: optional string OptBody (api.body="body"); 15: required string ReqGoTag (go.tag="json:\"json\""); 16: optional string OptGoTag (go.tag="json:\"json\""); // gotag cover feature 17: required string QueryGoTag (apt.query="query", go.tag="query:\"queryTag\"") } ================================================ FILE: cmd/hz/thrift/thriftgo_util.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package thrift import ( "github.com/cloudwego/thriftgo/generator/golang" "github.com/cloudwego/thriftgo/generator/golang/styles" ) var thriftgoUtil *golang.CodeUtils var NameStyle = styles.NewNamingStyle("thriftgo") ================================================ FILE: cmd/hz/util/ast.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package util import ( "bytes" "fmt" "go/ast" "go/format" "go/parser" "go/token" "path/filepath" "golang.org/x/tools/go/ast/astutil" ) func AddImport(file, alias, impt string) ([]byte, error) { fset := token.NewFileSet() path, _ := filepath.Abs(file) f, err := parser.ParseFile(fset, path, nil, parser.ParseComments) if err != nil { return nil, fmt.Errorf("can not parse ast for file: %s, err: %v", path, err) } return addImport(fset, f, alias, impt) } func AddImportForContent(fileContent []byte, alias, impt string) ([]byte, error) { fset := token.NewFileSet() f, err := parser.ParseFile(fset, "", fileContent, parser.ParseComments) if err != nil { return nil, fmt.Errorf("can not parse ast for file: %s, err: %v", fileContent, err) } return addImport(fset, f, alias, impt) } func addImport(fset *token.FileSet, f *ast.File, alias, impt string) ([]byte, error) { added := astutil.AddNamedImport(fset, f, alias, impt) if !added { return nil, fmt.Errorf("can not add import \"%s\" for file: %s", impt, f.Name.Name) } var output []byte buffer := bytes.NewBuffer(output) err := format.Node(buffer, fset, f) if err != nil { return nil, fmt.Errorf("can not add import for file: %s, err: %v", f.Name.Name, err) } return buffer.Bytes(), nil } ================================================ FILE: cmd/hz/util/ast_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package util import ( "bytes" "go/format" "go/parser" "go/token" "testing" "golang.org/x/tools/go/ast/astutil" ) func TestAddImport(t *testing.T) { inserts := [][]string{ { "ctx", "context", }, { "", "context", }, } files := [][]string{ { `package foo import ( "fmt" "time" ) `, `package foo import ( ctx "context" "fmt" "time" ) `, }, { `package foo import ( "fmt" "time" ) `, `package foo import ( "context" "fmt" "time" ) `, }, } for idx, file := range files { fset := token.NewFileSet() f, err := parser.ParseFile(fset, "", file[0], parser.ImportsOnly) if err != nil { t.Fatalf("can not parse ast for file") } astutil.AddNamedImport(fset, f, inserts[idx][0], inserts[idx][1]) var output []byte buffer := bytes.NewBuffer(output) err = format.Node(buffer, fset, f) if err != nil { t.Fatalf("can add import for file") } if buffer.String() != file[1] { t.Fatalf("insert import failed") } } } ================================================ FILE: cmd/hz/util/data.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package util import ( "errors" "fmt" "net/url" "path" "path/filepath" "reflect" "regexp" "strconv" "strings" "github.com/cloudwego/hertz/cmd/hz/util/logs" ) func CopyStringSlice(from, to *[]string) { n := len(*from) m := len(*to) if n > m { n = m } for i := 0; i < n; i++ { (*to)[i] = (*from)[i] } *to = (*to)[:n] } func CopyString2StringMap(from, to map[string]string) { for k := range to { delete(to, k) } for k, v := range from { to[k] = v } } func PackArgs(c interface{}) (res []string, err error) { t := reflect.TypeOf(c) v := reflect.ValueOf(c) if reflect.TypeOf(c).Kind() == reflect.Ptr { t = t.Elem() v = v.Elem() } if t.Kind() != reflect.Struct { return nil, errors.New("passed c must be struct or pointer of struct") } for i := 0; i < t.NumField(); i++ { f := t.Field(i) x := v.Field(i) n := f.Name if x.IsZero() { continue } switch x.Kind() { case reflect.Bool: if x.Bool() == false { continue } res = append(res, n+"="+fmt.Sprint(x.Bool())) case reflect.String: if x.String() == "" { continue } res = append(res, n+"="+x.String()) case reflect.Slice: if x.Len() == 0 { continue } ft := f.Type.Elem() if ft.Kind() != reflect.String { return nil, fmt.Errorf("slice field %v must be '[]string', err: %v", f.Name, err.Error()) } var ss []string for i := 0; i < x.Len(); i++ { ss = append(ss, x.Index(i).String()) } res = append(res, n+"="+strings.Join(ss, ";")) case reflect.Map: if x.Len() == 0 { continue } fk := f.Type.Key() if fk.Kind() != reflect.String { return nil, fmt.Errorf("map field %v must be 'map[string]string', err: %v", f.Name, err.Error()) } fv := f.Type.Elem() if fv.Kind() != reflect.String { return nil, fmt.Errorf("map field %v must be 'map[string]string', err: %v", f.Name, err.Error()) } var sk []string it := x.MapRange() for it.Next() { sk = append(sk, it.Key().String()+"="+it.Value().String()) } res = append(res, n+"="+strings.Join(sk, ";")) default: return nil, fmt.Errorf("unsupported field type: %+v, err: %v", f, err.Error()) } } return res, nil } func UnpackArgs(args []string, c interface{}) error { m, err := MapForm(args) if err != nil { return fmt.Errorf("unmarshal args failed, err: %v", err.Error()) } t := reflect.TypeOf(c).Elem() v := reflect.ValueOf(c).Elem() if t.Kind() != reflect.Struct { return errors.New("passed c must be struct or pointer of struct") } for i := 0; i < t.NumField(); i++ { f := t.Field(i) x := v.Field(i) n := f.Name values, ok := m[n] if !ok || len(values) == 0 || values[0] == "" { continue } switch x.Kind() { case reflect.Bool: if len(values) != 1 { return fmt.Errorf("field %s can't be assigned multi values: %v", n, values) } x.SetBool(values[0] == "true") case reflect.String: if len(values) != 1 { return fmt.Errorf("field %s can't be assigned multi values: %v", n, values) } x.SetString(values[0]) case reflect.Slice: if len(values) != 1 { return fmt.Errorf("field %s can't be assigned multi values: %v", n, values) } ss := strings.Split(values[0], ";") if x.Type().Elem().Kind() == reflect.Int { n := reflect.MakeSlice(x.Type(), len(ss), len(ss)) for i, s := range ss { val, err := strconv.ParseInt(s, 10, 64) if err != nil { return err } n.Index(i).SetInt(val) } x.Set(n) } else { for _, s := range ss { val := reflect.Append(x, reflect.ValueOf(s)) x.Set(val) } } case reflect.Map: if len(values) != 1 { return fmt.Errorf("field %s can't be assigned multi values: %v", n, values) } ss := strings.Split(values[0], ";") out := make(map[string]string, len(ss)) for _, s := range ss { sk := strings.SplitN(s, "=", 2) if len(sk) != 2 { return fmt.Errorf("map filed %v invalid key-value pair '%v'", n, s) } out[sk[0]] = sk[1] } x.Set(reflect.ValueOf(out)) default: return fmt.Errorf("field %s has unsupported type %+v", n, f.Type) } } return nil } func MapForm(input []string) (map[string][]string, error) { out := make(map[string][]string, len(input)) for _, str := range input { parts := strings.SplitN(str, "=", 2) if len(parts) != 2 { return nil, fmt.Errorf("invalid argument: '%s'", str) } key, val := parts[0], parts[1] out[key] = append(out[key], val) } return out, nil } func GetFirstKV(m map[string][]string) (string, []string) { for k, v := range m { return k, v } return "", nil } func ToCamelCase(name string) string { return CamelString(name) } func ToSnakeCase(name string) string { return SnakeString(name) } // unifyPath will convert "\" to "/" in path if the os is windows func unifyPath(path string) string { if IsWindows() { path = strings.ReplaceAll(path, "\\", "/") } return path } // BaseName get base name for path. ex: "github.com/p.s.m" => "p.s.m" func BaseName(include, subFixToTrim string) string { include = unifyPath(include) subFixToTrim = unifyPath(subFixToTrim) last := include if id := strings.LastIndex(last, "/"); id >= 0 && id < len(last)-1 { last = last[id+1:] } if !strings.HasSuffix(last, subFixToTrim) { return last } return last[:len(last)-len(subFixToTrim)] } func BaseNameAndTrim(include string) string { include = unifyPath(include) last := include if id := strings.LastIndex(last, "/"); id >= 0 && id < len(last)-1 { last = last[id+1:] } if id := strings.LastIndex(last, "."); id != -1 { last = last[:id] } return last } func SplitPackageName(pkg, subFixToTrim string) string { pkg = unifyPath(pkg) subFixToTrim = unifyPath(subFixToTrim) last := SplitPackage(pkg, subFixToTrim) if id := strings.LastIndex(last, "/"); id >= 0 && id < len(last)-1 { last = last[id+1:] } return last } func SplitPackage(pkg, subFixToTrim string) string { pkg = unifyPath(pkg) subFixToTrim = unifyPath(subFixToTrim) last := strings.TrimSuffix(pkg, subFixToTrim) if id := strings.LastIndex(last, "/"); id >= 0 && id < len(last)-1 { last = last[id+1:] } return strings.ReplaceAll(last, ".", "/") } // ImportToSanitizedPath converts import path to file system path and replaces dots with underscores in the last component. // For example: "github.com/example/v1.2" becomes "github.com/example/v1_2" on Unix systems. // NOTE: no idea about the background, it might caused go import package issues before? func ImportToSanitizedPath(path string) string { path = filepath.FromSlash(path) if i := strings.LastIndex(path, string(filepath.Separator)); i >= 0 && i < len(path)-1 && strings.Contains(path[i+1:], ".") { base := strings.ReplaceAll(path[i+1:], ".", "_") dir := path[:i] return dir + string(filepath.Separator) + base } return path } func ToVarName(paths []string) string { ps := strings.Join(paths, "__") input := []byte(url.PathEscape(ps)) out := make([]byte, 0, len(input)) for i := 0; i < len(input); i++ { c := input[i] if c == ':' || c == '*' { continue } if (c >= '0' && c <= '9' && i != 0) || (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c == '_') { out = append(out, c) } else { out = append(out, '_') } } return string(out) } func SplitGoTags(input string) []string { out := make([]string, 0, 4) ns := len(input) flag := false prev := 0 i := 0 for i = 0; i < ns; i++ { c := input[i] if c == '"' { flag = !flag } if !flag && c == ' ' { if prev < i { out = append(out, input[prev:i]) } prev = i + 1 } } if i != 0 && prev < i { out = append(out, input[prev:i]) } return out } func SubPackage(mod, dir string) string { if dir == "" { return mod } return path.Join(mod, filepath.ToSlash(dir)) } func SubDir(root, subPkg string) string { return filepath.Join(root, filepath.FromSlash(subPkg)) } var ( uniquePackageName = map[string]bool{} uniqueMiddlewareName = map[string]bool{} uniqueHandlerPackageName = map[string]bool{} ) // GetPackageUniqueName can get a non-repeating variable name for package alias func GetPackageUniqueName(name string) (string, error) { name, err := getUniqueName(name, uniquePackageName) if err != nil { return "", fmt.Errorf("can not generate unique name for package '%s', err: %v", name, err) } return name, nil } // GetMiddlewareUniqueName can get a non-repeating variable name for middleware name func GetMiddlewareUniqueName(name string) (string, error) { name, err := getUniqueName(name, uniqueMiddlewareName) if err != nil { return "", fmt.Errorf("can not generate routing group for path '%s', err: %v", name, err) } return name, nil } func GetHandlerPackageUniqueName(name string) (string, error) { name, err := getUniqueName(name, uniqueHandlerPackageName) if err != nil { return "", fmt.Errorf("can not generate unique handler package name: '%s', err: %v", name, err) } return name, nil } // getUniqueName can get a non-repeating variable name func getUniqueName(name string, uniqueNameSet map[string]bool) (string, error) { uniqueName := name if _, exist := uniqueNameSet[uniqueName]; exist { for i := 0; i < 10000; i++ { uniqueName = uniqueName + fmt.Sprintf("%d", i) if _, exist := uniqueNameSet[uniqueName]; !exist { logs.Infof("There is a package name with the same name, change %s to %s", name, uniqueName) break } uniqueName = name if i == 9999 { return "", fmt.Errorf("there is too many same package for %s", name) } } } uniqueNameSet[uniqueName] = true return uniqueName, nil } var validFuncReg = regexp.MustCompile("[_0-9a-zA-Z]") // ToGoFuncName converts a string to a function naming style for go func ToGoFuncName(s string) string { ss := []byte(s) for i := range ss { if !validFuncReg.Match([]byte{s[i]}) { ss[i] = '_' } } return string(ss) } ================================================ FILE: cmd/hz/util/data_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package util import "testing" func TestUniqueName(t *testing.T) { type UniqueName struct { Name string ExpectedName string ActualName string } nameList := []UniqueName{ { Name: "aaa", ExpectedName: "aaa", }, { Name: "aaa", ExpectedName: "aaa0", }, { Name: "aaa0", ExpectedName: "aaa00", }, { Name: "aaa0", ExpectedName: "aaa01", }, { Name: "aaa00", ExpectedName: "aaa000", }, { Name: "aaa", ExpectedName: "aaa1", }, { Name: "aaa", ExpectedName: "aaa2", }, { Name: "aaa", ExpectedName: "aaa3", }, { Name: "aaa", ExpectedName: "aaa4", }, } for _, name := range nameList { name.ActualName, _ = getUniqueName(name.Name, uniquePackageName) if name.ActualName != name.ExpectedName { t.Errorf("%s name expected unique name '%s', actually get '%s'", name.Name, name.ExpectedName, name.ActualName) } } for _, name := range nameList { name.ActualName, _ = getUniqueName(name.Name, uniqueMiddlewareName) if name.ActualName != name.ExpectedName { t.Errorf("%s name expected unique name '%s', actually get '%s'", name.Name, name.ExpectedName, name.ActualName) } } } ================================================ FILE: cmd/hz/util/env.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package util import ( "bytes" "fmt" "go/build" "io/ioutil" "os" "os/exec" "path/filepath" "regexp" "strings" "github.com/cloudwego/hertz/cmd/hz/meta" ) func GetGOPATH() (gopath string, err error) { ps := filepath.SplitList(os.Getenv("GOPATH")) if len(ps) > 0 { gopath = ps[0] } if gopath == "" { cmd := exec.Command("go", "env", "GOPATH") var out bytes.Buffer cmd.Stderr = &out cmd.Stdout = &out if err := cmd.Run(); err == nil { gopath = strings.Trim(out.String(), " \t\n\r") } } if gopath == "" { ps := GetBuildGoPaths() if len(ps) > 0 { gopath = ps[0] } } isExist, err := PathExist(gopath) if !isExist { return "", err } return strings.Replace(gopath, "/", string(os.PathSeparator), -1), nil } // GetBuildGoPaths returns the list of Go path directories. func GetBuildGoPaths() []string { var all []string for _, p := range filepath.SplitList(build.Default.GOPATH) { if p == "" || p == build.Default.GOROOT { continue } if strings.HasPrefix(p, "~") { continue } all = append(all, p) } for k, v := range all { if strings.HasSuffix(v, "/") || strings.HasSuffix(v, string(os.PathSeparator)) { v = v[:len(v)-1] } all[k] = v } return all } var goModReg = regexp.MustCompile(`^\s*module\s+(\S+)\s*`) // SearchGoMod searches go.mod from the given directory (which must be an absolute path) to // the root directory. When the go.mod is found, its module name and path will be returned. func SearchGoMod(cwd string, recurse bool) (moduleName, path string, found bool) { for { path = filepath.Join(cwd, "go.mod") data, err := ioutil.ReadFile(path) if err == nil { for _, line := range strings.Split(string(data), "\n") { m := goModReg.FindStringSubmatch(line) if m != nil { return m[1], cwd, true } } return fmt.Sprintf("", path), path, true } if !os.IsNotExist(err) { return } if !recurse { break } cwd = filepath.Dir(cwd) // the root directory will return itself by using "filepath.Dir()"; to prevent dead loops, so jump out if cwd == filepath.Dir(cwd) { break } } return } func InitGoMod(module string) error { isExist, err := PathExist("go.mod") if err != nil { return err } if isExist { return nil } gg, err := exec.LookPath("go") if err != nil { return err } cmd := &exec.Cmd{ Path: gg, Args: []string{"go", "mod", "init", module}, Stdin: os.Stdin, Stdout: os.Stdout, Stderr: os.Stderr, } return cmd.Run() } func IsWindows() bool { return meta.SysType == meta.WindowsOS } ================================================ FILE: cmd/hz/util/fs.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package util import ( "os" "path/filepath" ) func PathExist(path string) (bool, error) { abPath, err := filepath.Abs(path) if err != nil { return false, err } _, err = os.Stat(abPath) if err != nil { return os.IsExist(err), nil } return true, nil } func RelativePath(path string) (string, error) { path, err := filepath.Abs(path) if err != nil { return "", err } cwd, err := os.Getwd() if err != nil { return "", err } ret, _ := filepath.Rel(cwd, path) return ret, nil } ================================================ FILE: cmd/hz/util/logs/api.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package logs func init() { defaultLogger = NewStdLogger(LevelInfo) } func SetLogger(logger Logger) { defaultLogger = logger } const ( LevelDebug = 1 + iota LevelInfo LevelWarn LevelError ) // TODO: merge with hertz logger package type Logger interface { Debugf(format string, v ...interface{}) Infof(format string, v ...interface{}) Warnf(format string, v ...interface{}) Errorf(format string, v ...interface{}) Flush() SetLevel(level int) error } var defaultLogger Logger func Errorf(format string, v ...interface{}) { defaultLogger.Errorf(format, v...) } func Warnf(format string, v ...interface{}) { defaultLogger.Warnf(format, v...) } func Infof(format string, v ...interface{}) { defaultLogger.Infof(format, v...) } func Debugf(format string, v ...interface{}) { defaultLogger.Debugf(format, v...) } func Error(format string, v ...interface{}) { defaultLogger.Errorf(format, v...) } func Warn(format string, v ...interface{}) { defaultLogger.Warnf(format, v...) } func Info(format string, v ...interface{}) { defaultLogger.Infof(format, v...) } func Debug(format string, v ...interface{}) { defaultLogger.Debugf(format, v...) } func Flush() { defaultLogger.Flush() } func SetLevel(level int) { defaultLogger.SetLevel(level) } ================================================ FILE: cmd/hz/util/logs/std.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package logs import ( "bytes" "errors" "fmt" "log" "os" ) type StdLogger struct { level int outLogger *log.Logger warnLogger *log.Logger errLogger *log.Logger out *bytes.Buffer warn *bytes.Buffer err *bytes.Buffer Defer bool ErrOnly bool } func NewStdLogger(level int) *StdLogger { out := bytes.NewBuffer(nil) warn := bytes.NewBuffer(nil) err := bytes.NewBuffer(nil) return &StdLogger{ level: level, outLogger: log.New(out, "[INFO]", log.Llongfile), warnLogger: log.New(warn, "[WARN]", log.Llongfile), errLogger: log.New(err, "[ERROR]", log.Llongfile), out: out, warn: warn, err: err, } } func (stdLogger *StdLogger) Debugf(format string, v ...interface{}) { if stdLogger.level > LevelDebug { return } stdLogger.outLogger.Output(3, fmt.Sprintf(format, v...)) if !stdLogger.Defer { stdLogger.FlushOut() } } func (stdLogger *StdLogger) Infof(format string, v ...interface{}) { if stdLogger.level > LevelInfo { return } stdLogger.outLogger.Output(3, fmt.Sprintf(format, v...)) if !stdLogger.Defer { stdLogger.FlushOut() } } func (stdLogger *StdLogger) Warnf(format string, v ...interface{}) { if stdLogger.level > LevelWarn { return } stdLogger.warnLogger.Output(3, fmt.Sprintf(format, v...)) if !stdLogger.Defer { stdLogger.FlushErr() } } func (stdLogger *StdLogger) Errorf(format string, v ...interface{}) { if stdLogger.level > LevelError { return } stdLogger.errLogger.Output(3, fmt.Sprintf(format, v...)) if !stdLogger.Defer { stdLogger.FlushErr() } } func (stdLogger *StdLogger) Flush() { stdLogger.FlushErr() if !stdLogger.ErrOnly { stdLogger.FlushOut() } } func (stdLogger *StdLogger) FlushOut() { os.Stderr.Write(stdLogger.out.Bytes()) stdLogger.out.Reset() } func (stdLogger *StdLogger) Err() string { return string(stdLogger.err.Bytes()) } func (stdLogger *StdLogger) Warn() string { return string(stdLogger.warn.Bytes()) } func (stdLogger *StdLogger) FlushErr() { os.Stderr.Write(stdLogger.err.Bytes()) stdLogger.err.Reset() } func (stdLogger *StdLogger) OutLines() []string { lines := bytes.Split(stdLogger.out.Bytes(), []byte("[INFO]")) var rets []string for _, line := range lines { rets = append(rets, string(line)) } return rets } func (stdLogger *StdLogger) Out() []byte { return stdLogger.out.Bytes() } func (stdLogger *StdLogger) SetLevel(level int) error { switch level { case LevelDebug, LevelInfo, LevelWarn, LevelError: break default: return errors.New("invalid log level") } stdLogger.level = level return nil } ================================================ FILE: cmd/hz/util/string.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package util import ( "reflect" "strings" "unicode/utf8" "unsafe" ) func Str2Bytes(in string) (out []byte) { op := (*reflect.SliceHeader)(unsafe.Pointer(&out)) ip := (*reflect.StringHeader)(unsafe.Pointer(&in)) op.Data = ip.Data op.Cap = ip.Len op.Len = ip.Len return } func Bytes2Str(in []byte) (out string) { op := (*reflect.StringHeader)(unsafe.Pointer(&out)) ip := (*reflect.SliceHeader)(unsafe.Pointer(&in)) op.Data = ip.Data op.Len = ip.Len return } // TrimLastChar can remove the last char for s func TrimLastChar(s string) string { r, size := utf8.DecodeLastRuneInString(s) if r == utf8.RuneError && (size == 0 || size == 1) { size = 0 } return s[:len(s)-size] } // AddSlashForComments can adjust the format of multi-line comments func AddSlashForComments(s string) string { s = strings.Replace(s, "\n", "\n//", -1) return s } // CamelString converts the string 's' to a camel string func CamelString(s string) string { data := make([]byte, 0, len(s)) j := false k := false num := len(s) - 1 for i := 0; i <= num; i++ { d := s[i] if k == false && d >= 'A' && d <= 'Z' { k = true } if d >= 'a' && d <= 'z' && (j || k == false) { d = d - 32 j = false k = true } if k && d == '_' && num > i && s[i+1] >= 'a' && s[i+1] <= 'z' { j = true continue } data = append(data, d) } return Bytes2Str(data[:]) } // SnakeString converts the string 's' to a snake string func SnakeString(s string) string { data := make([]byte, 0, len(s)*2) j := false for _, d := range Str2Bytes(s) { if d >= 'A' && d <= 'Z' { if j { data = append(data, '_') j = false } } else if d != '_' { j = true } data = append(data, d) } return strings.ToLower(Bytes2Str(data)) } ================================================ FILE: cmd/hz/util/tool_install.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package util import ( "fmt" "os" "os/exec" "path/filepath" "strings" "time" "github.com/cloudwego/hertz/cmd/hz/meta" "github.com/cloudwego/hertz/cmd/hz/util/logs" gv "github.com/hashicorp/go-version" ) const ThriftgoMiniVersion = "v0.2.0" // QueryVersion will query the version of the corresponding executable. func QueryVersion(exe string) (version string, err error) { var buf strings.Builder cmd := &exec.Cmd{ Path: exe, Args: []string{ exe, "--version", }, Stdin: os.Stdin, Stdout: &buf, Stderr: &buf, } err = cmd.Run() if err == nil { version = strings.Split(buf.String(), " ")[1] if strings.HasSuffix(version, "\n") { version = version[:len(version)-1] } } return } // ShouldUpdate will return "true" when current is lower than latest. func ShouldUpdate(current, latest string) bool { cv, err := gv.NewVersion(current) if err != nil { return false } lv, err := gv.NewVersion(latest) if err != nil { return false } return cv.Compare(lv) < 0 } // InstallAndCheckThriftgo will automatically install thriftgo and judge whether it is installed successfully. func InstallAndCheckThriftgo() error { exe, err := exec.LookPath("go") if err != nil { return fmt.Errorf("can not find tool 'go': %v", err) } var buf strings.Builder cmd := &exec.Cmd{ Path: exe, Args: []string{ exe, "install", "github.com/cloudwego/thriftgo@latest", }, Stdin: os.Stdin, Stdout: &buf, Stderr: &buf, } done := make(chan error) logs.Infof("installing thriftgo automatically") go func() { done <- cmd.Run() }() select { case err = <-done: if err != nil { return fmt.Errorf("can not install thriftgo, err: %v. Please install it manual, and make sure the version of thriftgo is greater than v0.2.0", cmd.Stderr) } case <-time.After(time.Second * 30): return fmt.Errorf("install thriftgo time out.Please install it manual, and make sure the version of thriftgo is greater than v0.2.0") } exist, err := CheckCompiler(meta.TpCompilerThrift) if err != nil { return fmt.Errorf("check %s exist failed, err: %v", meta.TpCompilerThrift, err) } if !exist { return fmt.Errorf("install thriftgo failed. Please install it manual, and make sure the version of thriftgo is greater than v0.2.0") } return nil } // CheckCompiler will check if the tool exists. func CheckCompiler(tool string) (bool, error) { path, err := exec.LookPath(tool) if err != nil { goPath, err := GetGOPATH() if err != nil { return false, fmt.Errorf("get 'GOPATH' failed for find %s : %v", tool, path) } path = filepath.Join(goPath, "bin", tool) } isExist, err := PathExist(path) if err != nil { return false, fmt.Errorf("can not check %s exist, err: %v", tool, err) } if !isExist { return false, nil } return true, nil } // CheckAndUpdateThriftgo checks the version of thriftgo and updates the tool to the latest version if its version is less than v0.2.0. func CheckAndUpdateThriftgo() error { path, err := exec.LookPath(meta.TpCompilerThrift) if err != nil { return fmt.Errorf("can not find %s", meta.TpCompilerThrift) } curVersion, err := QueryVersion(path) logs.Infof("current thriftgo version is %s", curVersion) if ShouldUpdate(curVersion, ThriftgoMiniVersion) { logs.Infof(" current thriftgo version is less than v0.2.0, so update thriftgo version") err = InstallAndCheckThriftgo() if err != nil { return fmt.Errorf("update thriftgo version failed, err: %v", err) } } return nil } ================================================ FILE: cmd/hz/util/tool_install_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package util import "testing" func TestQueryVersion(t *testing.T) { lowVersion := "v0.1.0" equalVersion := "v0.2.0" highVersion := "v0.3.0" if ShouldUpdate(lowVersion, ThriftgoMiniVersion) { } if ShouldUpdate(equalVersion, ThriftgoMiniVersion) { t.Fatal("should not be updated") } if ShouldUpdate(highVersion, ThriftgoMiniVersion) { t.Fatal("should not be updated") } } ================================================ FILE: examples/html_rendering/index.tmpl ================================================

{[{ .title }]}

================================================ FILE: examples/html_rendering/main.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package main import ( "context" "fmt" "html/template" "time" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/app/server" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol/consts" ) func formatAsDate(t time.Time) string { year, month, day := t.Date() return fmt.Sprintf("%d/%02d/%02d", year, month, day) } func main() { // set interval to 0 means using fs-watching mechanism. h := server.Default(server.WithAutoReloadRender(true, 0)) h.Delims("{[{", "}]}") h.SetFuncMap(template.FuncMap{ "formatAsDate": formatAsDate, }) h.LoadHTMLGlob("./examples/html_rendering/*") h.GET("/index", func(c context.Context, ctx *app.RequestContext) { ctx.HTML(consts.StatusOK, "index.tmpl", utils.H{ "title": "Main website", }) }) h.GET("/raw", func(c context.Context, ctx *app.RequestContext) { ctx.HTML(consts.StatusOK, "template.html", utils.H{ "now": time.Date(2017, 0o7, 0o1, 0, 0, 0, 0, time.UTC), }) }) h.Spin() } ================================================ FILE: examples/html_rendering/template.html ================================================

Date: {[{.now | formatAsDate}]}

================================================ FILE: examples/standard/main.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package main import ( "context" "fmt" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/app/server" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol/consts" ) type Test struct { A string B string } func main() { h := server.Default() h.StaticFS("/", &app.FS{Root: "./", GenerateIndexPages: true}) h.GET("/ping", func(c context.Context, ctx *app.RequestContext) { ctx.JSON(consts.StatusOK, utils.H{"ping": "pong"}) }) h.GET("/json", func(c context.Context, ctx *app.RequestContext) { ctx.JSON(consts.StatusOK, &Test{ A: "aaa", B: "bbb", }) }) h.GET("/redirect", func(c context.Context, ctx *app.RequestContext) { ctx.Redirect(consts.StatusMovedPermanently, []byte("http://www.google.com/")) }) v1 := h.Group("/v1") { v1.GET("/hello/:name", func(c context.Context, ctx *app.RequestContext) { fmt.Fprintf(ctx, "Hi %s, this is the response from Hertz.\n", ctx.Param("name")) }) } h.Spin() } ================================================ FILE: go.mod ================================================ module github.com/cloudwego/hertz go 1.19 require ( github.com/bytedance/gopkg v0.1.3 github.com/bytedance/sonic v1.15.0 github.com/cloudwego/gopkg v0.1.11-0.20260303065100-1e5551ecf390 github.com/cloudwego/netpoll v0.7.3-0.20260305035010-81277e4f7b67 github.com/fsnotify/fsnotify v1.5.4 github.com/stretchr/testify v1.10.0 github.com/tidwall/gjson v1.14.4 golang.org/x/sync v0.8.0 golang.org/x/sys v0.24.0 google.golang.org/protobuf v1.34.1 ) require ( github.com/bytedance/sonic/loader v0.5.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/klauspost/cpuid/v2 v2.2.9 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) ================================================ FILE: go.sum ================================================ github.com/bytedance/gopkg v0.1.1/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE= github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k= github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE= github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/cloudwego/gopkg v0.1.4/go.mod h1:FQuXsRWRsSqJLsMVd5SYzp8/Z1y5gXKnVvRrWUOsCMI= github.com/cloudwego/gopkg v0.1.11-0.20260303065100-1e5551ecf390 h1:DERt3Cue/q307RWCd+pPXvzuqmujrrzORgShkeU4Q0s= github.com/cloudwego/gopkg v0.1.11-0.20260303065100-1e5551ecf390/go.mod h1:wQv2rXOgrRCYdIrOce+xnAF7MA30CkofQZ3JHZOXY+8= github.com/cloudwego/netpoll v0.7.3-0.20260305035010-81277e4f7b67 h1:0dwPCnAMoeEupEKCAR4paadnmaq39MVdxvmEVlwcu3g= github.com/cloudwego/netpoll v0.7.3-0.20260305035010-81277e4f7b67/go.mod h1:PI+YrmyS7cIr0+SD4seJz3Eo3ckkXdu2ZVKBLhURLNU= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fsnotify/fsnotify v1.5.4 h1:jRbGcIw6P2Meqdwuo0H1p6JVLbL5DHKAKlYndzMwVZI= github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmVXmkdnm1bU= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY= github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.19.0/go.mod h1:2CuTdWZ7KHSQwUzKva0cbMg6q2DMI3Mmxp+gKJbskEk= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= ================================================ FILE: internal/bytesconv/bytesconv.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package bytesconv import ( "net/http" "time" "unsafe" "github.com/cloudwego/hertz/pkg/network" ) const ( upperhex = "0123456789ABCDEF" lowerhex = "0123456789abcdef" ) func LowercaseBytes(b []byte) { for i := 0; i < len(b); i++ { p := &b[i] *p = ToLowerTable[*p] } } // B2s converts byte slice to a string without memory allocation. // See https://groups.google.com/forum/#!msg/Golang-Nuts/ENgbUzYvCuU/90yGx7GUAgAJ . // // Note it may break if string and/or slice header will change // in the future go versions. func B2s(b []byte) string { return *(*string)(unsafe.Pointer(&b)) } type sliceHeader struct { Data unsafe.Pointer Len int Cap int } // S2b converts string to a byte slice without memory allocation. // // Note it may break if string and/or slice header will change // in the future go versions. func S2b(s string) (b []byte) { *(*string)(unsafe.Pointer(&b)) = s (*sliceHeader)(unsafe.Pointer(&b)).Cap = len(s) return } func EncodedIntHexLen(n uint64) int { if n == 0 { return 1 } i := 0 for n > 0 { i++ n >>= 4 } return i } func AppendIntHex(b []byte, n uint64) []byte { if n == 0 { return append(b, '0') } var tmp [16]byte // 64 / 4 = 16 i := len(tmp) for n > 0 { i-- tmp[i] = lowerhex[n&0xf] n >>= 4 } return append(b, tmp[i:]...) } func ReadHexInt(r network.Reader) (int, error) { n := 0 i := 0 var k int for { buf, err := r.Peek(1) if err != nil { r.Skip(1) if i > 0 { return n, nil } return -1, err } c := buf[0] k = int(Hex2intTable[c]) if k == 16 { if i == 0 { r.Skip(1) return -1, errEmptyHexNum } return n, nil } if i >= maxHexIntChars { r.Skip(1) return -1, errTooLargeHexNum } r.Skip(1) n = (n << 4) | k i++ } } func ParseUintBuf(b []byte) (int, int, error) { n := len(b) if n == 0 { return -1, 0, errEmptyInt } v := 0 for i := 0; i < n; i++ { c := b[i] k := c - '0' if k > 9 { if i == 0 { return -1, i, errUnexpectedFirstChar } return v, i, nil } vNew := 10*v + int(k) // Test for overflow. if vNew < v { return -1, i, errTooLongInt } v = vNew } return v, n, nil } // AppendUint appends n to dst and returns the extended dst. func AppendUint(dst []byte, n int) []byte { if n < 0 { panic("BUG: int must be positive") } var b [20]byte buf := b[:] i := len(buf) var q int for n >= 10 { i-- q = n / 10 buf[i] = '0' + byte(n-q*10) n = q } i-- buf[i] = '0' + byte(n) dst = append(dst, buf[i:]...) return dst } // AppendHTTPDate appends HTTP-compliant representation of date // to dst and returns the extended dst. func AppendHTTPDate(dst []byte, date time.Time) []byte { return date.UTC().AppendFormat(dst, http.TimeFormat) } func AppendQuotedPath(dst, src []byte) []byte { // Fix issue in https://github.com/golang/go/issues/11202 if len(src) == 1 && src[0] == '*' { return append(dst, '*') } for _, c := range src { if QuotedPathShouldEscapeTable[int(c)] != 0 { dst = append(dst, '%', upperhex[c>>4], upperhex[c&15]) } else { dst = append(dst, c) } } return dst } // AppendQuotedArg appends url-encoded src to dst and returns appended dst. func AppendQuotedArg(dst, src []byte) []byte { for _, c := range src { switch { case c == ' ': dst = append(dst, '+') case QuotedArgShouldEscapeTable[int(c)] != 0: dst = append(dst, '%', upperhex[c>>4], upperhex[c&0xf]) default: dst = append(dst, c) } } return dst } // ParseHTTPDate parses HTTP-compliant (RFC1123) date. func ParseHTTPDate(date []byte) (time.Time, error) { return time.Parse(time.RFC1123, B2s(date)) } // ParseUint parses uint from buf. func ParseUint(buf []byte) (int, error) { v, n, err := ParseUintBuf(buf) if n != len(buf) { return -1, errUnexpectedTrailingChar } return v, err } ================================================ FILE: internal/bytesconv/bytesconv_32.go ================================================ //go:build !amd64 && !arm64 && !ppc64 /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package bytesconv const ( maxHexIntChars = 7 ) ================================================ FILE: internal/bytesconv/bytesconv_32_test.go ================================================ //go:build !amd64 && !arm64 && !ppc64 /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package bytesconv import ( "fmt" "testing" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestReadHexInt(t *testing.T) { t.Parallel() for _, v := range []struct { s string n int }{ //errTooLargeHexNum "too large hex number" //{"0123456789abcdef", -1}, {"0", 0}, {"fF", 0xff}, {"00abc", 0xabc}, {"7fffffff", 0x7fffffff}, {"000", 0}, {"1234ZZZ", 0x1234}, } { testReadHexInt(t, v.s, v.n) } } func TestParseUint(t *testing.T) { t.Parallel() for _, v := range []struct { s string i int }{ {"0", 0}, {"123", 123}, {"123456789", 123456789}, {"2147483647", 2147483647}, } { n, err := ParseUint(S2b(v.s)) if err != nil { t.Errorf("unexpected error: %v. s=%q n=%v", err, v.s, n) } assert.DeepEqual(t, n, v.i) } } func TestParseUintError(t *testing.T) { t.Parallel() for _, v := range []struct { s string }{ {""}, {"cloudwego123"}, {"1234.545"}, {"-2147483648"}, {"2147483648"}, {"4294967295"}, } { n, err := ParseUint(S2b(v.s)) if err == nil { t.Fatalf("Expecting error when parsing %q. obtained %d", v.s, n) } if n >= 0 { t.Fatalf("Unexpected n=%d when parsing %q. Expected negative num", n, v.s) } } } func TestAppendUint(t *testing.T) { t.Parallel() for _, s := range []struct { n int }{ {0}, {123}, {0x7fffffff}, } { expectedS := fmt.Sprintf("%d", s.n) s := AppendUint(nil, s.n) assert.DeepEqual(t, expectedS, B2s(s)) } } ================================================ FILE: internal/bytesconv/bytesconv_64.go ================================================ //go:build amd64 || arm64 || ppc64 /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package bytesconv const ( maxHexIntChars = 15 ) ================================================ FILE: internal/bytesconv/bytesconv_64_test.go ================================================ //go:build amd64 || arm64 || ppc64 /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package bytesconv import ( "fmt" "testing" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestReadHexInt(t *testing.T) { t.Parallel() for _, v := range []struct { s string n int }{ //errTooLargeHexNum "too large hex number" //{"0123456789abcdef", -1}, {"0", 0}, {"fF", 0xff}, {"00abc", 0xabc}, {"7fffffff", 0x7fffffff}, {"000", 0}, {"1234ZZZ", 0x1234}, {"7ffffffffffffff", 0x7ffffffffffffff}, } { testReadHexInt(t, v.s, v.n) } } func TestParseUint(t *testing.T) { t.Parallel() for _, v := range []struct { s string i int }{ {"0", 0}, {"123", 123}, {"1234567890", 1234567890}, {"123456789012345678", 123456789012345678}, {"9223372036854775807", 9223372036854775807}, } { n, err := ParseUint(S2b(v.s)) if err != nil { t.Errorf("unexpected error: %v. s=%q n=%v", err, v.s, n) } assert.DeepEqual(t, n, v.i) } } func TestParseUintError(t *testing.T) { t.Parallel() for _, v := range []struct { s string }{ {""}, {"cloudwego123"}, {"1234.545"}, {"-9223372036854775808"}, {"9223372036854775808"}, {"18446744073709551615"}, } { n, err := ParseUint(S2b(v.s)) if err == nil { t.Fatalf("Expecting error when parsing %q. obtained %d", v.s, n) } if n >= 0 { t.Fatalf("Unexpected n=%d when parsing %q. Expected negative num", n, v.s) } } } func TestAppendUint(t *testing.T) { t.Parallel() for _, s := range []struct { n int }{ {0}, {123}, {0x7fffffffffffffff}, } { expectedS := fmt.Sprintf("%d", s.n) s := AppendUint(nil, s.n) assert.DeepEqual(t, expectedS, B2s(s)) } } ================================================ FILE: internal/bytesconv/bytesconv_table.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package bytesconv // Code generated by go run bytesconv_table_gen.go; DO NOT EDIT. // See bytesconv_table_gen.go for more information about these tables. const ( Hex2intTable = "\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x00\x01\x02\x03\x04\x05\x06\a\b\t\x10\x10\x10\x10\x10\x10\x10\n\v\f\r\x0e\x0f\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\n\v\f\r\x0e\x0f\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10" ToLowerTable = "\x00\x01\x02\x03\x04\x05\x06\a\b\t\n\v\f\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f !\"#$%&'()*+,-./0123456789:;<=>?@abcdefghijklmnopqrstuvwxyz[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\x7f\x80\x81\x82\x83\x84\x85\x86\x87\x88\x89\x8a\x8b\x8c\x8d\x8e\x8f\x90\x91\x92\x93\x94\x95\x96\x97\x98\x99\x9a\x9b\x9c\x9d\x9e\x9f\xa0\xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xab\xac\xad\xae\xaf\xb0\xb1\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xbb\xbc\xbd\xbe\xbf\xc0\xc1\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xcb\xcc\xcd\xce\xcf\xd0\xd1\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xdb\xdc\xdd\xde\xdf\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee\xef\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb\xfc\xfd\xfe\xff" ToUpperTable = "\x00\x01\x02\x03\x04\x05\x06\a\b\t\n\v\f\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`ABCDEFGHIJKLMNOPQRSTUVWXYZ{|}~\x7f\x80\x81\x82\x83\x84\x85\x86\x87\x88\x89\x8a\x8b\x8c\x8d\x8e\x8f\x90\x91\x92\x93\x94\x95\x96\x97\x98\x99\x9a\x9b\x9c\x9d\x9e\x9f\xa0\xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xab\xac\xad\xae\xaf\xb0\xb1\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xbb\xbc\xbd\xbe\xbf\xc0\xc1\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xcb\xcc\xcd\xce\xcf\xd0\xd1\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xdb\xdc\xdd\xde\xdf\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee\xef\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb\xfc\xfd\xfe\xff" QuotedArgShouldEscapeTable = "\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x01\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01" QuotedPathShouldEscapeTable = "\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x00\x01\x00\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x01\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01" ValidCookieValueTable = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" ValidHeaderFieldValueTable = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01" ValidHeaderFieldNameTable = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x01\x01\x01\x01\x01\x00\x00\x01\x01\x00\x01\x01\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x00\x00\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x00\x01\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" ) ================================================ FILE: internal/bytesconv/bytesconv_table_gen.go ================================================ //go:build ignore /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package main import ( "bytes" "fmt" "go/format" "io/ioutil" "log" ) const ( toLower = 'a' - 'A' ) func main() { hex2intTable := func() [256]byte { var b [256]byte for i := 0; i < 256; i++ { c := byte(16) if i >= '0' && i <= '9' { c = byte(i) - '0' } else if i >= 'a' && i <= 'f' { c = byte(i) - 'a' + 10 } else if i >= 'A' && i <= 'F' { c = byte(i) - 'A' + 10 } b[i] = c } return b }() toLowerTable := func() [256]byte { var a [256]byte for i := 0; i < 256; i++ { c := byte(i) if c >= 'A' && c <= 'Z' { c += toLower } a[i] = c } return a }() toUpperTable := func() [256]byte { var a [256]byte for i := 0; i < 256; i++ { c := byte(i) if c >= 'a' && c <= 'z' { c -= toLower } a[i] = c } return a }() quotedArgShouldEscapeTable := func() [256]byte { // According to RFC 3986 §2.3 var a [256]byte for i := 0; i < 256; i++ { a[i] = 1 } // ALPHA for i := int('a'); i <= int('z'); i++ { a[i] = 0 } for i := int('A'); i <= int('Z'); i++ { a[i] = 0 } // DIGIT for i := int('0'); i <= int('9'); i++ { a[i] = 0 } // Unreserved characters for _, v := range `-_.~` { a[v] = 0 } return a }() quotedPathShouldEscapeTable := func() [256]byte { // The implementation here equal to net/url shouldEscape(s, encodePath) // // The RFC allows : @ & = + $ but saves / ; , for assigning // meaning to individual path segments. This package // only manipulates the path as a whole, so we allow those // last three as well. That leaves only ? to escape. a := quotedArgShouldEscapeTable for _, v := range `$&+,/:;=@` { a[v] = 0 } return a }() validCookieValueTable := func() [256]byte { // The implementation here is equal to net/http validCookieValueByte(b byte) // see https://datatracker.ietf.org/doc/html/rfc6265#section-4.1.1 var a [256]byte for i := 0; i < 256; i++ { a[i] = 0 } for i := 0x20; i < 0x7f; i++ { a[i] = 1 } a['"'] = 0 a[';'] = 0 a['\\'] = 0 return a }() validHeaderFieldValueTable := func() [256]byte { // The implementation here is equal to httpguts.ValidHeaderFieldValue var a [256]byte for i := 0; i < 256; i++ { a[i] = 1 } for i := 0; i < ' '; i++ { a[i] = 0 } // del CTL a[0x7f] = 0 // tab a['\t'] = 1 return a }() validHeaderFieldNameTable := func() [256]byte { // The implementation here is equal to httpguts ValidHeaderFieldName(string) // see https://datatracker.ietf.org/doc/html/rfc7230#section-3.2 // // RFC 7230 says: // header-field = field-name ":" OWS field-value OWS // field-name = token // token = 1*tchar // tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." / // "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA var a [256]byte for i := 0; i < 256; i++ { a[i] = 0 } a['!'] = 1 a['#'] = 1 a['$'] = 1 a['%'] = 1 a['&'] = 1 a['\''] = 1 a['*'] = 1 a['+'] = 1 a['-'] = 1 a['.'] = 1 a['^'] = 1 a['_'] = 1 a['`'] = 1 a['|'] = 1 a['~'] = 1 // ALPHA for i := int('a'); i <= int('z'); i++ { a[i] = 1 } for i := int('A'); i <= int('Z'); i++ { a[i] = 1 } // DIGIT for i := int('0'); i <= int('9'); i++ { a[i] = 1 } return a }() w := new(bytes.Buffer) w.WriteString(pre) fmt.Fprintf(w, "const (\n") fmt.Fprintf(w, "\tHex2intTable = %q\n", hex2intTable) fmt.Fprintf(w, "\tToLowerTable = %q\n", toLowerTable) fmt.Fprintf(w, "\tToUpperTable = %q\n", toUpperTable) fmt.Fprintf(w, "\tQuotedArgShouldEscapeTable = %q\n", quotedArgShouldEscapeTable) fmt.Fprintf(w, "\tQuotedPathShouldEscapeTable = %q\n", quotedPathShouldEscapeTable) fmt.Fprintf(w, "\tValidCookieValueTable = %q\n", validCookieValueTable) fmt.Fprintf(w, "\tValidHeaderFieldValueTable = %q\n", validHeaderFieldValueTable) fmt.Fprintf(w, "\tValidHeaderFieldNameTable = %q\n", validHeaderFieldNameTable) fmt.Fprintf(w, ")\n") source, err := format.Source(w.Bytes()) if err != nil { log.Fatal(err) } if err := ioutil.WriteFile("bytesconv_table.go", source, 0o660); err != nil { log.Fatal(err) } } const pre = `/* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package bytesconv // Code generated by go run bytesconv_table_gen.go; DO NOT EDIT. // See bytesconv_table_gen.go for more information about these tables. ` ================================================ FILE: internal/bytesconv/bytesconv_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package bytesconv import ( "net/url" "testing" "time" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/test/mock" ) func TestAppendDate(t *testing.T) { t.Parallel() // GMT+8 shanghaiTimeZone := time.FixedZone("Asia/Shanghai", 8*60*60) for _, c := range []struct { name string date time.Time dateStr string }{ { name: "UTC", date: time.Date(2022, 6, 15, 11, 12, 13, 123, time.UTC), dateStr: "Wed, 15 Jun 2022 11:12:13 GMT", }, { name: "Asia/Shanghai", date: time.Date(2022, 6, 15, 3, 12, 45, 999, shanghaiTimeZone), dateStr: "Tue, 14 Jun 2022 19:12:45 GMT", }, } { t.Run(c.name, func(t *testing.T) { s := AppendHTTPDate(nil, c.date) assert.DeepEqual(t, c.dateStr, B2s(s)) }) } } func TestLowercaseBytes(t *testing.T) { t.Parallel() for _, v := range []struct { b1, b2 []byte }{ {[]byte("CLOUDWEGO-HERTZ"), []byte("cloudwego-hertz")}, {[]byte("CLOUDWEGO"), []byte("cloudwego")}, {[]byte("HERTZ"), []byte("hertz")}, } { LowercaseBytes(v.b1) assert.DeepEqual(t, v.b2, v.b1) } } // The test converts byte slice to a string without memory allocation. func TestB2s(t *testing.T) { t.Parallel() for _, v := range []struct { s string b []byte }{ {"cloudwego-hertz", []byte("cloudwego-hertz")}, {"cloudwego", []byte("cloudwego")}, {"hertz", []byte("hertz")}, } { assert.DeepEqual(t, v.s, B2s(v.b)) } } // The test converts string to a byte slice without memory allocation. func TestS2b(t *testing.T) { t.Parallel() for _, v := range []struct { s string b []byte }{ {"cloudwego-hertz", []byte("cloudwego-hertz")}, {"cloudwego", []byte("cloudwego")}, {"hertz", []byte("hertz")}, } { assert.DeepEqual(t, S2b(v.s), v.b) } } func TestAppendIntHex(t *testing.T) { testCases := []struct { b []byte n uint64 expected string }{ {[]byte{}, 0, "0"}, {[]byte{}, 1, "1"}, {[]byte{}, 10, "a"}, {[]byte{}, 15, "f"}, {[]byte{}, 16, "10"}, {[]byte{}, 255, "ff"}, {[]byte{}, 256, "100"}, {[]byte{}, 123456789, "75bcd15"}, {[]byte{}, 0xffffffffffffffff, "ffffffffffffffff"}, {[]byte("pre-"), 255, "pre-ff"}, {[]byte("start"), 0, "start0"}, } for _, tc := range testCases { result := AppendIntHex(tc.b, tc.n) if string(result) != tc.expected { t.Fatalf("AppendIntHex(%q, %d) = %q; want %q", tc.b, tc.n, result, tc.expected) } actualLen := EncodedIntHexLen(tc.n) expectedLen := len(result) - len(tc.b) if actualLen != expectedLen { t.Fatalf("EncodedIntHexLen(%d) = %d; want %d", tc.n, actualLen, expectedLen) } } } // common test function for 32bit and 64bit func testReadHexInt(t *testing.T, s string, expectedN int) { zr := mock.NewZeroCopyReader(s) n, err := ReadHexInt(zr) if err != nil { t.Errorf("unexpected error: %v. s=%q", err, s) } assert.DeepEqual(t, n, expectedN) } func TestAppendQuotedPath(t *testing.T) { t.Parallel() // Test all characters pathSegment := make([]byte, 256) for i := 0; i < 256; i++ { pathSegment[i] = byte(i) } for _, s := range []struct { path string }{ {"/"}, {"//"}, {"/foo/bar"}, {"*"}, {"/foo/" + B2s(pathSegment)}, } { u := url.URL{Path: s.path} expectedS := u.EscapedPath() res := B2s(AppendQuotedPath(nil, S2b(s.path))) assert.DeepEqual(t, expectedS, res) } } func TestAppendQuotedArg(t *testing.T) { t.Parallel() // Sync with url.QueryEscape allcases := make([]byte, 256) for i := 0; i < 256; i++ { allcases[i] = byte(i) } res := B2s(AppendQuotedArg(nil, allcases)) expect := url.QueryEscape(B2s(allcases)) assert.DeepEqual(t, expect, res) } func TestParseHTTPDate(t *testing.T) { t.Parallel() for _, v := range []struct { t string }{ {"Thu, 04 Feb 2010 21:00:57 PST"}, {"Mon, 02 Jan 2006 15:04:05 MST"}, } { t1, err := time.Parse(time.RFC1123, v.t) if err != nil { t.Fatalf("unexpected error: %v. t=%q", err, v.t) } t2, err := ParseHTTPDate(S2b(t1.Format(time.RFC1123))) if err != nil { t.Fatalf("unexpected error: %v. t=%q", err, v.t) } assert.DeepEqual(t, t1, t2) } } // For test only, but it will import golang.org/x/net/http. // So comment out all this code. Keep this for the full context. //func TestValidHeaderFieldValueTable(t *testing.T) { // t.Parallel() // // // Test all characters // allBytes := make([]byte, 0) // for i := 0; i < 256; i++ { // allBytes = append(allBytes, byte(i)) // } // for _, s := range allBytes { // ss := []byte{s} // expectedS := httpguts.ValidHeaderFieldValue(string(ss)) // res := func() bool { // return ValidHeaderFieldValueTable[s] != 0 // }() // // assert.DeepEqual(t, expectedS, res) // } //} // For test only, but it will import golang.org/x/net/http. // So comment out all this code. Keep this for the full context. //func TestValidHeaderFieldNameTable(t *testing.T) { // t.Parallel() // // // Test all characters // allBytes := make([]byte, 0) // for i := 0; i < 256; i++ { // allBytes = append(allBytes, byte(i)) // } // for _, s := range allBytes { // ss := []byte{s} // expectedS := httpguts.ValidHeaderFieldName(string(ss)) // res := func() bool { // return ValidHeaderFieldNameTable[s] != 0 // }() // // assert.DeepEqual(t, expectedS, res) // } //} ================================================ FILE: internal/bytesconv/bytesconv_timing_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package bytesconv import ( "testing" ) // For test only, but it will import golang.org/x/net/http. // So comment out all this code. Keep this for the full context. //func BenchmarkValidHeaderFiledValueTable(b *testing.B) { // // Test all characters // allBytes := make([]string, 0) // for i := 0; i < 256; i++ { // allBytes = append(allBytes, string([]byte{byte(i)})) // } // // for i := 0; i < b.N; i++ { // for _, s := range allBytes { // _ = httpguts.ValidHeaderFieldValue(s) // } // } //} func BenchmarkValidHeaderFiledValueTableHertz(b *testing.B) { // Test all characters allBytes := make([]byte, 0) for i := 0; i < 256; i++ { allBytes = append(allBytes, byte(i)) } for i := 0; i < b.N; i++ { for _, s := range allBytes { _ = func() bool { return ValidHeaderFieldValueTable[s] != 0 }() } } } ================================================ FILE: internal/bytesconv/doc.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // The files in bytesconv package are forked from fasthttp[github.com/valyala/fasthttp], // and we keep the original Copyright[Copyright 2015 fasthttp authors] and License of fasthttp for those files. // We also need to modify as we need, the modifications are Copyright of 2022 CloudWeGo Authors. // Thanks for fasthttp authors! Below is the source code information: // Repo: github.com/valyala/fasthttp // Forked Version: v1.36.0 package bytesconv ================================================ FILE: internal/bytesconv/errors.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package bytesconv import "errors" var ( errEmptyInt = errors.New("empty integer") errUnexpectedFirstChar = errors.New("unexpected first char found. Expecting 0-9") errUnexpectedTrailingChar = errors.New("unexpected trailing char found. Expecting 0-9") errTooLongInt = errors.New("too long int") errEmptyHexNum = errors.New("empty hex number") errTooLargeHexNum = errors.New("too large hex number") ) ================================================ FILE: internal/bytestr/bytes.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package bytestr defines some common bytes package bytestr import ( "github.com/cloudwego/hertz/pkg/protocol/consts" ) var ( DefaultServerName = []byte("hertz") DefaultUserAgent = []byte("hertz") DefaultContentType = []byte("text/plain; charset=utf-8") ) var ( StrBackSlash = []byte("\\") StrSlash = []byte("/") StrSlashSlash = []byte("//") StrSlashDotDot = []byte("/..") StrSlashDotSlash = []byte("/./") StrSlashDotDotSlash = []byte("/../") StrCRLF = []byte("\r\n") StrHTTP = []byte("http") StrHTTPS = []byte("https") StrHTTP11 = []byte("HTTP/1.1") StrColon = []byte(":") StrStar = []byte("*") StrColonSlashSlash = []byte("://") StrColonSpace = []byte(": ") StrCommaSpace = []byte(", ") StrAt = []byte("@") StrSD = []byte("sd") StrResponseContinue = []byte("HTTP/1.1 100 Continue\r\n\r\n") StrGet = []byte(consts.MethodGet) StrHead = []byte(consts.MethodHead) StrPost = []byte(consts.MethodPost) StrPut = []byte(consts.MethodPut) StrDelete = []byte(consts.MethodDelete) StrConnect = []byte(consts.MethodConnect) StrOptions = []byte(consts.MethodOptions) StrTrace = []byte(consts.MethodTrace) StrPatch = []byte(consts.MethodPatch) StrExpect = []byte(consts.HeaderExpect) StrConnection = []byte(consts.HeaderConnection) StrContentLength = []byte(consts.HeaderContentLength) StrContentType = []byte(consts.HeaderContentType) StrDate = []byte(consts.HeaderDate) StrHost = []byte(consts.HeaderHost) StrServer = []byte(consts.HeaderServer) StrTransferEncoding = []byte(consts.HeaderTransferEncoding) StrUserAgent = []byte(consts.HeaderUserAgent) StrCookie = []byte(consts.HeaderCookie) StrLocation = []byte(consts.HeaderLocation) StrContentRange = []byte(consts.HeaderContentRange) StrContentEncoding = []byte(consts.HeaderContentEncoding) StrAcceptEncoding = []byte(consts.HeaderAcceptEncoding) StrSetCookie = []byte(consts.HeaderSetCookie) StrAuthorization = []byte(consts.HeaderAuthorization) StrRange = []byte(consts.HeaderRange) StrLastModified = []byte(consts.HeaderLastModified) StrAcceptRanges = []byte(consts.HeaderAcceptRanges) StrIfModifiedSince = []byte(consts.HeaderIfModifiedSince) StrTE = []byte(consts.HeaderTE) StrTrailer = []byte(consts.HeaderTrailer) StrMaxForwards = []byte(consts.HeaderMaxForwards) StrProxyConnection = []byte(consts.HeaderProxyConnection) StrProxyAuthenticate = []byte(consts.HeaderProxyAuthenticate) StrProxyAuthorization = []byte(consts.HeaderProxyAuthorization) StrWWWAuthenticate = []byte(consts.HeaderWWWAuthenticate) StrCookieExpires = []byte("expires") StrCookieDomain = []byte("domain") StrCookiePath = []byte("path") StrCookieHTTPOnly = []byte("HttpOnly") StrCookieSecure = []byte("secure") StrCookieMaxAge = []byte("max-age") StrCookieSameSite = []byte("SameSite") StrCookieSameSiteLax = []byte("Lax") StrCookieSameSiteStrict = []byte("Strict") StrCookieSameSiteNone = []byte("None") StrCookiePartitioned = []byte("Partitioned") StrClose = []byte("close") StrGzip = []byte("gzip") StrDeflate = []byte("deflate") StrKeepAlive = []byte("keep-alive") StrUpgrade = []byte("Upgrade") StrChunked = []byte("chunked") StrIdentity = []byte("identity") Str100Continue = []byte("100-continue") StrBoundary = []byte("boundary") StrBytes = []byte("bytes") StrTextSlash = []byte("text/") StrApplicationSlash = []byte("application/") StrBasicSpace = []byte("Basic ") // http2 StrClientPreface = []byte(consts.ClientPreface) ) var ( // content types MIMEPostForm = []byte("application/x-www-form-urlencoded") MIMEFormData = []byte("multipart/form-data") MIMETextEventStream = []byte("text/event-stream") // for server-sent events ) ================================================ FILE: internal/nocopy/nocopy.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package nocopy defines the NoCopy struct package nocopy // NoCopy defines the nocopy struct. // Embed this type into a struct, which mustn't be copied, // so `go vet` gives a warning if this struct is copied. // // See https://github.com/golang/go/issues/8005#issuecomment-190753527 for details. // and also: https://stackoverflow.com/questions/52494458/nocopy-minimal-example type NoCopy struct{} func (*NoCopy) Lock() {} func (*NoCopy) Unlock() {} ================================================ FILE: internal/stats/stats_util.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package stats import ( "github.com/cloudwego/hertz/pkg/common/tracer/stats" "github.com/cloudwego/hertz/pkg/common/tracer/traceinfo" ) // Record records the event to HTTPStats. func Record(ti traceinfo.TraceInfo, event stats.Event, err error) { if ti == nil { return } if err != nil { ti.Stats().Record(event, stats.StatusError, err.Error()) } else { ti.Stats().Record(event, stats.StatusInfo, "") } } // CalcEventCostUs calculates the duration between start and end and returns in microsecond. func CalcEventCostUs(start, end traceinfo.Event) uint64 { if start == nil || end == nil || start.IsNil() || end.IsNil() { return 0 } return uint64(end.Time().Sub(start.Time()).Microseconds()) } ================================================ FILE: internal/stats/stats_util_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package stats import ( "testing" "time" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/tracer/stats" "github.com/cloudwego/hertz/pkg/common/tracer/traceinfo" ) func TestUtil(t *testing.T) { assert.Assert(t, CalcEventCostUs(nil, nil) == 0) ti := traceinfo.NewTraceInfo() // nil context Record(ti, stats.HTTPStart, nil) Record(ti, stats.HTTPFinish, nil) st := ti.Stats() assert.Assert(t, st != nil) s, e := st.GetEvent(stats.HTTPStart), st.GetEvent(stats.HTTPFinish) assert.Assert(t, s == nil) assert.Assert(t, e == nil) // stats disabled Record(ti, stats.HTTPStart, nil) time.Sleep(time.Millisecond) Record(ti, stats.HTTPFinish, nil) st = ti.Stats() assert.Assert(t, st != nil) s, e = st.GetEvent(stats.HTTPStart), st.GetEvent(stats.HTTPFinish) assert.Assert(t, s == nil) assert.Assert(t, e == nil) // stats enabled st = ti.Stats() st.(interface{ SetLevel(stats.Level) }).SetLevel(stats.LevelBase) Record(ti, stats.HTTPStart, nil) time.Sleep(time.Millisecond) Record(ti, stats.HTTPFinish, nil) s, e = st.GetEvent(stats.HTTPStart), st.GetEvent(stats.HTTPFinish) assert.Assert(t, s != nil, s) assert.Assert(t, e != nil, e) assert.Assert(t, CalcEventCostUs(s, e) > 0) } ================================================ FILE: internal/stats/tracer.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package stats import ( "context" "runtime/debug" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/common/hlog" "github.com/cloudwego/hertz/pkg/common/tracer" "github.com/cloudwego/hertz/pkg/common/tracer/stats" ) // Controller controls tracers. type Controller struct { tracers []tracer.Tracer } // Append appends a new tracer to the controller. func (ctl *Controller) Append(col tracer.Tracer) { ctl.tracers = append(ctl.tracers, col) } // DoStart starts the tracers. func (ctl *Controller) DoStart(ctx context.Context, c *app.RequestContext) context.Context { defer ctl.tryRecover() Record(c.GetTraceInfo(), stats.HTTPStart, nil) for _, col := range ctl.tracers { ctx = col.Start(ctx, c) } return ctx } // DoFinish calls the tracers in reversed order. func (ctl *Controller) DoFinish(ctx context.Context, c *app.RequestContext, err error) { defer ctl.tryRecover() Record(c.GetTraceInfo(), stats.HTTPFinish, err) if err != nil { c.GetTraceInfo().Stats().SetError(err) } // reverse the order for i := len(ctl.tracers) - 1; i >= 0; i-- { ctl.tracers[i].Finish(ctx, c) } } func (ctl *Controller) HasTracer() bool { return ctl != nil && len(ctl.tracers) > 0 } func (ctl *Controller) tryRecover() { if err := recover(); err != nil { hlog.SystemLogger().Warnf("Panic happened during tracer call. This doesn't affect the http call, but may lead to lack of monitor data such as metrics and logs: %s, %s", err, string(debug.Stack())) } } ================================================ FILE: internal/stats/tracer_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package stats import ( "context" "errors" "fmt" "testing" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/tracer/traceinfo" ) type mockTracer struct { order int stack *[]int panicAtStart bool panicAtFinish bool } func (mt *mockTracer) Start(ctx context.Context, c *app.RequestContext) context.Context { if mt.panicAtStart { panic(fmt.Sprintf("panicked at start: Tracer(%d)", mt.order)) } *mt.stack = append(*mt.stack, mt.order) return context.WithValue(ctx, mt, mt.order) } func (mt *mockTracer) Finish(ctx context.Context, c *app.RequestContext) { if mt.panicAtFinish { panic(fmt.Sprintf("panicked at finish: Tracer(%d)", mt.order)) } *mt.stack = append(*mt.stack, -mt.order) } func TestOrder(t *testing.T) { var c Controller var stack []int t1 := &mockTracer{order: 1, stack: &stack} t2 := &mockTracer{order: 2, stack: &stack} ctx := app.NewContext(16) c.Append(t1) c.Append(t2) ctx0 := context.Background() ctx1 := c.DoStart(ctx0, ctx) assert.Assert(t, ctx1 != ctx0) assert.Assert(t, len(stack) == 2 && stack[0] == 1 && stack[1] == 2, stack) c.DoFinish(ctx1, ctx, nil) assert.Assert(t, len(stack) == 4 && stack[2] == -2 && stack[3] == -1, stack) } func TestPanic(t *testing.T) { var c Controller var stack []int t1 := &mockTracer{order: 1, stack: &stack, panicAtStart: true, panicAtFinish: true} t2 := &mockTracer{order: 2, stack: &stack} ctx := app.NewContext(16) ctx.SetTraceInfo(traceinfo.NewTraceInfo()) c.Append(t1) c.Append(t2) ctx0 := context.Background() ctx1 := c.DoStart(ctx0, ctx) assert.Assert(t, ctx1 != ctx0) assert.Assert(t, len(stack) == 0) // t1's panic skips all subsequent Starts err := errors.New("some error") c.DoFinish(ctx1, ctx, err) assert.Assert(t, len(stack) == 1 && stack[0] == -2, stack) } ================================================ FILE: internal/tagexpr/LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright 2019 Bytedance Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: internal/tagexpr/README.md ================================================ # go-tagexpr originally from https://github.com/bytedance/go-tagexpr v2.9.2 ================================================ FILE: internal/tagexpr/example_test.go ================================================ // Copyright 2019 Bytedance Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package tagexpr_test import ( "fmt" "github.com/cloudwego/hertz/internal/tagexpr" ) func Example() { type T struct { A int `tagexpr:"$<0||$>=100"` B string `tagexpr:"len($)>1 && regexp('^\\w*$')"` C bool `tagexpr:"expr1:(f.g)$>0 && $; expr2:'C must be true when T.f.g>0'"` d []string `tagexpr:"@:len($)>0 && $[0]=='D'; msg:sprintf('invalid d: %v',$)"` e map[string]int `tagexpr:"len($)==$['len']"` e2 map[string]*int `tagexpr:"len($)==$['len']"` f struct { g int `tagexpr:"$"` } h int `tagexpr:"$>minVal"` } vm := tagexpr.New("tagexpr") t := &T{ A: 107, B: "abc", C: true, d: []string{"x", "y"}, e: map[string]int{"len": 1}, e2: map[string]*int{"len": new(int)}, f: struct { g int `tagexpr:"$"` }{1}, h: 10, } tagExpr, err := vm.Run(t) if err != nil { panic(err) } fmt.Println(tagExpr.Eval("A")) fmt.Println(tagExpr.Eval("B")) fmt.Println(tagExpr.Eval("C@expr1")) fmt.Println(tagExpr.Eval("C@expr2")) if !tagExpr.Eval("d").(bool) { fmt.Println(tagExpr.Eval("d@msg")) } fmt.Println(tagExpr.Eval("e")) fmt.Println(tagExpr.Eval("e2")) fmt.Println(tagExpr.Eval("f.g")) fmt.Println(tagExpr.EvalWithEnv("h", map[string]interface{}{"minVal": 9})) fmt.Println(tagExpr.EvalWithEnv("h", map[string]interface{}{"minVal": 11})) // Output: // true // true // true // C must be true when T.f.g>0 // invalid d: [x y] // true // false // 1 // true // false } ================================================ FILE: internal/tagexpr/expr.go ================================================ // Copyright 2019 Bytedance Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package tagexpr import ( "context" "fmt" ) type variableKeyType string const variableKey variableKeyType = "__ENV_KEY__" // Expr expression type Expr struct { expr ExprNode } // parseExpr parses the expression. func parseExpr(expr string) (*Expr, error) { e := newGroupExprNode() p := &Expr{ expr: e, } s := expr err := p.parseExprNode(&s, e) if err != nil { return nil, err } sortPriority(e) return p, nil } func (p *Expr) parseExprNode(expr *string, e ExprNode) error { trimLeftSpace(expr) if *expr == "" { return nil } operand := p.readSelectorExprNode(expr) if operand == nil { operand = p.readRangeKvExprNode(expr) if operand == nil { var subExprNode *string operand, subExprNode = readGroupExprNode(expr) if operand != nil { err := p.parseExprNode(subExprNode, operand) if err != nil { return err } } else { operand = p.parseOperand(expr) } } } if operand == nil { return fmt.Errorf("syntax error: %q", *expr) } trimLeftSpace(expr) operator := p.parseOperator(expr) if operator == nil { e.SetRightOperand(operand) operand.SetParent(e) return nil } if _, ok := e.(*groupExprNode); ok { operator.SetLeftOperand(operand) operand.SetParent(operator) e.SetRightOperand(operator) operator.SetParent(e) } else { operator.SetParent(e.Parent()) operator.Parent().SetRightOperand(operator) operator.SetLeftOperand(e) e.SetParent(operator) e.SetRightOperand(operand) operand.SetParent(e) } return p.parseExprNode(expr, operator) } func (p *Expr) parseOperand(expr *string) (e ExprNode) { for _, fn := range funcList { if e = fn(p, expr); e != nil { return e } } if e = readStringExprNode(expr); e != nil { return e } if e = readDigitalExprNode(expr); e != nil { return e } if e = readBoolExprNode(expr); e != nil { return e } if e = readNilExprNode(expr); e != nil { return e } if e = readVariableExprNode(expr); e != nil { return e } return nil } func (*Expr) parseOperator(expr *string) (e ExprNode) { s := *expr if len(s) < 2 { return nil } defer func() { if e != nil && *expr == s { *expr = (*expr)[2:] } }() a := s[:2] switch a { // case "<<": // case ">>": // case "&^": case "||": return newOrExprNode() case "&&": return newAndExprNode() case "==": return newEqualExprNode() case ">=": return newGreaterEqualExprNode() case "<=": return newLessEqualExprNode() case "!=": return newNotEqualExprNode() } defer func() { if e != nil { *expr = (*expr)[1:] } }() switch a[0] { // case '&': // case '|': // case '^': case '+': return newAdditionExprNode() case '-': return newSubtractionExprNode() case '*': return newMultiplicationExprNode() case '/': return newDivisionExprNode() case '%': return newRemainderExprNode() case '<': return newLessExprNode() case '>': return newGreaterExprNode() } return nil } // run calculates the value of expression. func (p *Expr) run(field string, tagExpr *TagExpr) interface{} { return p.expr.Run(context.Background(), field, tagExpr) } func (p *Expr) runWithEnv(field string, tagExpr *TagExpr, env map[string]interface{}) interface{} { ctx := context.WithValue(context.Background(), variableKey, env) return p.expr.Run(ctx, field, tagExpr) } /** * Priority: * () ! bool float64 string nil * * / % * + - * < <= > >= * == != * && * || **/ func sortPriority(e ExprNode) { for subSortPriority(e.RightOperand(), false) { } } func subSortPriority(e ExprNode, isLeft bool) bool { if e == nil { return false } leftChanged := subSortPriority(e.LeftOperand(), true) rightChanged := subSortPriority(e.RightOperand(), false) if getPriority(e) > getPriority(e.LeftOperand()) { leftOperandToParent(e, isLeft) return true } return leftChanged || rightChanged } func leftOperandToParent(e ExprNode, isLeft bool) { le := e.LeftOperand() if le == nil { return } p := e.Parent() le.SetParent(p) if p != nil { if isLeft { p.SetLeftOperand(le) } else { p.SetRightOperand(le) } } e.SetParent(le) e.SetLeftOperand(le.RightOperand()) le.RightOperand().SetParent(e) le.SetRightOperand(e) } func getPriority(e ExprNode) (i int) { // defer func() { // printf("expr:%T %d\n", e, i) // }() switch e.(type) { default: // () ! bool float64 string nil return 7 case *multiplicationExprNode, *divisionExprNode, *remainderExprNode: // * / % return 6 case *additionExprNode, *subtractionExprNode: // + - return 5 case *lessExprNode, *lessEqualExprNode, *greaterExprNode, *greaterEqualExprNode: // < <= > >= return 4 case *equalExprNode, *notEqualExprNode: // == != return 3 case *andExprNode: // && return 2 case *orExprNode: // || return 1 } } // ExprNode expression interface type ExprNode interface { SetParent(ExprNode) Parent() ExprNode LeftOperand() ExprNode RightOperand() ExprNode SetLeftOperand(ExprNode) SetRightOperand(ExprNode) String() string Run(context.Context, string, *TagExpr) interface{} } // var _ ExprNode = new(exprBackground) type exprBackground struct { parent ExprNode leftOperand ExprNode rightOperand ExprNode } func (eb *exprBackground) SetParent(e ExprNode) { eb.parent = e } func (eb *exprBackground) Parent() ExprNode { return eb.parent } func (eb *exprBackground) LeftOperand() ExprNode { return eb.leftOperand } func (eb *exprBackground) RightOperand() ExprNode { return eb.rightOperand } func (eb *exprBackground) SetLeftOperand(left ExprNode) { eb.leftOperand = left } func (eb *exprBackground) SetRightOperand(right ExprNode) { eb.rightOperand = right } func (*exprBackground) Run(context.Context, string, *TagExpr) interface{} { return nil } ================================================ FILE: internal/tagexpr/expr_test.go ================================================ // Copyright 2019 Bytedance Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package tagexpr import ( "math" "reflect" "testing" ) func TestExpr(t *testing.T) { cases := []struct { expr string val interface{} }{ // Simple string {expr: "'a'", val: "a"}, {expr: "('a')", val: "a"}, // Simple digital {expr: " 10 ", val: 10.0}, {expr: "(10)", val: 10.0}, // Simple bool {expr: "true", val: true}, {expr: "!true", val: false}, {expr: "!!true", val: true}, {expr: "false", val: false}, {expr: "!false", val: true}, {expr: "!!false", val: false}, {expr: "(false)", val: false}, {expr: "(!false)", val: true}, {expr: "(!!false)", val: false}, {expr: "!!(!false)", val: true}, {expr: "!(!false)", val: false}, // Join string {expr: "'true '+('a')", val: "true a"}, {expr: "'a'+('b'+'c')+'d'", val: "abcd"}, // Arithmetic operator {expr: "1+7+2", val: 10.0}, {expr: "1+(7)+(2)", val: 10.0}, {expr: "1.1+ 2", val: 3.1}, {expr: "-1.1+4", val: 2.9}, {expr: "10-7-2", val: 1.0}, {expr: "20/2", val: 10.0}, {expr: "1/0", val: math.NaN()}, {expr: "20%2", val: 0.0}, {expr: "6 % 5", val: 1.0}, {expr: "20%7 %5", val: 1.0}, {expr: "1*2+7+2.2", val: 11.2}, {expr: "-20/2+1+2", val: -7.0}, {expr: "20/2+1-2-1", val: 8.0}, {expr: "30/(2+1)/5-2-1", val: -1.0}, {expr: "100/(( 2+8)*5 )-(1 +1- 0)", val: 0.0}, {expr: "(2*3)+(4*2)", val: 14.0}, {expr: "1+(2*(3+4))", val: 15.0}, {expr: "20%(7%5)", val: 0.0}, // Relational operator {expr: "50 == 5", val: false}, {expr: "'50'==50", val: true}, {expr: "'50'=='50'", val: true}, {expr: "'50' =='5' == true", val: false}, {expr: "50== 50 == false", val: false}, {expr: "50== 50 == true ==true==true", val: true}, {expr: "50 != 5", val: true}, {expr: "'50'!=50", val: false}, {expr: "'50'!= '50'", val: false}, {expr: "'50' !='5' != true", val: false}, {expr: "50!= 50 == false", val: true}, {expr: "50== 50 != true ==true!=true", val: true}, {expr: "50 > 5", val: true}, {expr: "50.1 > 50.1", val: false}, {expr: "3.2 > 2.1", val: true}, {expr: "'3.2' > '2.1'", val: true}, {expr: "'13.2'>'2.1'", val: false}, {expr: "3.2 >= 2.1", val: true}, {expr: "2.1 >= 2.1", val: true}, {expr: "2.05 >= 2.1", val: false}, {expr: "'2.05'>='2.1'", val: false}, {expr: "'12.05'>='2.1'", val: false}, {expr: "50 < 5", val: false}, {expr: "50.1 < 50.1", val: false}, {expr: "3 <12.11", val: true}, {expr: "3.2 < 2.1", val: false}, {expr: "'3.2' < '2.1'", val: false}, {expr: "'13.2' < '2.1'", val: true}, {expr: "3.2 <= 2.1", val: false}, {expr: "2.1 <= 2.1", val: true}, {expr: "2.05 <= 2.1", val: true}, {expr: "'2.05'<='2.1'", val: true}, {expr: "'12.05'<='2.1'", val: true}, // Logical operator {expr: "!('13.2' < '2.1')", val: false}, {expr: "(3.2 <= 2.1) &&true", val: false}, {expr: "true&&(2.1<=2.1)", val: true}, {expr: "(2.05<=2.1)&&false", val: false}, {expr: "true&&!true&&false", val: false}, {expr: "true&&true&&true", val: true}, {expr: "true&&true&&false", val: false}, {expr: "false&&true&&true", val: false}, {expr: "true && false && true", val: false}, {expr: "true||false", val: true}, {expr: "false ||true", val: true}, {expr: "true&&true || false", val: true}, {expr: "true&&false || false", val: false}, {expr: "true && false || true ", val: true}, } for _, c := range cases { t.Log(c.expr) vm, err := parseExpr(c.expr) if err != nil { t.Fatal(err) } val := vm.run("", nil) if !reflect.DeepEqual(val, c.val) { if f, ok := c.val.(float64); ok && math.IsNaN(f) && math.IsNaN(val.(float64)) { continue } t.Fatalf("expr: %q, got: %v, expect: %v", c.expr, val, c.val) } } } func TestExprWithEnv(t *testing.T) { cases := []struct { expr string val interface{} }{ // env: a = 10, b = "string value", {expr: "a", val: 10.0}, {expr: "b", val: "string value"}, {expr: "a>10", val: false}, {expr: "a<11", val: true}, {expr: "a+1", val: 11.0}, {expr: "a==10", val: true}, } for _, c := range cases { t.Log(c.expr) vm, err := parseExpr(c.expr) if err != nil { t.Fatal(err) } val := vm.runWithEnv("", nil, map[string]interface{}{"a": 10, "b": "string value"}) if !reflect.DeepEqual(val, c.val) { if f, ok := c.val.(float64); ok && math.IsNaN(f) && math.IsNaN(val.(float64)) { continue } t.Fatalf("expr: %q, got: %v, expect: %v", c.expr, val, c.val) } } } func TestPriority(t *testing.T) { cases := []struct { expr string val interface{} }{ {expr: "false||true&&8==8", val: true}, {expr: "1+2>5-4", val: true}, {expr: "1+2*4/2", val: 5.0}, {expr: "(true||false)&&false||false", val: false}, {expr: "true||false&&false||false", val: true}, {expr: "true||1<0&&'a'!='a'||0!=0", val: true}, } for _, c := range cases { t.Log(c.expr) vm, err := parseExpr(c.expr) if err != nil { t.Fatal(err) } val := vm.run("", nil) if !reflect.DeepEqual(val, c.val) { if f, ok := c.val.(float64); ok && math.IsNaN(f) && math.IsNaN(val.(float64)) { continue } t.Fatalf("expr: %q, got: %v, expect: %v", c.expr, val, c.val) } } } func TestBuiltInFunc(t *testing.T) { cases := []struct { expr string val interface{} }{ {expr: "len('abc')", val: 3.0}, {expr: "len('abc')+2*2/len('cd')", val: 5.0}, {expr: "len(0)", val: 0.0}, {expr: "regexp('a\\d','a0')", val: true}, {expr: "regexp('^a\\d$','a0')", val: true}, {expr: "regexp('a\\d','a')", val: false}, {expr: "regexp('^a\\d$','a')", val: false}, {expr: "sprintf('test string: %s','a')", val: "test string: a"}, {expr: "sprintf('test string: %s','a'+'b')", val: "test string: ab"}, {expr: "sprintf('test string: %s,%v','a',1)", val: "test string: a,1"}, {expr: "sprintf('')+'a'", val: "a"}, {expr: "sprintf('%v',10+2*2)", val: "14"}, } for _, c := range cases { t.Log(c.expr) vm, err := parseExpr(c.expr) if err != nil { t.Fatal(err) } val := vm.run("", nil) if !reflect.DeepEqual(val, c.val) { if f, ok := c.val.(float64); ok && math.IsNaN(f) && math.IsNaN(val.(float64)) { continue } t.Fatalf("expr: %q, got: %v, expect: %v", c.expr, val, c.val) } } } func TestSyntaxIncorrect(t *testing.T) { cases := []struct { incorrectExpr string }{ {incorrectExpr: "1 + + 'a'"}, {incorrectExpr: "regexp()"}, {incorrectExpr: "regexp('^'+'a','a')"}, {incorrectExpr: "regexp('^a','a','b')"}, {incorrectExpr: "sprintf()"}, {incorrectExpr: "sprintf(0)"}, {incorrectExpr: "sprintf('a'+'b')"}, } for _, c := range cases { _, err := parseExpr(c.incorrectExpr) if err == nil { t.Fatalf("expect syntax incorrect: %s", c.incorrectExpr) } else { t.Log(err) } } } ================================================ FILE: internal/tagexpr/handler.go ================================================ // Copyright 2019 Bytedance Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package tagexpr import "reflect" // FieldHandler field handler type FieldHandler struct { selector string field *fieldVM expr *TagExpr } func newFieldHandler(expr *TagExpr, fieldSelector string, field *fieldVM) *FieldHandler { return &FieldHandler{ selector: fieldSelector, field: field, expr: expr, } } // StringSelector returns the field selector of string type. func (f *FieldHandler) StringSelector() string { return f.selector } // FieldSelector returns the field selector of FieldSelector type. func (f *FieldHandler) FieldSelector() FieldSelector { return FieldSelector(f.selector) } // Value returns the field value. // NOTE: // // If initZero==true, initialize nil pointer to zero value func (f *FieldHandler) Value(initZero bool) reflect.Value { return f.field.reflectValueGetter(f.expr.ptr, initZero) } // EvalFuncs returns the tag expression eval functions. func (f *FieldHandler) EvalFuncs() map[ExprSelector]func() interface{} { targetTagExpr, _ := f.expr.checkout(f.selector) evals := make(map[ExprSelector]func() interface{}, len(f.field.exprs)) for k, v := range f.field.exprs { expr := v exprSelector := ExprSelector(k) evals[exprSelector] = func() interface{} { return expr.run(exprSelector.Name(), targetTagExpr) } } return evals } // StructField returns the field StructField object. func (f *FieldHandler) StructField() reflect.StructField { return f.field.structField } // ExprHandler expr handler type ExprHandler struct { base string path string selector string expr *TagExpr targetExpr *TagExpr } func newExprHandler(te, tte *TagExpr, base, es string) *ExprHandler { return &ExprHandler{ base: base, selector: es, expr: te, targetExpr: tte, } } // TagExpr returns the *TagExpr. func (e *ExprHandler) TagExpr() *TagExpr { return e.expr } // StringSelector returns the expression selector of string type. func (e *ExprHandler) StringSelector() string { return e.selector } // ExprSelector returns the expression selector of ExprSelector type. func (e *ExprHandler) ExprSelector() ExprSelector { return ExprSelector(e.selector) } // Path returns the path description of the expression. func (e *ExprHandler) Path() string { if e.path == "" { if e.targetExpr.path == "" { e.path = e.selector } else { e.path = e.targetExpr.path + FieldSeparator + e.selector } } return e.path } // Eval evaluate the value of the struct tag expression. // NOTE: // // result types: float64, string, bool, nil func (e *ExprHandler) Eval() interface{} { return e.expr.s.exprs[e.selector].run(e.base, e.targetExpr) } // EvalFloat evaluates the value of the struct tag expression. // NOTE: // // If the expression value type is not float64, return 0. func (e *ExprHandler) EvalFloat() float64 { r, _ := e.Eval().(float64) return r } // EvalString evaluates the value of the struct tag expression. // NOTE: // // If the expression value type is not string, return "". func (e *ExprHandler) EvalString() string { r, _ := e.Eval().(string) return r } // EvalBool evaluates the value of the struct tag expression. // NOTE: // // If the expression value is not 0, '' or nil, return true. func (e *ExprHandler) EvalBool() bool { return FakeBool(e.Eval()) } ================================================ FILE: internal/tagexpr/selector.go ================================================ // Copyright 2019 Bytedance Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package tagexpr import ( "strings" ) const ( // FieldSeparator in the expression selector, // the separator between field names FieldSeparator = "." // ExprNameSeparator in the expression selector, // the separator of the field name and expression name ExprNameSeparator = "@" // DefaultExprName the default name of single model expression DefaultExprName = ExprNameSeparator ) // FieldSelector expression selector type FieldSelector string // Name returns the current field name. func (f FieldSelector) Name() string { s := string(f) idx := strings.LastIndex(s, FieldSeparator) if idx == -1 { return s } return s[idx+1:] } // Split returns the path segments and the current field name. func (f FieldSelector) Split() (paths []string, name string) { s := string(f) a := strings.Split(s, FieldSeparator) idx := len(a) - 1 if idx > 0 { return a[:idx], a[idx] } return nil, s } // Parent returns the parent FieldSelector. func (f FieldSelector) Parent() (string, bool) { s := string(f) i := strings.LastIndex(s, FieldSeparator) if i < 0 { return "", false } return s[:i], true } // String returns string type value. func (f FieldSelector) String() string { return string(f) } // ExprSelector expression selector type ExprSelector string // Name returns the name of the expression. func (e ExprSelector) Name() string { s := string(e) atIdx := strings.LastIndex(s, ExprNameSeparator) if atIdx == -1 { return DefaultExprName } return s[atIdx+1:] } // Field returns the field selector it belongs to. func (e ExprSelector) Field() string { s := string(e) idx := strings.LastIndex(s, ExprNameSeparator) if idx != -1 { s = s[:idx] } return s } // ParentField returns the parent field selector it belongs to. func (e ExprSelector) ParentField() (string, bool) { return FieldSelector(e.Field()).Parent() } // Split returns the field selector and the expression name. func (e ExprSelector) Split() (field FieldSelector, name string) { s := string(e) atIdx := strings.LastIndex(s, ExprNameSeparator) if atIdx == -1 { return FieldSelector(s), DefaultExprName } return FieldSelector(s[:atIdx]), s[atIdx+1:] } // String returns string type value. func (e ExprSelector) String() string { return string(e) } ================================================ FILE: internal/tagexpr/selector_test.go ================================================ // Copyright 2019 Bytedance Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package tagexpr import ( "testing" ) func TestExprSelector(t *testing.T) { es := ExprSelector("F1.Index") field, ok := es.ParentField() if !ok { t.Fatal("not ok") } if "F1" != field { t.Fatal(field) } } ================================================ FILE: internal/tagexpr/spec_func.go ================================================ // Copyright 2019 Bytedance Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package tagexpr import ( "context" "fmt" "reflect" "regexp" "strings" ) // --------------------------- Custom function --------------------------- var funcList = map[string]func(p *Expr, expr *string) ExprNode{} // MustRegFunc registers function expression. // NOTE: // // example: len($), regexp("\\d") or regexp("\\d",$); // If @force=true, allow to cover the existed same @funcName; // The go number types always are float64; // The go string types always are string; // Panic if there is an error. func MustRegFunc(funcName string, fn func(...interface{}) interface{}, force ...bool) { err := RegFunc(funcName, fn, force...) if err != nil { panic(err) } } // RegFunc registers function expression. // NOTE: // // example: len($), regexp("\\d") or regexp("\\d",$); // If @force=true, allow to cover the existed same @funcName; // The go number types always are float64; // The go string types always are string. func RegFunc(funcName string, fn func(...interface{}) interface{}, force ...bool) error { if len(force) == 0 || !force[0] { _, ok := funcList[funcName] if ok { return fmt.Errorf("duplicate registration expression function: %s", funcName) } } funcList[funcName] = newFunc(funcName, fn) return nil } func (p *Expr) parseFuncSign(funcName string, expr *string) (boolOpposite *bool, signOpposite *bool, args []ExprNode, found bool) { prefix := funcName + "(" length := len(funcName) last, boolOpposite, signOpposite := getBoolAndSignOpposite(expr) if !strings.HasPrefix(last, prefix) { return } *expr = last[length:] lastStr := *expr subExprNode := readPairedSymbol(expr, '(', ')') if subExprNode == nil { return } *subExprNode = "," + *subExprNode for { if strings.HasPrefix(*subExprNode, ",") { *subExprNode = (*subExprNode)[1:] operand := newGroupExprNode() err := p.parseExprNode(trimLeftSpace(subExprNode), operand) if err != nil { *expr = lastStr return } sortPriority(operand) args = append(args, operand) } else { *expr = lastStr return } trimLeftSpace(subExprNode) if len(*subExprNode) == 0 { found = true return } } } func newFunc(funcName string, fn func(...interface{}) interface{}) func(*Expr, *string) ExprNode { return func(p *Expr, expr *string) ExprNode { boolOpposite, signOpposite, args, found := p.parseFuncSign(funcName, expr) if !found { return nil } return &funcExprNode{ fn: fn, boolOpposite: boolOpposite, signOpposite: signOpposite, args: args, } } } type funcExprNode struct { exprBackground args []ExprNode fn func(...interface{}) interface{} boolOpposite *bool signOpposite *bool } func (f *funcExprNode) String() string { return "func()" } func (f *funcExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { var args []interface{} if n := len(f.args); n > 0 { args = make([]interface{}, n) for k, v := range f.args { args[k] = v.Run(ctx, currField, tagExpr) } } return realValue(f.fn(args...), f.boolOpposite, f.signOpposite) } // --------------------------- Built-in function --------------------------- func init() { funcList["regexp"] = readRegexpFuncExprNode funcList["sprintf"] = readSprintfFuncExprNode funcList["range"] = readRangeFuncExprNode // len: Built-in function len, the length of struct field X MustRegFunc("len", func(args ...interface{}) (n interface{}) { if len(args) != 1 { return 0 } v := args[0] switch e := v.(type) { case string: return float64(len(e)) case float64, bool, nil: return 0 } defer func() { if recover() != nil { n = 0 } }() return float64(reflect.ValueOf(v).Len()) }, true) // mblen: get the length of string field X (character number) MustRegFunc("mblen", func(args ...interface{}) (n interface{}) { if len(args) != 1 { return 0 } v := args[0] switch e := v.(type) { case string: return float64(len([]rune(e))) case float64, bool, nil: return 0 } defer func() { if recover() != nil { n = 0 } }() return float64(reflect.ValueOf(v).Len()) }, true) // in: Check if the first parameter is one of the enumerated parameters MustRegFunc("in", func(args ...interface{}) interface{} { switch len(args) { case 0: return true case 1: return false default: elem := args[0] set := args[1:] for _, e := range set { if elem == e { return true } } return false } }, true) } type regexpFuncExprNode struct { exprBackground re *regexp.Regexp boolOpposite bool } func (re *regexpFuncExprNode) String() string { return "regexp()" } func readRegexpFuncExprNode(p *Expr, expr *string) ExprNode { last, boolOpposite, _ := getBoolAndSignOpposite(expr) if !strings.HasPrefix(last, "regexp(") { return nil } *expr = last[6:] lastStr := *expr subExprNode := readPairedSymbol(expr, '(', ')') if subExprNode == nil { return nil } s := readPairedSymbol(trimLeftSpace(subExprNode), '\'', '\'') if s == nil { *expr = lastStr return nil } rege, err := regexp.Compile(*s) if err != nil { *expr = lastStr return nil } operand := newGroupExprNode() trimLeftSpace(subExprNode) if strings.HasPrefix(*subExprNode, ",") { *subExprNode = (*subExprNode)[1:] err = p.parseExprNode(trimLeftSpace(subExprNode), operand) if err != nil { *expr = lastStr return nil } } else { currFieldVal := "$" p.parseExprNode(&currFieldVal, operand) } trimLeftSpace(subExprNode) if *subExprNode != "" { *expr = lastStr return nil } e := ®expFuncExprNode{ re: rege, } if boolOpposite != nil { e.boolOpposite = *boolOpposite } e.SetRightOperand(operand) return e } func (re *regexpFuncExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { param := re.rightOperand.Run(ctx, currField, tagExpr) switch v := param.(type) { case string: bol := re.re.MatchString(v) if re.boolOpposite { return !bol } return bol case float64, bool: return false } v := reflect.ValueOf(param) if v.Kind() == reflect.String { bol := re.re.MatchString(v.String()) if re.boolOpposite { return !bol } return bol } return false } type sprintfFuncExprNode struct { exprBackground format string args []ExprNode } func (se *sprintfFuncExprNode) String() string { return "sprintf()" } func readSprintfFuncExprNode(p *Expr, expr *string) ExprNode { if !strings.HasPrefix(*expr, "sprintf(") { return nil } *expr = (*expr)[7:] lastStr := *expr subExprNode := readPairedSymbol(expr, '(', ')') if subExprNode == nil { return nil } format := readPairedSymbol(trimLeftSpace(subExprNode), '\'', '\'') if format == nil { *expr = lastStr return nil } e := &sprintfFuncExprNode{ format: *format, } for { trimLeftSpace(subExprNode) if len(*subExprNode) == 0 { return e } if strings.HasPrefix(*subExprNode, ",") { *subExprNode = (*subExprNode)[1:] operand := newGroupExprNode() err := p.parseExprNode(trimLeftSpace(subExprNode), operand) if err != nil { *expr = lastStr return nil } sortPriority(operand) e.args = append(e.args, operand) } else { *expr = lastStr return nil } } } func (se *sprintfFuncExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { var args []interface{} if n := len(se.args); n > 0 { args = make([]interface{}, n) for i, e := range se.args { args[i] = e.Run(ctx, currField, tagExpr) } } return fmt.Sprintf(se.format, args...) } ================================================ FILE: internal/tagexpr/spec_func_test.go ================================================ // Copyright 2019 Bytedance Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package tagexpr_test import ( "reflect" "regexp" "testing" "github.com/cloudwego/hertz/internal/tagexpr" ) func TestFunc(t *testing.T) { emailRegexp := regexp.MustCompile( "^([A-Za-z0-9_\\-\\.\u4e00-\u9fa5])+\\@([A-Za-z0-9_\\-\\.])+\\.([A-Za-z]{2,8})$", ) tagexpr.RegFunc("email", func(args ...interface{}) interface{} { if len(args) == 0 { return false } s, ok := args[0].(string) if !ok { return false } t.Log(s) return emailRegexp.MatchString(s) }) vm := tagexpr.New("te") type T struct { Email string `te:"email($)"` } cases := []struct { email string expect bool }{ {"", false}, {"henrylee2cn@gmail.com", true}, } obj := new(T) for _, c := range cases { obj.Email = c.email te := vm.MustRun(obj) got := te.EvalBool("Email") if got != c.expect { t.Fatalf("email: %s, expect: %v, but got: %v", c.email, c.expect, got) } } // test len type R struct { Str string `vd:"mblen($)<6"` } lenCases := []struct { str string expect bool }{ {"123", true}, {"一二三四五六七", false}, {"一二三四五", true}, } lenObj := new(R) vm = tagexpr.New("vd") for _, lenCase := range lenCases { lenObj.Str = lenCase.str te := vm.MustRun(lenObj) got := te.EvalBool("Str") if got != lenCase.expect { t.Fatalf("string: %v, expect: %v, but got: %v", lenCase.str, lenCase.expect, got) } } } func TestRangeIn(t *testing.T) { vm := tagexpr.New("te") type S struct { F []string `te:"range($, in(#v, '', 'ttp', 'euttp'))"` } a := []string{"ttp", "", "euttp"} r := vm.MustRun(S{ F: a, // F: b, }) expect := []interface{}{true, true, true} actual := r.Eval("F") if !reflect.DeepEqual(expect, actual) { t.Fatal("not equal", expect, actual) } } ================================================ FILE: internal/tagexpr/spec_operand.go ================================================ // Copyright 2019 Bytedance Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package tagexpr import ( "context" "fmt" "reflect" "regexp" "strconv" "strings" ) // --------------------------- Operand --------------------------- type groupExprNode struct { exprBackground boolOpposite *bool signOpposite *bool } func newGroupExprNode() ExprNode { return &groupExprNode{} } func readGroupExprNode(expr *string) (grp ExprNode, subExprNode *string) { last, boolOpposite, signOpposite := getBoolAndSignOpposite(expr) sptr := readPairedSymbol(&last, '(', ')') if sptr == nil { return nil, nil } *expr = last e := &groupExprNode{boolOpposite: boolOpposite, signOpposite: signOpposite} return e, sptr } func (ge *groupExprNode) String() string { return "()" } func (ge *groupExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { if ge.rightOperand == nil { return nil } return realValue(ge.rightOperand.Run(ctx, currField, tagExpr), ge.boolOpposite, ge.signOpposite) } type boolExprNode struct { exprBackground val bool } func (be *boolExprNode) String() string { return fmt.Sprintf("%v", be.val) } var boolRegexp = regexp.MustCompile(`^!*(true|false)([\)\],\|&!= \t]{1}|$)`) func readBoolExprNode(expr *string) ExprNode { s := boolRegexp.FindString(*expr) if s == "" { return nil } last := s[len(s)-1] if last != 'e' { s = s[:len(s)-1] } *expr = (*expr)[len(s):] e := &boolExprNode{} if strings.Contains(s, "t") { e.val = (len(s)-4)&1 == 0 } else { e.val = (len(s)-5)&1 == 1 } return e } func (be *boolExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { return be.val } type stringExprNode struct { exprBackground val interface{} } func (se *stringExprNode) String() string { return fmt.Sprintf("%v", se.val) } func readStringExprNode(expr *string) ExprNode { last, boolOpposite, _ := getBoolAndSignOpposite(expr) sptr := readPairedSymbol(&last, '\'', '\'') if sptr == nil { return nil } *expr = last e := &stringExprNode{val: realValue(*sptr, boolOpposite, nil)} return e } func (se *stringExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { return se.val } type digitalExprNode struct { exprBackground val interface{} } func (de *digitalExprNode) String() string { return fmt.Sprintf("%v", de.val) } var digitalRegexp = regexp.MustCompile(`^[\+\-]?\d+(\.\d+)?([\)\],\+\-\*\/%><\|&!=\^ \t\\]|$)`) func readDigitalExprNode(expr *string) ExprNode { last, boolOpposite := getOpposite(expr, "!") s := digitalRegexp.FindString(last) if s == "" { return nil } if r := s[len(s)-1]; r < '0' || r > '9' { s = s[:len(s)-1] } *expr = last[len(s):] f64, _ := strconv.ParseFloat(s, 64) return &digitalExprNode{val: realValue(f64, boolOpposite, nil)} } func (de *digitalExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { return de.val } type nilExprNode struct { exprBackground val interface{} } func (ne *nilExprNode) String() string { return "" } var nilRegexp = regexp.MustCompile(`^nil([\)\],\|&!= \t]{1}|$)`) func readNilExprNode(expr *string) ExprNode { last, boolOpposite := getOpposite(expr, "!") s := nilRegexp.FindString(last) if s == "" { return nil } *expr = last[3:] return &nilExprNode{val: realValue(nil, boolOpposite, nil)} } func (ne *nilExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { return ne.val } type variableExprNode struct { exprBackground boolOpposite *bool val string } func (ve *variableExprNode) String() string { return fmt.Sprintf("%v", ve.val) } func (ve *variableExprNode) Run(ctx context.Context, variableName string, _ *TagExpr) interface{} { envObj := ctx.Value(variableKey) if envObj == nil { return nil } env := envObj.(map[string]interface{}) if len(env) == 0 { return nil } if value, ok := env[ve.val]; ok && value != nil { return realValue(value, ve.boolOpposite, nil) } else { return nil } } var variableRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*`) func readVariableExprNode(expr *string) ExprNode { last, boolOpposite := getOpposite(expr, "!") variable := variableRegex.FindString(last) if variable == "" { return nil } *expr = (*expr)[len(*expr)-len(last)+len(variable):] return &variableExprNode{ val: variable, boolOpposite: boolOpposite, } } func getBoolAndSignOpposite(expr *string) (last string, boolOpposite *bool, signOpposite *bool) { last, boolOpposite = getOpposite(expr, "!") last = strings.TrimLeft(last, "+") last, signOpposite = getOpposite(&last, "-") last = strings.TrimLeft(last, "+") return } func getOpposite(expr *string, cutset string) (string, *bool) { last := strings.TrimLeft(*expr, cutset) n := len(*expr) - len(last) if n == 0 { return last, nil } bol := n&1 == 1 return last, &bol } func toString(i interface{}, enforce bool) (string, bool) { switch vv := i.(type) { case string: return vv, true case nil: return "", false default: rv := dereferenceValue(reflect.ValueOf(i)) if rv.Kind() == reflect.String { return rv.String(), true } if enforce { if rv.IsValid() && rv.CanInterface() { return fmt.Sprint(rv.Interface()), true } else { return fmt.Sprint(i), true } } } return "", false } func toFloat64(i interface{}, tryParse bool) (float64, bool) { var v float64 ok := true switch t := i.(type) { case float64: v = t case float32: v = float64(t) case int: v = float64(t) case int8: v = float64(t) case int16: v = float64(t) case int32: v = float64(t) case int64: v = float64(t) case uint: v = float64(t) case uint8: v = float64(t) case uint16: v = float64(t) case uint32: v = float64(t) case uint64: v = float64(t) case nil: ok = false default: rv := dereferenceValue(reflect.ValueOf(t)) switch rv.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: v = float64(rv.Int()) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: v = float64(rv.Uint()) case reflect.Float32, reflect.Float64: v = rv.Float() default: if tryParse { if s, ok := toString(i, false); ok { var err error v, err = strconv.ParseFloat(s, 64) return v, err == nil } } ok = false } } return v, ok } func realValue(v interface{}, boolOpposite *bool, signOpposite *bool) interface{} { if boolOpposite != nil { bol := FakeBool(v) if *boolOpposite { return !bol } return bol } switch t := v.(type) { case float64, string: case float32: v = float64(t) case int: v = float64(t) case int8: v = float64(t) case int16: v = float64(t) case int32: v = float64(t) case int64: v = float64(t) case uint: v = float64(t) case uint8: v = float64(t) case uint16: v = float64(t) case uint32: v = float64(t) case uint64: v = float64(t) case []interface{}: for k, v := range t { t[k] = realValue(v, boolOpposite, signOpposite) } default: rv := dereferenceValue(reflect.ValueOf(v)) switch rv.Kind() { case reflect.String: v = rv.String() case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: v = float64(rv.Int()) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: v = float64(rv.Uint()) case reflect.Float32, reflect.Float64: v = rv.Float() } } if signOpposite != nil && *signOpposite { if f, ok := v.(float64); ok { v = -f } } return v } ================================================ FILE: internal/tagexpr/spec_operator.go ================================================ // Copyright 2019 Bytedance Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package tagexpr import ( "context" "math" ) // --------------------------- Operator --------------------------- type additionExprNode struct{ exprBackground } func (ae *additionExprNode) String() string { return "+" } func newAdditionExprNode() ExprNode { return &additionExprNode{} } func (ae *additionExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { // positive number or Addition v0 := ae.leftOperand.Run(ctx, currField, tagExpr) v1 := ae.rightOperand.Run(ctx, currField, tagExpr) if s0, ok := toFloat64(v0, false); ok { s1, _ := toFloat64(v1, true) return s0 + s1 } if s0, ok := toString(v0, false); ok { s1, _ := toString(v1, true) return s0 + s1 } return v0 } type multiplicationExprNode struct{ exprBackground } func (ae *multiplicationExprNode) String() string { return "*" } func newMultiplicationExprNode() ExprNode { return &multiplicationExprNode{} } func (ae *multiplicationExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { v0, _ := toFloat64(ae.leftOperand.Run(ctx, currField, tagExpr), true) v1, _ := toFloat64(ae.rightOperand.Run(ctx, currField, tagExpr), true) return v0 * v1 } type divisionExprNode struct{ exprBackground } func (de *divisionExprNode) String() string { return "/" } func newDivisionExprNode() ExprNode { return &divisionExprNode{} } func (de *divisionExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { v1, _ := toFloat64(de.rightOperand.Run(ctx, currField, tagExpr), true) if v1 == 0 { return math.NaN() } v0, _ := toFloat64(de.leftOperand.Run(ctx, currField, tagExpr), true) return v0 / v1 } type subtractionExprNode struct{ exprBackground } func (de *subtractionExprNode) String() string { return "-" } func newSubtractionExprNode() ExprNode { return &subtractionExprNode{} } func (de *subtractionExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { v0, _ := toFloat64(de.leftOperand.Run(ctx, currField, tagExpr), true) v1, _ := toFloat64(de.rightOperand.Run(ctx, currField, tagExpr), true) return v0 - v1 } type remainderExprNode struct{ exprBackground } func (re *remainderExprNode) String() string { return "%" } func newRemainderExprNode() ExprNode { return &remainderExprNode{} } func (re *remainderExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { v1, _ := toFloat64(re.rightOperand.Run(ctx, currField, tagExpr), true) if v1 == 0 { return math.NaN() } v0, _ := toFloat64(re.leftOperand.Run(ctx, currField, tagExpr), true) return float64(int64(v0) % int64(v1)) } type equalExprNode struct{ exprBackground } func (ee *equalExprNode) String() string { return "==" } func newEqualExprNode() ExprNode { return &equalExprNode{} } func (ee *equalExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { v0 := ee.leftOperand.Run(ctx, currField, tagExpr) v1 := ee.rightOperand.Run(ctx, currField, tagExpr) if v0 == v1 { return true } if s0, ok := toFloat64(v0, false); ok { if s1, ok := toFloat64(v1, true); ok { return s0 == s1 } } if s0, ok := toString(v0, false); ok { if s1, ok := toString(v1, true); ok { return s0 == s1 } return false } switch r := v0.(type) { case bool: r1, ok := v1.(bool) if ok { return r == r1 } case nil: return v1 == nil } return false } type notEqualExprNode struct{ equalExprNode } func (ne *notEqualExprNode) String() string { return "!=" } func newNotEqualExprNode() ExprNode { return ¬EqualExprNode{} } func (ne *notEqualExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { return !ne.equalExprNode.Run(ctx, currField, tagExpr).(bool) } type greaterExprNode struct{ exprBackground } func (ge *greaterExprNode) String() string { return ">" } func newGreaterExprNode() ExprNode { return &greaterExprNode{} } func (ge *greaterExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { v0 := ge.leftOperand.Run(ctx, currField, tagExpr) v1 := ge.rightOperand.Run(ctx, currField, tagExpr) if s0, ok := toFloat64(v0, false); ok { if s1, ok := toFloat64(v1, true); ok { return s0 > s1 } } if s0, ok := toString(v0, false); ok { if s1, ok := toString(v1, true); ok { return s0 > s1 } return false } return false } type greaterEqualExprNode struct{ exprBackground } func (ge *greaterEqualExprNode) String() string { return ">=" } func newGreaterEqualExprNode() ExprNode { return &greaterEqualExprNode{} } func (ge *greaterEqualExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { v0 := ge.leftOperand.Run(ctx, currField, tagExpr) v1 := ge.rightOperand.Run(ctx, currField, tagExpr) if s0, ok := toFloat64(v0, false); ok { if s1, ok := toFloat64(v1, true); ok { return s0 >= s1 } } if s0, ok := toString(v0, false); ok { if s1, ok := toString(v1, true); ok { return s0 >= s1 } return false } return false } type lessExprNode struct{ exprBackground } func (le *lessExprNode) String() string { return "<" } func newLessExprNode() ExprNode { return &lessExprNode{} } func (le *lessExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { v0 := le.leftOperand.Run(ctx, currField, tagExpr) v1 := le.rightOperand.Run(ctx, currField, tagExpr) if s0, ok := toFloat64(v0, false); ok { if s1, ok := toFloat64(v1, true); ok { return s0 < s1 } } if s0, ok := toString(v0, false); ok { if s1, ok := toString(v1, true); ok { return s0 < s1 } return false } return false } type lessEqualExprNode struct{ exprBackground } func (le *lessEqualExprNode) String() string { return "<=" } func newLessEqualExprNode() ExprNode { return &lessEqualExprNode{} } func (le *lessEqualExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { v0 := le.leftOperand.Run(ctx, currField, tagExpr) v1 := le.rightOperand.Run(ctx, currField, tagExpr) if s0, ok := toFloat64(v0, false); ok { if s1, ok := toFloat64(v1, true); ok { return s0 <= s1 } } if s0, ok := toString(v0, false); ok { if s1, ok := toString(v1, true); ok { return s0 <= s1 } return false } return false } type andExprNode struct{ exprBackground } func (ae *andExprNode) String() string { return "&&" } func newAndExprNode() ExprNode { return &andExprNode{} } func (ae *andExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { for _, e := range [2]ExprNode{ae.leftOperand, ae.rightOperand} { if !FakeBool(e.Run(ctx, currField, tagExpr)) { return false } } return true } type orExprNode struct{ exprBackground } func (oe *orExprNode) String() string { return "||" } func newOrExprNode() ExprNode { return &orExprNode{} } func (oe *orExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { for _, e := range [2]ExprNode{oe.leftOperand, oe.rightOperand} { if FakeBool(e.Run(ctx, currField, tagExpr)) { return true } } return false } ================================================ FILE: internal/tagexpr/spec_range.go ================================================ // Copyright 2019 Bytedance Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package tagexpr import ( "context" "reflect" "regexp" ) type rangeCtxKey string const ( rangeKey rangeCtxKey = "#k" rangeValue rangeCtxKey = "#v" rangeLen rangeCtxKey = "##" ) type rangeKvExprNode struct { exprBackground ctxKey rangeCtxKey boolOpposite *bool signOpposite *bool } func (re *rangeKvExprNode) String() string { return string(re.ctxKey) } func (p *Expr) readRangeKvExprNode(expr *string) ExprNode { name, boolOpposite, signOpposite, found := findRangeKv(expr) if !found { return nil } operand := &rangeKvExprNode{ ctxKey: rangeCtxKey(name), boolOpposite: boolOpposite, signOpposite: signOpposite, } // fmt.Printf("operand: %#v\n", operand) return operand } var rangeKvRegexp = regexp.MustCompile(`^([\!\+\-]*)(#[kv#])([\)\[\],\+\-\*\/%><\|&!=\^ \t\\]|$)`) func findRangeKv(expr *string) (name string, boolOpposite, signOpposite *bool, found bool) { raw := *expr a := rangeKvRegexp.FindAllStringSubmatch(raw, -1) if len(a) != 1 { return } r := a[0] name = r[2] *expr = (*expr)[len(a[0][0])-len(r[3]):] prefix := r[1] if len(prefix) == 0 { found = true return } _, boolOpposite, signOpposite = getBoolAndSignOpposite(&prefix) found = true return } func (re *rangeKvExprNode) Run(ctx context.Context, _ string, _ *TagExpr) interface{} { var v interface{} switch val := ctx.Value(re.ctxKey).(type) { case reflect.Value: if !val.IsValid() || !val.CanInterface() { return nil } v = val.Interface() default: v = val } return realValue(v, re.boolOpposite, re.signOpposite) } type rangeFuncExprNode struct { exprBackground object ExprNode elemExprNode ExprNode boolOpposite *bool signOpposite *bool } func (e *rangeFuncExprNode) String() string { return "range()" } // range($, gt($v,10)) // range($, $v>10) func readRangeFuncExprNode(p *Expr, expr *string) ExprNode { boolOpposite, signOpposite, args, found := p.parseFuncSign("range", expr) if !found { return nil } if len(args) != 2 { return nil } return &rangeFuncExprNode{ boolOpposite: boolOpposite, signOpposite: signOpposite, object: args[0], elemExprNode: args[1], } } func (e *rangeFuncExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { var r []interface{} obj := e.object.Run(ctx, currField, tagExpr) // fmt.Printf("%v\n", obj) objval := reflect.ValueOf(obj) switch objval.Kind() { case reflect.Array, reflect.Slice: count := objval.Len() r = make([]interface{}, count) ctx = context.WithValue(ctx, rangeLen, count) for i := 0; i < count; i++ { // fmt.Printf("%#v, (%v)\n", e.elemExprNode, objval.Index(i)) r[i] = realValue(e.elemExprNode.Run( context.WithValue( context.WithValue( ctx, rangeKey, i, ), rangeValue, objval.Index(i), ), currField, tagExpr, ), e.boolOpposite, e.signOpposite) } case reflect.Map: keys := objval.MapKeys() count := len(keys) r = make([]interface{}, count) ctx = context.WithValue(ctx, rangeLen, count) for i, key := range keys { r[i] = realValue(e.elemExprNode.Run( context.WithValue( context.WithValue( ctx, rangeKey, key, ), rangeValue, objval.MapIndex(key), ), currField, tagExpr, ), e.boolOpposite, e.signOpposite) } default: } return r } ================================================ FILE: internal/tagexpr/spec_range_test.go ================================================ // Copyright 2019 Bytedance Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package tagexpr_test import ( "reflect" "testing" "github.com/cloudwego/hertz/internal/tagexpr" ) func TestIssue12(t *testing.T) { vm := tagexpr.New("te") type I int type S struct { F []I `te:"range($, '>'+sprintf('%v:%v', #k, #v+2+len($)))"` Fs [][]I `te:"range($, range(#v, '>'+sprintf('%v:%v', #k, #v+2+##)))"` M map[string]I `te:"range($, '>'+sprintf('%s:%v', #k, #v+2+##))"` MFs []map[string][]I `te:"range($, range(#v, range(#v, '>'+sprintf('%v:%v', #k, #v+2+##))))"` MFs2 []map[string][]I `te:"range($, range(#v, range(#v, '>'+sprintf('%v:%v', #k, #v+2+##))))"` } a := []I{2, 3} r := vm.MustRun(S{ F: a, Fs: [][]I{a}, M: map[string]I{"m0": 2, "m1": 3}, MFs: []map[string][]I{{"m": a}}, MFs2: []map[string][]I{}, }) assertEqual(t, []interface{}{">0:6", ">1:7"}, r.Eval("F")) assertEqual(t, []interface{}{[]interface{}{">0:6", ">1:7"}}, r.Eval("Fs")) assertEqual(t, []interface{}{[]interface{}{[]interface{}{">0:6", ">1:7"}}}, r.Eval("MFs")) assertEqual(t, []interface{}{}, r.Eval("MFs2")) assertEqual(t, true, r.EvalBool("MFs2")) // result may not stable for map got := r.Eval("M") if !reflect.DeepEqual([]interface{}{">m0:6", ">m1:7"}, got) && !reflect.DeepEqual([]interface{}{">m1:7", ">m0:6"}, got) { t.Fatal(got) } } ================================================ FILE: internal/tagexpr/spec_selector.go ================================================ // Copyright 2019 Bytedance Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package tagexpr import ( "context" "fmt" "regexp" "strings" ) type selectorExprNode struct { exprBackground field, name string subExprs []ExprNode boolOpposite *bool signOpposite *bool } func (se *selectorExprNode) String() string { return fmt.Sprintf("(%s)%s", se.field, se.name) } func (p *Expr) readSelectorExprNode(expr *string) ExprNode { field, name, subSelector, boolOpposite, signOpposite, found := findSelector(expr) if !found { return nil } operand := &selectorExprNode{ field: field, name: name, boolOpposite: boolOpposite, signOpposite: signOpposite, } operand.subExprs = make([]ExprNode, 0, len(subSelector)) for _, s := range subSelector { grp := newGroupExprNode() err := p.parseExprNode(&s, grp) if err != nil { return nil } sortPriority(grp) operand.subExprs = append(operand.subExprs, grp) } return operand } var selectorRegexp = regexp.MustCompile(`^([\!\+\-]*)(\([ \t]*[A-Za-z_]+[A-Za-z0-9_\.]*[ \t]*\))?(\$)([\)\[\],\+\-\*\/%><\|&!=\^ \t\\]|$)`) func findSelector(expr *string) (field string, name string, subSelector []string, boolOpposite, signOpposite *bool, found bool) { raw := *expr a := selectorRegexp.FindAllStringSubmatch(raw, -1) if len(a) != 1 { return } r := a[0] if s0 := r[2]; len(s0) > 0 { field = strings.TrimSpace(s0[1 : len(s0)-1]) } name = r[3] *expr = (*expr)[len(a[0][0])-len(r[4]):] for { sub := readPairedSymbol(expr, '[', ']') if sub == nil { break } if *sub == "" || (*sub)[0] == '[' { *expr = raw return "", "", nil, nil, nil, false } subSelector = append(subSelector, strings.TrimSpace(*sub)) } prefix := r[1] if len(prefix) == 0 { found = true return } _, boolOpposite, signOpposite = getBoolAndSignOpposite(&prefix) found = true return } func (se *selectorExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { var subFields []interface{} if n := len(se.subExprs); n > 0 { subFields = make([]interface{}, n) for i, e := range se.subExprs { subFields[i] = e.Run(ctx, currField, tagExpr) } } field := se.field if field == "" { field = currField } v := tagExpr.getValue(field, subFields) return realValue(v, se.boolOpposite, se.signOpposite) } ================================================ FILE: internal/tagexpr/spec_test.go ================================================ // Copyright 2019 Bytedance Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package tagexpr import ( "context" "reflect" "testing" ) func TestReadPairedSymbol(t *testing.T) { cases := []struct { left, right rune expr, val, lastExprNode string }{ {left: '\'', right: '\'', expr: "'true '+'a'", val: "true ", lastExprNode: "+'a'"}, {left: '(', right: ')', expr: "((0+1)/(2-1)*9)%2", val: "(0+1)/(2-1)*9", lastExprNode: "%2"}, {left: '(', right: ')', expr: `(\)\(\))`, val: `)()`}, {left: '\'', right: '\'', expr: `'\\'`, val: `\\`}, {left: '\'', right: '\'', expr: `'\'\''`, val: `''`}, } for _, c := range cases { t.Log(c.expr) expr := c.expr got := readPairedSymbol(&expr, c.left, c.right) if got == nil { t.Fatalf("expr: %q, got: %v, %q, want: %q, %q", c.expr, got, expr, c.val, c.lastExprNode) } else if *got != c.val || expr != c.lastExprNode { t.Fatalf("expr: %q, got: %q, %q, want: %q, %q", c.expr, *got, expr, c.val, c.lastExprNode) } } } func TestReadBoolExprNode(t *testing.T) { cases := []struct { expr string val bool lastExprNode string }{ {expr: "false", val: false, lastExprNode: ""}, {expr: "true", val: true, lastExprNode: ""}, {expr: "true ", val: true, lastExprNode: " "}, {expr: "!true&", val: false, lastExprNode: "&"}, {expr: "!false|", val: true, lastExprNode: "|"}, {expr: "!!!!false =", val: !!!!false, lastExprNode: " ="}, //nolint:staticcheck // SA4013: negating a boolean twice has no effect } for _, c := range cases { t.Log(c.expr) expr := c.expr e := readBoolExprNode(&expr) got := e.Run(context.TODO(), "", nil).(bool) if got != c.val || expr != c.lastExprNode { t.Fatalf("expr: %s, got: %v, %s, want: %v, %s", c.expr, got, expr, c.val, c.lastExprNode) } } } func TestReadDigitalExprNode(t *testing.T) { cases := []struct { expr string val float64 lastExprNode string }{ {expr: "0.1 +1", val: 0.1, lastExprNode: " +1"}, {expr: "-1\\1", val: -1, lastExprNode: "\\1"}, {expr: "1a", val: 0, lastExprNode: ""}, {expr: "1", val: 1, lastExprNode: ""}, {expr: "1.1", val: 1.1, lastExprNode: ""}, {expr: "1.1/", val: 1.1, lastExprNode: "/"}, } for _, c := range cases { expr := c.expr e := readDigitalExprNode(&expr) if c.expr == "1a" { if e != nil { t.Fatalf("expr: %s, got:%v, want:%v", c.expr, e.Run(context.TODO(), "", nil), nil) } continue } got := e.Run(context.TODO(), "", nil).(float64) if got != c.val || expr != c.lastExprNode { t.Fatalf("expr: %s, got: %f, %s, want: %f, %s", c.expr, got, expr, c.val, c.lastExprNode) } } } func TestFindSelector(t *testing.T) { cases := []struct { expr string field string name string subSelector []string boolOpposite bool signOpposite bool found bool last string }{ {expr: "$", name: "$", found: true}, {expr: "!!$", name: "$", found: true}, {expr: "!$", name: "$", boolOpposite: true, found: true}, {expr: "+$", name: "$", found: true}, {expr: "--$", name: "$", found: true}, {expr: "-$", name: "$", signOpposite: true, found: true}, {expr: "---$", name: "$", signOpposite: true, found: true}, {expr: "()$", last: "()$"}, {expr: "(0)$", last: "(0)$"}, {expr: "(A)$", field: "A", name: "$", found: true}, {expr: "+(A)$", field: "A", name: "$", found: true}, {expr: "++(A)$", field: "A", name: "$", found: true}, {expr: "!(A)$", field: "A", name: "$", boolOpposite: true, found: true}, {expr: "-(A)$", field: "A", name: "$", signOpposite: true, found: true}, {expr: "(A0)$", field: "A0", name: "$", found: true}, {expr: "!!(A0)$", field: "A0", name: "$", found: true}, {expr: "--(A0)$", field: "A0", name: "$", found: true}, {expr: "(A0)$(A1)$", last: "(A0)$(A1)$"}, {expr: "(A0)$ $(A1)$", field: "A0", name: "$", found: true, last: " $(A1)$"}, {expr: "$a", last: "$a"}, {expr: "$[1]['a']", name: "$", subSelector: []string{"1", "'a'"}, found: true}, {expr: "$[1][]", last: "$[1][]"}, {expr: "$[[]]", last: "$[[]]"}, {expr: "$[[[]]]", last: "$[[[]]]"}, {expr: "$[(A)$[1]]", name: "$", subSelector: []string{"(A)$[1]"}, found: true}, {expr: "$>0&&$<10", name: "$", found: true, last: ">0&&$<10"}, } for _, c := range cases { last := c.expr field, name, subSelector, boolOpposite, signOpposite, found := findSelector(&last) if found != c.found { t.Fatalf("%q found: got: %v, want: %v", c.expr, found, c.found) } if c.boolOpposite && (boolOpposite == nil || !*boolOpposite) { t.Fatalf("%q boolOpposite: got: %v, want: %v", c.expr, boolOpposite, c.boolOpposite) } if c.signOpposite && (signOpposite == nil || !*signOpposite) { t.Fatalf("%q signOpposite: got: %v, want: %v", c.expr, signOpposite, c.signOpposite) } if field != c.field { t.Fatalf("%q field: got: %q, want: %q", c.expr, field, c.field) } if name != c.name { t.Fatalf("%q name: got: %q, want: %q", c.expr, name, c.name) } if !reflect.DeepEqual(subSelector, c.subSelector) { t.Fatalf("%q subSelector: got: %v, want: %v", c.expr, subSelector, c.subSelector) } if last != c.last { t.Fatalf("%q last: got: %q, want: %q", c.expr, last, c.last) } } } ================================================ FILE: internal/tagexpr/tagexpr.go ================================================ // Copyright 2019 Bytedance Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package tagexpr is an interesting go struct tag expression syntax for field validation, etc. package tagexpr import ( "errors" "fmt" "reflect" "strconv" "strings" "sync" "unsafe" ) // Internally unified data types type ( Number = float64 Null = interface{} Boolean = bool String = string ) // VM struct tag expression interpreter type VM struct { tagName string structJar map[uintptr]*structVM rw sync.RWMutex } // structVM tag expression set of struct type structVM struct { vm *VM name string fields map[string]*fieldVM fieldSelectorList []string fieldsWithIndirectStructVM []*fieldVM exprs map[string]*Expr exprSelectorList []string ifaceTagExprGetters []func(unsafe.Pointer, string, func(*TagExpr, error) error) error err error } // fieldVM tag expression set of struct field type fieldVM struct { structField reflect.StructField ptrDeep int getPtr func(unsafe.Pointer) unsafe.Pointer elemType reflect.Type elemKind reflect.Kind valueGetter func(unsafe.Pointer) interface{} reflectValueGetter func(unsafe.Pointer, bool) reflect.Value exprs map[string]*Expr origin *structVM mapKeyStructVM *structVM mapOrSliceElemStructVM *structVM mapOrSliceIfaceKinds [2]bool // [value, key/index] fieldSelector string tagOp string } // New creates a tag expression interpreter that uses tagName as the tag name. // NOTE: // // If no tagName is specified, no tag expression will be interpreted, // but still can operate the various fields. func New(tagName ...string) *VM { if len(tagName) == 0 { tagName = append(tagName, "") } return &VM{ tagName: tagName[0], structJar: make(map[uintptr]*structVM, 256), } } // MustRun is similar to Run, but panic when error. func (vm *VM) MustRun(structOrStructPtrOrReflectValue interface{}) *TagExpr { te, err := vm.Run(structOrStructPtrOrReflectValue) if err != nil { panic(err) } return te } var ( unsupportedNil = errors.New("unsupported data: nil") unsupportedCannotAddr = errors.New("unsupported data: can not addr") ) // Run returns the tag expression handler of the @structPtrOrReflectValue. // NOTE: // // If the structure type has not been warmed up, // it will be slower when it is first called. // // Disable new -d=checkptr behaviour for Go 1.14 // //go:nocheckptr func (vm *VM) Run(structPtrOrReflectValue interface{}) (*TagExpr, error) { var v reflect.Value switch t := structPtrOrReflectValue.(type) { case reflect.Value: v = dereferenceValue(t) default: v = dereferenceValue(reflect.ValueOf(t)) } if err := checkStructMapAddr(v); err != nil { return nil, err } ptr := rvPtr(v) if ptr == nil { return nil, unsupportedNil } tid := rvType(v) var err error vm.rw.RLock() s, ok := vm.structJar[tid] vm.rw.RUnlock() if !ok { vm.rw.Lock() s, ok = vm.structJar[tid] if !ok { s, err = vm.registerStructLocked(v.Type()) if err != nil { vm.rw.Unlock() return nil, err } } vm.rw.Unlock() } if s.err != nil { return nil, s.err } return s.newTagExpr(ptr, ""), nil } // RunAny returns the tag expression handler for the @v. // NOTE: // // The @v can be structured data such as struct, map, slice, array, interface, reflcet.Value, etc. // If the structure type has not been warmed up, // it will be slower when it is first called. func (vm *VM) RunAny(v interface{}, fn func(*TagExpr, error) error) error { vv, isReflectValue := v.(reflect.Value) if !isReflectValue { vv = reflect.ValueOf(v) } return vm.subRunAll(false, "", vv, fn) } // check type: struct{F map[T1]T2} func checkStructMapAddr(v reflect.Value) error { if !v.IsValid() || v.CanAddr() || v.NumField() != 1 || v.Field(0).Kind() != reflect.Map { return nil } return unsupportedCannotAddr } func (vm *VM) subRunAll(omitNil bool, tePath string, value reflect.Value, fn func(*TagExpr, error) error) error { rv := dereferenceInterfaceValue(value) if !rv.IsValid() { return nil } rt := dereferenceType(rv.Type()) rv = dereferenceValue(rv) switch rt.Kind() { case reflect.Struct: if len(tePath) == 0 { if err := checkStructMapAddr(rv); err != nil { return err } } ptr := rvPtr(rv) if ptr == nil { if omitNil { return nil } return fn(nil, unsupportedNil) } return fn(vm.subRun(tePath, rt, rvType(rv), ptr)) case reflect.Slice, reflect.Array: count := rv.Len() if count == 0 { return nil } switch dereferenceType(rv.Type().Elem()).Kind() { case reflect.Struct, reflect.Interface, reflect.Slice, reflect.Array, reflect.Map: for i := count - 1; i >= 0; i-- { err := vm.subRunAll(omitNil, tePath+"["+strconv.Itoa(i)+"]", rv.Index(i), fn) if err != nil { return err } } default: return nil } case reflect.Map: if rv.Len() == 0 { return nil } var canKey, canValue bool rt := rv.Type() switch dereferenceType(rt.Key()).Kind() { case reflect.Struct, reflect.Interface, reflect.Slice, reflect.Array, reflect.Map: canKey = true } switch dereferenceType(rt.Elem()).Kind() { case reflect.Struct, reflect.Interface, reflect.Slice, reflect.Array, reflect.Map: canValue = true } if !canKey && !canValue { return nil } for _, key := range rv.MapKeys() { if canKey { err := vm.subRunAll(omitNil, tePath+"{k}", key, fn) if err != nil { return err } } if canValue { err := vm.subRunAll(omitNil, tePath+"{v for k="+key.String()+"}", rv.MapIndex(key), fn) if err != nil { return err } } } } return nil } func (vm *VM) subRun(path string, t reflect.Type, tid uintptr, ptr unsafe.Pointer) (*TagExpr, error) { var err error vm.rw.RLock() s, ok := vm.structJar[tid] vm.rw.RUnlock() if !ok { vm.rw.Lock() s, ok = vm.structJar[tid] if !ok { s, err = vm.registerStructLocked(t) if err != nil { vm.rw.Unlock() return nil, err } } vm.rw.Unlock() } if s.err != nil { return nil, s.err } return s.newTagExpr(ptr, path), nil } func (vm *VM) registerStructLocked(structType reflect.Type) (*structVM, error) { structType, err := vm.getStructType(structType) if err != nil { return nil, err } tid := rtType(structType) s, had := vm.structJar[tid] if had { return s, s.err } s = vm.newStructVM() s.name = structType.String() vm.structJar[tid] = s numField := structType.NumField() var structField reflect.StructField var sub *structVM for i := 0; i < numField; i++ { structField = structType.Field(i) if f := structField; !f.IsExported() && strings.HasPrefix(f.Type.PkgPath(), "google.golang.org/protobuf") { // Skip unexported protobuf internal fields to prevent stack overflow. // Fields like `state protoimpl.MessageState` may contain cyclic references // after unmarshaling, causing infinite recursion during field processing. // See: https://github.com/cloudwego/hertz/issues/1410 // // This is a temporary solution, // we should fix cyclic references when calling Range continue } field, ok, err := s.newFieldVM(structField) if err != nil { s.err = err return nil, err } // skip omitted tag if !ok { continue } switch field.elemKind { default: field.setUnsupportedGetter() switch field.elemKind { case reflect.Struct: sub, err = vm.registerStructLocked(field.structField.Type) if err != nil { s.err = err return nil, err } s.mergeSubStructVM(field, sub) case reflect.Interface: s.setIfaceTagExprGetter(field) } case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: field.setFloatGetter() case reflect.String: field.setStringGetter() case reflect.Bool: field.setBoolGetter() case reflect.Array, reflect.Slice, reflect.Map: err = vm.registerIndirectStructLocked(field) if err != nil { s.err = err return nil, err } } } return s, nil } func (vm *VM) registerIndirectStructLocked(field *fieldVM) error { field.setLengthGetter() if field.tagOp == tagOmit { return nil } a := make([]reflect.Type, 1, 2) a[0] = derefType(field.elemType.Elem()) if field.elemKind == reflect.Map { a = append(a, derefType(field.elemType.Key())) } for i, t := range a { kind := t.Kind() switch kind { case reflect.Interface: field.mapOrSliceIfaceKinds[i] = true field.origin.fieldsWithIndirectStructVM = appendDistinct(field.origin.fieldsWithIndirectStructVM, field) case reflect.Slice, reflect.Array, reflect.Map: tt := t.Elem() checkMap := kind == reflect.Map F2: for { switch tt.Kind() { case reflect.Slice, reflect.Array, reflect.Map, reflect.Ptr: tt = tt.Elem() case reflect.Struct: _, err := vm.registerStructLocked(tt) if err != nil { return err } field.mapOrSliceIfaceKinds[i] = true field.origin.fieldsWithIndirectStructVM = appendDistinct(field.origin.fieldsWithIndirectStructVM, field) break F2 default: break F2 } } if checkMap { tt = t.Key() checkMap = false goto F2 } case reflect.Struct: s, err := vm.registerStructLocked(t) if err != nil { return err } if len(s.exprSelectorList) > 0 || len(s.ifaceTagExprGetters) > 0 || len(s.fieldsWithIndirectStructVM) > 0 { if i == 0 { field.mapOrSliceElemStructVM = s } else { field.mapKeyStructVM = s } field.origin.fieldsWithIndirectStructVM = appendDistinct(field.origin.fieldsWithIndirectStructVM, field) } } } return nil } func appendDistinct(a []*fieldVM, i *fieldVM) []*fieldVM { has := false for _, e := range a { if e == i { has = true break } } if !has { return append(a, i) } return a } func (vm *VM) newStructVM() *structVM { return &structVM{ vm: vm, fields: make(map[string]*fieldVM, 32), fieldSelectorList: make([]string, 0, 32), fieldsWithIndirectStructVM: make([]*fieldVM, 0, 32), exprs: make(map[string]*Expr, 64), exprSelectorList: make([]string, 0, 64), } } func (s *structVM) newFieldVM(structField reflect.StructField) (*fieldVM, bool, error) { tag := structField.Tag.Get(s.vm.tagName) if tag == tagOmit { return nil, false, nil } f := &fieldVM{ structField: structField, exprs: make(map[string]*Expr, 8), origin: s, fieldSelector: structField.Name, } err := f.parseExprs(tag) if err != nil { return nil, false, err } s.fields[f.fieldSelector] = f s.fieldSelectorList = append(s.fieldSelectorList, f.fieldSelector) t := structField.Type var ptrDeep int for t.Kind() == reflect.Ptr { t = t.Elem() ptrDeep++ } f.ptrDeep = ptrDeep offset := structField.Offset f.getPtr = func(ptr unsafe.Pointer) unsafe.Pointer { if ptr == nil { return nil } return unsafe.Pointer(uintptr(ptr) + offset) } f.elemType = t f.elemKind = t.Kind() f.reflectValueGetter = func(ptr unsafe.Pointer, initZero bool) reflect.Value { v := f.packRawFrom(ptr) if initZero { f.ensureInit(v) } return v } return f, true, nil } func (f *fieldVM) ensureInit(v reflect.Value) { if safeIsNil(v) && v.CanSet() { newField := reflect.New(f.elemType).Elem() for i := 0; i < f.ptrDeep; i++ { if newField.CanAddr() { newField = newField.Addr() } else { newField2 := reflect.New(newField.Type()) newField2.Elem().Set(newField) newField = newField2 } } v.Set(newField) } } func (s *structVM) mergeSubStructVM(field *fieldVM, sub *structVM) { field.origin = sub fieldsWithIndirectStructVM := make(map[*fieldVM]struct{}, len(sub.fieldsWithIndirectStructVM)) for _, subField := range sub.fieldsWithIndirectStructVM { fieldsWithIndirectStructVM[subField] = struct{}{} } for _, k := range sub.fieldSelectorList { v := sub.fields[k] f := s.newChildField(field, v, true) if _, ok := fieldsWithIndirectStructVM[v]; ok { s.fieldsWithIndirectStructVM = append(s.fieldsWithIndirectStructVM, f) // TODO: maybe needed? // delete(fieldsWithIndirectStructVM, v) } } // TODO: maybe needed? // for v := range fieldsWithIndirectStructVM { // f := s.newChildField(field, v, false) // s.fieldsWithIndirectStructVM = append(s.fieldsWithIndirectStructVM, f) // } for _, _subFn := range sub.ifaceTagExprGetters { subFn := _subFn s.ifaceTagExprGetters = append(s.ifaceTagExprGetters, func(ptr unsafe.Pointer, pathPrefix string, fn func(*TagExpr, error) error) error { ptr = field.getElemPtr(ptr) if ptr == nil { return nil } var path string if pathPrefix == "" { path = field.fieldSelector } else { path = pathPrefix + FieldSeparator + field.fieldSelector } return subFn(ptr, path, fn) }) } } func (s *structVM) newChildField(parent *fieldVM, child *fieldVM, toBind bool) *fieldVM { f := &fieldVM{ structField: child.structField, exprs: make(map[string]*Expr, len(child.exprs)), ptrDeep: child.ptrDeep, elemType: child.elemType, elemKind: child.elemKind, origin: child.origin, mapKeyStructVM: child.mapKeyStructVM, mapOrSliceElemStructVM: child.mapOrSliceElemStructVM, mapOrSliceIfaceKinds: child.mapOrSliceIfaceKinds, fieldSelector: parent.fieldSelector + FieldSeparator + child.fieldSelector, } if parent.tagOp != tagOmit { f.tagOp = child.tagOp } else { f.tagOp = parent.tagOp } f.getPtr = func(ptr unsafe.Pointer) unsafe.Pointer { ptr = parent.getElemPtr(ptr) if ptr == nil { return nil } return child.getPtr(ptr) } if child.valueGetter != nil { if parent.ptrDeep == 0 { f.valueGetter = func(ptr unsafe.Pointer) interface{} { return child.valueGetter(parent.getPtr(ptr)) } f.reflectValueGetter = func(ptr unsafe.Pointer, initZero bool) reflect.Value { return child.reflectValueGetter(parent.getPtr(ptr), initZero) } } else { f.valueGetter = func(ptr unsafe.Pointer) interface{} { newField := reflect.NewAt(parent.structField.Type, parent.getPtr(ptr)) for i := 0; i < parent.ptrDeep; i++ { newField = newField.Elem() } if newField.IsNil() { return nil } return child.valueGetter(unsafe.Pointer(newField.Pointer())) } f.reflectValueGetter = func(ptr unsafe.Pointer, initZero bool) reflect.Value { newField := reflect.NewAt(parent.structField.Type, parent.getPtr(ptr)) if initZero { parent.ensureInit(newField.Elem()) } for i := 0; i < parent.ptrDeep; i++ { newField = newField.Elem() } if (newField == reflect.Value{}) || (!initZero && newField.IsNil()) { return reflect.Value{} } return child.reflectValueGetter(unsafe.Pointer(newField.Pointer()), initZero) } } } if toBind { s.fields[f.fieldSelector] = f s.fieldSelectorList = append(s.fieldSelectorList, f.fieldSelector) if parent.tagOp != tagOmit { for k, v := range child.exprs { selector := parent.fieldSelector + FieldSeparator + k f.exprs[selector] = v s.exprs[selector] = v s.exprSelectorList = append(s.exprSelectorList, selector) } } } return f } func (f *fieldVM) getElemPtr(ptr unsafe.Pointer) unsafe.Pointer { ptr = f.getPtr(ptr) for i := f.ptrDeep; ptr != nil && i > 0; i-- { ptr = ptrElem(ptr) } return ptr } func (f *fieldVM) packRawFrom(ptr unsafe.Pointer) reflect.Value { return reflect.NewAt(f.structField.Type, f.getPtr(ptr)).Elem() } func (f *fieldVM) packElemFrom(ptr unsafe.Pointer) reflect.Value { return reflect.NewAt(f.elemType, f.getElemPtr(ptr)).Elem() } func (s *structVM) setIfaceTagExprGetter(f *fieldVM) { if f.tagOp == tagOmit { return } s.ifaceTagExprGetters = append(s.ifaceTagExprGetters, func(ptr unsafe.Pointer, pathPrefix string, fn func(*TagExpr, error) error) error { v := f.packElemFrom(ptr) if !v.IsValid() || v.IsNil() { return nil } var path string if pathPrefix == "" { path = f.fieldSelector } else { path = pathPrefix + FieldSeparator + f.fieldSelector } return s.vm.subRunAll(f.tagOp == tagOmitNil, path, v, fn) }) } func (f *fieldVM) setFloatGetter() { if f.ptrDeep == 0 { f.valueGetter = func(ptr unsafe.Pointer) interface{} { ptr = f.getPtr(ptr) if ptr == nil { return nil } return getFloat64(f.elemKind, ptr) } } else { f.valueGetter = func(ptr unsafe.Pointer) interface{} { v := f.packElemFrom(ptr) if v.CanAddr() { return getFloat64(f.elemKind, unsafe.Pointer(v.UnsafeAddr())) } return nil } } } func (f *fieldVM) setBoolGetter() { if f.ptrDeep == 0 { f.valueGetter = func(ptr unsafe.Pointer) interface{} { ptr = f.getPtr(ptr) if ptr == nil { return nil } return *(*bool)(ptr) } } else { f.valueGetter = func(ptr unsafe.Pointer) interface{} { v := f.packElemFrom(ptr) if v.IsValid() { return v.Bool() } return nil } } } func (f *fieldVM) setStringGetter() { if f.ptrDeep == 0 { f.valueGetter = func(ptr unsafe.Pointer) interface{} { ptr = f.getPtr(ptr) if ptr == nil { return nil } return *(*string)(ptr) } } else { f.valueGetter = func(ptr unsafe.Pointer) interface{} { v := f.packElemFrom(ptr) if v.IsValid() { return v.String() } return nil } } } func (f *fieldVM) setLengthGetter() { f.valueGetter = func(ptr unsafe.Pointer) interface{} { v := f.packElemFrom(ptr) if v.IsValid() { return v.Interface() } return nil } } func (f *fieldVM) setUnsupportedGetter() { f.valueGetter = func(ptr unsafe.Pointer) interface{} { raw := f.packRawFrom(ptr) if safeIsNil(raw) { return nil } v := raw for i := 0; i < f.ptrDeep; i++ { v = v.Elem() } for v.Kind() == reflect.Interface { v = v.Elem() } return anyValueGetter(raw, v) } } func (vm *VM) getStructType(t reflect.Type) (reflect.Type, error) { structType := t for structType.Kind() == reflect.Ptr { structType = structType.Elem() } if structType.Kind() != reflect.Struct { return nil, fmt.Errorf("unsupported type: %s", t.String()) } return structType, nil } func (s *structVM) newTagExpr(ptr unsafe.Pointer, path string) *TagExpr { te := &TagExpr{ s: s, ptr: ptr, sub: make(map[string]*TagExpr, 8), path: strings.TrimPrefix(path, "."), } return te } // TagExpr struct tag expression evaluator type TagExpr struct { s *structVM ptr unsafe.Pointer sub map[string]*TagExpr path string } // EvalFloat evaluates the value of the struct tag expression by the selector expression. // NOTE: // // If the expression value type is not float64, return 0. func (t *TagExpr) EvalFloat(exprSelector string) float64 { r, _ := t.Eval(exprSelector).(float64) return r } // EvalString evaluates the value of the struct tag expression by the selector expression. // NOTE: // // If the expression value type is not string, return "". func (t *TagExpr) EvalString(exprSelector string) string { r, _ := t.Eval(exprSelector).(string) return r } // EvalBool evaluates the value of the struct tag expression by the selector expression. // NOTE: // // If the expression value is not 0, '' or nil, return true. func (t *TagExpr) EvalBool(exprSelector string) bool { return FakeBool(t.Eval(exprSelector)) } // FakeBool fakes any type as a boolean. func FakeBool(v interface{}) bool { switch r := v.(type) { case float64: return r != 0 case float32: return r != 0 case int: return r != 0 case int8: return r != 0 case int16: return r != 0 case int32: return r != 0 case int64: return r != 0 case uint: return r != 0 case uint8: return r != 0 case uint16: return r != 0 case uint32: return r != 0 case uint64: return r != 0 case string: return r != "" case bool: return r case nil, error: return false case []interface{}: bol := true for _, v := range r { bol = bol && FakeBool(v) } return bol default: // https://github.com/bytedance/go-tagexpr/blob/v2.9.2/tagexpr.go#L801 // the original implementation either returns false or panics for default case // we always return false for unsupported types to avoid introducing new behavior return false } } // Field returns the field handler specified by the selector. func (t *TagExpr) Field(fieldSelector string) (fh *FieldHandler, found bool) { f, ok := t.s.fields[fieldSelector] if !ok { return nil, false } return newFieldHandler(t, fieldSelector, f), true } // RangeFields loop through each field. // When fn returns false, interrupt traversal and return false. func (t *TagExpr) RangeFields(fn func(*FieldHandler) bool) bool { if list := t.s.fieldSelectorList; len(list) > 0 { fields := t.s.fields for _, fieldSelector := range list { if !fn(newFieldHandler(t, fieldSelector, fields[fieldSelector])) { return false } } } return true } // Eval evaluates the value of the struct tag expression by the selector expression. // NOTE: // // format: fieldName, fieldName.exprName, fieldName1.fieldName2.exprName1 // result types: float64, string, bool, nil func (t *TagExpr) Eval(exprSelector string) interface{} { expr, ok := t.s.exprs[exprSelector] if !ok { // Compatible with single mode or the expression with the name @ if strings.HasSuffix(exprSelector, ExprNameSeparator) { exprSelector = exprSelector[:len(exprSelector)-1] if strings.HasSuffix(exprSelector, ExprNameSeparator) { exprSelector = exprSelector[:len(exprSelector)-1] } expr, ok = t.s.exprs[exprSelector] } if !ok { return nil } } dir, base := splitFieldSelector(exprSelector) targetTagExpr, err := t.checkout(dir) if err != nil { return nil } return expr.run(base, targetTagExpr) } // EvalWithEnv evaluates the value with the given env // NOTE: // // format: fieldName, fieldName.exprName, fieldName1.fieldName2.exprName1 // result types: float64, string, bool, nil func (t *TagExpr) EvalWithEnv(exprSelector string, env map[string]interface{}) interface{} { expr, ok := t.s.exprs[exprSelector] if !ok { // Compatible with single mode or the expression with the name @ if strings.HasSuffix(exprSelector, ExprNameSeparator) { exprSelector = exprSelector[:len(exprSelector)-1] if strings.HasSuffix(exprSelector, ExprNameSeparator) { exprSelector = exprSelector[:len(exprSelector)-1] } expr, ok = t.s.exprs[exprSelector] } if !ok { return nil } } dir, base := splitFieldSelector(exprSelector) targetTagExpr, err := t.checkout(dir) if err != nil { return nil } return expr.runWithEnv(base, targetTagExpr, env) } // Range loop through each tag expression. // When fn returns false, interrupt traversal and return false. // NOTE: // // eval result types: float64, string, bool, nil func (t *TagExpr) Range(fn func(*ExprHandler) error) error { var err error if list := t.s.exprSelectorList; len(list) > 0 { for _, es := range list { dir, base := splitFieldSelector(es) targetTagExpr, err := t.checkout(dir) if err != nil { continue } err = fn(newExprHandler(t, targetTagExpr, base, es)) if err != nil { return err } } } ptr := t.ptr if list := t.s.fieldsWithIndirectStructVM; len(list) > 0 { for _, f := range list { v := f.packElemFrom(ptr) if !v.IsValid() { continue } omitNil := f.tagOp == tagOmitNil mapKeyStructVM := f.mapKeyStructVM mapOrSliceElemStructVM := f.mapOrSliceElemStructVM valueIface := f.mapOrSliceIfaceKinds[0] keyIface := f.mapOrSliceIfaceKinds[1] if f.elemKind == reflect.Map && (mapOrSliceElemStructVM != nil || mapKeyStructVM != nil || valueIface || keyIface) { keyPath := f.fieldSelector + "{k}" for _, key := range v.MapKeys() { if mapKeyStructVM != nil { p := rvPtr(derefValue(key)) if omitNil && p == nil { continue } err = mapKeyStructVM.newTagExpr(p, keyPath).Range(fn) if err != nil { return err } } else if keyIface { err = t.subRange(omitNil, keyPath, key, fn) if err != nil { return err } } if mapOrSliceElemStructVM != nil { p := rvPtr(derefValue(v.MapIndex(key))) if omitNil && p == nil { continue } err = mapOrSliceElemStructVM.newTagExpr(p, f.fieldSelector+"{v for k="+key.String()+"}").Range(fn) if err != nil { return err } } else if valueIface { err = t.subRange(omitNil, f.fieldSelector+"{v for k="+key.String()+"}", v.MapIndex(key), fn) if err != nil { return err } } } } else if mapOrSliceElemStructVM != nil || valueIface { // slice or array for i := v.Len() - 1; i >= 0; i-- { if mapOrSliceElemStructVM != nil { p := rvPtr(derefValue(v.Index(i))) if omitNil && p == nil { continue } err = mapOrSliceElemStructVM.newTagExpr(p, f.fieldSelector+"["+strconv.Itoa(i)+"]").Range(fn) if err != nil { return err } } else if valueIface { err = t.subRange(omitNil, f.fieldSelector+"["+strconv.Itoa(i)+"]", v.Index(i), fn) if err != nil { return err } } } } } } if list := t.s.ifaceTagExprGetters; len(list) > 0 { for _, getter := range list { err = getter(ptr, "", func(te *TagExpr, err error) error { if err != nil { return err } return te.Range(fn) }) if err != nil { return err } } } return nil } func (t *TagExpr) subRange(omitNil bool, path string, value reflect.Value, fn func(*ExprHandler) error) error { return t.s.vm.subRunAll(omitNil, path, value, func(te *TagExpr, err error) error { if err != nil { return err } return te.Range(fn) }) } var ( errFieldSelector = errors.New("field selector does not exist") errOmitNil = errors.New("omit nil") ) func (t *TagExpr) checkout(fs string) (*TagExpr, error) { if fs == "" { return t, nil } subTagExpr, ok := t.sub[fs] if ok { if subTagExpr == nil { return nil, errOmitNil } return subTagExpr, nil } f, ok := t.s.fields[fs] if !ok { return nil, errFieldSelector } ptr := f.getElemPtr(t.ptr) if f.tagOp == tagOmitNil && ptr == nil { t.sub[fs] = nil return nil, errOmitNil } subTagExpr = f.origin.newTagExpr(ptr, t.path) t.sub[fs] = subTagExpr return subTagExpr, nil } func (t *TagExpr) getValue(fieldSelector string, subFields []interface{}) (v interface{}) { f := t.s.fields[fieldSelector] if f == nil { return nil } if f.valueGetter == nil { return nil } v = f.valueGetter(t.ptr) if v == nil { return nil } if len(subFields) == 0 { return v } vv := reflect.ValueOf(v) var kind reflect.Kind for i, k := range subFields { kind = vv.Kind() for kind == reflect.Ptr || kind == reflect.Interface { vv = vv.Elem() kind = vv.Kind() } switch kind { case reflect.Slice, reflect.Array, reflect.String: if float, ok := k.(float64); ok { idx := int(float) if idx >= vv.Len() { return nil } vv = vv.Index(idx) } else { return nil } case reflect.Map: k := safeConvert(reflect.ValueOf(k), vv.Type().Key()) if !k.IsValid() { return nil } vv = vv.MapIndex(k) case reflect.Struct: if float, ok := k.(float64); ok { idx := int(float) if idx < 0 || idx >= vv.NumField() { return nil } vv = vv.Field(idx) } else if str, ok := k.(string); ok { vv = vv.FieldByName(str) } else { return nil } default: if i < len(subFields)-1 { return nil } } if !vv.IsValid() { return nil } } raw := vv for vv.Kind() == reflect.Ptr || vv.Kind() == reflect.Interface { vv = vv.Elem() } return anyValueGetter(raw, vv) } func safeConvert(v reflect.Value, t reflect.Type) reflect.Value { defer func() { recover() }() return v.Convert(t) } func splitFieldSelector(selector string) (dir, base string) { idx := strings.LastIndex(selector, ExprNameSeparator) if idx != -1 { selector = selector[:idx] } idx = strings.LastIndex(selector, FieldSeparator) if idx != -1 { return selector[:idx], selector[idx+1:] } return "", selector } func getFloat64(kind reflect.Kind, p unsafe.Pointer) interface{} { switch kind { case reflect.Float32: return float64(*(*float32)(p)) case reflect.Float64: return *(*float64)(p) case reflect.Int: return float64(*(*int)(p)) case reflect.Int8: return float64(*(*int8)(p)) case reflect.Int16: return float64(*(*int16)(p)) case reflect.Int32: return float64(*(*int32)(p)) case reflect.Int64: return float64(*(*int64)(p)) case reflect.Uint: return float64(*(*uint)(p)) case reflect.Uint8: return float64(*(*uint8)(p)) case reflect.Uint16: return float64(*(*uint16)(p)) case reflect.Uint32: return float64(*(*uint32)(p)) case reflect.Uint64: return float64(*(*uint64)(p)) case reflect.Uintptr: return float64(*(*uintptr)(p)) } return nil } func anyValueGetter(raw, elem reflect.Value) interface{} { if !elem.IsValid() || !raw.IsValid() { return nil } kind := elem.Kind() switch kind { case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: if elem.CanAddr() { return getFloat64(kind, unsafe.Pointer(elem.UnsafeAddr())) } switch kind { case reflect.Float32, reflect.Float64: return elem.Float() case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: return float64(elem.Int()) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: return float64(elem.Uint()) } case reflect.String: return elem.String() case reflect.Bool: return elem.Bool() } if raw.CanInterface() { return raw.Interface() } return nil } func safeIsNil(v reflect.Value) bool { if !v.IsValid() { return true } switch v.Kind() { case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice: return v.IsNil() } return false } //go:nocheckptr func ptrElem(ptr unsafe.Pointer) unsafe.Pointer { return unsafe.Pointer(*(*uintptr)(ptr)) } func derefType(t reflect.Type) reflect.Type { for t.Kind() == reflect.Ptr { t = t.Elem() } return t } func derefValue(v reflect.Value) reflect.Value { for v.Kind() == reflect.Ptr { v = v.Elem() } return v } ================================================ FILE: internal/tagexpr/tagexpr_test.go ================================================ // Copyright 2019 Bytedance Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package tagexpr_test import ( "encoding/json" "errors" "fmt" "reflect" "strconv" "testing" "time" "github.com/cloudwego/hertz/internal/tagexpr" "google.golang.org/protobuf/types/known/structpb" ) func assertEqual(t *testing.T, v1, v2 interface{}, msgs ...interface{}) { t.Helper() if reflect.DeepEqual(v1, v2) { return } t.Fatal(fmt.Sprintf("not equal %v %v", v1, v2) + "\n" + fmt.Sprint(msgs...)) } func BenchmarkTagExpr(b *testing.B) { type T struct { a int `bench:"$%3"` } vm := tagexpr.New("bench") vm.MustRun(new(T)) // warm up b.ReportAllocs() b.ResetTimer() t := &T{10} for i := 0; i < b.N; i++ { tagExpr, err := vm.Run(t) if err != nil { b.FailNow() } if tagExpr.EvalFloat("a") != 1 { b.FailNow() } } } func BenchmarkReflect(b *testing.B) { type T struct { a int `remainder:"3"` } b.ReportAllocs() b.ResetTimer() t := &T{1} for i := 0; i < b.N; i++ { v := reflect.ValueOf(t).Elem() ft, ok := v.Type().FieldByName("a") if !ok { b.FailNow() } x, err := strconv.ParseInt(ft.Tag.Get("remainder"), 10, 64) if err != nil { b.FailNow() } fv := v.FieldByName("a") if fv.Int()%x != 1 { b.FailNow() } } } func Test(t *testing.T) { g := &struct { _ int h string `tagexpr:"$"` s []string m map[string][]string }{ h: "haha", s: []string{"1"}, m: map[string][]string{"0": {"2"}}, } d := "ddd" e := new(int) *e = 3 type iface interface{} cases := []struct { tagName string structure interface{} tests map[string]interface{} }{ { tagName: "tagexpr", structure: &struct { A int `tagexpr:"$>0&&$<10&&!''&&!!!0&&!nil&&$"` A2 int `tagexpr:"@:$>0&&$<10"` b string `tagexpr:"is:$=='test';msg:sprintf('expect: test, but got: %s',$)"` c float32 `tagexpr:"(A)$+$"` d *string `tagexpr:"$"` e **int `tagexpr:"$"` f *[3]int `tagexpr:"x:len($)"` g string `tagexpr:"x:!regexp('xxx',$);y:regexp('g\\d{3}$')"` h []string `tagexpr:"x:$[1];y:$[10]"` i map[string]int `tagexpr:"x:$['a'];y:$[0];z:$==nil"` i2 *map[string]int `tagexpr:"x:$['a'];y:$[0];z:$"` j, j2 iface `tagexpr:"@:$==1;y:$"` k *iface `tagexpr:"$==nil"` m *struct{ i int } `tagexpr:"@:$;x:$['a']['x']"` }{ A: 5.0, A2: 5.0, b: "x", c: 1, d: &d, e: &e, f: new([3]int), g: "g123", h: []string{"", "hehe"}, i: map[string]int{"a": 7}, j2: iface(1), m: &struct{ i int }{1}, }, tests: map[string]interface{}{ "A": true, "A2": true, "b@is": false, "b@msg": "expect: test, but got: x", "c": 6.0, "d": d, "e": float64(*e), "f@x": float64(3), "g@x": true, "g@y": true, "h@x": "hehe", "h@y": nil, "i@x": 7.0, "i@y": nil, "i@z": false, "i2@x": nil, "i2@y": nil, "i2@z": nil, "j": false, "j@y": nil, "j2": true, "j2@y": 1.0, "k": true, "m": &struct{ i int }{1}, "m@x": nil, }, }, { tagName: "tagexpr", structure: &struct { A int `tagexpr:"$>0&&$<10"` b string `tagexpr:"is:$=='test';msg:sprintf('expect: test, but got: %s',$)"` c struct { _ int d bool `tagexpr:"$"` } e *struct { _ int f bool `tagexpr:"$"` } g **struct { _ int h string `tagexpr:"$"` s []string m map[string][]string } `tagexpr:"$['h']"` i string `tagexpr:"(g.s)$[0]+(g.m)$['0'][0]==$"` j bool `tagexpr:"!$"` k int `tagexpr:"!$"` m *int `tagexpr:"$==nil"` n *bool `tagexpr:"$==nil"` p *string `tagexpr:"$"` }{ A: 5, b: "x", c: struct { _ int d bool `tagexpr:"$"` }{d: true}, e: &struct { _ int f bool `tagexpr:"$"` }{f: true}, g: &g, i: "12", }, tests: map[string]interface{}{ "A": true, "b@is": false, "b@msg": "expect: test, but got: x", "c.d": true, "e.f": true, "g": "haha", "g.h": "haha", "i": true, "j": true, "k": true, "m": true, "n": true, "p": nil, }, }, { tagName: "p", structure: &struct { q *struct { x int } `p:"(q.x)$"` }{}, tests: map[string]interface{}{ "q": nil, }, }, } for i, c := range cases { vm := tagexpr.New(c.tagName) // vm.WarmUp(c.structure) tagExpr, err := vm.Run(c.structure) if err != nil { t.Fatal(err) } for selector, value := range c.tests { val := tagExpr.Eval(selector) if !reflect.DeepEqual(val, value) { t.Fatalf("Eval Serial: %d, selector: %q, got: %v, expect: %v", i, selector, val, value) } } tagExpr.Range(func(eh *tagexpr.ExprHandler) error { es := eh.ExprSelector() t.Logf("Range selector: %s, field: %q exprName: %q", es, es.Field(), es.Name()) value := c.tests[es.String()] val := eh.Eval() if !reflect.DeepEqual(val, value) { t.Fatalf("Range NO: %d, selector: %q, got: %v, expect: %v", i, es, val, value) } return nil }) } } func TestFieldNotInit(t *testing.T) { g := &struct { _ int h string s []string m map[string][]string }{ h: "haha", s: []string{"1"}, m: map[string][]string{"0": {"2"}}, } structure := &struct { A int b string c struct { _ int d *bool `expr:"test:nil"` } e *struct { _ int f bool } g **struct { _ int h string s []string m map[string][]string } i string j bool k int m *int n *bool p *string }{ A: 5, b: "x", e: &struct { _ int f bool }{f: true}, g: &g, i: "12", } vm := tagexpr.New("expr") e, err := vm.Run(structure) if err != nil { t.Fatal(err) } cases := []struct { fieldSelector string value interface{} }{ {"A", structure.A}, {"b", structure.b}, {"c", structure.c}, {"c._", 0}, {"c.d", structure.c.d}, {"e", structure.e}, {"e._", 0}, {"e.f", structure.e.f}, {"g", structure.g}, {"g._", 0}, {"g.h", (*structure.g).h}, {"g.s", (*structure.g).s}, {"g.m", (*structure.g).m}, {"i", structure.i}, {"j", structure.j}, {"k", structure.k}, {"m", structure.m}, {"n", structure.n}, {"p", structure.p}, } for _, c := range cases { fh, _ := e.Field(c.fieldSelector) val := fh.Value(false).Interface() assertEqual(t, c.value, val, c.fieldSelector) } var i int e.RangeFields(func(fh *tagexpr.FieldHandler) bool { val := fh.Value(false).Interface() if fh.StringSelector() == "c.d" { if fh.EvalFuncs()["c.d@test"] == nil { t.Fatal("nil") } } assertEqual(t, cases[i].value, val, fh.StringSelector()) i++ return true }) var wall uint64 = 1024 unix := time.Unix(1549186325, int64(wall)) e, err = vm.Run(&unix) if err != nil { t.Fatal(err) } fh, _ := e.Field("wall") val := fh.Value(false).Interface() if !reflect.DeepEqual(val, wall) { t.Fatalf("Time.wall: got: %v(%[1]T), expect: %v(%[2]T)", val, wall) } } func TestFieldInitZero(t *testing.T) { g := &struct { _ int h string s []string m map[string][]string }{ h: "haha", s: []string{"1"}, m: map[string][]string{"0": {"2"}}, } structure := &struct { A int b string c struct { _ int d *bool } e *struct { _ int f bool } g **struct { _ int h string s []string m map[string][]string } g2 ****struct { _ int h string s []string m map[string][]string } i string j bool k int m *int n *bool p *string }{ A: 5, b: "x", e: &struct { _ int f bool }{f: true}, g: &g, i: "12", } vm := tagexpr.New("") e, err := vm.Run(structure) if err != nil { t.Fatal(err) } cases := []struct { fieldSelector string value interface{} }{ {"A", structure.A}, {"b", structure.b}, {"c", struct { _ int d *bool }{}}, {"c._", 0}, {"c.d", new(bool)}, {"e", structure.e}, {"e._", 0}, {"e.f", structure.e.f}, {"g", structure.g}, {"g._", 0}, {"g.h", (*structure.g).h}, {"g.s", (*structure.g).s}, {"g.m", (*structure.g).m}, {"g2.m", (map[string][]string)(nil)}, {"i", structure.i}, {"j", structure.j}, {"k", structure.k}, {"m", new(int)}, {"n", new(bool)}, {"p", new(string)}, } for _, c := range cases { fh, _ := e.Field(c.fieldSelector) val := fh.Value(true).Interface() assertEqual(t, c.value, val, c.fieldSelector) } } func TestOperator(t *testing.T) { type Tmp1 struct { A string `tagexpr:$=="1"||$=="2"||$="3"` //nolint:govet B []int `tagexpr:len($)>=10&&$[0]<10` //nolint:govet C interface{} } type Tmp2 struct { A *Tmp1 B interface{} } type Target struct { A int `tagexpr:"-$+$<=10"` B int `tagexpr:"+$-$<=10"` C int `tagexpr:"-$+(M)$*(N)$/$%(D.B)$[2]+$==1"` D *Tmp1 `tagexpr:"(D.A)$!=nil"` E string `tagexpr:"((D.A)$=='1'&&len($)>1)||((D.A)$=='2'&&len($)>2)||((D.A)$=='3'&&len($)>3)"` F map[string]int `tagexpr:"x:len($);y:$['a']>10&&$['b']>1"` G *map[string]int `tagexpr:"x:$['a']+(F)$['a']>20"` H []string `tagexpr:"len($)>=1&&len($)<10&&$[0]=='123'&&$[1]!='456'"` I interface{} `tagexpr:"$!=nil"` K *string `tagexpr:"len((D.A)$)+len($)<10&&len((D.A)$+$)<10"` L **string `tagexpr:"false"` M float64 `tagexpr:"$/2>10&&$%2==0"` N *float64 `tagexpr:"($+$*$-$/$+1)/$==$+1"` O *[3]float64 `tagexpr:"$[0]>10&&$[0]<20||$[0]>20&&$[0]<30"` P *Tmp2 `tagexpr:"x:$!=nil;y:len((P.A.A)$)<=1&&(P.A.B)$[0]==1;z:$['A']['C']==nil;w:$['A']['B'][0]==1;r:$[0][1][2]==3;s1:$[2]==nil;s2:$[0][3]==nil;s3:(ZZ)$;s4:(P.B)$!=nil"` Q *Tmp2 `tagexpr:"s1:$['A']['B']!=nil;s2:(Q.A)$['B']!=nil;s3:$['A']['C']==nil;s4:(Q.A)$['C']==nil;s5:(Q.A)$['B'][0]==1;s6:$['X']['Z']==nil"` } k := "123456" n := float64(-12.5) o := [3]float64{15, 9, 9} cases := []struct { tagName string structure interface{} tests map[string]interface{} }{ { tagName: "tagexpr", structure: &Target{ A: 5, B: 10, C: -10, D: &Tmp1{A: "3", B: []int{1, 2, 3}}, E: "1234", F: map[string]int{"a": 11, "b": 9}, G: &map[string]int{"a": 11}, H: []string{"123", "45"}, I: struct{}{}, K: &k, L: nil, M: float64(30), N: &n, O: &o, P: &Tmp2{A: &Tmp1{A: "3", B: []int{1, 2, 3}}, B: struct{}{}}, Q: &Tmp2{A: &Tmp1{A: "3", B: []int{1, 2, 3}}, B: struct{}{}}, }, tests: map[string]interface{}{ "A": true, "B": true, "C": true, "D": true, "E": true, "F@x": float64(2), "F@y": true, "G@x": true, "H": true, "I": true, "K": true, "L": false, "M": true, "N": true, "O": true, "P@x": true, "P@y": true, "P@z": true, "P@w": true, "P@r": true, "P@s1": true, "P@s2": true, "P@s3": nil, "P@s4": true, "Q@s1": true, "Q@s2": true, "Q@s3": true, "Q@s4": true, "Q@s5": true, "Q@s6": true, }, }, } for i, c := range cases { vm := tagexpr.New(c.tagName) // vm.WarmUp(c.structure) tagExpr, err := vm.Run(c.structure) if err != nil { t.Fatal(err) } for selector, value := range c.tests { val := tagExpr.Eval(selector) if !reflect.DeepEqual(val, value) { t.Fatalf("Eval NO: %d, selector: %q, got: %v, expect: %v", i, selector, val, value) } } tagExpr.Range(func(eh *tagexpr.ExprHandler) error { es := eh.ExprSelector() t.Logf("Range selector: %s, field: %q exprName: %q", es, es.Field(), es.Name()) value := c.tests[es.String()] val := eh.Eval() if !reflect.DeepEqual(val, value) { t.Fatalf("Range NO: %d, selector: %q, got: %v, expect: %v", i, es, val, value) } return nil }) } } func TestStruct(t *testing.T) { type A struct { B struct { C struct { D struct { X string `vd:"$"` } } `vd:"@:$['D']['X']"` C2 string `vd:"@:(C)$['D']['X']"` C3 string `vd:"@:(C.D.X)$"` } } a := new(A) a.B.C.D.X = "xxx" vm := tagexpr.New("vd") expr := vm.MustRun(a) assertEqual(t, "xxx", expr.EvalString("B.C2")) assertEqual(t, "xxx", expr.EvalString("B.C3")) assertEqual(t, "xxx", expr.EvalString("B.C")) assertEqual(t, "xxx", expr.EvalString("B.C.D.X")) expr.Range(func(eh *tagexpr.ExprHandler) error { es := eh.ExprSelector() t.Logf("Range selector: %s, field: %q exprName: %q", es, es.Field(), es.Name()) if eh.Eval().(string) != "xxx" { t.FailNow() } return nil }) } func TestStruct2(t *testing.T) { type IframeBlock struct { XBlock struct { BlockType string `vd:"$"` } Props struct { Data struct { DataType string `vd:"$"` } } } b := new(IframeBlock) b.XBlock.BlockType = "BlockType" b.Props.Data.DataType = "DataType" vm := tagexpr.New("vd") expr := vm.MustRun(b) if expr.EvalString("XBlock.BlockType") != "BlockType" { t.Fatal(expr.EvalString("XBlock.BlockType")) } if expr.EvalString("Props.Data.DataType") != "DataType" { t.Fatal(expr.EvalString("Props.Data.DataType")) } } func TestStruct3(t *testing.T) { type Data struct { DataType string `vd:"$"` } type Prop struct { PropType string `vd:"$"` DD []*Data `vd:"$"` DD2 []*Data `vd:"$"` DataMap map[int]Data `vd:"$"` DataMap2 map[int]Data `vd:"$"` } type IframeBlock struct { XBlock struct { BlockType string `vd:"$"` } Props []Prop `vd:"$"` Props1 [2]Prop `vd:"$"` Props2 []Prop `vd:"$"` PropMap map[int]*Prop `vd:"$"` PropMap2 map[int]*Prop `vd:"$"` } b := new(IframeBlock) b.XBlock.BlockType = "BlockType" p1 := Prop{ PropType: "p1", DD: []*Data{ {"p1s1"}, {"p1s2"}, nil, }, DataMap: map[int]Data{ 1: {"p1m1"}, 2: {"p1m2"}, 0: {}, }, } b.Props = []Prop{p1} p2 := &Prop{ PropType: "p2", DD: []*Data{ {"p2s1"}, {"p2s2"}, nil, }, DataMap: map[int]Data{ 1: {"p2m1"}, 2: {"p2m2"}, 0: {}, }, } b.Props1 = [2]Prop{p1, {}} b.PropMap = map[int]*Prop{ 9: p2, } vm := tagexpr.New("vd") expr := vm.MustRun(b) if expr.EvalString("XBlock.BlockType") != "BlockType" { t.Fatal(expr.EvalString("XBlock.BlockType")) } err := expr.Range(func(eh *tagexpr.ExprHandler) error { es := eh.ExprSelector() t.Logf("Range selector: %s, field: %q exprName: %q, eval: %v", eh.Path(), es.Field(), es.Name(), eh.Eval()) return nil }) if err != nil { t.Fatal(err) } } func TestNilField(t *testing.T) { type P struct { X **struct { A *[]uint16 `tagexpr:"$"` } `tagexpr:"$"` Y **struct{} `tagexpr:"$"` } vm := tagexpr.New("tagexpr") te := vm.MustRun(&P{}) te.Range(func(eh *tagexpr.ExprHandler) error { r := eh.Eval() if r != nil { t.Fatal(eh.Path()) } return nil }) type G struct { // Nil1 *int `tagexpr:"nil!=$"` Nil2 *int `tagexpr:"$!=nil"` } g := &G{ // Nil1: new(int), Nil2: new(int), } vm.MustRun(g).Range(func(eh *tagexpr.ExprHandler) error { r, ok := eh.Eval().(bool) if !ok || !r { t.Fatal(eh.Path()) } return nil }) } func TestDeepNested(t *testing.T) { type testInner struct { Address string `tagexpr:"name:$"` } type struct1 struct { I *testInner A []*testInner X interface{} } type struct2 struct { S *struct1 } type Data struct { S1 *struct2 S2 *struct2 } data := &Data{ S1: &struct2{ S: &struct1{ I: &testInner{Address: "I:address"}, A: []*testInner{{Address: "A:address"}}, X: []*testInner{{Address: "X:address"}}, }, }, S2: &struct2{ S: &struct1{ A: []*testInner{nil}, }, }, } expectKey := [...]interface{}{"S1.S.I.Address@name", "S2.S.I.Address@name", "S1.S.A[0].Address@name", "S2.S.A[0].Address@name", "S1.S.X[0].Address@name"} expectValue := [...]interface{}{"I:address", nil, "A:address", nil, "X:address"} var i int vm := tagexpr.New("tagexpr") vm.MustRun(data).Range(func(eh *tagexpr.ExprHandler) error { assertEqual(t, expectKey[i], eh.Path()) assertEqual(t, expectValue[i], eh.Eval()) i++ t.Log(eh.Path(), eh.ExprSelector(), eh.Eval()) return nil }) assertEqual(t, 5, i) } func TestIssue3(t *testing.T) { type C struct { Id string Index int32 `vd:"$"` P *int `vd:"$!=nil"` } type A struct { F1 *C F2 *C } a := &A{ F1: &C{ Id: "test", Index: 1, P: new(int), }, } vm := tagexpr.New("vd") err := vm.MustRun(a).Range(func(eh *tagexpr.ExprHandler) error { switch eh.Path() { case "F1.Index": assertEqual(t, float64(1), eh.Eval(), eh.Path()) case "F2.Index": assertEqual(t, nil, eh.Eval(), eh.Path()) case "F1.P": assertEqual(t, true, eh.Eval(), eh.Path()) case "F2.P": assertEqual(t, false, eh.Eval(), eh.Path()) } return nil }) if err != nil { t.Fatal(err) } } func TestIssue4(t *testing.T) { type T struct { A *string `te:"len($)+mblen($)"` B *string `te:"len($)+mblen($)"` C *string `te:"len($)+mblen($)"` } c := "c" v := &T{ B: new(string), C: &c, } vm := tagexpr.New("te") err := vm.MustRun(v).Range(func(eh *tagexpr.ExprHandler) error { t.Logf("eval:%v, path:%s", eh.EvalFloat(), eh.Path()) return nil }) if err != nil { t.Fatal(err) } } func TestIssue5(t *testing.T) { type A struct { F1 int `vd:"true && $ <= 24*60*60"` // 1500 ok F2 int `vd:"$%60 == 0 && $ <= (24*60*60)"` // 1500 ok F3 int `vd:"$ <= 24*60*60"` // 1500 ok } a := &A{ F1: 1500, F2: 1500, F3: 1500, } vm := tagexpr.New("vd") err := vm.MustRun(a).Range(func(eh *tagexpr.ExprHandler) error { switch eh.Path() { case "F1": assertEqual(t, true, eh.Eval(), eh.Path()) case "F2": assertEqual(t, true, eh.Eval(), eh.Path()) case "F3": assertEqual(t, true, eh.Eval(), eh.Path()) } return nil }) if err != nil { t.Fatal(err) } } func TestHertzIssue1410(t *testing.T) { type HelloReq struct { Meta *structpb.Struct `protobuf:"bytes,2,opt,name=meta,proto3" form:"meta" json:"meta,omitempty"` } x := &HelloReq{} if err := json.Unmarshal([]byte(`{"meta": {"test": "value"}}`), x); err != nil { t.Fatal(err) } te := tagexpr.New("test").MustRun(x) if err := te.Range(func(eh *tagexpr.ExprHandler) error { return nil }); err != nil { t.Fatal(err) } } func TestFakeBool(t *testing.T) { // Numeric types - zero values should be false, non-zero should be true tests := []struct { input interface{} expected bool }{ // Float types {float64(0), false}, {float64(3.14), true}, {float32(0), false}, {float32(2.5), true}, // Integer types {int(0), false}, {int(42), true}, {int8(0), false}, {int8(127), true}, {int16(0), false}, {int16(32767), true}, {int32(0), false}, {int32(2147483647), true}, {int64(0), false}, {int64(9223372036854775807), true}, // Unsigned integer types {uint(0), false}, {uint(1), true}, {uint8(0), false}, {uint8(255), true}, {uint16(0), false}, {uint16(65535), true}, {uint32(0), false}, {uint32(4294967295), true}, {uint64(0), false}, {uint64(18446744073709551615), true}, // String type {"", false}, {"hello", true}, // Boolean type {true, true}, {false, false}, // Nil and error types {nil, false}, {errors.New("test"), false}, // Slice of interfaces - all elements must be truthy for true {[]interface{}{}, true}, // empty slice -> true {[]interface{}{1, "hello", true}, true}, // all truthy -> true {[]interface{}{1, "", true}, false}, // one falsy -> false {[]interface{}{0, "", false}, false}, // all falsy -> false {[]interface{}{nil, nil}, false}, // nil values are falsy -> false // Unsupported types should return false {struct{}{}, false}, {new(int), false}, {make(chan int), false}, {func() {}, false}, {map[string]int{}, false}, {[3]int{1, 2, 3}, false}, } for _, tt := range tests { result := tagexpr.FakeBool(tt.input) if result != tt.expected { t.Errorf("FakeBool(%v) = %v; want %v", tt.input, result, tt.expected) } } } ================================================ FILE: internal/tagexpr/tagparser.go ================================================ // Copyright 2019 Bytedance Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package tagexpr import ( "fmt" "strings" "unicode" ) const ( tagOmit = "-" tagOmitNil = "?" ) func (f *fieldVM) parseExprs(tag string) error { switch tag { case tagOmit, tagOmitNil: f.tagOp = tag return nil } kvs, err := parseTag(tag) if err != nil { return err } exprSelectorPrefix := f.structField.Name for exprSelector, exprString := range kvs { expr, err := parseExpr(exprString) if err != nil { return err } if exprSelector == ExprNameSeparator { exprSelector = exprSelectorPrefix } else { exprSelector = exprSelectorPrefix + ExprNameSeparator + exprSelector } f.exprs[exprSelector] = expr f.origin.exprs[exprSelector] = expr f.origin.exprSelectorList = append(f.origin.exprSelectorList, exprSelector) } return nil } func parseTag(tag string) (map[string]string, error) { s := tag ptr := &s kvs := make(map[string]string) for { one, err := readOneExpr(ptr) if err != nil { return nil, err } if one == "" { return kvs, nil } key, val := splitExpr(one) if val == "" { return nil, fmt.Errorf("syntax error: %q expression string can not be empty", tag) } if _, ok := kvs[key]; ok { return nil, fmt.Errorf("syntax error: %q duplicate expression name %q", tag, key) } kvs[key] = val } } func splitExpr(one string) (key, val string) { one = strings.TrimSpace(one) if one == "" { return DefaultExprName, "" } var rs []rune for _, r := range one { if r == '@' || r == '_' || (r >= '0' && r <= '9') || (r >= 'A' && r <= 'Z') || (r >= 'a' && r <= 'z') { rs = append(rs, r) } else { break } } key = string(rs) val = strings.TrimSpace(one[len(key):]) if val == "" || val[0] != ':' { return DefaultExprName, one } val = val[1:] if key == "" { key = DefaultExprName } return key, val } func readOneExpr(tag *string) (string, error) { s := *(trimRightSpace(trimLeftSpace(tag))) s = strings.TrimLeft(s, ";") if s == "" { return "", nil } if s[len(s)-1] != ';' { s += ";" } a := strings.SplitAfter(strings.Replace(s, "\\'", "##", -1), ";") idx := -1 var patch int for _, v := range a { idx += len(v) count := strings.Count(v, "'") if (count+patch)%2 == 0 { *tag = s[idx+1:] return s[:idx], nil } if count > 0 { patch++ } } return "", fmt.Errorf("syntax error: %q unclosed single quote \"'\"", s) } func trimLeftSpace(p *string) *string { *p = strings.TrimLeftFunc(*p, unicode.IsSpace) return p } func trimRightSpace(p *string) *string { *p = strings.TrimRightFunc(*p, unicode.IsSpace) return p } func readPairedSymbol(p *string, left, right rune) *string { s := *p if len(s) == 0 || rune(s[0]) != left { return nil } s = s[1:] last1 := left var last2 rune var leftLevel, rightLevel int escapeIndexes := make(map[int]bool) var realEqual, escapeEqual bool for i, r := range s { if realEqual, escapeEqual = equalRune(right, r, last1, last2); realEqual { if leftLevel == rightLevel { *p = s[i+1:] sub := make([]rune, 0, i) for k, v := range s[:i] { if !escapeIndexes[k] { sub = append(sub, v) } } s = string(sub) return &s } rightLevel++ } else if escapeEqual { escapeIndexes[i-1] = true } else if realEqual, escapeEqual = equalRune(left, r, last1, last2); realEqual { leftLevel++ } else if escapeEqual { escapeIndexes[i-1] = true } last2 = last1 last1 = r } return nil } func equalRune(a, b, last1, last2 rune) (real, escape bool) { if a == b { real = last1 != '\\' || last2 == '\\' escape = last1 == '\\' && last2 != '\\' } return } ================================================ FILE: internal/tagexpr/tagparser_test.go ================================================ // Copyright 2019 Bytedance Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package tagexpr import ( "reflect" "testing" ) func TestTagparser(t *testing.T) { cases := []struct { tag reflect.StructTag expect map[string]string fail bool }{ { tag: `tagexpr:"$>0"`, expect: map[string]string{ "@": "$>0", }, }, { tag: `tagexpr:"$>0;'xxx'"`, fail: true, }, { tag: `tagexpr:"$>0;b:sprintf('%[1]T; %[1]v',(X)$)"`, expect: map[string]string{ "@": `$>0`, "b": `sprintf('%[1]T; %[1]v',(X)$)`, }, }, { tag: `tagexpr:"a:$=='0;1;';b:sprintf('%[1]T; %[1]v',(X)$)"`, expect: map[string]string{ "a": `$=='0;1;'`, "b": `sprintf('%[1]T; %[1]v',(X)$)`, }, }, { tag: `tagexpr:"a:1;;b:2"`, expect: map[string]string{ "a": `1`, "b": `2`, }, }, { tag: `tagexpr:";a:1;;b:2;;;"`, expect: map[string]string{ "a": `1`, "b": `2`, }, }, { tag: `tagexpr:";a:'123\\'';;b:'1\\'23';c:'1\\'2\\'3';;"`, expect: map[string]string{ "a": `'123\''`, "b": `'1\'23'`, "c": `'1\'2\'3'`, }, }, { tag: `tagexpr:"email($)"`, expect: map[string]string{ "@": `email($)`, }, }, { tag: `tagexpr:"false"`, expect: map[string]string{ "@": `false`, }, }, } for _, c := range cases { r, e := parseTag(c.tag.Get("tagexpr")) if e != nil == c.fail { if !reflect.DeepEqual(c.expect, r) { t.Fatal(c.expect, r, c.tag) } } else { t.Fatalf("tag:%s kvs:%v, err:%v", c.tag, r, e) } if e != nil { t.Logf("tag:%q, errMsg:%v", c.tag, e) } } } ================================================ FILE: internal/tagexpr/utils.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package tagexpr import ( "reflect" "unsafe" ) func init() { testhack() } func dereferenceValue(v reflect.Value) reflect.Value { for v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface { v = v.Elem() } return v } func dereferenceType(t reflect.Type) reflect.Type { for t.Kind() == reflect.Ptr { t = t.Elem() } return t } func dereferenceInterfaceValue(v reflect.Value) reflect.Value { for v.Kind() == reflect.Interface { v = v.Elem() } return v } type rvtype struct { // reflect.Value abiType uintptr ptr unsafe.Pointer // data pointer } func rvPtr(rv reflect.Value) unsafe.Pointer { return (*rvtype)(unsafe.Pointer(&rv)).ptr } func rvType(rv reflect.Value) uintptr { return (*rvtype)(unsafe.Pointer(&rv)).abiType } func rtType(rt reflect.Type) uintptr { type iface struct { tab uintptr data uintptr } return (*iface)(unsafe.Pointer(&rt)).data } // quick test make sure the hack above works func testhack() { type T1 struct { a int } type T2 struct { a int } p0 := &T1{1} p1 := &T1{2} p2 := &T2{3} if rvPtr(reflect.ValueOf(p0)) != unsafe.Pointer(p0) || rvPtr(reflect.ValueOf(p0).Elem()) != unsafe.Pointer(p0) || rvPtr(reflect.ValueOf(p0)) == rvPtr(reflect.ValueOf(p1)) { panic("rvPtr() compatibility issue found") } if rvType(reflect.ValueOf(p0)) != rvType(reflect.ValueOf(p1)) || rvType(reflect.ValueOf(p0)) == rvType(reflect.ValueOf(p2)) || rvType(reflect.ValueOf(p0).Elem()) != rvType(reflect.ValueOf(p1).Elem()) || rvType(reflect.ValueOf(p0).Elem()) == rvType(reflect.ValueOf(p2).Elem()) { panic("rvType() compatibility issue found") } if rtType(reflect.TypeOf(p0)) != rtType(reflect.TypeOf(p1)) || rtType(reflect.TypeOf(p0)) == rtType(reflect.TypeOf(p2)) || rtType(reflect.TypeOf(p0).Elem()) != rtType(reflect.TypeOf(p1).Elem()) || rtType(reflect.TypeOf(p0).Elem()) == rtType(reflect.TypeOf(p2).Elem()) { panic("rtType() compatibility issue found") } } ================================================ FILE: internal/tagexpr/validator/README.md ================================================ # validator originally from https://github.com/bytedance/go-tagexpr v2.9.2 ================================================ FILE: internal/tagexpr/validator/default.go ================================================ // Copyright 2019 Bytedance Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package validator var defaultValidator = New("vd").SetErrorFactory(defaultErrorFactory) // Default returns the default validator. // NOTE: // // The tag name is 'vd' func Default() *Validator { return defaultValidator } // Validate uses the default validator to validate whether the fields of value is valid. // NOTE: // // The tag name is 'vd' // If checkAll=true, validate all the error. func Validate(value interface{}, checkAll ...bool) error { return defaultValidator.Validate(value, checkAll...) } // SetErrorFactory customizes the factory of validation error for the default validator. // NOTE: // // The tag name is 'vd' func SetErrorFactory(errFactory func(fieldSelector, msg string) error) { defaultValidator.SetErrorFactory(errFactory) } ================================================ FILE: internal/tagexpr/validator/example_test.go ================================================ // Copyright 2019 Bytedance Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package validator_test import ( "fmt" vd "github.com/cloudwego/hertz/internal/tagexpr/validator" ) func Example() { type InfoRequest struct { Name string `vd:"($!='Alice'||(Age)$==18) && regexp('\\w')"` Age int `vd:"$>0"` Email string `vd:"email($)"` Phone1 string `vd:"phone($)"` OtherPhones []string `vd:"range($, phone(#v,'CN'))"` *InfoRequest `vd:"?"` Info1 *InfoRequest `vd:"?"` Info2 *InfoRequest `vd:"-"` } info := &InfoRequest{ Name: "Alice", Age: 18, Email: "henrylee2cn@gmail.com", Phone1: "+8618812345678", OtherPhones: []string{"18812345679", "18812345680"}, } fmt.Println(vd.Validate(info)) type A struct { A int `vd:"$<0||$>=100"` Info interface{} } info.Email = "xxx" a := &A{A: 107, Info: info} fmt.Println(vd.Validate(a)) type B struct { B string `vd:"len($)>1 && regexp('^\\w*$')"` } b := &B{"abc"} fmt.Println(vd.Validate(b) == nil) type C struct { C bool `vd:"@:(S.A)$>0 && !$; msg:'C must be false when S.A>0'"` S *A } c := &C{C: true, S: a} fmt.Println(vd.Validate(c)) type D struct { d []string `vd:"@:len($)>0 && $[0]=='D'; msg:sprintf('invalid d: %v',$)"` } d := &D{d: []string{"x", "y"}} fmt.Println(vd.Validate(d)) type E struct { e map[string]int `vd:"len($)==$['len']"` } e := &E{map[string]int{"len": 2}} fmt.Println(vd.Validate(e)) // Customizes the factory of validation error. vd.SetErrorFactory(func(failPath, msg string) error { return fmt.Errorf(`{"succ":false, "error":"validation failed: %s"}`, failPath) }) type F struct { f struct { g int `vd:"$%3==0"` } } f := &F{} f.f.g = 10 fmt.Println(vd.Validate(f)) fmt.Println(vd.Validate(map[string]*F{"a": f})) fmt.Println(vd.Validate(map[string]map[string]*F{"a": {"b": f}})) fmt.Println(vd.Validate([]map[string]*F{{"a": f}})) fmt.Println(vd.Validate(struct { A []map[string]*F }{A: []map[string]*F{{"x": f}}})) fmt.Println(vd.Validate(map[*F]int{f: 1})) fmt.Println(vd.Validate([][1]*F{{f}})) fmt.Println(vd.Validate((*F)(nil))) fmt.Println(vd.Validate(map[string]*F{})) fmt.Println(vd.Validate(map[string]map[string]*F{})) fmt.Println(vd.Validate([]map[string]*F{})) fmt.Println(vd.Validate([]*F{})) // Output: // // email format is incorrect // true // C must be false when S.A>0 // invalid d: [x y] // invalid parameter: e // {"succ":false, "error":"validation failed: f.g"} // {"succ":false, "error":"validation failed: {v for k=a}.f.g"} // {"succ":false, "error":"validation failed: {v for k=a}{v for k=b}.f.g"} // {"succ":false, "error":"validation failed: [0]{v for k=a}.f.g"} // {"succ":false, "error":"validation failed: A[0]{v for k=x}.f.g"} // {"succ":false, "error":"validation failed: {k}.f.g"} // {"succ":false, "error":"validation failed: [0][0].f.g"} // unsupported data: nil // // // // } ================================================ FILE: internal/tagexpr/validator/func.go ================================================ // Copyright 2019 Bytedance Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package validator import ( "errors" "regexp" "github.com/cloudwego/hertz/internal/tagexpr" ) // ErrInvalidWithoutMsg verification error without error message. var ErrInvalidWithoutMsg = errors.New("") // MustRegFunc registers validator function expression. // NOTE: // // panic if exist error; // example: phone($) or phone($,'CN'); // If @force=true, allow to cover the existed same @funcName; // The go number types always are float64; // The go string types always are string. func MustRegFunc(funcName string, fn func(args ...interface{}) error, force ...bool) { err := RegFunc(funcName, fn, force...) if err != nil { panic(err) } } // RegFunc registers validator function expression. // NOTE: // // example: phone($) or phone($,'CN'); // If @force=true, allow to cover the existed same @funcName; // The go number types always are float64; // The go string types always are string. func RegFunc(funcName string, fn func(args ...interface{}) error, force ...bool) error { return tagexpr.RegFunc(funcName, func(args ...interface{}) interface{} { err := fn(args...) if err == nil { // nil defaults to false, so returns true return true } return err }, force...) } func init() { pattern := "^([A-Za-z0-9_\\-\\.\u4e00-\u9fa5])+\\@([A-Za-z0-9_\\-\\.])+\\.([A-Za-z]{2,8})$" emailRegexp := regexp.MustCompile(pattern) MustRegFunc("email", func(args ...interface{}) error { if len(args) != 1 { return errors.New("number of parameters of email function is not one") } s, ok := args[0].(string) if !ok { return errors.New("parameter of email function is not string type") } matched := emailRegexp.MatchString(s) if !matched { // return ErrInvalidWithoutMsg return errors.New("email format is incorrect") } return nil }, true) } // Phone validation always returns true. // // Removed github.com/nyaruka/phonenumbers dependency for the following reasons: // 1. The tagexpr validator package is deprecated // 2. The phonenumbers library has unresolved issues requiring upgrades // 3. The phonenumbers library is memory-heavy (loads many objects into memory even when unused) // // Since this validator is deprecated, we simply return true for all phone numbers // instead of maintaining complex validation logic. func validatePhone(numberToParse, region string) bool { return true } func init() { // phone: defaultRegion is 'CN' MustRegFunc("phone", func(args ...interface{}) error { var numberToParse, defaultRegion string var ok bool switch len(args) { default: return errors.New("the number of parameters of phone function is not one or two") case 2: defaultRegion, ok = args[1].(string) if !ok { return errors.New("the 2nd parameter of phone function is not string type") } fallthrough case 1: numberToParse, ok = args[0].(string) if !ok { return errors.New("the 1st parameter of phone function is not string type") } } if defaultRegion == "" { defaultRegion = "CN" } if !validatePhone(numberToParse, defaultRegion) { return errors.New("phone format is incorrect") } return nil }, true) } ================================================ FILE: internal/tagexpr/validator/validator.go ================================================ // Package validator is a powerful validator that supports struct tag expression. // // Copyright 2019 Bytedance Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package validator import ( "errors" "io" "reflect" "strings" _ "unsafe" "github.com/cloudwego/hertz/internal/tagexpr" ) const ( // MatchExprName the name of the expression used for validation MatchExprName = tagexpr.DefaultExprName // ErrMsgExprName the name of the expression used to specify the message // returned when validation failed ErrMsgExprName = "msg" ) // Validator struct fields validator type Validator struct { vm *tagexpr.VM errFactory func(failPath, msg string) error } // New creates a struct fields validator. func New(tagName string) *Validator { v := &Validator{ vm: tagexpr.New(tagName), errFactory: defaultErrorFactory, } return v } // VM returns the struct tag expression interpreter. func (v *Validator) VM() *tagexpr.VM { return v.vm } // Validate validates whether the fields of value is valid. // NOTE: // // If checkAll=true, validate all the error. func (v *Validator) Validate(value interface{}, checkAll ...bool) error { var all bool if len(checkAll) > 0 { all = checkAll[0] } errs := make([]error, 0, 8) err := v.vm.RunAny(value, func(te *tagexpr.TagExpr, err error) error { if err != nil { errs = append(errs, err) if all { return nil } return io.EOF } nilParentFields := make(map[string]bool, 16) err = te.Range(func(eh *tagexpr.ExprHandler) error { if strings.Contains(eh.StringSelector(), tagexpr.ExprNameSeparator) { return nil } r := eh.Eval() if r == nil { return nil } rerr, ok := r.(error) if !ok && tagexpr.FakeBool(r) { return nil } // Ignore this error if the value of the parent is nil if pfs, ok := eh.ExprSelector().ParentField(); ok { if nilParentFields[pfs] { return nil } if fh, ok := eh.TagExpr().Field(pfs); ok { v := fh.Value(false) if !v.IsValid() || (v.Kind() == reflect.Ptr && v.IsNil()) { nilParentFields[pfs] = true return nil } } } msg := eh.TagExpr().EvalString(eh.StringSelector() + tagexpr.ExprNameSeparator + ErrMsgExprName) if msg == "" && rerr != nil { msg = rerr.Error() } errs = append(errs, v.errFactory(eh.Path(), msg)) if all { return nil } return io.EOF }) if err != nil && !all { return err } return nil }) if err != io.EOF && err != nil { return err } switch len(errs) { case 0: return nil case 1: return errs[0] default: var errStr string for _, e := range errs { errStr += e.Error() + "\t" } return errors.New(errStr[:len(errStr)-1]) } } // SetErrorFactory customizes the factory of validation error. // NOTE: // // If errFactory==nil, the default is used func (v *Validator) SetErrorFactory(errFactory func(failPath, msg string) error) *Validator { if errFactory == nil { errFactory = defaultErrorFactory } v.errFactory = errFactory return v } // Error validate error type Error struct { FailPath, Msg string } // Error implements error interface. func (e *Error) Error() string { if e.Msg != "" { return e.Msg } return "invalid parameter: " + e.FailPath } //go:nosplit func defaultErrorFactory(failPath, msg string) error { return &Error{ FailPath: failPath, Msg: msg, } } ================================================ FILE: internal/tagexpr/validator/validator_test.go ================================================ // Copyright 2019 Bytedance Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package validator_test import ( "encoding/json" "errors" "testing" vd "github.com/cloudwego/hertz/internal/tagexpr/validator" ) func assertEqualError(t *testing.T, err error, s string) { t.Helper() if err.Error() != s { t.Fatal("not equal", err, s) } } func assertNoError(t *testing.T, err error) { t.Helper() if err != nil { t.Fatal(err) } } func TestNil(t *testing.T) { type F struct { F struct { G int `vd:"$%3==1"` } } assertEqualError(t, vd.Validate((*F)(nil)), "unsupported data: nil") } func TestAll(t *testing.T) { type T struct { A string `vd:"email($)"` F struct { G int `vd:"$%3==1"` } } assertEqualError(t, vd.Validate(new(T), true), "email format is incorrect\tinvalid parameter: F.G") } func TestIssue1(t *testing.T) { type MailBox struct { Address *string `vd:"email($)"` Name *string } type EmailMsg struct { Recipients []*MailBox RecipientsCc []*MailBox RecipientsBcc []*MailBox Subject *string Content *string AttachmentIDList []string ReplyTo *string Params map[string]string FromEmailAddress *string FromEmailName *string } type EmailTaskInfo struct { Msg *EmailMsg StartTimeMS *int64 LogTag *string } type BatchCreateEmailTaskRequest struct { InfoList []*EmailTaskInfo } invalid := "invalid email" req := &BatchCreateEmailTaskRequest{ InfoList: []*EmailTaskInfo{ { Msg: &EmailMsg{ Recipients: []*MailBox{ { Address: &invalid, }, }, }, }, }, } assertEqualError(t, vd.Validate(req, false), "email format is incorrect") } func TestIssue2(t *testing.T) { type a struct { m map[string]interface{} } A := &a{ m: map[string]interface{}{ "1": 1, "2": nil, }, } v := vd.New("vd") assertNoError(t, v.Validate(A)) } func TestIssue3(t *testing.T) { type C struct { Id string Index int32 `vd:"$==1"` } type A struct { F1 *C F2 *C } a := &A{ F1: &C{ Id: "test", Index: 1, }, } v := vd.New("vd") assertNoError(t, v.Validate(a)) } func TestIssue4(t *testing.T) { type C struct { Index *int32 `vd:"@:$!=nil;msg:'index is nil'"` Index2 *int32 `vd:"$!=nil"` Index3 *int32 `vd:"$!=nil"` } type A struct { F1 *C F2 map[string]*C F3 []*C } v := vd.New("vd") a := &A{} assertNoError(t, v.Validate(a)) a = &A{F1: new(C)} assertEqualError(t, v.Validate(a), "index is nil") a = &A{F2: map[string]*C{"x": {Index: new(int32)}}} assertEqualError(t, v.Validate(a), "invalid parameter: F2{v for k=x}.Index2") a = &A{F3: []*C{{Index: new(int32)}}} assertEqualError(t, v.Validate(a), "invalid parameter: F3[0].Index2") type B struct { F1 *C `vd:"$!=nil"` F2 *C } b := &B{} assertEqualError(t, v.Validate(b), "invalid parameter: F1") type D struct { F1 *C F2 *C } type E struct { D []*D } b.F1 = new(C) e := &E{D: []*D{nil}} assertNoError(t, v.Validate(e)) } func TestIssue5(t *testing.T) { type SubSheet struct{} type CopySheet struct { Source *SubSheet `json:"source" vd:"$!=nil"` Destination *SubSheet `json:"destination" vd:"$!=nil"` } type UpdateSheetsRequest struct { CopySheet *CopySheet `json:"copySheet"` } type BatchUpdateSheetRequestArg struct { Requests []*UpdateSheetsRequest `json:"requests"` } b := `{"requests": [{}]}` var data BatchUpdateSheetRequestArg err := json.Unmarshal([]byte(b), &data) assertNoError(t, err) if len(data.Requests) != 1 { t.Fatal(len(data.Requests)) } if data.Requests[0].CopySheet != nil { t.Fatal(data.Requests[0].CopySheet) } v := vd.New("vd") assertNoError(t, v.Validate(&data)) } func TestIn(t *testing.T) { type S string type I int16 type T struct { X *int `vd:"$==nil || len($)>0"` A S `vd:"in($,'a','b','c')"` B I `vd:"in($,1,2.0,3)"` } v := vd.New("vd") data := &T{} err := v.Validate(data) assertEqualError(t, err, "invalid parameter: A") data.A = "b" err = v.Validate(data) assertEqualError(t, err, "invalid parameter: B") data.B = 2 err = v.Validate(data) assertNoError(t, err) type T2 struct { C string `vd:"in($)"` } data2 := &T2{} err = v.Validate(data2) assertEqualError(t, err, "invalid parameter: C") type T3 struct { C string `vd:"in($,1)"` } data3 := &T3{} err = v.Validate(data3) assertEqualError(t, err, "invalid parameter: C") } type ( Issue23A struct { B *Issue23B V int64 `vd:"$==0"` } Issue23B struct { A *Issue23A V int64 `vd:"$==0"` } ) func TestIssue23(t *testing.T) { data := &Issue23B{A: &Issue23A{B: new(Issue23B)}} err := vd.Validate(data, true) assertNoError(t, err) } func TestIssue24(t *testing.T) { type SubmitDoctorImportItem struct { Name string `form:"name,required" json:"name,required" query:"name,required"` Avatar *string `form:"avatar,omitempty" json:"avatar,omitempty" query:"avatar,omitempty"` Idcard string `form:"idcard,required" json:"idcard,required" query:"idcard,required" vd:"len($)==18"` IdcardPics []string `form:"idcard_pics,omitempty" json:"idcard_pics,omitempty" query:"idcard_pics,omitempty"` Hosp string `form:"hosp,required" json:"hosp,required" query:"hosp,required"` HospDept string `form:"hosp_dept,required" json:"hosp_dept,required" query:"hosp_dept,required"` HospProv *string `form:"hosp_prov,omitempty" json:"hosp_prov,omitempty" query:"hosp_prov,omitempty"` HospCity *string `form:"hosp_city,omitempty" json:"hosp_city,omitempty" query:"hosp_city,omitempty"` HospCounty *string `form:"hosp_county,omitempty" json:"hosp_county,omitempty" query:"hosp_county,omitempty"` ProTit string `form:"pro_tit,required" json:"pro_tit,required" query:"pro_tit,required"` ThTit *string `form:"th_tit,omitempty" json:"th_tit,omitempty" query:"th_tit,omitempty"` ServDepts *string `form:"serv_depts,omitempty" json:"serv_depts,omitempty" query:"serv_depts,omitempty"` TitCerts []string `form:"tit_certs,omitempty" json:"tit_certs,omitempty" query:"tit_certs,omitempty"` ThTitCerts []string `form:"th_tit_certs,omitempty" json:"th_tit_certs,omitempty" query:"th_tit_certs,omitempty"` PracCerts []string `form:"prac_certs,omitempty" json:"prac_certs,omitempty" query:"prac_certs,omitempty"` QualCerts []string `form:"qual_certs,omitempty" json:"qual_certs,omitempty" query:"qual_certs,omitempty"` PracCertNo string `form:"prac_cert_no,required" json:"prac_cert_no,required" query:"prac_cert_no,required" vd:"len($)==15"` Goodat *string `form:"goodat,omitempty" json:"goodat,omitempty" query:"goodat,omitempty"` Intro *string `form:"intro,omitempty" json:"intro,omitempty" query:"intro,omitempty"` Linkman string `form:"linkman,required" json:"linkman,required" query:"linkman,required" vd:"email($)"` Phone string `form:"phone,required" json:"phone,required" query:"phone,required" vd:"phone($,'CN')"` } type SubmitDoctorImportRequest struct { SubmitDoctorImport []*SubmitDoctorImportItem `form:"submit_doctor_import,required" json:"submit_doctor_import,required"` } data := &SubmitDoctorImportRequest{SubmitDoctorImport: []*SubmitDoctorImportItem{{}}} err := vd.Validate(data, true) assertEqualError(t, err, "invalid parameter: SubmitDoctorImport[0].Idcard\tinvalid parameter: SubmitDoctorImport[0].PracCertNo\temail format is incorrect") } func TestStructSliceMap(t *testing.T) { type F struct { f struct { g int `vd:"$%3==0"` } } f := &F{} f.f.g = 10 type S struct { A map[string]*F B []map[string]*F C map[string][]map[string]F // _ int } s := S{ A: map[string]*F{"x": f}, B: []map[string]*F{{"y": f}}, C: map[string][]map[string]F{"z": {{"zz": *f}}}, } err := vd.Validate(s, true) assertEqualError(t, err, "invalid parameter: A{v for k=x}.f.g\tinvalid parameter: B[0]{v for k=y}.f.g\tinvalid parameter: C{v for k=z}[0]{v for k=zz}.f.g") } func TestIssue30(t *testing.T) { type TStruct struct { TOk string `vd:"gt($,'0') && gt($, '1')" json:"t_ok"` // TFail string `vd:"gt($,'0')" json:"t_fail"` } vd.RegFunc("gt", func(args ...interface{}) error { return errors.New("force error") }) assertEqualError(t, vd.Validate(&TStruct{TOk: "1"}), "invalid parameter: TOk") // assertNoError(t, vd.Validate(&TStruct{TOk: "1", TFail: "1"})) } func TestIssue31(t *testing.T) { type TStruct struct { A []int32 `vd:"$ == nil || ($ != nil && range($, in(#v, 1, 2, 3))"` } assertEqualError(t, vd.Validate(&TStruct{A: []int32{1}}), "syntax error: \"($ != nil && range($, in(#v, 1, 2, 3))\"") assertEqualError(t, vd.Validate(&TStruct{A: []int32{1}}), "syntax error: \"($ != nil && range($, in(#v, 1, 2, 3))\"") assertEqualError(t, vd.Validate(&TStruct{A: []int32{1}}), "syntax error: \"($ != nil && range($, in(#v, 1, 2, 3))\"") } func TestRegexp(t *testing.T) { type TStruct struct { A string `vd:"regexp('(\\d+\\.){3}\\d+')"` } assertNoError(t, vd.Validate(&TStruct{A: "0.0.0.0"})) assertEqualError(t, vd.Validate(&TStruct{A: "0...0"}), "invalid parameter: A") assertEqualError(t, vd.Validate(&TStruct{A: "abc1"}), "invalid parameter: A") assertEqualError(t, vd.Validate(&TStruct{A: "0?0?0?0"}), "invalid parameter: A") } func TestRangeIn(t *testing.T) { type S struct { F []string `vd:"range($, in(#v, '', 'ttp', 'euttp'))"` } err := vd.Validate(S{ F: []string{"ttp", "", "euttp"}, }) assertNoError(t, err) err = vd.Validate(S{ F: []string{"ttp", "?", "euttp"}, }) assertEqualError(t, err, "invalid parameter: F") } ================================================ FILE: internal/test/mock/binder/binder.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package binder import ( "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" ) // Binder provides a mock implementation of the Binder interface for testing. type Binder struct { ValidateError error // Error to return from Validate method } // NewBinder creates a new mock binder. func NewBinder() *Binder { return &Binder{} } // NewBinderWithValidateError creates a new mock binder that returns the specified error from Validate. func NewBinderWithValidateError(err error) *Binder { return &Binder{ValidateError: err} } func (m *Binder) Name() string { return "test binder" } func (m *Binder) Bind(request *protocol.Request, i interface{}, params param.Params) error { return nil } func (m *Binder) BindAndValidate(request *protocol.Request, i interface{}, params param.Params) error { return nil } func (m *Binder) BindQuery(request *protocol.Request, i interface{}) error { return nil } func (m *Binder) BindHeader(request *protocol.Request, i interface{}) error { return nil } func (m *Binder) BindPath(request *protocol.Request, i interface{}, params param.Params) error { return nil } func (m *Binder) BindForm(request *protocol.Request, i interface{}) error { return nil } func (m *Binder) BindJSON(request *protocol.Request, i interface{}) error { return nil } func (m *Binder) BindProtobuf(request *protocol.Request, i interface{}) error { return nil } func (m *Binder) Validate(request *protocol.Request, i interface{}) error { if m.ValidateError != nil { return m.ValidateError } return nil } ================================================ FILE: internal/test/mock/binder/binder_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package binder import ( "errors" "testing" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestNewBinder(t *testing.T) { binder := NewBinder() assert.DeepEqual(t, "test binder", binder.Name()) assert.Nil(t, binder.ValidateError) // Test all binding methods return nil assert.Nil(t, binder.Bind(nil, nil, nil)) assert.Nil(t, binder.BindAndValidate(nil, nil, nil)) assert.Nil(t, binder.BindQuery(nil, nil)) assert.Nil(t, binder.BindHeader(nil, nil)) assert.Nil(t, binder.BindPath(nil, nil, nil)) assert.Nil(t, binder.BindForm(nil, nil)) assert.Nil(t, binder.BindJSON(nil, nil)) assert.Nil(t, binder.BindProtobuf(nil, nil)) assert.Nil(t, binder.Validate(nil, nil)) } func TestNewBinderWithValidateError(t *testing.T) { testErr := errors.New("test error") binder := NewBinderWithValidateError(testErr) assert.DeepEqual(t, testErr, binder.ValidateError) } func TestBinderValidate(t *testing.T) { // Test no error binder1 := NewBinder() assert.Nil(t, binder1.Validate(nil, nil)) // Test with error testErr := errors.New("validation failed") binder2 := NewBinderWithValidateError(testErr) assert.DeepEqual(t, testErr, binder2.Validate(nil, nil)) } ================================================ FILE: internal/testutils/testutils.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package testutils import ( "net" "testing" "time" ) type RouteEngine interface { IsRunning() bool } func WaitEngineRunning(e RouteEngine) { for i := 0; i < 100; i++ { if e.IsRunning() { return } time.Sleep(10 * time.Millisecond) } panic("not running") } // NewTestListener creates a TCP listener on a random available port. // It calls tb.Fatal if the listener cannot be created. // The caller is responsible for closing the listener (usually via defer). func NewTestListener(tb testing.TB) net.Listener { tb.Helper() ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { tb.Fatalf("failed to create test listener: %s", err) } return ln } ================================================ FILE: internal/testutils/testutils_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package testutils import ( "sync/atomic" "testing" "time" ) func TestNewTestListener(t *testing.T) { ln := NewTestListener(t) defer ln.Close() t.Log(ln.Addr()) } type routeEngine struct { Running atomic.Bool } func (e *routeEngine) IsRunning() bool { return e.Running.Load() } func TestWaitEngineRunning(t *testing.T) { e := &routeEngine{} go func() { time.Sleep(30 * time.Millisecond) e.Running.Store(true) }() WaitEngineRunning(e) } ================================================ FILE: licenses/LICENSE-echo.txt ================================================ BSD 3-Clause License Copyright (c) 2013, Julien Schmidt All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: licenses/LICENSE-fasthttp.txt ================================================ The MIT License (MIT) Copyright (c) 2021 LabStack Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: licenses/LICENSE-fsnotify ================================================ Copyright © 2012 The Go Authors. All rights reserved. Copyright © fsnotify Authors. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of Google Inc. nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: licenses/LICENSE-gin.txt ================================================ The MIT License (MIT) Copyright (c) 2014 Manuel Martínez-Almeida Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: licenses/LICENSE-go-version.txt ================================================ Mozilla Public License, version 2.0 1. Definitions 1.1. “Contributor” means each individual or legal entity that creates, contributes to the creation of, or owns Covered Software. 1.2. “Contributor Version” means the combination of the Contributions of others (if any) used by a Contributor and that particular Contributor’s Contribution. 1.3. “Contribution” means Covered Software of a particular Contributor. 1.4. “Covered Software” means Source Code Form to which the initial Contributor has attached the notice in Exhibit A, the Executable Form of such Source Code Form, and Modifications of such Source Code Form, in each case including portions thereof. 1.5. “Incompatible With Secondary Licenses” means a. that the initial Contributor has attached the notice described in Exhibit B to the Covered Software; or b. that the Covered Software was made available under the terms of version 1.1 or earlier of the License, but not also under the terms of a Secondary License. 1.6. “Executable Form” means any form of the work other than Source Code Form. 1.7. “Larger Work” means a work that combines Covered Software with other material, in a separate file or files, that is not Covered Software. 1.8. “License” means this document. 1.9. “Licensable” means having the right to grant, to the maximum extent possible, whether at the time of the initial grant or subsequently, any and all of the rights conveyed by this License. 1.10. “Modifications” means any of the following: a. any file in Source Code Form that results from an addition to, deletion from, or modification of the contents of Covered Software; or b. any new file in Source Code Form that contains any Covered Software. 1.11. “Patent Claims” of a Contributor means any patent claim(s), including without limitation, method, process, and apparatus claims, in any patent Licensable by such Contributor that would be infringed, but for the grant of the License, by the making, using, selling, offering for sale, having made, import, or transfer of either its Contributions or its Contributor Version. 1.12. “Secondary License” means either the GNU General Public License, Version 2.0, the GNU Lesser General Public License, Version 2.1, the GNU Affero General Public License, Version 3.0, or any later versions of those licenses. 1.13. “Source Code Form” means the form of the work preferred for making modifications. 1.14. “You” (or “Your”) means an individual or a legal entity exercising rights under this License. For legal entities, “You” includes any entity that controls, is controlled by, or is under common control with You. For purposes of this definition, “control” means (a) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (b) ownership of more than fifty percent (50%) of the outstanding shares or beneficial ownership of such entity. 2. License Grants and Conditions 2.1. Grants Each Contributor hereby grants You a world-wide, royalty-free, non-exclusive license: a. under intellectual property rights (other than patent or trademark) Licensable by such Contributor to use, reproduce, make available, modify, display, perform, distribute, and otherwise exploit its Contributions, either on an unmodified basis, with Modifications, or as part of a Larger Work; and b. under Patent Claims of such Contributor to make, use, sell, offer for sale, have made, import, and otherwise transfer either its Contributions or its Contributor Version. 2.2. Effective Date The licenses granted in Section 2.1 with respect to any Contribution become effective for each Contribution on the date the Contributor first distributes such Contribution. 2.3. Limitations on Grant Scope The licenses granted in this Section 2 are the only rights granted under this License. No additional rights or licenses will be implied from the distribution or licensing of Covered Software under this License. Notwithstanding Section 2.1(b) above, no patent license is granted by a Contributor: a. for any code that a Contributor has removed from Covered Software; or b. for infringements caused by: (i) Your and any other third party’s modifications of Covered Software, or (ii) the combination of its Contributions with other software (except as part of its Contributor Version); or c. under Patent Claims infringed by Covered Software in the absence of its Contributions. This License does not grant any rights in the trademarks, service marks, or logos of any Contributor (except as may be necessary to comply with the notice requirements in Section 3.4). 2.4. Subsequent Licenses No Contributor makes additional grants as a result of Your choice to distribute the Covered Software under a subsequent version of this License (see Section 10.2) or under the terms of a Secondary License (if permitted under the terms of Section 3.3). 2.5. Representation Each Contributor represents that the Contributor believes its Contributions are its original creation(s) or it has sufficient rights to grant the rights to its Contributions conveyed by this License. 2.6. Fair Use This License is not intended to limit any rights You have under applicable copyright doctrines of fair use, fair dealing, or other equivalents. 2.7. Conditions Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in Section 2.1. 3. Responsibilities 3.1. Distribution of Source Form All distribution of Covered Software in Source Code Form, including any Modifications that You create or to which You contribute, must be under the terms of this License. You must inform recipients that the Source Code Form of the Covered Software is governed by the terms of this License, and how they can obtain a copy of this License. You may not attempt to alter or restrict the recipients’ rights in the Source Code Form. 3.2. Distribution of Executable Form If You distribute Covered Software in Executable Form then: a. such Covered Software must also be made available in Source Code Form, as described in Section 3.1, and You must inform recipients of the Executable Form how they can obtain a copy of such Source Code Form by reasonable means in a timely manner, at a charge no more than the cost of distribution to the recipient; and b. You may distribute such Executable Form under the terms of this License, or sublicense it under different terms, provided that the license for the Executable Form does not attempt to limit or alter the recipients’ rights in the Source Code Form under this License. 3.3. Distribution of a Larger Work You may create and distribute a Larger Work under terms of Your choice, provided that You also comply with the requirements of this License for the Covered Software. If the Larger Work is a combination of Covered Software with a work governed by one or more Secondary Licenses, and the Covered Software is not Incompatible With Secondary Licenses, this License permits You to additionally distribute such Covered Software under the terms of such Secondary License(s), so that the recipient of the Larger Work may, at their option, further distribute the Covered Software under the terms of either this License or such Secondary License(s). 3.4. Notices You may not remove or alter the substance of any license notices (including copyright notices, patent notices, disclaimers of warranty, or limitations of liability) contained within the Source Code Form of the Covered Software, except that You may alter any license notices to the extent required to remedy known factual inaccuracies. 3.5. Application of Additional Terms You may choose to offer, and to charge a fee for, warranty, support, indemnity or liability obligations to one or more recipients of Covered Software. However, You may do so only on Your own behalf, and not on behalf of any Contributor. You must make it absolutely clear that any such warranty, support, indemnity, or liability obligation is offered by You alone, and You hereby agree to indemnify every Contributor for any liability incurred by such Contributor as a result of warranty, support, indemnity or liability terms You offer. You may include additional disclaimers of warranty and limitations of liability specific to any jurisdiction. 4. Inability to Comply Due to Statute or Regulation If it is impossible for You to comply with any of the terms of this License with respect to some or all of the Covered Software due to statute, judicial order, or regulation then You must: (a) comply with the terms of this License to the maximum extent possible; and (b) describe the limitations and the code they affect. Such description must be placed in a text file included with all distributions of the Covered Software under this License. Except to the extent prohibited by statute or regulation, such description must be sufficiently detailed for a recipient of ordinary skill to be able to understand it. 5. Termination 5.1. The rights granted under this License will terminate automatically if You fail to comply with any of its terms. However, if You become compliant, then the rights granted under this License from a particular Contributor are reinstated (a) provisionally, unless and until such Contributor explicitly and finally terminates Your grants, and (b) on an ongoing basis, if such Contributor fails to notify You of the non-compliance by some reasonable means prior to 60 days after You have come back into compliance. Moreover, Your grants from a particular Contributor are reinstated on an ongoing basis if such Contributor notifies You of the non-compliance by some reasonable means, this is the first time You have received notice of non-compliance with this License from such Contributor, and You become compliant prior to 30 days after Your receipt of the notice. 5.2. If You initiate litigation against any entity by asserting a patent infringement claim (excluding declaratory judgment actions, counter-claims, and cross-claims) alleging that a Contributor Version directly or indirectly infringes any patent, then the rights granted to You by any and all Contributors for the Covered Software under Section 2.1 of this License shall terminate. 5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user license agreements (excluding distributors and resellers) which have been validly granted by You or Your distributors under this License prior to termination shall survive termination. 6. Disclaimer of Warranty Covered Software is provided under this License on an “as is” basis, without warranty of any kind, either expressed, implied, or statutory, including, without limitation, warranties that the Covered Software is free of defects, merchantable, fit for a particular purpose or non-infringing. The entire risk as to the quality and performance of the Covered Software is with You. Should any Covered Software prove defective in any respect, You (not any Contributor) assume the cost of any necessary servicing, repair, or correction. This disclaimer of warranty constitutes an essential part of this License. No use of any Covered Software is authorized under this License except under this disclaimer. 7. Limitation of Liability Under no circumstances and under no legal theory, whether tort (including negligence), contract, or otherwise, shall any Contributor, or anyone who distributes Covered Software as permitted above, be liable to You for any direct, indirect, special, incidental, or consequential damages of any character including, without limitation, damages for lost profits, loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses, even if such party shall have been informed of the possibility of such damages. This limitation of liability shall not apply to liability for death or personal injury resulting from such party’s negligence to the extent applicable law prohibits such limitation. Some jurisdictions do not allow the exclusion or limitation of incidental or consequential damages, so this exclusion and limitation may not apply to You. 8. Litigation Any litigation relating to this License may be brought only in the courts of a jurisdiction where the defendant maintains its principal place of business and such litigation shall be governed by laws of that jurisdiction, without reference to its conflict-of-law provisions. Nothing in this Section shall prevent a party’s ability to bring cross-claims or counter-claims. 9. Miscellaneous This License represents the complete agreement concerning the subject matter hereof. If any provision of this License is held to be unenforceable, such provision shall be reformed only to the extent necessary to make it enforceable. Any law or regulation which provides that the language of a contract shall be construed against the drafter shall not be used to construe this License against a Contributor. 10. Versions of the License 10.1. New Versions Mozilla Foundation is the license steward. Except as provided in Section 10.3, no one other than the license steward has the right to modify or publish new versions of this License. Each version will be given a distinguishing version number. 10.2. Effect of New Versions You may distribute the Covered Software under the terms of the version of the License under which You originally received the Covered Software, or under the terms of any subsequent version published by the license steward. 10.3. Modified Versions If you create software not governed by this License, and you want to create a new license for such software, you may create and use a modified version of this License if you rename the license and remove any references to the name of the license steward (except to note that such modified license differs from this License). 10.4. Distributing Source Code Form that is Incompatible With Secondary Licenses If You choose to distribute Source Code Form that is Incompatible With Secondary Licenses under the terms of this version of the License, the notice described in Exhibit B of this License must be attached. Exhibit A - Source Code Form License Notice This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. If it is not possible or desirable to put the notice in a particular file, then You may include the notice in a location (such as a LICENSE file in a relevant directory) where a recipient would be likely to look for such a notice. You may add additional accurate notices of copyright ownership. Exhibit B - “Incompatible With Secondary Licenses” Notice This Source Code Form is “Incompatible With Secondary Licenses”, as defined by the Mozilla Public License, v. 2.0. ================================================ FILE: licenses/LICENSE-protobuf.txt ================================================ Copyright (c) 2018 The Go Authors. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of Google Inc. nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: licenses/LICENSE-protoreflect.txt ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: licenses/LICENSE-sprig.txt ================================================ Copyright (C) 2013-2020 Masterminds Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: licenses/LICENSE-yaml.txt ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "{}" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright {yyyy} {name of copyright owner} Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: pkg/app/client/client.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package client import ( "bytes" "context" "fmt" "io" "reflect" "strings" "sync" "time" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/internal/nocopy" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/hlog" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/network/dialer" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/client" "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/protocol/http1" "github.com/cloudwego/hertz/pkg/protocol/http1/factory" "github.com/cloudwego/hertz/pkg/protocol/suite" ) var ( errorInvalidURI = errors.NewPublic("invalid uri") errorLastMiddlewareExist = errors.NewPublic("last middleware already set") ) // Do performs the given http request and fills the given http response. // // Request must contain at least non-zero RequestURI with full url (including // scheme and host) or non-zero Host header + RequestURI.© // // Client determines the server to be requested in the following order: // // - from RequestURI if it contains full url with scheme and host; // - from Host header otherwise. // // The function doesn't follow redirects. Use Get* for following redirects. // // Response is ignored if resp is nil. // // If MaxConnsPerHost is configured (> 0), ErrNoFreeConns is returned // when all connections to the requested host are busy. // // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. func Do(ctx context.Context, req *protocol.Request, resp *protocol.Response) error { return defaultClient.Do(ctx, req, resp) } // DoTimeout performs the given request and waits for response during // the given timeout duration. // // Request must contain at least non-zero RequestURI with full url (including // scheme and host) or non-zero Host header + RequestURI. // // Client determines the server to be requested in the following order: // // - from RequestURI if it contains full url with scheme and host; // - from Host header otherwise. // // The function doesn't follow redirects. Use Get* for following redirects. // // Response is ignored if resp is nil. // // errTimeout is returned if the response wasn't returned during // the given timeout. // // If MaxConnsPerHost is configured (> 0), ErrNoFreeConns is returned // when all connections to the requested host are busy. // // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. // // Warning: DoTimeout does not terminate the request itself. The request will // continue in the background and the response will be discarded. // If requests take too long and the connection pool gets filled up please // try using a customized Client instance with a ReadTimeout config or set the request level read timeout like: // `req.SetOptions(config.WithReadTimeout(1 * time.Second))` func DoTimeout(ctx context.Context, req *protocol.Request, resp *protocol.Response, timeout time.Duration) error { return defaultClient.DoTimeout(ctx, req, resp, timeout) } // DoDeadline performs the given request and waits for response until // the given deadline. // // Request must contain at least non-zero RequestURI with full url (including // scheme and host) or non-zero Host header + RequestURI. // // Client determines the server to be requested in the following order: // // - from RequestURI if it contains full url with scheme and host; // - from Host header otherwise. // // The function doesn't follow redirects. Use Get* for following redirects. // // Response is ignored if resp is nil. // // errTimeout is returned if the response wasn't returned until // the given deadline. // // If MaxConnsPerHost is configured (> 0), ErrNoFreeConns is returned // when all connections to the requested host are busy. // // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. // // Warning: DoDeadline does not terminate the request itself. The request will // continue in the background and the response will be discarded. // If requests take too long and the connection pool gets filled up please // try using a customized Client instance with a ReadTimeout config or set the request level read timeout like: // `req.SetOptions(config.WithReadTimeout(1 * time.Second))` func DoDeadline(ctx context.Context, req *protocol.Request, resp *protocol.Response, deadline time.Time) error { return defaultClient.DoDeadline(ctx, req, resp, deadline) } // DoRedirects performs the given http request and fills the given http response, // following up to maxRedirectsCount redirects. When the redirect count exceeds // maxRedirectsCount, ErrTooManyRedirects is returned. // // Request must contain at least non-zero RequestURI with full url (including // scheme and host) or non-zero Host header + RequestURI. // // Client determines the server to be requested in the following order: // // - from RequestURI if it contains full url with scheme and host; // - from Host header otherwise. // // Response is ignored if resp is nil. // // If MaxConnsPerHost is configured (> 0), ErrNoFreeConns is returned // when all connections to the requested host are busy. // // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. func DoRedirects(ctx context.Context, req *protocol.Request, resp *protocol.Response, maxRedirectsCount int) error { _, _, err := client.DoRequestFollowRedirects(ctx, req, resp, req.URI().String(), maxRedirectsCount, defaultClient) return err } // Get returns the status code and body of url. // // The contents of dst will be replaced by the body and returned, if the dst // is too small a new slice will be allocated. // // The function follows redirects. Use Do* for manually handling redirects. func Get(ctx context.Context, dst []byte, url string, requestOptions ...config.RequestOption) (statusCode int, body []byte, err error) { return defaultClient.Get(ctx, dst, url, requestOptions...) } // GetTimeout returns the status code and body of url. // // The contents of dst will be replaced by the body and returned, if the dst // is too small a new slice will be allocated. // // The function follows redirects. Use Do* for manually handling redirects. // // errTimeout error is returned if url contents couldn't be fetched // during the given timeout. // // Warning: GetTimeout does not terminate the request itself. The request will // continue in the background and the response will be discarded. // If requests take too long and the connection pool gets filled up please // try using a customized Client instance with a ReadTimeout config or set the request level read timeout like: // `GetTimeout(ctx, dst, url, timeout, config.WithReadTimeout(1 * time.Second))` func GetTimeout(ctx context.Context, dst []byte, url string, timeout time.Duration, requestOptions ...config.RequestOption) (statusCode int, body []byte, err error) { return defaultClient.GetTimeout(ctx, dst, url, timeout, requestOptions...) } // GetDeadline returns the status code and body of url. // // The contents of dst will be replaced by the body and returned, if the dst // is too small a new slice will be allocated. // // The function follows redirects. Use Do* for manually handling redirects. // // errTimeout error is returned if url contents couldn't be fetched // until the given deadline. // // Warning: GetDeadline does not terminate the request itself. The request will // continue in the background and the response will be discarded. // If requests take too long and the connection pool gets filled up please // try using a customized Client instance with a ReadTimeout config or set the request level read timeout like: // `GetDeadline(ctx, dst, url, timeout, config.WithReadTimeout(1 * time.Second))` func GetDeadline(ctx context.Context, dst []byte, url string, deadline time.Time, requestOptions ...config.RequestOption) (statusCode int, body []byte, err error) { return defaultClient.GetDeadline(ctx, dst, url, deadline, requestOptions...) } // Post sends POST request to the given url with the given POST arguments. // // The contents of dst will be replaced by the body and returned, if the dst // is too small a new slice will be allocated. // // The function follows redirects. Use Do* for manually handling redirects. // // Empty POST body is sent if postArgs is nil. func Post(ctx context.Context, dst []byte, url string, postArgs *protocol.Args, requestOptions ...config.RequestOption) (statusCode int, body []byte, err error) { return defaultClient.Post(ctx, dst, url, postArgs, requestOptions...) } var defaultClient, _ = NewClient(WithDialTimeout(consts.DefaultDialTimeout)) // Client implements http client. // // Copying Client by value is prohibited. Create new instance instead. // // It is safe calling Client methods from concurrently running goroutines. type Client struct { noCopy nocopy.NoCopy //lint:ignore U1000 until noCopy is used options *config.ClientOptions // Proxy specifies a function to return a proxy for a given // Request. If the function returns a non-nil error, the // request is aborted with the provided error. // // The proxy type is determined by the URL scheme. // "http" and "https" are supported. If the scheme is empty, // "http" is assumed. // // If Proxy is nil or returns a nil *URL, no proxy is used. Proxy protocol.Proxy // RetryIfFunc sets the retry decision function. If nil, the client.DefaultRetryIf will be applied. RetryIfFunc client.RetryIfFunc clientFactory suite.ClientFactory mLock sync.Mutex m map[string]client.HostClient ms map[string]client.HostClient mws Middleware lastMiddleware Middleware } func (c *Client) GetOptions() *config.ClientOptions { return c.options } func (c *Client) SetRetryIfFunc(retryIf client.RetryIfFunc) { c.RetryIfFunc = retryIf } // Deprecated: use SetRetryIfFunc instead of SetRetryIf func (c *Client) SetRetryIf(fn func(request *protocol.Request) bool) { f := func(req *protocol.Request, resp *protocol.Response, err error) bool { return fn(req) } c.SetRetryIfFunc(f) } // SetProxy is used to set client proxy. // // Don't SetProxy twice for a client. // If you want to use another proxy, please create another client and set proxy to it. func (c *Client) SetProxy(p protocol.Proxy) { c.Proxy = p } // Get returns the status code and body of url. // // The contents of dst will be replaced by the body and returned, if the dst // is too small a new slice will be allocated. // // The function follows redirects. Use Do* for manually handling redirects. func (c *Client) Get(ctx context.Context, dst []byte, url string, requestOptions ...config.RequestOption) (statusCode int, body []byte, err error) { return client.GetURL(ctx, dst, url, c, requestOptions...) } // GetTimeout returns the status code and body of url. // // The contents of dst will be replaced by the body and returned, if the dst // is too small a new slice will be allocated. // // The function follows redirects. Use Do* for manually handling redirects. // // errTimeout error is returned if url contents couldn't be fetched // during the given timeout. func (c *Client) GetTimeout(ctx context.Context, dst []byte, url string, timeout time.Duration, requestOptions ...config.RequestOption) (statusCode int, body []byte, err error) { return client.GetURLTimeout(ctx, dst, url, timeout, c, requestOptions...) } // GetDeadline returns the status code and body of url. // // The contents of dst will be replaced by the body and returned, if the dst // is too small a new slice will be allocated. // // The function follows redirects. Use Do* for manually handling redirects. // // errTimeout error is returned if url contents couldn't be fetched // until the given deadline. func (c *Client) GetDeadline(ctx context.Context, dst []byte, url string, deadline time.Time, requestOptions ...config.RequestOption) (statusCode int, body []byte, err error) { return client.GetURLDeadline(ctx, dst, url, deadline, c, requestOptions...) } // Post sends POST request to the given url with the given POST arguments. // // The contents of dst will be replaced by the body and returned, if the dst // is too small a new slice will be allocated. // // The function follows redirects. Use Do* for manually handling redirects. // // Empty POST body is sent if postArgs is nil. func (c *Client) Post(ctx context.Context, dst []byte, url string, postArgs *protocol.Args, requestOptions ...config.RequestOption) (statusCode int, body []byte, err error) { return client.PostURL(ctx, dst, url, postArgs, c, requestOptions...) } // DoTimeout performs the given request and waits for response during // the given timeout duration. // // Request must contain at least non-zero RequestURI with full url (including // scheme and host) or non-zero Host header + RequestURI. // // Client determines the server to be requested in the following order: // // - from RequestURI if it contains full url with scheme and host; // - from Host header otherwise. // // The function doesn't follow redirects. Use Get* for following redirects. // // Response is ignored if resp is nil. // // errTimeout is returned if the response wasn't returned during // the given timeout. // // If MaxConnsPerHost is configured (> 0), ErrNoFreeConns is returned // when all connections to the requested host are busy. // // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. // // Warning: DoTimeout does not terminate the request itself. The request will // continue in the background and the response will be discarded. // If requests take too long and the connection pool gets filled up please // try using a customized Client instance with a ReadTimeout config or set the request level read timeout like: // `req.SetOptions(config.WithReadTimeout(1 * time.Second))` func (c *Client) DoTimeout(ctx context.Context, req *protocol.Request, resp *protocol.Response, timeout time.Duration) error { return client.DoTimeout(ctx, req, resp, timeout, c) } // DoDeadline performs the given request and waits for response until // the given deadline. // // Request must contain at least non-zero RequestURI with full url (including // scheme and host) or non-zero Host header + RequestURI. // // Client determines the server to be requested in the following order: // // - from RequestURI if it contains full url with scheme and host; // - from Host header otherwise. // // The function doesn't follow redirects. Use Get* for following redirects. // // Response is ignored if resp is nil. // // errTimeout is returned if the response wasn't returned until // the given deadline. // // If MaxConnsPerHost is configured (> 0), ErrNoFreeConns is returned // when all connections to the requested host are busy. // // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. func (c *Client) DoDeadline(ctx context.Context, req *protocol.Request, resp *protocol.Response, deadline time.Time) error { return client.DoDeadline(ctx, req, resp, deadline, c) } // DoRedirects performs the given http request and fills the given http response, // following up to maxRedirectsCount redirects. When the redirect count exceeds // maxRedirectsCount, ErrTooManyRedirects is returned. // // Request must contain at least non-zero RequestURI with full url (including // scheme and host) or non-zero Host header + RequestURI. // // Client determines the server to be requested in the following order: // // - from RequestURI if it contains full url with scheme and host; // - from Host header otherwise. // // Response is ignored if resp is nil. // // If MaxConnsPerHost is configured (> 0), ErrNoFreeConns is returned // when all connections to the requested host are busy. // // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. func (c *Client) DoRedirects(ctx context.Context, req *protocol.Request, resp *protocol.Response, maxRedirectsCount int) error { _, _, err := client.DoRequestFollowRedirects(ctx, req, resp, req.URI().String(), maxRedirectsCount, c) return err } // Do performs the given http request and fills the given http response. // // Request must contain at least non-zero RequestURI with full url (including // scheme and host) or non-zero Host header + RequestURI. // // Client determines the server to be requested in the following order: // // - from RequestURI if it contains full url with scheme and host; // - from Host header otherwise. // // Response is ignored if resp is nil. // // The function doesn't follow redirects. Use Get* for following redirects. // // If MaxConnsPerHost is configured (> 0), ErrNoFreeConns is returned // when all connections to the requested host are busy. // // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. func (c *Client) Do(ctx context.Context, req *protocol.Request, resp *protocol.Response) error { if c.mws == nil { return c.do(ctx, req, resp) } if c.lastMiddleware != nil { return c.mws(c.lastMiddleware(c.do))(ctx, req, resp) } return c.mws(c.do)(ctx, req, resp) } func (c *Client) do(ctx context.Context, req *protocol.Request, resp *protocol.Response) error { if !c.options.KeepAlive { req.Header.SetConnectionClose(true) } uri := req.URI() if uri == nil { return errorInvalidURI } var proxyURI *protocol.URI var err error if c.Proxy != nil { proxyURI, err = c.Proxy(req) if err != nil { return fmt.Errorf("proxy error=%w", err) } } isTLS := false scheme := uri.Scheme() if bytes.Equal(scheme, bytestr.StrHTTPS) { isTLS = true } else if !bytes.Equal(scheme, bytestr.StrHTTP) && !bytes.Equal(scheme, bytestr.StrSD) { return fmt.Errorf("unsupported protocol %q. http and https are supported", scheme) } host := uri.Host() startCleaner := false c.mLock.Lock() m := c.m if isTLS { m = c.ms } h := string(host) hc := m[h] if hc == nil { if c.clientFactory == nil { // load http1 client by default c.clientFactory = factory.NewClientFactory(newHttp1OptionFromClient(c)) } hc, _ = c.clientFactory.NewHostClient() hc.SetDynamicConfig(&client.DynamicConfig{ Addr: utils.AddMissingPort(h, isTLS), ProxyURI: proxyURI, IsTLS: isTLS, }) // re-configure hook if c.options.HostClientConfigHook != nil { err = c.options.HostClientConfigHook(hc) if err != nil { c.mLock.Unlock() return err } } m[h] = hc if len(m) == 1 { startCleaner = true } } c.mLock.Unlock() if startCleaner { go c.cleaner(isTLS) } return hc.Do(ctx, req, resp) } // CloseIdleConnections closes any connections which were previously // connected from previous requests but are now sitting idle in a // "keep-alive" state. It does not interrupt any connections currently // in use. func (c *Client) CloseIdleConnections() { c.mLock.Lock() for _, v := range c.m { v.CloseIdleConnections() } for _, v := range c.ms { v.CloseIdleConnections() } c.mLock.Unlock() } func (c *Client) cleaner(isTLS bool) { for { time.Sleep(10 * time.Second) if c.cleanHostClients(isTLS) { break } } } func (c *Client) cleanHostClients(isTLS bool) bool { c.mLock.Lock() defer c.mLock.Unlock() m := c.m if isTLS { m = c.ms } for k, v := range m { if v.ShouldRemove() { delete(m, k) if f, ok := v.(io.Closer); ok { err := f.Close() if err != nil { hlog.Warnf("clean hostclient error, addr: %s, err: %s", k, err.Error()) } } } } return len(m) == 0 } func (c *Client) SetClientFactory(cf suite.ClientFactory) { c.clientFactory = cf } // GetDialerName returns the name of the dialer func (c *Client) GetDialerName() (dName string, err error) { defer func() { err := recover() if err != nil { dName = "unknown" } }() opt := c.GetOptions() if opt == nil || opt.Dialer == nil { return "", fmt.Errorf("abnormal process: there is no client options or dialer") } dName = reflect.TypeOf(opt.Dialer).String() dSlice := strings.Split(dName, ".") dName = dSlice[0] if dName[0] == '*' { dName = dName[1:] } return } // NewClient return a client with options func NewClient(opts ...config.ClientOption) (*Client, error) { opt := config.NewClientOptions(opts) if opt.Dialer == nil { opt.Dialer = dialer.DefaultDialer() } c := &Client{ options: opt, m: make(map[string]client.HostClient), ms: make(map[string]client.HostClient), } return c, nil } func (c *Client) Use(mws ...Middleware) { // Put the original middlewares to the first middlewares := make([]Middleware, 0, 1+len(mws)) if c.mws != nil { middlewares = append(middlewares, c.mws) } middlewares = append(middlewares, mws...) c.mws = chain(middlewares...) } // UseAsLast is used to add middleware to the end of the middleware chain. // // Will return an error if last middleware has been set before, to ensure all middleware has the change to work, // Please use `TakeOutLastMiddleware` to take out the already set middleware. // Chain the middleware after or before is both Okay - but remember to put it back. func (c *Client) UseAsLast(mw Middleware) error { if c.lastMiddleware != nil { return errorLastMiddlewareExist } c.lastMiddleware = mw return nil } // TakeOutLastMiddleware will return the set middleware and remove it from client. // // Remember to set it back after chain it with other middleware. func (c *Client) TakeOutLastMiddleware() Middleware { last := c.lastMiddleware c.lastMiddleware = nil return last } func newHttp1OptionFromClient(c *Client) *http1.ClientOptions { return &http1.ClientOptions{ Name: c.options.Name, NoDefaultUserAgentHeader: c.options.NoDefaultUserAgentHeader, Dialer: c.options.Dialer, DialTimeout: c.options.DialTimeout, DialDualStack: c.options.DialDualStack, TLSConfig: c.options.TLSConfig, MaxConns: c.options.MaxConnsPerHost, MaxConnDuration: c.options.MaxConnDuration, MaxIdleConnDuration: c.options.MaxIdleConnDuration, ReadTimeout: c.options.ReadTimeout, WriteTimeout: c.options.WriteTimeout, MaxResponseBodySize: c.options.MaxResponseBodySize, DisableHeaderNamesNormalizing: c.options.DisableHeaderNamesNormalizing, DisablePathNormalizing: c.options.DisablePathNormalizing, MaxConnWaitTimeout: c.options.MaxConnWaitTimeout, ResponseBodyStream: c.options.ResponseBodyStream, RetryConfig: c.options.RetryConfig, RetryIfFunc: c.RetryIfFunc, StateObserve: c.options.HostClientStateObserve, ObservationInterval: c.options.ObservationInterval, } } ================================================ FILE: pkg/app/client/client_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package client import ( "context" "crypto/tls" "encoding/base64" "errors" "fmt" "io" "io/ioutil" "net" "net/http" "net/http/httptest" "net/url" "os" "path" "path/filepath" "reflect" "regexp" "runtime" "strings" "sync" "sync/atomic" "testing" "time" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/internal/testutils" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/app/client/retry" "github.com/cloudwego/hertz/pkg/common/config" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/network/dialer" "github.com/cloudwego/hertz/pkg/network/standard" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/client" "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/protocol/http1" "github.com/cloudwego/hertz/pkg/protocol/http1/req" "github.com/cloudwego/hertz/pkg/protocol/http1/resp" "github.com/cloudwego/hertz/pkg/route" ) var errTooManyRedirects = errors.New("too many redirects detected when doing the request") func assertNil(err error) { if err != nil { panic(err) } } func waitEngineRunning(e *route.Engine) { testutils.WaitEngineRunning(e) } func newTestOptions(t *testing.T) (*config.Options, net.Listener) { ln := testutils.NewTestListener(t) opt := config.NewOptions([]config.Option{}) opt.Listener = ln opt.Addr = ln.Addr().String() opt.Network = "tcp" return opt, ln } func fullURL(ln net.Listener, p string) string { return "http://" + path.Join(ln.Addr().String(), p) } func TestCloseIdleConnections(t *testing.T) { opt, ln := newTestOptions(t) defer ln.Close() engine := route.NewEngine(opt) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) if _, _, err := c.Get(context.Background(), nil, "http://google.com"); err != nil { t.Fatal(err) } connsLen := func() int { c.mLock.Lock() defer c.mLock.Unlock() if _, ok := c.m["google.com"]; !ok { return 0 } return c.m["google.com"].ConnectionCount() } if conns := connsLen(); conns > 1 { t.Errorf("expected 1 conns got %d", conns) } c.CloseIdleConnections() if conns := connsLen(); conns > 0 { t.Errorf("expected 0 conns got %d", conns) } c.cleanHostClients(false) func() { c.mLock.Lock() defer c.mLock.Unlock() if len(c.m) != 0 { t.Errorf("expected 0 conns got %d", len(c.m)) } }() } func TestCloseIdleTLSConnections(t *testing.T) { httpsServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("https response")) })) defer httpsServer.Close() c, _ := NewClient( WithTLSConfig(httpsServer.Client().Transport.(*http.Transport).TLSClientConfig), WithDialTimeout(1*time.Second), ) httpsReq, httpsResp := protocol.AcquireRequest(), protocol.AcquireResponse() defer func() { protocol.ReleaseRequest(httpsReq) protocol.ReleaseResponse(httpsResp) }() httpsReq.SetRequestURI(httpsServer.URL) if err := c.Do(context.Background(), httpsReq, httpsResp); err != nil { t.Fatalf("HTTPS request failed: %v", err) } c.CloseIdleConnections() c.mLock.Lock() var totalConns int for _, hc := range c.ms { totalConns += hc.ConnectionCount() } c.mLock.Unlock() if totalConns > 0 { t.Errorf("expected 0 HTTPS idle connections after close, got %d", totalConns) } c.cleanHostClients(true) c.mLock.Lock() defer c.mLock.Unlock() if len(c.ms) != 0 { t.Errorf("expected 0 HTTPS host clients, got %d", len(c.ms)) } } func TestClientInvalidURI(t *testing.T) { opt, ln := newTestOptions(t) defer ln.Close() requests := int64(0) engine := route.NewEngine(opt) engine.GET("/", func(c context.Context, ctx *app.RequestContext) { atomic.AddInt64(&requests, 1) }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) req, res := protocol.AcquireRequest(), protocol.AcquireResponse() defer func() { protocol.ReleaseRequest(req) protocol.ReleaseResponse(res) }() req.Header.SetMethod(consts.MethodGet) req.SetRequestURI("http://example.com\r\n\r\nGET /\r\n\r\n") err := c.Do(context.Background(), req, res) if err == nil { t.Fatal("expected error (missing required Host header in request)") } if n := atomic.LoadInt64(&requests); n != 0 { t.Fatalf("0 requests expected, got %d", n) } } func TestClientGetWithBody(t *testing.T) { opt, ln := newTestOptions(t) defer ln.Close() engine := route.NewEngine(opt) engine.GET("/", func(c context.Context, ctx *app.RequestContext) { body := ctx.Request.Body() ctx.Write(body) //nolint:errcheck }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) req, res := protocol.AcquireRequest(), protocol.AcquireResponse() defer func() { protocol.ReleaseRequest(req) protocol.ReleaseResponse(res) }() req.Header.SetMethod(consts.MethodGet) req.SetRequestURI("http://example.com") req.SetBodyString("test") err := c.Do(context.Background(), req, res) if err != nil { t.Fatal(err) } if len(res.Body()) == 0 { t.Fatal("missing request body") } } func TestClientPostBodyStream(t *testing.T) { opt, ln := newTestOptions(t) defer ln.Close() engine := route.NewEngine(opt) engine.POST("/", func(c context.Context, ctx *app.RequestContext) { body := ctx.Request.Body() ctx.Write(body) //nolint:errcheck }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) cStream, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil)), WithResponseBodyStream(true)) args := &protocol.Args{} // There is some data in databuf and others is in bodystream, so we need // to let the data exceed the max bodysize of bodystream v := "" for i := 0; i < 10240; i++ { v += "b" } args.Add("a", v) _, body, err := cStream.Post(context.Background(), nil, "http://example.com", args) if err != nil { t.Fatal(err) } assert.DeepEqual(t, "a="+v, string(body)) } func TestClientURLAuth(t *testing.T) { cases := map[string]string{ "foo:bar@": "Basic Zm9vOmJhcg==", "foo:@": "Basic Zm9vOg==", ":@": "", "@": "", "": "", } ch := make(chan string, 1) opt, ln := newTestOptions(t) defer ln.Close() engine := route.NewEngine(opt) engine.GET("/foo/bar", func(c context.Context, ctx *app.RequestContext) { ch <- string(ctx.Request.Header.Peek(consts.HeaderAuthorization)) }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) for up, expected := range cases { req := protocol.AcquireRequest() req.Header.SetMethod(consts.MethodGet) req.SetRequestURI("http://" + up + "example.com/foo/bar") if err := c.Do(context.Background(), req, nil); err != nil { t.Fatal(err) } val := <-ch if val != expected { t.Fatalf("wrong %s header: %s expected %s", consts.HeaderAuthorization, val, expected) } } } func TestClientNilResp(t *testing.T) { opt, ln := newTestOptions(t) defer ln.Close() engine := route.NewEngine(opt) engine.GET("/", func(c context.Context, ctx *app.RequestContext) { }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) req := protocol.AcquireRequest() req.Header.SetMethod(consts.MethodGet) req.SetRequestURI("http://example.com") if err := c.Do(context.Background(), req, nil); err != nil { t.Fatal(err) } if err := c.DoTimeout(context.Background(), req, nil, time.Second); err != nil { t.Fatal(err) } } func TestClientParseConn(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() opt := config.NewOptions([]config.Option{}) opt.Listener = ln engine := route.NewEngine(opt) engine.GET("/", func(c context.Context, ctx *app.RequestContext) { }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) opt.Addr = ln.Addr().String() c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) req, res := protocol.AcquireRequest(), protocol.AcquireResponse() defer func() { protocol.ReleaseRequest(req) protocol.ReleaseResponse(res) }() req.SetRequestURI("http://" + opt.Addr + "") if err := c.Do(context.Background(), req, res); err != nil { t.Fatal(err) } if res.RemoteAddr().Network() != opt.Network { t.Fatalf("req RemoteAddr parse network fail: %s, hope: %s", res.RemoteAddr().Network(), opt.Network) } if opt.Addr != res.RemoteAddr().String() { t.Fatalf("req RemoteAddr parse addr fail: %s, hope: %s", res.RemoteAddr().String(), opt.Addr) } if !regexp.MustCompile(`^127\.0\.0\.1:[0-9]{4,5}$`).MatchString(res.LocalAddr().String()) { t.Fatalf("res LocalAddr addr match fail: %s, hope match: %s", res.LocalAddr().String(), "^127.0.0.1:[0-9]{4,5}$") } } func TestClientPostArgs(t *testing.T) { opt, ln := newTestOptions(t) defer ln.Close() engine := route.NewEngine(opt) engine.POST("/", func(c context.Context, ctx *app.RequestContext) { body := ctx.Request.Body() if len(body) == 0 { return } ctx.Write(body) //nolint:errcheck }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) req, res := protocol.AcquireRequest(), protocol.AcquireResponse() defer func() { protocol.ReleaseRequest(req) protocol.ReleaseResponse(res) }() args := req.PostArgs() args.Add("addhttp2", "support") args.Add("fast", "http") req.Header.SetMethod(consts.MethodPost) req.SetRequestURI("http://make.hertz.great?again") err := c.Do(context.Background(), req, res) if err != nil { t.Fatal(err) } if len(res.Body()) == 0 { t.Fatal("cannot set args as body") } } func TestClientHeaderCase(t *testing.T) { opt, ln := newTestOptions(t) defer ln.Close() engine := route.NewEngine(opt) engine.GET("/", func(c context.Context, ctx *app.RequestContext) { zw := ctx.GetWriter() zw.WriteBinary([]byte("HTTP/1.1 200 OK\r\n" + //nolint:errcheck "content-type: text/plain\r\n" + "transfer-encoding: chunked\r\n\r\n" + "24\r\nThis is the data in the first chunk \r\n" + "1B\r\nand this is the second one \r\n" + "0\r\n\r\n", )) }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, time.Second, nil)), WithDisableHeaderNamesNormalizing(true)) code, body, err := c.Get(context.Background(), nil, "http://example.com") if err != nil { t.Error(err) } else if code != 200 { t.Errorf("expected status code 200 got %d", code) } else if string(body) != "This is the data in the first chunk and this is the second one " { t.Errorf("wrong body: %q", body) } } func TestClientReadTimeout(t *testing.T) { opt, ln := newTestOptions(t) defer ln.Close() engine := route.NewEngine(opt) readtimeout := 50 * time.Millisecond if runtime.GOOS == "windows" { // XXX: The windows CI instance powered by Github is quite unstable. // Increase readtimeout here for better testing stability. readtimeout = 2 * readtimeout } sleeptime := readtimeout + readtimeout/2 // must > readtimeout engine.GET("/normal", func(c context.Context, ctx *app.RequestContext) { ctx.String(201, "ok") }) engine.GET("/timeout", func(c context.Context, ctx *app.RequestContext) { time.Sleep(sleeptime) ctx.String(202, "timeout ok") }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) c := &http1.HostClient{ ClientOptions: &http1.ClientOptions{ ReadTimeout: readtimeout, RetryConfig: &retry.Config{MaxAttemptTimes: 1}, Dialer: newMockDialerWithCustomFunc(opt.Network, opt.Addr, readtimeout, nil), }, Addr: opt.Addr, } req := protocol.AcquireRequest() res := protocol.AcquireResponse() req.SetRequestURI("http://example.com/normal") req.Header.SetMethod(consts.MethodGet) // Setting Connection: Close will make the connection be returned to the pool. req.SetConnectionClose() if err := c.Do(context.Background(), req, res); err != nil { t.Fatal(err) } req.Reset() req.SetRequestURI("http://example.com/timeout") req.Header.SetMethod(consts.MethodGet) req.SetConnectionClose() res.Reset() t0 := time.Now() err := c.Do(context.Background(), req, res) t1 := time.Now() if !errors.Is(err, errs.ErrTimeout) { if err == nil { t.Errorf("expected ErrTimeout got nil, req url: %s, read resp body: %s, status: %d", string(req.URI().FullURI()), string(res.Body()), res.StatusCode()) } else { if !strings.Contains(err.Error(), "timeout") { t.Errorf("expected ErrTimeout got %#v", err) } } } protocol.ReleaseRequest(req) protocol.ReleaseResponse(res) if d := t1.Sub(t0) - readtimeout; d > readtimeout/2 { t.Errorf("timeout more than expected: %v", d) } else { t.Log("latency", d) } } func TestClientDefaultUserAgent(t *testing.T) { opt, ln := newTestOptions(t) defer ln.Close() engine := route.NewEngine(opt) engine.GET("/", func(c context.Context, ctx *app.RequestContext) { ctx.Data(consts.StatusOK, "text/plain; charset=utf-8", ctx.UserAgent()) }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) req := protocol.AcquireRequest() res := protocol.AcquireResponse() req.SetRequestURI("http://example.com") req.Header.SetMethod(consts.MethodGet) err := c.Do(context.Background(), req, res) if err != nil { t.Fatal(err) } if string(res.Body()) != string(bytestr.DefaultUserAgent) { t.Fatalf("User-Agent defers %q != %q", string(res.Body()), bytestr.DefaultUserAgent) } } func TestClientSetUserAgent(t *testing.T) { opt, ln := newTestOptions(t) defer ln.Close() engine := route.NewEngine(opt) engine.GET("/", func(c context.Context, ctx *app.RequestContext) { ctx.Data(consts.StatusOK, "text/plain; charset=utf-8", ctx.UserAgent()) }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) userAgent := "I'm not hertz" c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, time.Second, nil)), WithName(userAgent)) req := protocol.AcquireRequest() res := protocol.AcquireResponse() req.SetRequestURI("http://example.com") err := c.Do(context.Background(), req, res) if err != nil { t.Fatal(err) } if string(res.Body()) != userAgent { t.Fatalf("User-Agent defers %q != %q", string(res.Body()), userAgent) } } func TestClientNoUserAgent(t *testing.T) { opt, ln := newTestOptions(t) defer ln.Close() engine := route.NewEngine(opt) engine.GET("/", func(c context.Context, ctx *app.RequestContext) { ctx.Data(consts.StatusOK, "text/plain; charset=utf-8", ctx.UserAgent()) }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, time.Second, nil)), WithDialTimeout(1*time.Second), WithNoDefaultUserAgentHeader(true)) req := protocol.AcquireRequest() res := protocol.AcquireResponse() req.SetRequestURI("http://example.com") err := c.Do(context.Background(), req, res) if err != nil { t.Fatal(err) } if string(res.Body()) != "" { t.Fatalf("User-Agent wrong %q != %q", string(res.Body()), "") } } func TestClientDoWithCustomHeaders(t *testing.T) { ch := make(chan error) uri := "/foo/bar/baz?a=b&cd=12" headers := map[string]string{ "Foo": "bar", "Host": "xxx.com", "Content-Type": "asdfsdf", "a-b-c-d-f": "", } body := "request body" opt, ln := newTestOptions(t) defer ln.Close() engine := route.NewEngine(opt) engine.POST("/foo/bar/baz", func(c context.Context, ctx *app.RequestContext) { zw := ctx.GetWriter() if string(ctx.Request.Header.Method()) != consts.MethodPost { ch <- fmt.Errorf("unexpected request method: %q. Expecting %q", ctx.Request.Header.Method(), consts.MethodPost) return } reqURI := ctx.Request.RequestURI() if string(reqURI) != uri { ch <- fmt.Errorf("unexpected request uri: %q. Expecting %q", reqURI, uri) return } for k, v := range headers { hv := ctx.Request.Header.Peek(k) if string(hv) != v { ch <- fmt.Errorf("unexpected value for header %q: %q. Expecting %q", k, hv, v) return } } cl := ctx.Request.Header.ContentLength() if cl != len(body) { ch <- fmt.Errorf("unexpected content-length %d. Expecting %d", cl, len(body)) return } reqBody := ctx.Request.Body() if string(reqBody) != body { ch <- fmt.Errorf("unexpected request body: %q. Expecting %q", reqBody, body) return } var r protocol.Response if err := resp.Write(&r, zw); err != nil { ch <- fmt.Errorf("cannot send response: %s", err) return } if err := zw.Flush(); err != nil { ch <- fmt.Errorf("cannot flush response: %s", err) return } ch <- nil }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) // make sure that the client sends all the request headers and body. c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) var req protocol.Request req.Header.SetMethod(consts.MethodPost) req.SetRequestURI(uri) for k, v := range headers { req.Header.Set(k, v) } req.SetBodyString(body) var resp protocol.Response err := c.DoTimeout(context.Background(), &req, &resp, time.Second) if err != nil { t.Fatalf("error when doing request: %s", err) } select { case <-ch: case <-time.After(5 * time.Second): t.Fatalf("timeout") } } func TestClientDoTimeoutDisablePathNormalizing(t *testing.T) { opt, ln := newTestOptions(t) defer ln.Close() engine := route.NewEngine(opt) engine.Use(func(c context.Context, ctx *app.RequestContext) { uri := ctx.URI() uri.DisablePathNormalizing = true ctx.Response.Header.Set("received-uri", string(uri.FullURI())) }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, time.Second, nil)), WithDisablePathNormalizing(true)) urlWithEncodedPath := "http://example.com/encoded/Y%2BY%2FY%3D/stuff" var req protocol.Request req.SetRequestURI(urlWithEncodedPath) var resp protocol.Response for i := 0; i < 5; i++ { if err := c.DoTimeout(context.Background(), &req, &resp, time.Second); err != nil { t.Fatalf("unexpected error: %s", err) } hv := resp.Header.Peek("received-uri") if string(hv) != urlWithEncodedPath { t.Fatalf("request uri was normalized: %q. Expecting %q", hv, urlWithEncodedPath) } } } func TestHostClientPendingRequests(t *testing.T) { const concurrency = 10 doneCh := make(chan struct{}) readyCh := make(chan struct{}, concurrency) opt, ln := newTestOptions(t) defer ln.Close() engine := route.NewEngine(opt) engine.GET("/baz", func(c context.Context, ctx *app.RequestContext) { readyCh <- struct{}{} <-doneCh }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) c := &http1.HostClient{ ClientOptions: &http1.ClientOptions{ Dialer: newMockDialerWithCustomFunc(opt.Network, opt.Addr, time.Second, nil), }, Addr: "foobar", } pendingRequests := c.PendingRequests() if pendingRequests != 0 { t.Fatalf("non-zero pendingRequests: %d", pendingRequests) } resultCh := make(chan error, concurrency) for i := 0; i < concurrency; i++ { go func() { req := protocol.AcquireRequest() req.SetRequestURI("http://foobar/baz") req.Header.SetMethod(consts.MethodGet) resp := protocol.AcquireResponse() if err := c.DoTimeout(context.Background(), req, resp, 10*time.Second); err != nil { resultCh <- fmt.Errorf("unexpected error: %s", err) return } if resp.StatusCode() != consts.StatusOK { resultCh <- fmt.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), consts.StatusOK) return } resultCh <- nil }() } // wait until all the requests reach server for i := 0; i < concurrency; i++ { select { case <-readyCh: case <-time.After(time.Second): t.Fatalf("timeout") } } pendingRequests = c.PendingRequests() if pendingRequests != concurrency { t.Fatalf("unexpected pendingRequests: %d. Expecting %d", pendingRequests, concurrency) } // unblock request handlers on the server and wait until all the requests are finished. close(doneCh) for i := 0; i < concurrency; i++ { select { case err := <-resultCh: if err != nil { t.Fatalf("unexpected error: %s", err) } case <-time.After(time.Second): t.Fatalf("timeout") } } pendingRequests = c.PendingRequests() if pendingRequests != 0 { t.Fatalf("non-zero pendingRequests: %d", pendingRequests) } } func TestHostClientMaxConnsWithDeadline(t *testing.T) { var ( emptyBodyCount uint8 timeout = 50 * time.Millisecond wg sync.WaitGroup ) opt, ln := newTestOptions(t) defer ln.Close() engine := route.NewEngine(opt) engine.POST("/baz", func(c context.Context, ctx *app.RequestContext) { if len(ctx.Request.Body()) == 0 { emptyBodyCount++ } ctx.WriteString("foo") //nolint:errcheck }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) c := &http1.HostClient{ ClientOptions: &http1.ClientOptions{ Dialer: newMockDialerWithCustomFunc(opt.Network, opt.Addr, time.Second, nil), MaxConns: 1, }, Addr: "foobar", } for i := 0; i < 5; i++ { wg.Add(1) go func() { defer wg.Done() req := protocol.AcquireRequest() req.SetRequestURI("http://foobar/baz") req.Header.SetMethod(consts.MethodPost) req.SetBodyString("bar") resp := protocol.AcquireResponse() for { if err := c.DoDeadline(context.Background(), req, resp, time.Now().Add(timeout)); err != nil { if err.Error() == errs.ErrNoFreeConns.Error() { time.Sleep(10 * time.Millisecond) continue } t.Errorf("unexpected error: %s", err) } break } if resp.StatusCode() != consts.StatusOK { t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), consts.StatusOK) } body := resp.Body() if string(body) != "foo" { t.Errorf("unexpected body %q. Expecting %q", body, "abcd") } }() } wg.Wait() if emptyBodyCount > 0 { t.Fatalf("at least one request body was empty") } } func TestHostClientMaxConnDuration(t *testing.T) { connectionCloseCount := uint32(0) opt, ln := newTestOptions(t) defer ln.Close() engine := route.NewEngine(opt) engine.GET("/bbb/cc", func(c context.Context, ctx *app.RequestContext) { ctx.WriteString("abcd") //nolint:errcheck if ctx.Request.ConnectionClose() { atomic.AddUint32(&connectionCloseCount, 1) } }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) c := &http1.HostClient{ ClientOptions: &http1.ClientOptions{ Dialer: newMockDialerWithCustomFunc(opt.Network, opt.Addr, time.Second, nil), MaxConnDuration: 10 * time.Millisecond, }, Addr: "foobar", } for i := 0; i < 5; i++ { statusCode, body, err := c.Get(context.Background(), nil, "http://aaaa.com/bbb/cc") if err != nil { t.Fatalf("unexpected error: %s", err) } if statusCode != consts.StatusOK { t.Fatalf("unexpected status code %d. Expecting %d", statusCode, consts.StatusOK) } if string(body) != "abcd" { t.Fatalf("unexpected body %q. Expecting %q", body, "abcd") } time.Sleep(c.MaxConnDuration) } if atomic.LoadUint32(&connectionCloseCount) == 0 { t.Fatalf("expecting at least one 'Connection: close' request header") } } func TestHostClientMultipleAddrs(t *testing.T) { opt, ln := newTestOptions(t) defer ln.Close() engine := route.NewEngine(opt) engine.GET("/baz/aaa", func(c context.Context, ctx *app.RequestContext) { ctx.Write(ctx.Host()) //nolint:errcheck ctx.SetConnectionClose() }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) dialsCount := make(map[string]int) c := &http1.HostClient{ ClientOptions: &http1.ClientOptions{ Dialer: newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, func(network, addr string, timeout time.Duration, tlsConfig *tls.Config) { dialsCount[addr]++ }), }, Addr: "foo,bar,baz", } for i := 0; i < 9; i++ { statusCode, body, err := c.Get(context.Background(), nil, "http://foobar/baz/aaa?bbb=ddd") if err != nil { t.Fatalf("unexpected error: %s", err) } if statusCode != consts.StatusOK { t.Fatalf("unexpected status code %d. Expecting %d", statusCode, consts.StatusOK) } if string(body) != "foobar" { t.Fatalf("unexpected body %q. Expecting %q", body, "foobar") } } if len(dialsCount) != 3 { t.Fatalf("unexpected dialsCount size %d. Expecting 3", len(dialsCount)) } for _, k := range []string{"foo", "bar", "baz"} { if dialsCount[k] != 3 { t.Fatalf("unexpected dialsCount for %q. Expecting 3", k) } } } func TestClientFollowRedirects(t *testing.T) { opt, ln := newTestOptions(t) defer ln.Close() engine := route.NewEngine(opt) handler := func(c context.Context, ctx *app.RequestContext) { switch string(ctx.Path()) { case "/foo": u := ctx.URI() u.Update("/xy?z=wer") ctx.Redirect(consts.StatusFound, u.FullURI()) case "/xy": u := ctx.URI() u.Update("/bar") ctx.Redirect(consts.StatusFound, u.FullURI()) default: ctx.SetContentType(consts.MIMETextPlain) ctx.Response.SetBody(ctx.Path()) } } engine.GET("/foo", handler) engine.GET("/xy", handler) engine.GET("/bar", handler) engine.GET("/aaab/sss", handler) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) c := &http1.HostClient{ ClientOptions: &http1.ClientOptions{ Dialer: newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil), }, Addr: "xxx", } for i := 0; i < 10; i++ { statusCode, body, err := c.GetTimeout(context.Background(), nil, "http://xxx/foo", time.Second) if err != nil { t.Fatalf("unexpected error: %s", err) } if statusCode != consts.StatusOK { t.Fatalf("unexpected status code: %d", statusCode) } if string(body) != "/bar" { t.Fatalf("unexpected response %q. Expecting %q", body, "/bar") } } for i := 0; i < 10; i++ { statusCode, body, err := c.Get(context.Background(), nil, "http://xxx/aaab/sss") if err != nil { t.Fatalf("unexpected error: %s", err) } if statusCode != consts.StatusOK { t.Fatalf("unexpected status code: %d", statusCode) } if string(body) != "/aaab/sss" { t.Fatalf("unexpected response %q. Expecting %q", body, "/aaab/sss") } } for i := 0; i < 10; i++ { req := protocol.AcquireRequest() resp := protocol.AcquireResponse() req.SetRequestURI("http://xxx/foo") err := c.DoRedirects(context.Background(), req, resp, 16) if err != nil { t.Fatalf("unexpected error: %s", err) } if statusCode := resp.StatusCode(); statusCode != consts.StatusOK { t.Fatalf("unexpected status code: %d", statusCode) } if body := string(resp.Body()); body != "/bar" { t.Fatalf("unexpected response %q. Expecting %q", body, "/bar") } protocol.ReleaseRequest(req) protocol.ReleaseResponse(resp) } req := protocol.AcquireRequest() resp := protocol.AcquireResponse() req.SetRequestURI("http://xxx/foo") err := c.DoRedirects(context.Background(), req, resp, 0) if have, want := err, errTooManyRedirects; have.Error() != want.Error() { t.Fatalf("want error: %v, have %v", want, have) } protocol.ReleaseRequest(req) protocol.ReleaseResponse(resp) } func TestHostClientMaxConnWaitTimeoutSuccess(t *testing.T) { var ( emptyBodyCount uint8 wg sync.WaitGroup ) opt, ln := newTestOptions(t) defer ln.Close() engine := route.NewEngine(opt) engine.POST("/baz", func(c context.Context, ctx *app.RequestContext) { if len(ctx.Request.Body()) == 0 { emptyBodyCount++ } time.Sleep(5 * time.Millisecond) ctx.WriteString("foo") //nolint:errcheck }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) c := &http1.HostClient{ ClientOptions: &http1.ClientOptions{ Dialer: newMockDialerWithCustomFunc(opt.Network, opt.Addr, time.Second, nil), MaxConns: 1, MaxConnWaitTimeout: 200 * time.Millisecond, }, Addr: "foobar", } for i := 0; i < 5; i++ { wg.Add(1) go func() { defer wg.Done() req := protocol.AcquireRequest() req.SetRequestURI("http://foobar/baz") req.Header.SetMethod(consts.MethodPost) req.SetBodyString("bar") resp := protocol.AcquireResponse() if err := c.Do(context.Background(), req, resp); err != nil { t.Errorf("unexpected error: %s", err) } if resp.StatusCode() != consts.StatusOK { t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), consts.StatusOK) } body := resp.Body() if string(body) != "foo" { t.Errorf("unexpected body %q. Expecting %q", body, "abcd") } }() } wg.Wait() if c.WantConnectionCount() > 0 { t.Errorf("connsWait has %v items remaining", c.WantConnectionCount()) } if emptyBodyCount > 0 { t.Fatalf("at least one request body was empty") } } func TestHostClientMaxConnWaitTimeoutError(t *testing.T) { var ( emptyBodyCount uint8 wg sync.WaitGroup ) opt, ln := newTestOptions(t) defer ln.Close() engine := route.NewEngine(opt) engine.POST("/baz", func(c context.Context, ctx *app.RequestContext) { if len(ctx.Request.Body()) == 0 { emptyBodyCount++ } time.Sleep(5 * time.Millisecond) ctx.WriteString("foo") //nolint:errcheck }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) c := &http1.HostClient{ ClientOptions: &http1.ClientOptions{ Dialer: newMockDialerWithCustomFunc(opt.Network, opt.Addr, time.Second, nil), MaxConns: 1, MaxConnWaitTimeout: 10 * time.Millisecond, }, Addr: "foobar", } var errNoFreeConnsCount uint32 for i := 0; i < 5; i++ { wg.Add(1) go func() { defer wg.Done() req := protocol.AcquireRequest() req.SetRequestURI("http://foobar/baz") req.Header.SetMethod(consts.MethodPost) req.SetBodyString("bar") resp := protocol.AcquireResponse() if err := c.Do(context.Background(), req, resp); err != nil { if err.Error() != errs.ErrNoFreeConns.Error() { t.Errorf("unexpected error: %s. Expecting %s", err.Error(), errs.ErrNoFreeConns.Error()) } atomic.AddUint32(&errNoFreeConnsCount, 1) } else { if resp.StatusCode() != consts.StatusOK { t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), consts.StatusOK) } body := resp.Body() if string(body) != "foo" { t.Errorf("unexpected body %q. Expecting %q", body, "abcd") } } }() } wg.Wait() if c.WantConnectionCount() > 0 { t.Errorf("connsWait has %v items remaining", c.WantConnectionCount()) } if errNoFreeConnsCount == 0 { t.Errorf("unexpected errorCount: %d. Expecting > 0", errNoFreeConnsCount) } if emptyBodyCount > 0 { t.Fatalf("at least one request body was empty") } } func TestNewClient(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() opt := config.NewOptions([]config.Option{}) opt.Listener = ln engine := route.NewEngine(opt) engine.GET("/ping", func(c context.Context, ctx *app.RequestContext) { ctx.SetBodyString("pong") }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) client, err := NewClient(WithDialTimeout(2 * time.Second)) if err != nil { t.Fatal(err) return } status, resp, err := client.Get(context.Background(), nil, fullURL(ln, "/ping")) if err != nil { t.Fatal(err) return } if status != consts.StatusOK { t.Errorf("return http status=%v", status) } t.Logf("resp=%v\n", string(resp)) } func TestUseShortConnection(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() opt := config.NewOptions([]config.Option{}) opt.Listener = ln engine := route.NewEngine(opt) engine.GET("/", func(c context.Context, ctx *app.RequestContext) { }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) opt.Addr = ln.Addr().String() c, _ := NewClient(WithKeepAlive(false)) var wg sync.WaitGroup for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() if _, _, err := c.Get(context.Background(), nil, fullURL(ln, "")); err != nil { t.Error(err) return } }() } wg.Wait() connsLen := func() int { c.mLock.Lock() defer c.mLock.Unlock() if _, ok := c.m[opt.Addr]; !ok { return 0 } return c.m[opt.Addr].ConnectionCount() } if conns := connsLen(); conns > 0 { t.Errorf("expected 0 conns got %d", conns) } } func TestPostWithFormData(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() opt := config.NewOptions([]config.Option{}) opt.Listener = ln engine := route.NewEngine(opt) engine.POST("/", func(c context.Context, ctx *app.RequestContext) { var ans string ctx.PostArgs().VisitAll(func(key, value []byte) { ans = ans + string(key) + "=" + string(value) + "&" }) ans = strings.TrimRight(ans, "&") ctx.Data(consts.StatusOK, "text/plain; charset=utf-8", []byte(ans)) }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) client, _ := NewClient() req := protocol.AcquireRequest() rsp := protocol.AcquireResponse() defer func() { protocol.ReleaseRequest(req) protocol.ReleaseResponse(rsp) }() postParam := map[string][]string{ "a": {"c", "d", "e"}, "b": {"c"}, "c": {"f"}, } req.SetFormData(map[string]string{ "a": "c", "b": "c", }) req.SetFormDataFromValues(url.Values{ "a": []string{"d", "e"}, "c": []string{"f"}, }) req.SetRequestURI(fullURL(ln, "")) req.SetMethod(consts.MethodPost) err := client.Do(context.Background(), req, rsp) if err != nil { t.Error(err) } for k, v := range postParam { for _, kv := range v { if !strings.Contains(string(rsp.Body()), k+"="+kv) { t.Errorf("miss %v=%v", k, kv) } } } } func TestPostWithMultipartField(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() opt := config.NewOptions([]config.Option{}) opt.Listener = ln engine := route.NewEngine(opt) engine.POST("/", func(c context.Context, ctx *app.RequestContext) { if string(ctx.FormValue("a")) != "1" { t.Errorf("field a want 1, got %v", string(ctx.FormValue("a"))) } if string(ctx.FormValue("b")) != "2" { t.Errorf("field b want 2, got %v", string(ctx.FormValue("b"))) } t.Log(req.GetHTTP1Request(&ctx.Request).String()) }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) client, _ := NewClient() req := protocol.AcquireRequest() rsp := protocol.AcquireResponse() defer func() { protocol.ReleaseRequest(req) protocol.ReleaseResponse(rsp) }() data := map[string]string{ "a": "1", "b": "2", } req.SetMethod(consts.MethodPost) req.SetRequestURI(fullURL(ln, "")) req.SetMultipartFormData(data) req.SetMultipartFormData(map[string]string{ "c": "3", }) err := client.DoTimeout(context.Background(), req, rsp, 1*time.Second) if err != nil { t.Error(err) } } func TestSetFiles(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() opt := config.NewOptions([]config.Option{}) opt.Listener = ln engine := route.NewEngine(opt) engine.POST("/", func(c context.Context, ctx *app.RequestContext) { form, _ := ctx.MultipartForm() files := form.File["files"] // Upload the file to specific dst. for _, file := range files { ctx.SaveUploadedFile(file, filepath.Base(file.Filename)) } file1, _ := ctx.FormFile("file_1") ctx.SaveUploadedFile(file1, filepath.Base(file1.Filename)) file2, _ := ctx.FormFile("file_2") ctx.SaveUploadedFile(file2, filepath.Base(file2.Filename)) ctx.String(consts.StatusOK, fmt.Sprintf("%d files uploaded!", len(files)+2)) }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) client, _ := NewClient() req := protocol.AcquireRequest() rsp := protocol.AcquireResponse() defer func() { protocol.ReleaseRequest(req) protocol.ReleaseResponse(rsp) }() req.SetMethod(consts.MethodPost) req.SetRequestURI(fullURL(ln, "")) files := []string{"../../common/testdata/test.txt", "../../common/testdata/proto/test.proto", "../../common/testdata/test.png", "../../common/testdata/proto/test.pb.go"} defer func() { for _, file := range files { os.Remove(filepath.Base(file)) } }() req.SetFile("files", files[0]) req.SetFile("files", files[1]) req.SetFiles(map[string]string{ "file_1": files[2], "file_2": files[3], }) err := client.DoTimeout(context.Background(), req, rsp, 1*time.Second) if err != nil { t.Error(err) } } func TestSetMultipartFields(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() opt := config.NewOptions([]config.Option{}) opt.Listener = ln engine := route.NewEngine(opt) engine.POST("/", func(c context.Context, ctx *app.RequestContext) { t.Log(req.GetHTTP1Request(&ctx.Request).String()) if string(ctx.FormValue("a")) != "1" { t.Errorf("field a want 1, got %v", string(ctx.FormValue("a"))) } if string(ctx.FormValue("b")) != "2" { t.Errorf("field b want 2, got %v", string(ctx.FormValue("b"))) } file1, _ := ctx.FormFile("file_1") ctx.SaveUploadedFile(file1, filepath.Base(file1.Filename)) file2, _ := ctx.FormFile("file_2") ctx.SaveUploadedFile(file2, filepath.Base(file2.Filename)) ctx.String(consts.StatusOK, fmt.Sprintf("%d files uploaded!", 2)) }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) client, _ := NewClient(WithDialTimeout(50 * time.Millisecond)) req := protocol.AcquireRequest() rsp := protocol.AcquireResponse() defer func() { protocol.ReleaseRequest(req) protocol.ReleaseResponse(rsp) }() jsonStr1 := `{"input": {"name": "Uploaded document 1", "_filename" : ["file1.txt"]}}` jsonStr2 := `{"input": {"name": "Uploaded document 2", "_filename" : ["file2.txt"]}}` files := []string{"upload-file-1.json", "upload-file-2.json"} fields := []*protocol.MultipartField{ { Param: "file_1", FileName: files[0], ContentType: consts.MIMEApplicationJSON, Reader: strings.NewReader(jsonStr1), }, { Param: "file_2", FileName: files[1], ContentType: consts.MIMEApplicationJSON, Reader: strings.NewReader(jsonStr2), }, } defer func() { for _, file := range files { os.Remove(filepath.Base(file)) } }() req.SetMultipartFields(fields...) req.SetMultipartFormData(map[string]string{"a": "1", "b": "2"}) req.SetRequestURI(fullURL(ln, "")) req.SetMethod(consts.MethodPost) err := client.DoTimeout(context.Background(), req, rsp, 1*time.Second) if err != nil { t.Error(err) } } func TestClientReadResponseBodyStream(t *testing.T) { part1 := "abcdef" part2 := "ghij" ln := testutils.NewTestListener(t) defer ln.Close() opt := config.NewOptions([]config.Option{}) opt.Listener = ln engine := route.NewEngine(opt) engine.POST("/", func(ctx context.Context, c *app.RequestContext) { c.String(consts.StatusOK, part1+part2) }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) client, _ := NewClient(WithResponseBodyStream(true)) req, resp := protocol.AcquireRequest(), protocol.AcquireResponse() defer func() { protocol.ReleaseRequest(req) protocol.ReleaseResponse(resp) }() req.SetRequestURI(fullURL(ln, "")) req.SetMethod(consts.MethodPost) err := client.Do(context.Background(), req, resp) if err != nil { t.Errorf("client Do error=%v", err.Error()) } bodyStream := resp.BodyStream() if bodyStream == nil { t.Errorf("bodystream is nil") } // Read part1 body bytes p := make([]byte, len(part1)) r, err := bodyStream.Read(p) if err != nil { t.Errorf("read from bodystream error=%v", err.Error()) } if string(p) != part1 { t.Errorf("read len=%v, read content=%v; want len=%v, want content=%v", r, string(p), len(part1), part1) } left, _ := ioutil.ReadAll(bodyStream) if string(left) != part2 { t.Errorf("left len=%v, left content=%v; want len=%v, want content=%v", len(left), string(left), len(part2), part2) } } func TestWithBasicAuth(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() opt := config.NewOptions([]config.Option{}) opt.Listener = ln engine := route.NewEngine(opt) engine.GET("/", func(c context.Context, ctx *app.RequestContext) { auth := ctx.GetHeader(consts.HeaderAuthorization) if len(auth) < 6 { ctx.SetStatusCode(consts.StatusUnauthorized) return } password, err := base64.StdEncoding.DecodeString(string(auth[6:])) if err != nil || string(password) != "myuser:basicauth" { ctx.SetStatusCode(consts.StatusUnauthorized) return } }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) client, _ := NewClient() req := protocol.AcquireRequest() rsp := protocol.AcquireResponse() defer func() { protocol.ReleaseRequest(req) protocol.ReleaseResponse(rsp) }() // Success req.SetBasicAuth("myuser", "basicauth") req.SetRequestURI(fullURL(ln, "")) req.SetMethod(consts.MethodGet) err := client.Do(context.Background(), req, rsp) if err != nil { t.Error(err) } if rsp.StatusCode() == consts.StatusUnauthorized { t.Error("unexpected status code=401") } // Fail req.Reset() rsp.Reset() req.SetRequestURI(fullURL(ln, "")) req.SetMethod(consts.MethodGet) err = client.Do(context.Background(), req, rsp) if err != nil { t.Error(err) } if rsp.StatusCode() != consts.StatusUnauthorized { t.Errorf("unexpected status code: %v, expected 401", rsp.StatusCode()) } } func TestClientProxyWithStandardDialer(t *testing.T) { testCases := []struct{ httpsSite, httpsProxy bool }{ {false, false}, {false, true}, {true, false}, {true, true}, } for _, testCase := range testCases { httpsSite := testCase.httpsSite httpsProxy := testCase.httpsProxy t.Run(fmt.Sprintf("httpsSite=%v, httpsProxy=%v", httpsSite, httpsProxy), func(t *testing.T) { siteCh := make(chan *http.Request, 1) h1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { siteCh <- r }) proxyCh := make(chan *http.Request, 1) h2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { proxyCh <- r if r.Method == "CONNECT" { hijacker, ok := w.(http.Hijacker) if !ok { t.Errorf("hijack not allowed") return } clientConn, _, err := hijacker.Hijack() if err != nil { t.Errorf("hijacking failed") return } res := &http.Response{ StatusCode: http.StatusOK, Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1, Header: make(http.Header), } targetConn, err := net.Dial("tcp", r.URL.Host) if err != nil { t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err) return } if err := res.Write(clientConn); err != nil { t.Errorf("Writing 200 OK failed: %v", err) return } go io.Copy(targetConn, clientConn) go func() { io.Copy(clientConn, targetConn) targetConn.Close() }() } }) var ts *httptest.Server if httpsSite { ts = httptest.NewTLSServer(h1) } else { ts = httptest.NewServer(h1) } var proxyServer *httptest.Server if httpsProxy { proxyServer = httptest.NewTLSServer(h2) } else { proxyServer = httptest.NewServer(h2) } pu := protocol.ParseURI(proxyServer.URL) // If neither server is HTTPS or both are, then c may be derived from either. // If only one server is HTTPS, c must be derived from that server in order // to ensure that it is configured to use the fake root CA from testcert.go. dialer.SetDialer(standard.NewDialer()) var cOpt config.ClientOption if httpsProxy { cOpt = WithTLSConfig(proxyServer.Client().Transport.(*http.Transport).TLSClientConfig) } else if httpsSite { cOpt = WithTLSConfig(ts.Client().Transport.(*http.Transport).TLSClientConfig) } var c *Client if httpsProxy || httpsSite { c, _ = NewClient(cOpt) } else { c, _ = NewClient() } c.SetProxy(protocol.ProxyURI(pu)) req, rsp := protocol.AcquireRequest(), protocol.AcquireResponse() defer func() { protocol.ReleaseRequest(req) protocol.ReleaseResponse(rsp) }() req.SetRequestURI(ts.URL) req.SetMethod(consts.MethodHead) err := c.Do(context.Background(), req, rsp) if err != nil { t.Error(err) } var got *http.Request select { case got = <-proxyCh: case <-time.After(5 * time.Second): t.Fatal("timeout connecting to http proxy") } ts.Close() proxyServer.Close() if httpsSite { // First message should be a CONNECT to ask for a socket to the real server, if got.Method != "CONNECT" { t.Errorf("Wrong method for secure proxying: %q", got.Method) } gotHost := got.URL.Host pu, err := url.Parse(ts.URL) if err != nil { t.Fatal("Invalid site URL") } if wantHost := pu.Host; gotHost != wantHost { t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost) } // The next message on the channel should be from the site's server. next := <-siteCh if next.Method != "HEAD" { t.Errorf("Wrong method at destination: %s", next.Method) } if nextURL := next.URL.String(); nextURL != "/" { t.Errorf("Wrong URL at destination: %s", nextURL) } } else { if got.Method != "HEAD" { t.Errorf("Wrong method for destination: %q", got.Method) } gotURL := got.URL.String() wantURL := ts.URL + "/" if gotURL != wantURL { t.Errorf("Got URL %q, want %q", gotURL, wantURL) } } }) } } func TestClientProxyWithNetpollDialer(t *testing.T) { testCases := []struct{ httpsSite, httpsProxy bool }{ {false, false}, {true, false}, {false, true}, {true, true}, } for _, testCase := range testCases { httpsSite := testCase.httpsSite httpsProxy := testCase.httpsProxy t.Run(fmt.Sprintf("httpsSite=%v, httpsProxy=%v", httpsSite, httpsProxy), func(t *testing.T) { siteCh := make(chan *http.Request, 1) h1 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { siteCh <- r }) proxyCh := make(chan *http.Request, 1) h2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { proxyCh <- r }) var ts *httptest.Server if httpsSite { ts = httptest.NewTLSServer(h1) } else { ts = httptest.NewServer(h1) } var proxyServer *httptest.Server if httpsProxy { proxyServer = httptest.NewTLSServer(h2) } else { proxyServer = httptest.NewServer(h2) } pu := protocol.ParseURI(proxyServer.URL) // If neither server is HTTPS or both are, then c may be derived from either. // If only one server is HTTPS, c must be derived from that server in order // to ensure that it is configured to use the fake root CA from testcert.go. c, _ := NewClient() c.SetProxy(protocol.ProxyURI(pu)) req, rsp := protocol.AcquireRequest(), protocol.AcquireResponse() defer func() { protocol.ReleaseRequest(req) protocol.ReleaseResponse(rsp) }() req.SetRequestURI(ts.URL) req.SetMethod(consts.MethodHead) err := c.Do(context.Background(), req, rsp) if err != nil { t.Log(err) if !httpsSite && !httpsProxy { t.Fatal(err) } return } var got *http.Request select { case got = <-proxyCh: case <-time.After(5 * time.Second): t.Fatal("timeout connecting to http proxy") } ts.Close() proxyServer.Close() if got.Method != "HEAD" { t.Errorf("Wrong method for destination: %q", got.Method) } gotURL := got.URL.String() wantURL := ts.URL + "/" if gotURL != wantURL { t.Errorf("Got URL %q, want %q", gotURL, wantURL) } }) } } func TestClientMiddleware(t *testing.T) { client, _ := NewClient() mw0 := func(next Endpoint) Endpoint { return func(ctx context.Context, req *protocol.Request, resp *protocol.Response) (err error) { req.SetRequestURI("middleware0") return next(ctx, req, resp) } } mw1 := func(next Endpoint) Endpoint { return func(ctx context.Context, req *protocol.Request, resp *protocol.Response) (err error) { if string(req.RequestURI()) != "middleware0" { t.Errorf("Wrong request URI: %s, expected %v", req.RequestURI(), "middleware0") } req.SetRequestURI("middleware1") return next(ctx, req, resp) } } mw2 := func(next Endpoint) Endpoint { return func(ctx context.Context, req *protocol.Request, resp *protocol.Response) (err error) { if string(req.RequestURI()) != "middleware1" { t.Errorf("Wrong request URI: %s, expected %v", req.RequestURI(), "middleware1") } return nil } } client.Use(mw0) client.Use(mw1) client.Use(mw2) request, response := protocol.AcquireRequest(), protocol.AcquireResponse() defer func() { protocol.ReleaseRequest(request) protocol.ReleaseResponse(response) }() err := client.Do(context.Background(), request, response) if err != nil { t.Errorf("unexpected error: %s", err.Error()) } } func TestClientLastMiddleware(t *testing.T) { client, _ := NewClient() mw0 := func(next Endpoint) Endpoint { return func(ctx context.Context, req *protocol.Request, resp *protocol.Response) (err error) { finalValue0 := ctx.Value("final0") assert.DeepEqual(t, "final3", finalValue0) finalValue1 := ctx.Value("final1") assert.DeepEqual(t, "final1", finalValue1) finalValue2 := ctx.Value("final2") assert.DeepEqual(t, "final2", finalValue2) return nil } } mw1 := func(next Endpoint) Endpoint { return func(ctx context.Context, req *protocol.Request, resp *protocol.Response) (err error) { //nolint:staticcheck // SA1029 no built-in type string as key ctx = context.WithValue(ctx, "final0", "final0") return next(ctx, req, resp) } } mw2 := func(next Endpoint) Endpoint { return func(ctx context.Context, req *protocol.Request, resp *protocol.Response) (err error) { //nolint:staticcheck // SA1029 no built-in type string as key ctx = context.WithValue(ctx, "final1", "final1") return next(ctx, req, resp) } } mw3 := func(next Endpoint) Endpoint { return func(ctx context.Context, req *protocol.Request, resp *protocol.Response) (err error) { //nolint:staticcheck // SA1029 no built-in type string as key ctx = context.WithValue(ctx, "final2", "final2") return next(ctx, req, resp) } } mw4 := func(next Endpoint) Endpoint { return func(ctx context.Context, req *protocol.Request, resp *protocol.Response) (err error) { //nolint:staticcheck // SA1029 no built-in type string as key ctx = context.WithValue(ctx, "final0", "final3") return next(ctx, req, resp) } } err := client.UseAsLast(mw0) assert.Nil(t, err) err = client.UseAsLast(func(endpoint Endpoint) Endpoint { return nil }) assert.DeepEqual(t, errorLastMiddlewareExist, err) client.Use(mw1) client.Use(mw2) client.Use(mw3) client.Use(mw4) request, response := protocol.AcquireRequest(), protocol.AcquireResponse() defer func() { protocol.ReleaseRequest(request) protocol.ReleaseResponse(response) }() err = client.Do(context.Background(), request, response) if err != nil { t.Errorf("unexpected error: %s", err.Error()) } last := client.TakeOutLastMiddleware() assert.DeepEqual(t, reflect.ValueOf(last).Pointer(), reflect.ValueOf(mw0).Pointer()) last = client.TakeOutLastMiddleware() assert.Nil(t, last) } func TestClientReadResponseBodyStreamWithDoubleRequest(t *testing.T) { part1 := "" for i := 0; i < 8192; i++ { part1 += "a" } part2 := "ghij" ln := testutils.NewTestListener(t) defer ln.Close() opt := config.NewOptions([]config.Option{}) opt.Listener = ln engine := route.NewEngine(opt) engine.POST("/", func(ctx context.Context, c *app.RequestContext) { c.String(consts.StatusOK, part1+part2) }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) client, _ := NewClient(WithResponseBodyStream(true)) req, resp := protocol.AcquireRequest(), protocol.AcquireResponse() defer func() { protocol.ReleaseRequest(req) protocol.ReleaseResponse(resp) }() req.SetRequestURI(fullURL(ln, "")) req.SetMethod(consts.MethodPost) err := client.Do(context.Background(), req, resp) if err != nil { t.Errorf("client Do error=%v", err.Error()) } bodyStream := resp.BodyStream() if bodyStream == nil { t.Errorf("bodystream is nil") } // Read part1 body bytes p := make([]byte, len(part1)) r, err := bodyStream.Read(p) if err != nil { t.Errorf("read from bodystream error=%v", err.Error()) } if string(p) != part1 { t.Errorf("read len=%v, read content=%v; want len=%v, want content=%v", r, string(p), len(part1), part1) } // send another request and read all bodystream req1, resp1 := protocol.AcquireRequest(), protocol.AcquireResponse() defer func() { protocol.ReleaseRequest(req1) protocol.ReleaseResponse(resp1) }() req1.SetRequestURI(fullURL(ln, "")) req1.SetMethod(consts.MethodPost) err = client.Do(context.Background(), req1, resp1) if err != nil { t.Errorf("client Do error=%v", err.Error()) } bodyStream1 := resp1.BodyStream() if bodyStream1 == nil { t.Errorf("bodystream1 is nil") } data, _ := ioutil.ReadAll(bodyStream1) if string(data) != part1+part2 { t.Errorf("read len=%v, read content=%v; want len=%v, want content=%v", len(data), data, len(part1+part2), part1+part2) } // read left bodystream left, _ := ioutil.ReadAll(bodyStream) if string(left) != part2 { t.Errorf("left len=%v, left content=%v; want len=%v, want content=%v", len(left), string(left), len(part2), part2) } } func TestClientReadResponseBodyStreamWithConnectionClose(t *testing.T) { part1 := "" for i := 0; i < 8192; i++ { part1 += "a" } ln := testutils.NewTestListener(t) defer ln.Close() opt := config.NewOptions([]config.Option{}) opt.Listener = ln engine := route.NewEngine(opt) engine.POST("/", func(ctx context.Context, c *app.RequestContext) { c.String(consts.StatusOK, part1) }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) client, _ := NewClient(WithResponseBodyStream(true)) // first req req, resp := protocol.AcquireRequest(), protocol.AcquireResponse() defer func() { protocol.ReleaseRequest(req) protocol.ReleaseResponse(resp) }() req.SetConnectionClose() req.SetMethod(consts.MethodPost) req.SetRequestURI(fullURL(ln, "")) err := client.Do(context.Background(), req, resp) if err != nil { t.Fatalf("client Do error=%v", err.Error()) } assert.DeepEqual(t, part1, string(resp.Body())) // second req req1, resp1 := protocol.AcquireRequest(), protocol.AcquireResponse() defer func() { protocol.ReleaseRequest(req1) protocol.ReleaseResponse(resp1) }() req1.SetConnectionClose() req1.SetMethod(consts.MethodPost) req1.SetRequestURI(fullURL(ln, "")) err = client.Do(context.Background(), req1, resp1) if err != nil { t.Fatalf("client Do error=%v", err.Error()) } assert.DeepEqual(t, part1, string(resp1.Body())) } type mockDialer struct { network.Dialer customDialerFunc func(network, address string, timeout time.Duration, tlsConfig *tls.Config) network string address string timeout time.Duration } func (m *mockDialer) DialConnection(network, address string, timeout time.Duration, tlsConfig *tls.Config) (conn network.Conn, err error) { if m.customDialerFunc != nil { m.customDialerFunc(network, address, timeout, tlsConfig) } return m.Dialer.DialConnection(m.network, m.address, m.timeout, tlsConfig) } func TestClientRetry(t *testing.T) { client, err := NewClient( // Default dial function performs different in different os. So unit the performance of dial function. WithDialFunc(func(addr string) (network.Conn, error) { return nil, fmt.Errorf("dial tcp %s: i/o timeout", addr) }), WithRetryConfig( retry.WithMaxAttemptTimes(3), retry.WithInitDelay(100*time.Millisecond), retry.WithMaxDelay(10*time.Second), retry.WithDelayPolicy(retry.CombineDelay(retry.FixedDelayPolicy, retry.BackOffDelayPolicy)), ), ) client.SetRetryIfFunc(func(req *protocol.Request, resp *protocol.Response, err error) bool { return err != nil }) if err != nil { t.Fatal(err) return } startTime := time.Now().UnixNano() _, resp, err := client.Get(context.Background(), nil, "http://127.0.0.1:1234/ping") if err != nil { // first delay 100+200ms , second delay 100+400ms if time.Duration(time.Now().UnixNano()-startTime) > 800*time.Millisecond && time.Duration(time.Now().UnixNano()-startTime) < 2*time.Second { t.Logf("Retry triggered : delay=%dms\tresp=%v\terr=%v\n", time.Duration(time.Now().UnixNano()-startTime)/(1*time.Millisecond), string(resp), fmt.Sprintln(err)) } else if time.Duration(time.Now().UnixNano()-startTime) < 1*time.Second { // Compatible without triggering retry t.Logf("Retry not triggered : delay=%dms\tresp=%v\terr=%v\n", time.Duration(time.Now().UnixNano()-startTime)/(1*time.Millisecond), string(resp), fmt.Sprintln(err)) } else { t.Fatal(err) } } client2, err := NewClient( WithDialFunc(func(addr string) (network.Conn, error) { return nil, fmt.Errorf("dial tcp %s: i/o timeout", addr) }), WithRetryConfig( retry.WithMaxAttemptTimes(2), retry.WithInitDelay(500*time.Millisecond), retry.WithMaxJitter(1*time.Second), retry.WithDelayPolicy(retry.CombineDelay(retry.FixedDelayPolicy, retry.BackOffDelayPolicy)), ), ) if err != nil { t.Fatal(err) return } client2.SetRetryIfFunc(func(req *protocol.Request, resp *protocol.Response, err error) bool { return err != nil }) startTime = time.Now().UnixNano() _, resp, err = client2.Get(context.Background(), nil, "http://127.0.0.1:1234/ping") if err != nil { // delay max{500ms+rand([0,1))s,100ms}. Because if the MaxDelay is not set, we will use the default MaxDelay of 100ms if time.Duration(time.Now().UnixNano()-startTime) > 100*time.Millisecond && time.Duration(time.Now().UnixNano()-startTime) < 1100*time.Millisecond { t.Logf("Retry triggered : delay=%dms\tresp=%v\terr=%v\n", time.Duration(time.Now().UnixNano()-startTime)/(1*time.Millisecond), string(resp), fmt.Sprintln(err)) } else if time.Duration(time.Now().UnixNano()-startTime) < 1*time.Second { // Compatible without triggering retry t.Logf("Retry not triggered : delay=%dms\tresp=%v\terr=%v\n", time.Duration(time.Now().UnixNano()-startTime)/(1*time.Millisecond), string(resp), fmt.Sprintln(err)) } else { t.Fatal(err) } } client3, err := NewClient( WithDialFunc(func(addr string) (network.Conn, error) { return nil, fmt.Errorf("dial tcp %s: i/o timeout", addr) }), WithRetryConfig( retry.WithMaxAttemptTimes(2), retry.WithInitDelay(100*time.Millisecond), retry.WithMaxDelay(5*time.Second), retry.WithMaxJitter(1*time.Second), retry.WithDelayPolicy(retry.CombineDelay(retry.FixedDelayPolicy, retry.BackOffDelayPolicy, retry.RandomDelayPolicy)), ), ) if err != nil { t.Fatal(err) return } client3.SetRetryIfFunc(func(req *protocol.Request, resp *protocol.Response, err error) bool { return err != nil }) startTime = time.Now().UnixNano() _, resp, err = client3.Get(context.Background(), nil, "http://127.0.0.1:1234/ping") if err != nil { // delay 100ms+200ms+rand([0,1))s if time.Duration(time.Now().UnixNano()-startTime) > 300*time.Millisecond && time.Duration(time.Now().UnixNano()-startTime) < 2300*time.Millisecond { t.Logf("Retry triggered : delay=%dms\tresp=%v\terr=%v\n", time.Duration(time.Now().UnixNano()-startTime)/(1*time.Millisecond), string(resp), fmt.Sprintln(err)) } else if time.Duration(time.Now().UnixNano()-startTime) < 1*time.Second { // Compatible without triggering retry t.Logf("Retry not triggered : delay=%dms\tresp=%v\terr=%v\n", time.Duration(time.Now().UnixNano()-startTime)/(1*time.Millisecond), string(resp), fmt.Sprintln(err)) } else { t.Fatal(err) } } client4, err := NewClient( WithDialFunc(func(addr string) (network.Conn, error) { return nil, fmt.Errorf("dial tcp %s: i/o timeout", addr) }), WithRetryConfig( retry.WithMaxAttemptTimes(2), retry.WithInitDelay(1*time.Second), retry.WithMaxDelay(10*time.Second), retry.WithMaxJitter(5*time.Second), retry.WithDelayPolicy(retry.CombineDelay(retry.FixedDelayPolicy, retry.BackOffDelayPolicy, retry.RandomDelayPolicy)), ), ) if err != nil { t.Fatal(err) return } /* If the retryIfFunc is not set , idempotent logic is used by default */ //client4.SetRetryIfFunc(func(req *protocol.Request, resp *protocol.Response, err error) bool { // return err != nil //}) startTime = time.Now().UnixNano() _, resp, err = client4.Get(context.Background(), nil, "http://127.0.0.1:1234/ping") if err != nil { if time.Duration(time.Now().UnixNano()-startTime) > 1*time.Second && time.Duration(time.Now().UnixNano()-startTime) < 9*time.Second { t.Logf("Retry triggered : delay=%dms\tresp=%v\terr=%v\n", time.Duration(time.Now().UnixNano()-startTime)/(1*time.Millisecond), string(resp), fmt.Sprintln(err)) } else if time.Duration(time.Now().UnixNano()-startTime) < 1*time.Second { // Compatible without triggering retry t.Logf("Retry not triggered : delay=%dms\tresp=%v\terr=%v\n", time.Duration(time.Now().UnixNano()-startTime)/(1*time.Millisecond), string(resp), fmt.Sprintln(err)) } else { t.Fatal(err) } return } } func TestClientHostClientConfigHookError(t *testing.T) { client, _ := NewClient(WithHostClientConfigHook(func(hc interface{}) error { hct, ok := hc.(*http1.HostClient) assert.True(t, ok) assert.DeepEqual(t, "foo.bar:80", hct.Addr) return errors.New("hook return") })) req := protocol.AcquireRequest() req.SetMethod(consts.MethodGet) req.SetRequestURI("http://foo.bar/") resp := protocol.AcquireResponse() err := client.do(context.TODO(), req, resp) assert.DeepEqual(t, "hook return", err.Error()) } func TestClientHostClientConfigHook(t *testing.T) { client, _ := NewClient(WithHostClientConfigHook(func(hc interface{}) error { hct, ok := hc.(*http1.HostClient) assert.True(t, ok) assert.DeepEqual(t, "foo.bar:80", hct.Addr) hct.Addr = "FOO.BAR:443" return nil })) req := protocol.AcquireRequest() req.SetMethod(consts.MethodGet) req.SetRequestURI("http://foo.bar/") resp := protocol.AcquireResponse() client.do(context.Background(), req, resp) client.mLock.Lock() hc := client.m["foo.bar"] client.mLock.Unlock() hcr, ok := hc.(*http1.HostClient) assert.True(t, ok) assert.DeepEqual(t, "FOO.BAR:443", hcr.Addr) } func TestClientDialerName(t *testing.T) { client, _ := NewClient() dName, err := client.GetDialerName() if err != nil { t.Fatalf("unexpected error: %v", err) } // Depending on the operating system, // the default dialer has a different network library, either "netpoll" or "standard" if !(dName == "netpoll" || dName == "standard") { t.Errorf("expected 'netpoll', but get %s", dName) } client, _ = NewClient(WithDialer(&mockDialer{})) dName, err = client.GetDialerName() if err != nil { t.Fatalf("unexpected error: %v", err) } if dName != "client" { t.Errorf("expected 'standard', but get %s", dName) } client, _ = NewClient(WithDialer(standard.NewDialer())) dName, err = client.GetDialerName() if err != nil { t.Fatalf("unexpected error: %v", err) } if dName != "standard" { t.Errorf("expected 'standard', but get %s", dName) } client, _ = NewClient(WithDialer(&mockDialer{})) dName, err = client.GetDialerName() if err != nil { t.Fatalf("unexpected error: %v", err) } if dName != "client" { t.Errorf("expected 'client', but get %s", dName) } client.options.Dialer = nil dName, err = client.GetDialerName() if err == nil { t.Errorf("expected an err for abnormal process") } if dName != "" { t.Errorf("expected 'empty string', but get %s", dName) } } func TestClientDoWithDialFunc(t *testing.T) { ch := make(chan error, 1) uri := "/foo/bar/baz" body := "request body" opt, ln := newTestOptions(t) defer ln.Close() engine := route.NewEngine(opt) engine.POST("/foo/bar/baz", func(c context.Context, ctx *app.RequestContext) { if string(ctx.Request.Header.Method()) != consts.MethodPost { ch <- fmt.Errorf("unexpected request method: %q. Expecting %q", ctx.Request.Header.Method(), consts.MethodPost) return } reqURI := ctx.Request.RequestURI() if string(reqURI) != uri { ch <- fmt.Errorf("unexpected request uri: %q. Expecting %q", reqURI, uri) return } cl := ctx.Request.Header.ContentLength() if cl != len(body) { ch <- fmt.Errorf("unexpected content-length %d. Expecting %d", cl, len(body)) return } reqBody := ctx.Request.Body() if string(reqBody) != body { ch <- fmt.Errorf("unexpected request body: %q. Expecting %q", reqBody, body) return } ch <- nil }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) c, _ := NewClient(WithDialFunc(func(addr string) (network.Conn, error) { return dialer.DialConnection(opt.Network, opt.Addr, time.Second, nil) })) var req protocol.Request req.Header.SetMethod(consts.MethodPost) req.SetRequestURI(uri) req.SetHost("xxx.com") req.SetBodyString(body) var resp protocol.Response err := c.Do(context.Background(), &req, &resp) if err != nil { t.Fatalf("error when doing request: %s", err) } select { case err = <-ch: if err != nil { t.Fatalf("err = %s", err.Error()) } case <-time.After(5 * time.Second): t.Fatalf("timeout") } } func TestClientState(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() opt := config.NewOptions([]config.Option{}) opt.Listener = ln engine := route.NewEngine(opt) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) opt.Addr = ln.Addr().String() var wg sync.WaitGroup wg.Add(2) state := int32(0) client, _ := NewClient( WithMaxIdleConnDuration(75*time.Millisecond), WithConnStateObserve(func(hcs config.HostClientState) { switch atomic.LoadInt32(&state) { case int32(0): assert.DeepEqual(t, 1, hcs.ConnPoolState().TotalConnNum) assert.DeepEqual(t, 1, hcs.ConnPoolState().PoolConnNum) assert.DeepEqual(t, 0, hcs.ConnPoolState().MaxConns) assert.DeepEqual(t, opt.Addr, hcs.ConnPoolState().Addr) atomic.StoreInt32(&state, int32(1)) wg.Done() case int32(1): assert.DeepEqual(t, 0, hcs.ConnPoolState().TotalConnNum) assert.DeepEqual(t, 0, hcs.ConnPoolState().PoolConnNum) assert.DeepEqual(t, 0, hcs.ConnPoolState().MaxConns) assert.DeepEqual(t, opt.Addr, hcs.ConnPoolState().Addr) atomic.StoreInt32(&state, int32(2)) wg.Done() } }, 50*time.Millisecond)) client.Get(context.Background(), nil, "http://"+opt.Addr) wg.Wait() assert.DeepEqual(t, int32(2), atomic.LoadInt32(&state)) } func TestClientRetryErr(t *testing.T) { t.Run("200", func(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() opt := config.NewOptions([]config.Option{}) opt.Listener = ln engine := route.NewEngine(opt) var l sync.Mutex retryNum := 0 engine.GET("/ping", func(c context.Context, ctx *app.RequestContext) { l.Lock() defer l.Unlock() retryNum += 1 ctx.SetStatusCode(200) }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) c, _ := NewClient(WithRetryConfig(retry.WithMaxAttemptTimes(3))) _, _, err := c.Get(context.Background(), nil, fullURL(ln, "/ping")) assert.Nil(t, err) l.Lock() assert.DeepEqual(t, 1, retryNum) l.Unlock() }) t.Run("502", func(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() opt := config.NewOptions([]config.Option{}) opt.Listener = ln engine := route.NewEngine(opt) var l sync.Mutex retryNum := 0 engine.GET("/ping", func(c context.Context, ctx *app.RequestContext) { l.Lock() defer l.Unlock() retryNum += 1 ctx.SetStatusCode(502) }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) c, _ := NewClient(WithRetryConfig(retry.WithMaxAttemptTimes(3))) c.SetRetryIfFunc(func(req *protocol.Request, resp *protocol.Response, err error) bool { return resp.StatusCode() == 502 }) _, _, err := c.Get(context.Background(), nil, fullURL(ln, "/ping")) assert.Nil(t, err) l.Lock() assert.DeepEqual(t, 3, retryNum) l.Unlock() }) } type mockHostClient struct { shouldRemove bool closed bool } func (m *mockHostClient) Do(ctx context.Context, req *protocol.Request, resp *protocol.Response) error { return nil } func (m *mockHostClient) SetDynamicConfig(dc *client.DynamicConfig) { } func (m *mockHostClient) CloseIdleConnections() { } func (m *mockHostClient) ShouldRemove() bool { return m.shouldRemove } func (m *mockHostClient) ConnectionCount() int { return 0 } func (m *mockHostClient) Close() error { m.closed = true return nil } func TestCleanHostClients(t *testing.T) { tests := []struct { name string isTLS bool initMap map[string]*mockHostClient expectedKeys []string shouldClose bool expectedRes bool }{ { name: "Remove item from c.m", isTLS: false, initMap: map[string]*mockHostClient{ "google.com": {shouldRemove: true}, }, expectedKeys: []string{}, shouldClose: true, expectedRes: true, }, { name: "Remove item from c.ms", isTLS: true, initMap: map[string]*mockHostClient{ "google.com": {shouldRemove: true}, }, expectedKeys: []string{}, shouldClose: true, expectedRes: true, }, { name: "Do not remove non-removable client", isTLS: false, initMap: map[string]*mockHostClient{ "google.com": {shouldRemove: false}, }, expectedKeys: []string{"google.com"}, shouldClose: false, expectedRes: false, }, { name: "Empty map", isTLS: false, initMap: map[string]*mockHostClient{}, expectedKeys: []string{}, shouldClose: false, expectedRes: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { cli := &Client{ mLock: sync.Mutex{}, m: map[string]client.HostClient{}, ms: map[string]client.HostClient{}, } if tt.isTLS { for k, v := range tt.initMap { cli.ms[k] = v } } else { for k, v := range tt.initMap { cli.m[k] = v } } result := cli.cleanHostClients(tt.isTLS) var resultMap map[string]client.HostClient if tt.isTLS { resultMap = cli.ms } else { resultMap = cli.m } var keys []string for k := range resultMap { keys = append(keys, k) } assert.Assert(t, len(tt.expectedKeys) == len(keys)) for _, v := range tt.initMap { if v.shouldRemove { assert.Assert(t, v.closed, "Expected client to be closed") } else { assert.Assert(t, !v.closed, "Client should not be closed") } } assert.Assert(t, tt.expectedRes == result) }) } } ================================================ FILE: pkg/app/client/client_unix_test.go ================================================ // Copyright 2023 CloudWeGo Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // //go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris package client import ( "crypto/tls" "math/rand" "time" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/network/netpoll" "github.com/cloudwego/hertz/pkg/network/standard" ) func newMockDialerWithCustomFunc(network, address string, timeout time.Duration, f func(network, address string, timeout time.Duration, tlsConfig *tls.Config)) network.Dialer { dialer := standard.NewDialer() if rand.Intn(2) == 0 { dialer = netpoll.NewDialer() } return &mockDialer{ Dialer: dialer, customDialerFunc: f, network: network, address: address, timeout: timeout, } } ================================================ FILE: pkg/app/client/client_windows_test.go ================================================ // Copyright 2023 CloudWeGo Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // //go:build windows package client import ( "crypto/tls" "time" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/network/standard" ) func newMockDialerWithCustomFunc(network, address string, timeout time.Duration, f func(network, address string, timeout time.Duration, tlsConfig *tls.Config)) network.Dialer { dialer := standard.NewDialer() return &mockDialer{ Dialer: dialer, customDialerFunc: f, network: network, address: address, timeout: timeout, } } ================================================ FILE: pkg/app/client/discovery/discovery.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package discovery import ( "context" "net" "github.com/cloudwego/hertz/pkg/app/server/registry" "github.com/cloudwego/hertz/pkg/common/utils" ) type TargetInfo struct { Host string Tags map[string]string } type Resolver interface { // Target should return a description for the given target that is suitable for being a key for cache. Target(ctx context.Context, target *TargetInfo) string // Resolve returns a list of instances for the given description of a target. Resolve(ctx context.Context, desc string) (Result, error) // Name returns the name of the resolver. Name() string } // SynthesizedResolver synthesizes a Resolver using a resolve function. type SynthesizedResolver struct { TargetFunc func(ctx context.Context, target *TargetInfo) string ResolveFunc func(ctx context.Context, key string) (Result, error) NameFunc func() string } func (sr SynthesizedResolver) Target(ctx context.Context, target *TargetInfo) string { if sr.TargetFunc == nil { return "" } return sr.TargetFunc(ctx, target) } func (sr SynthesizedResolver) Resolve(ctx context.Context, key string) (Result, error) { return sr.ResolveFunc(ctx, key) } // Name implements the Resolver interface func (sr SynthesizedResolver) Name() string { if sr.NameFunc == nil { return "" } return sr.NameFunc() } // Instance contains information of an instance from the target service. type Instance interface { Address() net.Addr Weight() int Tag(key string) (value string, exist bool) } type instance struct { addr net.Addr weight int tags map[string]string } func (i *instance) Address() net.Addr { return i.addr } func (i *instance) Weight() int { if i.weight > 0 { return i.weight } return registry.DefaultWeight } func (i *instance) Tag(key string) (value string, exist bool) { value, exist = i.tags[key] return } // NewInstance creates an Instance using the given network, address and tags func NewInstance(network, address string, weight int, tags map[string]string) Instance { return &instance{ addr: utils.NewNetAddr(network, address), weight: weight, tags: tags, } } // Result contains the result of service discovery process. // the instance list can/should be cached and CacheKey can be used to map the instance list in cache. type Result struct { CacheKey string Instances []Instance } ================================================ FILE: pkg/app/client/discovery/discovery_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package discovery import ( "context" "testing" "github.com/cloudwego/hertz/pkg/app/server/registry" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestInstance(t *testing.T) { network := "192.168.1.1" address := "/hello" weight := 1 instance := NewInstance(network, address, weight, nil) assert.DeepEqual(t, network, instance.Address().Network()) assert.DeepEqual(t, address, instance.Address().String()) assert.DeepEqual(t, weight, instance.Weight()) val, ok := instance.Tag("name") assert.DeepEqual(t, "", val) assert.False(t, ok) instance2 := NewInstance("", "", 0, nil) assert.DeepEqual(t, registry.DefaultWeight, instance2.Weight()) } func TestSynthesizedResolver(t *testing.T) { targetFunc := func(ctx context.Context, target *TargetInfo) string { return "hello" } resolveFunc := func(ctx context.Context, key string) (Result, error) { return Result{CacheKey: "name"}, nil } nameFunc := func() string { return "raymonder" } resolver := SynthesizedResolver{ TargetFunc: targetFunc, ResolveFunc: resolveFunc, NameFunc: nameFunc, } assert.DeepEqual(t, "hello", resolver.Target(context.Background(), &TargetInfo{})) res, err := resolver.Resolve(context.Background(), "") assert.DeepEqual(t, "name", res.CacheKey) assert.Nil(t, err) assert.DeepEqual(t, "raymonder", resolver.Name()) resolver2 := SynthesizedResolver{ TargetFunc: nil, ResolveFunc: nil, NameFunc: nil, } assert.DeepEqual(t, "", resolver2.Target(context.Background(), &TargetInfo{})) assert.DeepEqual(t, "", resolver2.Name()) } ================================================ FILE: pkg/app/client/loadbalance/lbcache.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package loadbalance import ( "context" "fmt" "sync" "sync/atomic" "time" "github.com/cloudwego/hertz/pkg/app/client/discovery" "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/hlog" "github.com/cloudwego/hertz/pkg/protocol" "golang.org/x/sync/singleflight" ) type cacheResult struct { res atomic.Value // newest and previous discovery result expire int32 // 0 = normal, 1 = expire and collect next ticker serviceName string // service psm } var ( balancerFactories sync.Map // key: resolver name + load-balancer name balancerFactoriesSfg singleflight.Group ) func cacheKey(resolver, balancer string, opts Options) string { return fmt.Sprintf("%s|%s|{%s %s}", resolver, balancer, opts.RefreshInterval, opts.ExpireInterval) } type BalancerFactory struct { opts Options cache sync.Map // key -> LoadBalancer resolver discovery.Resolver balancer Loadbalancer sfg singleflight.Group } type Config struct { Resolver discovery.Resolver Balancer Loadbalancer LbOpts Options } // NewBalancerFactory get or create a balancer with given target. // If it has the same key(resolver.Target(target)), we will cache and reuse the Balance. func NewBalancerFactory(config Config) *BalancerFactory { config.LbOpts.Check() uniqueKey := cacheKey(config.Resolver.Name(), config.Balancer.Name(), config.LbOpts) val, ok := balancerFactories.Load(uniqueKey) if ok { return val.(*BalancerFactory) } val, _, _ = balancerFactoriesSfg.Do(uniqueKey, func() (interface{}, error) { b := &BalancerFactory{ opts: config.LbOpts, resolver: config.Resolver, balancer: config.Balancer, } go b.watcher() go b.refresh() balancerFactories.Store(uniqueKey, b) return b, nil }) return val.(*BalancerFactory) } // watch expired balancer func (b *BalancerFactory) watcher() { for range time.Tick(b.opts.ExpireInterval) { b.cache.Range(func(key, value interface{}) bool { cache := value.(*cacheResult) if atomic.CompareAndSwapInt32(&cache.expire, 0, 1) { // 1. set expire flag // 2. wait next ticker for collect, maybe the balancer is used again // (avoid being immediate delete the balancer which had been created recently) } else { b.cache.Delete(key) b.balancer.Delete(key.(string)) } return true }) } } // cache key with resolver name prefix avoid conflict for balancer func renameResultCacheKey(res *discovery.Result, resolverName string) { res.CacheKey = resolverName + ":" + res.CacheKey } // refresh is used to update service discovery information periodically. func (b *BalancerFactory) refresh() { for range time.Tick(b.opts.RefreshInterval) { b.cache.Range(func(key, value interface{}) bool { res, err := b.resolver.Resolve(context.Background(), key.(string)) if err != nil { hlog.SystemLogger().Warnf("resolver refresh failed, key=%s error=%s", key, err.Error()) return true } renameResultCacheKey(&res, b.resolver.Name()) cache := value.(*cacheResult) cache.res.Store(res) atomic.StoreInt32(&cache.expire, 0) b.balancer.Rebalance(res) return true }) } } func (b *BalancerFactory) GetInstance(ctx context.Context, req *protocol.Request) (discovery.Instance, error) { cacheRes, err := b.getCacheResult(ctx, req) if err != nil { return nil, err } atomic.StoreInt32(&cacheRes.expire, 0) ins := b.balancer.Pick(cacheRes.res.Load().(discovery.Result)) if ins == nil { hlog.SystemLogger().Errorf("null instance. serviceName: %s, options: %v", string(req.Host()), req.Options()) return nil, errors.NewPublic("instance not found") } return ins, nil } func (b *BalancerFactory) getCacheResult(ctx context.Context, req *protocol.Request) (*cacheResult, error) { target := b.resolver.Target(ctx, &discovery.TargetInfo{Host: string(req.Host()), Tags: req.Options().Tags()}) cr, existed := b.cache.Load(target) if existed { return cr.(*cacheResult), nil } cr, err, _ := b.sfg.Do(target, func() (interface{}, error) { cache := &cacheResult{ serviceName: string(req.Host()), } res, err := b.resolver.Resolve(ctx, target) if err != nil { return cache, err } renameResultCacheKey(&res, b.resolver.Name()) cache.res.Store(res) atomic.StoreInt32(&cache.expire, 0) b.balancer.Rebalance(res) b.cache.Store(target, cache) return cache, nil }) if err != nil { return nil, err } return cr.(*cacheResult), nil } ================================================ FILE: pkg/app/client/loadbalance/lbcache_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package loadbalance import ( "context" "fmt" "strconv" "sync/atomic" "testing" "time" "github.com/cloudwego/hertz/pkg/app/client/discovery" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/protocol" ) func TestBuilder(t *testing.T) { ins := discovery.NewInstance("tcp", "127.0.0.1:8888", 10, map[string]string{"a": "b"}) r := &discovery.SynthesizedResolver{ ResolveFunc: func(ctx context.Context, key string) (discovery.Result, error) { return discovery.Result{CacheKey: key, Instances: []discovery.Instance{ins}}, nil }, TargetFunc: func(ctx context.Context, target *discovery.TargetInfo) string { return "mockRoute" }, NameFunc: func() string { return t.Name() }, } lb := mockLoadbalancer{ rebalanceFunc: nil, deleteFunc: nil, pickFunc: func(res discovery.Result) discovery.Instance { assert.Assert(t, res.CacheKey == t.Name()+":mockRoute", res.CacheKey) assert.Assert(t, len(res.Instances) == 1) assert.Assert(t, len(res.Instances) == 1) assert.Assert(t, res.Instances[0].Address().String() == "127.0.0.1:8888") return res.Instances[0] }, nameFunc: func() string { return "Synthesized" }, } NewBalancerFactory(Config{ Balancer: lb, LbOpts: DefaultLbOpts, Resolver: r, }) b, ok := balancerFactories.Load(cacheKey(t.Name(), "Synthesized", DefaultLbOpts)) assert.Assert(t, ok) assert.Assert(t, b != nil) req := &protocol.Request{} req.SetHost("hertz.api.test") ins1, err := b.(*BalancerFactory).GetInstance(context.TODO(), req) assert.Assert(t, err == nil) assert.Assert(t, ins1.Address().String() == "127.0.0.1:8888") assert.Assert(t, ins1.Weight() == 10) value, exists := ins1.Tag("a") assert.Assert(t, value == "b") assert.Assert(t, exists == true) } func TestBalancerCache(t *testing.T) { count := 10 inss := make([]discovery.Instance, 0, count) for i := 0; i < count; i++ { inss = append(inss, discovery.NewInstance("tcp", fmt.Sprint(i), 10, nil)) } r := &discovery.SynthesizedResolver{ TargetFunc: func(ctx context.Context, target *discovery.TargetInfo) string { return target.Host }, ResolveFunc: func(ctx context.Context, key string) (discovery.Result, error) { return discovery.Result{CacheKey: "svc", Instances: inss}, nil }, NameFunc: func() string { return t.Name() }, } lb := NewWeightedBalancer() for i := 0; i < count; i++ { blf := NewBalancerFactory(Config{ Balancer: lb, LbOpts: Options{}, Resolver: r, }) req := &protocol.Request{} req.SetHost("svc") for a := 0; a < count; a++ { addr, err := blf.GetInstance(context.TODO(), req) assert.Assert(t, err == nil, err) t.Logf("count: %d addr: %s\n", i, addr.Address().String()) } } } func TestBalancerRefresh(t *testing.T) { var ins atomic.Value ins.Store(discovery.NewInstance("tcp", "127.0.0.1:8888", 10, nil)) r := &discovery.SynthesizedResolver{ TargetFunc: func(ctx context.Context, target *discovery.TargetInfo) string { return target.Host }, ResolveFunc: func(ctx context.Context, key string) (discovery.Result, error) { return discovery.Result{CacheKey: "svc1", Instances: []discovery.Instance{ins.Load().(discovery.Instance)}}, nil }, NameFunc: func() string { return t.Name() }, } opts := DefaultLbOpts opts.RefreshInterval = 30 * time.Millisecond blf := NewBalancerFactory(Config{ Balancer: NewWeightedBalancer(), LbOpts: opts, Resolver: r, }) req := &protocol.Request{} req.SetHost("svc1") addr, err := blf.GetInstance(context.Background(), req) assert.Assert(t, err == nil, err) assert.Assert(t, addr.Address().String() == "127.0.0.1:8888") ins.Store(discovery.NewInstance("tcp", "127.0.0.1:8889", 10, nil)) addr, err = blf.GetInstance(context.Background(), req) assert.Assert(t, err == nil, err) assert.Assert(t, addr.Address().String() == "127.0.0.1:8888") time.Sleep(2 * opts.RefreshInterval) addr, err = blf.GetInstance(context.Background(), req) assert.Assert(t, err == nil, err) assert.Assert(t, addr.Address().String() == "127.0.0.1:8889") } func TestBalancerExpires(t *testing.T) { n := int32(1000) r := &discovery.SynthesizedResolver{ TargetFunc: func(ctx context.Context, target *discovery.TargetInfo) string { return target.Host }, ResolveFunc: func(ctx context.Context, key string) (discovery.Result, error) { ins := discovery.NewInstance("tcp", "127.0.0.1:"+strconv.Itoa(int(atomic.AddInt32(&n, 1))), 10, nil) return discovery.Result{CacheKey: "svc1", Instances: []discovery.Instance{ins}}, nil }, NameFunc: func() string { return t.Name() }, } opts := DefaultLbOpts opts.ExpireInterval = 30 * time.Millisecond blf := NewBalancerFactory(Config{ Balancer: NewWeightedBalancer(), LbOpts: opts, Resolver: r, }) req := &protocol.Request{} req.SetHost("svc1") addr1, err := blf.GetInstance(context.Background(), req) assert.Assert(t, err == nil, err) addr2, err := blf.GetInstance(context.Background(), req) assert.Assert(t, err == nil, err) assert.Assert(t, addr1.Address().String() == addr2.Address().String()) time.Sleep(3 * opts.ExpireInterval) addr3, err := blf.GetInstance(context.Background(), req) assert.Assert(t, err == nil, err) assert.Assert(t, addr3.Address().String() != addr2.Address().String()) } func TestCacheKey(t *testing.T) { uniqueKey := cacheKey("hello", "world", Options{RefreshInterval: 15 * time.Second, ExpireInterval: 5 * time.Minute}) assert.Assert(t, uniqueKey == "hello|world|{15s 5m0s}") } type mockLoadbalancer struct { rebalanceFunc func(ch discovery.Result) deleteFunc func(key string) pickFunc func(discovery.Result) discovery.Instance nameFunc func() string } // Rebalance implements the Loadbalancer interface. func (m mockLoadbalancer) Rebalance(ch discovery.Result) { if m.rebalanceFunc != nil { m.rebalanceFunc(ch) } } // Delete implements the Loadbalancer interface. func (m mockLoadbalancer) Delete(ch string) { if m.deleteFunc != nil { m.deleteFunc(ch) } } // Name implements the Loadbalancer interface. func (m mockLoadbalancer) Name() string { if m.nameFunc != nil { return m.nameFunc() } return "" } // Pick implements the Loadbalancer interface. func (m mockLoadbalancer) Pick(d discovery.Result) discovery.Instance { if m.pickFunc != nil { return m.pickFunc(d) } return nil } ================================================ FILE: pkg/app/client/loadbalance/loadbalance.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package loadbalance import ( "time" "github.com/cloudwego/hertz/pkg/app/client/discovery" ) // Loadbalancer picks instance for the given service discovery result. type Loadbalancer interface { // Pick is used to select an instance according to discovery result Pick(discovery.Result) discovery.Instance // Rebalance is used to refresh the cache of load balance's information Rebalance(discovery.Result) // Delete is used to delete the cache of load balance's information when it is expired Delete(string) // Name returns the name of the Loadbalancer. Name() string } const ( DefaultRefreshInterval = 5 * time.Second DefaultExpireInterval = 15 * time.Second ) var DefaultLbOpts = Options{ RefreshInterval: DefaultRefreshInterval, ExpireInterval: DefaultExpireInterval, } // Options for LoadBalance option type Options struct { // refresh discovery result timely RefreshInterval time.Duration // Balancer expire check interval // we need remove idle Balancers for resource saving ExpireInterval time.Duration } // Check checks option's param func (v *Options) Check() { if v.RefreshInterval <= 0 { v.RefreshInterval = DefaultRefreshInterval } if v.ExpireInterval <= 0 { v.ExpireInterval = DefaultExpireInterval } } ================================================ FILE: pkg/app/client/loadbalance/weight_random.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package loadbalance import ( "sync" "github.com/bytedance/gopkg/lang/fastrand" "github.com/cloudwego/hertz/pkg/app/client/discovery" "github.com/cloudwego/hertz/pkg/common/hlog" "golang.org/x/sync/singleflight" ) type weightedBalancer struct { cachedWeightInfo sync.Map sfg singleflight.Group } type weightInfo struct { instances []discovery.Instance entries []int weightSum int } // NewWeightedBalancer creates a loadbalancer using weighted-random algorithm. func NewWeightedBalancer() Loadbalancer { lb := &weightedBalancer{} return lb } func (wb *weightedBalancer) calcWeightInfo(e discovery.Result) *weightInfo { w := &weightInfo{ instances: make([]discovery.Instance, len(e.Instances)), weightSum: 0, entries: make([]int, len(e.Instances)), } var cnt int for idx := range e.Instances { weight := e.Instances[idx].Weight() if weight > 0 { w.instances[cnt] = e.Instances[idx] w.entries[cnt] = weight w.weightSum += weight cnt++ } else { hlog.SystemLogger().Warnf("Invalid weight=%d on instance address=%s", weight, e.Instances[idx].Address()) } } w.instances = w.instances[:cnt] return w } // Pick implements the Loadbalancer interface. func (wb *weightedBalancer) Pick(e discovery.Result) discovery.Instance { wi, ok := wb.cachedWeightInfo.Load(e.CacheKey) if !ok { wi, _, _ = wb.sfg.Do(e.CacheKey, func() (interface{}, error) { return wb.calcWeightInfo(e), nil }) wb.cachedWeightInfo.Store(e.CacheKey, wi) } w := wi.(*weightInfo) if w.weightSum <= 0 { return nil } weight := fastrand.Intn(w.weightSum) for i := 0; i < len(w.instances); i++ { weight -= w.entries[i] if weight < 0 { return w.instances[i] } } return nil } // Rebalance implements the Loadbalancer interface. func (wb *weightedBalancer) Rebalance(e discovery.Result) { wb.cachedWeightInfo.Store(e.CacheKey, wb.calcWeightInfo(e)) } // Delete implements the Loadbalancer interface. func (wb *weightedBalancer) Delete(cacheKey string) { wb.cachedWeightInfo.Delete(cacheKey) } func (wb *weightedBalancer) Name() string { return "weight_random" } ================================================ FILE: pkg/app/client/loadbalance/weight_random_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package loadbalance import ( "math" "testing" "github.com/cloudwego/hertz/pkg/app/client/discovery" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestWeightedBalancer(t *testing.T) { balancer := NewWeightedBalancer() // nil ins := balancer.Pick(discovery.Result{}) assert.DeepEqual(t, ins, nil) // empty instance e := discovery.Result{ Instances: make([]discovery.Instance, 0), CacheKey: "a", } balancer.Rebalance(e) ins = balancer.Pick(e) assert.DeepEqual(t, ins, nil) // one instance insList := []discovery.Instance{ discovery.NewInstance("tcp", "127.0.0.1:8888", 20, nil), } e = discovery.Result{ Instances: insList, CacheKey: "b", } balancer.Rebalance(e) for i := 0; i < 100; i++ { ins = balancer.Pick(e) assert.DeepEqual(t, ins.Weight(), 20) } // multi instances, weightSum > 0 insList = []discovery.Instance{ discovery.NewInstance("tcp", "127.0.0.1:8881", 100, nil), discovery.NewInstance("tcp", "127.0.0.1:8882", 200, nil), discovery.NewInstance("tcp", "127.0.0.1:8883", 300, nil), discovery.NewInstance("tcp", "127.0.0.1:8884", 400, nil), discovery.NewInstance("tcp", "127.0.0.1:8885", 500, nil), } var weightSum int for _, ins := range insList { weight := ins.Weight() weightSum += weight } n := 1000000 pickedStat := map[int]int{} e = discovery.Result{ Instances: insList, CacheKey: "c", } balancer.Rebalance(e) for i := 0; i < n; i++ { ins = balancer.Pick(e) weight := ins.Weight() if pickedCnt, ok := pickedStat[weight]; ok { pickedStat[weight] = pickedCnt + 1 } else { pickedStat[weight] = 1 } } for _, ins := range insList { weight := ins.Weight() expect := float64(weight) / float64(weightSum) * float64(n) actual := float64(pickedStat[weight]) delta := math.Abs(expect - actual) assert.DeepEqual(t, true, delta/expect < 0.05) } // have instances that weight < 0 insList = []discovery.Instance{ discovery.NewInstance("tcp", "127.0.0.1:8881", 10, nil), discovery.NewInstance("tcp", "127.0.0.1:8882", -10, nil), } e = discovery.Result{ Instances: insList, CacheKey: "d", } balancer.Rebalance(e) for i := 0; i < 1000; i++ { ins = balancer.Pick(e) assert.DeepEqual(t, 10, ins.Weight()) } } ================================================ FILE: pkg/app/client/middleware.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package client import ( "context" "github.com/cloudwego/hertz/pkg/protocol" ) // Endpoint represent one method for calling from remote. type Endpoint func(ctx context.Context, req *protocol.Request, resp *protocol.Response) (err error) // Middleware deal with input Endpoint and output Endpoint. type Middleware func(Endpoint) Endpoint // Chain connect middlewares into one middleware. func chain(mws ...Middleware) Middleware { return func(next Endpoint) Endpoint { for i := len(mws) - 1; i >= 0; i-- { next = mws[i](next) } return next } } ================================================ FILE: pkg/app/client/middleware_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package client import ( "context" "testing" "github.com/cloudwego/hertz/pkg/protocol" ) var ( biz = "Biz" beforeMW0 = "BeforeMiddleware0" afterMW0 = "AfterMiddleware0" beforeMW1 = "BeforeMiddleware1" afterMW1 = "AfterMiddleware1" ) func invoke(ctx context.Context, req *protocol.Request, resp *protocol.Response) (err error) { req.BodyBuffer().WriteString(biz) return nil } func mockMW0(next Endpoint) Endpoint { return func(ctx context.Context, req *protocol.Request, resp *protocol.Response) (err error) { req.BodyBuffer().WriteString(beforeMW0) err = next(ctx, req, resp) if err != nil { return err } req.BodyBuffer().WriteString(afterMW0) return nil } } func mockMW1(next Endpoint) Endpoint { return func(ctx context.Context, req *protocol.Request, resp *protocol.Response) (err error) { req.BodyBuffer().WriteString(beforeMW1) err = next(ctx, req, resp) if err != nil { return err } req.BodyBuffer().WriteString(afterMW1) return nil } } func TestChain(t *testing.T) { mws := chain(mockMW0, mockMW1) req := protocol.AcquireRequest() mws(invoke)(context.Background(), req, nil) final := beforeMW0 + beforeMW1 + biz + afterMW1 + afterMW0 if req.BodyBuffer().String() != final { t.Errorf("unexpected %#v, expected %#v", req.BodyBuffer().String(), final) } } ================================================ FILE: pkg/app/client/option.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package client import ( "crypto/tls" "time" "github.com/cloudwego/hertz/pkg/app/client/retry" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/network/dialer" "github.com/cloudwego/hertz/pkg/network/standard" "github.com/cloudwego/hertz/pkg/protocol/consts" ) // WithDialTimeout sets dial timeout. func WithDialTimeout(dialTimeout time.Duration) config.ClientOption { return config.ClientOption{F: func(o *config.ClientOptions) { o.DialTimeout = dialTimeout }} } // WithMaxConnsPerHost sets maximum number of connections per host which may be established. func WithMaxConnsPerHost(mc int) config.ClientOption { return config.ClientOption{F: func(o *config.ClientOptions) { o.MaxConnsPerHost = mc }} } // WithMaxIdleConnDuration sets max idle connection duration, idle keep-alive connections are closed after this duration. func WithMaxIdleConnDuration(t time.Duration) config.ClientOption { return config.ClientOption{F: func(o *config.ClientOptions) { o.MaxIdleConnDuration = t }} } // WithMaxConnDuration sets max connection duration, keep-alive connections are closed after this duration. func WithMaxConnDuration(t time.Duration) config.ClientOption { return config.ClientOption{F: func(o *config.ClientOptions) { o.MaxConnDuration = t }} } // WithMaxConnWaitTimeout sets maximum duration for waiting for a free connection. func WithMaxConnWaitTimeout(t time.Duration) config.ClientOption { return config.ClientOption{F: func(o *config.ClientOptions) { o.MaxConnWaitTimeout = t }} } // WithKeepAlive determines whether use keep-alive connection. func WithKeepAlive(b bool) config.ClientOption { return config.ClientOption{F: func(o *config.ClientOptions) { o.KeepAlive = b }} } // WithClientReadTimeout sets maximum duration for full response reading (including body). func WithClientReadTimeout(t time.Duration) config.ClientOption { return config.ClientOption{F: func(o *config.ClientOptions) { o.ReadTimeout = t }} } // WithTLSConfig sets tlsConfig to create a tls connection. func WithTLSConfig(cfg *tls.Config) config.ClientOption { return config.ClientOption{F: func(o *config.ClientOptions) { o.TLSConfig = cfg o.Dialer = standard.NewDialer() }} } // WithDialer sets the specific dialer. func WithDialer(d network.Dialer) config.ClientOption { return config.ClientOption{F: func(o *config.ClientOptions) { o.Dialer = d }} } // WithResponseBodyStream is used to determine whether read body in stream or not. func WithResponseBodyStream(b bool) config.ClientOption { return config.ClientOption{F: func(o *config.ClientOptions) { o.ResponseBodyStream = b }} } // WithHostClientConfigHook is used to set the function hook for re-configure the host client. func WithHostClientConfigHook(h func(hc interface{}) error) config.ClientOption { return config.ClientOption{F: func(o *config.ClientOptions) { o.HostClientConfigHook = h }} } // WithDisableHeaderNamesNormalizing is used to set whether disable header names normalizing. func WithDisableHeaderNamesNormalizing(disable bool) config.ClientOption { return config.ClientOption{F: func(o *config.ClientOptions) { o.DisableHeaderNamesNormalizing = disable }} } // WithName sets client name which used in User-Agent Header. func WithName(name string) config.ClientOption { return config.ClientOption{F: func(o *config.ClientOptions) { o.Name = name }} } // WithNoDefaultUserAgentHeader sets whether no default User-Agent header. func WithNoDefaultUserAgentHeader(isNoDefaultUserAgentHeader bool) config.ClientOption { return config.ClientOption{F: func(o *config.ClientOptions) { o.NoDefaultUserAgentHeader = isNoDefaultUserAgentHeader }} } // WithDisablePathNormalizing sets whether disable path normalizing. func WithDisablePathNormalizing(isDisablePathNormalizing bool) config.ClientOption { return config.ClientOption{F: func(o *config.ClientOptions) { o.DisablePathNormalizing = isDisablePathNormalizing }} } func WithRetryConfig(opts ...retry.Option) config.ClientOption { retryCfg := &retry.Config{ MaxAttemptTimes: consts.DefaultMaxRetryTimes, Delay: 1 * time.Millisecond, MaxDelay: 100 * time.Millisecond, MaxJitter: 20 * time.Millisecond, DelayPolicy: retry.CombineDelay(retry.DefaultDelayPolicy), } retryCfg.Apply(opts) return config.ClientOption{F: func(o *config.ClientOptions) { o.RetryConfig = retryCfg }} } // WithWriteTimeout sets write timeout. func WithWriteTimeout(t time.Duration) config.ClientOption { return config.ClientOption{F: func(o *config.ClientOptions) { o.WriteTimeout = t }} } // WithConnStateObserve sets the connection state observation function. // The first param is used to set hostclient state func. // The second param is used to set observation interval, default value is 5 seconds. // Warn: Do not start go routine in HostClientStateFunc. func WithConnStateObserve(hs config.HostClientStateFunc, interval ...time.Duration) config.ClientOption { return config.ClientOption{F: func(o *config.ClientOptions) { o.HostClientStateObserve = hs if len(interval) > 0 { o.ObservationInterval = interval[0] } }} } // WithDialFunc is used to set dialer function. // Note: WithDialFunc will overwrite custom dialer. func WithDialFunc(f network.DialFunc, dialers ...network.Dialer) config.ClientOption { return config.ClientOption{F: func(o *config.ClientOptions) { d := dialer.DefaultDialer() if len(dialers) != 0 { d = dialers[0] } o.Dialer = newCustomDialerWithDialFunc(d, f) }} } // customDialer set customDialerFunc and params to set dailFunc type customDialer struct { network.Dialer dialFunc network.DialFunc } func (m *customDialer) DialConnection(network, address string, timeout time.Duration, tlsConfig *tls.Config) (conn network.Conn, err error) { if m.dialFunc != nil { return m.dialFunc(address) } return m.Dialer.DialConnection(network, address, timeout, tlsConfig) } func newCustomDialerWithDialFunc(dialer network.Dialer, dialFunc network.DialFunc) network.Dialer { return &customDialer{ Dialer: dialer, dialFunc: dialFunc, } } ================================================ FILE: pkg/app/client/option_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package client import ( "testing" "time" "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/app/client/retry" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestClientOptions(t *testing.T) { // default opt := config.NewClientOptions(nil) assert.DeepEqual(t, 0, opt.MaxConnsPerHost) assert.DeepEqual(t, consts.DefaultDialTimeout, opt.DialTimeout) assert.DeepEqual(t, consts.DefaultMaxIdleConnDuration, opt.MaxIdleConnDuration) assert.DeepEqual(t, true, opt.KeepAlive) assert.DeepEqual(t, 5*time.Second, opt.ObservationInterval) // config opt = config.NewClientOptions([]config.ClientOption{ WithDialTimeout(100 * time.Millisecond), WithMaxConnsPerHost(128), WithMaxIdleConnDuration(5 * time.Second), WithMaxConnDuration(10 * time.Second), WithMaxConnWaitTimeout(5 * time.Second), WithKeepAlive(false), WithClientReadTimeout(1 * time.Second), WithResponseBodyStream(true), WithRetryConfig( retry.WithMaxAttemptTimes(2), retry.WithInitDelay(100*time.Millisecond), retry.WithMaxDelay(5*time.Second), retry.WithMaxJitter(1*time.Second), retry.WithDelayPolicy(retry.CombineDelay(retry.DefaultDelayPolicy, retry.FixedDelayPolicy, retry.BackOffDelayPolicy)), ), WithWriteTimeout(time.Second), WithConnStateObserve(nil, time.Second), }) assert.DeepEqual(t, 100*time.Millisecond, opt.DialTimeout) assert.DeepEqual(t, 128, opt.MaxConnsPerHost) assert.DeepEqual(t, 5*time.Second, opt.MaxIdleConnDuration) assert.DeepEqual(t, 10*time.Second, opt.MaxConnDuration) assert.DeepEqual(t, 5*time.Second, opt.MaxConnWaitTimeout) assert.DeepEqual(t, false, opt.KeepAlive) assert.DeepEqual(t, 1*time.Second, opt.ReadTimeout) assert.DeepEqual(t, 1*time.Second, opt.WriteTimeout) assert.DeepEqual(t, true, opt.ResponseBodyStream) assert.DeepEqual(t, uint(2), opt.RetryConfig.MaxAttemptTimes) assert.DeepEqual(t, 100*time.Millisecond, opt.RetryConfig.Delay) assert.DeepEqual(t, 5*time.Second, opt.RetryConfig.MaxDelay) assert.DeepEqual(t, 1*time.Second, opt.RetryConfig.MaxJitter) assert.DeepEqual(t, 1*time.Second, opt.ObservationInterval) for i := 0; i < 100; i++ { assert.DeepEqual(t, opt.RetryConfig.DelayPolicy(uint(i), nil, opt.RetryConfig), retry.CombineDelay(retry.DefaultDelayPolicy, retry.FixedDelayPolicy, retry.BackOffDelayPolicy)(uint(i), nil, opt.RetryConfig)) } } ================================================ FILE: pkg/app/client/retry/option.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package retry import "time" // Option is the only struct that can be used to set Retry Config. type Option struct { F func(o *Config) } // WithMaxAttemptTimes set WithMaxAttemptTimes , including the first call. func WithMaxAttemptTimes(maxAttemptTimes uint) Option { return Option{F: func(o *Config) { o.MaxAttemptTimes = maxAttemptTimes }} } // WithInitDelay set init Delay. func WithInitDelay(delay time.Duration) Option { return Option{F: func(o *Config) { o.Delay = delay }} } // WithMaxDelay set MaxDelay. func WithMaxDelay(maxDelay time.Duration) Option { return Option{F: func(o *Config) { o.MaxDelay = maxDelay }} } // WithDelayPolicy set DelayPolicy. func WithDelayPolicy(delayPolicy DelayPolicyFunc) Option { return Option{F: func(o *Config) { o.DelayPolicy = delayPolicy }} } // WithMaxJitter set MaxJitter. func WithMaxJitter(maxJitter time.Duration) Option { return Option{F: func(o *Config) { o.MaxJitter = maxJitter }} } ================================================ FILE: pkg/app/client/retry/retry.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package retry import ( "math" "time" "github.com/bytedance/gopkg/lang/fastrand" ) // Config All configurations related to retry type Config struct { // The maximum number of call attempt times, including the initial call MaxAttemptTimes uint // Initial retry delay time Delay time.Duration // Maximum retry delay time. When the retry time increases beyond this time, // this configuration will limit the upper limit of waiting time MaxDelay time.Duration // The maximum jitter time, which takes effect when the delay policy is configured as RandomDelay MaxJitter time.Duration // Delay strategy, which can combine multiple delay strategies. such as CombineDelay(BackOffDelayPolicy, RandomDelayPolicy) or BackOffDelayPolicy,etc DelayPolicy DelayPolicyFunc } func (o *Config) Apply(opts []Option) { for _, op := range opts { op.F(o) } } // DelayPolicyFunc signature of delay policy function // is called to return the delay of retry type DelayPolicyFunc func(attempts uint, err error, retryConfig *Config) time.Duration // DefaultDelayPolicy is a DelayPolicyFunc which keep 0 delay in all iterations func DefaultDelayPolicy(_ uint, _ error, _ *Config) time.Duration { return 0 * time.Millisecond } // FixedDelayPolicy is a DelayPolicyFunc which keeps delay the same through all iterations func FixedDelayPolicy(_ uint, _ error, retryConfig *Config) time.Duration { return retryConfig.Delay } // RandomDelayPolicy is a DelayPolicyFunc which picks a random delay up to RetryConfig.MaxJitter, if the retryConfig.MaxJitter less than or equal to 0, the final delay is 0 func RandomDelayPolicy(_ uint, _ error, retryConfig *Config) time.Duration { if retryConfig.MaxJitter <= 0 { return 0 * time.Millisecond } return time.Duration(fastrand.Int63n(int64(retryConfig.MaxJitter))) } // BackOffDelayPolicy is a DelayPolicyFunc which exponentially increases delay between consecutive retries, if the retryConfig.Delay less than or equal to 0, the final delay is 0 func BackOffDelayPolicy(attempts uint, _ error, retryConfig *Config) time.Duration { if retryConfig.Delay <= 0 { return 0 * time.Millisecond } // 1 << 63 would overflow signed int64 (time.Duration), thus 62. const max uint = 62 if attempts > max { attempts = max } return retryConfig.Delay << attempts } // CombineDelay return DelayPolicyFunc, which combines the optional DelayPolicyFunc into a new DelayPolicyFunc func CombineDelay(delays ...DelayPolicyFunc) DelayPolicyFunc { const maxInt64 = uint64(math.MaxInt64) return func(attempts uint, err error, config *Config) time.Duration { var total uint64 for _, delay := range delays { total += uint64(delay(attempts, err, config)) if total > maxInt64 { total = maxInt64 } } return time.Duration(total) } } // Delay generate the delay time required for the current retry config, if the retryConfig.DelayPolicy == nil, the final delay is 0 func Delay(attempts uint, err error, retryConfig *Config) time.Duration { if retryConfig.DelayPolicy == nil { return 0 * time.Millisecond } delayTime := retryConfig.DelayPolicy(attempts, err, retryConfig) if retryConfig.MaxDelay > 0 && delayTime > retryConfig.MaxDelay { delayTime = retryConfig.MaxDelay } return delayTime } ================================================ FILE: pkg/app/client/retry/retry_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package retry import ( "math" "testing" "time" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestApply(t *testing.T) { delayPolicyFunc := func(attempts uint, err error, retryConfig *Config) time.Duration { return time.Second } options := []Option{} options = append(options, WithMaxAttemptTimes(100), WithInitDelay(time.Second), WithMaxDelay(time.Second), WithDelayPolicy(delayPolicyFunc), WithMaxJitter(time.Second)) config := Config{} config.Apply(options) assert.DeepEqual(t, uint(100), config.MaxAttemptTimes) assert.DeepEqual(t, time.Second, config.Delay) assert.DeepEqual(t, time.Second, config.MaxDelay) assert.DeepEqual(t, time.Second, Delay(0, nil, &config)) assert.DeepEqual(t, time.Second, config.MaxJitter) } func TestPolicy(t *testing.T) { dur := DefaultDelayPolicy(0, nil, nil) assert.DeepEqual(t, 0*time.Millisecond, dur) config := Config{ Delay: time.Second, } dur = FixedDelayPolicy(0, nil, &config) assert.DeepEqual(t, time.Second, dur) dur = RandomDelayPolicy(0, nil, &config) assert.DeepEqual(t, 0*time.Millisecond, dur) config.MaxJitter = time.Second * 1 dur = RandomDelayPolicy(0, nil, &config) assert.NotEqual(t, time.Second*1, dur) dur = BackOffDelayPolicy(0, nil, &config) assert.DeepEqual(t, time.Second*1, dur) config.Delay = time.Duration(-1) dur = BackOffDelayPolicy(0, nil, &config) assert.DeepEqual(t, time.Second*0, dur) config.Delay = time.Duration(1) dur = BackOffDelayPolicy(63, nil, &config) durExp := config.Delay << 62 assert.DeepEqual(t, durExp, dur) dur = Delay(0, nil, &config) assert.DeepEqual(t, 0*time.Millisecond, dur) delayPolicyFunc := func(attempts uint, err error, retryConfig *Config) time.Duration { return time.Second } config.DelayPolicy = delayPolicyFunc config.MaxDelay = time.Second / 2 dur = Delay(0, nil, &config) assert.DeepEqual(t, config.MaxDelay, dur) delayPolicyFunc2 := func(attempts uint, err error, retryConfig *Config) time.Duration { return time.Duration(math.MaxInt64) } delayFunc := CombineDelay(delayPolicyFunc2, delayPolicyFunc) dur = delayFunc(0, nil, &config) assert.DeepEqual(t, time.Duration(math.MaxInt64), dur) } ================================================ FILE: pkg/app/context.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package app import ( "context" "fmt" "io" "mime/multipart" "net" "net/url" "os" "reflect" "strings" "sync" "time" "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/pkg/app/server/binding" "github.com/cloudwego/hertz/pkg/app/server/render" "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/tracer/traceinfo" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" rConsts "github.com/cloudwego/hertz/pkg/route/consts" "github.com/cloudwego/hertz/pkg/route/param" ) var zeroTCPAddr = &net.TCPAddr{ IP: net.IPv4zero, } type Handler interface { ServeHTTP(c context.Context, ctx *RequestContext) } type ClientIP func(ctx *RequestContext) string type ClientIPOptions struct { RemoteIPHeaders []string TrustedCIDRs []*net.IPNet } var defaultTrustedCIDRs = []*net.IPNet{ { // 0.0.0.0/0 (IPv4) IP: net.IP{0x0, 0x0, 0x0, 0x0}, Mask: net.IPMask{0x0, 0x0, 0x0, 0x0}, }, { // ::/0 (IPv6) IP: net.IP{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, Mask: net.IPMask{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, }, } var defaultClientIPOptions = ClientIPOptions{ RemoteIPHeaders: []string{"X-Forwarded-For", "X-Real-IP"}, TrustedCIDRs: defaultTrustedCIDRs, } var loopbackIP = net.ParseIP("127.0.0.1") // ClientIPWithOption used to generate custom ClientIP function and set by engine.SetClientIPFunc func ClientIPWithOption(opts ClientIPOptions) ClientIP { return func(ctx *RequestContext) string { remoteIPStr := "" trustedProxy := false if addr := ctx.RemoteAddr(); strings.HasPrefix(addr.Network(), "unix") { // unix, unixgram, unixpacket is considered same as "127.0.0.1" remoteIPStr = addr.String() trustedProxy = isTrustedProxy(opts.TrustedCIDRs, loopbackIP) } else { h, _, err := net.SplitHostPort(strings.TrimSpace(addr.String())) if err != nil { return "" } remoteIPStr = h trustedProxy = isTrustedProxy(opts.TrustedCIDRs, net.ParseIP(h)) } if trustedProxy { for _, headerName := range opts.RemoteIPHeaders { ip, valid := validateHeader(opts.TrustedCIDRs, ctx.Request.Header.Get(headerName)) if valid { return ip } } } return remoteIPStr } } // isTrustedProxy will check whether the IP address is included in the trusted list according to trustedCIDRs func isTrustedProxy(trustedCIDRs []*net.IPNet, remoteIP net.IP) bool { if trustedCIDRs == nil || remoteIP == nil { return false } for _, cidr := range trustedCIDRs { if cidr.Contains(remoteIP) { return true } } return false } // validateHeader will parse X-Real-IP and X-Forwarded-For header and return the Initial client IP address or an untrusted IP address func validateHeader(trustedCIDRs []*net.IPNet, header string) (clientIP string, valid bool) { if header == "" { return "", false } items := strings.Split(header, ",") for i := len(items) - 1; i >= 0; i-- { ipStr := strings.TrimSpace(items[i]) ip := net.ParseIP(ipStr) if ip == nil { break } // X-Forwarded-For is appended by proxy // Check IPs in reverse order and stop when find untrusted proxy if (i == 0) || (!isTrustedProxy(trustedCIDRs, ip)) { return ipStr, true } } return "", false } var defaultClientIP = ClientIPWithOption(defaultClientIPOptions) // SetClientIPFunc sets ClientIP function implementation to get ClientIP. // Deprecated: Use engine.SetClientIPFunc instead of SetClientIPFunc func SetClientIPFunc(fn ClientIP) { defaultClientIP = fn } type FormValueFunc func(*RequestContext, string) []byte var defaultFormValue = func(ctx *RequestContext, key string) []byte { v := ctx.QueryArgs().Peek(key) if len(v) > 0 { return v } v = ctx.PostArgs().Peek(key) if len(v) > 0 { return v } mf, err := ctx.MultipartForm() if err == nil && mf.Value != nil { vv := mf.Value[key] if len(vv) > 0 { return []byte(vv[0]) } } return nil } type RequestContext struct { conn network.Conn Request protocol.Request Response protocol.Response // Errors is a list of errors attached to all the handlers/middlewares who used this context. Errors errors.ErrorChain Params param.Params handlers HandlersChain fullPath string index int8 HTMLRender render.HTMLRender // This mutex protect Keys map. mu sync.RWMutex // Keys is a key/value pair exclusively for the context of each request. Keys map[string]interface{} hijackHandler HijackHandler finishedMu sync.Mutex // finished means the request end. finished chan struct{} // traceInfo defines the trace information. traceInfo traceinfo.TraceInfo // enableTrace defines whether enable trace. enableTrace bool // clientIPFunc get client ip by use custom function. clientIPFunc ClientIP // clientIPFunc get form value by use custom function. formValueFunc FormValueFunc binder binding.Binder exiled bool } // Exile marks this RequestContext as not to be recycled. // Experimental features: Use with caution, it may have a slight impact on performance. func (ctx *RequestContext) Exile() { ctx.exiled = true } func (ctx *RequestContext) IsExiled() bool { return ctx.exiled } // Flush is the shortcut for ctx.Response.GetHijackWriter().Flush(). // Will return nil if the response writer is not hijacked. func (ctx *RequestContext) Flush() error { if ctx.Response.GetHijackWriter() == nil { return nil } return ctx.Response.GetHijackWriter().Flush() } func (ctx *RequestContext) SetClientIPFunc(f ClientIP) { ctx.clientIPFunc = f } func (ctx *RequestContext) SetFormValueFunc(f FormValueFunc) { ctx.formValueFunc = f } func (ctx *RequestContext) SetBinder(binder binding.Binder) { ctx.binder = binder } func (ctx *RequestContext) GetTraceInfo() traceinfo.TraceInfo { return ctx.traceInfo } func (ctx *RequestContext) SetTraceInfo(t traceinfo.TraceInfo) { ctx.traceInfo = t } func (ctx *RequestContext) IsEnableTrace() bool { return ctx.enableTrace } // SetEnableTrace sets whether enable trace. // // NOTE: biz handler must not modify this value, otherwise, it may panic. func (ctx *RequestContext) SetEnableTrace(enable bool) { ctx.enableTrace = enable } // NewContext make a pure RequestContext without any http request/response information // // Set the Request filed before use it for handlers func NewContext(maxParams uint16) *RequestContext { v := make(param.Params, 0, maxParams) ctx := &RequestContext{Params: v, index: -1} return ctx } // Loop fn for every k/v in Keys func (ctx *RequestContext) ForEachKey(fn func(k string, v interface{})) { ctx.mu.RLock() for key, val := range ctx.Keys { fn(key, val) } ctx.mu.RUnlock() } func (ctx *RequestContext) SetConn(c network.Conn) { ctx.conn = c } func (ctx *RequestContext) GetConn() network.Conn { return ctx.conn } func (ctx *RequestContext) SetHijackHandler(h HijackHandler) { ctx.hijackHandler = h } func (ctx *RequestContext) GetHijackHandler() HijackHandler { return ctx.hijackHandler } func (ctx *RequestContext) GetReader() network.Reader { return ctx.conn } func (ctx *RequestContext) GetWriter() network.Writer { return ctx.conn } func (ctx *RequestContext) GetIndex() int8 { return ctx.index } // SetIndex reset the handler's execution index // Disclaimer: You can loop yourself to deal with this, use wisely. func (ctx *RequestContext) SetIndex(index int8) { ctx.index = index } type HandlerFunc func(c context.Context, ctx *RequestContext) // HandlersChain defines a HandlerFunc array. type HandlersChain []HandlerFunc type HandlerNameOperator interface { SetHandlerName(handler HandlerFunc, name string) GetHandlerName(handler HandlerFunc) string } func SetHandlerNameOperator(o HandlerNameOperator) { inbuiltHandlerNameOperator = o } type inbuiltHandlerNameOperatorStruct struct { handlerNames map[uintptr]string } func (o *inbuiltHandlerNameOperatorStruct) SetHandlerName(handler HandlerFunc, name string) { o.handlerNames[getFuncAddr(handler)] = name } func (o *inbuiltHandlerNameOperatorStruct) GetHandlerName(handler HandlerFunc) string { return o.handlerNames[getFuncAddr(handler)] } type concurrentHandlerNameOperatorStruct struct { handlerNames map[uintptr]string lock sync.RWMutex } func (o *concurrentHandlerNameOperatorStruct) SetHandlerName(handler HandlerFunc, name string) { o.lock.Lock() defer o.lock.Unlock() o.handlerNames[getFuncAddr(handler)] = name } func (o *concurrentHandlerNameOperatorStruct) GetHandlerName(handler HandlerFunc) string { o.lock.RLock() defer o.lock.RUnlock() return o.handlerNames[getFuncAddr(handler)] } func SetConcurrentHandlerNameOperator() { SetHandlerNameOperator(&concurrentHandlerNameOperatorStruct{handlerNames: map[uintptr]string{}}) } func init() { inbuiltHandlerNameOperator = &inbuiltHandlerNameOperatorStruct{handlerNames: map[uintptr]string{}} } var inbuiltHandlerNameOperator HandlerNameOperator func SetHandlerName(handler HandlerFunc, name string) { inbuiltHandlerNameOperator.SetHandlerName(handler, name) } func GetHandlerName(handler HandlerFunc) string { return inbuiltHandlerNameOperator.GetHandlerName(handler) } func getFuncAddr(v interface{}) uintptr { return reflect.ValueOf(reflect.ValueOf(v)).Field(1).Pointer() } // HijackHandler must process the hijacked connection c. // // If KeepHijackedConns is disabled, which is by default, // the connection c is automatically closed after returning from HijackHandler. // // The connection c must not be used after returning from the handler, if KeepHijackedConns is disabled. // // When KeepHijackedConns enabled, hertz will not Close() the connection, // you must do it when you need it. You must not use c in any way after calling Close(). // // network.Connection provide two options of io: net.Conn and zero-copy read/write type HijackHandler func(c network.Conn) // Hijack registers the given handler for connection hijacking. // // The handler is called after returning from RequestHandler // and sending http response. The current connection is passed // to the handler. The connection is automatically closed after // returning from the handler. // // The server skips calling the handler in the following cases: // // - 'Connection: close' header exists in either request or response. // - Unexpected error during response writing to the connection. // // The server stops processing requests from hijacked connections. // // Server limits such as Concurrency, ReadTimeout, WriteTimeout, etc. // aren't applied to hijacked connections. // // The handler must not retain references to ctx members. // // Arbitrary 'Connection: Upgrade' protocols may be implemented // with HijackHandler. For instance, // // - WebSocket ( https://en.wikipedia.org/wiki/WebSocket ) // - HTTP/2.0 ( https://en.wikipedia.org/wiki/HTTP/2 ) func (ctx *RequestContext) Hijack(handler HijackHandler) { ctx.hijackHandler = handler } // Last returns the last handler of the handler chain. // // Generally speaking, the last handler is the main handler. func (c HandlersChain) Last() HandlerFunc { if length := len(c); length > 0 { return c[length-1] } return nil } func (ctx *RequestContext) Finished() <-chan struct{} { ctx.finishedMu.Lock() if ctx.finished == nil { ctx.finished = make(chan struct{}) } ch := ctx.finished ctx.finishedMu.Unlock() return ch } // GetRequest returns a copy of Request. func (ctx *RequestContext) GetRequest() (dst *protocol.Request) { dst = &protocol.Request{} ctx.Request.CopyTo(dst) return } // GetResponse returns a copy of Response. func (ctx *RequestContext) GetResponse() (dst *protocol.Response) { dst = &protocol.Response{} ctx.Response.CopyTo(dst) return } // Value returns the value associated with this context for key, or nil // if no value is associated with key. Successive calls to Value with // the same key returns the same result. // // In case the Key is reset after response, Value() return nil if ctx.Key is nil. func (ctx *RequestContext) Value(key interface{}) interface{} { // this ctx has been reset, return nil. if ctx.Keys == nil { return nil } if keyString, ok := key.(string); ok { val, _ := ctx.Get(keyString) return val } return nil } // Hijacked returns true after Hijack is called. func (ctx *RequestContext) Hijacked() bool { return ctx.hijackHandler != nil } // SetBodyStream sets response body stream and, optionally body size. // // bodyStream.Close() is called after finishing reading all body data // if it implements io.Closer. // // If bodySize is >= 0, then bodySize bytes must be provided by bodyStream // before returning io.EOF. // // If bodySize < 0, then bodyStream is read until io.EOF. // // See also SetBodyStreamWriter. func (ctx *RequestContext) SetBodyStream(bodyStream io.Reader, bodySize int) { ctx.Response.SetBodyStream(bodyStream, bodySize) } // Host returns requested host. // // The host is valid until returning from RequestHandler. func (ctx *RequestContext) Host() []byte { return ctx.URI().Host() } // RemoteAddr returns client address for the given request. // // If address is nil, it will return zeroTCPAddr. func (ctx *RequestContext) RemoteAddr() net.Addr { if ctx.conn == nil { return zeroTCPAddr } addr := ctx.conn.RemoteAddr() if addr == nil { return zeroTCPAddr } return addr } // WriteString appends s to response body. func (ctx *RequestContext) WriteString(s string) (int, error) { ctx.Response.AppendBodyString(s) return len(s), nil } // SetContentType sets response Content-Type. func (ctx *RequestContext) SetContentType(contentType string) { ctx.Response.Header.SetContentType(contentType) } // Path returns requested path. // // The path is valid until returning from RequestHandler. func (ctx *RequestContext) Path() []byte { return ctx.URI().Path() } // NotModified resets response and sets '304 Not Modified' response status code. func (ctx *RequestContext) NotModified() { ctx.Response.Reset() ctx.SetStatusCode(consts.StatusNotModified) } // IfModifiedSince returns true if lastModified exceeds 'If-Modified-Since' // value from the request header. // // The function returns true also 'If-Modified-Since' request header is missing. func (ctx *RequestContext) IfModifiedSince(lastModified time.Time) bool { ifModStr := ctx.Request.Header.PeekIfModifiedSinceBytes() if len(ifModStr) == 0 { return true } ifMod, err := bytesconv.ParseHTTPDate(ifModStr) if err != nil { return true } lastModified = lastModified.Truncate(time.Second) return ifMod.Before(lastModified) } // URI returns requested uri. // // The uri is valid until returning from RequestHandler. func (ctx *RequestContext) URI() *protocol.URI { return ctx.Request.URI() } func (ctx *RequestContext) String(code int, format string, values ...interface{}) { ctx.Render(code, render.String{Format: format, Data: values}) } // FullPath returns a matched route full path. For not found routes // returns an empty string. // // router.GET("/user/:id", func(c context.Context, ctx *app.RequestContext) { // ctx.FullPath() == "/user/:id" // true // }) func (ctx *RequestContext) FullPath() string { return ctx.fullPath } func (ctx *RequestContext) SetFullPath(p string) { ctx.fullPath = p } // SetStatusCode sets response status code. func (ctx *RequestContext) SetStatusCode(statusCode int) { ctx.Response.SetStatusCode(statusCode) } // Write writes p into response body. func (ctx *RequestContext) Write(p []byte) (int, error) { ctx.Response.AppendBody(p) return len(p), nil } // File writes the specified file into the body stream in an efficient way. func (ctx *RequestContext) File(filepath string) { ServeFile(ctx, filepath) } func (ctx *RequestContext) FileFromFS(filepath string, fs *FS) { defer func(old string) { ctx.Request.URI().SetPath(old) }(string(ctx.Request.URI().Path())) ctx.Request.URI().SetPath(filepath) fs.NewRequestHandler()(context.Background(), ctx) } // FileAttachment use an efficient way to write the file to body stream. // // When client download the file, it will rename the file as filename func (ctx *RequestContext) FileAttachment(filepath, filename string) { ctx.Response.Header.Set("content-disposition", fmt.Sprintf("attachment; filename=\"%s\"", filename)) ServeFile(ctx, filepath) } // SetBodyString sets response body to the given value. func (ctx *RequestContext) SetBodyString(body string) { ctx.Response.SetBodyString(body) } // SetContentTypeBytes sets response Content-Type. // // It is safe modifying contentType buffer after function return. func (ctx *RequestContext) SetContentTypeBytes(contentType []byte) { ctx.Response.Header.SetContentTypeBytes(contentType) } // FormFile returns the first file for the provided form key. func (ctx *RequestContext) FormFile(name string) (*multipart.FileHeader, error) { return ctx.Request.FormFile(name) } // FormValue returns form value associated with the given key. // // The value is searched in the following places: // // - Query string. // - POST or PUT body. // // There are more fine-grained methods for obtaining form values: // // - QueryArgs for obtaining values from query string. // - PostArgs for obtaining values from POST or PUT body. // - MultipartForm for obtaining values from multipart form. // - FormFile for obtaining uploaded files. // // The returned value is valid until returning from RequestHandler. // Use engine.SetCustomFormValueFunc to change action of FormValue. func (ctx *RequestContext) FormValue(key string) []byte { if ctx.formValueFunc != nil { return ctx.formValueFunc(ctx, key) } return defaultFormValue(ctx, key) } func (ctx *RequestContext) multipartFormValue(key string) (string, bool) { mf, err := ctx.MultipartForm() if err == nil && mf.Value != nil { vv := mf.Value[key] if len(vv) > 0 { return vv[0], true } } return "", false } func (ctx *RequestContext) multipartFormValueArray(key string) ([]string, bool) { mf, err := ctx.MultipartForm() if err == nil && mf.Value != nil { vv := mf.Value[key] if len(vv) > 0 { return vv, true } } return nil, false } func (ctx *RequestContext) RequestBodyStream() io.Reader { return ctx.Request.BodyStream() } // MultipartForm returns request's multipart form. // // Returns errNoMultipartForm if request's content-type // isn't 'multipart/form-data'. // // All uploaded temporary files are automatically deleted after // returning from RequestHandler. Either move or copy uploaded files // into new place if you want retaining them. // // Use SaveMultipartFile function for permanently saving uploaded file. // // The returned form is valid until returning from RequestHandler. // // See also FormFile and FormValue. func (ctx *RequestContext) MultipartForm() (*multipart.Form, error) { return ctx.Request.MultipartForm() } // SaveUploadedFile uploads the form file to specific dst. func (ctx *RequestContext) SaveUploadedFile(file *multipart.FileHeader, dst string) error { src, err := file.Open() if err != nil { return err } defer src.Close() out, err := os.Create(dst) if err != nil { return err } defer out.Close() _, err = io.Copy(out, src) return err } // SetConnectionClose sets 'Connection: close' response header. func (ctx *RequestContext) SetConnectionClose() { ctx.Response.SetConnectionClose() } // IsGet returns true if request method is GET. func (ctx *RequestContext) IsGet() bool { return ctx.Request.Header.IsGet() } // IsHead returns true if request method is HEAD. func (ctx *RequestContext) IsHead() bool { return ctx.Request.Header.IsHead() } // IsPost returns true if request method is POST. func (ctx *RequestContext) IsPost() bool { return ctx.Request.Header.IsPost() } // Method return request method. // // Returned value is valid until returning from RequestHandler. func (ctx *RequestContext) Method() []byte { return ctx.Request.Header.Method() } // NotFound resets response and sets '404 Not Found' response status code. func (ctx *RequestContext) NotFound() { ctx.Response.Reset() ctx.SetStatusCode(consts.StatusNotFound) ctx.SetBodyString(consts.StatusMessage(consts.StatusNotFound)) } func (ctx *RequestContext) redirect(uri []byte, statusCode int) { ctx.Response.Header.SetCanonical(bytestr.StrLocation, uri) statusCode = getRedirectStatusCode(statusCode) ctx.Response.SetStatusCode(statusCode) } func getRedirectStatusCode(statusCode int) int { if statusCode == consts.StatusMovedPermanently || statusCode == consts.StatusFound || statusCode == consts.StatusSeeOther || statusCode == consts.StatusTemporaryRedirect || statusCode == consts.StatusPermanentRedirect { return statusCode } return consts.StatusFound } // Copy returns a copy of the current context that can be safely used outside // the request's scope. // // NOTE: If you want to pass requestContext to a goroutine, call this method // to get a copy of requestContext. func (ctx *RequestContext) Copy() *RequestContext { cp := &RequestContext{ conn: ctx.conn, Params: ctx.Params, } ctx.Request.CopyTo(&cp.Request) ctx.Response.CopyTo(&cp.Response) cp.index = rConsts.AbortIndex cp.handlers = nil cp.Keys = map[string]interface{}{} ctx.mu.RLock() for k, v := range ctx.Keys { cp.Keys[k] = v } ctx.mu.RUnlock() paramCopy := make([]param.Param, len(cp.Params)) copy(paramCopy, cp.Params) cp.Params = paramCopy cp.fullPath = ctx.fullPath cp.clientIPFunc = ctx.clientIPFunc cp.formValueFunc = ctx.formValueFunc cp.binder = ctx.binder return cp } // Next should be used only inside middleware. // It executes the pending handlers in the chain inside the calling handler. func (ctx *RequestContext) Next(c context.Context) { ctx.index++ for ctx.index < int8(len(ctx.handlers)) { ctx.handlers[ctx.index](c, ctx) ctx.index++ } } // Handler returns the main handler. func (ctx *RequestContext) Handler() HandlerFunc { return ctx.handlers.Last() } // Handlers returns the handler chain. func (ctx *RequestContext) Handlers() HandlersChain { return ctx.handlers } func (ctx *RequestContext) SetHandlers(hc HandlersChain) { ctx.handlers = hc } // HandlerName returns the main handler's name. // // For example if the handler is "handleGetUsers()", this function will return "main.handleGetUsers". func (ctx *RequestContext) HandlerName() string { return utils.NameOfFunction(ctx.handlers.Last()) } func (ctx *RequestContext) ResetWithoutConn() { ctx.Params = ctx.Params[0:0] ctx.Errors = ctx.Errors[0:0] ctx.handlers = nil ctx.index = -1 ctx.fullPath = "" ctx.Keys = nil if ctx.finished != nil { close(ctx.finished) ctx.finished = nil } ctx.Request.ResetWithoutConn() ctx.Response.Reset() if ctx.IsEnableTrace() { ctx.traceInfo.Reset() } } // Reset resets requestContext. // // NOTE: It is an internal function. You should not use it. func (ctx *RequestContext) Reset() { ctx.ResetWithoutConn() ctx.conn = nil } // Redirect returns an HTTP redirect to the specific location. // Note that this will not stop the current handler. // In other words, even if Redirect() is called, the remaining handlers will still be executed and cause unexpected result. // So it should call Abort to ensure the remaining handlers of this request will not be called. // // ctx.Abort() // return func (ctx *RequestContext) Redirect(statusCode int, uri []byte) { ctx.redirect(uri, statusCode) } // Header is an intelligent shortcut for ctx.Response.Header.Set(key, value). // It writes a header in the response. // If value == "", this method removes the header `ctx.Response.Header.Del(key)`. func (ctx *RequestContext) Header(key, value string) { if value == "" { ctx.Response.Header.Del(key) return } ctx.Response.Header.Set(key, value) } // Set is used to store a new key/value pair exclusively for this context. // It also lazy initializes c.Keys if it was not used previously. func (ctx *RequestContext) Set(key string, value interface{}) { ctx.mu.Lock() if ctx.Keys == nil { ctx.Keys = make(map[string]interface{}) } ctx.Keys[key] = value ctx.mu.Unlock() } // Get returns the value for the given key, ie: (value, true). // If the value does not exist it returns (nil, false) func (ctx *RequestContext) Get(key string) (value interface{}, exists bool) { ctx.mu.RLock() value, exists = ctx.Keys[key] ctx.mu.RUnlock() return } // MustGet returns the value for the given key if it exists, otherwise it panics. func (ctx *RequestContext) MustGet(key string) interface{} { if value, exists := ctx.Get(key); exists { return value } panic("Key \"" + key + "\" does not exist") } // GetString returns the value associated with the key as a string. Return "" when type is error. func (ctx *RequestContext) GetString(key string) (s string) { if val, ok := ctx.Get(key); ok && val != nil { s, _ = val.(string) } return } // GetBool returns the value associated with the key as a boolean. Return false when type is error. func (ctx *RequestContext) GetBool(key string) (b bool) { if val, ok := ctx.Get(key); ok && val != nil { b, _ = val.(bool) } return } // GetInt returns the value associated with the key as an integer. Return 0 when type is error. func (ctx *RequestContext) GetInt(key string) (i int) { if val, ok := ctx.Get(key); ok && val != nil { i, _ = val.(int) } return } // GetInt32 returns the value associated with the key as an integer. Return int32(0) when type is error. func (ctx *RequestContext) GetInt32(key string) (i32 int32) { if val, ok := ctx.Get(key); ok && val != nil { i32, _ = val.(int32) } return } // GetInt64 returns the value associated with the key as an integer. Return int64(0) when type is error. func (ctx *RequestContext) GetInt64(key string) (i64 int64) { if val, ok := ctx.Get(key); ok && val != nil { i64, _ = val.(int64) } return } // GetUint returns the value associated with the key as an unsigned integer. Return uint(0) when type is error. func (ctx *RequestContext) GetUint(key string) (ui uint) { if val, ok := ctx.Get(key); ok && val != nil { ui, _ = val.(uint) } return } // GetUint32 returns the value associated with the key as an unsigned integer. Return uint32(0) when type is error. func (ctx *RequestContext) GetUint32(key string) (ui32 uint32) { if val, ok := ctx.Get(key); ok && val != nil { ui32, _ = val.(uint32) } return } // GetUint64 returns the value associated with the key as an unsigned integer. Return uint64(0) when type is error. func (ctx *RequestContext) GetUint64(key string) (ui64 uint64) { if val, ok := ctx.Get(key); ok && val != nil { ui64, _ = val.(uint64) } return } // GetFloat32 returns the value associated with the key as a float32. Return float32(0.0) when type is error. func (ctx *RequestContext) GetFloat32(key string) (f32 float32) { if val, ok := ctx.Get(key); ok && val != nil { f32, _ = val.(float32) } return } // GetFloat64 returns the value associated with the key as a float64. Return 0.0 when type is error. func (ctx *RequestContext) GetFloat64(key string) (f64 float64) { if val, ok := ctx.Get(key); ok && val != nil { f64, _ = val.(float64) } return } // GetTime returns the value associated with the key as time. Return time.Time{} when type is error. func (ctx *RequestContext) GetTime(key string) (t time.Time) { if val, ok := ctx.Get(key); ok && val != nil { t, _ = val.(time.Time) } return } // GetDuration returns the value associated with the key as a duration. Return time.Duration(0) when type is error. func (ctx *RequestContext) GetDuration(key string) (d time.Duration) { if val, ok := ctx.Get(key); ok && val != nil { d, _ = val.(time.Duration) } return } // GetStringSlice returns the value associated with the key as a slice of strings. // // Return []string(nil) when type is error. func (ctx *RequestContext) GetStringSlice(key string) (ss []string) { if val, ok := ctx.Get(key); ok && val != nil { ss, _ = val.([]string) } return } // GetStringMap returns the value associated with the key as a map of interfaces. // // Return map[string]interface{}(nil) when type is error. func (ctx *RequestContext) GetStringMap(key string) (sm map[string]interface{}) { if val, ok := ctx.Get(key); ok && val != nil { sm, _ = val.(map[string]interface{}) } return } // GetStringMapString returns the value associated with the key as a map of strings. // // Return map[string]string(nil) when type is error. func (ctx *RequestContext) GetStringMapString(key string) (sms map[string]string) { if val, ok := ctx.Get(key); ok && val != nil { sms, _ = val.(map[string]string) } return } // GetStringMapStringSlice returns the value associated with the key as a map to a slice of strings. // // Return map[string][]string(nil) when type is error. func (ctx *RequestContext) GetStringMapStringSlice(key string) (smss map[string][]string) { if val, ok := ctx.Get(key); ok && val != nil { smss, _ = val.(map[string][]string) } return } // Param returns the value of the URL param. // It is a shortcut for c.Params.ByName(key) // // router.GET("/user/:id", func(c context.Context, ctx *app.RequestContext) { // // a GET request to /user/john // id := ctx.Param("id") // id == "john" // }) func (ctx *RequestContext) Param(key string) string { return ctx.Params.ByName(key) } // Abort prevents pending handlers from being called. // // Note that this will not stop the current handler. // Let's say you have an authorization middleware that validates that the current request is authorized. // If the authorization fails (ex: the password does not match), call Abort to ensure the remaining handlers // for this request are not called. func (ctx *RequestContext) Abort() { ctx.index = rConsts.AbortIndex } // AbortWithStatus calls `Abort()` and writes the headers with the specified status code. // // For example, a failed attempt to authenticate a request could use: context.AbortWithStatus(401). func (ctx *RequestContext) AbortWithStatus(code int) { ctx.SetStatusCode(code) ctx.Abort() } // AbortWithMsg sets response status code to the given value and sets response body // to the given message. // // Warning: this will reset the response headers and body already set! func (ctx *RequestContext) AbortWithMsg(msg string, statusCode int) { ctx.Response.Reset() ctx.SetStatusCode(statusCode) ctx.SetContentTypeBytes(bytestr.DefaultContentType) ctx.SetBodyString(msg) ctx.Abort() } // AbortWithStatusJSON calls `Abort()` and then `JSON` internally. // // This method stops the chain, writes the status code and return a JSON body. // It also sets the Content-Type as "application/json". func (ctx *RequestContext) AbortWithStatusJSON(code int, jsonObj interface{}) { ctx.Abort() ctx.JSON(code, jsonObj) } // Render writes the response headers and calls render.Render to render data. func (ctx *RequestContext) Render(code int, r render.Render) { ctx.SetStatusCode(code) if !bodyAllowedForStatus(code) { r.WriteContentType(&ctx.Response) return } if err := r.Render(&ctx.Response); err != nil { panic(err) } } // ProtoBuf serializes the given struct as ProtoBuf into the response body. func (ctx *RequestContext) ProtoBuf(code int, obj interface{}) { ctx.Render(code, render.ProtoBuf{Data: obj}) } // JSON serializes the given struct as JSON into the response body. // // It also sets the Content-Type as "application/json". func (ctx *RequestContext) JSON(code int, obj interface{}) { ctx.Render(code, render.JSONRender{Data: obj}) } // PureJSON serializes the given struct as JSON into the response body. // PureJSON, unlike JSON, does not replace special html characters with their unicode entities. func (ctx *RequestContext) PureJSON(code int, obj interface{}) { ctx.Render(code, render.PureJSON{Data: obj}) } // IndentedJSON serializes the given struct as pretty JSON (indented + endlines) into the response body. // It also sets the Content-Type as "application/json". func (ctx *RequestContext) IndentedJSON(code int, obj interface{}) { ctx.Render(code, render.IndentedJSON{Data: obj}) } // HTML renders the HTTP template specified by its file name. // // It also updates the HTTP code and sets the Content-Type as "text/html". // See http://golang.org/doc/articles/wiki/ func (ctx *RequestContext) HTML(code int, name string, obj interface{}) { instance := ctx.HTMLRender.Instance(name, obj) ctx.Render(code, instance) } // Data writes some data into the body stream and updates the HTTP code. func (ctx *RequestContext) Data(code int, contentType string, data []byte) { ctx.Render(code, render.Data{ ContentType: contentType, Data: data, }) } // XML serializes the given struct as XML into the response body. // // It also sets the Content-Type as "application/xml". func (ctx *RequestContext) XML(code int, obj interface{}) { ctx.Render(code, render.XML{Data: obj}) } // AbortWithError calls `AbortWithStatus()` and `Error()` internally. // // This method stops the chain, writes the status code and pushes the specified error to `c.Errors`. // See RequestContext.Error() for more details. func (ctx *RequestContext) AbortWithError(code int, err error) *errors.Error { ctx.AbortWithStatus(code) return ctx.Error(err) } // IsAborted returns true if the current context has aborted. func (ctx *RequestContext) IsAborted() bool { return ctx.index >= rConsts.AbortIndex } // Error attaches an error to the current context. The error is pushed to a list of errors. // // It's a good idea to call Error for each error that occurred during the resolution of a request. // A middleware can be used to collect all the errors and push them to a database together, // print a log, or append it in the HTTP response. // Error will panic if err is nil. func (ctx *RequestContext) Error(err error) *errors.Error { if err == nil { panic("err is nil") } parsedError, ok := err.(*errors.Error) if !ok { parsedError = &errors.Error{ Err: err, Type: errors.ErrorTypePrivate, } } ctx.Errors = append(ctx.Errors, parsedError) return parsedError } // ContentType returns the Content-Type header of the request. func (ctx *RequestContext) ContentType() []byte { return ctx.Request.Header.ContentType() } // Cookie returns the value of the request cookie key. func (ctx *RequestContext) Cookie(key string) []byte { return ctx.Request.Header.Cookie(key) } // SetCookie adds a Set-Cookie header to the Response's headers. // // Parameter introduce: // name and value is used to set cookie's name and value, eg. Set-Cookie: name=value // maxAge is use to set cookie's expiry date, eg. Set-Cookie: name=value; max-age=1 // path and domain is used to set the scope of a cookie, eg. Set-Cookie: name=value;domain=localhost; path=/; // secure and httpOnly is used to sent cookies securely; eg. Set-Cookie: name=value;HttpOnly; secure; // sameSite let servers specify whether/when cookies are sent with cross-site requests; eg. Set-Cookie: name=value;HttpOnly; secure; SameSite=Lax; // // For example: // 1. ctx.SetCookie("user", "hertz", 1, "/", "localhost",protocol.CookieSameSiteLaxMode, true, true) // add response header ---> Set-Cookie: user=hertz; max-age=1; domain=localhost; path=/; HttpOnly; secure; SameSite=Lax; // 2. ctx.SetCookie("user", "hertz", 10, "/", "localhost",protocol.CookieSameSiteLaxMode, false, false) // add response header ---> Set-Cookie: user=hertz; max-age=10; domain=localhost; path=/; SameSite=Lax; // 3. ctx.SetCookie("", "hertz", 10, "/", "localhost",protocol.CookieSameSiteLaxMode, false, false) // add response header ---> Set-Cookie: hertz; max-age=10; domain=localhost; path=/; SameSite=Lax; // 4. ctx.SetCookie("user", "", 10, "/", "localhost",protocol.CookieSameSiteLaxMode, false, false) // add response header ---> Set-Cookie: user=; max-age=10; domain=localhost; path=/; SameSite=Lax; func (ctx *RequestContext) SetCookie(name, value string, maxAge int, path, domain string, sameSite protocol.CookieSameSite, secure, httpOnly bool) { ctx.setCookie(name, value, maxAge, path, domain, sameSite, secure, httpOnly, false) } func (ctx *RequestContext) setCookie(name, value string, maxAge int, path, domain string, sameSite protocol.CookieSameSite, secure, httpOnly, partitioned bool) { if path == "" { path = "/" } cookie := protocol.AcquireCookie() defer protocol.ReleaseCookie(cookie) cookie.SetKey(name) cookie.SetValue(url.QueryEscape(value)) cookie.SetMaxAge(maxAge) cookie.SetPath(path) cookie.SetDomain(domain) cookie.SetSecure(secure) cookie.SetHTTPOnly(httpOnly) cookie.SetSameSite(sameSite) cookie.SetPartitioned(partitioned) ctx.Response.Header.SetCookie(cookie) } // SetPartitionedCookie adds a partitioned cookie to the Response's headers. // Use protocol.CookieSameSiteNoneMode for cross-site cookies to work. // // Usage: ctx.SetPartitionedCookie("user", "name", 10, "/", "localhost", protocol.CookieSameSiteNoneMode, true, true) // // This adds the response header: Set-Cookie: user=name; Max-Age=10; Domain=localhost; Path=/; HttpOnly; Secure; SameSite=None; Partitioned func (ctx *RequestContext) SetPartitionedCookie(name, value string, maxAge int, path, domain string, sameSite protocol.CookieSameSite, secure, httpOnly bool) { ctx.setCookie(name, value, maxAge, path, domain, sameSite, secure, httpOnly, true) } // UserAgent returns the value of the request user_agent. func (ctx *RequestContext) UserAgent() []byte { return ctx.Request.Header.UserAgent() } // Status sets the HTTP response code. func (ctx *RequestContext) Status(code int) { ctx.SetStatusCode(code) } // GetHeader returns value from request headers. func (ctx *RequestContext) GetHeader(key string) []byte { return ctx.Request.Header.Peek(key) } // GetRawData returns body data. func (ctx *RequestContext) GetRawData() []byte { return ctx.Request.Body() } // Body returns body data func (ctx *RequestContext) Body() ([]byte, error) { return ctx.Request.BodyE() } // ClientIP attempts to parse the headers in the order of [X-Forwarded-For, X-Real-IP]. // It calls RemoteIP() under the hood. If it cannot satisfy the requirements, // use engine.SetClientIPFunc to inject your own implementation. func (ctx *RequestContext) ClientIP() string { if ctx.clientIPFunc != nil { return ctx.clientIPFunc(ctx) } return defaultClientIP(ctx) } // QueryArgs returns query arguments from RequestURI. // // It doesn't return POST'ed arguments - use PostArgs() for this. // Returned arguments are valid until returning from RequestHandler. // See also PostArgs, FormValue and FormFile. func (ctx *RequestContext) QueryArgs() *protocol.Args { return ctx.URI().QueryArgs() } // PostArgs returns POST arguments. // // It doesn't return query arguments from RequestURI - use QueryArgs for this. // Returned arguments are valid until returning from RequestHandler. // See also QueryArgs, FormValue and FormFile. func (ctx *RequestContext) PostArgs() *protocol.Args { return ctx.Request.PostArgs() } // Query returns the keyed url query value if it exists, otherwise it returns an empty string `("")`. // // For example: // // GET /path?id=1234&name=Manu&value= // c.Query("id") == "1234" // c.Query("name") == "Manu" // c.Query("value") == "" // c.Query("wtf") == "" func (ctx *RequestContext) Query(key string) string { value, _ := ctx.GetQuery(key) return value } // DefaultQuery returns the keyed url query value if it exists, // otherwise it returns the specified defaultValue string. func (ctx *RequestContext) DefaultQuery(key, defaultValue string) string { if value, ok := ctx.GetQuery(key); ok { return value } return defaultValue } // GetQuery returns the keyed url query value // // if it exists `(value, true)` (even when the value is an empty string) will be returned, // otherwise it returns `("", false)`. // For example: // // GET /?name=Manu&lastname= // ("Manu", true) == c.GetQuery("name") // ("", false) == c.GetQuery("id") // ("", true) == c.GetQuery("lastname") func (ctx *RequestContext) GetQuery(key string) (string, bool) { return ctx.QueryArgs().PeekExists(key) } // PostForm returns the specified key from a POST urlencoded form or multipart form // when it exists, otherwise it returns an empty string `("")`. func (ctx *RequestContext) PostForm(key string) string { value, _ := ctx.GetPostForm(key) return value } // PostFormArray returns the specified key from a POST urlencoded form or multipart form // when it exists, otherwise it returns an empty array `([])`. func (ctx *RequestContext) PostFormArray(key string) []string { values, _ := ctx.GetPostFormArray(key) return values } // DefaultPostForm returns the specified key from a POST urlencoded form or multipart form // when it exists, otherwise it returns the specified defaultValue string. // // See: PostForm() and GetPostForm() for further information. func (ctx *RequestContext) DefaultPostForm(key, defaultValue string) string { if value, ok := ctx.GetPostForm(key); ok { return value } return defaultValue } // GetPostForm is like PostForm(key). It returns the specified key from a POST urlencoded // form or multipart form when it exists `(value, true)` (even when the value is an empty string), // otherwise it returns ("", false). // // For example, during a PATCH request to update the user's email: // // email=mail@example.com --> ("mail@example.com", true) := GetPostForm("email") // set email to "mail@example.com" // email= --> ("", true) := GetPostForm("email") // set email to "" // --> ("", false) := GetPostForm("email") // do nothing with email func (ctx *RequestContext) GetPostForm(key string) (string, bool) { if v, exists := ctx.PostArgs().PeekExists(key); exists { return v, exists } return ctx.multipartFormValue(key) } // GetPostFormArray is like PostFormArray(key). It returns the specified key from a POST urlencoded // form or multipart form when it exists `([]string, true)` (even when the value is an empty string), // otherwise it returns ([]string(nil), false). // // For example, during a PATCH request to update the item's tags: // // tag=tag1 tag=tag2 tag=tag3 --> (["tag1", "tag2", "tag3"], true) := GetPostFormArray("tags") // set tags to ["tag1", "tag2", "tag3"] // tags= --> (nil, true) := GetPostFormArray("tags") // set tags to nil // --> (nil, false) := GetPostFormArray("tags") // do nothing with tags func (ctx *RequestContext) GetPostFormArray(key string) ([]string, bool) { vs := ctx.PostArgs().PeekAll(key) values := make([]string, len(vs)) for i, v := range vs { values[i] = string(v) } if len(values) == 0 { return ctx.multipartFormValueArray(key) } else { return values, true } } // bodyAllowedForStatus is a copy of http.bodyAllowedForStatus non-exported function. func bodyAllowedForStatus(status int) bool { switch { case status >= 100 && status <= 199: return false case status == consts.StatusNoContent: return false case status == consts.StatusNotModified: return false } return true } func (ctx *RequestContext) getBinder() binding.Binder { if ctx.binder != nil { return ctx.binder } return binding.DefaultBinder() } // BindAndValidate binds data from *RequestContext to obj and validates them if needed. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindAndValidate(obj interface{}) error { bi := ctx.getBinder() if err := bi.Bind(&ctx.Request, obj, ctx.Params); err != nil { return err } return bi.Validate(&ctx.Request, obj) } // Bind binds data from *RequestContext to obj. // NOTE: obj should be a pointer. func (ctx *RequestContext) Bind(obj interface{}) error { return ctx.getBinder().Bind(&ctx.Request, obj, ctx.Params) } // Validate validates obj // NOTE: obj should be a pointer. func (ctx *RequestContext) Validate(obj interface{}) error { return ctx.getBinder().Validate(&ctx.Request, obj) } // BindQuery binds query parameters from *RequestContext to obj with 'query' tag. It will only use 'query' tag for binding. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindQuery(obj interface{}) error { return ctx.getBinder().BindQuery(&ctx.Request, obj) } // BindHeader binds header parameters from *RequestContext to obj with 'header' tag. It will only use 'header' tag for binding. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindHeader(obj interface{}) error { return ctx.getBinder().BindHeader(&ctx.Request, obj) } // BindPath binds router parameters from *RequestContext to obj with 'path' tag. It will only use 'path' tag for binding. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindPath(obj interface{}) error { return ctx.getBinder().BindPath(&ctx.Request, obj, ctx.Params) } // BindForm binds form parameters from *RequestContext to obj with 'form' tag. It will only use 'form' tag for binding. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindForm(obj interface{}) error { if len(ctx.Request.Body()) == 0 { return fmt.Errorf("missing form body") } return ctx.getBinder().BindForm(&ctx.Request, obj) } // BindJSON binds JSON body from *RequestContext. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindJSON(obj interface{}) error { return ctx.getBinder().BindJSON(&ctx.Request, obj) } // BindProtobuf binds protobuf body from *RequestContext. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindProtobuf(obj interface{}) error { return ctx.getBinder().BindProtobuf(&ctx.Request, obj) } // BindByContentType will select the binding type on the ContentType automatically. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindByContentType(obj interface{}) error { if ctx.Request.Header.IsGet() { return ctx.BindQuery(obj) } ct := utils.FilterContentType(bytesconv.B2s(ctx.Request.Header.ContentType())) switch strings.ToLower(ct) { case consts.MIMEApplicationJSON: return ctx.BindJSON(obj) case consts.MIMEPROTOBUF: return ctx.BindProtobuf(obj) case consts.MIMEApplicationHTMLForm, consts.MIMEMultipartPOSTForm: return ctx.BindForm(obj) default: return fmt.Errorf("unsupported bind content-type for '%s'", ct) } } // VisitAllQueryArgs calls f for each existing query arg. // // f must not retain references to key and value after returning. // Make key and/or value copies if you need storing them after returning. func (ctx *RequestContext) VisitAllQueryArgs(f func(key, value []byte)) { ctx.QueryArgs().VisitAll(f) } // VisitAllPostArgs calls f for each existing post arg. // // f must not retain references to key and value after returning. // Make key and/or value copies if you need storing them after returning. func (ctx *RequestContext) VisitAllPostArgs(f func(key, value []byte)) { ctx.Request.PostArgs().VisitAll(f) } // VisitAllHeaders calls f for each request header. // // f must not retain references to key and/or value after returning. // Copy key and/or value contents before returning if you need retaining them. // // To get the headers in order they were received use VisitAllInOrder. func (ctx *RequestContext) VisitAllHeaders(f func(key, value []byte)) { ctx.Request.Header.VisitAll(f) } // VisitAllCookie calls f for each request cookie. // // f must not retain references to key and/or value after returning. func (ctx *RequestContext) VisitAllCookie(f func(key, value []byte)) { ctx.Request.Header.VisitAllCookie(f) } ================================================ FILE: pkg/app/context_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package app import ( "bytes" "context" "encoding/xml" "errors" "fmt" "html/template" "io/ioutil" "net" "os" "reflect" "strings" "testing" "time" "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/internal/test/mock/binder" "github.com/cloudwego/hertz/pkg/app/server/render" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/test/mock" "github.com/cloudwego/hertz/pkg/common/testdata/proto" "github.com/cloudwego/hertz/pkg/common/tracer/traceinfo" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/protocol/http1/req" "github.com/cloudwego/hertz/pkg/protocol/http1/resp" con "github.com/cloudwego/hertz/pkg/route/consts" "github.com/cloudwego/hertz/pkg/route/param" ) func TestProtobuf(t *testing.T) { ctx := NewContext(0) body := proto.TestStruct{Body: []byte("Hello World")} ctx.ProtoBuf(consts.StatusOK, &body) assert.DeepEqual(t, string(ctx.Response.Body()), "\n\vHello World") } func TestPureJson(t *testing.T) { ctx := NewContext(0) ctx.PureJSON(consts.StatusOK, utils.H{ "html": "Hello World", }) if string(ctx.Response.Body()) != "{\"html\":\"Hello World\"}\n" { t.Fatalf("unexpected purejson: %#v, expected: %#v", string(ctx.Response.Body()), "Hello World") } } func TestIndentedJSON(t *testing.T) { ctx := NewContext(0) ctx.IndentedJSON(consts.StatusOK, utils.H{ "foo": "bar", "html": "h1", }) if string(ctx.Response.Body()) != "{\n \"foo\": \"bar\",\n \"html\": \"h1\"\n}" { t.Fatalf("unexpected purejson: %#v, expected: %#v", string(ctx.Response.Body()), "{\n \"foo\": \"bar\",\n \"html\": \"\"\n}") } } func TestContext(t *testing.T) { reqContext := NewContext(0) reqContext.Set("testContextKey", "testValue") ctx := reqContext if ctx.Value("testContextKey") != "testValue" { t.Fatalf("unexpected value: %#v, expected: %#v", ctx.Value("testContextKey"), "testValue") } } func TestValue(t *testing.T) { ctx := NewContext(0) v := ctx.Value("testContextKey") assert.Nil(t, v) ctx.Set("testContextKey", "testValue") v = ctx.Value("testContextKey") assert.DeepEqual(t, "testValue", v) } func TestContextNotModified(t *testing.T) { reqContext := NewContext(0) reqContext.Response.SetStatusCode(consts.StatusOK) if reqContext.Response.StatusCode() != consts.StatusOK { t.Fatalf("unexpected status code: %#v, expected: %#v", reqContext.Response.StatusCode(), consts.StatusOK) } reqContext.NotModified() if reqContext.Response.StatusCode() != consts.StatusNotModified { t.Fatalf("unexpected status code: %#v, expected: %#v", reqContext.Response.StatusCode(), consts.StatusNotModified) } } func TestIfModifiedSince(t *testing.T) { ctx := NewContext(0) var req protocol.Request req.Header.Set(string(bytestr.StrIfModifiedSince), "Mon, 02 Jan 2006 15:04:05 MST") req.CopyTo(&ctx.Request) if !ctx.IfModifiedSince(time.Now()) { t.Fatalf("ifModifiedSince error, expected false, but get true") } tt, _ := time.Parse(time.RFC3339, "2004-11-12T11:45:26.371Z") if ctx.IfModifiedSince(tt) { t.Fatalf("ifModifiedSince error, expected true, but get false") } } func TestWrite(t *testing.T) { ctx := NewContext(0) l, err := ctx.Write([]byte("test body")) if err != nil { t.Fatalf("unexpected error: %#v", err.Error()) } if l != 9 { t.Fatalf("unexpected len: %#v, expected: %#v", l, 9) } if string(ctx.Response.BodyBytes()) != "test body" { t.Fatalf("unexpected body: %#v, expected: %#v", string(ctx.Response.BodyBytes()), "test body") } } func TestSetConnectionClose(t *testing.T) { ctx := NewContext(0) ctx.SetConnectionClose() if !ctx.Response.Header.ConnectionClose() { t.Fatalf("expected close connection, but not") } } func TestNotFound(t *testing.T) { ctx := NewContext(0) ctx.NotFound() if ctx.Response.StatusCode() != consts.StatusNotFound || string(ctx.Response.BodyBytes()) != "Not Found" { t.Fatalf("unexpected status code or body") } } func TestRedirect(t *testing.T) { ctx := NewContext(0) ctx.Redirect(consts.StatusFound, []byte("/hello")) assert.DeepEqual(t, consts.StatusFound, ctx.Response.StatusCode()) ctx.redirect([]byte("/hello"), consts.StatusMovedPermanently) assert.DeepEqual(t, consts.StatusMovedPermanently, ctx.Response.StatusCode()) } func TestGetRedirectStatusCode(t *testing.T) { val := getRedirectStatusCode(consts.StatusMovedPermanently) assert.DeepEqual(t, consts.StatusMovedPermanently, val) val = getRedirectStatusCode(consts.StatusNotFound) assert.DeepEqual(t, consts.StatusFound, val) } func TestCookie(t *testing.T) { ctx := NewContext(0) ctx.Request.Header.SetCookie("cookie", "test cookie") if string(ctx.Cookie("cookie")) != "test cookie" { t.Fatalf("unexpected cookie: %#v, expected get: %#v", string(ctx.Cookie("cookie")), "test cookie") } } func TestUserAgent(t *testing.T) { ctx := NewContext(0) ctx.Request.Header.SetUserAgentBytes([]byte("user agent")) if string(ctx.UserAgent()) != "user agent" { t.Fatalf("unexpected user agent: %#v, expected get: %#v", string(ctx.UserAgent()), "user agent") } } func TestStatus(t *testing.T) { ctx := NewContext(0) ctx.Status(consts.StatusOK) if ctx.Response.StatusCode() != consts.StatusOK { t.Fatalf("expected get consts.StatusOK, but not") } } func TestPost(t *testing.T) { ctx := NewContext(0) ctx.Request.Header.SetMethod(consts.MethodPost) if !ctx.IsPost() { t.Fatalf("expected post method , but get: %#v", ctx.Method()) } if string(ctx.Method()) != consts.MethodPost { t.Fatalf("expected post method , but get: %#v", ctx.Method()) } } func TestGet(t *testing.T) { ctx := NewContext(0) ctx.Request.Header.SetMethod(consts.MethodPost) assert.False(t, ctx.IsGet()) ctx.Request.Header.SetMethod(consts.MethodGet) assert.True(t, ctx.IsGet()) } func TestCopy(t *testing.T) { t.Parallel() ctx := NewContext(0) ctx.fullPath = "full_path" ctx.Request.Header.Add("header_a", "header_value_a") ctx.Response.Header.Add("header_b", "header_value_b") ctx.Params = param.Params{ {Key: "key_a", Value: "value_a"}, {Key: "key_b", Value: "value_b"}, {Key: "key_c", Value: "value_b"}, {Key: "key_d", Value: "value_b"}, {Key: "key_e", Value: "value_b"}, {Key: "key_f", Value: "value_b"}, {Key: "key_g", Value: "value_b"}, {Key: "key_h", Value: "value_b"}, {Key: "key_i", Value: "value_b"}, } ctx.Set("map_key_a", "map_value_a") ctx.Set("map_key_b", "map_value_b") for i := 0; i <= 10000; i++ { c := ctx.Copy() go func(context *RequestContext) { str, _ := context.Params.Get("key_a") if str != "value_a" { t.Errorf("unexpected value: %#v, expected: %#v", str, "value_a") return } if c.fullPath != "full_path" { t.Errorf("unexpected value: %#v, expected: %#v", c.fullPath, "full_path") return } reqHeaderStr := context.Request.Header.Get("header_a") if reqHeaderStr != "header_value_a" { t.Errorf("unexpected value: %#v, expected: %#v", reqHeaderStr, "header_value_a") return } respHeaderStr := context.Response.Header.Get("header_b") if respHeaderStr != "header_value_b" { t.Errorf("unexpected value: %#v, expected: %#v", respHeaderStr, "header_value_b") return } iStr := ctx.Value("map_key_a") if iStr.(string) != "map_value_a" { t.Errorf("unexpected value: %#v, expected: %#v", iStr.(string), "map_value_a") return } context.Params = context.Params[0:0] context.Params = append(context.Params, param.Param{Key: "key_a", Value: "value_a_"}) context.Request.Header.Reset() context.Request.Header.Add("header_a", "header_value_a_") context.Response.Header.Reset() context.Response.Header.Add("header_b", "header_value_b_") context.Keys = nil context.Keys = make(map[string]interface{}) context.Set("header_value_a", "map_value_a_") }(c) } } func TestQuery(t *testing.T) { var r protocol.Request ctx := NewContext(0) s := "POST /foo?name=menu&value= HTTP/1.1\r\nHost: google.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa/bb\r\n\r\n3 \r\nabc\r\n0\r\n\r\n" zr := mock.NewZeroCopyReader(s) err := req.Read(&r, zr) if err != nil { t.Fatalf("Unexpected error when reading chunked request: %s", err) } r.CopyTo(&ctx.Request) if ctx.Query("name") != "menu" { t.Fatalf("unexpected query: %#v, expected menu", ctx.Query("name")) } if ctx.DefaultQuery("name", "default value") != "menu" { t.Fatalf("unexpected query: %#v, expected menu", ctx.Query("name")) } if ctx.DefaultQuery("defaultQuery", "default value") != "default value" { t.Fatalf("unexpected query: %#v, expected `default value`", ctx.Query("defaultQuery")) } } func TestMethod(t *testing.T) { ctx := NewContext(0) ctx.Status(consts.StatusOK) if ctx.Response.StatusCode() != consts.StatusOK { t.Fatalf("expected get consts.StatusOK, but not") } } func makeCtxByReqString(t *testing.T, s string) *RequestContext { ctx := NewContext(0) mr := mock.NewZeroCopyReader(s) if err := req.Read(&ctx.Request, mr); err != nil { t.Fatalf("unexpected error: %s", err) } return ctx } func TestPostForm(t *testing.T) { t.Parallel() ctx := makeCtxByReqString(t, `POST /upload HTTP/1.1 Host: localhost:10000 Content-Length: 521 Content-Type: multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg ------WebKitFormBoundaryJwfATyF8tmxSJnLg Content-Disposition: form-data; name="f1" value1 ------WebKitFormBoundaryJwfATyF8tmxSJnLg Content-Disposition: form-data; name="fileaaa"; filename="TODO" Content-Type: application/octet-stream - SessionClient with referer and cookies support. - Client with requests' pipelining support. - ProxyHandler similar to FSHandler. - WebSockets. See https://tools.ietf.org/html/rfc6455 . - HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . ------WebKitFormBoundaryJwfATyF8tmxSJnLg-- `) if ctx.PostForm("f1") != "value1" { t.Fatalf("PostForm get Multipart Form data failed") } if ctx.PostForm("fileaaa") != "" { t.Fatalf("PostForm should not get file") } ctx = makeCtxByReqString(t, `POST /upload HTTP/1.1 Host: localhost:10000 Content-Length: 11 Content-Type: application/x-www-form-urlencoded hello=world`) if ctx.PostForm("hello") != "world" { t.Fatalf("PostForm get form failed") } } func TestPostFormArray(t *testing.T) { t.Parallel() ctx := makeCtxByReqString(t, `POST /upload HTTP/1.1 Host: localhost:10000 Content-Type: multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg Content-Length: 521 ------WebKitFormBoundaryJwfATyF8tmxSJnLg Content-Disposition: form-data; name="tag" red ------WebKitFormBoundaryJwfATyF8tmxSJnLg Content-Disposition: form-data; name="tag" green ------WebKitFormBoundaryJwfATyF8tmxSJnLg Content-Disposition: form-data; name="tag" blue ------WebKitFormBoundaryJwfATyF8tmxSJnLg-- `) assert.DeepEqual(t, []string{"red", "green", "blue"}, ctx.PostFormArray("tag")) ctx = makeCtxByReqString(t, `POST /upload HTTP/1.1 Host: localhost:10000 Content-Type: application/x-www-form-urlencoded; charset=UTF-8 Content-Length: 26 tag=red&tag=green&tag=blue `) assert.DeepEqual(t, []string{"red", "green", "blue"}, ctx.PostFormArray("tag")) } func TestDefaultPostForm(t *testing.T) { ctx := makeCtxByReqString(t, `POST /upload HTTP/1.1 Host: localhost:10000 Content-Length: 521 Content-Type: multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg ------WebKitFormBoundaryJwfATyF8tmxSJnLg Content-Disposition: form-data; name="f1" value1 ------WebKitFormBoundaryJwfATyF8tmxSJnLg Content-Disposition: form-data; name="fileaaa"; filename="TODO" Content-Type: application/octet-stream - SessionClient with referer and cookies support. - Client with requests' pipelining support. - ProxyHandler similar to FSHandler. - WebSockets. See https://tools.ietf.org/html/rfc6455 . - HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . ------WebKitFormBoundaryJwfATyF8tmxSJnLg-- `) val := ctx.DefaultPostForm("f1", "no val") assert.DeepEqual(t, "value1", val) val = ctx.DefaultPostForm("f99", "no val") assert.DeepEqual(t, "no val", val) } func TestRequestContext_FormFile(t *testing.T) { t.Parallel() s := `POST /upload HTTP/1.1 Host: localhost:10000 Content-Length: 521 Content-Type: multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg ------WebKitFormBoundaryJwfATyF8tmxSJnLg Content-Disposition: form-data; name="f1" value1 ------WebKitFormBoundaryJwfATyF8tmxSJnLg Content-Disposition: form-data; name="fileaaa"; filename="TODO" Content-Type: application/octet-stream - SessionClient with referer and cookies support. - Client with requests' pipelining support. - ProxyHandler similar to FSHandler. - WebSockets. See https://tools.ietf.org/html/rfc6455 . - HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . ------WebKitFormBoundaryJwfATyF8tmxSJnLg-- tailfoobar` mr := mock.NewZeroCopyReader(s) ctx := NewContext(0) if err := req.Read(&ctx.Request, mr); err != nil { t.Fatalf("unexpected error: %s", err) } tail, err := ioutil.ReadAll(mr) if err != nil { t.Fatalf("unexpected error: %s", err) } if string(tail) != "tailfoobar" { t.Fatalf("unexpected tail %q. Expecting %q", tail, "tailfoobar") } f, err := ctx.MultipartForm() if err != nil { t.Fatalf("unexpected error: %s", err) } defer ctx.Request.RemoveMultipartFormFiles() // verify files if len(f.File) != 1 { t.Fatalf("unexpected number of file values in multipart form: %d. Expecting 1", len(f.File)) } for k, vv := range f.File { if k != "fileaaa" { t.Fatalf("unexpected file value name %q. Expecting %q", k, "fileaaa") } if len(vv) != 1 { t.Fatalf("unexpected number of file values %d. Expecting 1", len(vv)) } v := vv[0] if v.Filename != "TODO" { t.Fatalf("unexpected filename %q. Expecting %q", v.Filename, "TODO") } ct := v.Header.Get("Content-Type") if ct != consts.MIMEApplicationOctetStream { t.Fatalf("unexpected content-type %q. Expecting %q", ct, "application/octet-stream") } } err = ctx.SaveUploadedFile(f.File["fileaaa"][0], "TODO") assert.Nil(t, err) fileInfo, err := os.Stat("TODO") assert.Nil(t, err) assert.DeepEqual(t, "TODO", fileInfo.Name()) assert.DeepEqual(t, f.File["fileaaa"][0].Size, fileInfo.Size()) err = os.Remove("TODO") assert.Nil(t, err) ff, err := ctx.FormFile("fileaaa") if err != nil || ff == nil { t.Fatalf("unexpected error happened when ctx.FormFile()") } buf := make([]byte, ff.Size) fff, _ := ff.Open() fff.Read(buf) if !strings.Contains(string(buf), "- SessionClient") { t.Fatalf("unexpected file content. Expecting %q", "- SessionClient") } if !strings.Contains(string(buf), "rfc7540 .") { t.Fatalf("unexpected file content. Expecting %q", "rfc7540 .") } } func TestContextRenderFileFromFS(t *testing.T) { t.Parallel() ctx := NewContext(0) var req protocol.Request req.Header.SetMethod(consts.MethodGet) req.SetRequestURI("/some/path") req.CopyTo(&ctx.Request) ctx.FileFromFS("./fs.go", &FS{ Root: ".", IndexNames: nil, GenerateIndexPages: false, AcceptByteRange: true, }) assert.DeepEqual(t, consts.StatusOK, ctx.Response.StatusCode()) assert.True(t, strings.Contains(resp.GetHTTP1Response(&ctx.Response).String(), "func (fs *FS) initRequestHandler() {")) // when Go version <= 1.16, mime.TypeByExtension will return Content-Type='text/plain; charset=utf-8', // otherwise it will return Content-Type='text/x-go; charset=utf-8' assert.NotEqual(t, "", string(ctx.Response.Header.Peek("Content-Type"))) assert.DeepEqual(t, "/some/path", string(ctx.Request.URI().Path())) } func TestContextRenderFile(t *testing.T) { t.Parallel() ctx := NewContext(0) var req protocol.Request req.Header.SetMethod(consts.MethodGet) req.SetRequestURI("/") req.CopyTo(&ctx.Request) ctx.File("./fs.go") assert.DeepEqual(t, consts.StatusOK, ctx.Response.StatusCode()) assert.True(t, strings.Contains(resp.GetHTTP1Response(&ctx.Response).String(), "func (fs *FS) initRequestHandler() {")) // when Go version <= 1.16, mime.TypeByExtension will return Content-Type='text/plain; charset=utf-8', // otherwise it will return Content-Type='text/x-go; charset=utf-8' assert.NotEqual(t, "", string(ctx.Response.Header.Peek("Content-Type"))) } func TestContextRenderAttachment(t *testing.T) { t.Parallel() ctx := NewContext(0) var req protocol.Request req.Header.SetMethod(consts.MethodGet) req.SetRequestURI("/") req.CopyTo(&ctx.Request) newFilename := "new_filename.go" ctx.FileAttachment("./context.go", newFilename) assert.DeepEqual(t, consts.StatusOK, ctx.Response.StatusCode()) assert.True(t, strings.Contains(resp.GetHTTP1Response(&ctx.Response).String(), "func (ctx *RequestContext) FileAttachment(filepath, filename string) {")) assert.DeepEqual(t, fmt.Sprintf("attachment; filename=\"%s\"", newFilename), string(ctx.Response.Header.Peek("Content-Disposition"))) } func TestRequestContext_Header(t *testing.T) { c := NewContext(0) c.Header("header_key", "header_val") val := string(c.Response.Header.Peek("header_key")) if val != "header_val" { t.Fatalf("unexpected %q. Expecting %q", val, "header_val") } c.Response.Header.Del("header_key") val = string(c.Response.Header.Peek("header_key")) if val != "" { t.Fatalf("unexpected %q. Expecting %q", val, "") } c.Header("header_key1", "header_val1") c.Header("header_key1", "") val = string(c.Response.Header.Peek("header_key1")) if val != "" { t.Fatalf("unexpected %q. Expecting %q", val, "") } } func TestRequestContext_Keys(t *testing.T) { c := NewContext(0) rightVal := "123" c.Set("key", rightVal) val := c.GetString("key") if val != rightVal { t.Fatalf("unexpected %v. Expecting %v", val, rightVal) } } func testFunc(c context.Context, ctx *RequestContext) { ctx.Next(c) } func testFunc2(c context.Context, ctx *RequestContext) { ctx.Set("key", "123") } func TestRequestContext_Handler(t *testing.T) { c := NewContext(0) c.handlers = HandlersChain{testFunc, testFunc2} c.Handler()(context.Background(), c) val := c.GetString("key") if val != "123" { t.Fatalf("unexpected %v. Expecting %v", val, "123") } c.handlers = nil handler := c.Handler() assert.Nil(t, handler) } func TestRequestContext_Handlers(t *testing.T) { c := NewContext(0) hc := HandlersChain{testFunc, testFunc2} c.SetHandlers(hc) c.Handlers()[1](context.Background(), c) val := c.GetString("key") if val != "123" { t.Fatalf("unexpected %v. Expecting %v", val, "123") } } func TestRequestContext_HandlerName(t *testing.T) { c := NewContext(0) c.handlers = HandlersChain{testFunc, testFunc2} val := c.HandlerName() if val != "github.com/cloudwego/hertz/pkg/app.testFunc2" { t.Fatalf("unexpected %v. Expecting %v", val, "github.com/cloudwego/hertz.testFunc2") } } func TestNext(t *testing.T) { c := NewContext(0) a := 0 testFunc1 := func(c context.Context, ctx *RequestContext) { a = 1 } testFunc3 := func(c context.Context, ctx *RequestContext) { a = 3 } c.handlers = HandlersChain{testFunc1, testFunc3} c.Next(context.Background()) assert.True(t, c.index == 2) assert.DeepEqual(t, 3, a) } func TestContextError(t *testing.T) { c := NewContext(0) assert.Nil(t, c.Errors) firstErr := errors.New("first error") c.Error(firstErr) // nolint: errcheck assert.DeepEqual(t, 1, len(c.Errors)) assert.DeepEqual(t, "Error #01: first error\n", c.Errors.String()) secondErr := errors.New("second error") c.Error(&errs.Error{ // nolint: errcheck Err: secondErr, Meta: "some data 2", Type: errs.ErrorTypePublic, }) assert.DeepEqual(t, 2, len(c.Errors)) assert.DeepEqual(t, firstErr, c.Errors[0].Err) assert.Nil(t, c.Errors[0].Meta) assert.DeepEqual(t, errs.ErrorTypePrivate, c.Errors[0].Type) assert.DeepEqual(t, secondErr, c.Errors[1].Err) assert.DeepEqual(t, "some data 2", c.Errors[1].Meta) assert.DeepEqual(t, errs.ErrorTypePublic, c.Errors[1].Type) assert.DeepEqual(t, c.Errors.Last(), c.Errors[1]) defer func() { if recover() == nil { t.Error("didn't panic") } }() c.Error(nil) // nolint: errcheck } func TestContextAbortWithError(t *testing.T) { c := NewContext(0) c.AbortWithError(consts.StatusUnauthorized, errors.New("bad input")).SetMeta("some input") // nolint: errcheck assert.DeepEqual(t, consts.StatusUnauthorized, c.Response.StatusCode()) assert.DeepEqual(t, con.AbortIndex, c.index) assert.True(t, c.IsAborted()) } func TestRender(t *testing.T) { c := NewContext(0) c.Render(consts.StatusOK, &render.Data{ ContentType: consts.MIMEApplicationJSONUTF8, Data: []byte("{\"test\":1}"), }) assert.DeepEqual(t, consts.StatusOK, c.Response.StatusCode()) assert.True(t, strings.Contains(string(c.Response.Body()), "test")) c.Reset() c.Render(110, &render.Data{ ContentType: "application/json; charset=utf-8", Data: []byte("{\"test\":1}"), }) assert.DeepEqual(t, "application/json; charset=utf-8", string(c.Response.Header.ContentType())) assert.DeepEqual(t, "", string(c.Response.Body())) c.Reset() c.Render(consts.StatusNoContent, &render.Data{ ContentType: "application/json; charset=utf-8", Data: []byte("{\"test\":1}"), }) assert.DeepEqual(t, "application/json; charset=utf-8", string(c.Response.Header.ContentType())) assert.DeepEqual(t, "", string(c.Response.Body())) c.Reset() c.Render(consts.StatusNotModified, &render.Data{ ContentType: "application/json; charset=utf-8", Data: []byte("{\"test\":1}"), }) assert.DeepEqual(t, "application/json; charset=utf-8", string(c.Response.Header.ContentType())) assert.DeepEqual(t, "", string(c.Response.Body())) } func TestHTML(t *testing.T) { c := NewContext(0) tmpl := template.Must(template.New(""). Delims("{[{", "}]}"). Funcs(template.FuncMap{}). ParseFiles("../common/testdata/template/index.tmpl")) r := &render.HTMLProduction{Template: tmpl} c.HTMLRender = r c.HTML(consts.StatusOK, "index.tmpl", utils.H{"title": "Main website"}) assert.DeepEqual(t, []byte("text/html; charset=utf-8"), c.Response.Header.Peek("Content-Type")) assert.DeepEqual(t, []byte("

Main website

"), c.Response.Body()) } type xmlmap map[string]interface{} // Allows type H to be used with xml.Marshal func (h xmlmap) MarshalXML(e *xml.Encoder, start xml.StartElement) error { start.Name = xml.Name{ Space: "", Local: "map", } if err := e.EncodeToken(start); err != nil { return err } for key, value := range h { elem := xml.StartElement{ Name: xml.Name{Space: "", Local: key}, Attr: []xml.Attr{}, } if err := e.EncodeElement(value, elem); err != nil { return err } } return e.EncodeToken(xml.EndElement{Name: start.Name}) } func TestXML(t *testing.T) { c := NewContext(0) c.XML(consts.StatusOK, xmlmap{"foo": "bar"}) assert.DeepEqual(t, []byte("bar"), c.Response.Body()) assert.DeepEqual(t, []byte("application/xml; charset=utf-8"), c.Response.Header.Peek("Content-Type")) } func TestJSON(t *testing.T) { c := NewContext(0) c.JSON(consts.StatusOK, "test") assert.DeepEqual(t, consts.StatusOK, c.Response.StatusCode()) assert.True(t, strings.Contains(string(c.Response.Body()), "test")) } func TestDATA(t *testing.T) { c := NewContext(0) c.Data(consts.StatusOK, "application/json; charset=utf-8", []byte("{\"test\":1}")) assert.DeepEqual(t, consts.StatusOK, c.Response.StatusCode()) assert.True(t, strings.Contains(string(c.Response.Body()), "test")) } func TestContextReset(t *testing.T) { c := NewContext(0) c.index = 2 c.Params = param.Params{param.Param{}} c.Error(errors.New("test")) // nolint: errcheck c.Set("foo", "bar") c.Finished() c.Request.SetIsTLS(true) c.ResetWithoutConn() c.Request.URI() assert.DeepEqual(t, "https", string(c.Request.Scheme())) assert.False(t, c.IsAborted()) assert.DeepEqual(t, 0, len(c.Errors)) assert.Nil(t, c.Errors.Errors()) assert.Nil(t, c.Errors.ByType(errs.ErrorTypeAny)) assert.DeepEqual(t, 0, len(c.Params)) assert.DeepEqual(t, int8(-1), c.index) assert.Nil(t, c.finished) } func TestContextContentType(t *testing.T) { c := NewContext(0) c.Request.Header.Set("Content-Type", consts.MIMEApplicationJSONUTF8) assert.DeepEqual(t, consts.MIMEApplicationJSONUTF8, bytesconv.B2s(c.ContentType())) } type MockConn struct { *mock.Conn remote net.Addr } func (c *MockConn) RemoteAddr() net.Addr { return c.remote } func newContextClientIPTest() *RequestContext { c := NewContext(0) c.conn = &MockConn{ remote: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080}, } c.Request.Header.Set("X-Real-IP", " 10.10.10.10 ") c.Request.Header.Set("X-Forwarded-For", " 20.20.20.20, 30.30.30.30") return c } func TestClientIP(t *testing.T) { c := newContextClientIPTest() // default X-Forwarded-For and X-Real-IP behaviour assert.DeepEqual(t, "20.20.20.20", c.ClientIP()) c.Request.Header.DelBytes([]byte("X-Forwarded-For")) assert.DeepEqual(t, "10.10.10.10", c.ClientIP()) c.Request.Header.Set("X-Forwarded-For", "30.30.30.30 ") assert.DeepEqual(t, "30.30.30.30", c.ClientIP()) // No trusted CIDRS c = newContextClientIPTest() opts := ClientIPOptions{ RemoteIPHeaders: []string{"X-Forwarded-For", "X-Real-IP"}, TrustedCIDRs: nil, } c.SetClientIPFunc(ClientIPWithOption(opts)) assert.DeepEqual(t, "127.0.0.1", c.ClientIP()) _, cidr, _ := net.ParseCIDR("30.30.30.30/32") opts = ClientIPOptions{ RemoteIPHeaders: []string{"X-Forwarded-For", "X-Real-IP"}, TrustedCIDRs: []*net.IPNet{cidr}, } c.SetClientIPFunc(ClientIPWithOption(opts)) assert.DeepEqual(t, "127.0.0.1", c.ClientIP()) _, cidr, _ = net.ParseCIDR("127.0.0.1/32") opts = ClientIPOptions{ RemoteIPHeaders: []string{"X-Forwarded-For", "X-Real-IP"}, TrustedCIDRs: []*net.IPNet{cidr}, } c.SetClientIPFunc(ClientIPWithOption(opts)) assert.DeepEqual(t, "30.30.30.30", c.ClientIP()) // UDS c.conn = &MockConn{remote: &net.UnixAddr{Net: "unix", Name: "/tmp/test.sock"}} assert.DeepEqual(t, "30.30.30.30", c.ClientIP()) // err: Addr not host:port c.conn = &MockConn{remote: &net.UnixAddr{Net: "tcp", Name: "/tmp/test.sock"}} assert.DeepEqual(t, "", c.ClientIP()) } func TestSetClientIPFunc(t *testing.T) { fn := func(ctx *RequestContext) string { return "" } SetClientIPFunc(fn) assert.DeepEqual(t, reflect.ValueOf(fn).Pointer(), reflect.ValueOf(defaultClientIP).Pointer()) } func TestGetQuery(t *testing.T) { c := NewContext(0) c.Request.SetRequestURI("http://aaa.com?a=1&b=") v, exists := c.GetQuery("b") assert.DeepEqual(t, "", v) assert.DeepEqual(t, true, exists) } func TestGetPostForm(t *testing.T) { c := NewContext(0) c.Request.Header.SetContentTypeBytes([]byte(consts.MIMEApplicationHTMLForm)) c.Request.SetBodyString("a=1&b=") v, exists := c.GetPostForm("b") assert.DeepEqual(t, "", v) assert.DeepEqual(t, true, exists) } func TestGetPostFormArray(t *testing.T) { c := NewContext(0) c.Request.Header.SetContentTypeBytes([]byte(consts.MIMEApplicationHTMLForm)) c.Request.SetBodyString("a=1&b=2&b=3") v, _ := c.GetPostFormArray("b") assert.DeepEqual(t, []string{"2", "3"}, v) } func TestRemoteAddr(t *testing.T) { c := NewContext(0) c.Request.SetRequestURI("http://aaa.com?a=1&b=") addr := c.RemoteAddr().String() assert.DeepEqual(t, "0.0.0.0:0", addr) } func TestRequestBodyStream(t *testing.T) { c := NewContext(0) s := "testRequestBodyStream" mr := bytes.NewBufferString(s) c.Request.SetBodyStream(mr, -1) data, err := ioutil.ReadAll(c.RequestBodyStream()) assert.Nil(t, err) assert.DeepEqual(t, "testRequestBodyStream", string(data)) } func TestContextIsAborted(t *testing.T) { ctx := NewContext(0) assert.False(t, ctx.IsAborted()) ctx.Abort() assert.True(t, ctx.IsAborted()) ctx.Next(context.Background()) assert.True(t, ctx.IsAborted()) ctx.index++ assert.True(t, ctx.IsAborted()) } func TestContextAbortWithStatus(t *testing.T) { c := NewContext(0) c.index = 4 c.AbortWithStatus(consts.StatusUnauthorized) assert.DeepEqual(t, con.AbortIndex, c.index) assert.DeepEqual(t, consts.StatusUnauthorized, c.Response.Header.StatusCode()) assert.True(t, c.IsAborted()) } type testJSONAbortMsg struct { Foo string `json:"foo"` Bar string `json:"bar"` } func TestContextAbortWithStatusJSON(t *testing.T) { c := NewContext(0) c.index = 4 in := new(testJSONAbortMsg) in.Bar = "barValue" in.Foo = "fooValue" c.AbortWithStatusJSON(consts.StatusUnsupportedMediaType, in) assert.DeepEqual(t, con.AbortIndex, c.index) assert.DeepEqual(t, consts.StatusUnsupportedMediaType, c.Response.Header.StatusCode()) assert.True(t, c.IsAborted()) contentType := c.Response.Header.Peek("Content-Type") assert.DeepEqual(t, consts.MIMEApplicationJSONUTF8, string(contentType)) jsonStringBody := c.Response.Body() assert.DeepEqual(t, "{\"foo\":\"fooValue\",\"bar\":\"barValue\"}", string(jsonStringBody)) } func TestRequestCtxFormValue(t *testing.T) { ctx := NewContext(0) ctx.Request.SetRequestURI("/foo/bar?baz=123&aaa=bbb") ctx.Request.Header.SetContentTypeBytes([]byte(consts.MIMEApplicationHTMLForm)) ctx.Request.SetBodyString("qqq=port&mmm=sddd") v := ctx.FormValue("baz") if string(v) != "123" { t.Fatalf("unexpected value %q. Expecting %q", v, "123") } v = ctx.FormValue("mmm") if string(v) != "sddd" { t.Fatalf("unexpected value %q. Expecting %q", v, "sddd") } v = ctx.FormValue("aaaasdfsdf") if len(v) > 0 { t.Fatalf("unexpected value for unknown key %q", v) } ctx.Request.Reset() ctx.Request.SetFormData(map[string]string{ "a": "1", }) v = ctx.FormValue("a") if string(v) != "1" { t.Fatalf("unexpected value %q. Expecting %q", v, "1") } ctx.Request.Reset() s := `------WebKitFormBoundaryJwfATyF8tmxSJnLg Content-Disposition: form-data; name="f" fff ------WebKitFormBoundaryJwfATyF8tmxSJnLg ` mr := bytes.NewBufferString(s) ctx.Request.SetBodyStream(mr, -1) ctx.Request.Header.SetContentLength(len(s)) ctx.Request.Header.SetContentTypeBytes([]byte("multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg")) v = ctx.FormValue("f") if string(v) != "fff" { t.Fatalf("unexpected value %q. Expecting %q", v, "fff") } } func TestSetCustomFormValueFunc(t *testing.T) { ctx := NewContext(0) ctx.Request.SetRequestURI("/foo/bar?aaa=bbb") ctx.Request.Header.SetContentTypeBytes([]byte(consts.MIMEApplicationHTMLForm)) ctx.Request.SetBodyString("aaa=port") ctx.SetFormValueFunc(func(ctx *RequestContext, key string) []byte { v := ctx.PostArgs().Peek(key) if len(v) > 0 { return v } mf, err := ctx.MultipartForm() if err == nil && mf.Value != nil { vv := mf.Value[key] if len(vv) > 0 { return []byte(vv[0]) } } v = ctx.QueryArgs().Peek(key) if len(v) > 0 { return v } return nil }) v := ctx.FormValue("aaa") if string(v) != "port" { t.Fatalf("unexpected value %q. Expecting %q", v, "port") } } func TestContextSetGet(t *testing.T) { c := &RequestContext{} c.Set("foo", "bar") value, err := c.Get("foo") assert.DeepEqual(t, "bar", value) assert.True(t, err) value, err = c.Get("foo2") assert.Nil(t, value) assert.False(t, err) assert.DeepEqual(t, "bar", c.MustGet("foo")) assert.Panic(t, func() { c.MustGet("no_exist") }) } func TestContextSetGetValues(t *testing.T) { c := &RequestContext{} c.Set("string", "this is a string") c.Set("int32", int32(-42)) c.Set("int64", int64(42424242424242)) c.Set("uint32", uint32(42)) c.Set("uint64", uint64(42424242424242)) c.Set("float32", float32(4.2)) c.Set("float64", 4.2) var a interface{} = 1 c.Set("intInterface", a) assert.DeepEqual(t, c.MustGet("string").(string), "this is a string") assert.DeepEqual(t, c.MustGet("int32").(int32), int32(-42)) assert.DeepEqual(t, c.MustGet("int64").(int64), int64(42424242424242)) assert.DeepEqual(t, c.MustGet("uint32").(uint32), uint32(42)) assert.DeepEqual(t, c.MustGet("uint64").(uint64), uint64(42424242424242)) assert.DeepEqual(t, c.MustGet("float32").(float32), float32(4.2)) assert.DeepEqual(t, c.MustGet("float64").(float64), 4.2) assert.DeepEqual(t, c.MustGet("intInterface").(int), 1) } func TestContextGetString(t *testing.T) { c := &RequestContext{} c.Set("string", "this is a string") assert.DeepEqual(t, "this is a string", c.GetString("string")) c.Set("bool", false) assert.DeepEqual(t, "", c.GetString("bool")) } func TestContextSetGetBool(t *testing.T) { c := &RequestContext{} c.Set("bool", true) assert.True(t, c.GetBool("bool")) c.Set("string", "this is a string") assert.False(t, c.GetBool("string")) } func TestContextGetInt(t *testing.T) { c := &RequestContext{} c.Set("int", 1) assert.DeepEqual(t, 1, c.GetInt("int")) c.Set("string", "this is a string") assert.DeepEqual(t, 0, c.GetInt("string")) } func TestContextGetInt32(t *testing.T) { c := &RequestContext{} c.Set("int32", int32(-42)) assert.DeepEqual(t, int32(-42), c.GetInt32("int32")) c.Set("string", "this is a string") assert.DeepEqual(t, int32(0), c.GetInt32("string")) } func TestContextGetInt64(t *testing.T) { c := &RequestContext{} c.Set("int64", int64(42424242424242)) assert.DeepEqual(t, int64(42424242424242), c.GetInt64("int64")) c.Set("string", "this is a string") assert.DeepEqual(t, int64(0), c.GetInt64("string")) } func TestContextGetUint(t *testing.T) { c := &RequestContext{} c.Set("uint", uint(1)) assert.DeepEqual(t, uint(1), c.GetUint("uint")) c.Set("string", "this is a string") assert.DeepEqual(t, uint(0), c.GetUint("string")) } func TestContextGetUint32(t *testing.T) { c := &RequestContext{} c.Set("uint32", uint32(42)) assert.DeepEqual(t, uint32(42), c.GetUint32("uint32")) c.Set("string", "this is a string") assert.DeepEqual(t, uint32(0), c.GetUint32("string")) } func TestContextGetUint64(t *testing.T) { c := &RequestContext{} c.Set("uint64", uint64(42424242424242)) assert.DeepEqual(t, uint64(42424242424242), c.GetUint64("uint64")) c.Set("string", "this is a string") assert.DeepEqual(t, uint64(0), c.GetUint64("string")) } func TestContextGetFloat32(t *testing.T) { c := &RequestContext{} c.Set("float32", float32(4.2)) assert.DeepEqual(t, float32(4.2), c.GetFloat32("float32")) c.Set("string", "this is a string") assert.DeepEqual(t, float32(0.0), c.GetFloat32("string")) } func TestContextGetFloat64(t *testing.T) { c := &RequestContext{} c.Set("float64", 4.2) assert.DeepEqual(t, 4.2, c.GetFloat64("float64")) c.Set("string", "this is a string") assert.DeepEqual(t, 0.0, c.GetFloat64("string")) } func TestContextGetTime(t *testing.T) { c := &RequestContext{} t1, _ := time.Parse("1/2/2006 15:04:05", "01/01/2017 12:00:00") c.Set("time", t1) assert.DeepEqual(t, t1, c.GetTime("time")) c.Set("string", "this is a string") assert.DeepEqual(t, time.Time{}, c.GetTime("string")) } func TestContextGetDuration(t *testing.T) { c := &RequestContext{} c.Set("duration", time.Second) assert.DeepEqual(t, time.Second, c.GetDuration("duration")) c.Set("string", "this is a string") assert.DeepEqual(t, time.Duration(0), c.GetDuration("string")) } func TestContextGetStringSlice(t *testing.T) { c := &RequestContext{} c.Set("slice", []string{"foo"}) assert.DeepEqual(t, []string{"foo"}, c.GetStringSlice("slice")) c.Set("string", "this is a string") var expected []string assert.DeepEqual(t, expected, c.GetStringSlice("string")) } func TestContextGetStringMap(t *testing.T) { c := &RequestContext{} m := make(map[string]interface{}) m["foo"] = 1 c.Set("map", m) assert.DeepEqual(t, m, c.GetStringMap("map")) assert.DeepEqual(t, 1, c.GetStringMap("map")["foo"]) c.Set("string", "this is a string") var expected map[string]interface{} assert.DeepEqual(t, expected, c.GetStringMap("string")) } func TestContextGetStringMapString(t *testing.T) { c := &RequestContext{} m := make(map[string]string) m["foo"] = "bar" c.Set("map", m) assert.DeepEqual(t, m, c.GetStringMapString("map")) assert.DeepEqual(t, "bar", c.GetStringMapString("map")["foo"]) c.Set("string", "this is a string") var expected map[string]string assert.DeepEqual(t, expected, c.GetStringMapString("string")) } func TestContextGetStringMapStringSlice(t *testing.T) { c := &RequestContext{} m := make(map[string][]string) m["foo"] = []string{"foo"} c.Set("map", m) assert.DeepEqual(t, m, c.GetStringMapStringSlice("map")) assert.DeepEqual(t, []string{"foo"}, c.GetStringMapStringSlice("map")["foo"]) c.Set("string", "this is a string") var expected map[string][]string assert.DeepEqual(t, expected, c.GetStringMapStringSlice("string")) } func TestContextTraceInfo(t *testing.T) { ctx := NewContext(0) traceIn := traceinfo.NewTraceInfo() ctx.SetTraceInfo(traceIn) traceOut := ctx.GetTraceInfo() assert.DeepEqual(t, traceIn, traceOut) } func TestEnableTrace(t *testing.T) { ctx := NewContext(0) ctx.SetEnableTrace(true) trace := ctx.IsEnableTrace() assert.True(t, trace) } func TestForEachKey(t *testing.T) { ctx := NewContext(0) ctx.Set("1", "2") handle := func(k string, v interface{}) { res := k + v.(string) assert.DeepEqual(t, res, "12") } ctx.ForEachKey(handle) val, ok := ctx.Get("1") assert.DeepEqual(t, val, "2") assert.True(t, ok) } func TestFlush(t *testing.T) { ctx := NewContext(0) err := ctx.Flush() assert.Nil(t, err) } func TestConn(t *testing.T) { ctx := NewContext(0) conn := mock.NewConn("") ctx.SetConn(conn) connRes := ctx.GetConn() val1 := reflect.ValueOf(conn).Pointer() val2 := reflect.ValueOf(connRes).Pointer() assert.DeepEqual(t, val1, val2) } func TestHijackHandler(t *testing.T) { ctx := NewContext(0) handle := func(c network.Conn) { c.SetReadTimeout(time.Duration(1) * time.Second) } ctx.SetHijackHandler(handle) handleRes := ctx.GetHijackHandler() val1 := reflect.ValueOf(handle).Pointer() val2 := reflect.ValueOf(handleRes).Pointer() assert.DeepEqual(t, val1, val2) } func TestGetReader(t *testing.T) { ctx := NewContext(0) conn := mock.NewConn("") ctx.SetConn(conn) connRes := ctx.GetReader() val1 := reflect.ValueOf(conn).Pointer() val2 := reflect.ValueOf(connRes).Pointer() assert.DeepEqual(t, val1, val2) } func TestGetWriter(t *testing.T) { ctx := NewContext(0) conn := mock.NewConn("") ctx.SetConn(conn) connRes := ctx.GetWriter() val1 := reflect.ValueOf(conn).Pointer() val2 := reflect.ValueOf(connRes).Pointer() assert.DeepEqual(t, val1, val2) } func TestIndex(t *testing.T) { ctx := NewContext(0) ctx.ResetWithoutConn() exc := int8(-1) res := ctx.GetIndex() assert.DeepEqual(t, exc, res) ctx.SetIndex(int8(1)) res = ctx.GetIndex() exc = int8(1) assert.DeepEqual(t, exc, res) } func TestConcurrentHandlerName(t *testing.T) { SetConcurrentHandlerNameOperator() defer SetHandlerNameOperator(&inbuiltHandlerNameOperatorStruct{handlerNames: map[uintptr]string{}}) h := func(c context.Context, ctx *RequestContext) {} SetHandlerName(h, "test1") for i := 0; i < 50; i++ { go func() { name := GetHandlerName(h) assert.DeepEqual(t, "test1", name) }() } time.Sleep(50 * time.Millisecond) go func() { SetHandlerName(h, "test2") }() time.Sleep(50 * time.Millisecond) name := GetHandlerName(h) assert.DeepEqual(t, "test2", name) } func TestHandlerName(t *testing.T) { h := func(c context.Context, ctx *RequestContext) {} SetHandlerName(h, "test1") name := GetHandlerName(h) assert.DeepEqual(t, "test1", name) } func TestHijack(t *testing.T) { ctx := NewContext(0) h := func(c network.Conn) {} ctx.Hijack(h) assert.True(t, ctx.Hijacked()) } func TestFinished(t *testing.T) { ctx := NewContext(0) ctx.Finished() ch := make(chan struct{}) ctx.finished = ch chRes := ctx.Finished() send := func() { time.Sleep(time.Duration(1) * time.Millisecond) ch <- struct{}{} } go send() val := <-chRes assert.DeepEqual(t, struct{}{}, val) } func TestString(t *testing.T) { ctx := NewContext(0) ctx.String(consts.StatusOK, "ok") assert.DeepEqual(t, consts.StatusOK, ctx.Response.StatusCode()) } func TestFullPath(t *testing.T) { ctx := NewContext(0) str := "/hello" ctx.SetFullPath(str) val := ctx.FullPath() assert.DeepEqual(t, str, val) } func TestReset(t *testing.T) { ctx := NewContext(0) ctx.Reset() assert.DeepEqual(t, nil, ctx.conn) } // func TestParam(t *testing.T) { // ctx := NewContext(0) // val := ctx.Param("/user/john") // assert.DeepEqual(t, "john", val) // } func TestGetHeader(t *testing.T) { ctx := NewContext(0) ctx.Request.Header.SetContentTypeBytes([]byte(consts.MIMETextPlainUTF8)) val := ctx.GetHeader("Content-Type") assert.DeepEqual(t, consts.MIMETextPlainUTF8, string(val)) } func TestGetRawData(t *testing.T) { ctx := NewContext(0) ctx.Request.SetBody([]byte("hello")) val := ctx.GetRawData() assert.DeepEqual(t, "hello", string(val)) val2, err := ctx.Body() assert.DeepEqual(t, val, val2) assert.Nil(t, err) } func TestRequestContext_GetRequest(t *testing.T) { c := &RequestContext{} c.Request.Header.Set("key1", "value1") c.Request.SetBody([]byte("test body")) req := c.GetRequest() if req.Header.Get("key1") != "value1" { t.Fatal("should have header: key1:value1") } if string(req.Body()) != "test body" { t.Fatal("should have body: test body") } } func TestRequestContext_GetResponse(t *testing.T) { c := &RequestContext{} c.Response.Header.Set("key1", "value1") c.Response.SetBody([]byte("test body")) resp := c.GetResponse() if resp.Header.Get("key1") != "value1" { t.Fatal("should have header: key1:value1") } if string(resp.Body()) != "test body" { t.Fatal("should have body: test body") } } func TestBindAndValidate(t *testing.T) { type Test struct { A string `query:"a"` B int `query:"b" vd:"$>10"` } c := &RequestContext{} c.Request.SetRequestURI("/foo/bar?a=123&b=11") var req Test err := c.BindAndValidate(&req) if err != nil { t.Fatalf("unexpected error: %v", err) } assert.DeepEqual(t, "123", req.A) assert.DeepEqual(t, 11, req.B) c.Request.URI().Reset() c.Request.SetRequestURI("/foo/bar?a=123&b=9") req = Test{} err = c.BindAndValidate(&req) if err == nil { t.Fatalf("unexpected nil, expected an error") } c.Request.URI().Reset() c.Request.SetRequestURI("/foo/bar?a=123&b=9") req = Test{} err = c.Bind(&req) if err != nil { t.Fatalf("unexpected error: %v", err) } assert.DeepEqual(t, "123", req.A) assert.DeepEqual(t, 9, req.B) err = c.Validate(&req) if err == nil { t.Fatalf("unexpected nil, expected an error") } } func TestBindForm(t *testing.T) { type Test struct { A string B int } c := &RequestContext{} c.Request.SetRequestURI("/foo/bar?a=123&b=11") c.Request.SetBody([]byte("A=123&B=11")) c.Request.Header.SetContentTypeBytes([]byte("application/x-www-form-urlencoded")) var req Test err := c.BindForm(&req) if err != nil { t.Fatalf("unexpected error: %v", err) } assert.DeepEqual(t, "123", req.A) assert.DeepEqual(t, 11, req.B) c.Request.SetBody([]byte("")) err = c.BindForm(&req) if err == nil { t.Fatalf("expected error, but get nil") } } func TestSetBinder(t *testing.T) { c := NewContext(0) c.SetBinder(binder.NewBinderWithValidateError(errors.New("test binder"))) type T struct{} req := T{} err := c.Bind(&req) assert.Nil(t, err) err = c.Validate(&req) assert.NotNil(t, err) assert.DeepEqual(t, "test binder", err.Error()) err = c.BindProtobuf(&req) assert.Nil(t, err) err = c.BindJSON(&req) assert.Nil(t, err) err = c.BindForm(&req) assert.NotNil(t, err) err = c.BindPath(&req) assert.Nil(t, err) err = c.BindQuery(&req) assert.Nil(t, err) err = c.BindHeader(&req) assert.Nil(t, err) } func TestRequestContext_SetCookie(t *testing.T) { c := NewContext(0) c.SetCookie("user", "hertz", 1, "/", "localhost", protocol.CookieSameSiteNoneMode, true, true) assert.DeepEqual(t, "user=hertz; max-age=1; domain=localhost; path=/; HttpOnly; secure; SameSite=None", c.Response.Header.Get("Set-Cookie")) } func TestRequestContext_SetPartitionedCookie(t *testing.T) { c := NewContext(0) c.SetPartitionedCookie("user", "hertz", 1, "/", "localhost", protocol.CookieSameSiteNoneMode, true, true) assert.DeepEqual(t, "user=hertz; max-age=1; domain=localhost; path=/; HttpOnly; secure; SameSite=None; Partitioned", c.Response.Header.Get("Set-Cookie")) } func TestRequestContext_SetCookiePathEmpty(t *testing.T) { c := NewContext(0) c.SetCookie("user", "hertz", 1, "", "localhost", protocol.CookieSameSiteDisabled, true, true) assert.DeepEqual(t, "user=hertz; max-age=1; domain=localhost; path=/; HttpOnly; secure", c.Response.Header.Get("Set-Cookie")) } func TestRequestContext_VisitAll(t *testing.T) { t.Run("VisitAllQueryArgs", func(t *testing.T) { c := NewContext(0) var s []string c.QueryArgs().Add("cloudwego", "hertz") c.QueryArgs().Add("hello", "world") c.VisitAllQueryArgs(func(key, value []byte) { s = append(s, string(key), string(value)) }) assert.DeepEqual(t, []string{"cloudwego", "hertz", "hello", "world"}, s) }) t.Run("VisitAllPostArgs", func(t *testing.T) { c := NewContext(0) var s []string c.PostArgs().Add("cloudwego", "hertz") c.PostArgs().Add("hello", "world") c.VisitAllPostArgs(func(key, value []byte) { s = append(s, string(key), string(value)) }) assert.DeepEqual(t, []string{"cloudwego", "hertz", "hello", "world"}, s) }) t.Run("VisitAllCookie", func(t *testing.T) { c := NewContext(0) var s []string c.Request.Header.Set("Cookie", "aaa=bbb;ccc=ddd") c.VisitAllCookie(func(key, value []byte) { s = append(s, string(key), string(value)) }) assert.DeepEqual(t, []string{"aaa", "bbb", "ccc", "ddd"}, s) }) t.Run("VisitAllHeaders", func(t *testing.T) { c := NewContext(0) c.Request.Header.Set("xxx", "yyy") c.Request.Header.Set("xxx2", "yyy2") c.VisitAllHeaders( func(k, v []byte) { key := string(k) value := string(v) if key != "Xxx" && key != "Xxx2" { t.Fatalf("Unexpected %v. Expected %v", key, "xxx or yyy") } if key == "Xxx" && value != "yyy" { t.Fatalf("Unexpected %v. Expected %v", value, "yyy") } if key == "Xxx2" && value != "yyy2" { t.Fatalf("Unexpected %v. Expected %v", value, "yyy2") } }) }) } func BenchmarkInbuiltHandlerNameOperator(b *testing.B) { for n := 0; n < b.N; n++ { fn := func(c context.Context, ctx *RequestContext) { } SetHandlerName(fn, fmt.Sprintf("%d", n)) GetHandlerName(fn) } } func BenchmarkConcurrentHandlerNameOperator(b *testing.B) { SetConcurrentHandlerNameOperator() defer SetHandlerNameOperator(&inbuiltHandlerNameOperatorStruct{handlerNames: map[uintptr]string{}}) for n := 0; n < b.N; n++ { fn := func(c context.Context, ctx *RequestContext) { } SetHandlerName(fn, fmt.Sprintf("%d", n)) GetHandlerName(fn) } } ================================================ FILE: pkg/app/fs.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package app import ( "bytes" "compress/gzip" "context" "fmt" "html" "io" "io/ioutil" "mime" "net/http" "os" "path/filepath" "sort" "strings" "sync" "time" "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/internal/nocopy" "github.com/cloudwego/hertz/pkg/common/bytebufferpool" "github.com/cloudwego/hertz/pkg/common/compress" "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/hlog" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" ) var ( errDirIndexRequired = errors.NewPublic("directory index required") errNoCreatePermission = errors.NewPublic("no 'create file' permissions") rootFSOnce sync.Once rootFS = &FS{ Root: "/", GenerateIndexPages: true, Compress: true, AcceptByteRange: true, } rootFSHandler HandlerFunc strInvalidHost = []byte("invalid-host") ) // PathRewriteFunc must return new request path based on arbitrary ctx // info such as ctx.Path(). // // Path rewriter is used in FS for translating the current request // to the local filesystem path relative to FS.Root. // // The returned path must not contain '/../' substrings due to security reasons, // since such paths may refer files outside FS.Root. // // The returned path may refer to ctx members. For example, ctx.Path(). type PathRewriteFunc func(ctx *RequestContext) []byte // FS represents settings for request handler serving static files // from the local filesystem. // // It is prohibited copying FS values. Create new values instead. type FS struct { noCopy nocopy.NoCopy //lint:ignore U1000 until noCopy is used // Path to the root directory to serve files from. Root string // List of index file names to try opening during directory access. // // For example: // // * index.html // * index.htm // * my-super-index.xml // // By default the list is empty. IndexNames []string // Index pages for directories without files matching IndexNames // are automatically generated if set. // // Directory index generation may be quite slow for directories // with many files (more than 1K), so it is discouraged enabling // index pages' generation for such directories. // // By default index pages aren't generated. GenerateIndexPages bool // Transparently compresses responses if set to true. // // The server tries minimizing CPU usage by caching compressed files. // It adds CompressedFileSuffix suffix to the original file name and // tries saving the resulting compressed file under the new file name. // So it is advisable to give the server write access to Root // and to all inner folders in order to minimize CPU usage when serving // compressed responses. // // Transparent compression is disabled by default. Compress bool // Enables byte range requests if set to true. // // Byte range requests are disabled by default. AcceptByteRange bool // Path rewriting function. // // By default request path is not modified. PathRewrite PathRewriteFunc // PathNotFound fires when file is not found in filesystem // this functions tries to replace "Cannot open requested path" // server response giving to the programmer the control of server flow. // // By default PathNotFound returns // "Cannot open requested path" PathNotFound HandlerFunc // Expiration duration for inactive file handlers. // // FSHandlerCacheDuration is used by default. CacheDuration time.Duration // Suffix to add to the name of cached compressed file. // // This value has sense only if Compress is set. // // FSCompressedFileSuffix is used by default. CompressedFileSuffix string once sync.Once h HandlerFunc } type byteRangeUpdater interface { UpdateByteRange(startPos, endPos int) error } type fsSmallFileReader struct { ff *fsFile startPos int endPos int } func (r *fsSmallFileReader) Close() error { ff := r.ff ff.decReadersCount() r.ff = nil r.startPos = 0 r.endPos = 0 ff.h.smallFileReaderPool.Put(r) return nil } func (r *fsSmallFileReader) UpdateByteRange(startPos, endPos int) error { r.startPos = startPos r.endPos = endPos + 1 return nil } func (r *fsSmallFileReader) Read(p []byte) (int, error) { tailLen := r.endPos - r.startPos if tailLen <= 0 { return 0, io.EOF } if len(p) > tailLen { p = p[:tailLen] } ff := r.ff if ff.f != nil { n, err := ff.f.ReadAt(p, int64(r.startPos)) r.startPos += n return n, err } n := copy(p, ff.dirIndex[r.startPos:]) r.startPos += n return n, nil } func (r *fsSmallFileReader) WriteTo(w io.Writer) (int64, error) { ff := r.ff var n int var err error if ff.f == nil { n, err = w.Write(ff.dirIndex[r.startPos:r.endPos]) return int64(n), err } if rf, ok := w.(io.ReaderFrom); ok { return rf.ReadFrom(r) } curPos := r.startPos bufv := utils.CopyBufPool.Get() buf := bufv.([]byte) for err == nil { tailLen := r.endPos - curPos if tailLen <= 0 { break } if len(buf) > tailLen { buf = buf[:tailLen] } n, err = ff.f.ReadAt(buf, int64(curPos)) nw, errw := w.Write(buf[:n]) curPos += nw if errw == nil && nw != n { panic("BUG: Write(p) returned (n, nil), where n != len(p)") } if err == nil { err = errw } } utils.CopyBufPool.Put(bufv) if err == io.EOF { err = nil } return int64(curPos - r.startPos), err } // ServeFile returns HTTP response containing compressed file contents // from the given path. // // HTTP response may contain uncompressed file contents in the following cases: // // - Missing 'Accept-Encoding: gzip' request header. // - No write access to directory containing the file. // // Directory contents is returned if path points to directory. // // Use ServeFileUncompressed is you don't need serving compressed file contents. func ServeFile(ctx *RequestContext, path string) { rootFSOnce.Do(func() { rootFSHandler = rootFS.NewRequestHandler() }) if len(path) == 0 || path[0] != '/' { // extend relative path to absolute path var err error if path, err = filepath.Abs(path); err != nil { hlog.SystemLogger().Errorf("Cannot resolve path=%q to absolute file error=%s", path, err) ctx.AbortWithMsg("Internal Server Error", consts.StatusInternalServerError) return } } ctx.Request.SetRequestURI(path) rootFSHandler(context.Background(), ctx) } // NewRequestHandler returns new request handler with the given FS settings. // // The returned handler caches requested file handles // for FS.CacheDuration. // Make sure your program has enough 'max open files' limit aka // 'ulimit -n' if FS.Root folder contains many files. // // Do not create multiple request handlers from a single FS instance - // just reuse a single request handler. func (fs *FS) NewRequestHandler() HandlerFunc { fs.once.Do(fs.initRequestHandler) return fs.h } func (fs *FS) initRequestHandler() { root := fs.Root // serve files from the current working directory if root is empty if len(root) == 0 { root = "." } // strip trailing slashes from the root path for len(root) > 0 && root[len(root)-1] == '/' { root = root[:len(root)-1] } cacheDuration := fs.CacheDuration if cacheDuration <= 0 { cacheDuration = consts.FSHandlerCacheDuration } compressedFileSuffix := fs.CompressedFileSuffix if len(compressedFileSuffix) == 0 { compressedFileSuffix = consts.FSCompressedFileSuffix } h := &fsHandler{ root: root, indexNames: fs.IndexNames, pathRewrite: fs.PathRewrite, generateIndexPages: fs.GenerateIndexPages, compress: fs.Compress, pathNotFound: fs.PathNotFound, acceptByteRange: fs.AcceptByteRange, cacheDuration: cacheDuration, compressedFileSuffix: compressedFileSuffix, cache: make(map[string]*fsFile), compressedCache: make(map[string]*fsFile), } go func() { var pendingFiles []*fsFile for { time.Sleep(cacheDuration / 2) pendingFiles = h.cleanCache(pendingFiles) } }() fs.h = h.handleRequest } type fsHandler struct { root string indexNames []string pathRewrite PathRewriteFunc pathNotFound HandlerFunc generateIndexPages bool compress bool acceptByteRange bool cacheDuration time.Duration compressedFileSuffix string cache map[string]*fsFile compressedCache map[string]*fsFile cacheLock sync.Mutex smallFileReaderPool sync.Pool } // bigFileReader attempts to trigger sendfile // for sending big files over the wire. type bigFileReader struct { f *os.File ff *fsFile r io.Reader lr io.LimitedReader } func (r *bigFileReader) UpdateByteRange(startPos, endPos int) error { if _, err := r.f.Seek(int64(startPos), 0); err != nil { return err } r.r = &r.lr r.lr.R = r.f r.lr.N = int64(endPos - startPos + 1) return nil } func (r *bigFileReader) Read(p []byte) (int, error) { return r.r.Read(p) } func (r *bigFileReader) WriteTo(w io.Writer) (int64, error) { if rf, ok := w.(io.ReaderFrom); ok { // fast path. Sendfile must be triggered return rf.ReadFrom(r.r) } zw := network.NewWriter(w) // slow pathw return utils.CopyZeroAlloc(zw, r.r) } func (r *bigFileReader) Close() error { r.r = r.f n, err := r.f.Seek(0, 0) if err == nil { if n != 0 { panic("BUG: File.Seek(0,0) returned (non-zero, nil)") } ff := r.ff ff.bigFilesLock.Lock() ff.bigFiles = append(ff.bigFiles, r) ff.bigFilesLock.Unlock() } else { r.f.Close() } r.ff.decReadersCount() return err } func (h *fsHandler) cleanCache(pendingFiles []*fsFile) []*fsFile { var filesToRelease []*fsFile h.cacheLock.Lock() // Close files which couldn't be closed before due to non-zero // readers count on the previous run. var remainingFiles []*fsFile for _, ff := range pendingFiles { if ff.readersCount > 0 { remainingFiles = append(remainingFiles, ff) } else { filesToRelease = append(filesToRelease, ff) } } pendingFiles = remainingFiles pendingFiles, filesToRelease = cleanCacheNolock(h.cache, pendingFiles, filesToRelease, h.cacheDuration) pendingFiles, filesToRelease = cleanCacheNolock(h.compressedCache, pendingFiles, filesToRelease, h.cacheDuration) h.cacheLock.Unlock() for _, ff := range filesToRelease { ff.Release() } return pendingFiles } func (h *fsHandler) compressAndOpenFSFile(filePath string) (*fsFile, error) { f, err := os.Open(filePath) if err != nil { return nil, err } fileInfo, err := f.Stat() if err != nil { f.Close() return nil, fmt.Errorf("cannot obtain info for file %q: %s", filePath, err) } if fileInfo.IsDir() { f.Close() return nil, errDirIndexRequired } if strings.HasSuffix(filePath, h.compressedFileSuffix) || fileInfo.Size() > consts.FsMaxCompressibleFileSize || !isFileCompressible(f, consts.FsMinCompressRatio) { return h.newFSFile(f, fileInfo, false) } compressedFilePath := filePath + h.compressedFileSuffix absPath, err := filepath.Abs(compressedFilePath) if err != nil { f.Close() return nil, fmt.Errorf("cannot determine absolute path for %q: %s", compressedFilePath, err) } flock := getFileLock(absPath) flock.Lock() ff, err := h.compressFileNolock(f, fileInfo, filePath, compressedFilePath) flock.Unlock() return ff, err } func (h *fsHandler) newCompressedFSFile(filePath string) (*fsFile, error) { f, err := os.Open(filePath) if err != nil { return nil, fmt.Errorf("cannot open compressed file %q: %s", filePath, err) } fileInfo, err := f.Stat() if err != nil { f.Close() return nil, fmt.Errorf("cannot obtain info for compressed file %q: %s", filePath, err) } return h.newFSFile(f, fileInfo, true) } func (h *fsHandler) compressFileNolock(f *os.File, fileInfo os.FileInfo, filePath, compressedFilePath string) (*fsFile, error) { // Attempt to open compressed file created by another concurrent // goroutine. // It is safe opening such a file, since the file creation // is guarded by file mutex - see getFileLock call. if _, err := os.Stat(compressedFilePath); err == nil { f.Close() return h.newCompressedFSFile(compressedFilePath) } // Create temporary file, so concurrent goroutines don't use // it until it is created. tmpFilePath := compressedFilePath + ".tmp" zf, err := os.Create(tmpFilePath) if err != nil { f.Close() if !os.IsPermission(err) { return nil, fmt.Errorf("cannot create temporary file %q: %s", tmpFilePath, err) } return nil, errNoCreatePermission } zw := compress.AcquireStacklessGzipWriter(zf, compress.CompressDefaultCompression) zrw := network.NewWriter(zw) _, err = utils.CopyZeroAlloc(zrw, f) if err1 := zw.Flush(); err == nil { err = err1 } compress.ReleaseStacklessGzipWriter(zw, compress.CompressDefaultCompression) zf.Close() f.Close() if err != nil { return nil, fmt.Errorf("error when compressing file %q to %q: %s", filePath, tmpFilePath, err) } if err = os.Chtimes(tmpFilePath, time.Now(), fileInfo.ModTime()); err != nil { return nil, fmt.Errorf("cannot change modification time to %s for tmp file %q: %s", fileInfo.ModTime(), tmpFilePath, err) } if err = os.Rename(tmpFilePath, compressedFilePath); err != nil { return nil, fmt.Errorf("cannot move compressed file from %q to %q: %s", tmpFilePath, compressedFilePath, err) } return h.newCompressedFSFile(compressedFilePath) } func (h *fsHandler) openFSFile(filePath string, mustCompress bool) (*fsFile, error) { filePathOriginal := filePath if mustCompress { filePath += h.compressedFileSuffix } f, err := os.Open(filePath) if err != nil { if mustCompress && os.IsNotExist(err) { return h.compressAndOpenFSFile(filePathOriginal) } return nil, err } fileInfo, err := f.Stat() if err != nil { f.Close() return nil, fmt.Errorf("cannot obtain info for file %q: %s", filePath, err) } if fileInfo.IsDir() { f.Close() if mustCompress { return nil, fmt.Errorf("directory with unexpected suffix found: %q. Suffix: %q", filePath, h.compressedFileSuffix) } return nil, errDirIndexRequired } if mustCompress { fileInfoOriginal, err := os.Stat(filePathOriginal) if err != nil { f.Close() return nil, fmt.Errorf("cannot obtain info for original file %q: %s", filePathOriginal, err) } if fileInfoOriginal.ModTime() != fileInfo.ModTime() { // The compressed file became stale. Re-create it. f.Close() os.Remove(filePath) return h.compressAndOpenFSFile(filePathOriginal) } } return h.newFSFile(f, fileInfo, mustCompress) } func (h *fsHandler) newFSFile(f *os.File, fileInfo os.FileInfo, compressed bool) (*fsFile, error) { n := fileInfo.Size() contentLength := int(n) if n != int64(contentLength) { f.Close() return nil, fmt.Errorf("too big file: %d bytes", n) } // detect content-type ext := fileExtension(fileInfo.Name(), compressed, h.compressedFileSuffix) contentType := mime.TypeByExtension(ext) if len(contentType) == 0 { data, err := readFileHeader(f, compressed) if err != nil { return nil, fmt.Errorf("cannot read header of the file %q: %s", f.Name(), err) } contentType = http.DetectContentType(data) } lastModified := fileInfo.ModTime() ff := &fsFile{ h: h, f: f, contentType: contentType, contentLength: contentLength, compressed: compressed, lastModified: lastModified, lastModifiedStr: bytesconv.AppendHTTPDate(make([]byte, 0, len(http.TimeFormat)), lastModified), t: time.Now(), } return ff, nil } func (h *fsHandler) createDirIndex(base *protocol.URI, dirPath string, mustCompress bool) (*fsFile, error) { w := &bytebufferpool.ByteBuffer{} basePathEscaped := html.EscapeString(string(base.Path())) fmt.Fprintf(w, "%s", basePathEscaped) fmt.Fprintf(w, "

%s

", basePathEscaped) fmt.Fprintf(w, "
    ") if len(basePathEscaped) > 1 { var parentURI protocol.URI base.CopyTo(&parentURI) parentURI.Update(string(base.Path()) + "/..") parentPathEscaped := html.EscapeString(string(parentURI.Path())) fmt.Fprintf(w, `
  • ..
  • `, parentPathEscaped) } f, err := os.Open(dirPath) if err != nil { return nil, err } fileinfos, err := f.Readdir(0) f.Close() if err != nil { return nil, err } fm := make(map[string]os.FileInfo, len(fileinfos)) filenames := make([]string, 0, len(fileinfos)) for _, fi := range fileinfos { name := fi.Name() if strings.HasSuffix(name, h.compressedFileSuffix) { // Do not show compressed files on index page. continue } fm[name] = fi filenames = append(filenames, name) } var u protocol.URI base.CopyTo(&u) u.Update(string(u.Path()) + "/") sort.Strings(filenames) for _, name := range filenames { u.Update(name) pathEscaped := html.EscapeString(string(u.Path())) fi := fm[name] auxStr := "dir" className := "dir" if !fi.IsDir() { auxStr = fmt.Sprintf("file, %d bytes", fi.Size()) className = "file" } fmt.Fprintf(w, `
  • %s, %s, last modified %s
  • `, pathEscaped, className, html.EscapeString(name), auxStr, fsModTime(fi.ModTime())) } fmt.Fprintf(w, "
") if mustCompress { var zbuf bytebufferpool.ByteBuffer zbuf.B = compress.AppendGzipBytesLevel(zbuf.B, w.B, compress.CompressDefaultCompression) w = &zbuf } dirIndex := w.B lastModified := time.Now() ff := &fsFile{ h: h, dirIndex: dirIndex, contentType: "text/html; charset=utf-8", contentLength: len(dirIndex), compressed: mustCompress, lastModified: lastModified, lastModifiedStr: bytesconv.AppendHTTPDate(make([]byte, 0, len(http.TimeFormat)), lastModified), t: lastModified, } return ff, nil } func (h *fsHandler) openIndexFile(ctx *RequestContext, dirPath string, mustCompress bool) (*fsFile, error) { for _, indexName := range h.indexNames { indexFilePath := dirPath + "/" + indexName ff, err := h.openFSFile(indexFilePath, mustCompress) if err == nil { return ff, nil } if !os.IsNotExist(err) { return nil, fmt.Errorf("cannot open file %q: %s", indexFilePath, err) } } if !h.generateIndexPages { return nil, fmt.Errorf("cannot access directory without index page. Directory %q", dirPath) } return h.createDirIndex(ctx.URI(), dirPath, mustCompress) } func (ff *fsFile) decReadersCount() { ff.h.cacheLock.Lock() defer ff.h.cacheLock.Unlock() ff.readersCount-- if ff.readersCount < 0 { panic("BUG: negative fsFile.readersCount!") } } func (ff *fsFile) bigFileReader() (io.Reader, error) { if ff.f == nil { panic("BUG: ff.f must be non-nil in bigFileReader") } var r io.Reader ff.bigFilesLock.Lock() n := len(ff.bigFiles) if n > 0 { r = ff.bigFiles[n-1] ff.bigFiles = ff.bigFiles[:n-1] } ff.bigFilesLock.Unlock() if r != nil { return r, nil } f, err := os.Open(ff.f.Name()) if err != nil { return nil, fmt.Errorf("cannot open already opened file: %s", err) } return &bigFileReader{ f: f, ff: ff, r: f, }, nil } func (ff *fsFile) NewReader() (io.Reader, error) { if ff.isBig() { r, err := ff.bigFileReader() if err != nil { ff.decReadersCount() } return r, err } return ff.smallFileReader(), nil } func (ff *fsFile) smallFileReader() io.Reader { v := ff.h.smallFileReaderPool.Get() if v == nil { v = &fsSmallFileReader{} } r := v.(*fsSmallFileReader) r.ff = ff r.endPos = ff.contentLength if r.startPos > 0 { panic("BUG: fsSmallFileReader with non-nil startPos found in the pool") } return r } func (h *fsHandler) handleRequest(c context.Context, ctx *RequestContext) { var path []byte if h.pathRewrite != nil { path = h.pathRewrite(ctx) } else { path = ctx.Path() } path = stripTrailingSlashes(path) if n := bytes.IndexByte(path, 0); n >= 0 { hlog.SystemLogger().Errorf("Cannot serve path with nil byte at position=%d, path=%q", n, path) ctx.AbortWithMsg("Are you a hacker?", consts.StatusBadRequest) return } if h.pathRewrite != nil { // There is no need to check for '/../' if path = ctx.Path(), // since ctx.Path must normalize and sanitize the path. if n := bytes.Index(path, bytestr.StrSlashDotDotSlash); n >= 0 { hlog.SystemLogger().Errorf("Cannot serve path with '/../' at position=%d due to security reasons, path=%q", n, path) ctx.AbortWithMsg("Internal Server Error", consts.StatusInternalServerError) return } } mustCompress := false fileCache := h.cache byteRange := ctx.Request.Header.PeekRange() if len(byteRange) == 0 && h.compress && ctx.Request.Header.HasAcceptEncodingBytes(bytestr.StrGzip) { mustCompress = true fileCache = h.compressedCache } h.cacheLock.Lock() ff, ok := fileCache[string(path)] if ok { ff.readersCount++ } h.cacheLock.Unlock() if !ok { pathStr := string(path) filePath := h.root + pathStr var err error ff, err = h.openFSFile(filePath, mustCompress) if mustCompress && err == errNoCreatePermission { hlog.SystemLogger().Errorf("Insufficient permissions for saving compressed file for path=%q. Serving uncompressed file. "+ "Allow write access to the directory with this file in order to improve hertz performance", filePath) mustCompress = false ff, err = h.openFSFile(filePath, mustCompress) } if err == errDirIndexRequired { ff, err = h.openIndexFile(ctx, filePath, mustCompress) if err != nil { hlog.SystemLogger().Errorf("Cannot open dir index, path=%q, error=%s", filePath, err) ctx.AbortWithMsg("Directory index is forbidden", consts.StatusForbidden) return } } else if err != nil { hlog.SystemLogger().Errorf("Cannot open file=%q, error=%s", filePath, err) if h.pathNotFound == nil { ctx.AbortWithMsg("Cannot open requested path", consts.StatusNotFound) } else { ctx.SetStatusCode(consts.StatusNotFound) h.pathNotFound(c, ctx) } return } h.cacheLock.Lock() ff1, ok := fileCache[pathStr] if !ok { fileCache[pathStr] = ff ff.readersCount++ } else { ff1.readersCount++ } h.cacheLock.Unlock() if ok { // The file has been already opened by another // goroutine, so close the current file and use // the file opened by another goroutine instead. ff.Release() ff = ff1 } } if !ctx.IfModifiedSince(ff.lastModified) { ff.decReadersCount() ctx.NotModified() return } r, err := ff.NewReader() if err != nil { hlog.SystemLogger().Errorf("Cannot obtain file reader for path=%q, error=%s", path, err) ctx.AbortWithMsg("Internal Server Error", consts.StatusInternalServerError) return } hdr := &ctx.Response.Header if ff.compressed { hdr.SetContentEncodingBytes(bytestr.StrGzip) } statusCode := consts.StatusOK contentLength := ff.contentLength if h.acceptByteRange { hdr.SetCanonical(bytestr.StrAcceptRanges, bytestr.StrBytes) if len(byteRange) > 0 { startPos, endPos, err := ParseByteRange(byteRange, contentLength) if err != nil { r.(io.Closer).Close() hlog.SystemLogger().Errorf("Cannot parse byte range %q for path=%q,error=%s", byteRange, path, err) ctx.AbortWithMsg("Range Not Satisfiable", consts.StatusRequestedRangeNotSatisfiable) return } if err = r.(byteRangeUpdater).UpdateByteRange(startPos, endPos); err != nil { r.(io.Closer).Close() hlog.SystemLogger().Errorf("Cannot seek byte range %q for path=%q, error=%s", byteRange, path, err) ctx.AbortWithMsg("Internal Server Error", consts.StatusInternalServerError) return } hdr.SetContentRange(startPos, endPos, contentLength) contentLength = endPos - startPos + 1 statusCode = consts.StatusPartialContent } } hdr.SetCanonical(bytestr.StrLastModified, ff.lastModifiedStr) if !ctx.IsHead() { ctx.SetBodyStream(r, contentLength) } else { ctx.Response.ResetBody() ctx.Response.SkipBody = true ctx.Response.Header.SetContentLength(contentLength) if rc, ok := r.(io.Closer); ok { if err := rc.Close(); err != nil { hlog.SystemLogger().Errorf("Cannot close file reader: error=%s", err) ctx.AbortWithMsg("Internal Server Error", consts.StatusInternalServerError) return } } } hdr.SetNoDefaultContentType(true) if len(hdr.ContentType()) == 0 { ctx.SetContentType(ff.contentType) } ctx.SetStatusCode(statusCode) } type fsFile struct { h *fsHandler f *os.File dirIndex []byte contentType string contentLength int compressed bool lastModified time.Time lastModifiedStr []byte t time.Time readersCount int bigFiles []*bigFileReader bigFilesLock sync.Mutex } func (ff *fsFile) Release() { if ff.f != nil { ff.f.Close() if ff.isBig() { ff.bigFilesLock.Lock() for _, r := range ff.bigFiles { r.f.Close() } ff.bigFilesLock.Unlock() } } } func (ff *fsFile) isBig() bool { return ff.contentLength > consts.MaxSmallFileSize && len(ff.dirIndex) == 0 } func cleanCacheNolock(cache map[string]*fsFile, pendingFiles, filesToRelease []*fsFile, cacheDuration time.Duration) ([]*fsFile, []*fsFile) { t := time.Now() for k, ff := range cache { if t.Sub(ff.t) > cacheDuration { if ff.readersCount > 0 { // There are pending readers on stale file handle, // so we cannot close it. Put it into pendingFiles // so it will be closed later. pendingFiles = append(pendingFiles, ff) } else { filesToRelease = append(filesToRelease, ff) } delete(cache, k) } } return pendingFiles, filesToRelease } func stripTrailingSlashes(path []byte) []byte { for len(path) > 0 && path[len(path)-1] == '/' { path = path[:len(path)-1] } return path } func isFileCompressible(f *os.File, minCompressRatio float64) bool { // Try compressing the first 4kb of the file // and see if it can be compressed by more than // the given minCompressRatio. b := bytebufferpool.Get() zw := compress.AcquireStacklessGzipWriter(b, compress.CompressDefaultCompression) lr := &io.LimitedReader{ R: f, N: 4096, } zrw := network.NewWriter(zw) _, err := utils.CopyZeroAlloc(zrw, lr) compress.ReleaseStacklessGzipWriter(zw, compress.CompressDefaultCompression) f.Seek(0, 0) //nolint:errcheck if err != nil { return false } n := 4096 - lr.N zn := len(b.B) bytebufferpool.Put(b) return float64(zn) < float64(n)*minCompressRatio } var ( filesLockMap = make(map[string]*sync.Mutex) filesLockMapLock sync.Mutex ) func getFileLock(absPath string) *sync.Mutex { filesLockMapLock.Lock() flock := filesLockMap[absPath] if flock == nil { flock = &sync.Mutex{} filesLockMap[absPath] = flock } filesLockMapLock.Unlock() return flock } func fileExtension(path string, compressed bool, compressedFileSuffix string) string { if compressed && strings.HasSuffix(path, compressedFileSuffix) { path = path[:len(path)-len(compressedFileSuffix)] } n := strings.LastIndexByte(path, '.') if n < 0 { return "" } return path[n:] } func readFileHeader(f *os.File, compressed bool) ([]byte, error) { r := io.Reader(f) var zr *gzip.Reader if compressed { var err error if zr, err = compress.AcquireGzipReader(f); err != nil { return nil, err } r = zr } lr := &io.LimitedReader{ R: r, N: 512, } data, err := ioutil.ReadAll(lr) if _, err := f.Seek(0, 0); err != nil { return nil, err } if zr != nil { compress.ReleaseGzipReader(zr) } return data, err } func fsModTime(t time.Time) time.Time { return t.In(time.UTC).Truncate(time.Second) } // ParseByteRange parses 'Range: bytes=...' header value. // // It follows https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.35 . func ParseByteRange(byteRange []byte, contentLength int) (startPos, endPos int, err error) { b := byteRange if !bytes.HasPrefix(b, bytestr.StrBytes) { return 0, 0, fmt.Errorf("unsupported range units: %q. Expecting %q", byteRange, bytestr.StrBytes) } b = b[len(bytestr.StrBytes):] if len(b) == 0 || b[0] != '=' { return 0, 0, fmt.Errorf("missing byte range in %q", byteRange) } b = b[1:] n := bytes.IndexByte(b, '-') if n < 0 { return 0, 0, fmt.Errorf("missing the end position of byte range in %q", byteRange) } if n == 0 { v, err := bytesconv.ParseUint(b[n+1:]) if err != nil { return 0, 0, err } startPos := contentLength - v if startPos < 0 { startPos = 0 } return startPos, contentLength - 1, nil } if startPos, err = bytesconv.ParseUint(b[:n]); err != nil { return 0, 0, err } if startPos >= contentLength { return 0, 0, fmt.Errorf("the start position of byte range cannot exceed %d. byte range %q", contentLength-1, byteRange) } b = b[n+1:] if len(b) == 0 { return startPos, contentLength - 1, nil } if endPos, err = bytesconv.ParseUint(b); err != nil { return 0, 0, err } if endPos >= contentLength { endPos = contentLength - 1 } if endPos < startPos { return 0, 0, fmt.Errorf("the start position of byte range cannot exceed the end position. byte range %q", byteRange) } return startPos, endPos, nil } // NewVHostPathRewriter returns path rewriter, which strips slashesCount // leading slashes from the path and prepends the path with request's host, // thus simplifying virtual hosting for static files. // // Examples: // // - host=foobar.com, slashesCount=0, original path="/foo/bar". // Resulting path: "/foobar.com/foo/bar" // // - host=img.aaa.com, slashesCount=1, original path="/images/123/456.jpg" // Resulting path: "/img.aaa.com/123/456.jpg" func NewVHostPathRewriter(slashesCount int) PathRewriteFunc { return func(ctx *RequestContext) []byte { path := stripLeadingSlashes(ctx.Path(), slashesCount) host := ctx.Host() if n := bytes.IndexByte(host, '/'); n >= 0 { host = nil } if len(host) == 0 { host = strInvalidHost } b := bytebufferpool.Get() b.B = append(b.B, '/') b.B = append(b.B, host...) b.B = append(b.B, path...) ctx.URI().SetPathBytes(b.B) bytebufferpool.Put(b) return ctx.Path() } } func stripLeadingSlashes(path []byte, stripSlashes int) []byte { for stripSlashes > 0 && len(path) > 0 { if path[0] != '/' { panic("BUG: path must start with slash") } n := bytes.IndexByte(path[1:], '/') if n < 0 { path = path[:0] break } path = path[n+1:] stripSlashes-- } return path } // ServeFileUncompressed returns HTTP response containing file contents // from the given path. // // Directory contents is returned if path points to directory. // // ServeFile may be used for saving network traffic when serving files // with good compression ratio. func ServeFileUncompressed(ctx *RequestContext, path string) { ctx.Request.Header.DelBytes(bytestr.StrAcceptEncoding) ServeFile(ctx, path) } // NewPathSlashesStripper returns path rewriter, which strips slashesCount // leading slashes from the path. // // Examples: // // - slashesCount = 0, original path: "/foo/bar", result: "/foo/bar" // - slashesCount = 1, original path: "/foo/bar", result: "/bar" // - slashesCount = 2, original path: "/foo/bar", result: "" // // The returned path rewriter may be used as FS.PathRewrite . func NewPathSlashesStripper(slashesCount int) PathRewriteFunc { return func(ctx *RequestContext) []byte { return stripLeadingSlashes(ctx.Path(), slashesCount) } } ================================================ FILE: pkg/app/fs_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package app import ( "bytes" "context" "fmt" "io" "io/ioutil" "math/rand" "os" "path" "testing" "time" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/test/mock" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/protocol/http1/resp" ) func TestNewVHostPathRewriter(t *testing.T) { t.Parallel() var ctx RequestContext var req protocol.Request req.Header.SetHost("foobar.com") req.SetRequestURI("/foo/bar/baz") req.CopyTo(&ctx.Request) f := NewVHostPathRewriter(0) path := f(&ctx) expectedPath := "/foobar.com/foo/bar/baz" if string(path) != expectedPath { t.Fatalf("unexpected path %q. Expecting %q", path, expectedPath) } ctx.Request.Reset() ctx.Request.SetRequestURI("https://aaa.bbb.cc/one/two/three/four?asdf=dsf") f = NewVHostPathRewriter(2) path = f(&ctx) expectedPath = "/aaa.bbb.cc/three/four" if string(path) != expectedPath { t.Fatalf("unexpected path %q. Expecting %q", path, expectedPath) } } func TestNewVHostPathRewriterMaliciousHost(t *testing.T) { var ctx RequestContext var req protocol.Request req.Header.SetHost("/../../../etc/passwd") req.SetRequestURI("/foo/bar/baz") req.CopyTo(&ctx.Request) f := NewVHostPathRewriter(0) path := f(&ctx) expectedPath := "/invalid-host/foo/bar/baz" if string(path) != expectedPath { t.Fatalf("unexpected path %q. Expecting %q", path, expectedPath) } } func testPathNotFound(t *testing.T, pathNotFoundFunc HandlerFunc) { var ctx RequestContext var req protocol.Request req.SetRequestURI("http//some.url/file") req.CopyTo(&ctx.Request) fs := &FS{ Root: "./", PathNotFound: pathNotFoundFunc, } fs.NewRequestHandler()(context.Background(), &ctx) if pathNotFoundFunc == nil { // different to ... if !bytes.Equal(ctx.Response.Body(), []byte("Cannot open requested path")) { t.Fatalf("response defers. Response: %q", ctx.Response.Body()) } } else { // Equals to ... if bytes.Equal(ctx.Response.Body(), []byte("Cannot open requested path")) { t.Fatalf("response defers. Response: %q", ctx.Response.Body()) } } } func TestPathNotFound(t *testing.T) { t.Parallel() testPathNotFound(t, nil) } func TestPathNotFoundFunc(t *testing.T) { t.Parallel() testPathNotFound(t, func(c context.Context, ctx *RequestContext) { ctx.WriteString("Not found hehe") //nolint:errcheck }) } func TestServeFileHead(t *testing.T) { t.Parallel() var ctx RequestContext var req protocol.Request req.Header.SetMethod(consts.MethodHead) req.SetRequestURI("http://foobar.com/baz") req.CopyTo(&ctx.Request) ServeFile(&ctx, "fs.go") var r protocol.Response r.SkipBody = true s := resp.GetHTTP1Response(&ctx.Response).String() zr := mock.NewZeroCopyReader(s) if err := resp.Read(&r, zr); err != nil { t.Fatalf("unexpected error: %s", err) } ce := r.Header.ContentEncoding() if len(ce) > 0 { t.Fatalf("Unexpected 'Content-Encoding' %q", ce) } body := r.Body() if len(body) > 0 { t.Fatalf("unexpected response body %q. Expecting empty body", body) } expectedBody, err := getFileContents("/fs.go") if err != nil { t.Fatalf("unexpected error: %s", err) } contentLength := r.Header.ContentLength() if contentLength != len(expectedBody) { t.Fatalf("unexpected Content-Length: %d. expecting %d", contentLength, len(expectedBody)) } } func TestServeFileSmallNoReadFrom(t *testing.T) { t.Parallel() teststr := "hello, world!" tempdir, err := ioutil.TempDir("", "httpexpect") if err != nil { t.Fatal(err) } defer os.RemoveAll(tempdir) if err := ioutil.WriteFile( path.Join(tempdir, "hello"), []byte(teststr), 0o666); err != nil { t.Fatal(err) } var ctx RequestContext var req protocol.Request req.SetRequestURI("http://foobar.com/baz") req.CopyTo(&ctx.Request) ServeFile(&ctx, path.Join(tempdir, "hello")) reader, ok := ctx.Response.BodyStream().(*fsSmallFileReader) if !ok { t.Fatal("expected fsSmallFileReader") } buf := bytes.NewBuffer(nil) n, err := reader.WriteTo(pureWriter{buf}) if err != nil { t.Fatal(err) } if n != int64(len(teststr)) { t.Fatalf("expected %d bytes, got %d bytes", len(teststr), n) } body := buf.String() if body != teststr { t.Fatalf("expected '%s'", teststr) } data := make([]byte, len([]byte(teststr))) nn, err := reader.Read(data) assert.DeepEqual(t, len([]byte(teststr)), nn) assert.Nil(t, err) assert.DeepEqual(t, teststr, string(data)) assert.DeepEqual(t, reader.startPos, len([]byte(teststr))) nn, err = reader.Read(data) assert.DeepEqual(t, 0, nn) assert.DeepEqual(t, io.EOF, err) data1 := make([]byte, 2) reader.startPos = len([]byte(teststr)) - 1 nn, err = reader.Read(data1) assert.DeepEqual(t, []byte("!"), []byte{data1[0]}) assert.DeepEqual(t, 1, nn) assert.DeepEqual(t, nil, err) reader.startPos = 0 reader.ff.f = nil buf = bytes.NewBuffer(nil) reader.ff.dirIndex = make([]byte, len([]byte(teststr))) n, err = reader.WriteTo(pureWriter{buf}) assert.DeepEqual(t, int64(len(teststr)), n) assert.Nil(t, err) } type pureWriter struct { w io.Writer } func (pw pureWriter) Write(p []byte) (nn int, err error) { return pw.w.Write(p) } func TestServeFileCompressed(t *testing.T) { t.Parallel() var ctx RequestContext var req protocol.Request req.SetRequestURI("http://foobar.com/baz") req.Header.Set(consts.HeaderAcceptEncoding, "gzip") req.CopyTo(&ctx.Request) ServeFile(&ctx, "fs.go") var r protocol.Response s := resp.GetHTTP1Response(&ctx.Response).String() zr := mock.NewZeroCopyReader(s) if err := resp.Read(&r, zr); err != nil { t.Fatalf("unexpected error: %s", err) } ce := r.Header.ContentEncoding() if string(ce) != "gzip" { t.Fatalf("Unexpected 'Content-Encoding' %q. Expecting %q", ce, "gzip") } body, err := r.BodyGunzip() if err != nil { t.Fatalf("unexpected error: %s", err) } expectedBody, err := getFileContents("/fs.go") if err != nil { t.Fatalf("unexpected error: %s", err) } if !bytes.Equal(body, expectedBody) { t.Fatalf("unexpected body %q. expecting %q", body, expectedBody) } } func TestServeFileUncompressed(t *testing.T) { t.Parallel() var ctx RequestContext var req protocol.Request req.SetRequestURI("http://foobar.com/baz") req.Header.Set(consts.HeaderAcceptEncoding, "gzip") req.CopyTo(&ctx.Request) ServeFileUncompressed(&ctx, "fs.go") var r protocol.Response s := resp.GetHTTP1Response(&ctx.Response).String() zr := mock.NewZeroCopyReader(s) if err := resp.Read(&r, zr); err != nil { t.Fatalf("unexpected error: %s", err) } ce := r.Header.ContentEncoding() if len(ce) > 0 { t.Fatalf("Unexpected 'Content-Encoding' %q", ce) } body := r.Body() expectedBody, err := getFileContents("/fs.go") if err != nil { t.Fatalf("unexpected error: %s", err) } if !bytes.Equal(body, expectedBody) { t.Fatalf("unexpected body %q. expecting %q", body, expectedBody) } } func TestFSByteRangeConcurrent(t *testing.T) { t.Parallel() fs := &FS{ Root: ".", AcceptByteRange: true, } h := fs.NewRequestHandler() concurrency := 10 ch := make(chan struct{}, concurrency) for i := 0; i < concurrency; i++ { go func() { for j := 0; j < 5; j++ { testFSByteRange(t, h, "/fs.go") } ch <- struct{}{} }() } for i := 0; i < concurrency; i++ { select { case <-time.After(time.Second): t.Fatalf("timeout") case <-ch: } } } func TestFSByteRangeSingleThread(t *testing.T) { t.Parallel() fs := &FS{ Root: ".", AcceptByteRange: true, } h := fs.NewRequestHandler() testFSByteRange(t, h, "/fs.go") } func testFSByteRange(t *testing.T, h HandlerFunc, filePath string) { var ctx RequestContext req := &protocol.Request{} req.CopyTo(&ctx.Request) expectedBody, err := getFileContents(filePath) if err != nil { t.Fatalf("cannot read file %q: %s", filePath, err) } fileSize := len(expectedBody) startPos := rand.Intn(fileSize) endPos := rand.Intn(fileSize) if endPos < startPos { startPos, endPos = endPos, startPos } ctx.Request.SetRequestURI(filePath) ctx.Request.Header.SetByteRange(startPos, endPos) h(context.Background(), &ctx) var r protocol.Response s := resp.GetHTTP1Response(&ctx.Response).String() zr := mock.NewZeroCopyReader(s) if err := resp.Read(&r, zr); err != nil { t.Fatalf("unexpected error: %s. filePath=%q", err, filePath) } if r.StatusCode() != consts.StatusPartialContent { t.Fatalf("unexpected status code: %d. Expecting %d. filePath=%q", r.StatusCode(), consts.StatusPartialContent, filePath) } cr := r.Header.Peek(consts.HeaderContentRange) expectedCR := fmt.Sprintf("bytes %d-%d/%d", startPos, endPos, fileSize) if string(cr) != expectedCR { t.Fatalf("unexpected content-range %q. Expecting %q. filePath=%q", cr, expectedCR, filePath) } body := r.Body() bodySize := endPos - startPos + 1 if len(body) != bodySize { t.Fatalf("unexpected body size %d. Expecting %d. filePath=%q, startPos=%d, endPos=%d", len(body), bodySize, filePath, startPos, endPos) } expectedBody = expectedBody[startPos : endPos+1] if !bytes.Equal(body, expectedBody) { t.Fatalf("unexpected body %q. Expecting %q. filePath=%q, startPos=%d, endPos=%d", body, expectedBody, filePath, startPos, endPos) } } func getFileContents(path string) ([]byte, error) { path = "." + path f, err := os.Open(path) if err != nil { return nil, err } defer f.Close() return ioutil.ReadAll(f) } func TestParseByteRangeSuccess(t *testing.T) { t.Parallel() testParseByteRangeSuccess(t, "bytes=0-0", 1, 0, 0) testParseByteRangeSuccess(t, "bytes=1234-6789", 6790, 1234, 6789) testParseByteRangeSuccess(t, "bytes=123-", 456, 123, 455) testParseByteRangeSuccess(t, "bytes=-1", 1, 0, 0) testParseByteRangeSuccess(t, "bytes=-123", 456, 333, 455) // End position exceeding content-length. It should be updated to content-length-1. // See https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.35 testParseByteRangeSuccess(t, "bytes=1-2345", 234, 1, 233) testParseByteRangeSuccess(t, "bytes=0-2345", 2345, 0, 2344) // Start position overflow. Whole range must be returned. // See https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.35 testParseByteRangeSuccess(t, "bytes=-567", 56, 0, 55) } func testParseByteRangeSuccess(t *testing.T, v string, contentLength, startPos, endPos int) { startPos1, endPos1, err := ParseByteRange([]byte(v), contentLength) if err != nil { t.Fatalf("unexpected error: %s. v=%q, contentLength=%d", err, v, contentLength) } if startPos1 != startPos { t.Fatalf("unexpected startPos=%d. Expecting %d. v=%q, contentLength=%d", startPos1, startPos, v, contentLength) } if endPos1 != endPos { t.Fatalf("unexpected endPos=%d. Expectind %d. v=%q, contentLength=%d", endPos1, endPos, v, contentLength) } } func TestParseByteRangeError(t *testing.T) { t.Parallel() // invalid value testParseByteRangeError(t, "asdfasdfas", 1234) // invalid units testParseByteRangeError(t, "foobar=1-34", 600) // missing '-' testParseByteRangeError(t, "bytes=1234", 1235) // non-numeric range testParseByteRangeError(t, "bytes=foobar", 123) testParseByteRangeError(t, "bytes=1-foobar", 123) testParseByteRangeError(t, "bytes=df-344", 545) // multiple byte ranges testParseByteRangeError(t, "bytes=1-2,4-6", 123) // byte range exceeding contentLength testParseByteRangeError(t, "bytes=123-", 12) // startPos exceeding endPos testParseByteRangeError(t, "bytes=123-34", 1234) } func testParseByteRangeError(t *testing.T, v string, contentLength int) { _, _, err := ParseByteRange([]byte(v), contentLength) if err == nil { t.Fatalf("expecting error when parsing byte range %q", v) } } func TestFSCompressConcurrent(t *testing.T) { // This test can't run parallel as files in / might by changed by other tests. fs := &FS{ Root: ".", GenerateIndexPages: true, Compress: true, } h := fs.NewRequestHandler() concurrency := 4 ch := make(chan struct{}, concurrency) for i := 0; i < concurrency; i++ { go func() { for j := 0; j < 5; j++ { testFSCompress(t, h, "/fs.go") testFSCompress(t, h, "/") } ch <- struct{}{} }() } for i := 0; i < concurrency; i++ { select { case <-ch: case <-time.After(time.Second): t.Fatalf("timeout") } } } func TestFSCompressSingleThread(t *testing.T) { // This test can't run parallel as files in / might by changed by other tests. fs := &FS{ Root: ".", GenerateIndexPages: true, Compress: true, } h := fs.NewRequestHandler() testFSCompress(t, h, "/fs.go") testFSCompress(t, h, "/") } func testFSCompress(t *testing.T, h HandlerFunc, filePath string) { var ctx RequestContext req := &protocol.Request{} req.CopyTo(&ctx.Request) // request uncompressed file ctx.Request.Reset() ctx.Request.SetRequestURI(filePath) h(context.Background(), &ctx) var r protocol.Response s := resp.GetHTTP1Response(&ctx.Response).String() zr := mock.NewZeroCopyReader(s) if err := resp.Read(&r, zr); err != nil { t.Fatalf("unexpected error: %s. filePath=%q", err, filePath) } if r.StatusCode() != consts.StatusOK { t.Fatalf("unexpected status code: %d. Expecting %d. filePath=%q", r.StatusCode(), consts.StatusOK, filePath) } ce := r.Header.ContentEncoding() if string(ce) != "" { t.Fatalf("unexpected content-encoding %q. Expecting empty string. filePath=%q", ce, filePath) } body := string(r.Body()) // request compressed file ctx.Request.Reset() ctx.Request.SetRequestURI(filePath) ctx.Request.Header.Set(consts.HeaderAcceptEncoding, "gzip") h(context.Background(), &ctx) s = resp.GetHTTP1Response(&ctx.Response).String() zr = mock.NewZeroCopyReader(s) if err := resp.Read(&r, zr); err != nil { t.Fatalf("unexpected error: %s. filePath=%q", err, filePath) } if r.StatusCode() != consts.StatusOK { t.Fatalf("unexpected status code: %d. Expecting %d. filePath=%q", r.StatusCode(), consts.StatusOK, filePath) } ce = r.Header.ContentEncoding() if string(ce) != "gzip" { t.Fatalf("unexpected content-encoding %q. Expecting %q. filePath=%q", ce, "gzip", filePath) } zbody, err := r.BodyGunzip() if err != nil { t.Fatalf("unexpected error when gunzipping response body: %s. filePath=%q", err, filePath) } if string(zbody) != body { t.Fatalf("unexpected body len=%d. Expected len=%d. FilePath=%q", len(zbody), len(body), filePath) } } func TestFileLock(t *testing.T) { t.Parallel() for i := 0; i < 10; i++ { filePath := fmt.Sprintf("foo/bar/%d.jpg", i) lock := getFileLock(filePath) lock.Lock() time.Sleep(time.Microsecond) lock.Unlock() // nolint:staticcheck } for i := 0; i < 10; i++ { filePath := fmt.Sprintf("foo/bar/%d.jpg", i) lock := getFileLock(filePath) lock.Lock() time.Sleep(time.Microsecond) lock.Unlock() // nolint:staticcheck } } func TestStripPathSlashes(t *testing.T) { t.Parallel() testStripPathSlashes(t, "", 0, "") testStripPathSlashes(t, "", 10, "") testStripPathSlashes(t, "/", 0, "") testStripPathSlashes(t, "/", 1, "") testStripPathSlashes(t, "/", 10, "") testStripPathSlashes(t, "/foo/bar/baz", 0, "/foo/bar/baz") testStripPathSlashes(t, "/foo/bar/baz", 1, "/bar/baz") testStripPathSlashes(t, "/foo/bar/baz", 2, "/baz") testStripPathSlashes(t, "/foo/bar/baz", 3, "") testStripPathSlashes(t, "/foo/bar/baz", 10, "") // trailing slash testStripPathSlashes(t, "/foo/bar/", 0, "/foo/bar") testStripPathSlashes(t, "/foo/bar/", 1, "/bar") testStripPathSlashes(t, "/foo/bar/", 2, "") testStripPathSlashes(t, "/foo/bar/", 3, "") } func testStripPathSlashes(t *testing.T, path string, stripSlashes int, expectedPath string) { s := stripLeadingSlashes([]byte(path), stripSlashes) s = stripTrailingSlashes(s) if string(s) != expectedPath { t.Fatalf("unexpected path after stripping %q with stripSlashes=%d: %q. Expecting %q", path, stripSlashes, s, expectedPath) } } func TestFileExtension(t *testing.T) { t.Parallel() testFileExtension(t, "foo.bar", false, "zzz", ".bar") testFileExtension(t, "foobar", false, "zzz", "") testFileExtension(t, "foo.bar.baz", false, "zzz", ".baz") testFileExtension(t, "", false, "zzz", "") testFileExtension(t, "/a/b/c.d/efg.jpg", false, ".zzz", ".jpg") testFileExtension(t, "foo.bar", true, ".zzz", ".bar") testFileExtension(t, "foobar.zzz", true, ".zzz", "") testFileExtension(t, "foo.bar.baz.hertz.gz", true, ".hertz.gz", ".baz") testFileExtension(t, "", true, ".zzz", "") testFileExtension(t, "/a/b/c.d/efg.jpg.xxx", true, ".xxx", ".jpg") } func testFileExtension(t *testing.T, path string, compressed bool, compressedFileSuffix, expectedExt string) { ext := fileExtension(path, compressed, compressedFileSuffix) if ext != expectedExt { t.Fatalf("unexpected file extension for file %q: %q. Expecting %q", path, ext, expectedExt) } } func TestServeFileContentType(t *testing.T) { t.Parallel() var ctx RequestContext var req protocol.Request req.Header.SetMethod(consts.MethodGet) req.SetRequestURI("http://foobar.com/baz") req.CopyTo(&ctx.Request) ServeFile(&ctx, "../common/testdata/test.png") var r protocol.Response s := resp.GetHTTP1Response(&ctx.Response).String() zr := mock.NewZeroCopyReader(s) if err := resp.Read(&r, zr); err != nil { t.Fatalf("unexpected error: %s", err) } expected := []byte(consts.MIMEImagePNG) if !bytes.Equal(r.Header.ContentType(), expected) { t.Fatalf("Unexpected Content-Type, expected: %q got %q", expected, r.Header.ContentType()) } } func TestFileSmallUpdateByteRange(t *testing.T) { r := &fsSmallFileReader{} err := r.UpdateByteRange(1, 1) assert.Nil(t, err) assert.DeepEqual(t, 1, r.startPos) assert.DeepEqual(t, 2, r.endPos) } ================================================ FILE: pkg/app/middlewares/client/sd/discovery.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package sd import ( "context" "github.com/cloudwego/hertz/pkg/app/client" "github.com/cloudwego/hertz/pkg/app/client/discovery" "github.com/cloudwego/hertz/pkg/app/client/loadbalance" "github.com/cloudwego/hertz/pkg/protocol" ) // Discovery will construct a middleware with BalancerFactory. func Discovery(resolver discovery.Resolver, opts ...ServiceDiscoveryOption) client.Middleware { options := &ServiceDiscoveryOptions{ Balancer: loadbalance.NewWeightedBalancer(), LbOpts: loadbalance.DefaultLbOpts, Resolver: resolver, } options.Apply(opts) lbConfig := loadbalance.Config{ Resolver: options.Resolver, Balancer: options.Balancer, LbOpts: options.LbOpts, } f := loadbalance.NewBalancerFactory(lbConfig) return func(next client.Endpoint) client.Endpoint { return func(ctx context.Context, req *protocol.Request, resp *protocol.Response) (err error) { if req.Options() != nil && req.Options().IsSD() { ins, err := f.GetInstance(ctx, req) if err != nil { return err } req.SetHost(ins.Address().String()) } return next(ctx, req, resp) } } } ================================================ FILE: pkg/app/middlewares/client/sd/discovery_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package sd import ( "context" "testing" "github.com/cloudwego/hertz/pkg/app/client/discovery" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/protocol" ) func TestDiscovery(t *testing.T) { inss := []discovery.Instance{ discovery.NewInstance("tcp", "127.0.0.1:8888", 10, nil), discovery.NewInstance("tcp", "127.0.0.1:8889", 10, nil), } r := &discovery.SynthesizedResolver{ TargetFunc: func(ctx context.Context, target *discovery.TargetInfo) string { return target.Host }, ResolveFunc: func(ctx context.Context, key string) (discovery.Result, error) { return discovery.Result{CacheKey: "svc1", Instances: inss}, nil }, NameFunc: func() string { return t.Name() }, } mw := Discovery(r) checkMdw := func(ctx context.Context, req *protocol.Request, resp *protocol.Response) (err error) { t.Log(string(req.Host())) assert.Assert(t, string(req.Host()) == "127.0.0.1:8888" || string(req.Host()) == "127.0.0.1:8889") return nil } for i := 0; i < 10; i++ { req := &protocol.Request{} resp := &protocol.Response{} req.Options().Apply([]config.RequestOption{config.WithSD(true)}) req.SetRequestURI("http://service_name") _ = mw(checkMdw)(context.Background(), req, resp) } } ================================================ FILE: pkg/app/middlewares/client/sd/options.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package sd import ( "context" "fmt" "net" "strings" "github.com/cloudwego/hertz/pkg/app/client/discovery" "github.com/cloudwego/hertz/pkg/app/client/loadbalance" "github.com/cloudwego/hertz/pkg/app/server/registry" ) // ServiceDiscoveryOptions service discovery option for client type ServiceDiscoveryOptions struct { // Resolver is used to client discovery Resolver discovery.Resolver // Balancer is used to client load balance Balancer loadbalance.Loadbalancer // LbOpts LoadBalance option LbOpts loadbalance.Options } func (o *ServiceDiscoveryOptions) Apply(opts []ServiceDiscoveryOption) { for _, op := range opts { op.F(o) } } type ServiceDiscoveryOption struct { F func(o *ServiceDiscoveryOptions) } // WithCustomizedAddrs specifies the target instance addresses when doing service discovery. // It overwrites the results from the Resolver func WithCustomizedAddrs(addrs ...string) ServiceDiscoveryOption { return ServiceDiscoveryOption{ F: func(o *ServiceDiscoveryOptions) { var ins []discovery.Instance for _, addr := range addrs { if _, err := net.ResolveTCPAddr("tcp", addr); err == nil { ins = append(ins, discovery.NewInstance("tcp", addr, registry.DefaultWeight, nil)) continue } if _, err := net.ResolveUnixAddr("unix", addr); err == nil { ins = append(ins, discovery.NewInstance("unix", addr, registry.DefaultWeight, nil)) continue } panic(fmt.Errorf("WithCustomizedAddrs: invalid '%s'", addr)) } if len(ins) == 0 { panic("WithCustomizedAddrs() requires at least one argument") } targets := strings.Join(addrs, ",") o.Resolver = &discovery.SynthesizedResolver{ ResolveFunc: func(ctx context.Context, key string) (discovery.Result, error) { return discovery.Result{ CacheKey: "fixed", Instances: ins, }, nil }, NameFunc: func() string { return targets }, TargetFunc: func(ctx context.Context, target *discovery.TargetInfo) string { return targets }, } }, } } // WithLoadBalanceOptions sets Loadbalancer and loadbalance options for hertz client func WithLoadBalanceOptions(lb loadbalance.Loadbalancer, options loadbalance.Options) ServiceDiscoveryOption { return ServiceDiscoveryOption{F: func(o *ServiceDiscoveryOptions) { o.LbOpts = options o.Balancer = lb }} } ================================================ FILE: pkg/app/middlewares/client/sd/options_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package sd import ( "context" "testing" "github.com/cloudwego/hertz/pkg/app/client/loadbalance" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestWithCustomizedAddrs(t *testing.T) { var options []ServiceDiscoveryOption options = append(options, WithCustomizedAddrs("127.0.0.1:8080", "/tmp/unix_ss")) opts := &ServiceDiscoveryOptions{} opts.Apply(options) assert.Assert(t, opts.Resolver.Name() == "127.0.0.1:8080,/tmp/unix_ss") res, err := opts.Resolver.Resolve(context.Background(), "") assert.Assert(t, err == nil) assert.Assert(t, res.Instances[0].Address().String() == "127.0.0.1:8080") assert.Assert(t, res.Instances[1].Address().String() == "/tmp/unix_ss") } func TestWithLoadBalanceOptions(t *testing.T) { balance := loadbalance.NewWeightedBalancer() var options []ServiceDiscoveryOption options = append(options, WithLoadBalanceOptions(balance, loadbalance.DefaultLbOpts)) opts := &ServiceDiscoveryOptions{} opts.Apply(options) assert.Assert(t, opts.Balancer.Name() == "weight_random") } ================================================ FILE: pkg/app/middlewares/server/basic_auth/basic_auth.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * The MIT License (MIT) * * Copyright (c) 2014 Manuel Martínez-Almeida * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors */ package basic_auth import ( "context" "encoding/base64" "net/http" "strconv" "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/pkg/app" ) // Accounts is an alias to map[string]string, construct with {"username":"password"} type Accounts map[string]string // pairs is an alias to map[string]string, which mean {"header":"username"} type pairs map[string]string func (p pairs) findValue(needle string) (v string, ok bool) { v, ok = p[needle] return } func constructPairs(accounts Accounts) pairs { length := len(accounts) p := make(pairs, length) for user, password := range accounts { value := "Basic " + base64.StdEncoding.EncodeToString(bytesconv.S2b(user+":"+password)) p[value] = user } return p } // BasicAuthForRealm returns a Basic HTTP Authorization middleware. It takes as arguments a map[string]string where // the key is the username and the value is the password, as well as the name of the Realm. // If the realm is empty, "Authorization Required" will be used by default. // (see http://tools.ietf.org/html/rfc2617#section-1.2) func BasicAuthForRealm(accounts Accounts, realm, userKey string) app.HandlerFunc { realm = "Basic realm=" + strconv.Quote(realm) p := constructPairs(accounts) return func(ctx context.Context, c *app.RequestContext) { // Search user in the slice of allowed credentials user, found := p.findValue(c.Request.Header.Get("Authorization")) if !found { // Credentials doesn't match, we return 401 and abort handlers chain. c.Header("WWW-Authenticate", realm) c.AbortWithStatus(http.StatusUnauthorized) return } // The user credentials was found, set user's id to key AuthUserKey in this context, the user's id can be read later using c.Set(userKey, user) } } // BasicAuth is a constructor of BasicAuth verifier to hertz middleware // It returns a Basic HTTP Authorization middleware. It takes as argument a map[string]string where // the key is the username and the value is the password. func BasicAuth(accounts Accounts) app.HandlerFunc { return BasicAuthForRealm(accounts, "Authorization Required", "user") } ================================================ FILE: pkg/app/middlewares/server/basic_auth/basic_auth_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * The MIT License (MIT) * * Copyright (c) 2014 Manuel Martínez-Almeida * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors */ package basic_auth import ( "context" "encoding/base64" "testing" "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestPairs(t *testing.T) { t1 := Accounts{"test1": "value1"} t2 := Accounts{"test2": "value2"} p1 := constructPairs(t1) p2 := constructPairs(t2) u1, ok1 := p1.findValue("Basic dGVzdDE6dmFsdWUx") u2, ok2 := p2.findValue("Basic dGVzdDI6dmFsdWUy") _, ok3 := p1.findValue("bad header") _, ok4 := p2.findValue("bad header") assert.True(t, ok1) assert.DeepEqual(t, "test1", u1) assert.True(t, ok2) assert.DeepEqual(t, "test2", u2) assert.False(t, ok3) assert.False(t, ok4) } func TestBasicAuth(t *testing.T) { userName1 := "user1" password1 := "value1" userName2 := "user2" password2 := "value2" c1 := app.RequestContext{} encodeStr := "Basic " + base64.StdEncoding.EncodeToString(bytesconv.S2b(userName1+":"+password1)) c1.Request.Header.Add("Authorization", encodeStr) t1 := Accounts{userName1: password1} handler := BasicAuth(t1) handler(context.TODO(), &c1) user, ok := c1.Get("user") assert.DeepEqual(t, userName1, user) assert.True(t, ok) c2 := app.RequestContext{} encodeStr = "Basic " + base64.StdEncoding.EncodeToString(bytesconv.S2b(userName2+":"+password2)) c2.Request.Header.Add("Authorization", encodeStr) handler(context.TODO(), &c2) user, ok = c2.Get("user") assert.Nil(t, user) assert.False(t, ok) } ================================================ FILE: pkg/app/middlewares/server/basic_auth/doc.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // The files in basic_auth package are forked from gin[github.com/gin-gonic/gin], // and we keep the original Copyright[Copyright 2014 gin authors] and License of gin for those files. // We also need to modify as we need, the modifications are Copyright of 2022 CloudWeGo Authors. // Thanks for gin authors! Below is the source code information: // Repo: github.com/gin-gonic/gin // Forked Version: v1.7.7 package basic_auth ================================================ FILE: pkg/app/middlewares/server/recovery/option.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package recovery import ( "context" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/common/hlog" "github.com/cloudwego/hertz/pkg/protocol/consts" ) type ( options struct { recoveryHandler func(c context.Context, ctx *app.RequestContext, err interface{}, stack []byte) } Option func(o *options) ) func defaultRecoveryHandler(c context.Context, ctx *app.RequestContext, err interface{}, stack []byte) { hlog.SystemLogger().CtxErrorf(c, "[Recovery] err=%v\nstack=%s", err, stack) ctx.AbortWithStatus(consts.StatusInternalServerError) } func newOptions(opts ...Option) *options { cfg := &options{ recoveryHandler: defaultRecoveryHandler, } for _, opt := range opts { opt(cfg) } return cfg } func WithRecoveryHandler(f func(c context.Context, ctx *app.RequestContext, err interface{}, stack []byte)) Option { return func(o *options) { o.recoveryHandler = f } } ================================================ FILE: pkg/app/middlewares/server/recovery/option_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package recovery import ( "context" "fmt" "testing" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/common/hlog" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol/consts" ) func TestDefaultOption(t *testing.T) { opts := newOptions() assert.DeepEqual(t, fmt.Sprintf("%p", defaultRecoveryHandler), fmt.Sprintf("%p", opts.recoveryHandler)) } func newRecoveryHandler(c context.Context, ctx *app.RequestContext, err interface{}, stack []byte) { hlog.SystemLogger().CtxErrorf(c, "[New Recovery] panic recovered:\n%s\n%s\n", err, stack) ctx.JSON(consts.StatusNotImplemented, utils.H{"msg": err.(string)}) } func TestOption(t *testing.T) { opts := newOptions(WithRecoveryHandler(newRecoveryHandler)) assert.DeepEqual(t, fmt.Sprintf("%p", newRecoveryHandler), fmt.Sprintf("%p", opts.recoveryHandler)) } ================================================ FILE: pkg/app/middlewares/server/recovery/recovery.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package recovery import ( "bytes" "context" "fmt" "io/ioutil" "runtime" "github.com/cloudwego/hertz/pkg/app" ) var ( dunno = []byte("???") centerDot = []byte("·") dot = []byte(".") slash = []byte("/") ) // Recovery returns a middleware that recovers from any panic. // By default, it will print the time, content, and stack information of the error and write a 500. // Overriding the Config configuration, you can customize the error printing logic. func Recovery(opts ...Option) app.HandlerFunc { cfg := newOptions(opts...) return func(c context.Context, ctx *app.RequestContext) { defer func() { if err := recover(); err != nil { stack := stack(3) cfg.recoveryHandler(c, ctx, err, stack) } }() ctx.Next(c) } } // stack returns a nicely formatted stack frame, skipping skip frames. func stack(skip int) []byte { buf := new(bytes.Buffer) // the returned data // As we loop, we open files and read them. These variables record the currently // loaded file. var lines [][]byte var lastFile string for i := skip; ; i++ { // Skip the expected number of frames pc, file, line, ok := runtime.Caller(i) if !ok { break } // Print this much at least. If we can't find the source, it won't show. fmt.Fprintf(buf, "%s:%d (0x%x)\n", file, line, pc) if file != lastFile { data, err := ioutil.ReadFile(file) if err != nil { continue } lines = bytes.Split(data, []byte{'\n'}) lastFile = file } fmt.Fprintf(buf, "\t%s: %s\n", function(pc), source(lines, line)) } return buf.Bytes() } // source returns a space-trimmed slice of the n'th line. func source(lines [][]byte, n int) []byte { n-- // in stack trace, lines are 1-indexed but our array is 0-indexed if n < 0 || n >= len(lines) { return dunno } return bytes.TrimSpace(lines[n]) } // function returns, if possible, the name of the function containing the PC. func function(pc uintptr) []byte { fn := runtime.FuncForPC(pc) if fn == nil { return dunno } name := []byte(fn.Name()) // The name includes the path name to the package, which is unnecessary // since the file name is already included. Plus, it has center dots. // That is, we see // runtime/debug.*T·ptrmethod // and want // *T.ptrmethod // Also the package path might contains dot (e.g. code.google.com/...), // so first eliminate the path prefix if lastSlash := bytes.LastIndex(name, slash); lastSlash >= 0 { name = name[lastSlash+1:] } if period := bytes.Index(name, dot); period >= 0 { name = name[period+1:] } name = bytes.Replace(name, centerDot, dot, -1) return name } ================================================ FILE: pkg/app/middlewares/server/recovery/recovery_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package recovery import ( "context" "fmt" "testing" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/protocol/consts" ) func TestRecovery(t *testing.T) { ctx := app.NewContext(0) var hc app.HandlersChain hc = append(hc, func(c context.Context, ctx *app.RequestContext) { fmt.Println("this is test") panic("test") }) ctx.SetHandlers(hc) Recovery()(context.Background(), ctx) if ctx.Response.StatusCode() != 500 { t.Fatalf("unexpected %v. Expecting %v", ctx.Response.StatusCode(), 500) } } func TestWithRecoveryHandler(t *testing.T) { ctx := app.NewContext(0) var hc app.HandlersChain hc = append(hc, func(c context.Context, ctx *app.RequestContext) { fmt.Println("this is test") panic("test") }) ctx.SetHandlers(hc) Recovery(WithRecoveryHandler(newRecoveryHandler))(context.Background(), ctx) if ctx.Response.StatusCode() != consts.StatusNotImplemented { t.Fatalf("unexpected %v. Expecting %v", ctx.Response.StatusCode(), 501) } assert.DeepEqual(t, "{\"msg\":\"test\"}", string(ctx.Response.Body())) } ================================================ FILE: pkg/app/server/binding/binder.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * MIT License * * Copyright (c) 2019-present Fenny and Contributors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2023 CloudWeGo Authors */ package binding import ( "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" ) type Binder interface { Name() string Bind(*protocol.Request, interface{}, param.Params) error BindQuery(*protocol.Request, interface{}) error BindHeader(*protocol.Request, interface{}) error BindPath(*protocol.Request, interface{}, param.Params) error BindForm(*protocol.Request, interface{}) error BindJSON(*protocol.Request, interface{}) error BindProtobuf(*protocol.Request, interface{}) error Validate(*protocol.Request, interface{}) error } ================================================ FILE: pkg/app/server/binding/binder_test.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * MIT License * * Copyright (c) 2019-present Fenny and Contributors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2023 CloudWeGo Authors */ package binding import ( "encoding" "encoding/json" "errors" "fmt" "mime/multipart" "net/url" "reflect" "testing" "time" "github.com/cloudwego/hertz/pkg/app/server/binding/testdata" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" req2 "github.com/cloudwego/hertz/pkg/protocol/http1/req" "github.com/cloudwego/hertz/pkg/route/param" "google.golang.org/protobuf/proto" ) type mockRequest struct { Req *protocol.Request } func newMockRequest() *mockRequest { return &mockRequest{ Req: &protocol.Request{}, } } func (m *mockRequest) SetRequestURI(uri string) *mockRequest { m.Req.SetRequestURI(uri) return m } func (m *mockRequest) SetFile(param, fileName string) *mockRequest { m.Req.SetFile(param, fileName) return m } func (m *mockRequest) SetHeader(key, value string) *mockRequest { m.Req.Header.Set(key, value) return m } func (m *mockRequest) SetHeaders(key, value string) *mockRequest { m.Req.Header.Set(key, value) return m } func (m *mockRequest) SetPostArg(key, value string) *mockRequest { m.Req.PostArgs().Add(key, value) return m } func (m *mockRequest) SetUrlEncodeContentType() *mockRequest { m.Req.Header.SetContentTypeBytes([]byte("application/x-www-form-urlencoded")) return m } func (m *mockRequest) SetJSONContentType() *mockRequest { m.Req.Header.SetContentTypeBytes([]byte(consts.MIMEApplicationJSON)) return m } func (m *mockRequest) SetProtobufContentType() *mockRequest { m.Req.Header.SetContentTypeBytes([]byte(consts.MIMEPROTOBUF)) return m } func (m *mockRequest) SetBody(data []byte) *mockRequest { m.Req.SetBody(data) m.Req.Header.SetContentLength(len(data)) return m } func TestBind_BaseType(t *testing.T) { type Req struct { Version int `path:"v"` ID int `query:"id"` Header string `header:"H"` Form string `form:"f"` } req := newMockRequest(). SetRequestURI("http://foobar.com?id=12"). SetHeaders("H", "header"). SetPostArg("f", "form"). SetUrlEncodeContentType() var params param.Params params = append(params, param.Param{ Key: "v", Value: "1", }) var result Req err := DefaultBinder().Bind(req.Req, &result, params) if err != nil { t.Error(err) } assert.DeepEqual(t, 1, result.Version) assert.DeepEqual(t, 12, result.ID) assert.DeepEqual(t, "header", result.Header) assert.DeepEqual(t, "form", result.Form) } func TestBind_SliceType(t *testing.T) { type Req struct { ID *[]int `query:"id"` Str [3]string `query:"str"` Byte []byte `query:"b"` HH []string `header:"h"` } IDs := []int{11, 12, 13} Strs := [3]string{"qwe", "asd", "zxc"} Bytes := []byte("123") Headers := []string{"header"} req := newMockRequest(). SetHeaders("H", Headers[0]). SetRequestURI(fmt.Sprintf("http://foobar.com?id=%d&id=%d&id=%d&str=%s&str=%s&str=%s&b=%d&b=%d&b=%d", IDs[0], IDs[1], IDs[2], Strs[0], Strs[1], Strs[2], Bytes[0], Bytes[1], Bytes[2])) var result Req err := DefaultBinder().Bind(req.Req, &result, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, 3, len(*result.ID)) for idx, val := range IDs { assert.DeepEqual(t, val, (*result.ID)[idx]) } assert.DeepEqual(t, 3, len(result.Str)) for idx, val := range Strs { assert.DeepEqual(t, val, result.Str[idx]) } assert.DeepEqual(t, 3, len(result.Byte)) for idx, val := range Bytes { assert.DeepEqual(t, val, result.Byte[idx]) } assert.DeepEqual(t, Headers, result.HH) } func TestBind_StructType(t *testing.T) { type FFF struct { F1 string `query:"F1"` } type TTT struct { T1 string `query:"F1"` T2 FFF } type Foo struct { F1 string `query:"F1"` F2 string `header:"f2"` F3 TTT } type Bar struct { B1 string `query:"B1"` B2 Foo `query:"B2"` } var result Bar req := newMockRequest().SetRequestURI("http://foobar.com?F1=f1&B1=b1").SetHeader("f2", "f2") err := DefaultBinder().Bind(req.Req, &result, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, "b1", result.B1) assert.DeepEqual(t, "f1", result.B2.F1) assert.DeepEqual(t, "f2", result.B2.F2) assert.DeepEqual(t, "f1", result.B2.F3.T1) assert.DeepEqual(t, "f1", result.B2.F3.T2.F1) } func TestBind_PointerType(t *testing.T) { type TT struct { T1 string `query:"F1"` } type Foo struct { F1 *TT `query:"F1"` F2 *******************string `query:"F1"` } type Bar struct { B1 ***string `query:"B1"` B2 ****Foo `query:"B2"` B3 []*string `query:"B3"` B4 [2]*int `query:"B4"` } result := Bar{} F1 := "f1" B1 := "b1" B2 := "b2" B3s := []string{"b31", "b32"} B4s := [2]int{0, 1} req := newMockRequest().SetRequestURI(fmt.Sprintf("http://foobar.com?F1=%s&B1=%s&B2=%s&B3=%s&B3=%s&B4=%d&B4=%d", F1, B1, B2, B3s[0], B3s[1], B4s[0], B4s[1])). SetHeader("f2", "f2") err := DefaultBinder().Bind(req.Req, &result, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, B1, ***result.B1) assert.DeepEqual(t, F1, (*(****result.B2).F1).T1) assert.DeepEqual(t, F1, *******************(****result.B2).F2) assert.DeepEqual(t, len(B3s), len(result.B3)) for idx, val := range B3s { assert.DeepEqual(t, val, *result.B3[idx]) } assert.DeepEqual(t, len(B4s), len(result.B4)) for idx, val := range B4s { assert.DeepEqual(t, val, *result.B4[idx]) } } func TestBind_NestedStruct(t *testing.T) { type Foo struct { F1 string `query:"F1"` } type Bar struct { Foo Nested struct { N1 string `query:"F1"` } } result := Bar{} req := newMockRequest().SetRequestURI("http://foobar.com?F1=qwe") err := DefaultBinder().Bind(req.Req, &result, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, "qwe", result.Foo.F1) assert.DeepEqual(t, "qwe", result.Nested.N1) } func TestBind_SliceStruct(t *testing.T) { type Foo struct { F1 string `json:"f1"` } type Bar struct { B1 []Foo `query:"F1"` } result := Bar{} B1s := []string{"1", "2", "3"} req := newMockRequest().SetRequestURI(fmt.Sprintf("http://foobar.com?F1={\"f1\":\"%s\"}&F1={\"f1\":\"%s\"}&F1={\"f1\":\"%s\"}", B1s[0], B1s[1], B1s[2])) err := DefaultBinder().Bind(req.Req, &result, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, len(result.B1), len(B1s)) for idx, val := range B1s { assert.DeepEqual(t, B1s[idx], val) } } func TestBind_MapType(t *testing.T) { var result map[string]string req := newMockRequest(). SetJSONContentType(). SetBody([]byte(`{"j1":"j1", "j2":"j2"}`)) err := DefaultBinder().Bind(req.Req, &result, nil) if err != nil { t.Fatal(err) } assert.DeepEqual(t, 2, len(result)) assert.DeepEqual(t, "j1", result["j1"]) assert.DeepEqual(t, "j2", result["j2"]) } func TestBind_MapFieldType(t *testing.T) { type Foo struct { F1 ***map[string]string `query:"f1" json:"f1"` } req := newMockRequest(). SetRequestURI("http://foobar.com?f1={\"f1\":\"f1\"}"). SetJSONContentType(). SetBody([]byte(`{"j1":"j1", "j2":"j2"}`)) result := Foo{} err := DefaultBinder().Bind(req.Req, &result, nil) if err != nil { t.Fatal(err) } assert.DeepEqual(t, 1, len(***result.F1)) assert.DeepEqual(t, "f1", (***result.F1)["f1"]) type Foo2 struct { F1 map[string]string `query:"f1" json:"f1"` } result2 := Foo2{} err = DefaultBinder().Bind(req.Req, &result2, nil) if err != nil { t.Fatal(err) } assert.DeepEqual(t, 1, len(result2.F1)) assert.DeepEqual(t, "f1", result2.F1["f1"]) req = newMockRequest(). SetRequestURI("http://foobar.com?f1={\"f1\":\"f1\"") result2 = Foo2{} err = DefaultBinder().Bind(req.Req, &result2, nil) if err == nil { t.Error(err) } } func TestBind_UnexportedField(t *testing.T) { var s struct { A int `query:"a"` b int `query:"b"` } req := newMockRequest(). SetRequestURI("http://foobar.com?a=1&b=2") err := DefaultBinder().Bind(req.Req, &s, nil) if err != nil { t.Fatal(err) } assert.DeepEqual(t, 1, s.A) assert.DeepEqual(t, 0, s.b) } func TestBind_NoTagField(t *testing.T) { var s struct { A string B string C string } req := newMockRequest(). SetRequestURI("http://foobar.com?B=b1&C=c1"). SetHeader("A", "a2") var params param.Params params = append(params, param.Param{ Key: "B", Value: "b2", }) err := DefaultBinder().Bind(req.Req, &s, params) if err != nil { t.Fatal(err) } assert.DeepEqual(t, "a2", s.A) assert.DeepEqual(t, "b2", s.B) assert.DeepEqual(t, "c1", s.C) } func TestBind_ZeroValueBind(t *testing.T) { var s struct { A int `query:"a"` B float64 `query:"b"` } req := newMockRequest(). SetRequestURI("http://foobar.com?a=&b") bindConfig := &BindConfig{} bindConfig.LooseZeroMode = true binder := NewDefaultBinder(bindConfig) err := binder.Bind(req.Req, &s, nil) if err != nil { t.Fatal(err) } assert.DeepEqual(t, 0, s.A) assert.DeepEqual(t, float64(0), s.B) } func TestBind_DefaultValueBind(t *testing.T) { var s struct { A int `default:"15"` B float64 `query:"b" default:"17"` C []int `default:"[15]"` D []string `default:"['qwe','asd']"` F [2]string `default:"['qwe','asd','zxc']"` } req := newMockRequest(). SetRequestURI("http://foobar.com") err := DefaultBinder().Bind(req.Req, &s, nil) if err != nil { t.Fatal(err) } assert.DeepEqual(t, 15, s.A) assert.DeepEqual(t, float64(17), s.B) assert.DeepEqual(t, 15, s.C[0]) assert.DeepEqual(t, 2, len(s.D)) assert.DeepEqual(t, "qwe", s.D[0]) assert.DeepEqual(t, "asd", s.D[1]) assert.DeepEqual(t, 2, len(s.F)) assert.DeepEqual(t, "qwe", s.F[0]) assert.DeepEqual(t, "asd", s.F[1]) var s2 struct { F [2]string `default:"['qwe']"` } err = DefaultBinder().Bind(req.Req, &s2, nil) if err != nil { t.Fatal(err) } assert.DeepEqual(t, 2, len(s2.F)) assert.DeepEqual(t, "qwe", s2.F[0]) assert.DeepEqual(t, "", s2.F[1]) var d struct { D [2]string `default:"qwe"` } err = DefaultBinder().Bind(req.Req, &d, nil) if err == nil { t.Fatal("expected err") } } func TestBind_RequiredBind(t *testing.T) { var s struct { A int `query:"a,required"` } req := newMockRequest(). SetRequestURI("http://foobar.com") err := DefaultBinder().Bind(req.Req, &s, nil) assert.DeepEqual(t, "'a' field is a 'required' parameter, but the request does not have this parameter", err.Error()) req = newMockRequest(). SetRequestURI("http://foobar.com"). SetHeader("A", "1") err = DefaultBinder().Bind(req.Req, &s, nil) if err == nil { t.Fatal("expected error") } var d struct { A int `query:"a,required" header:"A"` } err = DefaultBinder().Bind(req.Req, &d, nil) if err != nil { t.Fatal(err) } assert.DeepEqual(t, 1, d.A) } func TestBind_TypedefType(t *testing.T) { type Foo string type Bar *int type T struct { T1 string `query:"a"` } type TT T var s struct { A Foo `query:"a"` B Bar `query:"b"` T1 TT } req := newMockRequest(). SetRequestURI("http://foobar.com?a=1&b=2") err := DefaultBinder().Bind(req.Req, &s, nil) if err != nil { t.Fatal(err) } assert.DeepEqual(t, Foo("1"), s.A) assert.DeepEqual(t, 2, *s.B) assert.DeepEqual(t, "1", s.T1.T1) } type EnumType int64 const ( EnumType_TWEET EnumType = 0 EnumType_RETWEET EnumType = 2 ) func (p EnumType) String() string { switch p { case EnumType_TWEET: return "TWEET" case EnumType_RETWEET: return "RETWEET" } return "" } func TestBind_EnumBind(t *testing.T) { var s struct { A EnumType `query:"a"` B EnumType `query:"b"` } req := newMockRequest(). SetRequestURI("http://foobar.com?a=0&b=2") err := DefaultBinder().Bind(req.Req, &s, nil) if err != nil { t.Fatal(err) } } type CustomizedDecode struct { A string } func TestBind_CustomizedTypeDecode(t *testing.T) { type Foo struct { F ***CustomizedDecode `query:"a"` } bindConfig := &BindConfig{} err := bindConfig.RegTypeUnmarshal(reflect.TypeOf(CustomizedDecode{}), func(req *protocol.Request, params param.Params, text string) (reflect.Value, error) { q1 := req.URI().QueryArgs().Peek("a") if len(q1) == 0 { return reflect.Value{}, fmt.Errorf("can be nil") } val := CustomizedDecode{ A: string(q1), } return reflect.ValueOf(val), nil }) if err != nil { t.Fatal(err) } binder := NewDefaultBinder(bindConfig) req := newMockRequest(). SetRequestURI("http://foobar.com?a=1&b=2") result := Foo{} err = binder.Bind(req.Req, &result, nil) if err != nil { t.Fatal(err) } assert.DeepEqual(t, "1", (***result.F).A) type Bar struct { B *Foo } result2 := Bar{} err = binder.Bind(req.Req, &result2, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, "1", (***(*result2.B).F).A) } func TestBind_CustomizedTypeDecodeForPanic(t *testing.T) { defer func() { if r := recover(); r == nil { t.Errorf("expect a panic, but get nil") } }() bindConfig := &BindConfig{} bindConfig.MustRegTypeUnmarshal(reflect.TypeOf(string("")), func(req *protocol.Request, params param.Params, text string) (reflect.Value, error) { return reflect.Value{}, nil }) } func TestBind_JSON(t *testing.T) { type Req struct { J1 string `json:"j1"` J2 int `json:"j2" query:"j2"` // 1. json unmarshal 2. query binding cover J3 []byte `json:"j3"` J4 [2]string `json:"j4"` } J3s := []byte("12") J4s := [2]string{"qwe", "asd"} req := newMockRequest(). SetRequestURI("http://foobar.com?j2=13"). SetJSONContentType(). SetBody([]byte(fmt.Sprintf(`{"j1":"j1", "j2":12, "j3":[%d, %d], "j4":["%s", "%s"]}`, J3s[0], J3s[1], J4s[0], J4s[1]))) var result Req err := DefaultBinder().Bind(req.Req, &result, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, "j1", result.J1) assert.DeepEqual(t, 13, result.J2) for idx, val := range J3s { assert.DeepEqual(t, val, result.J3[idx]) } for idx, val := range J4s { assert.DeepEqual(t, val, result.J4[idx]) } } func TestBind_ResetJSONUnmarshal(t *testing.T) { bindConfig := &BindConfig{} bindConfig.UseStdJSONUnmarshaler() binder := NewDefaultBinder(bindConfig) type Req struct { J1 string `json:"j1"` J2 int `json:"j2"` J3 []byte `json:"j3"` J4 [2]string `json:"j4"` } J3s := []byte("12") J4s := [2]string{"qwe", "asd"} req := newMockRequest(). SetJSONContentType(). SetBody([]byte(fmt.Sprintf(`{"j1":"j1", "j2":12, "j3":[%d, %d], "j4":["%s", "%s"]}`, J3s[0], J3s[1], J4s[0], J4s[1]))) var result Req err := binder.Bind(req.Req, &result, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, "j1", result.J1) assert.DeepEqual(t, 12, result.J2) for idx, val := range J3s { assert.DeepEqual(t, val, result.J3[idx]) } for idx, val := range J4s { assert.DeepEqual(t, val, result.J4[idx]) } } func TestBind_FileBind(t *testing.T) { type Nest struct { N multipart.FileHeader `file_name:"d"` } var s struct { A *multipart.FileHeader `file_name:"a"` B *multipart.FileHeader `form:"b"` C multipart.FileHeader D **Nest `file_name:"d"` } fileName := "binder_test.go" req := newMockRequest(). SetRequestURI("http://foobar.com"). SetFile("a", fileName). SetFile("b", fileName). SetFile("C", fileName). SetFile("d", fileName) // to parse multipart files req2 := req2.GetHTTP1Request(req.Req) _ = req2.String() err := DefaultBinder().Bind(req.Req, &s, nil) if err != nil { t.Fatal(err) } assert.DeepEqual(t, fileName, s.A.Filename) assert.DeepEqual(t, fileName, s.B.Filename) assert.DeepEqual(t, fileName, s.C.Filename) assert.DeepEqual(t, fileName, (**s.D).N.Filename) } func TestBind_FileBindWithNoFile(t *testing.T) { var s struct { A *multipart.FileHeader `file_name:"a"` B *multipart.FileHeader `form:"b"` C *multipart.FileHeader } fileName := "binder_test.go" req := newMockRequest(). SetRequestURI("http://foobar.com"). SetFile("a", fileName). SetFile("b", fileName) // to parse multipart files req2 := req2.GetHTTP1Request(req.Req) _ = req2.String() err := DefaultBinder().Bind(req.Req, &s, nil) if err != nil { t.Fatalf("unexpected err: %v", err) } assert.DeepEqual(t, fileName, s.A.Filename) assert.DeepEqual(t, fileName, s.B.Filename) if s.C != nil { t.Fatalf("expected a nil for s.C") } } func TestBind_FileSliceBind(t *testing.T) { type Nest struct { N *[]*multipart.FileHeader `form:"b"` } var s struct { A []multipart.FileHeader `form:"a"` B [3]multipart.FileHeader `form:"b"` C []*multipart.FileHeader `form:"b"` D Nest } fileName := "binder_test.go" req := newMockRequest(). SetRequestURI("http://foobar.com"). SetFile("a", fileName). SetFile("a", fileName). SetFile("a", fileName). SetFile("b", fileName). SetFile("b", fileName). SetFile("b", fileName) // to parse multipart files req2 := req2.GetHTTP1Request(req.Req) _ = req2.String() err := DefaultBinder().Bind(req.Req, &s, nil) if err != nil { t.Fatal(err) } assert.DeepEqual(t, 3, len(s.A)) for _, file := range s.A { assert.DeepEqual(t, fileName, file.Filename) } assert.DeepEqual(t, 3, len(s.B)) for _, file := range s.B { assert.DeepEqual(t, fileName, file.Filename) } assert.DeepEqual(t, 3, len(s.C)) for _, file := range s.C { assert.DeepEqual(t, fileName, file.Filename) } assert.DeepEqual(t, 3, len(*s.D.N)) for _, file := range *s.D.N { assert.DeepEqual(t, fileName, file.Filename) } } func TestBind_AnonymousField(t *testing.T) { type nest struct { n1 string `query:"n1"` // bind default value N2 ***string `query:"n2"` // bind n2 value string `query:"n3"` // bind default value } var s struct { s1 int `query:"s1"` // bind default value int `query:"s2"` // bind default value nest } req := newMockRequest(). SetRequestURI("http://foobar.com?s1=1&s2=2&n1=1&n2=2&n3=3") err := DefaultBinder().Bind(req.Req, &s, nil) if err != nil { t.Fatal(err) } assert.DeepEqual(t, 0, s.s1) assert.DeepEqual(t, 0, s.int) assert.DeepEqual(t, "", s.nest.n1) assert.DeepEqual(t, "2", ***s.nest.N2) assert.DeepEqual(t, "", s.nest.string) } func TestBind_IgnoreField(t *testing.T) { type Req struct { Version int `path:"-"` ID int `query:"-"` Header string `header:"-"` Form string `form:"-"` } req := newMockRequest(). SetRequestURI("http://foobar.com?ID=12"). SetHeaders("Header", "header"). SetPostArg("Form", "form"). SetUrlEncodeContentType() var params param.Params params = append(params, param.Param{ Key: "Version", Value: "1", }) var result Req err := DefaultBinder().Bind(req.Req, &result, params) if err != nil { t.Error(err) } assert.DeepEqual(t, 0, result.Version) assert.DeepEqual(t, 0, result.ID) assert.DeepEqual(t, "", result.Header) assert.DeepEqual(t, "", result.Form) } func TestBind_DefaultTag(t *testing.T) { type Req struct { Version int ID int Header string Form string } type Req2 struct { Version int ID int Header string Form string } req := newMockRequest(). SetRequestURI("http://foobar.com?ID=12"). SetHeaders("Header", "header"). SetPostArg("Form", "form"). SetUrlEncodeContentType() var params param.Params params = append(params, param.Param{ Key: "Version", Value: "1", }) var result Req err := DefaultBinder().Bind(req.Req, &result, params) if err != nil { t.Error(err) } assert.DeepEqual(t, 1, result.Version) assert.DeepEqual(t, 12, result.ID) assert.DeepEqual(t, "header", result.Header) assert.DeepEqual(t, "form", result.Form) bindConfig := &BindConfig{} bindConfig.DisableDefaultTag = true binder := NewDefaultBinder(bindConfig) result2 := Req2{} err = binder.Bind(req.Req, &result2, params) if err != nil { t.Error(err) } assert.DeepEqual(t, 0, result2.Version) assert.DeepEqual(t, 0, result2.ID) assert.DeepEqual(t, "", result2.Header) assert.DeepEqual(t, "", result2.Form) } func TestBind_StructFieldResolve(t *testing.T) { type Nested struct { A int `query:"a" json:"a"` B int `query:"b" json:"b"` } type Req struct { N Nested `query:"n"` } req := newMockRequest(). SetRequestURI("http://foobar.com?n={\"a\":1,\"b\":2}"). SetHeaders("Header", "header"). SetPostArg("Form", "form"). SetUrlEncodeContentType() var result Req bindConfig := &BindConfig{} bindConfig.DisableStructFieldResolve = false binder := NewDefaultBinder(bindConfig) err := binder.Bind(req.Req, &result, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, 1, result.N.A) assert.DeepEqual(t, 2, result.N.B) req = newMockRequest(). SetRequestURI("http://foobar.com?n={\"a\":1,\"b\":2}&a=11&b=22"). SetHeaders("Header", "header"). SetPostArg("Form", "form"). SetUrlEncodeContentType() err = DefaultBinder().Bind(req.Req, &result, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, 11, result.N.A) assert.DeepEqual(t, 22, result.N.B) } func TestBind_JSONRequiredField(t *testing.T) { type Nested2 struct { C int `json:"c,required"` D int `json:"dd,required"` } type Nested struct { A int `json:"a,required"` B int `json:"b,required"` N2 Nested2 `json:"n2"` } type Req struct { N Nested `json:"n,required"` } bodyBytes := []byte(`{ "n": { "a": 1, "b": 2, "n2": { "dd": 4 } } }`) req := newMockRequest(). SetRequestURI("http://foobar.com?j2=13"). SetJSONContentType(). SetBody(bodyBytes) var result Req err := DefaultBinder().Bind(req.Req, &result, nil) if err == nil { t.Errorf("expected an error, but get nil") } assert.DeepEqual(t, "'c' field is a 'required' parameter, but the request body does not have this parameter 'n.n2.c'", err.Error()) assert.DeepEqual(t, 1, result.N.A) assert.DeepEqual(t, 2, result.N.B) assert.DeepEqual(t, 0, result.N.N2.C) assert.DeepEqual(t, 4, result.N.N2.D) bodyBytes = []byte(`{ "n": { "a": 1, "b": 2 } }`) req = newMockRequest(). SetRequestURI("http://foobar.com?j2=13"). SetJSONContentType(). SetBody(bodyBytes) var result2 Req err = DefaultBinder().Bind(req.Req, &result2, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, 1, result2.N.A) assert.DeepEqual(t, 2, result2.N.B) assert.DeepEqual(t, 0, result2.N.N2.C) assert.DeepEqual(t, 0, result2.N.N2.D) } func TestValidate_MultipleValidate(t *testing.T) { type Test1 struct { A int `query:"a" vd:"$>10"` } req := newMockRequest(). SetRequestURI("http://foobar.com?a=9") var result Test1 err := BindAndValidate(req.Req, &result, nil) if err == nil { t.Fatalf("expected an error, but get nil") } } func TestBind_BindQuery(t *testing.T) { type Req struct { Q1 int `query:"q1"` Q2 int Q3 string Q4 string Q5 []int } req := newMockRequest(). SetRequestURI("http://foobar.com?q1=1&Q2=2&Q3=3&Q4=4&Q5=51&Q5=52") var result Req err := DefaultBinder().BindQuery(req.Req, &result) if err != nil { t.Error(err) } assert.DeepEqual(t, 1, result.Q1) assert.DeepEqual(t, 2, result.Q2) assert.DeepEqual(t, "3", result.Q3) assert.DeepEqual(t, "4", result.Q4) assert.DeepEqual(t, 51, result.Q5[0]) assert.DeepEqual(t, 52, result.Q5[1]) } func TestBind_LooseMode(t *testing.T) { bindConfig := &BindConfig{} bindConfig.LooseZeroMode = false binder := NewDefaultBinder(bindConfig) type Req struct { ID int `query:"id"` } req := newMockRequest(). SetRequestURI("http://foobar.com?id=") var result Req err := binder.Bind(req.Req, &result, nil) if err == nil { t.Fatal("expected err") } assert.DeepEqual(t, 0, result.ID) bindConfig.LooseZeroMode = true binder = NewDefaultBinder(bindConfig) var result2 Req err = binder.Bind(req.Req, &result2, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, 0, result.ID) } func TestBind_NonStruct(t *testing.T) { req := newMockRequest(). SetRequestURI("http://foobar.com?id=1&id=2") var id interface{} err := DefaultBinder().Bind(req.Req, &id, nil) if err != nil { t.Error(err) } err = BindAndValidate(req.Req, &id, nil) if err != nil { t.Error(err) } } func TestBind_BindTag(t *testing.T) { type Req struct { Query string Header string Path string Form string } req := newMockRequest(). SetRequestURI("http://foobar.com?Query=query"). SetHeader("Header", "header"). SetPostArg("Form", "form") var params param.Params params = append(params, param.Param{ Key: "Path", Value: "path", }) result := Req{} // test query tag err := DefaultBinder().BindQuery(req.Req, &result) if err != nil { t.Error(err) } assert.DeepEqual(t, "query", result.Query) // test header tag result = Req{} err = DefaultBinder().BindHeader(req.Req, &result) if err != nil { t.Error(err) } assert.DeepEqual(t, "header", result.Header) // test form tag result = Req{} err = DefaultBinder().BindForm(req.Req, &result) if err != nil { t.Error(err) } assert.DeepEqual(t, "form", result.Form) // test path tag result = Req{} err = DefaultBinder().BindPath(req.Req, &result, params) if err != nil { t.Error(err) } assert.DeepEqual(t, "path", result.Path) // test json tag req = newMockRequest(). SetRequestURI("http://foobar.com"). SetJSONContentType(). SetBody([]byte("{\n \"Query\": \"query\",\n \"Path\": \"path\",\n \"Header\": \"header\",\n \"Form\": \"form\"\n}")) result = Req{} err = DefaultBinder().BindJSON(req.Req, &result) if err != nil { t.Error(err) } assert.DeepEqual(t, "form", result.Form) assert.DeepEqual(t, "query", result.Query) assert.DeepEqual(t, "header", result.Header) assert.DeepEqual(t, "path", result.Path) } func TestBind_BindAndValidate(t *testing.T) { type Req struct { ID int `query:"id" vd:"$>10"` } req := newMockRequest(). SetRequestURI("http://foobar.com?id=12") // test bindAndValidate var result Req err := BindAndValidate(req.Req, &result, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, 12, result.ID) // test bind result = Req{} err = Bind(req.Req, &result, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, 12, result.ID) // test validate req = newMockRequest(). SetRequestURI("http://foobar.com?id=9") result = Req{} err = Bind(req.Req, &result, nil) if err != nil { t.Error(err) } err = Validate(result) if err == nil { t.Errorf("expect an error, but get nil") } assert.DeepEqual(t, 9, result.ID) } func TestBind_FastPath(t *testing.T) { type Req struct { ID int `query:"id" vd:"$>10"` } req := newMockRequest(). SetRequestURI("http://foobar.com?id=12") // test bindAndValidate var result Req err := BindAndValidate(req.Req, &result, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, 12, result.ID) // execute multiple times, test cache for i := 0; i < 10; i++ { result = Req{} err := BindAndValidate(req.Req, &result, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, 12, result.ID) } } func TestBind_NonPointer(t *testing.T) { type Req struct { ID int `query:"id" vd:"$>10"` } req := newMockRequest(). SetRequestURI("http://foobar.com?id=12") // test bindAndValidate var result Req err := BindAndValidate(req.Req, result, nil) if err == nil { t.Error("expect an error, but get nil") } err = Bind(req.Req, result, nil) if err == nil { t.Error("expect an error, but get nil") } } func TestBind_PreBind(t *testing.T) { type Req struct { Query string Header string Path string Form string } // test json tag req := newMockRequest(). SetRequestURI("http://foobar.com"). SetJSONContentType(). SetBody([]byte("\n \"Query\": \"query\",\n \"Path\": \"path\",\n \"Header\": \"header\",\n \"Form\": \"form\"\n}")) result := Req{} err := DefaultBinder().Bind(req.Req, &result, nil) if err == nil { t.Error("expect an error, but get nil") } err = BindAndValidate(req.Req, &result, nil) if err == nil { t.Error("expect an error, but get nil") } } func TestBind_BindProtobuf(t *testing.T) { data := testdata.HertzReq{Name: "hertz"} body, err := proto.Marshal(&data) if err != nil { t.Fatal(err) } req := newMockRequest(). SetRequestURI("http://foobar.com"). SetProtobufContentType(). SetBody(body) result := testdata.HertzReq{} err = BindAndValidate(req.Req, &result, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, "hertz", result.Name) result = testdata.HertzReq{} err = DefaultBinder().BindProtobuf(req.Req, &result) if err != nil { t.Error(err) } assert.DeepEqual(t, "hertz", result.Name) } func TestBind_PointerStruct(t *testing.T) { bindConfig := &BindConfig{} bindConfig.DisableStructFieldResolve = false binder := NewDefaultBinder(bindConfig) type Foo struct { F1 string `query:"F1"` } type Bar struct { B1 **Foo `query:"B1,required"` } query := make(url.Values) query.Add("B1", "{\n \"F1\": \"111\"\n}") var result Bar req := newMockRequest(). SetRequestURI(fmt.Sprintf("http://foobar.com?%s", query.Encode())) err := binder.Bind(req.Req, &result, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, "111", (**result.B1).F1) result = Bar{} req = newMockRequest(). SetRequestURI(fmt.Sprintf("http://foobar.com?%s&F1=222", query.Encode())) err = binder.Bind(req.Req, &result, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, "222", (**result.B1).F1) } func TestBind_StructRequired(t *testing.T) { bindConfig := &BindConfig{} bindConfig.DisableStructFieldResolve = false binder := NewDefaultBinder(bindConfig) type Foo struct { F1 string `query:"F1"` } type Bar struct { B1 **Foo `query:"B1,required"` } var result Bar req := newMockRequest(). SetRequestURI("http://foobar.com") err := binder.Bind(req.Req, &result, nil) if err == nil { t.Error("expect an error, but get nil") } type Bar2 struct { B1 **Foo `query:"B1"` } var result2 Bar2 req = newMockRequest(). SetRequestURI("http://foobar.com") err = binder.Bind(req.Req, &result2, nil) if err != nil { t.Error(err) } } func TestBind_StructErrorToWarn(t *testing.T) { bindConfig := &BindConfig{} bindConfig.DisableStructFieldResolve = false binder := NewDefaultBinder(bindConfig) type Foo struct { F1 string `query:"F1"` } type Bar struct { B1 **Foo `query:"B1,required"` } var result Bar req := newMockRequest(). SetRequestURI("http://foobar.com?B1=111&F1=222") err := binder.Bind(req.Req, &result, nil) // transfer 'unmarsahl err' to 'warn' if err != nil { t.Error(err) } assert.DeepEqual(t, "222", (**result.B1).F1) type Bar2 struct { B1 Foo `query:"B1,required"` } var result2 Bar2 err = binder.Bind(req.Req, &result2, nil) // transfer 'unmarsahl err' to 'warn' if err != nil { t.Error(err) } assert.DeepEqual(t, "222", result2.B1.F1) } func TestBind_DisallowUnknownFieldsConfig(t *testing.T) { bindConfig := &BindConfig{} bindConfig.EnableDecoderDisallowUnknownFields = true binder := NewDefaultBinder(bindConfig) type FooStructUseNumber struct { Foo interface{} `json:"foo"` } req := newMockRequest(). SetRequestURI("http://foobar.com"). SetJSONContentType(). SetBody([]byte(`{"foo": 123,"bar": "456"}`)) var result FooStructUseNumber err := binder.BindJSON(req.Req, &result) if err == nil { t.Errorf("expected an error, but get nil") } } func TestBind_UseNumberConfig(t *testing.T) { bindConfig := &BindConfig{} bindConfig.EnableDecoderUseNumber = true binder := NewDefaultBinder(bindConfig) type FooStructUseNumber struct { Foo interface{} `json:"foo"` } req := newMockRequest(). SetRequestURI("http://foobar.com"). SetJSONContentType(). SetBody([]byte(`{"foo": 123}`)) var result FooStructUseNumber err := binder.BindJSON(req.Req, &result) if err != nil { t.Error(err) } v, err := result.Foo.(json.Number).Int64() if err != nil { t.Error(err) } assert.DeepEqual(t, int64(123), v) } func TestBind_InterfaceType(t *testing.T) { type Bar struct { B1 interface{} `query:"B1"` } var result Bar query := make(url.Values) query.Add("B1", `{"B1":"111"}`) req := newMockRequest(). SetRequestURI(fmt.Sprintf("http://foobar.com?%s", query.Encode())) err := DefaultBinder().Bind(req.Req, &result, nil) if err != nil { t.Error(err) } type Bar2 struct { B2 *interface{} `query:"B1"` } var result2 Bar2 err = DefaultBinder().Bind(req.Req, &result2, nil) if err != nil { t.Error(err) } } func Test_BindHeaderNormalize(t *testing.T) { type Req struct { Header string `header:"h"` } req := newMockRequest(). SetRequestURI("http://foobar.com"). SetHeaders("h", "header") var result Req err := DefaultBinder().Bind(req.Req, &result, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, "header", result.Header) req = newMockRequest(). SetRequestURI("http://foobar.com"). SetHeaders("H", "header") err = DefaultBinder().Bind(req.Req, &result, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, "header", result.Header) type Req2 struct { Header string `header:"H"` } req2 := newMockRequest(). SetRequestURI("http://foobar.com"). SetHeaders("h", "header") var result2 Req2 err2 := DefaultBinder().Bind(req2.Req, &result2, nil) if err != nil { t.Error(err2) } assert.DeepEqual(t, "header", result2.Header) req2 = newMockRequest(). SetRequestURI("http://foobar.com"). SetHeaders("H", "header") err2 = DefaultBinder().Bind(req2.Req, &result2, nil) if err2 != nil { t.Error(err2) } assert.DeepEqual(t, "header", result2.Header) type Req3 struct { Header string `header:"h"` } // without normalize, the header key & tag key need to be consistent req3 := newMockRequest(). SetRequestURI("http://foobar.com") req3.Req.Header.DisableNormalizing() req3.SetHeaders("h", "header") var result3 Req3 err3 := DefaultBinder().Bind(req3.Req, &result3, nil) if err3 != nil { t.Error(err3) } assert.DeepEqual(t, "header", result3.Header) req3 = newMockRequest(). SetRequestURI("http://foobar.com") req3.Req.Header.DisableNormalizing() req3.SetHeaders("H", "header") result3 = Req3{} err3 = DefaultBinder().Bind(req3.Req, &result3, nil) if err3 != nil { t.Error(err3) } assert.DeepEqual(t, "", result3.Header) } type ValidateError struct { ErrType, FailField, Msg string } // Error implements error interface. func (e *ValidateError) Error() string { if e.Msg != "" { return e.ErrType + ": expr_path=" + e.FailField + ", cause=" + e.Msg } return e.ErrType + ": expr_path=" + e.FailField + ", cause=invalid" } func Test_ValidatorErrorFactory(t *testing.T) { type TestBind struct { A string `query:"a,required"` } r := protocol.NewRequest("GET", "/foo", nil) r.SetRequestURI("/foo/bar?b=20") CustomValidateErrFunc := func(failField, msg string) error { err := ValidateError{ ErrType: "validateErr", FailField: "[validateFailField]: " + failField, Msg: "[validateErrMsg]: " + msg, } return &err } validateConfig := NewValidateConfig() validateConfig.SetValidatorErrorFactory(CustomValidateErrFunc) validator := NewValidator(validateConfig) var req TestBind err := Bind(r, &req, nil) if err == nil { t.Fatalf("unexpected nil, expected an error") } assert.DeepEqual(t, "'a' field is a 'required' parameter, but the request does not have this parameter", err.Error()) type TestValidate struct { B int `query:"b" vd:"$>100"` } var reqValidate TestValidate err = Bind(r, &reqValidate, nil) if err != nil { t.Fatalf("unexpected error: %v", err) } err = validator.ValidateStruct(&reqValidate) if err == nil { t.Fatalf("unexpected nil, expected an error") } assert.DeepEqual(t, "validateErr: expr_path=[validateFailField]: B, cause=[validateErrMsg]: ", err.Error()) } // Test_Issue964 used to the cover issue for time.Time func Test_Issue964(t *testing.T) { type CreateReq struct { StartAt *time.Time `json:"startAt"` } r := newMockRequest().SetBody([]byte("{\n \"startAt\": \"2006-01-02T15:04:05+07:00\"\n}")).SetJSONContentType() var req CreateReq err := BindAndValidate(r.Req, &req, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, "2006-01-02 15:04:05 +0700 +0700", req.StartAt.String()) r = newMockRequest() req = CreateReq{} err = BindAndValidate(r.Req, &req, nil) if err != nil { t.Error(err) } if req.StartAt != nil { t.Error("expected nil") } } type reqSameType struct { Parent *reqSameType `json:"parent"` Children []reqSameType `json:"children"` Foo1 reqSameType2 `json:"foo1"` A string `json:"a"` } type reqSameType2 struct { Foo1 *reqSameType `json:"foo1"` } func TestBind_Issue1015(t *testing.T) { req := newMockRequest(). SetJSONContentType(). SetBody([]byte(`{"parent":{"parent":{}, "children":[{},{}], "foo1":{"foo1":{}}}, "children":[{},{}], "a":"asd"}`)) var result reqSameType err := DefaultBinder().Bind(req.Req, &result, nil) if err != nil { t.Error(err) } assert.NotNil(t, result.Parent) assert.NotNil(t, result.Parent.Parent) assert.Nil(t, result.Parent.Parent.Parent) assert.NotNil(t, result.Parent.Children) assert.DeepEqual(t, 2, len(result.Parent.Children)) assert.NotNil(t, result.Parent.Foo1.Foo1) assert.DeepEqual(t, "", result.Parent.A) assert.DeepEqual(t, 2, len(result.Children)) assert.Nil(t, result.Foo1.Foo1) assert.DeepEqual(t, "asd", result.A) } func TestBind_JSONWithDefault(t *testing.T) { type Req struct { J1 string `json:"j1" default:"j1default"` } req := newMockRequest(). SetJSONContentType(). SetBody([]byte(`{"j1":"j1"}`)) var result Req err := DefaultBinder().Bind(req.Req, &result, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, "j1", result.J1) result = Req{} req = newMockRequest(). SetJSONContentType(). SetBody([]byte(`{"j2":"j2"}`)) err = DefaultBinder().Bind(req.Req, &result, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, "j1default", result.J1) } func TestBind_WithoutPreBindForTag(t *testing.T) { type BaseQuery struct { Action string `query:"Action" binding:"required"` Version string `query:"Version" binding:"required"` } req := newMockRequest(). SetJSONContentType(). SetRequestURI("http://foobar.com/?Action=action&Version=version"). SetBody([]byte(``)) var result BaseQuery err := DefaultBinder().BindQuery(req.Req, &result) if err != nil { t.Error(err) } assert.DeepEqual(t, "action", result.Action) assert.DeepEqual(t, "version", result.Version) } func TestBind_NormalizeContentType(t *testing.T) { type BaseQuery struct { Action string `json:"action" binding:"required"` Version string `json:"version" binding:"required"` } req := newMockRequest(). SetHeader("Content-Type", "ApplicAtion/json"). SetRequestURI("http://foobar.com/?Action=action&Version=version"). SetBody([]byte(`{"action":"action", "version":"version"}`)) var result BaseQuery err := DefaultBinder().BindQuery(req.Req, &result) if err != nil { t.Error(err) } assert.DeepEqual(t, "action", result.Action) assert.DeepEqual(t, "version", result.Version) } type TestEnumType int32 var _ encoding.TextUnmarshaler = (*TestEnumType)(nil) func (p *TestEnumType) UnmarshalText(v []byte) error { switch string(v) { case "one": *p = 1 case "two": *p = 2 default: return errors.New("invalid") } return nil } func TestBind_TextUnmarshaler(t *testing.T) { type Query struct { A TestEnumType `query:"a"` B TestEnumType `query:"b"` C *TestEnumType `query:"c"` D *TestEnumType `query:"d"` } q := &Query{} req := newMockRequest().SetRequestURI("http://example.com?a=1&b=one&c=2&d=two") err := DefaultBinder().BindQuery(req.Req, q) assert.Nil(t, err) assert.DeepEqual(t, TestEnumType(1), q.A) assert.DeepEqual(t, TestEnumType(1), q.B) assert.NotNil(t, q.C) assert.NotNil(t, q.D) assert.DeepEqual(t, TestEnumType(2), *q.C) assert.DeepEqual(t, TestEnumType(2), *q.D) } func Benchmark_Binding(b *testing.B) { type Req struct { Version string `path:"v"` ID int `query:"id"` Header string `header:"h"` Form string `form:"f"` } req := newMockRequest(). SetRequestURI("http://foobar.com?id=12"). SetHeaders("H", "header"). SetPostArg("f", "form"). SetUrlEncodeContentType() var params param.Params params = append(params, param.Param{ Key: "v", Value: "1", }) b.ResetTimer() for i := 0; i < b.N; i++ { var result Req err := DefaultBinder().Bind(req.Req, &result, params) if err != nil { b.Error(err) } if result.ID != 12 { b.Error("Id failed") } if result.Form != "form" { b.Error("form failed") } if result.Header != "header" { b.Error("header failed") } if result.Version != "1" { b.Error("path failed") } } } // TestBind_AnonymousFieldWithDefaultTag tests that default tag values don't override // JSON-provided values when using anonymous struct embedding with multiple tags func TestBind_AnonymousFieldWithDefaultTag(t *testing.T) { type PageInfo struct { Page int `json:"page" form:"page" query:"page" default:"1"` Limit int `json:"limit" form:"limit" query:"limit" default:"15"` } type Req struct { Keyword string `json:"keyword"` PageInfo } // Test 1: JSON values should override defaults req := protocol.NewRequest("POST", "/search", nil) req.SetBody([]byte(`{"keyword":"test","page":2,"limit":5}`)) req.Header.SetContentTypeBytes([]byte("application/json")) req.Header.SetContentLength(37) var r Req err := DefaultBinder().Bind(req, &r, nil) assert.Nil(t, err) assert.DeepEqual(t, "test", r.Keyword) assert.DeepEqual(t, 2, r.Page) // Should use JSON value, not default assert.DeepEqual(t, 5, r.Limit) // Should use JSON value, not default // Test 2: Empty JSON should use defaults req2 := protocol.NewRequest("POST", "/search", nil) req2.SetBody([]byte(`{"keyword":"test"}`)) req2.Header.SetContentTypeBytes([]byte("application/json")) req2.Header.SetContentLength(20) var r2 Req err2 := DefaultBinder().Bind(req2, &r2, nil) assert.Nil(t, err2) assert.DeepEqual(t, "test", r2.Keyword) assert.DeepEqual(t, 1, r2.Page) // Should use default value assert.DeepEqual(t, 15, r2.Limit) // Should use default value // Test 3: Query values should work req3 := protocol.NewRequest("POST", "/search?page=3&limit=4", nil) req3.Header.SetContentTypeBytes([]byte("application/x-www-form-urlencoded")) var r3 Req err3 := DefaultBinder().Bind(req3, &r3, nil) assert.Nil(t, err3) assert.DeepEqual(t, 3, r3.Page) // Should use query value assert.DeepEqual(t, 4, r3.Limit) // Should use query value // Test 4: Nested anonymous structs (multi-level) type BaseInfo struct { Page int `json:"page" form:"page" default:"1"` } type ExtInfo struct { BaseInfo Limit int `json:"limit" form:"limit" default:"20"` } type Req2 struct { Keyword string `json:"keyword"` ExtInfo } req4 := protocol.NewRequest("POST", "/search", nil) req4.SetBody([]byte(`{"keyword":"nested","page":10,"limit":30}`)) req4.Header.SetContentTypeBytes([]byte("application/json")) req4.Header.SetContentLength(44) var r4 Req2 err4 := DefaultBinder().Bind(req4, &r4, nil) assert.Nil(t, err4) assert.DeepEqual(t, "nested", r4.Keyword) assert.DeepEqual(t, 10, r4.Page) // Should use JSON value from nested struct assert.DeepEqual(t, 30, r4.Limit) // Should use JSON value, not default } ================================================ FILE: pkg/app/server/binding/config.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package binding import ( stdJson "encoding/json" "fmt" "reflect" "time" exprValidator "github.com/cloudwego/hertz/internal/tagexpr/validator" inDecoder "github.com/cloudwego/hertz/pkg/app/server/binding/internal/decoder" hJson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" ) // BindConfig contains options for default bind behavior. type BindConfig struct { // LooseZeroMode if set to true, // the empty string request parameter is bound to the zero value of parameter. // NOTE: // The default is false. // Suitable for these parameter types: query/header/cookie/form . LooseZeroMode bool // DisableDefaultTag is used to add default tags to a field when it has no tag // If is false, the field with no tag will be added default tags, for more automated binding. But there may be additional overhead. // NOTE: // The default is false. DisableDefaultTag bool // DisableStructFieldResolve is used to generate a separate decoder for a struct. // If is false, the 'struct' field will get a single inDecoder.structTypeFieldTextDecoder, and use json.Unmarshal for decode it. // It usually used to add json string to query parameter. // NOTE: // The default is false. DisableStructFieldResolve bool // EnableDecoderUseNumber is used to call the UseNumber method on the JSON // Decoder instance. UseNumber causes the Decoder to unmarshal a number into an // interface{} as a Number instead of as a float64. // NOTE: // The default is false. // It is used for BindJSON(). EnableDecoderUseNumber bool // EnableDecoderDisallowUnknownFields is used to call the DisallowUnknownFields method // on the JSON Decoder instance. DisallowUnknownFields causes the Decoder to // return an error when the destination is a struct and the input contains object // keys which do not match any non-ignored, exported fields in the destination. // NOTE: // The default is false. // It is used for BindJSON(). EnableDecoderDisallowUnknownFields bool // TypeUnmarshalFuncs registers customized type unmarshaler. // NOTE: // time.Time is registered by default TypeUnmarshalFuncs map[reflect.Type]inDecoder.CustomizeDecodeFunc // Validator is used to validate for BindAndValidate() // // Deprecated: use ValidatorFunc instead. You can create a ValidatorFunc // from a StructValidator using MakeValidatorFunc() Validator StructValidator // ValidatorFunc is used to validate structs with custom validation logic. // It replaces the deprecated Validator field and provides request context. // NOTE: // The default is nil. If set, this takes precedence over the Validator field. // The function signature allows access to the request for context-aware validation. ValidatorFunc func(req *protocol.Request, v any) error } func NewBindConfig() *BindConfig { return &BindConfig{ LooseZeroMode: false, DisableDefaultTag: false, DisableStructFieldResolve: false, EnableDecoderUseNumber: false, EnableDecoderDisallowUnknownFields: false, TypeUnmarshalFuncs: make(map[reflect.Type]inDecoder.CustomizeDecodeFunc), Validator: defaultValidate, } } // RegTypeUnmarshal registers customized type unmarshaler. func (config *BindConfig) RegTypeUnmarshal(t reflect.Type, fn inDecoder.CustomizeDecodeFunc) error { // check switch t.Kind() { case reflect.String, reflect.Bool, reflect.Float32, reflect.Float64, reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8, reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: return fmt.Errorf("registration type cannot be a basic type") case reflect.Ptr: return fmt.Errorf("registration type cannot be a pointer type") } if config.TypeUnmarshalFuncs == nil { config.TypeUnmarshalFuncs = make(map[reflect.Type]inDecoder.CustomizeDecodeFunc) } config.TypeUnmarshalFuncs[t] = fn return nil } // MustRegTypeUnmarshal registers customized type unmarshaler. It will panic if exist error. func (config *BindConfig) MustRegTypeUnmarshal(t reflect.Type, fn func(req *protocol.Request, params param.Params, text string) (reflect.Value, error)) { err := config.RegTypeUnmarshal(t, fn) if err != nil { panic(err) } } func (config *BindConfig) initTypeUnmarshal() { config.MustRegTypeUnmarshal(reflect.TypeOf(time.Time{}), func(req *protocol.Request, params param.Params, text string) (reflect.Value, error) { if text == "" { return reflect.ValueOf(time.Time{}), nil } t, err := time.Parse(time.RFC3339, text) if err != nil { return reflect.Value{}, err } return reflect.ValueOf(t), nil }) } // UseThirdPartyJSONUnmarshaler uses third-party json library for binding // NOTE: // // UseThirdPartyJSONUnmarshaler will remain in effect once it has been called. func (config *BindConfig) UseThirdPartyJSONUnmarshaler(fn func(data []byte, v interface{}) error) { hJson.Unmarshal = fn } // UseStdJSONUnmarshaler uses encoding/json as json library // NOTE: // // The current version uses encoding/json by default. // UseStdJSONUnmarshaler will remain in effect once it has been called. func (config *BindConfig) UseStdJSONUnmarshaler() { config.UseThirdPartyJSONUnmarshaler(stdJson.Unmarshal) } // ValidateErrFactory defines the factory function for creating validation errors. // // Deprecated: Use WithCustomValidatorFunc with a custom validation function instead. type ValidateErrFactory func(fieldSelector, msg string) error // ValidateConfig configures validation behavior for the built-in StructValidator. // // Deprecated: Use WithCustomValidatorFunc with a custom validation function instead. type ValidateConfig struct { ValidateTag string ErrFactory ValidateErrFactory } // NewValidateConfig creates a new ValidateConfig. // // Deprecated: Use WithCustomValidatorFunc with a custom validation function instead. func NewValidateConfig() *ValidateConfig { return &ValidateConfig{} } // MustRegValidateFunc registers validator function expression. // NOTE: // // If force=true, allow to cover the existed same funcName. // MustRegValidateFunc will remain in effect once it has been called. // // Deprecated: Use WithCustomValidatorFunc with a custom validation function instead. func (config *ValidateConfig) MustRegValidateFunc(funcName string, fn func(args ...interface{}) error, force ...bool) { exprValidator.MustRegFunc(funcName, fn, force...) } // SetValidatorErrorFactory customizes the factory of validation error. // // Deprecated: Use WithCustomValidatorFunc with a custom validation function instead. func (config *ValidateConfig) SetValidatorErrorFactory(errFactory ValidateErrFactory) { config.ErrFactory = errFactory } // SetValidatorTag customizes the validation tag. // // Deprecated: Use WithCustomValidatorFunc with a custom validation function instead. func (config *ValidateConfig) SetValidatorTag(tag string) { config.ValidateTag = tag } ================================================ FILE: pkg/app/server/binding/default.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * The MIT License * * Copyright (c) 2019-present Fenny and Contributors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * * Copyright (c) 2014 Manuel Martínez-Almeida * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2023 CloudWeGo Authors */ package binding import ( "bytes" stdJson "encoding/json" "fmt" "io" "net/url" "reflect" "strings" "sync" "github.com/cloudwego/hertz/internal/bytesconv" exprValidator "github.com/cloudwego/hertz/internal/tagexpr/validator" inDecoder "github.com/cloudwego/hertz/pkg/app/server/binding/internal/decoder" hJson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/route/param" "google.golang.org/protobuf/proto" ) const ( queryTag = "query" headerTag = "header" formTag = "form" pathTag = "path" defaultValidateTag = "vd" ) type decoderInfo struct { decoder inDecoder.Decoder } var defaultBind = (NewDefaultBinder(nil).(*defaultBinder)) func DefaultBinder() Binder { return defaultBind } type defaultBinder struct { config *BindConfig decoderCache sync.Map queryDecoderCache sync.Map formDecoderCache sync.Map headerDecoderCache sync.Map pathDecoderCache sync.Map } func NewDefaultBinder(config *BindConfig) Binder { if config == nil { config = NewBindConfig() } config.initTypeUnmarshal() if config.Validator == nil { config.Validator = DefaultValidator() } // Initialize ValidatorFunc if not set, using the legacy Validator if config.ValidatorFunc == nil { config.ValidatorFunc = MakeValidatorFunc(config.Validator) } return &defaultBinder{ config: config, } } // BindAndValidate binds data from *protocol.Request to obj and validates them if needed. // NOTE: // // obj should be a pointer. func BindAndValidate(req *protocol.Request, obj interface{}, pathParams param.Params) error { if err := defaultBind.Bind(req, obj, pathParams); err != nil { return err } return defaultBind.Validate(req, obj) } // Bind binds data from *protocol.Request to obj. // NOTE: // // obj should be a pointer. func Bind(req *protocol.Request, obj interface{}, pathParams param.Params) error { return defaultBind.Bind(req, obj, pathParams) } // Validate validates obj with "vd" tag // NOTE: // // obj should be a pointer. // Validate should be called after Bind. func Validate(obj interface{}) error { return defaultBind.Validate(nil, obj) } func (b *defaultBinder) tagCache(tag string) *sync.Map { switch tag { case queryTag: return &b.queryDecoderCache case headerTag: return &b.headerDecoderCache case formTag: return &b.formDecoderCache case pathTag: return &b.pathDecoderCache default: return &b.decoderCache } } func (b *defaultBinder) bindTag(req *protocol.Request, v interface{}, params param.Params, tag string) error { rv, typeID := valueAndTypeID(v) if err := checkPointer(rv); err != nil { return err } rt := dereferenceType(rv.Type()) if rt.Kind() != reflect.Struct { return b.bindNonStruct(req, v) } if len(tag) == 0 { err := b.preBindBody(req, v) if err != nil { return fmt.Errorf("bind body failed, err=%v", err) } } cache := b.tagCache(tag) cached, ok := cache.Load(typeID) if ok { // cached fieldDecoder, fast path decoder := cached.(decoderInfo) return decoder.decoder(req, params, rv.Elem()) } decodeConfig := &inDecoder.DecodeConfig{ LooseZeroMode: b.config.LooseZeroMode, DisableDefaultTag: b.config.DisableDefaultTag, DisableStructFieldResolve: b.config.DisableStructFieldResolve, EnableDecoderUseNumber: b.config.EnableDecoderUseNumber, EnableDecoderDisallowUnknownFields: b.config.EnableDecoderDisallowUnknownFields, TypeUnmarshalFuncs: b.config.TypeUnmarshalFuncs, } decoder, err := inDecoder.GetReqDecoder(rv.Type(), tag, decodeConfig) if err != nil { return err } cache.Store(typeID, decoderInfo{decoder: decoder}) return decoder(req, params, rv.Elem()) } func (b *defaultBinder) BindQuery(req *protocol.Request, v interface{}) error { return b.bindTag(req, v, nil, queryTag) } func (b *defaultBinder) BindHeader(req *protocol.Request, v interface{}) error { return b.bindTag(req, v, nil, headerTag) } func (b *defaultBinder) BindPath(req *protocol.Request, v interface{}, params param.Params) error { return b.bindTag(req, v, params, pathTag) } func (b *defaultBinder) BindForm(req *protocol.Request, v interface{}) error { return b.bindTag(req, v, nil, formTag) } func (b *defaultBinder) BindJSON(req *protocol.Request, v interface{}) error { return b.decodeJSON(bytes.NewReader(req.Body()), v) } func (b *defaultBinder) decodeJSON(r io.Reader, obj interface{}) error { decoder := hJson.NewDecoder(r) if b.config.EnableDecoderUseNumber { decoder.UseNumber() } if b.config.EnableDecoderDisallowUnknownFields { decoder.DisallowUnknownFields() } return decoder.Decode(obj) } func (b *defaultBinder) BindProtobuf(req *protocol.Request, v interface{}) error { msg, ok := v.(proto.Message) if !ok { return fmt.Errorf("%s does not implement 'proto.Message'", v) } return proto.Unmarshal(req.Body(), msg) } func (b *defaultBinder) Name() string { return "hertz" } func (b *defaultBinder) Bind(req *protocol.Request, v interface{}, params param.Params) error { return b.bindTag(req, v, params, "") } func (b *defaultBinder) Validate(req *protocol.Request, v interface{}) error { return b.config.ValidatorFunc(req, v) } // best effort binding func (b *defaultBinder) preBindBody(req *protocol.Request, v interface{}) error { if req.Header.ContentLength() <= 0 { return nil } ct := bytesconv.B2s(req.Header.ContentType()) switch strings.ToLower(utils.FilterContentType(ct)) { case consts.MIMEApplicationJSON: return hJson.Unmarshal(req.Body(), v) case consts.MIMEPROTOBUF: msg, ok := v.(proto.Message) if !ok { return fmt.Errorf("%s can not implement 'proto.Message'", v) } return proto.Unmarshal(req.Body(), msg) default: return nil } } func (b *defaultBinder) bindNonStruct(req *protocol.Request, v interface{}) (err error) { ct := bytesconv.B2s(req.Header.ContentType()) switch strings.ToLower(utils.FilterContentType(ct)) { case consts.MIMEApplicationJSON: err = hJson.Unmarshal(req.Body(), v) case consts.MIMEPROTOBUF: msg, ok := v.(proto.Message) if !ok { return fmt.Errorf("%s can not implement 'proto.Message'", v) } err = proto.Unmarshal(req.Body(), msg) case consts.MIMEMultipartPOSTForm: form := make(url.Values) mf, err1 := req.MultipartForm() if err1 == nil && mf.Value != nil { for k, v := range mf.Value { for _, vv := range v { form.Add(k, vv) } } } b, _ := stdJson.Marshal(form) err = hJson.Unmarshal(b, v) case consts.MIMEApplicationHTMLForm: form := make(url.Values) req.PostArgs().VisitAll(func(formKey, value []byte) { form.Add(string(formKey), string(value)) }) b, _ := stdJson.Marshal(form) err = hJson.Unmarshal(b, v) default: // using query to decode query := make(url.Values) req.URI().QueryArgs().VisitAll(func(queryKey, value []byte) { query.Add(string(queryKey), string(value)) }) b, _ := stdJson.Marshal(query) err = hJson.Unmarshal(b, v) } return } var _ StructValidator = (*validator)(nil) type validator struct { validateTag string validate *exprValidator.Validator } // NewValidator creates a new StructValidator with the given configuration. // // Deprecated: Use WithCustomValidatorFunc with a custom validation function instead. // You can convert the returned StructValidator to a ValidatorFunc using MakeValidatorFunc(). func NewValidator(config *ValidateConfig) StructValidator { validateTag := defaultValidateTag if config != nil && len(config.ValidateTag) != 0 { validateTag = config.ValidateTag } vd := exprValidator.New(validateTag).SetErrorFactory(defaultValidateErrorFactory) if config != nil && config.ErrFactory != nil { vd.SetErrorFactory(config.ErrFactory) } return &validator{ validateTag: validateTag, validate: vd, } } // Error validate error type validateError struct { FailPath, Msg string } // Error implements error interface. func (e *validateError) Error() string { if e.Msg != "" { return e.Msg } return "invalid parameter: " + e.FailPath } func defaultValidateErrorFactory(failPath, msg string) error { return &validateError{ FailPath: failPath, Msg: msg, } } // ValidateStruct receives any kind of type, but only performed struct or pointer to struct type. func (v *validator) ValidateStruct(obj interface{}) error { if obj == nil { return nil } return v.validate.Validate(obj) } // Engine returns the underlying validator func (v *validator) Engine() interface{} { return v.validate } func (v *validator) ValidateTag() string { return v.validateTag } var defaultValidate = NewValidator(NewValidateConfig()) // DefaultValidator returns the default StructValidator instance that uses tagexpr validation. // The validator uses the "vd" tag for validation expressions and provides comprehensive // struct field validation capabilities. // // Deprecated: Use WithCustomValidatorFunc with a custom validation function instead. // For migration: convert this StructValidator to a ValidatorFunc using MakeValidatorFunc(). // // Example migration: // // // Old way (deprecated) // validator := binding.DefaultValidator() // // // New way (recommended) // validatorFunc := binding.MakeValidatorFunc(binding.DefaultValidator()) // server.WithCustomValidatorFunc(validatorFunc) func DefaultValidator() StructValidator { return defaultValidate } ================================================ FILE: pkg/app/server/binding/internal/decoder/base_type_decoder.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * MIT License * * Copyright (c) 2019-present Fenny and Contributors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2023 CloudWeGo Authors */ package decoder import ( "fmt" "reflect" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" ) type fieldInfo struct { index int parentIndex []int fieldName string tagInfos []TagInfo fieldType reflect.Type config *DecodeConfig } type baseTypeFieldTextDecoder struct { fieldInfo decoder TextDecoder } func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Params, reqValue reflect.Value) error { var err error var text string var exist bool var defaultValue string for _, tagInfo := range d.tagInfos { if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { if tagInfo.Key == jsonTag { defaultValue = tagInfo.Default found := checkRequireJSON(req, tagInfo) if found { err = nil } else { err = fmt.Errorf("'%s' field is a 'required' parameter, but the request body does not have this parameter '%s'", tagInfo.Value, tagInfo.JSONName) } if len(tagInfo.Default) != 0 && keyExist(req, tagInfo) { defaultValue = "" } } continue } text, exist = tagInfo.Getter(req, params, tagInfo.Value) defaultValue = tagInfo.Default if exist { err = nil break } if tagInfo.Required { err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", tagInfo.Value) } } if err != nil { return err } if len(text) == 0 && len(defaultValue) != 0 { text = toDefaultValue(d.fieldType, defaultValue) } if !exist && len(text) == 0 { return nil } // get the non-nil value for the parent field reqValue = GetFieldValue(reqValue, d.parentIndex) field := reqValue.Field(d.index) if field.Kind() == reflect.Ptr { t := field.Type() var ptrDepth int for t.Kind() == reflect.Ptr { t = t.Elem() ptrDepth++ } var vv reflect.Value vv, err := stringToValue(t, text, req, params, d.config) if err != nil { return err } field.Set(ReferenceValue(vv, ptrDepth)) return nil } // Non-pointer elems if field.CanAddr() { if tryTextUnmarshaler(field.Addr(), text) { return nil } } err = d.decoder.UnmarshalString(text, field, d.config.LooseZeroMode) if err != nil { return fmt.Errorf("unable to decode '%s' as %s: %w", text, d.fieldType.Name(), err) } return nil } func getBaseTypeTextDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int, config *DecodeConfig) ([]fieldDecoder, error) { for idx, tagInfo := range tagInfos { switch tagInfo.Key { case pathTag: tagInfos[idx].SliceGetter = pathSlice tagInfos[idx].Getter = path case formTag: tagInfos[idx].SliceGetter = postFormSlice tagInfos[idx].Getter = postForm case queryTag: tagInfos[idx].SliceGetter = querySlice tagInfos[idx].Getter = query case cookieTag: tagInfos[idx].SliceGetter = cookieSlice tagInfos[idx].Getter = cookie case headerTag: tagInfos[idx].SliceGetter = headerSlice tagInfos[idx].Getter = header case jsonTag: // do nothing case rawBodyTag: tagInfos[idx].SliceGetter = rawBodySlice tagInfos[idx].Getter = rawBody case fileNameTag: // do nothing default: } } fieldType := field.Type for field.Type.Kind() == reflect.Ptr { fieldType = field.Type.Elem() } textDecoder, err := SelectTextDecoder(fieldType) if err != nil { return nil, err } return []fieldDecoder{&baseTypeFieldTextDecoder{ fieldInfo: fieldInfo{ index: index, parentIndex: parentIdx, fieldName: field.Name, tagInfos: tagInfos, fieldType: fieldType, config: config, }, decoder: textDecoder, }}, nil } ================================================ FILE: pkg/app/server/binding/internal/decoder/customized_type_decoder.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * MIT License * * Copyright (c) 2019-present Fenny and Contributors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2023 CloudWeGo Authors */ package decoder import ( "reflect" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" ) type CustomizeDecodeFunc func(req *protocol.Request, params param.Params, text string) (reflect.Value, error) type customizedFieldTextDecoder struct { fieldInfo decodeFunc CustomizeDecodeFunc } func (d *customizedFieldTextDecoder) Decode(req *protocol.Request, params param.Params, reqValue reflect.Value) error { var text string var exist bool var defaultValue string for _, tagInfo := range d.tagInfos { if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { if tagInfo.Key == jsonTag { defaultValue = tagInfo.Default if len(tagInfo.Default) != 0 && keyExist(req, tagInfo) { defaultValue = "" } } continue } text, exist = tagInfo.Getter(req, params, tagInfo.Value) defaultValue = tagInfo.Default if exist { break } } if !exist { return nil } if len(text) == 0 && len(defaultValue) != 0 { text = toDefaultValue(d.fieldType, defaultValue) } v, err := d.decodeFunc(req, params, text) if err != nil { return err } if !v.IsValid() { return nil } reqValue = GetFieldValue(reqValue, d.parentIndex) field := reqValue.Field(d.index) if field.Kind() == reflect.Ptr { t := field.Type() var ptrDepth int for t.Kind() == reflect.Ptr { t = t.Elem() ptrDepth++ } field.Set(ReferenceValue(v, ptrDepth)) return nil } field.Set(v) return nil } func getCustomizedFieldDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int, decodeFunc CustomizeDecodeFunc, config *DecodeConfig) ([]fieldDecoder, error) { for idx, tagInfo := range tagInfos { switch tagInfo.Key { case pathTag: tagInfos[idx].SliceGetter = pathSlice tagInfos[idx].Getter = path case formTag: tagInfos[idx].SliceGetter = postFormSlice tagInfos[idx].Getter = postForm case queryTag: tagInfos[idx].SliceGetter = querySlice tagInfos[idx].Getter = query case cookieTag: tagInfos[idx].SliceGetter = cookieSlice tagInfos[idx].Getter = cookie case headerTag: tagInfos[idx].SliceGetter = headerSlice tagInfos[idx].Getter = header case jsonTag: // do nothing case rawBodyTag: tagInfos[idx].SliceGetter = rawBodySlice tagInfos[idx].Getter = rawBody case fileNameTag: // do nothing default: } } fieldType := field.Type for field.Type.Kind() == reflect.Ptr { fieldType = field.Type.Elem() } return []fieldDecoder{&customizedFieldTextDecoder{ fieldInfo: fieldInfo{ index: index, parentIndex: parentIdx, fieldName: field.Name, tagInfos: tagInfos, fieldType: fieldType, config: config, }, decodeFunc: decodeFunc, }}, nil } ================================================ FILE: pkg/app/server/binding/internal/decoder/decoder.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * MIT License * * Copyright (c) 2019-present Fenny and Contributors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2023 CloudWeGo Authors */ package decoder import ( "fmt" "mime/multipart" "reflect" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" ) type fieldDecoder interface { Decode(req *protocol.Request, params param.Params, reqValue reflect.Value) error } type Decoder func(req *protocol.Request, params param.Params, rv reflect.Value) error type DecodeConfig struct { LooseZeroMode bool DisableDefaultTag bool DisableStructFieldResolve bool EnableDecoderUseNumber bool EnableDecoderDisallowUnknownFields bool TypeUnmarshalFuncs map[reflect.Type]CustomizeDecodeFunc } func GetReqDecoder(rt reflect.Type, byTag string, config *DecodeConfig) (Decoder, error) { var decoders []fieldDecoder el := rt.Elem() if el.Kind() != reflect.Struct { return nil, fmt.Errorf("unsupported \"%s\" type binding", rt.String()) } for i := 0; i < el.NumField(); i++ { if el.Field(i).PkgPath != "" && !el.Field(i).Anonymous { // ignore unexported field continue } dec, err := getFieldDecoder(parentInfos{[]reflect.Type{el}, []int{}, ""}, el.Field(i), i, byTag, config) if err != nil { return nil, err } if dec != nil { decoders = append(decoders, dec...) } } return func(req *protocol.Request, params param.Params, rv reflect.Value) error { for _, decoder := range decoders { err := decoder.Decode(req, params, rv) if err != nil { return err } } return nil }, nil } type parentInfos struct { Types []reflect.Type Indexes []int JSONName string } func getFieldDecoder(pInfo parentInfos, field reflect.StructField, index int, byTag string, config *DecodeConfig) ([]fieldDecoder, error) { for field.Type.Kind() == reflect.Ptr { field.Type = field.Type.Elem() } // skip anonymous definitions, like: // type A struct { // string // } if field.Type.Kind() != reflect.Struct && field.Anonymous { return nil, nil } // JSONName is like 'a.b.c' for 'required validate' fieldTagInfos, newParentJSONName := lookupFieldTags(field, pInfo.JSONName, config) if len(fieldTagInfos) == 0 && !config.DisableDefaultTag { fieldTagInfos, newParentJSONName = getDefaultFieldTags(field, pInfo.JSONName) } if len(byTag) != 0 { fieldTagInfos = getFieldTagInfoByTag(field, byTag) } // customized type decoder has the highest priority if customizedFunc, exist := config.TypeUnmarshalFuncs[field.Type]; exist { dec, err := getCustomizedFieldDecoder(field, index, fieldTagInfos, pInfo.Indexes, customizedFunc, config) return dec, err } // slice/array field decoder if field.Type.Kind() == reflect.Slice || field.Type.Kind() == reflect.Array { dec, err := getSliceFieldDecoder(field, index, fieldTagInfos, pInfo.Indexes, config) return dec, err } // map filed decoder if field.Type.Kind() == reflect.Map { dec, err := getMapTypeTextDecoder(field, index, fieldTagInfos, pInfo.Indexes, config) return dec, err } // struct field will be resolved recursively if field.Type.Kind() == reflect.Struct { var decoders []fieldDecoder el := field.Type // todo: more built-in common struct binding, ex. time... switch el { case reflect.TypeOf(multipart.FileHeader{}): // file binding dec, err := getMultipartFileDecoder(field, index, fieldTagInfos, pInfo.Indexes, config) return dec, err } if !config.DisableStructFieldResolve { // decode struct type separately structFieldDecoder, err := getStructTypeFieldDecoder(field, index, fieldTagInfos, pInfo.Indexes, config) if err != nil { return nil, err } if structFieldDecoder != nil { decoders = append(decoders, structFieldDecoder...) } } // prevent infinite recursion when struct field with the same name as a struct if hasSameType(pInfo.Types, el) { return decoders, nil } pIdx := pInfo.Indexes if !field.Anonymous { pInfo.JSONName = newParentJSONName } for i := 0; i < el.NumField(); i++ { if el.Field(i).PkgPath != "" && !el.Field(i).Anonymous { // ignore unexported field continue } var idxes []int if len(pInfo.Indexes) > 0 { idxes = append(idxes, pIdx...) } idxes = append(idxes, index) pInfo.Indexes = idxes pInfo.Types = append(pInfo.Types, el) dec, err := getFieldDecoder(pInfo, el.Field(i), i, byTag, config) if err != nil { return nil, err } if dec != nil { decoders = append(decoders, dec...) } } return decoders, nil } // base type decoder dec, err := getBaseTypeTextDecoder(field, index, fieldTagInfos, pInfo.Indexes, config) return dec, err } // hasSameType determine if the same type is present in the parent-child relationship func hasSameType(pts []reflect.Type, ft reflect.Type) bool { for _, pt := range pts { if reflect.DeepEqual(getElemType(pt), getElemType(ft)) { return true } } return false } ================================================ FILE: pkg/app/server/binding/internal/decoder/getter.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * MIT License * * Copyright (c) 2019-present Fenny and Contributors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2023 CloudWeGo Authors */ package decoder import ( "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" ) type getter func(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string, exist bool) func path(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string, exist bool) { if params != nil { ret, exist = params.Get(key) } if len(ret) == 0 && len(defaultValue) != 0 { ret = defaultValue[0] } return ret, exist } func postForm(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string, exist bool) { if ret, exist = req.PostArgs().PeekExists(key); exist { return } mf, err := req.MultipartForm() if err == nil && mf.Value != nil { for k, v := range mf.Value { if k == key && len(v) > 0 { ret = v[0] } } } if len(ret) != 0 { return ret, true } if ret, exist = req.URI().QueryArgs().PeekExists(key); exist { return } if len(ret) == 0 && len(defaultValue) != 0 { ret = defaultValue[0] } return ret, false } func query(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string, exist bool) { if ret, exist = req.URI().QueryArgs().PeekExists(key); exist { return } if len(ret) == 0 && len(defaultValue) != 0 { ret = defaultValue[0] } return } func cookie(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string, exist bool) { if val := req.Header.Cookie(key); val != nil { ret = string(val) return ret, true } if len(ret) == 0 && len(defaultValue) != 0 { ret = defaultValue[0] } return ret, false } func header(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string, exist bool) { if val := req.Header.Peek(key); val != nil { ret = string(val) return ret, true } if len(ret) == 0 && len(defaultValue) != 0 { ret = defaultValue[0] } return ret, false } func rawBody(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string, exist bool) { exist = false if req.Header.ContentLength() > 0 { ret = string(req.Body()) exist = true } return } ================================================ FILE: pkg/app/server/binding/internal/decoder/gjson_required.go ================================================ // Copyright 2023 CloudWeGo Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // //go:build gjson || !(amd64 && (linux || windows || darwin)) package decoder import ( "strings" "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/tidwall/gjson" ) func checkRequireJSON(req *protocol.Request, tagInfo TagInfo) bool { if !tagInfo.Required { return true } ct := bytesconv.B2s(req.Header.ContentType()) if !strings.EqualFold(utils.FilterContentType(ct), consts.MIMEApplicationJSON) { return false } result := gjson.GetBytes(req.Body(), tagInfo.JSONName) if !result.Exists() { idx := strings.LastIndex(tagInfo.JSONName, ".") // There should be a superior if it is empty, it will report 'true' for required if idx > 0 && !gjson.GetBytes(req.Body(), tagInfo.JSONName[:idx]).Exists() { return true } return false } return true } func keyExist(req *protocol.Request, tagInfo TagInfo) bool { ct := bytesconv.B2s(req.Header.ContentType()) if utils.FilterContentType(ct) != consts.MIMEApplicationJSON { return false } result := gjson.GetBytes(req.Body(), tagInfo.JSONName) return result.Exists() } ================================================ FILE: pkg/app/server/binding/internal/decoder/map_type_decoder.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * MIT License * * Copyright (c) 2019-present Fenny and Contributors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2023 CloudWeGo Authors */ package decoder import ( "fmt" "reflect" "github.com/cloudwego/hertz/internal/bytesconv" hJson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" ) type mapTypeFieldTextDecoder struct { fieldInfo } func (d *mapTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Params, reqValue reflect.Value) error { var err error var text string var exist bool var defaultValue string for _, tagInfo := range d.tagInfos { if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { if tagInfo.Key == jsonTag { defaultValue = tagInfo.Default found := checkRequireJSON(req, tagInfo) if found { err = nil } else { err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", tagInfo.Value) } if len(tagInfo.Default) != 0 && keyExist(req, tagInfo) { defaultValue = "" } } continue } text, exist = tagInfo.Getter(req, params, tagInfo.Value) defaultValue = tagInfo.Default if exist { err = nil break } if tagInfo.Required { err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", tagInfo.Value) } } if err != nil { return err } if len(text) == 0 && len(defaultValue) != 0 { text = toDefaultValue(d.fieldType, defaultValue) } if !exist && len(text) == 0 { return nil } reqValue = GetFieldValue(reqValue, d.parentIndex) field := reqValue.Field(d.index) if field.Kind() == reflect.Ptr { t := field.Type() var ptrDepth int for t.Kind() == reflect.Ptr { t = t.Elem() ptrDepth++ } var vv reflect.Value vv, err := stringToValue(t, text, req, params, d.config) if err != nil { return fmt.Errorf("unable to decode '%s' as %s: %w", text, d.fieldType.Name(), err) } field.Set(ReferenceValue(vv, ptrDepth)) return nil } err = hJson.Unmarshal(bytesconv.S2b(text), field.Addr().Interface()) if err != nil { return fmt.Errorf("unable to decode '%s' as %s: %w", text, d.fieldType.Name(), err) } return nil } func getMapTypeTextDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int, config *DecodeConfig) ([]fieldDecoder, error) { for idx, tagInfo := range tagInfos { switch tagInfo.Key { case pathTag: tagInfos[idx].SliceGetter = pathSlice tagInfos[idx].Getter = path case formTag: tagInfos[idx].SliceGetter = postFormSlice tagInfos[idx].Getter = postForm case queryTag: tagInfos[idx].SliceGetter = querySlice tagInfos[idx].Getter = query case cookieTag: tagInfos[idx].SliceGetter = cookieSlice tagInfos[idx].Getter = cookie case headerTag: tagInfos[idx].SliceGetter = headerSlice tagInfos[idx].Getter = header case jsonTag: // do nothing case rawBodyTag: tagInfos[idx].SliceGetter = rawBodySlice tagInfos[idx].Getter = rawBody case fileNameTag: // do nothing default: } } fieldType := field.Type for field.Type.Kind() == reflect.Ptr { fieldType = field.Type.Elem() } return []fieldDecoder{&mapTypeFieldTextDecoder{ fieldInfo: fieldInfo{ index: index, parentIndex: parentIdx, fieldName: field.Name, tagInfos: tagInfos, fieldType: fieldType, config: config, }, }}, nil } ================================================ FILE: pkg/app/server/binding/internal/decoder/multipart_file_decoder.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package decoder import ( "fmt" "reflect" "github.com/cloudwego/hertz/pkg/common/hlog" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" ) type fileTypeDecoder struct { fieldInfo isRepeated bool } func (d *fileTypeDecoder) Decode(req *protocol.Request, params param.Params, reqValue reflect.Value) error { fieldValue := GetFieldValue(reqValue, d.parentIndex) field := fieldValue.Field(d.index) if d.isRepeated { return d.fileSliceDecode(req, params, reqValue) } var fileName string // file_name > form > fieldName for _, tagInfo := range d.tagInfos { if tagInfo.Key == fileNameTag { fileName = tagInfo.Value break } if tagInfo.Key == formTag { fileName = tagInfo.Value } } if len(fileName) == 0 { fileName = d.fieldName } file, err := req.FormFile(fileName) if err != nil { hlog.SystemLogger().Warnf("can not get file '%s' form request, reason: %v, so skip '%s' field binding", fileName, err, d.fieldName) return nil } if field.Kind() == reflect.Ptr { t := field.Type() var ptrDepth int for t.Kind() == reflect.Ptr { t = t.Elem() ptrDepth++ } v := reflect.New(t).Elem() v.Set(reflect.ValueOf(*file)) field.Set(ReferenceValue(v, ptrDepth)) return nil } // Non-pointer elems field.Set(reflect.ValueOf(*file)) return nil } func (d *fileTypeDecoder) fileSliceDecode(req *protocol.Request, params param.Params, reqValue reflect.Value) error { fieldValue := GetFieldValue(reqValue, d.parentIndex) field := fieldValue.Field(d.index) // 如果没值,需要为其建一个值 if field.Kind() == reflect.Ptr { if field.IsNil() { nonNilVal, ptrDepth := GetNonNilReferenceValue(field) field.Set(ReferenceValue(nonNilVal, ptrDepth)) } } var parentPtrDepth int for field.Kind() == reflect.Ptr { field = field.Elem() parentPtrDepth++ } var fileName string // file_name > form > fieldName for _, tagInfo := range d.tagInfos { if tagInfo.Key == fileNameTag { fileName = tagInfo.Value break } if tagInfo.Key == formTag { fileName = tagInfo.Value } } if len(fileName) == 0 { fileName = d.fieldName } multipartForm, err := req.MultipartForm() if err != nil { hlog.SystemLogger().Warnf("can not get MultipartForm from request, reason: %v, so skip '%s' field binding", fileName, err, d.fieldName) return nil } files, exist := multipartForm.File[fileName] if !exist { hlog.SystemLogger().Warnf("the file '%s' is not existed in request, so skip '%s' field binding", fileName, d.fieldName) return nil } if field.Kind() == reflect.Array { if len(files) != field.Len() { return fmt.Errorf("the numbers(%d) of file '%s' does not match the length(%d) of %s", len(files), fileName, field.Len(), field.Type().String()) } } else { // slice need creating enough capacity field = reflect.MakeSlice(field.Type(), len(files), len(files)) } // handle multiple pointer var ptrDepth int t := d.fieldType.Elem() elemKind := t.Kind() for elemKind == reflect.Ptr { t = t.Elem() elemKind = t.Kind() ptrDepth++ } for idx, file := range files { v := reflect.New(t).Elem() v.Set(reflect.ValueOf(*file)) field.Index(idx).Set(ReferenceValue(v, ptrDepth)) } fieldValue.Field(d.index).Set(ReferenceValue(field, parentPtrDepth)) return nil } func getMultipartFileDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int, config *DecodeConfig) ([]fieldDecoder, error) { fieldType := field.Type for field.Type.Kind() == reflect.Ptr { fieldType = field.Type.Elem() } isRepeated := false if fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice { isRepeated = true } return []fieldDecoder{&fileTypeDecoder{ fieldInfo: fieldInfo{ index: index, parentIndex: parentIdx, fieldName: field.Name, tagInfos: tagInfos, fieldType: fieldType, config: config, }, isRepeated: isRepeated, }}, nil } ================================================ FILE: pkg/app/server/binding/internal/decoder/reflect.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * MIT License * * Copyright (c) 2019-present Fenny and Contributors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2023 CloudWeGo Authors */ package decoder import ( "reflect" ) // ReferenceValue convert T to *T, the ptrDepth is the count of '*'. func ReferenceValue(v reflect.Value, ptrDepth int) reflect.Value { switch { case ptrDepth > 0: for ; ptrDepth > 0; ptrDepth-- { vv := reflect.New(v.Type()) vv.Elem().Set(v) v = vv } case ptrDepth < 0: for ; ptrDepth < 0 && v.Kind() == reflect.Ptr; ptrDepth++ { v = v.Elem() } } return v } func GetNonNilReferenceValue(v reflect.Value) (reflect.Value, int) { var ptrDepth int t := v.Type() elemKind := t.Kind() for elemKind == reflect.Ptr { t = t.Elem() elemKind = t.Kind() ptrDepth++ } val := reflect.New(t).Elem() return val, ptrDepth } func GetFieldValue(reqValue reflect.Value, parentIndex []int) reflect.Value { // reqValue -> (***bar)(nil) need new a default if reqValue.Kind() == reflect.Ptr && reqValue.IsNil() { nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) reqValue = ReferenceValue(nonNilVal, ptrDepth) } for _, idx := range parentIndex { if reqValue.Kind() == reflect.Ptr && reqValue.IsNil() { nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) reqValue.Set(ReferenceValue(nonNilVal, ptrDepth)) } for reqValue.Kind() == reflect.Ptr { reqValue = reqValue.Elem() } reqValue = reqValue.Field(idx) } // It is possible that the parent struct is also a pointer, // so need to create a non-nil reflect.Value for it at runtime. for reqValue.Kind() == reflect.Ptr { if reqValue.IsNil() { nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) reqValue.Set(ReferenceValue(nonNilVal, ptrDepth)) } reqValue = reqValue.Elem() } return reqValue } func getElemType(t reflect.Type) reflect.Type { for t.Kind() == reflect.Ptr { t = t.Elem() } return t } ================================================ FILE: pkg/app/server/binding/internal/decoder/slice_getter.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * MIT License * * Copyright (c) 2019-present Fenny and Contributors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2023 CloudWeGo Authors */ package decoder import ( "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" ) type sliceGetter func(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret []string) func pathSlice(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret []string) { var value string if params != nil { value, _ = params.Get(key) } if len(value) == 0 && len(defaultValue) != 0 { value = defaultValue[0] } if len(value) != 0 { ret = append(ret, value) } return } func postFormSlice(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret []string) { req.PostArgs().VisitAll(func(formKey, value []byte) { if bytesconv.B2s(formKey) == key { ret = append(ret, string(value)) } }) if len(ret) > 0 { return } mf, err := req.MultipartForm() if err == nil && mf.Value != nil { for k, v := range mf.Value { if k == key && len(v) > 0 { ret = append(ret, v...) } } } if len(ret) > 0 { return } if len(ret) == 0 && len(defaultValue) != 0 { ret = append(ret, defaultValue...) } return } func querySlice(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret []string) { req.URI().QueryArgs().VisitAll(func(queryKey, value []byte) { if key == bytesconv.B2s(queryKey) { ret = append(ret, string(value)) } }) if len(ret) == 0 && len(defaultValue) != 0 { ret = append(ret, defaultValue...) } return } func cookieSlice(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret []string) { req.Header.VisitAllCookie(func(cookieKey, value []byte) { if bytesconv.B2s(cookieKey) == key { ret = append(ret, string(value)) } }) if len(ret) == 0 && len(defaultValue) != 0 { ret = append(ret, defaultValue...) } return } func headerSlice(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret []string) { ret = defaultValue if vv := req.Header.GetAll(key); len(vv) > 0 { ret = vv } return } func rawBodySlice(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret []string) { if req.Header.ContentLength() > 0 { ret = append(ret, string(req.Body())) } return } ================================================ FILE: pkg/app/server/binding/internal/decoder/slice_type_decoder.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * MIT License * * Copyright (c) 2019-present Fenny and Contributors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2023 CloudWeGo Authors */ package decoder import ( "fmt" "mime/multipart" "reflect" "github.com/cloudwego/hertz/internal/bytesconv" hJson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" ) type sliceTypeFieldTextDecoder struct { fieldInfo isArray bool } func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Params, reqValue reflect.Value) error { var err error var texts []string var defaultValue string var bindRawBody bool var isDefault bool for _, tagInfo := range d.tagInfos { if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { if tagInfo.Key == jsonTag { defaultValue = tagInfo.Default found := checkRequireJSON(req, tagInfo) if found { err = nil } else { err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", tagInfo.Value) } if len(tagInfo.Default) != 0 && keyExist(req, tagInfo) { // defaultValue = "" } } continue } if tagInfo.Key == rawBodyTag { bindRawBody = true } texts = tagInfo.SliceGetter(req, params, tagInfo.Value) defaultValue = tagInfo.Default if len(texts) != 0 { err = nil break } if tagInfo.Required { err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", tagInfo.Value) } } if err != nil { return err } if len(texts) == 0 && len(defaultValue) != 0 { defaultValue = toDefaultValue(d.fieldType, defaultValue) texts = append(texts, defaultValue) isDefault = true } if len(texts) == 0 { return nil } reqValue = GetFieldValue(reqValue, d.parentIndex) field := reqValue.Field(d.index) // **[]**int if field.Kind() == reflect.Ptr { if field.IsNil() { nonNilVal, ptrDepth := GetNonNilReferenceValue(field) field.Set(ReferenceValue(nonNilVal, ptrDepth)) } } var parentPtrDepth int for field.Kind() == reflect.Ptr { field = field.Elem() parentPtrDepth++ } if d.isArray { if len(texts) != field.Len() && !isDefault { return fmt.Errorf("%q is not valid value for %s", texts, field.Type().String()) } } else { // slice need creating enough capacity field = reflect.MakeSlice(field.Type(), len(texts), len(texts)) } // raw_body && []byte binding if bindRawBody && field.Type().Elem().Kind() == reflect.Uint8 { reqValue.Field(d.index).Set(reflect.ValueOf(req.Body())) return nil } // handle internal multiple pointer, []**int var ptrDepth int t := d.fieldType.Elem() // d.fieldType is non-pointer type for the field elemKind := t.Kind() for elemKind == reflect.Ptr { t = t.Elem() elemKind = t.Kind() ptrDepth++ } if isDefault { err = hJson.Unmarshal(bytesconv.S2b(texts[0]), reqValue.Field(d.index).Addr().Interface()) if err != nil { return fmt.Errorf("using '%s' to unmarshal field '%s: %s' failed, %v", texts[0], d.fieldName, d.fieldType.String(), err) } return nil } for idx, text := range texts { var vv reflect.Value vv, err = stringToValue(t, text, req, params, d.config) if err != nil { break } field.Index(idx).Set(ReferenceValue(vv, ptrDepth)) } if err != nil { if !reqValue.Field(d.index).CanAddr() { return err } // text[0] can be a complete json content for []Type. err = hJson.Unmarshal(bytesconv.S2b(texts[0]), reqValue.Field(d.index).Addr().Interface()) if err != nil { return fmt.Errorf("using '%s' to unmarshal field '%s: %s' failed, %v", texts[0], d.fieldName, d.fieldType.String(), err) } } else { reqValue.Field(d.index).Set(ReferenceValue(field, parentPtrDepth)) } return nil } func getSliceFieldDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int, config *DecodeConfig) ([]fieldDecoder, error) { if !(field.Type.Kind() == reflect.Slice || field.Type.Kind() == reflect.Array) { return nil, fmt.Errorf("unexpected type %s, expected slice or array", field.Type.String()) } isArray := false if field.Type.Kind() == reflect.Array { isArray = true } for idx, tagInfo := range tagInfos { switch tagInfo.Key { case pathTag: tagInfos[idx].SliceGetter = pathSlice tagInfos[idx].Getter = path case formTag: tagInfos[idx].SliceGetter = postFormSlice tagInfos[idx].Getter = postForm case queryTag: tagInfos[idx].SliceGetter = querySlice tagInfos[idx].Getter = query case cookieTag: tagInfos[idx].SliceGetter = cookieSlice tagInfos[idx].Getter = cookie case headerTag: tagInfos[idx].SliceGetter = headerSlice tagInfos[idx].Getter = header case jsonTag: // do nothing case rawBodyTag: tagInfos[idx].SliceGetter = rawBodySlice tagInfos[idx].Getter = rawBody case fileNameTag: // do nothing default: } } fieldType := field.Type for field.Type.Kind() == reflect.Ptr { fieldType = field.Type.Elem() } // fieldType.Elem() is the type for array/slice elem t := getElemType(fieldType.Elem()) if t == reflect.TypeOf(multipart.FileHeader{}) { return getMultipartFileDecoder(field, index, tagInfos, parentIdx, config) } return []fieldDecoder{&sliceTypeFieldTextDecoder{ fieldInfo: fieldInfo{ index: index, parentIndex: parentIdx, fieldName: field.Name, tagInfos: tagInfos, fieldType: fieldType, config: config, }, isArray: isArray, }}, nil } ================================================ FILE: pkg/app/server/binding/internal/decoder/sonic_required.go ================================================ // Copyright 2023 CloudWeGo Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // //go:build (linux || windows || darwin) && amd64 && !gjson package decoder import ( "strings" "github.com/bytedance/sonic" "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" ) func checkRequireJSON(req *protocol.Request, tagInfo TagInfo) bool { if !tagInfo.Required { return true } ct := bytesconv.B2s(req.Header.ContentType()) if !strings.EqualFold(utils.FilterContentType(ct), consts.MIMEApplicationJSON) { return false } node, _ := sonic.Get(req.Body(), stringSliceForInterface(tagInfo.JSONName)...) if !node.Exists() { idx := strings.LastIndex(tagInfo.JSONName, ".") if idx > 0 { // There should be a superior if it is empty, it will report 'true' for required node, _ := sonic.Get(req.Body(), stringSliceForInterface(tagInfo.JSONName[:idx])...) if !node.Exists() { return true } } return false } return true } func stringSliceForInterface(s string) (ret []interface{}) { x := strings.Split(s, ".") for _, val := range x { ret = append(ret, val) } return } func keyExist(req *protocol.Request, tagInfo TagInfo) bool { ct := bytesconv.B2s(req.Header.ContentType()) if utils.FilterContentType(ct) != consts.MIMEApplicationJSON { return false } node, _ := sonic.Get(req.Body(), stringSliceForInterface(tagInfo.JSONName)...) return node.Exists() } ================================================ FILE: pkg/app/server/binding/internal/decoder/struct_type_decoder.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package decoder import ( "fmt" "reflect" "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/pkg/common/hlog" hjson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" ) type structTypeFieldTextDecoder struct { fieldInfo } func (d *structTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Params, reqValue reflect.Value) error { var err error var text string var exist bool var defaultValue string for _, tagInfo := range d.tagInfos { if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { if tagInfo.Key == jsonTag { defaultValue = tagInfo.Default found := checkRequireJSON(req, tagInfo) if found { err = nil } else { err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", tagInfo.Value) } if len(tagInfo.Default) != 0 && keyExist(req, tagInfo) { defaultValue = "" } } continue } text, exist = tagInfo.Getter(req, params, tagInfo.Value) defaultValue = tagInfo.Default if exist { err = nil break } if tagInfo.Required { err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", tagInfo.Value) } } if err != nil { return err } if len(text) == 0 && len(defaultValue) != 0 { text = toDefaultValue(d.fieldType, defaultValue) } if !exist && len(text) == 0 { return nil } reqValue = GetFieldValue(reqValue, d.parentIndex) field := reqValue.Field(d.index) if field.Kind() == reflect.Ptr { t := field.Type() var ptrDepth int for t.Kind() == reflect.Ptr { t = t.Elem() ptrDepth++ } var vv reflect.Value vv, err := stringToValue(t, text, req, params, d.config) if err != nil { hlog.SystemLogger().Infof("unable to decode '%s' as %s: %v, but it may not affect correctness, so skip it", text, d.fieldType.Name(), err) return nil } field.Set(ReferenceValue(vv, ptrDepth)) return nil } err = hjson.Unmarshal(bytesconv.S2b(text), field.Addr().Interface()) if err != nil { hlog.SystemLogger().Infof("unable to decode '%s' as %s: %v, but it may not affect correctness, so skip it", text, d.fieldType.Name(), err) } return nil } func getStructTypeFieldDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int, config *DecodeConfig) ([]fieldDecoder, error) { for idx, tagInfo := range tagInfos { switch tagInfo.Key { case pathTag: tagInfos[idx].SliceGetter = pathSlice tagInfos[idx].Getter = path case formTag: tagInfos[idx].SliceGetter = postFormSlice tagInfos[idx].Getter = postForm case queryTag: tagInfos[idx].SliceGetter = querySlice tagInfos[idx].Getter = query case cookieTag: tagInfos[idx].SliceGetter = cookieSlice tagInfos[idx].Getter = cookie case headerTag: tagInfos[idx].SliceGetter = headerSlice tagInfos[idx].Getter = header case jsonTag: // do nothing case rawBodyTag: tagInfos[idx].SliceGetter = rawBodySlice tagInfos[idx].Getter = rawBody case fileNameTag: // do nothing default: } } fieldType := field.Type for field.Type.Kind() == reflect.Ptr { fieldType = field.Type.Elem() } return []fieldDecoder{&structTypeFieldTextDecoder{ fieldInfo: fieldInfo{ index: index, parentIndex: parentIdx, fieldName: field.Name, tagInfos: tagInfos, fieldType: fieldType, config: config, }, }}, nil } ================================================ FILE: pkg/app/server/binding/internal/decoder/tag.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package decoder import ( "reflect" "strings" ) const ( pathTag = "path" formTag = "form" queryTag = "query" cookieTag = "cookie" headerTag = "header" jsonTag = "json" rawBodyTag = "raw_body" fileNameTag = "file_name" ) const ( defaultTag = "default" ) const ( requiredTagOpt = "required" ) type TagInfo struct { Key string Value string JSONName string Required bool Skip bool Default string Options []string Getter getter SliceGetter sliceGetter } func head(str, sep string) (head, tail string) { idx := strings.Index(str, sep) if idx < 0 { return str, "" } return str[:idx], str[idx+len(sep):] } func lookupFieldTags(field reflect.StructField, parentJSONName string, config *DecodeConfig) ([]TagInfo, string) { var ret []string tags := []string{pathTag, formTag, queryTag, cookieTag, headerTag, jsonTag, rawBodyTag, fileNameTag} for _, tag := range tags { if _, ok := field.Tag.Lookup(tag); ok { ret = append(ret, tag) } } defaultVal := "" if val, ok := field.Tag.Lookup(defaultTag); ok { defaultVal = val } var tagInfos []TagInfo var newParentJSONName string for _, tag := range ret { tagContent := field.Tag.Get(tag) tagValue, opts := head(tagContent, ",") if len(tagValue) == 0 { tagValue = field.Name } skip := false jsonName := parentJSONName + "." + field.Name if tag == jsonTag { jsonName = parentJSONName + "." + tagValue } if tagValue == "-" { skip = true if tag == jsonTag { jsonName = parentJSONName + "." + field.Name } } if jsonName != "" { jsonName = strings.TrimPrefix(jsonName, ".") newParentJSONName = jsonName } var options []string var opt string var required bool for len(opts) > 0 { opt, opts = head(opts, ",") options = append(options, opt) if opt == requiredTagOpt { required = true } } tagInfos = append(tagInfos, TagInfo{Key: tag, Value: tagValue, JSONName: jsonName, Options: options, Required: required, Default: defaultVal, Skip: skip}) } if len(newParentJSONName) == 0 { newParentJSONName = strings.TrimPrefix(parentJSONName+"."+field.Name, ".") } return tagInfos, newParentJSONName } func getDefaultFieldTags(field reflect.StructField, parentJSONName string) (tagInfos []TagInfo, newParentJSONName string) { defaultVal := "" if val, ok := field.Tag.Lookup(defaultTag); ok { defaultVal = val } tags := []string{pathTag, formTag, queryTag, cookieTag, headerTag, jsonTag, fileNameTag} for _, tag := range tags { jsonName := strings.TrimPrefix(parentJSONName+"."+field.Name, ".") tagInfos = append(tagInfos, TagInfo{Key: tag, Value: field.Name, Default: defaultVal, JSONName: jsonName}) } newParentJSONName = strings.TrimPrefix(parentJSONName+"."+field.Name, ".") return } func getFieldTagInfoByTag(field reflect.StructField, tag string) []TagInfo { var tagInfos []TagInfo if content, ok := field.Tag.Lookup(tag); ok { tagValue, opts := head(content, ",") if len(tagValue) == 0 { tagValue = field.Name } skip := false if tagValue == "-" { skip = true } defaultVal := "" if val, ok := field.Tag.Lookup(defaultTag); ok { defaultVal = val } var options []string var opt string var required bool for len(opts) > 0 { opt, opts = head(opts, ",") options = append(options, opt) if opt == requiredTagOpt { required = true } } tagInfos = append(tagInfos, TagInfo{Key: tag, Value: tagValue, Options: options, Required: required, Default: defaultVal, Skip: skip}) } else { tagInfos = append(tagInfos, TagInfo{Key: tag, Value: field.Name}) } return tagInfos } ================================================ FILE: pkg/app/server/binding/internal/decoder/text_decoder.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * MIT License * * Copyright (c) 2019-present Fenny and Contributors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2023 CloudWeGo Authors */ package decoder import ( "fmt" "reflect" "strconv" "github.com/cloudwego/hertz/internal/bytesconv" hJson "github.com/cloudwego/hertz/pkg/common/json" ) type TextDecoder interface { UnmarshalString(s string, fieldValue reflect.Value, looseZeroMode bool) error } func SelectTextDecoder(rt reflect.Type) (TextDecoder, error) { switch rt.Kind() { case reflect.Bool: return &boolDecoder{}, nil case reflect.Uint8: return &uintDecoder{bitSize: 8}, nil case reflect.Uint16: return &uintDecoder{bitSize: 16}, nil case reflect.Uint32: return &uintDecoder{bitSize: 32}, nil case reflect.Uint64: return &uintDecoder{bitSize: 64}, nil case reflect.Uint: return &uintDecoder{}, nil case reflect.Int8: return &intDecoder{bitSize: 8}, nil case reflect.Int16: return &intDecoder{bitSize: 16}, nil case reflect.Int32: return &intDecoder{bitSize: 32}, nil case reflect.Int64: return &intDecoder{bitSize: 64}, nil case reflect.Int: return &intDecoder{}, nil case reflect.String: return &stringDecoder{}, nil case reflect.Float32: return &floatDecoder{bitSize: 32}, nil case reflect.Float64: return &floatDecoder{bitSize: 64}, nil case reflect.Interface: return &interfaceDecoder{}, nil } return nil, fmt.Errorf("unsupported type %s", rt.String()) } type boolDecoder struct{} func (d *boolDecoder) UnmarshalString(s string, fieldValue reflect.Value, looseZeroMode bool) error { if s == "" && looseZeroMode { s = "false" } v, err := strconv.ParseBool(s) if err != nil { return err } fieldValue.SetBool(v) return nil } type floatDecoder struct { bitSize int } func (d *floatDecoder) UnmarshalString(s string, fieldValue reflect.Value, looseZeroMode bool) error { if s == "" && looseZeroMode { s = "0.0" } v, err := strconv.ParseFloat(s, d.bitSize) if err != nil { return err } fieldValue.SetFloat(v) return nil } type intDecoder struct { bitSize int } func (d *intDecoder) UnmarshalString(s string, fieldValue reflect.Value, looseZeroMode bool) error { if s == "" && looseZeroMode { s = "0" } v, err := strconv.ParseInt(s, 10, d.bitSize) if err != nil { return err } fieldValue.SetInt(v) return nil } type stringDecoder struct{} func (d *stringDecoder) UnmarshalString(s string, fieldValue reflect.Value, looseZeroMode bool) error { fieldValue.SetString(s) return nil } type uintDecoder struct { bitSize int } func (d *uintDecoder) UnmarshalString(s string, fieldValue reflect.Value, looseZeroMode bool) error { if s == "" && looseZeroMode { s = "0" } v, err := strconv.ParseUint(s, 10, d.bitSize) if err != nil { return err } fieldValue.SetUint(v) return nil } type interfaceDecoder struct{} func (d *interfaceDecoder) UnmarshalString(s string, fieldValue reflect.Value, looseZeroMode bool) error { if s == "" && looseZeroMode { s = "0" } return hJson.Unmarshal(bytesconv.S2b(s), fieldValue.Addr().Interface()) } ================================================ FILE: pkg/app/server/binding/internal/decoder/util.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package decoder import ( "encoding" "fmt" "reflect" "strings" "github.com/cloudwego/hertz/internal/bytesconv" hJson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" ) const ( specialChar = "\x07" ) // toDefaultValue will preprocess the default value and transfer it to be standard format func toDefaultValue(typ reflect.Type, defaultValue string) string { switch typ.Kind() { case reflect.Slice, reflect.Array, reflect.Map, reflect.Struct: // escape single quote and double quote, replace single quote with double quote defaultValue = strings.Replace(defaultValue, `"`, `\"`, -1) defaultValue = strings.Replace(defaultValue, `\'`, specialChar, -1) defaultValue = strings.Replace(defaultValue, `'`, `"`, -1) defaultValue = strings.Replace(defaultValue, specialChar, `'`, -1) } return defaultValue } // stringToValue is used to dynamically create reflect.Value for 'text' func stringToValue(elemType reflect.Type, text string, req *protocol.Request, params param.Params, config *DecodeConfig) (reflect.Value, error) { if customizedFunc, exist := config.TypeUnmarshalFuncs[elemType]; exist { val, err := customizedFunc(req, params, text) if err != nil { return reflect.Value{}, err } return val, nil } v := reflect.New(elemType) if tryTextUnmarshaler(v, text) { return v.Elem(), nil } switch elemType.Kind() { case reflect.Struct, reflect.Map: if err := hJson.Unmarshal(bytesconv.S2b(text), v.Interface()); err != nil { return reflect.Value{}, err } return v.Elem(), nil case reflect.Array, reflect.Slice: // do nothing return v.Elem(), nil default: decoder, err := SelectTextDecoder(elemType) if err != nil { return reflect.Value{}, err } v = v.Elem() err = decoder.UnmarshalString(text, v, config.LooseZeroMode) if err != nil { return reflect.Value{}, fmt.Errorf("unable to decode '%s' as %s: %w", text, elemType.String(), err) } return v, nil } } func tryTextUnmarshaler(v reflect.Value, s string) bool { enc, ok := v.Interface().(encoding.TextUnmarshaler) if ok { if err := enc.UnmarshalText(bytesconv.S2b(s)); err == nil { return true } } return false } ================================================ FILE: pkg/app/server/binding/internal/decoder/util_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package decoder import ( "encoding" "errors" "reflect" "testing" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" ) type testTextUnmarshaler struct { Value string } func (t *testTextUnmarshaler) UnmarshalText(text []byte) error { t.Value = string(text) return nil } var _ encoding.TextUnmarshaler = (*testTextUnmarshaler)(nil) func TestStringToValue(t *testing.T) { tests := []struct { name string elemType reflect.Type text string config *DecodeConfig expectValue interface{} expectError bool }{ { name: "string type", elemType: reflect.TypeOf(""), text: "test string", expectValue: "test string", }, { name: "int type", elemType: reflect.TypeOf(0), text: "42", expectValue: 42, }, { name: "bool type", elemType: reflect.TypeOf(false), text: "true", expectValue: true, }, { name: "float type", elemType: reflect.TypeOf(0.0), text: "3.14", expectValue: 3.14, }, { name: "text unmarshaler", elemType: reflect.TypeOf(testTextUnmarshaler{}), text: "custom text", expectValue: testTextUnmarshaler{Value: "custom text"}, }, { name: "invalid int", elemType: reflect.TypeOf(0), text: "not an int", expectError: true, }, { name: "struct type", elemType: reflect.TypeOf(struct{ Name string }{}), text: `{"Name":"test"}`, expectValue: struct{ Name string }{Name: "test"}, }, { name: "struct type err", elemType: reflect.TypeOf(struct{ Name string }{}), text: `{"Name":1}`, expectError: true, }, { name: "list type", elemType: reflect.TypeOf([]int{}), expectValue: *new([]int), }, { name: "map type", elemType: reflect.TypeOf(map[string]interface{}{}), text: `{"key":"value"}`, expectValue: map[string]interface{}{"key": "value"}, }, { name: "unsupported type", elemType: reflect.TypeOf(complex64(0)), expectError: true, }, { name: "custom type unmarshal func", elemType: reflect.TypeOf(testTextUnmarshaler{}), text: "custom func", config: &DecodeConfig{ TypeUnmarshalFuncs: map[reflect.Type]CustomizeDecodeFunc{ reflect.TypeOf(testTextUnmarshaler{}): func(req *protocol.Request, params param.Params, text string) (reflect.Value, error) { return reflect.ValueOf(testTextUnmarshaler{Value: "from custom func"}), nil }, }, }, expectValue: testTextUnmarshaler{Value: "from custom func"}, }, { name: "custom type unmarshal func err", elemType: reflect.TypeOf(testTextUnmarshaler{}), config: &DecodeConfig{ TypeUnmarshalFuncs: map[reflect.Type]CustomizeDecodeFunc{ reflect.TypeOf(testTextUnmarshaler{}): func(req *protocol.Request, params param.Params, text string) (reflect.Value, error) { return reflect.Value{}, errors.New("err") }, }, }, expectError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := &protocol.Request{} params := param.Params{} config := tt.config if config == nil { config = &DecodeConfig{} } val, err := stringToValue(tt.elemType, tt.text, req, params, config) if tt.expectError { assert.NotNil(t, err) return } assert.Nil(t, err) assert.DeepEqual(t, tt.expectValue, val.Interface()) }) } } func TestTryTextUnmarshaler(t *testing.T) { tests := []struct { name string value interface{} text string expected bool }{ { name: "text unmarshaler", value: &testTextUnmarshaler{}, text: "test text", expected: true, }, { name: "non text unmarshaler", value: &struct{}{}, text: "test text", expected: false, }, { name: "nil value", value: nil, text: "test text", expected: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var v reflect.Value if tt.value != nil { v = reflect.ValueOf(tt.value) } else { v = reflect.ValueOf(&tt.value).Elem() } result := tryTextUnmarshaler(v, tt.text) assert.DeepEqual(t, tt.expected, result) if tt.expected && tt.value != nil { // Verify the value was actually set unmarshaler := tt.value.(*testTextUnmarshaler) assert.DeepEqual(t, tt.text, unmarshaler.Value) } }) } } ================================================ FILE: pkg/app/server/binding/reflect.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * MIT License * * Copyright (c) 2019-present Fenny and Contributors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in all * copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2023 CloudWeGo Authors */ package binding import ( "fmt" "reflect" "unsafe" ) func valueAndTypeID(v interface{}) (reflect.Value, uintptr) { header := (*emptyInterface)(unsafe.Pointer(&v)) rv := reflect.ValueOf(v) return rv, header.typeID } type emptyInterface struct { typeID uintptr dataPtr unsafe.Pointer } func checkPointer(rv reflect.Value) error { if rv.Kind() != reflect.Ptr || rv.IsNil() { return fmt.Errorf("receiver must be a non-nil pointer") } return nil } // dereferenceType recursively dereferences pointer types to get the underlying type. func dereferenceType(t reflect.Type) reflect.Type { for t.Kind() == reflect.Pointer { t = t.Elem() } return t } ================================================ FILE: pkg/app/server/binding/reflect_internal_test.go ================================================ // Copyright 2023 CloudWeGo Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // package binding import ( "reflect" "testing" "github.com/cloudwego/hertz/pkg/app/server/binding/internal/decoder" "github.com/cloudwego/hertz/pkg/common/test/assert" ) type foo2 struct { F1 string } type fooq struct { F1 **string } func Test_ReferenceValue(t *testing.T) { foo1 := foo2{F1: "f1"} foo1Val := reflect.ValueOf(foo1) foo1PointerVal := decoder.ReferenceValue(foo1Val, 5) assert.DeepEqual(t, "f1", foo1.F1) assert.DeepEqual(t, "f1", foo1Val.Field(0).Interface().(string)) if foo1PointerVal.Kind() != reflect.Ptr { t.Errorf("expect a pointer, but get nil") } assert.DeepEqual(t, "*****binding.foo2", foo1PointerVal.Type().String()) deFoo1PointerVal := decoder.ReferenceValue(foo1PointerVal, -5) if deFoo1PointerVal.Kind() == reflect.Ptr { t.Errorf("expect a non-pointer, but get a pointer") } assert.DeepEqual(t, "f1", deFoo1PointerVal.Field(0).Interface().(string)) } func Test_GetNonNilReferenceValue(t *testing.T) { foo1 := (****foo)(nil) foo1Val := reflect.ValueOf(foo1) foo1ValNonNil, ptrDepth := decoder.GetNonNilReferenceValue(foo1Val) if !foo1ValNonNil.IsValid() { t.Errorf("expect a valid value, but get nil") } if !foo1ValNonNil.CanSet() { t.Errorf("expect can set value, but not") } foo1ReferPointer := decoder.ReferenceValue(foo1ValNonNil, ptrDepth) if foo1ReferPointer.Kind() != reflect.Ptr { t.Errorf("expect a pointer, but get nil") } } func Test_GetFieldValue(t *testing.T) { type bar struct { B1 **fooq } bar1 := (***bar)(nil) parentIdx := []int{0} idx := 0 bar1Val := reflect.ValueOf(bar1) parentFieldVal := decoder.GetFieldValue(bar1Val, parentIdx) if parentFieldVal.Kind() == reflect.Ptr { t.Errorf("expect a non-pointer, but get a pointer") } if !parentFieldVal.CanSet() { t.Errorf("expect can set value, but not") } fooFieldVal := parentFieldVal.Field(idx) assert.DeepEqual(t, "**string", fooFieldVal.Type().String()) if !fooFieldVal.CanSet() { t.Errorf("expect can set value, but not") } } ================================================ FILE: pkg/app/server/binding/reflect_test.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package binding import ( "reflect" "testing" "github.com/cloudwego/hertz/pkg/common/test/assert" ) type foo struct { f1 string } func TestReflect_TypeID(t *testing.T) { _, intType := valueAndTypeID(int(1)) _, uintType := valueAndTypeID(uint(1)) _, shouldBeIntType := valueAndTypeID(int(1)) assert.DeepEqual(t, intType, shouldBeIntType) assert.NotEqual(t, intType, uintType) foo1 := foo{f1: "1"} foo2 := foo{f1: "2"} _, foo1Type := valueAndTypeID(foo1) _, foo2Type := valueAndTypeID(foo2) _, foo2PointerType := valueAndTypeID(&foo2) _, foo1PointerType := valueAndTypeID(&foo1) assert.DeepEqual(t, foo1Type, foo2Type) assert.NotEqual(t, foo1Type, foo2PointerType) assert.DeepEqual(t, foo1PointerType, foo2PointerType) } func TestReflect_CheckPointer(t *testing.T) { foo1 := foo{} foo1Val := reflect.ValueOf(foo1) err := checkPointer(foo1Val) if err == nil { t.Errorf("expect an err, but get nil") } foo2 := &foo{} foo2Val := reflect.ValueOf(foo2) err = checkPointer(foo2Val) if err != nil { t.Error(err) } foo3 := (*foo)(nil) foo3Val := reflect.ValueOf(foo3) err = checkPointer(foo3Val) if err == nil { t.Errorf("expect an err, but get nil") } } func TestReflect_DereferenceType(t *testing.T) { var foo1 ****foo foo1Val := reflect.ValueOf(foo1) rt := dereferenceType(foo1Val.Type()) assert.DeepEqual(t, "foo", rt.Name()) var foo2 foo foo2Val := reflect.ValueOf(foo2) rt2 := dereferenceType(foo2Val.Type()) assert.DeepEqual(t, "foo", rt2.Name()) } ================================================ FILE: pkg/app/server/binding/tagexpr_bind_test.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * MIT License * * Copyright 2019 Bytedance Inc. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2023 CloudWeGo Authors */ package binding import ( "bytes" "encoding/json" "io" "io/ioutil" "mime/multipart" "net/http" "net/url" "strings" "testing" "time" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/route/param" "google.golang.org/protobuf/proto" ) func TestRawBody(t *testing.T) { type Recv struct { S []byte `raw_body:""` F **string `raw_body:""` } bodyBytes := []byte("raw_body.............") req := newRequest("", nil, nil, bytes.NewReader(bodyBytes)) recv := new(Recv) err := DefaultBinder().Bind(req.Req, recv, nil) if err != nil { if err != nil { t.Error(err) } } assert.DeepEqual(t, bodyBytes, recv.S) assert.DeepEqual(t, string(bodyBytes), **recv.F) } func TestQueryString(t *testing.T) { type metric string type count int32 type Recv struct { X **struct { A []string `query:"a"` B string `query:"b"` C *[]string `query:"c,required"` D *string `query:"d"` E *[]***int `query:"e"` F metric `query:"f"` G []count `query:"g"` } Y string `query:"y,required"` Z *string `query:"z"` } req := newRequest("http://localhost:8080/?a=a1&a=a2&b=b1&c=c1&c=c2&d=d1&d=d&f=qps&g=1002&g=1003&e=&e=2&y=y1", nil, nil, nil) recv := new(Recv) bindConfig := &BindConfig{} bindConfig.LooseZeroMode = true binder := NewDefaultBinder(bindConfig) err := binder.Bind(req.Req, recv, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, 0, ***(*(**recv.X).E)[0]) assert.DeepEqual(t, 2, ***(*(**recv.X).E)[1]) assert.DeepEqual(t, []string{"a1", "a2"}, (**recv.X).A) assert.DeepEqual(t, "b1", (**recv.X).B) assert.DeepEqual(t, []string{"c1", "c2"}, *(**recv.X).C) assert.DeepEqual(t, "d1", *(**recv.X).D) assert.DeepEqual(t, metric("qps"), (**recv.X).F) assert.DeepEqual(t, []count{1002, 1003}, (**recv.X).G) assert.DeepEqual(t, "y1", recv.Y) assert.DeepEqual(t, (*string)(nil), recv.Z) } func TestGetBody(t *testing.T) { type Recv struct { X **struct { E string `json:"e,required" query:"e,required"` } } req := newRequest("http://localhost:8080/", nil, nil, nil) recv := new(Recv) err := DefaultBinder().Bind(req.Req, recv, nil) if err == nil { t.Fatalf("expected an error, but get nil") } assert.DeepEqual(t, err.Error(), "'e' field is a 'required' parameter, but the request body does not have this parameter 'X.e'") } func TestQueryNum(t *testing.T) { type Recv struct { X **struct { A []int `query:"a"` B int32 `query:"b"` C *[]uint16 `query:"c,required"` D *float32 `query:"d"` } Y bool `query:"y,required"` Z *int64 `query:"z"` } req := newRequest("http://localhost:8080/?a=11&a=12&b=21&c=31&c=32&d=41&d=42&y=true", nil, nil, nil) recv := new(Recv) err := DefaultBinder().Bind(req.Req, recv, nil) if err != nil { if err != nil { t.Error(err) } } assert.DeepEqual(t, []int{11, 12}, (**recv.X).A) assert.DeepEqual(t, int32(21), (**recv.X).B) assert.DeepEqual(t, &[]uint16{31, 32}, (**recv.X).C) assert.DeepEqual(t, float32(41), *(**recv.X).D) assert.DeepEqual(t, true, recv.Y) assert.DeepEqual(t, (*int64)(nil), recv.Z) } func TestHeaderString(t *testing.T) { type Recv struct { X **struct { A []string `header:"X-A"` B string `header:"X-B"` C *[]string `header:"X-C,required"` D *string `header:"X-D"` } Y string `header:"X-Y,required"` Z *string `header:"X-Z"` } header := make(http.Header) header.Add("X-A", "a1") header.Add("X-A", "a2") header.Add("X-B", "b1") header.Add("X-C", "c1") header.Add("X-C", "c2") header.Add("X-D", "d1") header.Add("X-D", "d2") header.Add("X-Y", "y1") req := newRequest("", header, nil, nil) recv := new(Recv) err := DefaultBinder().Bind(req.Req, recv, nil) if err != nil { if err != nil { t.Error(err) } } assert.DeepEqual(t, []string{"a1", "a2"}, (**recv.X).A) assert.DeepEqual(t, "b1", (**recv.X).B) assert.DeepEqual(t, []string{"c1", "c2"}, *(**recv.X).C) assert.DeepEqual(t, "d1", *(**recv.X).D) assert.DeepEqual(t, "y1", recv.Y) assert.DeepEqual(t, (*string)(nil), recv.Z) } func TestHeaderNum(t *testing.T) { type Recv struct { X **struct { A []int `header:"X-A"` B int32 `header:"X-B"` C *[]uint16 `header:"X-C,required"` D *float32 `header:"X-D"` } Y bool `header:"X-Y,required"` Z *int64 `header:"X-Z"` } header := make(http.Header) header.Add("X-A", "11") header.Add("X-A", "12") header.Add("X-B", "21") header.Add("X-C", "31") header.Add("X-C", "32") header.Add("X-D", "41") header.Add("X-D", "42") header.Add("X-Y", "true") req := newRequest("", header, nil, nil) recv := new(Recv) err := DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, []int{11, 12}, (**recv.X).A) assert.DeepEqual(t, int32(21), (**recv.X).B) assert.DeepEqual(t, &[]uint16{31, 32}, (**recv.X).C) assert.DeepEqual(t, float32(41), *(**recv.X).D) assert.DeepEqual(t, true, recv.Y) assert.DeepEqual(t, (*int64)(nil), recv.Z) } func TestCookieString(t *testing.T) { type Recv struct { X **struct { A []string `cookie:"a"` B string `cookie:"b"` C *[]string `cookie:"c,required"` D *string `cookie:"d"` } Y string `cookie:"y,required"` Z *string `cookie:"z"` } cookies := []*http.Cookie{ {Name: "a", Value: "a1"}, {Name: "a", Value: "a2"}, {Name: "b", Value: "b1"}, {Name: "c", Value: "c1"}, {Name: "c", Value: "c2"}, {Name: "d", Value: "d1"}, {Name: "d", Value: "d2"}, {Name: "y", Value: "y1"}, } req := newRequest("", nil, cookies, nil) recv := new(Recv) err := DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, []string{"a2"}, (**recv.X).A) assert.DeepEqual(t, "b1", (**recv.X).B) assert.DeepEqual(t, []string{"c2"}, *(**recv.X).C) assert.DeepEqual(t, "d2", *(**recv.X).D) assert.DeepEqual(t, "y1", recv.Y) assert.DeepEqual(t, (*string)(nil), recv.Z) } func TestCookieNum(t *testing.T) { type Recv struct { X **struct { A []int `cookie:"a"` B int32 `cookie:"b"` C *[]uint16 `cookie:"c,required"` D *float32 `cookie:"d"` } Y bool `cookie:"y,required"` Z *int64 `cookie:"z"` } cookies := []*http.Cookie{ {Name: "a", Value: "11"}, {Name: "b", Value: "21"}, {Name: "c", Value: "31"}, {Name: "d", Value: "41"}, {Name: "y", Value: "t"}, } req := newRequest("", nil, cookies, nil) recv := new(Recv) err := DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, []int{11}, (**recv.X).A) assert.DeepEqual(t, int32(21), (**recv.X).B) assert.DeepEqual(t, &[]uint16{31}, (**recv.X).C) assert.DeepEqual(t, float32(41), *(**recv.X).D) assert.DeepEqual(t, true, recv.Y) assert.DeepEqual(t, (*int64)(nil), recv.Z) } func TestFormString(t *testing.T) { type Recv struct { X **struct { A []string `form:"a"` B string `form:"b"` C *[]string `form:"c,required"` D *string `form:"d"` } Y string `form:"y,required"` Z *string `form:"z"` F *multipart.FileHeader `form:"F1"` F1 multipart.FileHeader Fs []multipart.FileHeader `form:"F1"` Fs1 []*multipart.FileHeader `form:"F1"` } values := make(url.Values) values.Add("a", "a1") values.Add("a", "a2") values.Add("b", "b1") values.Add("c", "c1") values.Add("c", "c2") values.Add("d", "d1") values.Add("d", "d2") values.Add("y", "y1") for i, f := range []files{{ "F1": []file{ newFile("txt", strings.NewReader("0123")), }, }} { contentType, bodyReader := newFormBody2(values, f) header := make(http.Header) header.Set("Content-Type", contentType) req := newRequest("", header, nil, bodyReader) recv := new(Recv) err := DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, []string{"a1", "a2"}, (**recv.X).A) assert.DeepEqual(t, "b1", (**recv.X).B) assert.DeepEqual(t, []string{"c1", "c2"}, *(**recv.X).C) assert.DeepEqual(t, "d1", *(**recv.X).D) assert.DeepEqual(t, "y1", recv.Y) assert.DeepEqual(t, (*string)(nil), recv.Z) t.Logf("[%d] F: %#v", i, recv.F) t.Logf("[%d] F1: %#v", i, recv.F1) t.Logf("[%d] Fs: %#v", i, recv.Fs) t.Logf("[%d] Fs1: %#v", i, recv.Fs1) if len(recv.Fs1) > 0 { t.Logf("[%d] Fs1[0]: %#v", i, recv.Fs1[0]) } } } func TestFormNum(t *testing.T) { type Recv struct { X **struct { A []int `form:"a"` B int32 `form:"b"` C *[]uint16 `form:"c,required"` D *float32 `form:"d"` } Y bool `form:"y,required"` Z *int64 `form:"z"` } values := make(url.Values) values.Add("a", "11") values.Add("a", "12") values.Add("b", "-21") values.Add("c", "31") values.Add("c", "32") values.Add("d", "41") values.Add("d", "42") values.Add("y", "1") for _, f := range []files{nil, { "f1": []file{ newFile("txt", strings.NewReader("f11 text.")), }, }} { contentType, bodyReader := newFormBody2(values, f) header := make(http.Header) header.Set("Content-Type", contentType) req := newRequest("", header, nil, bodyReader) recv := new(Recv) err := DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, []int{11, 12}, (**recv.X).A) assert.DeepEqual(t, int32(-21), (**recv.X).B) assert.DeepEqual(t, &[]uint16{31, 32}, (**recv.X).C) assert.DeepEqual(t, float32(41), *(**recv.X).D) assert.DeepEqual(t, true, recv.Y) assert.DeepEqual(t, (*int64)(nil), recv.Z) } } func TestJSON(t *testing.T) { type metric string type count int32 type ZS struct { Z *int64 } type Recv struct { X **struct { A []string `json:"a"` B int32 `json:""` C *[]uint16 `json:",required"` D *float32 `json:"d"` E metric `json:"e"` F count `json:"f"` M map[string]string `json:"m"` } Y string `json:"y,required"` ZS } bodyReader := strings.NewReader(`{ "X": { "a": ["a1","a2"], "B": 21, "C": [31,32], "d": 41, "e": "qps", "f": 100, "m": {"a":"x"} }, "Z": 6 }`) header := make(http.Header) header.Set("Content-Type", consts.MIMEApplicationJSON) req := newRequest("", header, nil, bodyReader) recv := new(Recv) err := DefaultBinder().Bind(req.Req, recv, nil) if err == nil { t.Error("expected an error, but get nil") } assert.DeepEqual(t, err.Error(), "'y' field is a 'required' parameter, but the request body does not have this parameter 'y'") assert.DeepEqual(t, []string{"a1", "a2"}, (**recv.X).A) assert.DeepEqual(t, int32(21), (**recv.X).B) assert.DeepEqual(t, &[]uint16{31, 32}, (**recv.X).C) assert.DeepEqual(t, float32(41), *(**recv.X).D) assert.DeepEqual(t, metric("qps"), (**recv.X).E) assert.DeepEqual(t, count(100), (**recv.X).F) assert.DeepEqual(t, map[string]string{"a": "x"}, (**recv.X).M) assert.DeepEqual(t, "", recv.Y) assert.DeepEqual(t, (int64)(6), *recv.Z) } func TestNonstruct(t *testing.T) { bodyReader := strings.NewReader(`{ "X": { "a": ["a1","a2"], "B": 21, "C": [31,32], "d": 41, "e": "qps", "f": 100 }, "Z": 6 }`) header := make(http.Header) header.Set("Content-Type", "application/json") req := newRequest("", header, nil, bodyReader) var recv interface{} err := DefaultBinder().Bind(req.Req, &recv, nil) if err != nil { t.Error(err) } b, err := json.Marshal(recv) if err != nil { t.Error(err) } t.Logf("%s", b) bodyReader = strings.NewReader("b=334ddddd&token=mymocktoken&type=url_verification") header.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8") req = newRequest("", header, nil, bodyReader) recv = nil err = DefaultBinder().Bind(req.Req, &recv, nil) if err != nil { t.Error(err) } b, err = json.Marshal(recv) if err != nil { t.Error(err) } t.Logf("%s", b) } func TestPath(t *testing.T) { type Recv struct { X **struct { A []string `path:"a"` B int32 `path:"b"` C *[]uint16 `path:"c,required"` D *float32 `path:"d"` } Y string `path:"y,required"` Z *int64 } req := newRequest("", nil, nil, nil) recv := new(Recv) params := param.Params{ { Key: "a", Value: "a1", }, { Key: "b", Value: "-21", }, { Key: "c", Value: "31", }, { Key: "d", Value: "41", }, { Key: "y", Value: "y1", }, { Key: "name", Value: "henrylee2cn", }, } err := DefaultBinder().Bind(req.Req, recv, params) if err != nil { t.Error(err) } assert.DeepEqual(t, []string{"a1"}, (**recv.X).A) assert.DeepEqual(t, int32(-21), (**recv.X).B) assert.DeepEqual(t, &[]uint16{31}, (**recv.X).C) assert.DeepEqual(t, float32(41), *(**recv.X).D) assert.DeepEqual(t, "y1", recv.Y) assert.DeepEqual(t, (*int64)(nil), recv.Z) } func TestDefault(t *testing.T) { type S struct { SS string `json:"ss"` } type Recv struct { X **struct { A []string `path:"a" json:"a"` B int32 `path:"b" default:"32"` C bool `json:"c" default:"true"` D *float32 `default:"123.4"` E *[]string `default:"['a','b','c','d,e,f']"` F map[string]string `default:"{'a':'\"\\'1','\"b':'c','c':'2'}"` G map[string]int64 `default:"{'a':1,'b':2,'c':3}"` H map[string]float64 `default:"{'a':0.1,'b':1.2,'c':2.3}"` I map[string]float64 `default:"{'\"a\"':0.1,'b':1.2,'c':2.3}"` Empty string `default:""` Null string `default:""` CommaSpace string `default:",a:c "` Dash string `default:"-"` // InvalidInt int `default:"abc"` // InvalidMap map[string]string `default:"abc"` } Y string `json:"y" default:"y1"` Z int64 W string `json:"w"` V []int64 `json:"v" default:"[1,2,3]"` U []float32 `json:"u" default:"[1.1,2,3]"` T *string `json:"t" default:"t1"` S S `default:"{'ss':'test'}"` O *S `default:"{'ss':'test2'}"` Complex map[string][]map[string][]int64 `default:"{'a':[{'aa':[1,2,3], 'bb':[4,5]}],'b':[{}]}"` } bodyReader := strings.NewReader(`{ "X": { "a": ["a1","a2"] }, "Z": 6 }`) // var nilMap map[string]string header := make(http.Header) header.Set("Content-Type", consts.MIMEApplicationJSON) req := newRequest("", header, nil, bodyReader) recv := new(Recv) param2 := param.Params{ { Key: "e", Value: "123", }, } err := DefaultBinder().Bind(req.Req, recv, param2) if err != nil { t.Error(err) } assert.DeepEqual(t, []string{"a1", "a2"}, (**recv.X).A) assert.DeepEqual(t, int32(32), (**recv.X).B) assert.DeepEqual(t, true, (**recv.X).C) assert.DeepEqual(t, float32(123.4), *(**recv.X).D) assert.DeepEqual(t, []string{"a", "b", "c", "d,e,f"}, *(**recv.X).E) assert.DeepEqual(t, map[string]string{"a": "\"'1", "\"b": "c", "c": "2"}, (**recv.X).F) assert.DeepEqual(t, map[string]int64{"a": 1, "b": 2, "c": 3}, (**recv.X).G) assert.DeepEqual(t, map[string]float64{"a": 0.1, "b": 1.2, "c": 2.3}, (**recv.X).H) assert.DeepEqual(t, map[string]float64{"\"a\"": 0.1, "b": 1.2, "c": 2.3}, (**recv.X).I) assert.DeepEqual(t, "", (**recv.X).Empty) assert.DeepEqual(t, "", (**recv.X).Null) assert.DeepEqual(t, ",a:c ", (**recv.X).CommaSpace) assert.DeepEqual(t, "-", (**recv.X).Dash) // assert.DeepEqual(t, 0, (**recv.X).InvalidInt) // assert.DeepEqual(t, nilMap, (**recv.X).InvalidMap) assert.DeepEqual(t, "y1", recv.Y) assert.DeepEqual(t, "t1", *recv.T) assert.DeepEqual(t, int64(6), recv.Z) assert.DeepEqual(t, []int64{1, 2, 3}, recv.V) assert.DeepEqual(t, []float32{1.1, 2, 3}, recv.U) assert.DeepEqual(t, S{SS: "test"}, recv.S) assert.DeepEqual(t, &S{SS: "test2"}, recv.O) assert.DeepEqual(t, map[string][]map[string][]int64{"a": {{"aa": {1, 2, 3}, "bb": []int64{4, 5}}}, "b": {map[string][]int64{}}}, recv.Complex) } func TestAuto(t *testing.T) { type Recv struct { A string B string C string D string `query:"D,required" form:"D,required"` E string `cookie:"e" json:"e"` } query := make(url.Values) query.Add("A", "a") query.Add("B", "b") query.Add("C", "c") query.Add("D", "d-from-query") contentType, bodyReader, err := newJSONBody(map[string]string{"e": "e-from-jsonbody"}) if err != nil { t.Error(err) } header := make(http.Header) header.Set("Content-Type", contentType) req := newRequest("http://localhost/?"+query.Encode(), header, []*http.Cookie{ {Name: "e", Value: "e-from-cookie"}, }, bodyReader) recv := new(Recv) err = DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, "a", recv.A) assert.DeepEqual(t, "b", recv.B) assert.DeepEqual(t, "c", recv.C) assert.DeepEqual(t, "d-from-query", recv.D) assert.DeepEqual(t, "e-from-cookie", recv.E) query = make(url.Values) query.Add("D", "d-from-query") form := make(url.Values) form.Add("B", "b") form.Add("C", "c") form.Add("D", "d-from-form") contentType, bodyReader = newFormBody2(form, nil) header = make(http.Header) header.Set("Content-Type", contentType) req = newRequest("http://localhost/?"+query.Encode(), header, nil, bodyReader) recv = new(Recv) err = DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, "", recv.A) assert.DeepEqual(t, "b", recv.B) assert.DeepEqual(t, "c", recv.C) assert.DeepEqual(t, "d-from-form", recv.D) } func TestTypeUnmarshal(t *testing.T) { type Recv struct { A time.Time `form:"t1"` B *time.Time `query:"t2"` C []time.Time `query:"t2"` } query := make(url.Values) query.Add("t2", "2019-09-04T14:05:24+08:00") query.Add("t2", "2019-09-04T18:05:24+08:00") form := make(url.Values) form.Add("t1", "2019-09-03T18:05:24+08:00") contentType, bodyReader := newFormBody2(form, nil) header := make(http.Header) header.Set("Content-Type", contentType) req := newRequest("http://localhost/?"+query.Encode(), header, nil, bodyReader) recv := new(Recv) err := DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } t1, err := time.Parse(time.RFC3339, "2019-09-03T18:05:24+08:00") if err != nil { t.Error(err) } assert.DeepEqual(t, t1, recv.A) t21, err := time.Parse(time.RFC3339, "2019-09-04T14:05:24+08:00") if err != nil { t.Error(err) } assert.DeepEqual(t, t21, *recv.B) t22, err := time.Parse(time.RFC3339, "2019-09-04T18:05:24+08:00") if err != nil { t.Error(err) } assert.DeepEqual(t, []time.Time{t21, t22}, recv.C) t.Logf("%v", recv) } // test: required func TestOption(t *testing.T) { type Recv struct { X *struct { C int `json:"c,required"` D int `json:"d"` } `json:"X"` Y string `json:"y"` } header := make(http.Header) header.Set("Content-Type", consts.MIMEApplicationJSON) bodyReader := strings.NewReader(`{ "X": { "c": 21, "d": 41 }, "y": "y1" }`) req := newRequest("", header, nil, bodyReader) recv := new(Recv) err := DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, 21, recv.X.C) assert.DeepEqual(t, 41, recv.X.D) assert.DeepEqual(t, "y1", recv.Y) bodyReader = strings.NewReader(`{ "X": { }, "y": "y1" }`) req = newRequest("", header, nil, bodyReader) recv = new(Recv) err = DefaultBinder().Bind(req.Req, recv, nil) assert.DeepEqual(t, err.Error(), "'c' field is a 'required' parameter, but the request body does not have this parameter 'X.c'") assert.DeepEqual(t, 0, recv.X.C) assert.DeepEqual(t, 0, recv.X.D) assert.DeepEqual(t, "y1", recv.Y) bodyReader = strings.NewReader(`{ "y": "y1" }`) req = newRequest("", header, nil, bodyReader) recv = new(Recv) err = DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } assert.True(t, recv.X == nil) assert.DeepEqual(t, "y1", recv.Y) type Recv2 struct { X *struct { C int `json:"c,required"` D int `json:"d"` } `json:"X,required"` Y string `json:"y"` } bodyReader = strings.NewReader(`{ "y": "y1" }`) req = newRequest("", header, nil, bodyReader) recv2 := new(Recv2) bindConfig := &BindConfig{} bindConfig.DisableStructFieldResolve = false binder := NewDefaultBinder(bindConfig) err = binder.Bind(req.Req, recv2, nil) assert.DeepEqual(t, err.Error(), "'X' field is a 'required' parameter, but the request does not have this parameter") assert.True(t, recv2.X == nil) assert.DeepEqual(t, "y1", recv2.Y) } func newRequest(u string, header http.Header, cookies []*http.Cookie, bodyReader io.Reader) *mockRequest { if header == nil { header = make(http.Header) } method := "GET" var body []byte if bodyReader != nil { body, _ = ioutil.ReadAll(bodyReader) method = "POST" } if u == "" { u = "http://localhost" } req := newMockRequest() req.SetRequestURI(u) for k, v := range header { for _, val := range v { req.Req.Header.Add(k, val) } } if len(body) != 0 { req.SetBody(body) req.Req.Header.SetContentLength(len(body)) } req.Req.SetMethod(method) for _, c := range cookies { req.Req.Header.SetCookie(c.Name, c.Value) } return req } func TestQueryStringIssue(t *testing.T) { type Timestamp struct { time.Time } type Recv struct { Name *string `query:"name"` T *Timestamp `query:"t"` } req := newRequest("http://localhost:8080/?name=test", nil, nil, nil) recv := new(Recv) err := DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, "test", *recv.Name) // DIFF: the type with customized decoder must be a non-nil value // assert.DeepEqual(t, (*Timestamp)(nil), recv.T) } func TestQueryTypes(t *testing.T) { type metric string type count int32 type metrics []string type filter struct { Col1 string } type Recv struct { A metric B count C *count D metrics `query:"D,required" form:"D,required"` E metric `cookie:"e" json:"e"` F filter `json:"f"` } query := make(url.Values) query.Add("A", "qps") query.Add("B", "123") query.Add("C", "321") query.Add("D", "dau") query.Add("D", "dnu") contentType, bodyReader, err := newJSONBody( map[string]interface{}{ "e": "e-from-jsonbody", "f": filter{Col1: "abc"}, }, ) if err != nil { t.Error(err) } header := make(http.Header) header.Set("Content-Type", contentType) req := newRequest("http://localhost/?"+query.Encode(), header, []*http.Cookie{ {Name: "e", Value: "e-from-cookie"}, }, bodyReader) recv := new(Recv) err = DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, metric("qps"), recv.A) assert.DeepEqual(t, count(123), recv.B) assert.DeepEqual(t, count(321), *recv.C) assert.DeepEqual(t, metrics{"dau", "dnu"}, recv.D) assert.DeepEqual(t, metric("e-from-cookie"), recv.E) assert.DeepEqual(t, filter{Col1: "abc"}, recv.F) } func TestNoTagIssue(t *testing.T) { type x int type T struct { x x2 x a int B int } req := newRequest("http://localhost:8080/?x=11&x2=12&a=1&B=2", nil, nil, nil) recv := new(T) err := DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, x(0), recv.x) assert.DeepEqual(t, x(0), recv.x2) assert.DeepEqual(t, 0, recv.a) assert.DeepEqual(t, 2, recv.B) } func TestRegTypeUnmarshal(t *testing.T) { type Q struct { A int B string } type T struct { Q Q `query:"q"` Qs []*Q `query:"qs"` Qs2 ***[]*Q `query:"qs"` } values := url.Values{} b, err := json.Marshal(Q{A: 2, B: "y"}) if err != nil { t.Error(err) } values.Add("q", string(b)) bs, _ := json.Marshal([]Q{{A: 1, B: "x"}, {A: 2, B: "y"}}) values.Add("qs", string(bs)) req := newRequest("http://localhost:8080/?"+values.Encode(), nil, nil, nil) recv := new(T) bindConfig := &BindConfig{} bindConfig.DisableStructFieldResolve = false binder := NewDefaultBinder(bindConfig) err = binder.Bind(req.Req, recv, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, 2, recv.Q.A) assert.DeepEqual(t, "y", recv.Q.B) assert.DeepEqual(t, 1, recv.Qs[0].A) assert.DeepEqual(t, "x", recv.Qs[0].B) assert.DeepEqual(t, 2, recv.Qs[1].A) assert.DeepEqual(t, "y", recv.Qs[1].B) assert.DeepEqual(t, 1, (***recv.Qs2)[0].A) assert.DeepEqual(t, "x", (***recv.Qs2)[0].B) assert.DeepEqual(t, 2, (***recv.Qs2)[1].A) assert.DeepEqual(t, "y", (***recv.Qs2)[1].B) } func TestPathnameBUG(t *testing.T) { type Currency struct { CurrencyName *string `form:"currency_name,required" json:"currency_name,required" protobuf:"bytes,1,req,name=currency_name,json=currencyName" query:"currency_name,required"` CurrencySymbol *string `form:"currency_symbol,required" json:"currency_symbol,required" protobuf:"bytes,2,req,name=currency_symbol,json=currencySymbol" query:"currency_symbol,required"` SymbolPosition *int32 `form:"symbol_position,required" json:"symbol_position,required" protobuf:"varint,3,req,name=symbol_position,json=symbolPosition" query:"symbol_position,required"` DecimalPlaces *int32 `form:"decimal_places,required" json:"decimal_places,required" protobuf:"varint,4,req,name=decimal_places,json=decimalPlaces" query:"decimal_places,required"` // 56x56 DecimalSymbol *string `form:"decimal_symbol,required" json:"decimal_symbol,required" protobuf:"bytes,5,req,name=decimal_symbol,json=decimalSymbol" query:"decimal_symbol,required"` Separator *string `form:"separator,required" json:"separator,required" protobuf:"bytes,6,req,name=separator" query:"separator,required"` SeparatorIndex *string `form:"separator_index,required" json:"separator_index,required" protobuf:"bytes,7,req,name=separator_index,json=separatorIndex" query:"separator_index,required"` Between *string `form:"between,required" json:"between,required" protobuf:"bytes,8,req,name=between" query:"between,required"` MinPrice *string `form:"min_price" json:"min_price,omitempty" protobuf:"bytes,9,opt,name=min_price,json=minPrice" query:"min_price"` MaxPrice *string `form:"max_price" json:"max_price,omitempty" protobuf:"bytes,10,opt,name=max_price,json=maxPrice" query:"max_price"` } type CurrencyData struct { Amount *string `form:"amount,required" json:"amount,required" protobuf:"bytes,1,req,name=amount" query:"amount,required"` Currency *Currency `form:"currency,required" json:"currency,required" protobuf:"bytes,2,req,name=currency" query:"currency,required"` } type ExchangeCurrencyRequest struct { PromotionRegion *string `form:"promotion_region,required" json:"promotion_region,required" protobuf:"bytes,1,req,name=promotion_region,json=promotionRegion" query:"promotion_region,required"` Currency *CurrencyData `form:"currency,required" json:"currency,required" protobuf:"bytes,2,req,name=currency" query:"currency,required"` Version *int32 `json:"version,omitempty" path:"version" protobuf:"varint,100,opt,name=version"` } z := new(ExchangeCurrencyRequest) z.Currency = new(CurrencyData) z.Currency.Currency = new(Currency) z.PromotionRegion = proto.String("?") z.Version = proto.Int32(-32) z.Currency.Amount = proto.String("?") z.Currency.Currency.CurrencyName = proto.String("?") z.Currency.Currency.CurrencySymbol = proto.String("?") z.Currency.Currency.SymbolPosition = proto.Int32(-32) z.Currency.Currency.DecimalPlaces = proto.Int32(-32) z.Currency.Currency.DecimalSymbol = proto.String("?") z.Currency.Currency.Separator = proto.String("?") z.Currency.Currency.Between = proto.String("?") z.Currency.Currency.MinPrice = proto.String("?") z.Currency.Currency.MaxPrice = proto.String("?") b, err := json.MarshalIndent(z, "", " ") if err != nil { t.Error(err) } header := make(http.Header) header.Set("Content-Type", "application/json;charset=utf-8") req := newRequest("http://localhost", header, nil, bytes.NewReader(b)) recv := new(ExchangeCurrencyRequest) err = DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } } // test: required func TestPathnameBUG2(t *testing.T) { type CurrencyData struct { Amount *string `form:"amount,required" json:"amount,required" protobuf:"bytes,1,req,name=amount" query:"amount,required"` Name *string `form:"name,required" json:"name,required" protobuf:"bytes,2,req,name=name" query:"name,required"` Symbol *string `form:"symbol" json:"symbol,omitempty" protobuf:"bytes,3,opt,name=symbol" query:"symbol"` } type TimeRange struct { StartTime *int64 `form:"start_time,required" json:"start_time,required" protobuf:"varint,1,req,name=start_time,json=startTime" query:"start_time,required"` EndTime *int64 `form:"end_time,required" json:"end_time,required" protobuf:"varint,2,req,name=end_time,json=endTime" query:"end_time,required"` } type CreateFreeShippingRequest struct { PromotionName *string `form:"promotion_name,required" json:"promotion_name,required" protobuf:"bytes,1,req,name=promotion_name,json=promotionName" query:"promotion_name,required"` PromotionRegion *string `form:"promotion_region,required" json:"promotion_region,required" protobuf:"bytes,2,req,name=promotion_region,json=promotionRegion" query:"promotion_region,required"` TimeRange *TimeRange `form:"time_range,required" json:"time_range,required" protobuf:"bytes,3,req,name=time_range,json=timeRange" query:"time_range,required"` PromotionBudget *CurrencyData `form:"promotion_budget,required" json:"promotion_budget,required" protobuf:"bytes,4,req,name=promotion_budget,json=promotionBudget" query:"promotion_budget,required"` Loaded_SellerIds []string `form:"loaded_Seller_ids" json:"loaded_Seller_ids,omitempty" protobuf:"bytes,5,rep,name=loaded_Seller_ids,json=loadedSellerIds" query:"loaded_Seller_ids"` Version *int32 `json:"version,omitempty" path:"version" protobuf:"varint,100,opt,name=version"` } // z := &CreateFreeShippingRequest{} // v := ameda.InitSampleValue(reflect.TypeOf(z), 10).Interface().(*CreateFreeShippingRequest) // b, err := json.MarshalIndent(v, "", " ") // t.Log(string(b)) b := []byte(`{ "promotion_name": "mu", "promotion_region": "ID", "time_range": { "start_time": 1616420139, "end_time": 1616520139 }, "promotion_budget": { "amount":"10000000", "name":"USD", "symbol":"$" }, "loaded_Seller_ids": [ "7493989780026655762","11111","111212121" ] }`) v := new(CreateFreeShippingRequest) err := json.Unmarshal(b, v) if err != nil { t.Error(err) } header := make(http.Header) header.Set("Content-Type", "application/json;charset=utf-8") req := newRequest("http://localhost", header, nil, bytes.NewReader(b)) recv := new(CreateFreeShippingRequest) err = DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, v, recv) } func TestRequiredBUG(t *testing.T) { type Currency struct { // currencyName *string `form:"currency_name,required" json:"currency_name,required" protobuf:"bytes,1,req,name=currency_name,json=currencyName" query:"currency_name,required"` CurrencySymbol *string `form:"currency_symbol,required" json:"currency_symbol,required" protobuf:"bytes,2,req,name=currency_symbol,json=currencySymbol" query:"currency_symbol,required"` } type CurrencyData struct { Amount *string `form:"amount,required" json:"amount,required" protobuf:"bytes,1,req,name=amount" query:"amount,required"` Slice []*Currency `form:"slice,required" json:"slice,required" protobuf:"bytes,2,req,name=slice" query:"slice,required"` Map map[string]*Currency `form:"map,required" json:"map,required" protobuf:"bytes,2,req,name=map" query:"map,required"` } type ExchangeCurrencyRequest struct { PromotionRegion *string `form:"promotion_region,required" json:"promotion_region,required" protobuf:"bytes,1,req,name=promotion_region,json=promotionRegion" query:"promotion_region,required"` Currency *CurrencyData `form:"currency,required" json:"currency,required" protobuf:"bytes,2,req,name=currency" query:"currency,required"` } z := &ExchangeCurrencyRequest{} b := []byte(`{ "promotion_region": "?", "currency": { "amount": "?", "slice": [ { "currency_symbol": "?" } ], "map": { "?": { "currency_name": "?" } } } }`) json.Unmarshal(b, z) header := make(http.Header) header.Set("Content-Type", "application/json;charset=utf-8") req := newRequest("http://localhost", header, nil, bytes.NewReader(b)) recv := new(ExchangeCurrencyRequest) err := DefaultBinder().Bind(req.Req, recv, nil) // no need for validate if err != nil { t.Error(err) } assert.DeepEqual(t, z, recv) } func TestIssue25(t *testing.T) { type Recv struct { A string } header := make(http.Header) header.Set("A", "from header") cookies := []*http.Cookie{ {Name: "A", Value: "from cookie"}, } req := newRequest("/1", header, cookies, nil) recv := new(Recv) err := DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } // assert.DeepEqual(t, "from cookie", recv.A) header2 := make(http.Header) header2.Set("A", "from header") cookies2 := []*http.Cookie{} req2 := newRequest("/2", header2, cookies2, nil) recv2 := new(Recv) err2 := DefaultBinder().Bind(req2.Req, recv2, nil) if err2 != nil { t.Error(err2) } assert.DeepEqual(t, "from header", recv2.A) } func TestIssue26(t *testing.T) { type Recv struct { Type string `json:"type,required" vd:"($=='update_target_threshold' && (TargetThreshold)$!='-1') || ($=='update_status' && (Status)$!='-1')"` RuleName string `json:"rule_name,required" vd:"regexp('^rule[0-9]+$')"` TargetThreshold string `json:"target_threshold" vd:"regexp('^-?[0-9]+(\\.[0-9]+)?$')"` Status string `json:"status" vd:"$=='0' || $=='1'"` Operator string `json:"operator,required" vd:"len($)>0"` } b := []byte(`{ "status": "1", "adv": "11520", "target_deep_external_action": "39", "package": "test.bytedance.com", "previous_target_threshold": "0.6", "deep_external_action": "675", "rule_name": "rule2", "deep_bid_type": "54", "modify_time": "2021-08-24:14:35:20", "aid": "111", "operator": "yanghaoze", "external_action": "76", "target_threshold": "0.1", "type": "update_status" }`) recv := new(Recv) err := json.Unmarshal(b, recv) if err != nil { t.Error(err) } header := make(http.Header) header.Set("Content-Type", consts.MIMEApplicationJSON) header.Set("A", "from header") cookies := []*http.Cookie{ {Name: "A", Value: "from cookie"}, } req := newRequest("/1", header, cookies, bytes.NewReader(b)) recv2 := new(Recv) err = DefaultBinder().Bind(req.Req, recv2, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, recv, recv2) } // BUGFIX: after 'json unmarshal', the default value will change it func TestDefault2(t *testing.T) { type Recv struct { X **struct { Dash string `default:"xxxx"` } } bodyReader := strings.NewReader(`{ "X": { "Dash": "hello Dash" } }`) header := make(http.Header) header.Set("Content-Type", consts.MIMEApplicationJSON) req := newRequest("", header, nil, bodyReader) recv := new(Recv) err := DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } assert.DeepEqual(t, "hello Dash", (**recv.X).Dash) } type ( files map[string][]file file interface { Name() string Read(p []byte) (n int, err error) } ) func newFormBody2(values url.Values, files files) (contentType string, bodyReader io.Reader) { if len(files) == 0 { return "application/x-www-form-urlencoded", strings.NewReader(values.Encode()) } pr, pw := io.Pipe() bodyWriter := multipart.NewWriter(pw) var fileWriter io.Writer buf := make([]byte, 32*1024) go func() { for fieldName, postfiles := range files { for _, file := range postfiles { fileWriter, _ = bodyWriter.CreateFormFile(fieldName, file.Name()) io.CopyBuffer(fileWriter, file, buf) } } for k, v := range values { for _, vv := range v { bodyWriter.WriteField(k, vv) } } bodyWriter.Close() pw.Close() }() return bodyWriter.FormDataContentType(), pr } func newFile(name string, bodyReader io.Reader) file { return &fileReader{name, bodyReader} } // fileReader file name and bytes. type fileReader struct { name string bodyReader io.Reader } func (f *fileReader) Name() string { return f.name } func (f *fileReader) Read(p []byte) (int, error) { return f.bodyReader.Read(p) } func newJSONBody(v interface{}) (contentType string, bodyReader io.Reader, err error) { b, err := json.Marshal(v) if err != nil { return } return "application/json;charset=utf-8", bytes.NewReader(b), nil } ================================================ FILE: pkg/app/server/binding/testdata/hello.pb.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.30.0 // protoc v3.21.12 // source: hello.proto package testdata import ( protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" sync "sync" ) const ( // Verify that this generated code is sufficiently up-to-date. _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) // Verify that runtime/protoimpl is sufficiently up-to-date. _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) type HertzReq struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields Name string `protobuf:"bytes,1,opt,name=Name,proto3" json:"Name,omitempty"` } func (x *HertzReq) Reset() { *x = HertzReq{} if protoimpl.UnsafeEnabled { mi := &file_hello_proto_msgTypes[0] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } } func (x *HertzReq) String() string { return protoimpl.X.MessageStringOf(x) } func (*HertzReq) ProtoMessage() {} func (x *HertzReq) ProtoReflect() protoreflect.Message { mi := &file_hello_proto_msgTypes[0] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use HertzReq.ProtoReflect.Descriptor instead. func (*HertzReq) Descriptor() ([]byte, []int) { return file_hello_proto_rawDescGZIP(), []int{0} } func (x *HertzReq) GetName() string { if x != nil { return x.Name } return "" } var File_hello_proto protoreflect.FileDescriptor var file_hello_proto_rawDesc = []byte{ 0x0a, 0x0b, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x05, 0x68, 0x65, 0x72, 0x74, 0x7a, 0x22, 0x1e, 0x0a, 0x08, 0x48, 0x65, 0x72, 0x74, 0x7a, 0x52, 0x65, 0x71, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x42, 0x0d, 0x5a, 0x0b, 0x68, 0x65, 0x72, 0x74, 0x7a, 0x2f, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( file_hello_proto_rawDescOnce sync.Once file_hello_proto_rawDescData = file_hello_proto_rawDesc ) func file_hello_proto_rawDescGZIP() []byte { file_hello_proto_rawDescOnce.Do(func() { file_hello_proto_rawDescData = protoimpl.X.CompressGZIP(file_hello_proto_rawDescData) }) return file_hello_proto_rawDescData } var file_hello_proto_msgTypes = make([]protoimpl.MessageInfo, 1) var file_hello_proto_goTypes = []interface{}{ (*HertzReq)(nil), // 0: hertz.HertzReq } var file_hello_proto_depIdxs = []int32{ 0, // [0:0] is the sub-list for method output_type 0, // [0:0] is the sub-list for method input_type 0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for field type_name } func init() { file_hello_proto_init() } func file_hello_proto_init() { if File_hello_proto != nil { return } if !protoimpl.UnsafeEnabled { file_hello_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*HertzReq); i { case 0: return &v.state case 1: return &v.sizeCache case 2: return &v.unknownFields default: return nil } } } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_hello_proto_rawDesc, NumEnums: 0, NumMessages: 1, NumExtensions: 0, NumServices: 0, }, GoTypes: file_hello_proto_goTypes, DependencyIndexes: file_hello_proto_depIdxs, MessageInfos: file_hello_proto_msgTypes, }.Build() File_hello_proto = out.File file_hello_proto_rawDesc = nil file_hello_proto_goTypes = nil file_hello_proto_depIdxs = nil } ================================================ FILE: pkg/app/server/binding/testdata/hello.proto ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ syntax = "proto3"; package hertz; option go_package = "hertz/hello"; message HertzReq { string Name = 1; } ================================================ FILE: pkg/app/server/binding/validator.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * The MIT License (MIT) * * Copyright (c) 2014 Manuel Martínez-Almeida * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2023 CloudWeGo Authors */ package binding import ( "reflect" "sync" "github.com/cloudwego/hertz/pkg/protocol" ) // ValidatorFunc defines a validation function that can access request context. // It takes a request and the object to validate, returning an error if validation fails. type ValidatorFunc func(*protocol.Request, interface{}) error // StructValidator defines the interface for struct validation. // // Deprecated: Use ValidatorFunc in BindConfig instead. You can create a ValidatorFunc // from a StructValidator using MakeValidatorFunc(). type StructValidator interface { ValidateStruct(interface{}) error Engine() interface{} ValidateTag() string } // hasValidateTagCache caches whether a type has validation tags to avoid // redundant reflection-based tag analysis on repeated validations var hasValidateTagCache sync.Map // MakeValidatorFunc creates a validation function from a StructValidator. // It optimizes validation by caching tag analysis results and skipping // validation entirely for types that don't have validation tags. func MakeValidatorFunc(s StructValidator) ValidatorFunc { if s == nil { return nil } return func(_ *protocol.Request, v any) error { rv, typeID := valueAndTypeID(v) c, ok := hasValidateTagCache.Load(typeID) if ok { if !c.(bool) { return nil } return s.ValidateStruct(rv) } tag := s.ValidateTag() if tag == "" { tag = defaultValidateTag } hasTag := containsStructTag(rv.Type(), tag, nil) hasValidateTagCache.Store(typeID, hasTag) if !hasTag { return nil } return s.ValidateStruct(rv) } } // containsStructTag recursively checks if a struct type contains any field with the specified tag. // It uses a checking map to prevent infinite recursion in self-referential struct types. func containsStructTag(rt reflect.Type, tag string, checking map[reflect.Type]bool) bool { rt = dereferenceType(rt) if rt.Kind() != reflect.Struct { return false } if checking == nil { checking = map[reflect.Type]bool{} } checking[rt] = true for i := 0; i < rt.NumField(); i++ { f := rt.Field(i) _, ok := f.Tag.Lookup(tag) if ok { return true } ft := dereferenceType(f.Type) if checking[ft] { continue } if containsStructTag(ft, tag, checking) { return true } } return false } ================================================ FILE: pkg/app/server/binding/validator_test.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package binding import ( "testing" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func Test_ValidateStruct(t *testing.T) { type User struct { Age int `vd:"$>=0&&$<=130"` } user := &User{ Age: 135, } err := DefaultValidator().ValidateStruct(user) if err == nil { t.Fatalf("expected an error, but got nil") } } func Test_ValidateTag(t *testing.T) { type User struct { Age int `query:"age" vt:"$>=0&&$<=130"` } user := &User{ Age: 135, } validateConfig := NewValidateConfig() validateConfig.ValidateTag = "vt" vd := NewValidator(validateConfig) err := vd.ValidateStruct(user) if err == nil { t.Fatalf("expected an error, but got nil") } bindConfig := NewBindConfig() bindConfig.Validator = vd binder := NewDefaultBinder(bindConfig) user = &User{} req := newMockRequest(). SetRequestURI("http://foobar.com?age=135"). SetHeaders("h", "header") err = binder.Bind(req.Req, user, nil) assert.Nil(t, err) err = binder.Validate(req.Req, user) assert.NotNil(t, err) } ================================================ FILE: pkg/app/server/hertz.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package server import ( "context" "os" "os/signal" "syscall" "time" "github.com/cloudwego/hertz/pkg/app/middlewares/server/recovery" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/common/hlog" "github.com/cloudwego/hertz/pkg/route" ) // Hertz is the core struct of hertz. type Hertz struct { *route.Engine signalWaiter func(err chan error) error } // New creates a hertz instance without any default config. func New(opts ...config.Option) *Hertz { options := config.NewOptions(opts) h := &Hertz{ Engine: route.NewEngine(options), } return h } // Default creates a hertz instance with default middlewares. func Default(opts ...config.Option) *Hertz { h := New(opts...) h.Use(recovery.Recovery()) return h } // Spin runs the server until catching os.Signal or error returned by h.Run(). func (h *Hertz) Spin() { errCh := make(chan error) h.initOnRunHooks(errCh) go func() { errCh <- h.Run() }() signalWaiter := waitSignal if h.signalWaiter != nil { signalWaiter = h.signalWaiter } if err := signalWaiter(errCh); err != nil { hlog.SystemLogger().Errorf("Receive close signal: error=%v", err) if err := h.Engine.Close(); err != nil { hlog.SystemLogger().Errorf("Close error=%v", err) } return } if err := h.Shutdown(context.Background()); err != nil { hlog.SystemLogger().Errorf("Shutdown error=%v", err) } } // SetCustomSignalWaiter sets the signal waiter function. // If Default one is not met the requirement, set this function to customize. // Hertz will exit immediately if f returns an error, otherwise it will exit gracefully. func (h *Hertz) SetCustomSignalWaiter(f func(err chan error) error) { h.signalWaiter = f } // Default implementation for signal waiter. // SIGHUP|SIGINT|SIGTERM triggers graceful shutdown. func waitSignal(errCh chan error) error { signalToNotify := []os.Signal{syscall.SIGINT, syscall.SIGHUP, syscall.SIGTERM} if signal.Ignored(syscall.SIGHUP) { signalToNotify = []os.Signal{syscall.SIGINT, syscall.SIGTERM} } signals := make(chan os.Signal, 1) signal.Notify(signals, signalToNotify...) select { case sig := <-signals: switch sig { case syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM: hlog.SystemLogger().Infof("Received signal: %s\n", sig) // graceful shutdown return nil } case err := <-errCh: // error occurs, exit immediately return err } return nil } func (h *Hertz) initOnRunHooks(errChan chan error) { // add register func to runHooks opt := h.GetOptions() h.OnRun = append(h.OnRun, func(ctx context.Context) error { go func() { // delay register 1s time.Sleep(1 * time.Second) if err := opt.Registry.Register(opt.RegistryInfo); err != nil { hlog.SystemLogger().Errorf("Register error=%v", err) // pass err to errChan errChan <- err } }() return nil }) } ================================================ FILE: pkg/app/server/hertz_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package server import ( "bytes" "context" "errors" "fmt" "html/template" "io" "io/ioutil" "net" "net/http" "path" "strings" "sync" "sync/atomic" "testing" "time" "github.com/cloudwego/hertz/internal/test/mock/binder" "github.com/cloudwego/hertz/internal/testutils" "github.com/cloudwego/hertz/pkg/app" c "github.com/cloudwego/hertz/pkg/app/client" "github.com/cloudwego/hertz/pkg/app/server/binding" "github.com/cloudwego/hertz/pkg/app/server/registry" "github.com/cloudwego/hertz/pkg/common/config" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/test/mock" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/network/standard" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/protocol/http1/req" "github.com/cloudwego/hertz/pkg/protocol/http1/resp" ) type routeEngine interface { IsRunning() bool } func waitEngineRunning(e routeEngine) { testutils.WaitEngineRunning(e) } func fullURL(ln net.Listener, p string) string { return "http://" + path.Join(ln.Addr().String(), p) } func TestHertz_Run(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() hertz := Default(WithListener(ln)) hertz.GET("/test", func(c context.Context, ctx *app.RequestContext) { time.Sleep(time.Second) path := ctx.Request.URI().PathOriginal() ctx.SetBodyString(string(path)) }) testint := uint32(0) hertz.Engine.OnShutdown = append(hertz.OnShutdown, func(ctx context.Context) { atomic.StoreUint32(&testint, 1) }) assert.Assert(t, len(hertz.Handlers) == 1) go hertz.Spin() waitEngineRunning(hertz) hertz.Close() time.Sleep(10 * time.Millisecond) // Close will not call OnShutdown assert.DeepEqual(t, uint32(0), atomic.LoadUint32(&testint)) } func TestHertz_GracefulShutdown(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() handling := make(chan struct{}) closing := make(chan struct{}) engine := New(WithListener(ln)) engine.GET("/test", func(c context.Context, ctx *app.RequestContext) { close(handling) <-closing path := ctx.Request.URI().PathOriginal() ctx.SetBodyString(string(path)) }) engine.GET("/test2", func(c context.Context, ctx *app.RequestContext) {}) testint := uint32(0) testint2 := uint32(0) testint3 := uint32(0) engine.Engine.OnShutdown = append(engine.OnShutdown, func(ctx context.Context) { atomic.StoreUint32(&testint, 1) }) engine.Engine.OnShutdown = append(engine.OnShutdown, func(ctx context.Context) { atomic.StoreUint32(&testint2, 2) }) engine.Engine.OnShutdown = append(engine.OnShutdown, func(ctx context.Context) { atomic.StoreUint32(&testint3, 3) }) go engine.Spin() waitEngineRunning(engine) hc := http.Client{Timeout: time.Second} var err error var resp *http.Response ch := make(chan struct{}) ch2 := make(chan struct{}) go func() { ticker := time.NewTicker(10 * time.Millisecond) defer ticker.Stop() for range ticker.C { t.Logf("[%v]begin listening\n", time.Now()) _, err2 := hc.Get(fullURL(ln, "/test2")) if err2 != nil { t.Logf("[%v]listening closed: %v", time.Now(), err2) ch2 <- struct{}{} break } } }() go func() { t.Logf("[%v]begin request\n", time.Now()) resp, err = http.Get(fullURL(ln, "/test")) t.Logf("[%v]end request\n", time.Now()) ch <- struct{}{} }() <-handling start := time.Now() ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) t.Logf("[%v]begin shutdown\n", start) engine.Shutdown(ctx) end := time.Now() t.Logf("[%v]end shutdown\n", end) close(closing) <-ch assert.Nil(t, err) assert.NotNil(t, resp) assert.DeepEqual(t, true, resp.Close) assert.DeepEqual(t, uint32(1), atomic.LoadUint32(&testint)) assert.DeepEqual(t, uint32(2), atomic.LoadUint32(&testint2)) assert.DeepEqual(t, uint32(3), atomic.LoadUint32(&testint3)) <-ch2 cancel() } func TestLoadHTMLGlob(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() engine := New(WithMaxRequestBodySize(15), WithListener(ln)) engine.Delims("{[{", "}]}") engine.LoadHTMLGlob("../../common/testdata/template/index.tmpl") engine.GET("/index", func(c context.Context, ctx *app.RequestContext) { ctx.HTML(consts.StatusOK, "index.tmpl", utils.H{ "title": "Main website", }) }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) resp, _ := http.Get(fullURL(ln, "/index")) assert.DeepEqual(t, consts.StatusOK, resp.StatusCode) b := make([]byte, 100) n, _ := resp.Body.Read(b) const expected = `

Main website

` assert.DeepEqual(t, expected, string(b[0:n])) } func TestLoadHTMLFiles(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() engine := New(WithMaxRequestBodySize(15), WithListener(ln)) engine.Delims("{[{", "}]}") engine.SetFuncMap(template.FuncMap{ "formatAsDate": formatAsDate, }) engine.LoadHTMLFiles("../../common/testdata/template/htmltemplate.html", "../../common/testdata/template/index.tmpl") engine.GET("/raw", func(c context.Context, ctx *app.RequestContext) { ctx.HTML(consts.StatusOK, "htmltemplate.html", map[string]interface{}{ "now": time.Date(2017, 0o7, 0o1, 0, 0, 0, 0, time.UTC), }) }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) resp, _ := http.Get(fullURL(ln, "/raw")) assert.DeepEqual(t, consts.StatusOK, resp.StatusCode) b := make([]byte, 100) n, _ := resp.Body.Read(b) assert.DeepEqual(t, "

Date: 2017/07/01

", string(b[0:n])) } func formatAsDate(t time.Time) string { year, month, day := t.Date() return fmt.Sprintf("%d/%02d/%02d", year, month, day) } // copied from router var ( default400Body = []byte("Bad Request") requiredHostBody = []byte("missing required Host header") ) func TestServer_Use(t *testing.T) { router := New() router.Use(func(c context.Context, ctx *app.RequestContext) {}) assert.DeepEqual(t, 1, len(router.Handlers)) router.Use(func(c context.Context, ctx *app.RequestContext) {}) assert.DeepEqual(t, 2, len(router.Handlers)) } func Test_getServerName(t *testing.T) { engine := New() assert.DeepEqual(t, []byte("hertz"), engine.GetServerName()) ss := New() ss.Name = "test_name" assert.DeepEqual(t, []byte("test_name"), ss.GetServerName()) } func TestServer_Run(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() hertz := New(WithListener(ln)) hertz.GET("/test", func(c context.Context, ctx *app.RequestContext) { path := ctx.Request.URI().PathOriginal() ctx.SetBodyString(string(path)) }) hertz.POST("/redirect", func(c context.Context, ctx *app.RequestContext) { ctx.Redirect(consts.StatusMovedPermanently, []byte(fullURL(ln, "/test"))) }) go hertz.Run() waitEngineRunning(hertz) resp, err := http.Get(fullURL(ln, "/test")) assert.Nil(t, err) assert.DeepEqual(t, consts.StatusOK, resp.StatusCode) b := make([]byte, 5) resp.Body.Read(b) assert.DeepEqual(t, "/test", string(b)) resp, err = http.Get(fullURL(ln, "/foo")) assert.Nil(t, err) assert.DeepEqual(t, consts.StatusNotFound, resp.StatusCode) resp, err = http.Post(fullURL(ln, "/redirect"), "", nil) assert.Nil(t, err) assert.DeepEqual(t, consts.StatusOK, resp.StatusCode) b = make([]byte, 5) resp.Body.Read(b) assert.DeepEqual(t, "/test", string(b)) ctx, cancel := context.WithTimeout(context.Background(), 0) defer cancel() _ = hertz.Shutdown(ctx) } func TestNotAbsolutePath(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() engine := New(WithListener(ln)) engine.POST("/", func(c context.Context, ctx *app.RequestContext) { ctx.Write(ctx.Request.Body()) }) engine.POST("/a", func(c context.Context, ctx *app.RequestContext) { ctx.Write(ctx.Request.Body()) }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) s := "POST ?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" zr := mock.NewZeroCopyReader(s) ctx := app.NewContext(0) if err := req.Read(&ctx.Request, zr); err != nil { t.Fatalf("unexpected error: %s", err) } engine.ServeHTTP(context.Background(), ctx) assert.DeepEqual(t, consts.StatusOK, ctx.Response.StatusCode()) assert.DeepEqual(t, ctx.Request.Body(), ctx.Response.Body()) s = "POST a?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" zr = mock.NewZeroCopyReader(s) ctx = app.NewContext(0) if err := req.Read(&ctx.Request, zr); err != nil { t.Fatalf("unexpected error: %s", err) } engine.ServeHTTP(context.Background(), ctx) assert.DeepEqual(t, consts.StatusOK, ctx.Response.StatusCode()) assert.DeepEqual(t, ctx.Request.Body(), ctx.Response.Body()) } func TestNotAbsolutePathWithRawPath(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() engine := New(WithListener(ln), WithUseRawPath(true)) const ( MiddlewareKey = "middleware_key" MiddlewareValue = "middleware_value" ) engine.Use(func(c context.Context, ctx *app.RequestContext) { ctx.Response.Header.Set(MiddlewareKey, MiddlewareValue) }) engine.POST("/", func(c context.Context, ctx *app.RequestContext) { }) engine.POST("/a", func(c context.Context, ctx *app.RequestContext) { }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) s := "POST ?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" zr := mock.NewZeroCopyReader(s) ctx := app.NewContext(0) if err := req.Read(&ctx.Request, zr); err != nil { t.Fatalf("unexpected error: %s", err) } engine.ServeHTTP(context.Background(), ctx) assert.DeepEqual(t, consts.StatusBadRequest, ctx.Response.StatusCode()) assert.DeepEqual(t, default400Body, ctx.Response.Body()) gh := ctx.Response.Header.Get(MiddlewareKey) assert.DeepEqual(t, MiddlewareValue, gh) s = "POST a?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" zr = mock.NewZeroCopyReader(s) ctx = app.NewContext(0) if err := req.Read(&ctx.Request, zr); err != nil { t.Fatalf("unexpected error: %s", err) } engine.ServeHTTP(context.Background(), ctx) assert.DeepEqual(t, consts.StatusBadRequest, ctx.Response.StatusCode()) assert.DeepEqual(t, default400Body, ctx.Response.Body()) gh = ctx.Response.Header.Get(MiddlewareKey) assert.DeepEqual(t, MiddlewareValue, gh) } func TestNotValidHost(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() engine := New(WithListener(ln)) const ( MiddlewareKey = "middleware_key" MiddlewareValue = "middleware_value" ) engine.Use(func(c context.Context, ctx *app.RequestContext) { ctx.Response.Header.Set(MiddlewareKey, MiddlewareValue) }) engine.POST("/", func(c context.Context, ctx *app.RequestContext) { }) engine.POST("/a", func(c context.Context, ctx *app.RequestContext) { }) s := "POST ?a=b HTTP/1.1\r\nHost: \r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" zr := mock.NewZeroCopyReader(s) ctx := app.NewContext(0) if err := req.Read(&ctx.Request, zr); err != nil { t.Fatalf("unexpected error: %s", err) } engine.ServeHTTP(context.Background(), ctx) assert.DeepEqual(t, consts.StatusBadRequest, ctx.Response.StatusCode()) assert.DeepEqual(t, requiredHostBody, ctx.Response.Body()) gh := ctx.Response.Header.Get(MiddlewareKey) assert.DeepEqual(t, MiddlewareValue, gh) s = "POST a?a=b HTTP/1.1\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" zr = mock.NewZeroCopyReader(s) ctx = app.NewContext(0) if err := req.Read(&ctx.Request, zr); err != nil { t.Fatalf("unexpected error: %s", err) } engine.ServeHTTP(context.Background(), ctx) assert.DeepEqual(t, consts.StatusBadRequest, ctx.Response.StatusCode()) assert.DeepEqual(t, requiredHostBody, ctx.Response.Body()) gh = ctx.Response.Header.Get(MiddlewareKey) assert.DeepEqual(t, MiddlewareValue, gh) } func TestWithBasePath(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() engine := New(WithBasePath("/hertz"), WithListener(ln)) engine.POST("/test", func(c context.Context, ctx *app.RequestContext) { }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) var r http.Request r.ParseForm() r.Form.Add("xxxxxx", "xxx") body := strings.NewReader(r.Form.Encode()) resp, err := http.Post(fullURL(ln, "/hertz/test"), "application/x-www-form-urlencoded", body) assert.Nil(t, err) assert.DeepEqual(t, consts.StatusOK, resp.StatusCode) } func TestNotEnoughBodySize(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() engine := New(WithMaxRequestBodySize(5), WithListener(ln)) engine.POST("/test", func(c context.Context, ctx *app.RequestContext) { }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) var r http.Request r.ParseForm() r.Form.Add("xxxxxx", "xxx") body := strings.NewReader(r.Form.Encode()) resp, err := http.Post(fullURL(ln, "/test"), "application/x-www-form-urlencoded", body) assert.Nil(t, err) assert.DeepEqual(t, 413, resp.StatusCode) bodyBytes, _ := ioutil.ReadAll(resp.Body) assert.DeepEqual(t, "Request Entity Too Large", string(bodyBytes)) } func TestEnoughBodySize(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() engine := New(WithMaxRequestBodySize(15), WithListener(ln)) engine.POST("/test", func(c context.Context, ctx *app.RequestContext) { }) go engine.Run() defer func() { engine.Close() }() waitEngineRunning(engine) var r http.Request r.ParseForm() r.Form.Add("xxxxxx", "xxx") body := strings.NewReader(r.Form.Encode()) resp, _ := http.Post(fullURL(ln, "/test"), "application/x-www-form-urlencoded", body) assert.DeepEqual(t, consts.StatusOK, resp.StatusCode) } func TestRequestCtxHijack(t *testing.T) { hijackStartCh := make(chan struct{}) hijackStopCh := make(chan struct{}) engine := New() engine.Init() engine.GET("/foo", func(c context.Context, ctx *app.RequestContext) { if ctx.Hijacked() { t.Error("connection mustn't be hijacked") } ctx.Hijack(func(c network.Conn) { <-hijackStartCh b := make([]byte, 1) // ping-pong echo via hijacked conn for { n, err := c.Read(b) if n != 1 { if err == io.EOF { close(hijackStopCh) return } if err != nil { t.Errorf("unexpected error: %s", err) } t.Errorf("unexpected number of bytes read: %d. Expecting 1", n) } if _, err = c.Write(b); err != nil { t.Errorf("unexpected error when writing data: %s", err) } } }) if !ctx.Hijacked() { t.Error("connection must be hijacked") } ctx.Data(consts.StatusOK, "foo/bar", []byte("hijack it!")) }) hijackedString := "foobar baz hijacked!!!" c := mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n" + hijackedString) ch := make(chan error) go func() { ch <- engine.Serve(context.Background(), c) }() time.Sleep(100 * time.Millisecond) close(hijackStartCh) if err := <-ch; err != nil { if !errors.Is(err, errs.ErrHijacked) { t.Fatalf("Unexpected error from serveConn: %s", err) } } verifyResponse(t, c.WriterRecorder(), consts.StatusOK, "foo/bar", "hijack it!") select { case <-hijackStopCh: case <-time.After(100 * time.Millisecond): t.Fatal("timeout") } zw := c.WriterRecorder() data, err := zw.ReadBinary(zw.Len()) if err != nil { t.Fatalf("Unexpected error when reading remaining data: %s", err) } if string(data) != hijackedString { t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, hijackedString) } } func verifyResponse(t *testing.T, zr network.Reader, expectedStatusCode int, expectedContentType, expectedBody string) { var r protocol.Response if err := resp.Read(&r, zr); err != nil { t.Fatalf("Unexpected error when parsing response: %s", err) } if !bytes.Equal(r.Body(), []byte(expectedBody)) { t.Fatalf("Unexpected body %q. Expected %q", r.Body(), []byte(expectedBody)) } verifyResponseHeader(t, &r.Header, expectedStatusCode, len(r.Body()), expectedContentType, "") } func verifyResponseHeader(t *testing.T, h *protocol.ResponseHeader, expectedStatusCode, expectedContentLength int, expectedContentType, expectedContentEncoding string) { if h.StatusCode() != expectedStatusCode { t.Fatalf("Unexpected status code %d. Expected %d", h.StatusCode(), expectedStatusCode) } if h.ContentLength() != expectedContentLength { t.Fatalf("Unexpected content length %d. Expected %d", h.ContentLength(), expectedContentLength) } if string(h.ContentType()) != expectedContentType { t.Fatalf("Unexpected content type %q. Expected %q", h.ContentType(), expectedContentType) } if string(h.ContentEncoding()) != expectedContentEncoding { t.Fatalf("Unexpected content encoding %q. Expected %q", h.ContentEncoding(), expectedContentEncoding) } } func TestParamInconsist(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() mapS := sync.Map{} h := New(WithListener(ln)) h.GET("/:label", func(c context.Context, ctx *app.RequestContext) { label := ctx.Param("label") x, _ := mapS.LoadOrStore(label, label) labelString := x.(string) if label != labelString { t.Errorf("unexpected label: %s, expected return label: %s", label, labelString) } }) go h.Run() waitEngineRunning(h) client, _ := c.NewClient() wg := sync.WaitGroup{} tr := func() { defer wg.Done() for i := 0; i < 500; i++ { client.Get(context.Background(), nil, fullURL(ln, "/test1")) } } ti := func() { defer wg.Done() for i := 0; i < 500; i++ { client.Get(context.Background(), nil, fullURL(ln, "/test2")) } } for i := 0; i < 30; i++ { go tr() go ti() wg.Add(2) } wg.Wait() } func TestDuplicateReleaseBodyStream(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() h := New(WithStreamBody(true), WithListener(ln)) h.POST("/test", func(ctx context.Context, c *app.RequestContext) { stream := c.RequestBodyStream() c.Response.SetBodyStream(stream, -1) }) go h.Spin() waitEngineRunning(h) client, _ := c.NewClient(c.WithMaxConnsPerHost(1000000), c.WithDialTimeout(time.Minute)) bodyBytes := make([]byte, 102388) index := 0 letterBytes := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" for i := 0; i < 102388; i++ { bodyBytes[i] = letterBytes[index] if i%1969 == 0 && i != 0 { index = index + 1 } } body := string(bodyBytes) wg := sync.WaitGroup{} testFunc := func() { defer wg.Done() r := protocol.NewRequest("POST", fullURL(ln, "/test"), nil) r.SetBodyString(body) resp := protocol.AcquireResponse() err := client.Do(context.Background(), r, resp) if err != nil { t.Errorf("unexpected error: %s", err.Error()) } if body != string(resp.Body()) { t.Errorf("unequal body") } } for i := 0; i < 10; i++ { wg.Add(1) go testFunc() } wg.Wait() } func TestServiceRegisterFailed(t *testing.T) { t.Parallel() // slow test, make it parallel ln := testutils.NewTestListener(t) defer ln.Close() mockRegErr := errors.New("mock register error") var rCount int32 var drCount int32 mockRegistry := MockRegistry{ RegisterFunc: func(info *registry.Info) error { atomic.AddInt32(&rCount, 1) return mockRegErr }, DeregisterFunc: func(info *registry.Info) error { atomic.AddInt32(&drCount, 1) return nil }, } var opts []config.Option opts = append(opts, WithRegistry(mockRegistry, nil)) opts = append(opts, WithListener(ln)) srv := New(opts...) srv.Spin() assert.Assert(t, atomic.LoadInt32(&rCount) == 1) } func TestServiceDeregisterFailed(t *testing.T) { t.Parallel() // slow test, make it parallel ln := testutils.NewTestListener(t) defer ln.Close() mockDeregErr := errors.New("mock deregister error") var wg sync.WaitGroup wg.Add(2) // RegisterFunc && DeregisterFunc var rCount int32 var drCount int32 mockRegistry := MockRegistry{ RegisterFunc: func(info *registry.Info) error { defer wg.Done() atomic.AddInt32(&rCount, 1) return nil }, DeregisterFunc: func(info *registry.Info) error { defer wg.Done() atomic.AddInt32(&drCount, 1) return mockDeregErr }, } var opts []config.Option opts = append(opts, WithRegistry(mockRegistry, nil)) opts = append(opts, WithListener(ln)) srv := New(opts...) go srv.Spin() waitEngineRunning(srv) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() _ = srv.Shutdown(ctx) wg.Wait() assert.Assert(t, atomic.LoadInt32(&rCount) == 1) assert.Assert(t, atomic.LoadInt32(&drCount) == 1) } func TestServiceRegistryInfo(t *testing.T) { t.Parallel() // slow test, make it parallel ln := testutils.NewTestListener(t) defer ln.Close() registryInfo := ®istry.Info{ Weight: 100, Tags: map[string]string{"aa": "bb"}, ServiceName: "hertz.api.test", } checkInfo := func(info *registry.Info) { assert.Assert(t, info.Weight == registryInfo.Weight) assert.Assert(t, info.ServiceName == "hertz.api.test") assert.Assert(t, len(info.Tags) == len(registryInfo.Tags), info.Tags) assert.Assert(t, info.Tags["aa"] == registryInfo.Tags["aa"], info.Tags) } var wg sync.WaitGroup wg.Add(2) // RegisterFunc && DeregisterFunc var rCount int32 var drCount int32 mockRegistry := MockRegistry{ RegisterFunc: func(info *registry.Info) error { defer wg.Done() checkInfo(info) atomic.AddInt32(&rCount, 1) return nil }, DeregisterFunc: func(info *registry.Info) error { defer wg.Done() checkInfo(info) atomic.AddInt32(&drCount, 1) return nil }, } var opts []config.Option opts = append(opts, WithRegistry(mockRegistry, registryInfo)) opts = append(opts, WithListener(ln)) srv := New(opts...) go srv.Spin() waitEngineRunning(srv) ctx, cancel := context.WithTimeout(context.Background(), 0) defer cancel() _ = srv.Shutdown(ctx) wg.Wait() assert.Assert(t, atomic.LoadInt32(&rCount) == 1) assert.Assert(t, atomic.LoadInt32(&drCount) == 1) } func TestServiceRegistryNoInitInfo(t *testing.T) { t.Parallel() // slow test, make it parallel ln := testutils.NewTestListener(t) defer ln.Close() checkInfo := func(info *registry.Info) { assert.Assert(t, info == nil) } var wg sync.WaitGroup wg.Add(2) // RegisterFunc && DeregisterFunc var rCount int32 var drCount int32 mockRegistry := MockRegistry{ RegisterFunc: func(info *registry.Info) error { defer wg.Done() checkInfo(info) atomic.AddInt32(&rCount, 1) return nil }, DeregisterFunc: func(info *registry.Info) error { defer wg.Done() checkInfo(info) atomic.AddInt32(&drCount, 1) return nil }, } var opts []config.Option opts = append(opts, WithRegistry(mockRegistry, nil)) opts = append(opts, WithListener(ln)) srv := New(opts...) go srv.Spin() waitEngineRunning(srv) ctx, cancel := context.WithTimeout(context.Background(), 0) defer cancel() _ = srv.Shutdown(ctx) wg.Wait() assert.Assert(t, atomic.LoadInt32(&rCount) == 1) assert.Assert(t, atomic.LoadInt32(&drCount) == 1) } type testTracer struct{} func (t testTracer) Start(ctx context.Context, c *app.RequestContext) context.Context { value := 0 if v := ctx.Value("testKey"); v != nil { value = v.(int) value++ } //nolint:staticcheck // SA1029 no built-in type string as key return context.WithValue(ctx, "testKey", value) } func (t testTracer) Finish(ctx context.Context, c *app.RequestContext) {} func TestReuseCtx(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() h := New(WithTracer(testTracer{}), WithListener(ln)) h.GET("/ping", func(ctx context.Context, c *app.RequestContext) { assert.DeepEqual(t, 0, ctx.Value("testKey").(int)) }) go h.Spin() waitEngineRunning(h) for i := 0; i < 1000; i++ { _, _, err := c.Get(context.Background(), nil, fullURL(ln, "/ping")) assert.Nil(t, err) } } func TestOnprepare(t *testing.T) { ln1 := testutils.NewTestListener(t) defer ln1.Close() h1 := New( WithListener(ln1), WithOnConnect(func(ctx context.Context, conn network.Conn) context.Context { b, err := conn.Peek(3) assert.Nil(t, err) assert.DeepEqual(t, string(b), "GET") conn.Close() return ctx })) h1.GET("/ping", func(ctx context.Context, c *app.RequestContext) { c.JSON(consts.StatusOK, utils.H{"ping": "pong"}) }) go h1.Spin() waitEngineRunning(h1) _, _, err := c.Get(context.Background(), nil, fullURL(ln1, "/ping")) assert.DeepEqual(t, "the server closed connection before returning the first response byte. Make sure the server returns 'Connection: close' response header before closing the connection", err.Error()) ln2 := testutils.NewTestListener(t) defer ln2.Close() h2 := New( WithOnAccept(func(conn net.Conn) context.Context { conn.Close() return context.Background() }), WithListener(ln2)) h2.GET("/ping", func(ctx context.Context, c *app.RequestContext) { c.JSON(consts.StatusOK, utils.H{"ping": "pong"}) }) go h2.Spin() waitEngineRunning(h2) _, _, err = c.Get(context.Background(), nil, fullURL(ln2, "/ping")) if err == nil { t.Fatalf("err should not be nil") } ln3 := testutils.NewTestListener(t) defer ln3.Close() var h3 *Hertz h3 = New( WithOnAccept(func(conn net.Conn) context.Context { assert.DeepEqual(t, conn.LocalAddr().String(), ln3.Addr().String()) return context.Background() }), WithListener(ln3), WithTransport(standard.NewTransporter)) h3.GET("/ping", func(ctx context.Context, c *app.RequestContext) { c.JSON(consts.StatusOK, utils.H{"ping": "pong"}) }) go h3.Spin() waitEngineRunning(h3) c.Get(context.Background(), nil, fullURL(ln3, "/ping")) } type lockBuffer struct { sync.Mutex b bytes.Buffer } func (l *lockBuffer) Write(p []byte) (int, error) { l.Lock() defer l.Unlock() return l.b.Write(p) } func (l *lockBuffer) String() string { l.Lock() defer l.Unlock() return l.b.String() } func TestHertzDisableHeaderNamesNormalizing(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() h := New( WithListener(ln), WithDisableHeaderNamesNormalizing(true), ) headerName := "CASE-senSITive-HEAder-NAME" headerValue := "foobar-baz" succeed := false h.GET("/test", func(c context.Context, ctx *app.RequestContext) { ctx.VisitAllHeaders(func(key, value []byte) { if string(key) == headerName && string(value) == headerValue { succeed = true return } }) if !succeed { t.Fatalf("DisableHeaderNamesNormalizing failed") } else { ctx.Header(headerName, headerValue) } }) go h.Spin() waitEngineRunning(h) cli, _ := c.NewClient(c.WithDisableHeaderNamesNormalizing(true)) r := protocol.NewRequest("GET", fullURL(ln, "/test"), nil) r.Header.DisableNormalizing() r.Header.Set(headerName, headerValue) res := protocol.AcquireResponse() err := cli.Do(context.Background(), r, res) assert.Nil(t, err) assert.DeepEqual(t, headerValue, res.Header.Get(headerName)) } func TestBindConfig(t *testing.T) { type Req struct { A int `query:"a"` } bindConfig := binding.NewBindConfig() bindConfig.LooseZeroMode = true ln := testutils.NewTestListener(t) defer ln.Close() h := New( WithListener(ln), WithBindConfig(bindConfig)) h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { var req Req err := ctx.BindAndValidate(&req) if err != nil { t.Fatal("unexpected error") } }) go h.Spin() waitEngineRunning(h) hc := http.Client{Timeout: time.Second} _, err := hc.Get(fullURL(ln, "/bind?a=")) assert.Nil(t, err) bindConfig = binding.NewBindConfig() bindConfig.LooseZeroMode = false ln2 := testutils.NewTestListener(t) defer ln2.Close() h2 := New( WithListener(ln2), WithBindConfig(bindConfig)) h2.GET("/bind", func(c context.Context, ctx *app.RequestContext) { var req Req err := ctx.BindAndValidate(&req) if err == nil { t.Fatal("expect an error") } }) go h2.Spin() waitEngineRunning(h2) _, err = hc.Get(fullURL(ln2, "/bind?a=")) assert.Nil(t, err) time.Sleep(100 * time.Millisecond) } func TestCustomBinder(t *testing.T) { type Req struct { A int `query:"a"` } ln := testutils.NewTestListener(t) defer ln.Close() h := New( WithListener(ln), WithCustomBinder(binder.NewBinderWithValidateError(errors.New("test binder")))) h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { var req Req err := ctx.BindAndValidate(&req) if err == nil { t.Fatal("expect an error") } assert.DeepEqual(t, "test binder", err.Error()) }) go h.Spin() waitEngineRunning(h) hc := http.Client{Timeout: time.Second} _, err := hc.Get(fullURL(ln, "/bind?a=")) assert.Nil(t, err) time.Sleep(100 * time.Millisecond) } func TestValidateConfigRegValidateFunc(t *testing.T) { type Req struct { A int `query:"a" vd:"f($)"` } validateConfig := &binding.ValidateConfig{} validateConfig.MustRegValidateFunc("f", func(args ...interface{}) error { return fmt.Errorf("test validator") }) ln := testutils.NewTestListener(t) defer ln.Close() h := New(WithListener(ln)) h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { var req Req err := ctx.BindAndValidate(&req) if err == nil { t.Fatal("expect an error") } assert.DeepEqual(t, "test validator", err.Error()) }) go h.Spin() waitEngineRunning(h) hc := http.Client{Timeout: time.Second} _, err := hc.Get(fullURL(ln, "/bind?a=2")) assert.Nil(t, err) time.Sleep(100 * time.Millisecond) } func TestCustomValidator(t *testing.T) { type Req struct { A int `query:"a" vd:"f($)"` } ln := testutils.NewTestListener(t) defer ln.Close() h := New( WithListener(ln), WithCustomValidatorFunc(func(_ *protocol.Request, _ interface{}) error { return errors.New("test mock validator") })) h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { var req Req err := ctx.BindAndValidate(&req) if err == nil { t.Fatal("expect an error") } assert.DeepEqual(t, "test mock validator", err.Error()) }) go h.Spin() time.Sleep(100 * time.Millisecond) hc := http.Client{Timeout: time.Second} _, err := hc.Get(fullURL(ln, "/bind?a=2")) assert.Nil(t, err) time.Sleep(100 * time.Millisecond) } type ValidateError struct { ErrType, FailField, Msg string } // Error implements error interface. func (e *ValidateError) Error() string { if e.Msg != "" { return e.ErrType + ": expr_path=" + e.FailField + ", cause=" + e.Msg } return e.ErrType + ": expr_path=" + e.FailField + ", cause=invalid" } func TestValidateConfigSetSetErrorFactory(t *testing.T) { type TestValidate struct { B int `query:"b" vd:"$>100"` } CustomValidateErrFunc := func(failField, msg string) error { err := ValidateError{ ErrType: "validateErr", FailField: "[validateFailField]: " + failField, Msg: "[validateErrMsg]: " + msg, } return &err } validateConfig := binding.NewValidateConfig() validateConfig.SetValidatorErrorFactory(CustomValidateErrFunc) ln := testutils.NewTestListener(t) defer ln.Close() h := New( WithListener(ln), WithValidateConfig(validateConfig)) h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { var req TestValidate err := ctx.BindAndValidate(&req) if err == nil { t.Fatal("expect an error") } assert.DeepEqual(t, "validateErr: expr_path=[validateFailField]: B, cause=[validateErrMsg]: ", err.Error()) }) go h.Spin() waitEngineRunning(h) hc := http.Client{Timeout: time.Second} _, err := hc.Get(fullURL(ln, "/bind?b=1")) assert.Nil(t, err) time.Sleep(100 * time.Millisecond) } func TestValidateConfigAndBindConfig(t *testing.T) { type Req struct { A int `query:"a" vt:"$>=0&&$<=130"` } validateConfig := binding.NewValidateConfig() validateConfig.ValidateTag = "vt" ln := testutils.NewTestListener(t) defer ln.Close() h := New( WithListener(ln), WithValidateConfig(validateConfig)) h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { var req Req err := ctx.BindAndValidate(&req) if err == nil { t.Fatal("expect an error") } t.Log(err) }) go h.Spin() waitEngineRunning(h) hc := http.Client{Timeout: time.Second} _, err := hc.Get(fullURL(ln, "/bind?a=135")) assert.Nil(t, err) time.Sleep(100 * time.Millisecond) } func TestWithDisableDefaultDate(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() h := New( WithListener(ln), WithDisableDefaultDate(true), ) h.GET("/", func(_ context.Context, c *app.RequestContext) {}) go h.Spin() waitEngineRunning(h) hc := http.Client{Timeout: time.Second} r, _ := hc.Get(fullURL(ln, "")) //nolint:errcheck assert.DeepEqual(t, "", r.Header.Get("Date")) } func TestWithDisableDefaultContentType(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() h := New( WithListener(ln), WithDisableDefaultContentType(true), ) h.GET("/", func(_ context.Context, c *app.RequestContext) {}) go h.Spin() waitEngineRunning(h) hc := http.Client{Timeout: time.Second} r, _ := hc.Get(fullURL(ln, "")) //nolint:errcheck assert.DeepEqual(t, "", r.Header.Get("Content-Type")) } func TestServerReturns413And431OnSizeLimits(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() h := Default(WithListener(ln), WithMaxHeaderBytes(500), WithMaxRequestBodySize(1000)) h.GET("/test", func(c context.Context, ctx *app.RequestContext) { ctx.String(consts.StatusOK, "success") }) h.POST("/test", func(c context.Context, ctx *app.RequestContext) { ctx.String(consts.StatusOK, "success") }) go h.Spin() waitEngineRunning(h) defer h.Shutdown(context.Background()) addr := ln.Addr().String() client := &http.Client{Timeout: 2 * time.Second} // Test 431 - Request Header Fields Too Large req, _ := http.NewRequest("GET", fmt.Sprintf("http://%s/test", addr), nil) req.Header.Set("Large-Header", strings.Repeat("x", 501)) // Exceeds 500 byte limit resp, err := client.Do(req) assert.Nil(t, err) resp.Body.Close() // If we get a response, it should be 431 assert.DeepEqual(t, resp.StatusCode, 431) // Test 413 - Request Entity Too Large largeBody := strings.NewReader(strings.Repeat("x", 1001)) // Exceeds 1000 byte limit req2, _ := http.NewRequest("POST", fmt.Sprintf("http://%s/test", addr), largeBody) resp2, err2 := client.Do(req2) assert.Nil(t, err2) resp2.Body.Close() // Should return 413 assert.DeepEqual(t, resp2.StatusCode, 413) } ================================================ FILE: pkg/app/server/hertz_unix_test.go ================================================ // Copyright 2023 CloudWeGo Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // //go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris package server import ( "context" "net" "net/http" "os" "os/exec" "strconv" "sync/atomic" "syscall" "testing" "time" "golang.org/x/sys/unix" "github.com/cloudwego/hertz/internal/testutils" "github.com/cloudwego/hertz/pkg/app" c "github.com/cloudwego/hertz/pkg/app/client" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/network/standard" "github.com/cloudwego/hertz/pkg/protocol/consts" ) func TestReusePorts(t *testing.T) { cfg := &net.ListenConfig{Control: func(network, address string, c syscall.RawConn) error { return c.Control(func(fd uintptr) { syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEADDR, 1) syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1) }) }} ha := New(WithHostPorts("localhost:10093"), WithListenConfig(cfg), WithTransport(standard.NewTransporter)) hb := New(WithHostPorts("localhost:10093"), WithListenConfig(cfg), WithTransport(standard.NewTransporter)) hc := New(WithHostPorts("localhost:10093"), WithListenConfig(cfg)) hd := New(WithHostPorts("localhost:10093"), WithListenConfig(cfg)) ha.GET("/ping", func(c context.Context, ctx *app.RequestContext) { ctx.JSON(consts.StatusOK, utils.H{"ping": "pong"}) }) hc.GET("/ping", func(c context.Context, ctx *app.RequestContext) { ctx.JSON(consts.StatusOK, utils.H{"ping": "pong"}) }) hd.GET("/ping", func(c context.Context, ctx *app.RequestContext) { ctx.JSON(consts.StatusOK, utils.H{"ping": "pong"}) }) hb.GET("/ping", func(c context.Context, ctx *app.RequestContext) { ctx.JSON(consts.StatusOK, utils.H{"ping": "pong"}) }) go ha.Run() go hb.Run() go hc.Run() go hd.Run() waitEngineRunning(ha) waitEngineRunning(hb) waitEngineRunning(hc) waitEngineRunning(hd) client, _ := c.NewClient() for i := 0; i < 1000; i++ { statusCode, body, err := client.Get(context.Background(), nil, "http://localhost:10093/ping") assert.Nil(t, err) assert.DeepEqual(t, consts.StatusOK, statusCode) assert.DeepEqual(t, "{\"ping\":\"pong\"}", string(body)) } } func TestHertz_Spin(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() engine := New(WithListener(ln)) engine.GET("/test", func(c context.Context, ctx *app.RequestContext) { time.Sleep(40 * time.Millisecond) path := ctx.Request.URI().PathOriginal() ctx.SetBodyString(string(path)) }) engine.GET("/test2", func(c context.Context, ctx *app.RequestContext) {}) testint := uint32(0) engine.Engine.OnShutdown = append(engine.OnShutdown, func(ctx context.Context) { atomic.StoreUint32(&testint, 1) }) go engine.Spin() waitEngineRunning(engine) hc := http.Client{Timeout: time.Second} var err error var resp *http.Response ch := make(chan struct{}) ch2 := make(chan struct{}) go func() { ticker := time.NewTicker(10 * time.Millisecond) defer ticker.Stop() for range ticker.C { _, err := hc.Get(fullURL(ln, "/test2")) t.Logf("[%v]begin listening\n", time.Now()) if err != nil { t.Logf("[%v]listening closed: %v", time.Now(), err) ch2 <- struct{}{} break } } }() go func() { t.Logf("[%v]begin request\n", time.Now()) resp, err = http.Get(fullURL(ln, "/test")) t.Logf("[%v]end request\n", time.Now()) ch <- struct{}{} }() time.Sleep(20 * time.Millisecond) pid := strconv.Itoa(os.Getpid()) cmd := exec.Command("kill", "-SIGHUP", pid) t.Logf("[%v]begin SIGHUP\n", time.Now()) if err := cmd.Run(); err != nil { t.Fatal(err) } t.Logf("[%v]end SIGHUP\n", time.Now()) <-ch assert.Nil(t, err) assert.NotNil(t, resp) <-ch2 assert.DeepEqual(t, uint32(1), atomic.LoadUint32(&testint)) } func TestWithSenseClientDisconnection(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() var closeFlag int32 h := New(WithListener(ln), WithSenseClientDisconnection(true)) h.GET("/ping", func(c context.Context, ctx *app.RequestContext) { assert.DeepEqual(t, "aa", string(ctx.Host())) ch := make(chan struct{}) select { case <-c.Done(): atomic.StoreInt32(&closeFlag, 1) assert.DeepEqual(t, context.Canceled, c.Err()) case <-ch: } }) go h.Spin() waitEngineRunning(h) con, err := net.Dial("tcp", ln.Addr().String()) assert.Nil(t, err) _, err = con.Write([]byte("GET /ping HTTP/1.1\r\nHost: aa\r\n\r\n")) assert.Nil(t, err) time.Sleep(20 * time.Millisecond) assert.DeepEqual(t, atomic.LoadInt32(&closeFlag), int32(0)) assert.Nil(t, con.Close()) time.Sleep(20 * time.Millisecond) assert.DeepEqual(t, atomic.LoadInt32(&closeFlag), int32(1)) } func TestWithSenseClientDisconnectionAndWithOnConnect(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() var closeFlag int32 h := New(WithListener(ln), WithSenseClientDisconnection(true), WithOnConnect(func(ctx context.Context, conn network.Conn) context.Context { return ctx })) h.GET("/ping", func(c context.Context, ctx *app.RequestContext) { assert.DeepEqual(t, "aa", string(ctx.Host())) ch := make(chan struct{}) select { case <-c.Done(): atomic.StoreInt32(&closeFlag, 1) assert.DeepEqual(t, context.Canceled, c.Err()) case <-ch: } }) go h.Spin() waitEngineRunning(h) con, err := net.Dial("tcp", ln.Addr().String()) assert.Nil(t, err) _, err = con.Write([]byte("GET /ping HTTP/1.1\r\nHost: aa\r\n\r\n")) assert.Nil(t, err) time.Sleep(20 * time.Millisecond) assert.DeepEqual(t, atomic.LoadInt32(&closeFlag), int32(0)) assert.Nil(t, con.Close()) time.Sleep(20 * time.Millisecond) assert.DeepEqual(t, atomic.LoadInt32(&closeFlag), int32(1)) } ================================================ FILE: pkg/app/server/mocks_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package server import ( "github.com/cloudwego/hertz/pkg/app/server/registry" ) var _ registry.Registry = (*MockRegistry)(nil) // MockRegistry is the mock implementation of registry.Registry interface. type MockRegistry struct { RegisterFunc func(info *registry.Info) error DeregisterFunc func(info *registry.Info) error } // Register is the mock implementation of registry.Registry interface. func (m MockRegistry) Register(info *registry.Info) error { if m.RegisterFunc != nil { return m.RegisterFunc(info) } return nil } // Deregister is the mock implementation of registry.Registry interface. func (m MockRegistry) Deregister(info *registry.Info) error { if m.DeregisterFunc != nil { return m.DeregisterFunc(info) } return nil } ================================================ FILE: pkg/app/server/option.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package server import ( "context" "crypto/tls" "net" "strings" "time" "github.com/cloudwego/hertz/pkg/app/server/binding" "github.com/cloudwego/hertz/pkg/app/server/registry" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/common/tracer" "github.com/cloudwego/hertz/pkg/common/tracer/stats" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/network/standard" ) // WithKeepAliveTimeout sets keep-alive timeout. // // In most cases, there is no need to care about this option. func WithKeepAliveTimeout(t time.Duration) config.Option { return config.Option{F: func(o *config.Options) { o.KeepAliveTimeout = t }} } // WithReadTimeout sets read timeout. // // Close the connection when read request timeout. func WithReadTimeout(t time.Duration) config.Option { return config.Option{F: func(o *config.Options) { o.ReadTimeout = t }} } // WithWriteTimeout sets write timeout. // // Connection will be closed when write request timeout. func WithWriteTimeout(t time.Duration) config.Option { return config.Option{F: func(o *config.Options) { o.WriteTimeout = t }} } // WithIdleTimeout sets idle timeout. // // Close the connection when the successive request timeout (in keepalive mode). // Set this to protect server from misbehavior clients. func WithIdleTimeout(t time.Duration) config.Option { return config.Option{F: func(o *config.Options) { o.IdleTimeout = t }} } // WithRedirectTrailingSlash sets redirectTrailingSlash. // // Enables automatic redirection if the current route can't be matched but a // handler for the path with (without) the trailing slash exists. // For example if /foo/ is requested but a route only exists for /foo, the // client is redirected to /foo with http status code 301 for GET requests // and 307 for all other request methods. func WithRedirectTrailingSlash(b bool) config.Option { return config.Option{F: func(o *config.Options) { o.RedirectTrailingSlash = b }} } // WithRedirectFixedPath sets redirectFixedPath. // // If enabled, the router tries to fix the current request path, if no // handle is registered for it. // First superfluous path elements like ../ or // are removed. // Afterwards the router does a case-insensitive lookup of the cleaned path. // If a handle can be found for this route, the router makes a redirection // to the corrected path with status code 301 for GET requests and 308 for // all other request methods. // For example /FOO and /..//Foo could be redirected to /foo. // RedirectTrailingSlash is independent of this option. func WithRedirectFixedPath(b bool) config.Option { return config.Option{F: func(o *config.Options) { o.RedirectFixedPath = b }} } // WithHandleMethodNotAllowed sets handleMethodNotAllowed. // // If enabled, the router checks if another method is allowed for the // current route, if the current request can not be routed. // If this is the case, the request is answered with 'Method Not Allowed' // and HTTP status code 405. // If no other Method is allowed, the request is delegated to the NotFound // handler. func WithHandleMethodNotAllowed(b bool) config.Option { return config.Option{F: func(o *config.Options) { o.HandleMethodNotAllowed = b }} } // WithUseRawPath sets useRawPath. // // If enabled, the url.RawPath will be used to find parameters. func WithUseRawPath(b bool) config.Option { return config.Option{F: func(o *config.Options) { o.UseRawPath = b }} } // WithRemoveExtraSlash sets removeExtraSlash. // // RemoveExtraSlash a parameter can be parsed from the URL even with extra slashes. // If UseRawPath is false (by default), the RemoveExtraSlash effectively is true, // as url.Path gonna be used, which is already cleaned. func WithRemoveExtraSlash(b bool) config.Option { return config.Option{F: func(o *config.Options) { o.RemoveExtraSlash = b }} } // WithUnescapePathValues sets unescapePathValues. // // If true, the path value will be unescaped. // If UseRawPath is false (by default), the UnescapePathValues effectively is true, // as url.Path gonna be used, which is already unescaped. func WithUnescapePathValues(b bool) config.Option { return config.Option{F: func(o *config.Options) { o.UnescapePathValues = b }} } // WithDisablePreParseMultipartForm sets disablePreParseMultipartForm. // // This option is useful for servers that desire to treat // multipart form data as a binary blob, or choose when to parse the data. // Server pre parses multipart form data by default. func WithDisablePreParseMultipartForm(b bool) config.Option { return config.Option{F: func(o *config.Options) { o.DisablePreParseMultipartForm = b }} } // WithHostPorts sets listening address. func WithHostPorts(hp string) config.Option { return config.Option{F: func(o *config.Options) { o.Addr = hp }} } // WithListener sets the listener to use. // // If set, the server will use this listener instead of creating a new one. // This is useful for custom listener implementations or testing. // Note: This will update Network and Addr based on the listener's address, // and reset ListenConfig since it's not needed when a listener is provided. // // WARNING: Custom net.Listener implementations may not be supported by cloudwego/netpoll. // If your custom listener doesn't support netpoll, you need to explicitly set the transporter to the standard library: // // WithListener(customListener), WithTransport(standard.NewTransporter) func WithListener(ln net.Listener) config.Option { return config.Option{F: func(o *config.Options) { o.Listener = ln o.Network = ln.Addr().Network() o.Addr = ln.Addr().String() o.ListenConfig = nil }} } // WithBasePath sets basePath.Must be "/" prefix and suffix,If not the default concatenate "/" func WithBasePath(basePath string) config.Option { return config.Option{F: func(o *config.Options) { // Must be "/" prefix and suffix,If not the default concatenate "/" if !strings.HasPrefix(basePath, "/") { basePath = "/" + basePath } if !strings.HasSuffix(basePath, "/") { basePath = basePath + "/" } o.BasePath = basePath }} } // WithMaxRequestBodySize sets the limitation of request body size. Unit: byte // // Body buffer which larger than this size will be put back into buffer poll. func WithMaxRequestBodySize(bs int) config.Option { return config.Option{F: func(o *config.Options) { o.MaxRequestBodySize = bs }} } // WithMaxHeaderBytes sets the limitation of request header size. Unit: byte // // If the header size exceeds this value, an ErrHeaderTooLarge error will be returned // and the server will respond with HTTP 431 Request Header Fields Too Large. // // Default: 1MB (1 << 20 bytes) func WithMaxHeaderBytes(size int) config.Option { return config.Option{F: func(o *config.Options) { o.MaxHeaderBytes = size }} } // WithMaxKeepBodySize sets max size of request/response body to keep when recycled. Unit: byte // // Body buffer which larger than this size will be put back into buffer poll. // Note: If memory pressure is high, try setting the value to 0. func WithMaxKeepBodySize(bs int) config.Option { return config.Option{F: func(o *config.Options) { o.MaxKeepBodySize = bs }} } // WithGetOnly sets whether accept GET request only. Default: false func WithGetOnly(isOnly bool) config.Option { return config.Option{F: func(o *config.Options) { o.GetOnly = isOnly }} } // WithKeepAlive sets Whether use long connection. Default: true func WithKeepAlive(b bool) config.Option { return config.Option{F: func(o *config.Options) { o.DisableKeepalive = !b }} } // WithStreamBody determines whether read body in stream or not. // // StreamRequestBody enables streaming request body, // and calls the handler sooner when given body is // larger than the current limit. func WithStreamBody(b bool) config.Option { return config.Option{F: func(o *config.Options) { o.StreamRequestBody = b }} } // WithNetwork sets network. Support "tcp", "udp", "unix"(unix domain socket). func WithNetwork(nw string) config.Option { return config.Option{F: func(o *config.Options) { o.Network = nw }} } // WithExitWaitTime sets timeout for graceful shutdown. // // The server may exit ahead after all connections closed. // All responses after shutdown will be added 'Connection: close' header. func WithExitWaitTime(timeout time.Duration) config.Option { return config.Option{F: func(o *config.Options) { o.ExitWaitTimeout = timeout }} } // WithTLS sets TLS config to start a tls server. // // NOTE: If a tls server is started, it won't accept non-tls request. func WithTLS(cfg *tls.Config) config.Option { return config.Option{F: func(o *config.Options) { // If there is no explicit transporter, change it to standard one. Netpoll do not support tls yet. if o.TransporterNewer == nil { o.TransporterNewer = standard.NewTransporter } o.TLS = cfg }} } // WithListenConfig sets listener config. func WithListenConfig(l *net.ListenConfig) config.Option { return config.Option{F: func(o *config.Options) { o.ListenConfig = l }} } // WithTransport sets which network library to use. func WithTransport(transporter func(options *config.Options) network.Transporter) config.Option { return config.Option{F: func(o *config.Options) { o.TransporterNewer = transporter }} } // WithAltTransport sets which network library to use as an alternative transporter(need to be implemented by specific transporter). func WithAltTransport(transporter func(options *config.Options) network.Transporter) config.Option { return config.Option{F: func(o *config.Options) { o.AltTransporterNewer = transporter }} } // WithH2C sets whether enable H2C. func WithH2C(enable bool) config.Option { return config.Option{F: func(o *config.Options) { o.H2C = enable }} } // WithReadBufferSize sets the size of each read buffer node in standard transport. // NOTE: this cannot limit the header size. func WithReadBufferSize(size int) config.Option { return config.Option{F: func(o *config.Options) { o.ReadBufferSize = size }} } // WithALPN sets whether enable ALPN. func WithALPN(enable bool) config.Option { return config.Option{F: func(o *config.Options) { o.ALPN = enable }} } // WithTracer adds tracer to server. func WithTracer(t tracer.Tracer) config.Option { return config.Option{F: func(o *config.Options) { o.Tracers = append(o.Tracers, t) }} } // WithTraceLevel sets the level trace. func WithTraceLevel(level stats.Level) config.Option { return config.Option{F: func(o *config.Options) { o.TraceLevel = level }} } // WithRegistry sets the registry and registry's info func WithRegistry(r registry.Registry, info *registry.Info) config.Option { return config.Option{F: func(o *config.Options) { o.Registry = r o.RegistryInfo = info }} } // WithAutoReloadRender sets the config of auto reload render. // If auto reload render is enabled: // 1. interval = 0 means reload render according to file watch mechanism.(recommended) // 2. interval > 0 means reload render every interval. func WithAutoReloadRender(b bool, interval time.Duration) config.Option { return config.Option{F: func(o *config.Options) { o.AutoReloadRender = b o.AutoReloadInterval = interval }} } // WithDisablePrintRoute sets whether disable debugPrintRoute // If we don't set it, it will default to false func WithDisablePrintRoute(b bool) config.Option { return config.Option{F: func(o *config.Options) { o.DisablePrintRoute = b }} } // WithOnAccept sets the callback function when a new connection is accepted but cannot // receive data in netpoll. In go net, it will be called before converting tls connection func WithOnAccept(fn func(conn net.Conn) context.Context) config.Option { return config.Option{F: func(o *config.Options) { o.OnAccept = fn }} } // WithOnConnect sets the onConnect function. It can received data from connection in netpoll. // In go net, it will be called after converting tls connection. func WithOnConnect(fn func(ctx context.Context, conn network.Conn) context.Context) config.Option { return config.Option{F: func(o *config.Options) { o.OnConnect = fn }} } // WithBindConfig sets bind config. func WithBindConfig(bc *binding.BindConfig) config.Option { return config.Option{F: func(o *config.Options) { o.BindConfig = bc }} } // WithValidateConfig sets validate config. // // Deprecated: Use WithCustomValidatorFunc with a custom validation function instead. func WithValidateConfig(vc *binding.ValidateConfig) config.Option { return config.Option{F: func(o *config.Options) { o.ValidateConfig = vc }} } // WithCustomBinder sets customized Binder. // // Priority: CustomBinder has the highest priority and will override any BindConfig // and CustomValidator settings when present. If CustomBinder is set, both the default // binder initialization and validator initialization are completely bypassed. // // Priority order (highest to lowest): // 1. CustomBinder (this option) - completely overrides all binding and validation logic // 2. CustomValidator/WithCustomValidatorFunc - sets custom validation for default binder // 3. BindConfig - configures the default binder behavior // 4. ValidateConfig - legacy validation configuration (deprecated) // 5. Default binding and validation behavior // // Note: When CustomBinder is used, validation logic must be implemented within the // custom binder's BindAndValidate method, as CustomValidator is ignored. func WithCustomBinder(b binding.Binder) config.Option { return config.Option{F: func(o *config.Options) { o.CustomBinder = b }} } // WithCustomValidator sets customized StructValidator. // // Deprecated: Use WithCustomValidatorFunc instead. func WithCustomValidator(b binding.StructValidator) config.Option { return WithCustomValidatorFunc(binding.MakeValidatorFunc(b)) } // WithCustomValidatorFunc sets customized validator function. func WithCustomValidatorFunc(vf binding.ValidatorFunc) config.Option { return config.Option{F: func(o *config.Options) { o.CustomValidator = vf }} } // WithDisableHeaderNamesNormalizing is used to set whether disable header names normalizing. func WithDisableHeaderNamesNormalizing(disable bool) config.Option { return config.Option{F: func(o *config.Options) { o.DisableHeaderNamesNormalizing = disable }} } func WithDisableDefaultDate(disable bool) config.Option { return config.Option{F: func(o *config.Options) { o.NoDefaultDate = disable }} } func WithDisableDefaultContentType(disable bool) config.Option { return config.Option{F: func(o *config.Options) { o.NoDefaultContentType = disable }} } // WithSenseClientDisconnection sets the ability to sense client disconnections. // If we don't set it, it will default to false. // There are two issues to note when using this option: // 1. Warning: It only applies to netpoll. // 2. After opening, the context.Context in the request will be cancelled. // // Example: // server.Default( // server.WithSenseClientDisconnection(true), // ) func WithSenseClientDisconnection(b bool) config.Option { return config.Option{F: func(o *config.Options) { o.SenseClientDisconnection = b }} } ================================================ FILE: pkg/app/server/option_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package server import ( "context" "net" "reflect" "syscall" "testing" "time" "github.com/cloudwego/hertz/internal/testutils" "github.com/cloudwego/hertz/pkg/app/server/registry" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/tracer/stats" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/network" ) func TestOptions(t *testing.T) { info := ®istry.Info{ ServiceName: "hertz.test.api", Weight: 10, Addr: utils.NewNetAddr("tcp", ":8888"), } cfg := &net.ListenConfig{Control: func(network, address string, c syscall.RawConn) error { return c.Control(func(fd uintptr) {}) }} transporter := func(options *config.Options) network.Transporter { return &mockTransporter{} } opt := config.NewOptions([]config.Option{ WithReadTimeout(time.Second), WithWriteTimeout(time.Second), WithIdleTimeout(time.Second), WithKeepAliveTimeout(time.Second), WithRedirectTrailingSlash(false), WithRedirectFixedPath(true), WithHandleMethodNotAllowed(true), WithUseRawPath(true), WithRemoveExtraSlash(true), WithUnescapePathValues(false), WithDisablePreParseMultipartForm(true), WithStreamBody(false), WithHostPorts(":8888"), WithBasePath("/"), WithMaxRequestBodySize(2), WithDisablePrintRoute(true), WithSenseClientDisconnection(true), WithNetwork("unix"), WithExitWaitTime(time.Second), WithMaxKeepBodySize(500), WithGetOnly(true), WithKeepAlive(false), WithTLS(nil), WithH2C(true), WithReadBufferSize(100), WithALPN(true), WithTraceLevel(stats.LevelDisabled), WithRegistry(nil, info), WithAutoReloadRender(true, 5*time.Second), WithListenConfig(cfg), WithAltTransport(transporter), WithDisableHeaderNamesNormalizing(true), WithMaxHeaderBytes(1024), }) assert.DeepEqual(t, opt.ReadTimeout, time.Second) assert.DeepEqual(t, opt.WriteTimeout, time.Second) assert.DeepEqual(t, opt.IdleTimeout, time.Second) assert.DeepEqual(t, opt.KeepAliveTimeout, time.Second) assert.DeepEqual(t, opt.RedirectTrailingSlash, false) assert.DeepEqual(t, opt.RedirectFixedPath, true) assert.DeepEqual(t, opt.HandleMethodNotAllowed, true) assert.DeepEqual(t, opt.UseRawPath, true) assert.DeepEqual(t, opt.RemoveExtraSlash, true) assert.DeepEqual(t, opt.UnescapePathValues, false) assert.DeepEqual(t, opt.DisablePreParseMultipartForm, true) assert.DeepEqual(t, opt.StreamRequestBody, false) assert.DeepEqual(t, opt.Addr, ":8888") assert.DeepEqual(t, opt.BasePath, "/") assert.DeepEqual(t, opt.MaxRequestBodySize, 2) assert.DeepEqual(t, opt.DisablePrintRoute, true) assert.DeepEqual(t, opt.SenseClientDisconnection, true) assert.DeepEqual(t, opt.Network, "unix") assert.DeepEqual(t, opt.ExitWaitTimeout, time.Second) assert.DeepEqual(t, opt.MaxKeepBodySize, 500) assert.DeepEqual(t, opt.GetOnly, true) assert.DeepEqual(t, opt.DisableKeepalive, true) assert.DeepEqual(t, opt.H2C, true) assert.DeepEqual(t, opt.ReadBufferSize, 100) assert.DeepEqual(t, opt.ALPN, true) assert.DeepEqual(t, opt.TraceLevel, stats.LevelDisabled) assert.DeepEqual(t, opt.RegistryInfo, info) assert.DeepEqual(t, opt.Registry, nil) assert.DeepEqual(t, opt.AutoReloadRender, true) assert.DeepEqual(t, opt.AutoReloadInterval, 5*time.Second) assert.DeepEqual(t, opt.ListenConfig, cfg) assert.Assert(t, reflect.TypeOf(opt.AltTransporterNewer) == reflect.TypeOf(transporter)) assert.DeepEqual(t, opt.DisableHeaderNamesNormalizing, true) assert.DeepEqual(t, opt.MaxHeaderBytes, 1024) } func TestDefaultOptions(t *testing.T) { opt := config.NewOptions([]config.Option{}) assert.DeepEqual(t, opt.ReadTimeout, time.Minute*3) assert.DeepEqual(t, opt.IdleTimeout, time.Minute*3) assert.DeepEqual(t, opt.KeepAliveTimeout, time.Minute) assert.DeepEqual(t, opt.RedirectTrailingSlash, true) assert.DeepEqual(t, opt.RedirectFixedPath, false) assert.DeepEqual(t, opt.HandleMethodNotAllowed, false) assert.DeepEqual(t, opt.UseRawPath, false) assert.DeepEqual(t, opt.RemoveExtraSlash, false) assert.DeepEqual(t, opt.UnescapePathValues, true) assert.DeepEqual(t, opt.DisablePreParseMultipartForm, false) assert.DeepEqual(t, opt.StreamRequestBody, false) assert.DeepEqual(t, opt.Addr, ":8888") assert.DeepEqual(t, opt.BasePath, "/") assert.DeepEqual(t, opt.MaxRequestBodySize, 4*1024*1024) assert.DeepEqual(t, opt.GetOnly, false) assert.DeepEqual(t, opt.DisableKeepalive, false) assert.DeepEqual(t, opt.DisablePrintRoute, false) assert.DeepEqual(t, opt.SenseClientDisconnection, false) assert.DeepEqual(t, opt.Network, "tcp") assert.DeepEqual(t, opt.ExitWaitTimeout, time.Second*5) assert.DeepEqual(t, opt.MaxKeepBodySize, 4*1024*1024) assert.DeepEqual(t, opt.H2C, false) assert.DeepEqual(t, opt.ReadBufferSize, 4096) assert.DeepEqual(t, opt.ALPN, false) assert.DeepEqual(t, opt.Registry, registry.NoopRegistry) assert.DeepEqual(t, opt.AutoReloadRender, false) assert.Assert(t, opt.RegistryInfo == nil) assert.DeepEqual(t, opt.AutoReloadRender, false) assert.DeepEqual(t, opt.AutoReloadInterval, time.Duration(0)) assert.DeepEqual(t, opt.DisableHeaderNamesNormalizing, false) assert.DeepEqual(t, opt.MaxHeaderBytes, 1<<20) } func TestWithListener(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() cfg := &net.ListenConfig{} opt := config.NewOptions([]config.Option{ WithHostPorts("127.0.0.1:8888"), WithNetwork("udp"), WithListenConfig(cfg), WithListener(ln), }) // Listener should be set assert.DeepEqual(t, opt.Listener, ln) // Network and Addr should be updated from listener assert.DeepEqual(t, opt.Network, ln.Addr().Network()) assert.DeepEqual(t, opt.Addr, ln.Addr().String()) // ListenConfig should be reset assert.Assert(t, opt.ListenConfig == nil) } type mockTransporter struct{} func (m *mockTransporter) ListenAndServe(onData network.OnData) (err error) { panic("implement me") } func (m *mockTransporter) Close() error { panic("implement me") } func (m *mockTransporter) Shutdown(ctx context.Context) error { panic("implement me") } ================================================ FILE: pkg/app/server/registry/registry.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package registry import "net" const ( DefaultWeight = 10 ) // Registry is extension interface of service registry. type Registry interface { Register(info *Info) error Deregister(info *Info) error } // Info is used for registry. // The fields are just suggested, which is used depends on design. type Info struct { // ServiceName will be set in hertz by default ServiceName string // Addr will be set in hertz by default Addr net.Addr // Weight will be set in hertz by default Weight int // extend other infos with Tags. Tags map[string]string } // NoopRegistry is an empty implement of Registry var NoopRegistry Registry = &noopRegistry{} // NoopRegistry type noopRegistry struct{} func (e noopRegistry) Register(*Info) error { return nil } func (e noopRegistry) Deregister(*Info) error { return nil } ================================================ FILE: pkg/app/server/registry/registry_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package registry import ( "testing" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestNoopRegistry(t *testing.T) { reg := noopRegistry{} assert.Nil(t, reg.Deregister(&Info{})) assert.Nil(t, reg.Register(&Info{})) } ================================================ FILE: pkg/app/server/render/data.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2014 Manuel Martínez-Almeida * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package render import "github.com/cloudwego/hertz/pkg/protocol" // Data contains ContentType and bytes data. type Data struct { ContentType string Data []byte } // Render (Data) writes data with custom ContentType. func (r Data) Render(resp *protocol.Response) (err error) { r.WriteContentType(resp) resp.AppendBody(r.Data) return } // WriteContentType (Data) writes custom ContentType. func (r Data) WriteContentType(resp *protocol.Response) { writeContentType(resp, r.ContentType) } ================================================ FILE: pkg/app/server/render/doc.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // The files in render package are forked from gin[github.com/gin-gonic/gin], // and we keep the original Copyright[Copyright 2014 gin authors] and License of gin for those files. // We also need to modify as we need, the modifications are Copyright of 2022 CloudWeGo Authors. // Thanks for gin authors! Below is the source code information: // Repo: github.com/gin-gonic/gin // Forked Version: v1.7.7 package render ================================================ FILE: pkg/app/server/render/html.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2014 Manuel Martínez-Almeida * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package render import ( "html/template" "log" "sync" "time" "github.com/cloudwego/hertz/pkg/common/hlog" "github.com/cloudwego/hertz/pkg/protocol" "github.com/fsnotify/fsnotify" ) // Delims represents a set of Left and Right delimiters for HTML template rendering. type Delims struct { // Left delimiter, defaults to {{. Left string // Right delimiter, defaults to }}. Right string } // HTMLRender interface is to be implemented by HTMLProduction and HTMLDebug. type HTMLRender interface { // Instance returns an HTML instance. Instance(string, interface{}) Render Close() error } // HTMLProduction contains template reference and its delims. type HTMLProduction struct { Template *template.Template } // HTML contains template reference and its name with given interface object. type HTML struct { Template *template.Template Name string Data interface{} } var htmlContentType = "text/html; charset=utf-8" // Instance (HTMLProduction) returns an HTML instance which it realizes Render interface. func (r HTMLProduction) Instance(name string, data interface{}) Render { return HTML{ Template: r.Template, Name: name, Data: data, } } func (r HTMLProduction) Close() error { return nil } // Render (HTML) executes template and writes its result with custom ContentType for response. func (r HTML) Render(resp *protocol.Response) error { r.WriteContentType(resp) if r.Name == "" { return r.Template.Execute(resp.BodyWriter(), r.Data) } return r.Template.ExecuteTemplate(resp.BodyWriter(), r.Name, r.Data) } // WriteContentType (HTML) writes HTML ContentType. func (r HTML) WriteContentType(resp *protocol.Response) { writeContentType(resp, htmlContentType) } type HTMLDebug struct { sync.Once Template *template.Template RefreshInterval time.Duration Files []string FuncMap template.FuncMap Delims Delims reloadCh chan struct{} watcher *fsnotify.Watcher } func (h *HTMLDebug) Instance(name string, data interface{}) Render { h.Do(func() { h.startChecker() }) select { case <-h.reloadCh: h.reload() default: } return HTML{ Template: h.Template, Name: name, Data: data, } } func (h *HTMLDebug) Close() error { if h.watcher == nil { return nil } return h.watcher.Close() } func (h *HTMLDebug) reload() { h.Template = template.Must(template.New(""). Delims(h.Delims.Left, h.Delims.Right). Funcs(h.FuncMap). ParseFiles(h.Files...)) } func (h *HTMLDebug) startChecker() { h.reloadCh = make(chan struct{}) if h.RefreshInterval > 0 { go func() { hlog.SystemLogger().Debugf("[HTMLDebug] HTML template reloader started with interval %v", h.RefreshInterval) for range time.Tick(h.RefreshInterval) { hlog.SystemLogger().Debugf("[HTMLDebug] triggering HTML template reloader") h.reloadCh <- struct{}{} hlog.SystemLogger().Debugf("[HTMLDebug] HTML template has been reloaded, next reload in %v", h.RefreshInterval) } }() return } watcher, err := fsnotify.NewWatcher() if err != nil { log.Fatal(err) } h.watcher = watcher for _, f := range h.Files { err := watcher.Add(f) hlog.SystemLogger().Debugf("[HTMLDebug] watching file: %s", f) if err != nil { hlog.SystemLogger().Errorf("[HTMLDebug] add watching file: %s, error happened: %v", f, err) } } go func() { hlog.SystemLogger().Debugf("[HTMLDebug] HTML template reloader started with file watcher") for { select { case event, ok := <-watcher.Events: if !ok { return } if event.Op&fsnotify.Write == fsnotify.Write { hlog.SystemLogger().Debugf("[HTMLDebug] modified file: %s, html render template will be reloaded at the next rendering", event.Name) h.reloadCh <- struct{}{} hlog.SystemLogger().Debugf("[HTMLDebug] HTML template has been reloaded") } case err, ok := <-watcher.Errors: if !ok { return } hlog.SystemLogger().Errorf("error happened when watching the rendering files: %v", err) } } }() } ================================================ FILE: pkg/app/server/render/html_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package render import ( "html/template" "io/ioutil" "os" "testing" "time" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol" ) func TestHTMLDebug_StartChecker_timer(t *testing.T) { render := &HTMLDebug{ RefreshInterval: time.Second, Delims: Delims{Left: "{[{", Right: "}]}"}, FuncMap: template.FuncMap{}, Files: []string{"../../../common/testdata/template/index.tmpl"}, } select { case <-render.reloadCh: t.Fatalf("should not be triggered") default: } render.startChecker() select { case <-time.After(render.RefreshInterval + 500*time.Millisecond): t.Fatalf("should be triggered in 1.5 second") case <-render.reloadCh: render.reload() } } func TestHTMLDebug_StartChecker_fs_watcher(t *testing.T) { f, _ := ioutil.TempFile("./", "test.tmpl") defer func() { f.Close() os.Remove(f.Name()) }() render := &HTMLDebug{Files: []string{f.Name()}} select { case <-render.reloadCh: t.Fatalf("should not be triggered") default: } render.startChecker() f.Write([]byte("hello")) f.Sync() select { case <-time.After(50 * time.Millisecond): t.Fatalf("should be triggered immediately") case <-render.reloadCh: } select { case <-render.reloadCh: t.Fatalf("should not be triggered") default: } } func TestRenderHTML(t *testing.T) { resp := &protocol.Response{} tmpl := template.Must(template.New(""). Delims("{[{", "}]}"). Funcs(template.FuncMap{}). ParseFiles("../../../common/testdata/template/index.tmpl")) r := &HTMLProduction{Template: tmpl} html := r.Instance("index.tmpl", utils.H{ "title": "Main website", }) err := r.Close() assert.Nil(t, err) html.WriteContentType(resp) assert.DeepEqual(t, []byte("text/html; charset=utf-8"), resp.Header.Peek("Content-Type")) err = html.Render(resp) assert.Nil(t, err) assert.DeepEqual(t, []byte("text/html; charset=utf-8"), resp.Header.Peek("Content-Type")) assert.DeepEqual(t, []byte("

Main website

"), resp.Body()) respDebug := &protocol.Response{} rDebug := &HTMLDebug{ Template: tmpl, Delims: Delims{Left: "{[{", Right: "}]}"}, FuncMap: template.FuncMap{}, Files: []string{"../../../common/testdata/template/index.tmpl"}, } htmlDebug := rDebug.Instance("index.tmpl", utils.H{ "title": "Main website", }) err = rDebug.Close() assert.Nil(t, err) htmlDebug.WriteContentType(respDebug) assert.DeepEqual(t, []byte("text/html; charset=utf-8"), respDebug.Header.Peek("Content-Type")) err = htmlDebug.Render(respDebug) assert.Nil(t, err) assert.DeepEqual(t, []byte("text/html; charset=utf-8"), respDebug.Header.Peek("Content-Type")) assert.DeepEqual(t, []byte("

Main website

"), respDebug.Body()) } ================================================ FILE: pkg/app/server/render/json.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2014 Manuel Martínez-Almeida * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package render import ( "bytes" "encoding/json" hjson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/protocol" ) // JSONMarshaler customize json.Marshal as you like type JSONMarshaler func(v interface{}) ([]byte, error) var jsonMarshalFunc JSONMarshaler func init() { ResetJSONMarshal(hjson.Marshal) } func ResetJSONMarshal(fn JSONMarshaler) { jsonMarshalFunc = fn } func ResetStdJSONMarshal() { ResetJSONMarshal(json.Marshal) } // JSONRender JSON contains the given interface object. type JSONRender struct { Data interface{} } var jsonContentType = "application/json; charset=utf-8" // Render (JSON) writes data with custom ContentType. func (r JSONRender) Render(resp *protocol.Response) error { writeContentType(resp, jsonContentType) jsonBytes, err := jsonMarshalFunc(r.Data) if err != nil { return err } resp.AppendBody(jsonBytes) return nil } // WriteContentType (JSON) writes JSON ContentType. func (r JSONRender) WriteContentType(resp *protocol.Response) { writeContentType(resp, jsonContentType) } // PureJSON contains the given interface object. type PureJSON struct { Data interface{} } // Render (JSON) writes data with custom ContentType. func (r PureJSON) Render(resp *protocol.Response) (err error) { r.WriteContentType(resp) buffer := new(bytes.Buffer) encoder := json.NewEncoder(buffer) encoder.SetEscapeHTML(false) err = encoder.Encode(r.Data) if err != nil { return } resp.AppendBody(buffer.Bytes()) return } // WriteContentType (JSON) writes JSON ContentType. func (r PureJSON) WriteContentType(resp *protocol.Response) { writeContentType(resp, jsonContentType) } // IndentedJSON contains the given interface object. type IndentedJSON struct { Data interface{} } // Render (IndentedJSON) marshals the given interface object and writes it with custom ContentType. func (r IndentedJSON) Render(resp *protocol.Response) (err error) { writeContentType(resp, jsonContentType) jsonBytes, err := jsonMarshalFunc(r.Data) if err != nil { return err } var buf bytes.Buffer err = json.Indent(&buf, jsonBytes, "", " ") if err != nil { return err } resp.AppendBody(buf.Bytes()) return nil } // WriteContentType (JSON) writes JSON ContentType. func (r IndentedJSON) WriteContentType(resp *protocol.Response) { writeContentType(resp, jsonContentType) } ================================================ FILE: pkg/app/server/render/json_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2014 Manuel Martínez-Almeida * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package render import ( "strings" "testing" ) func Test_ResetStdJSONMarshal(t *testing.T) { table := map[string]string{ "testA": "hello", "B": "world", } ResetStdJSONMarshal() jsonBytes, err := jsonMarshalFunc(table) if err != nil { t.Fatal(err) } if !strings.Contains(string(jsonBytes), "\"B\":\"world\"") || !strings.Contains(string(jsonBytes), "\"testA\":\"hello\"") { t.Fatal("marshal struct is not equal to the string") } } func Test_DefaultJSONMarshal(t *testing.T) { table := map[string]string{ "testA": "hello", "B": "world", } jsonBytes, err := jsonMarshalFunc(table) if err != nil { t.Fatal(err) } if !strings.Contains(string(jsonBytes), "\"B\":\"world\"") || !strings.Contains(string(jsonBytes), "\"testA\":\"hello\"") { t.Fatal("marshal struct is not equal to the string") } } ================================================ FILE: pkg/app/server/render/protobuf.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2014 Manuel Martínez-Almeida * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package render import ( "github.com/cloudwego/hertz/pkg/protocol" "google.golang.org/protobuf/proto" ) // ProtoBuf contains the given interface object. type ProtoBuf struct { Data interface{} } var protobufContentType = "application/x-protobuf" // Render (ProtoBuf) marshals the given interface object and writes data with custom ContentType. func (r ProtoBuf) Render(resp *protocol.Response) error { r.WriteContentType(resp) bytes, err := proto.Marshal(r.Data.(proto.Message)) if err != nil { return err } resp.AppendBody(bytes) return nil } // WriteContentType (ProtoBuf) writes ProtoBuf ContentType. func (r ProtoBuf) WriteContentType(resp *protocol.Response) { writeContentType(resp, protobufContentType) } ================================================ FILE: pkg/app/server/render/render.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2014 Manuel Martínez-Almeida * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package render import "github.com/cloudwego/hertz/pkg/protocol" // Render interface is to be implemented by JSON, XML, HTML, YAML and so on. type Render interface { // Render writes data with custom ContentType. // Do not panic inside, RequestContext will handle it. Render(resp *protocol.Response) error // WriteContentType writes custom ContentType. WriteContentType(resp *protocol.Response) } var ( _ Render = JSONRender{} _ Render = String{} _ Render = Data{} ) func writeContentType(resp *protocol.Response, value string) { resp.Header.SetContentType(value) } ================================================ FILE: pkg/app/server/render/render_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2014 Manuel Martínez-Almeida * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package render import ( "encoding/xml" "testing" "github.com/bytedance/sonic" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/testdata/proto" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" ) type xmlmap map[string]interface{} // Allows type H to be used with xml.Marshal func (h xmlmap) MarshalXML(e *xml.Encoder, start xml.StartElement) error { start.Name = xml.Name{ Space: "", Local: "map", } if err := e.EncodeToken(start); err != nil { return err } for key, value := range h { elem := xml.StartElement{ Name: xml.Name{Space: "", Local: key}, Attr: []xml.Attr{}, } if err := e.EncodeElement(value, elem); err != nil { return err } } return e.EncodeToken(xml.EndElement{Name: start.Name}) } func TestRenderJSON(t *testing.T) { resp := &protocol.Response{} data := map[string]interface{}{ "foo": "bar", "html": "", } (JSONRender{data}).WriteContentType(resp) assert.DeepEqual(t, []byte(consts.MIMEApplicationJSONUTF8), resp.Header.Peek("Content-Type")) err := (JSONRender{data}).Render(resp) assert.Nil(t, err) assert.DeepEqual(t, []byte("{\"foo\":\"bar\",\"html\":\"\\u003cb\\u003e\"}"), resp.Body()) assert.DeepEqual(t, []byte(consts.MIMEApplicationJSONUTF8), resp.Header.Peek("Content-Type")) } func TestRenderJSONError(t *testing.T) { resp := &protocol.Response{} data := make(chan int) err := (JSONRender{data}).Render(resp) // json: unsupported type: chan int assert.NotNil(t, err) } func TestRenderPureJSON(t *testing.T) { resp := &protocol.Response{} data := map[string]interface{}{ "foo": "bar", "html": "", } (PureJSON{data}).WriteContentType(resp) assert.DeepEqual(t, []byte(consts.MIMEApplicationJSONUTF8), resp.Header.Peek("Content-Type")) err := (PureJSON{data}).Render(resp) assert.Nil(t, err) assert.DeepEqual(t, []byte("{\"foo\":\"bar\",\"html\":\"\"}\n"), resp.Body()) assert.DeepEqual(t, []byte(consts.MIMEApplicationJSONUTF8), resp.Header.Peek("Content-Type")) } func TestRenderPureJSONError(t *testing.T) { resp := &protocol.Response{} data := make(chan int) err := (PureJSON{data}).Render(resp) // json: unsupported type: chan int assert.NotNil(t, err) } func TestRenderProtobuf(t *testing.T) { resp := &protocol.Response{} data := proto.TestStruct{Body: []byte("Hello World")} (ProtoBuf{&data}).WriteContentType(resp) assert.DeepEqual(t, []byte("application/x-protobuf"), resp.Header.Peek("Content-Type")) err := (ProtoBuf{&data}).Render(resp) assert.Nil(t, err) assert.DeepEqual(t, []byte("\n\vHello World"), resp.Body()) assert.DeepEqual(t, []byte("application/x-protobuf"), resp.Header.Peek("Content-Type")) } func TestRenderProtobufError(t *testing.T) { resp := &protocol.Response{} data := proto.Test{} err := (ProtoBuf{&data}).Render(resp) assert.NotNil(t, err) } func TestRenderString(t *testing.T) { resp := &protocol.Response{} (String{ Format: "hello %s %d", Data: []interface{}{}, }).WriteContentType(resp) assert.DeepEqual(t, []byte(consts.MIMETextPlainUTF8), resp.Header.Peek("Content-Type")) err := (String{ Format: "hola %s %d", Data: []interface{}{"manu", 2}, }).Render(resp) assert.Nil(t, err) assert.DeepEqual(t, []byte("hola manu 2"), resp.Body()) assert.DeepEqual(t, []byte(consts.MIMETextPlainUTF8), resp.Header.Peek("Content-Type")) } func TestRenderStringLenZero(t *testing.T) { resp := &protocol.Response{} err := (String{ Format: "hola %s %d", Data: []interface{}{}, }).Render(resp) assert.Nil(t, err) assert.DeepEqual(t, []byte("hola %s %d"), resp.Body()) assert.DeepEqual(t, []byte(consts.MIMETextPlainUTF8), resp.Header.Peek("Content-Type")) } func TestRenderData(t *testing.T) { resp := &protocol.Response{} data := []byte("#!PNG some raw data") err := (Data{ ContentType: "image/png", Data: data, }).Render(resp) assert.Nil(t, err) assert.DeepEqual(t, []byte("#!PNG some raw data"), resp.Body()) assert.DeepEqual(t, []byte(consts.MIMEImagePNG), resp.Header.Peek("Content-Type")) } func TestRenderXML(t *testing.T) { resp := &protocol.Response{} data := xmlmap{ "foo": "bar", } (XML{data}).WriteContentType(resp) assert.DeepEqual(t, []byte(consts.MIMEApplicationXMLUTF8), resp.Header.Peek("Content-Type")) err := (XML{data}).Render(resp) assert.Nil(t, err) assert.DeepEqual(t, []byte("bar"), resp.Body()) assert.DeepEqual(t, []byte(consts.MIMEApplicationXMLUTF8), resp.Header.Peek("Content-Type")) } func TestRenderXMLError(t *testing.T) { resp := &protocol.Response{} data := make(chan int) err := (XML{data}).Render(resp) assert.NotNil(t, err) } func TestRenderIndentedJSON(t *testing.T) { data := map[string]interface{}{ "foo": "bar", "html": "h1", } t.Run("TestHeader", func(t *testing.T) { resp := &protocol.Response{} (IndentedJSON{data}).WriteContentType(resp) assert.DeepEqual(t, []byte(consts.MIMEApplicationJSONUTF8), resp.Header.Peek("Content-Type")) }) t.Run("TestBody", func(t *testing.T) { ResetStdJSONMarshal() resp := &protocol.Response{} err := (IndentedJSON{data}).Render(resp) assert.Nil(t, err) assert.DeepEqual(t, []byte("{\n \"foo\": \"bar\",\n \"html\": \"h1\"\n}"), resp.Body()) assert.DeepEqual(t, []byte(consts.MIMEApplicationJSONUTF8), resp.Header.Peek("Content-Type")) ResetJSONMarshal(sonic.Marshal) }) t.Run("TestError", func(t *testing.T) { resp := &protocol.Response{} ch := make(chan int) err := (IndentedJSON{ch}).Render(resp) assert.NotNil(t, err) }) } ================================================ FILE: pkg/app/server/render/text.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2014 Manuel Martínez-Almeida * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package render import ( "fmt" "github.com/cloudwego/hertz/pkg/protocol" ) // String contains the given interface object slice and its format. type String struct { Format string Data []interface{} } var plainContentType = "text/plain; charset=utf-8" // Render (String) writes data with custom ContentType. func (r String) Render(resp *protocol.Response) error { writeContentType(resp, plainContentType) output := r.Format if len(r.Data) > 0 { output = fmt.Sprintf(r.Format, r.Data...) } resp.AppendBodyString(output) return nil } // WriteContentType (String) writes Plain ContentType. func (r String) WriteContentType(resp *protocol.Response) { writeContentType(resp, plainContentType) } ================================================ FILE: pkg/app/server/render/xml.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2014 Manuel Martínez-Almeida * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package render import ( "encoding/xml" "github.com/cloudwego/hertz/pkg/protocol" ) // XML contains the given interface object. type XML struct { Data interface{} } var xmlContentType = "application/xml; charset=utf-8" // Render (XML) encodes the given interface object and writes data with custom ContentType. func (r XML) Render(resp *protocol.Response) error { writeContentType(resp, xmlContentType) xmlBytes, err := xml.Marshal(r.Data) if err != nil { return err } resp.AppendBody(xmlBytes) return nil } // WriteContentType (XML) writes XML ContentType for response. func (r XML) WriteContentType(w *protocol.Response) { writeContentType(w, xmlContentType) } ================================================ FILE: pkg/app/server/server_bench_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package server import ( "bufio" "context" "net" "testing" "github.com/cloudwego/hertz/internal/testutils" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/network/standard" ) func BenchmarkServerHelloWorld(b *testing.B) { ln := testutils.NewTestListener(b) defer ln.Close() h := Default(WithListener(ln), WithTransport(standard.NewTransporter)) h.GET("/hello", func(c context.Context, ctx *app.RequestContext) { ctx.SetBodyString("hello world") }) go h.Run() waitEngineRunning(h) defer h.Engine.Close() addr := ln.Addr().String() // Pre-create connection pool with keep-alive const poolSize = 10 connPool := make([]net.Conn, poolSize) readerPool := make([]*bufio.Reader, poolSize) for i := 0; i < poolSize; i++ { conn, err := net.Dial("tcp", addr) if err != nil { b.Fatalf("failed to dial: %s", err) } connPool[i] = conn readerPool[i] = bufio.NewReader(conn) defer conn.Close() } request := []byte("GET /hello HTTP/1.1\r\nHost: localhost\r\nConnection: keep-alive\r\n\r\n") b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { conn := connPool[i%poolSize] reader := readerPool[i%poolSize] _, err := conn.Write(request) if err != nil { b.Fatalf("write error: %s", err) } _, err = reader.Peek(1) if err != nil { b.Fatal(err) } _, err = reader.Discard(reader.Buffered()) if err != nil { b.Fatal(err) } } } ================================================ FILE: pkg/common/adaptor/handler.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adaptor import ( "bufio" "context" "errors" "net" "net/http" "runtime" "sync" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/protocol/http1/resp" ) // HertzHandler converts a http.Handler to an app.HandlerFunc. func HertzHandler(h http.Handler) app.HandlerFunc { return func(ctx context.Context, rc *app.RequestContext) { // creating http.Request r := &rc.Request req, err := http.NewRequestWithContext(ctx, methodstr(r.Method()), r.URI().String(), nil) if err != nil { rc.SetStatusCode(consts.StatusInternalServerError) rc.SetBodyString(err.Error()) return } // request Body & ContentLength stream := r.IsBodyStream() if stream && r.HasMultipartForm() && len(r.BodyBytes()) > 0 { // in this case, r.MultipartForm() called before this handler then we can not rely on stream body. // coz both r.Body() and r.BodyStream() will return nothing due to EOF of stream body. // we fix it by calling r.CloseBodyStream(), then r.Body() correctly returns bytes. // FIXME: This should be taken care of by `protocol.Request` r.CloseBodyStream() stream = false } if stream { // use BodyStream if possible to avoid OOM issue. req.Body = reader2closer(r.BodyStream()) req.ContentLength = int64(r.Header.ContentLength()) } else { // fast path for default cases with StreamRequestBody=false b := r.Body() req.Body = newBytesRWCloser(b) req.ContentLength = int64(len(b)) } // request Header r.Header.VisitAll(func(k, v []byte) { req.Header[string(k)] = append(req.Header[string(k)], string(v)) }) // request other properties if s := r.Header.GetProtocol(); s != "" { req.Proto = s req.ProtoMajor, req.ProtoMinor, _ = parseHTTPVersion(s) } req.Close = r.ConnectionClose() req.RemoteAddr = rc.RemoteAddr().String() req.RequestURI = string(r.RequestURI()) if tlsconn, ok := rc.GetConn().(network.ConnTLSer); ok { state := tlsconn.ConnectionState() req.TLS = &state } // creating http.ResponseWriter // // coz it's server response // no need to copy anything from hertz Response w := &httpResponseWriter{rc: rc} h.ServeHTTP(w, req) if w.hijacked != nil { // wait for hijacked conn to close before returning, // otherwise either hertz will close the conn // or netpoll may reuse the conn for next request. <-w.hijacked } } } type httpResponseWriter struct { rc *app.RequestContext header http.Header err error wroteHeader bool skipBody bool hijacked chan struct{} // != nil if hijacked } var errConnHijacked = errors.New("hertz net/http adaptor: conn hijacked") func (p *httpResponseWriter) Header() http.Header { if p.header != nil { return p.header } p.header = make(map[string][]string) return p.header } // Write implements http.ResponseWriter.Write func (p *httpResponseWriter) Write(b []byte) (n int, err error) { if p.hijacked != nil { return 0, errConnHijacked } if !p.wroteHeader { p.WriteHeader(consts.StatusOK) } if p.err != nil { return 0, p.err } if p.skipBody { return len(b), nil } n, p.err = p.rc.Response.GetHijackWriter().Write(b) return n, p.err } // WriteHeader implements http.ResponseWriter.WriteHeader func (p *httpResponseWriter) WriteHeader(statusCode int) { if p.wroteHeader || p.hijacked != nil { return } p.wroteHeader = true r := &p.rc.Response // reset and check if user updates Content-Length // if we have no Content-Length, can only use chunked Transfer-Encoding r.Header.InitContentLengthWithValue(-1) for k, vv := range p.header { for _, v := range vv { r.Header.Add(k, v) } } w := p.rc.GetWriter() r.Header.SetStatusCode(statusCode) p.skipBody = r.Header.MustSkipContentLength() || string(p.rc.Request.Method()) == consts.MethodHead if p.skipBody { // set Content-Length: 0 r.Header.SetCanonical(bytestr.StrContentLength, []byte("0")) // skip all further writes, // must be set for hertz request loop or it would write header and body after handler returns r.HijackWriter(noopWriter{}) p.err = resp.WriteHeader(&r.Header, w) } else if r.Header.ContentLength() < 0 { // For chunked encoding, write headers immediately cw := resp.NewChunkedBodyWriter(r, w) r.HijackWriter(cw) type chunkedBodyWriter interface { WriteHeader() error } p.err = cw.(chunkedBodyWriter).WriteHeader() } else { // use Writer directly instead of keep buffering data in resp.BodyBuffer() // you never know how much data would be written to response r.HijackWriter(writer2writerExt(w)) p.err = resp.WriteHeader(&r.Header, w) } } var _ http.Flusher = (*httpResponseWriter)(nil) // Flush implements http.Flusher and captures any flush errors func (p *httpResponseWriter) Flush() { if p.err == nil { p.err = p.rc.GetWriter().Flush() } } var _ http.Hijacker = (*httpResponseWriter)(nil) // Hijack implements http.Hijacker func (p *httpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { if p.hijacked != nil { return nil, nil, errConnHijacked } if p.err != nil { return nil, nil, p.err } // If headers were already written, flush the buffer to avoid losing // any pending data before hijacking the connection if p.wroteHeader { if p.err = p.rc.GetWriter().Flush(); p.err != nil { return nil, nil, p.err } } conn := newHijackedConn(p.rc.GetConn()) p.hijacked = conn.closeCh // reset timeout if any _ = conn.SetReadTimeout(0) _ = conn.SetWriteTimeout(0) // make sure after handler returns: // * hertz won't reuse the conn // * hertz won't write any extra bytes to underlying conn p.rc.Response.Header.SetConnectionClose(true) p.rc.Response.HijackWriter(noopHijackWriter{}) return conn, bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)), nil } type noopHijackWriter struct{} var _ network.ExtWriter = noopHijackWriter{} func (noopHijackWriter) Write(b []byte) (int, error) { return 0, errConnHijacked } func (noopHijackWriter) Flush() error { return errConnHijacked } func (noopHijackWriter) Finalize() error { return nil } type noopWriter struct{} var _ network.ExtWriter = noopWriter{} func (noopWriter) Write(b []byte) (int, error) { return len(b), nil } func (noopWriter) Flush() error { return nil } func (noopWriter) Finalize() error { return nil } type hijackedConn struct { network.Conn closeOnce sync.Once closeCh chan struct{} } func newHijackedConn(conn network.Conn) *hijackedConn { c := &hijackedConn{Conn: conn, closeCh: make(chan struct{})} runtime.SetFinalizer(c, hijackedConnFinalizer) return c } func hijackedConnFinalizer(c *hijackedConn) { _ = c.Close() } func (c *hijackedConn) Close() error { runtime.SetFinalizer(c, nil) err := c.Conn.Close() c.closeOnce.Do(func() { close(c.closeCh) }) return err } ================================================ FILE: pkg/common/adaptor/handler_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adaptor import ( "bytes" "context" "embed" "fmt" "io" "mime/multipart" "net" "net/http" "runtime" "strings" "sync" "testing" "time" "github.com/cloudwego/hertz/internal/testutils" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/route" ) //go:embed * var adaptorFiles embed.FS func runEngine(onCreate func(*route.Engine)) (string, *route.Engine) { ln := testutils.NewTestListener(&testing.T{}) opt := config.NewOptions(nil) opt.Listener = ln engine := route.NewEngine(opt) onCreate(engine) go engine.Run() testutils.WaitEngineRunning(engine) return ln.Addr().String(), engine } func TestHertzHandler_BodyStream(t *testing.T) { var wg sync.WaitGroup wg.Add(1) h := HertzHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer wg.Done() b := make([]byte, 100) for i := 0; i < 3; i++ { // reading chunked data n, err := r.Body.Read(b) assert.Nil(t, err) assert.Assert(t, n == 5, n) assert.Assert(t, string(b[:n]) == "hello") } n, err := r.Body.Read(b) assert.Assert(t, err == io.EOF) assert.Assert(t, n == 0) })) addr, e := runEngine(func(e *route.Engine) { e.GetOptions().StreamRequestBody = true e.POST("/test", h) }) defer e.Close() r, w := io.Pipe() // for sending chunked data req, err := http.NewRequest("POST", "http://"+addr+"/test", r) assert.Nil(t, err) cli := &http.Client{} go cli.Do(req) for i := 0; i < 3; i++ { w.Write([]byte("hello")) time.Sleep(50 * time.Millisecond) } w.Close() wg.Wait() } func TestHertzHandler_Chunked(t *testing.T) { h := HertzHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { f := w.(http.Flusher) w.Header().Set("Transfer-Encoding", "chunked") for i := 0; i < 5; i++ { chunk := fmt.Sprintf("data:%d", i) _, err := w.Write([]byte(chunk)) assert.Nil(t, err) f.Flush() time.Sleep(20 * time.Millisecond) } })) addr, e := runEngine(func(e *route.Engine) { e.GET("/test", h) }) defer e.Close() resp, err := http.Get("http://" + addr + "/test") assert.Nil(t, err) defer resp.Body.Close() assert.Assert(t, len(resp.TransferEncoding) == 1 && resp.TransferEncoding[0] == "chunked") for i := 0; i < 5; i++ { b := make([]byte, 10) n, err := resp.Body.Read(b) assert.Nil(t, err) assert.Assert(t, string(b[:n]) == fmt.Sprintf("data:%d", i)) } } func TestHertzHandler_WriteHeader(t *testing.T) { h := HertzHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Length", "0") w.WriteHeader(500) w.(http.Flusher).Flush() time.Sleep(time.Second) // Simulate long-running handler })) addr, e := runEngine(func(e *route.Engine) { e.GET("/test", h) }) defer e.Close() conn, err := net.Dial("tcp", addr) assert.Nil(t, err) defer conn.Close() _, err = conn.Write([]byte("GET /test HTTP/1.1\r\nHost: example.com\r\n\r\n")) assert.Nil(t, err) // Set a short read deadline to verify headers arrive quickly after WriteHeader() conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond)) b := make([]byte, 200) n, err := conn.Read(b) assert.Nil(t, err) assert.Assert(t, strings.HasPrefix(string(b[:n]), "HTTP/1.1 500 ")) } func TestHertzHandler_Hijack(t *testing.T) { h := HertzHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { conn, rw, err := w.(http.Hijacker).Hijack() assert.Nil(t, err) _, _, err = w.(http.Hijacker).Hijack() // hijacked assert.NotNil(t, err) w.WriteHeader(500) // hijacked, noop go func() { defer conn.Close() time.Sleep(50 * time.Millisecond) _, err = w.Write([]byte("hello")) assert.Assert(t, err == errConnHijacked) _, err = rw.Write([]byte("hello")) assert.Nil(t, err) err = rw.Flush() assert.Nil(t, err) b := make([]byte, 10) n, err := rw.Read(b) assert.Nil(t, err) assert.Assert(t, string(b[:n]) == "world") }() })) addr, e := runEngine(func(e *route.Engine) { e.GET("/test", h) }) defer e.Close() conn, err := net.Dial("tcp", addr) assert.Nil(t, err) defer conn.Close() conn.Write([]byte("GET /test HTTP/1.1\r\nHost: example.com\r\n\r\n")) b := make([]byte, 100) n, err := conn.Read(b) assert.Nil(t, err) assert.Assert(t, string(b[:n]) == "hello", string(b[:n])) _, err = conn.Write([]byte("world")) assert.Nil(t, err) n, err = conn.Read(b) // Keep-Alive will not work if hijacked assert.Assert(t, err == io.EOF) assert.Assert(t, n == 0) } // TestHertzHandler_HijackGC tests that hijacked conn is closed by GC finalizer // when user forgets to call Close() func TestHertzHandler_HijackGC(t *testing.T) { h := HertzHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _, err := w.(http.Hijacker).Hijack() assert.Nil(t, err) // intentionally not closing conn, let GC handle it runtime.GC() runtime.GC() // make sure the net.Conn is closed by GC })) addr, e := runEngine(func(e *route.Engine) { e.GET("/test", h) }) defer e.Close() conn, err := net.Dial("tcp", addr) assert.Nil(t, err) defer conn.Close() conn.Write([]byte("GET /test HTTP/1.1\r\nHost: example.com\r\n\r\n")) b := make([]byte, 100) n, err := conn.Read(b) // conn should be closed by finalizer assert.Assert(t, err == io.EOF, err) assert.Assert(t, n == 0) } // TestHertzHandler_WriteHeader_Hijack verifies that headers are properly flushed // before hijacking the connection. This test ensures that when WriteHeader is called // before Hijack, the headers are correctly sent and the connection can be taken over. func TestHertzHandler_WriteHeader_Hijack(t *testing.T) { h := HertzHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Custom-Header", "test-value") w.WriteHeader(200) conn, rw, err := w.(http.Hijacker).Hijack() assert.Nil(t, err) defer conn.Close() _, err = rw.WriteString("hijacked response body") assert.Nil(t, err) assert.Nil(t, rw.Flush()) })) addr, e := runEngine(func(e *route.Engine) { e.GET("/test", h) }) defer e.Close() conn, err := net.Dial("tcp", addr) assert.Nil(t, err) defer conn.Close() _, err = conn.Write([]byte("GET /test HTTP/1.1\r\nHost: example.com\r\n\r\n")) assert.Nil(t, err) // Wait briefly for server to process and send response time.Sleep(50 * time.Millisecond) conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond)) b := make([]byte, 1024) n, err := conn.Read(b) assert.Nil(t, err) response := string(b[:n]) t.Logf("Response: %q", response) assert.Assert(t, strings.Contains(response, "HTTP/1.1 200 OK"), response) assert.Assert(t, strings.Contains(response, "X-Custom-Header: test-value"), response) assert.Assert(t, strings.Contains(response, "hijacked response body"), response) } func TestHertzHandler_FSEmbed(t *testing.T) { addr, e := runEngine(func(e *route.Engine) { h := HertzHandler(http.FileServer(http.FS(adaptorFiles))) e.GET("/*filepath", h) e.HEAD("/*filepath", h) }) defer e.Close() resp, err := http.Get("http://" + addr + "/handler_test.go") assert.Nil(t, err) expect := "hello, I'm handler_test.go" b, err := io.ReadAll(resp.Body) s := string(b) assert.Nil(t, err) assert.Assert(t, strings.Contains(s, expect), s) } func TestHertzHandler_Multipart(t *testing.T) { kvs := map[string]string{ "name": "Alice", "email": "alice@example.com", } h := HertzHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { for k, expectv := range kvs { v := r.FormValue(k) assert.Assert(t, v == expectv, v) } w.WriteHeader(204) w.Write([]byte("hello")) })) addr, e := runEngine(func(e *route.Engine) { opts := e.GetOptions() opts.StreamRequestBody = true opts.DisablePreParseMultipartForm = true e.POST("/test", func(ctx context.Context, rc *app.RequestContext) { _, err := rc.MultipartForm() // call rc.MultipartForm before HertzHandler assert.Nil(t, err) h(ctx, rc) }) }) defer e.Close() body, ct := createMultipartBody(kvs) req, err := http.NewRequest("POST", "http://"+addr+"/test", bytes.NewReader(body.Bytes())) assert.Nil(t, err) req.Header.Set("Content-Type", ct) client := &http.Client{} resp, err := client.Do(req) assert.Nil(t, err) assert.Assert(t, resp.StatusCode == 204, resp.StatusCode) resp.Body.Close() } func createMultipartBody(kvs map[string]string) (*bytes.Buffer, string) { buf := &bytes.Buffer{} w := multipart.NewWriter(buf) for k, v := range kvs { _ = w.WriteField(k, v) } _ = w.Close() return buf, w.FormDataContentType() } func TestNoopHijackWriter(t *testing.T) { writer := noopHijackWriter{} // Test Write method n, err := writer.Write([]byte("test")) assert.Assert(t, n == 0, n) assert.Assert(t, err == errConnHijacked, err) // Test Flush method err = writer.Flush() assert.Assert(t, err == errConnHijacked, err) // Test Finalize method err = writer.Finalize() assert.Nil(t, err) } func TestNoopWriter(t *testing.T) { writer := noopWriter{} // Test Write method testData := []byte("test data") n, err := writer.Write(testData) assert.Assert(t, n == len(testData), n) assert.Nil(t, err) // Test Flush method err = writer.Flush() assert.Nil(t, err) // Test Finalize method err = writer.Finalize() assert.Nil(t, err) } ================================================ FILE: pkg/common/adaptor/request.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adaptor import ( "bytes" "net/http" "github.com/cloudwego/hertz/pkg/protocol" ) // GetCompatRequest only support basic function of Request, not for all. // // Deprecated: use HertzHandler instead func GetCompatRequest(req *protocol.Request) (*http.Request, error) { r, err := http.NewRequest(string(req.Method()), req.URI().String(), bytes.NewReader(req.Body())) if err != nil { return r, err } h := make(map[string][]string) req.Header.VisitAll(func(k, v []byte) { h[string(k)] = append(h[string(k)], string(v)) }) r.Header = h return r, nil } // CopyToHertzRequest copy uri, host, method, protocol, header, but share body reader from http.Request to protocol.Request. func CopyToHertzRequest(req *http.Request, hreq *protocol.Request) error { hreq.Header.SetRequestURI(req.RequestURI) hreq.Header.SetHost(req.Host) hreq.Header.SetMethod(req.Method) hreq.Header.SetProtocol(req.Proto) for k, v := range req.Header { for _, vv := range v { hreq.Header.Add(k, vv) } } if req.Body != nil { hreq.SetBodyStream(req.Body, hreq.Header.ContentLength()) } return nil } ================================================ FILE: pkg/common/adaptor/request_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adaptor import ( "context" "io/ioutil" "net" "net/http" "net/url" "path" "strings" "testing" "time" "github.com/cloudwego/hertz/internal/testutils" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/app/server" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" ) func fullURL(ln net.Listener, p string) string { return "http://" + path.Join(ln.Addr().String(), p) } func TestCompatResponse_WriteHeader(t *testing.T) { var testHeader http.Header var testBody string testStatusCode := 299 testCookieValue := "cookie" testHeader = make(map[string][]string) testHeader["Key1"] = []string{"value1"} testHeader["Key2"] = []string{"value2", "value22"} testHeader["Key3"] = []string{"value3", "value33", "value333"} testHeader[consts.HeaderSetCookie] = []string{testCookieValue} testBody = "test body" ln := testutils.NewTestListener(t) defer ln.Close() h := server.New(server.WithListener(ln)) h.POST("/test1", func(c context.Context, ctx *app.RequestContext) { req, _ := GetCompatRequest(&ctx.Request) resp := GetCompatResponseWriter(&ctx.Response) handlerAndCheck(t, resp, req, testHeader, testBody, testStatusCode) }) h.POST("/test2", func(c context.Context, ctx *app.RequestContext) { req, _ := GetCompatRequest(&ctx.Request) resp := GetCompatResponseWriter(&ctx.Response) handlerAndCheck(t, resp, req, testHeader, testBody) }) go h.Spin() time.Sleep(100 * time.Millisecond) testUrl1 := fullURL(ln, "/test1") testUrl2 := fullURL(ln, "/test2") makeACall(t, http.MethodPost, testUrl1, testHeader, testBody, testStatusCode, []byte(testCookieValue)) makeACall(t, http.MethodPost, testUrl2, testHeader, testBody, consts.StatusOK, []byte(testCookieValue)) } func makeACall(t *testing.T, method, url string, header http.Header, body string, expectStatusCode int, expectCookieValue []byte) { client := http.Client{} req, _ := http.NewRequest(method, url, strings.NewReader(body)) req.Header = header resp, err := client.Do(req) if err != nil { t.Fatalf("make a call error: %s", err) } respHeader := resp.Header for k, v := range header { for i := 0; i < len(v); i++ { if respHeader[k][i] != v[i] { t.Fatalf("Header error: want %s=%s, got %s=%s", respHeader[k], respHeader[k][i], respHeader[k], v[i]) } } } b, err := ioutil.ReadAll(resp.Body) if err != nil { t.Fatalf("Read body error: %s", err) } assert.DeepEqual(t, body, string(b)) assert.DeepEqual(t, expectStatusCode, resp.StatusCode) // Parse out the cookie to verify it is correct cookie := protocol.Cookie{} _ = cookie.Parse(header[consts.HeaderSetCookie][0]) assert.DeepEqual(t, expectCookieValue, cookie.Value()) } // handlerAndCheck is designed to handle the program and check the header // // "..." is used in the type of statusCode, which is a syntactic sugar in Go. // In this way, the statusCode can be made an optional parameter, // and there is no need to pass in some meaningless numbers to judge some special cases. func handlerAndCheck(t *testing.T, writer http.ResponseWriter, request *http.Request, wantHeader http.Header, wantBody string, statusCode ...int) { reqHeader := request.Header for k, v := range wantHeader { if reqHeader[k] == nil { t.Fatalf("Header error: want %s=%s, got %s=nil", reqHeader[k], reqHeader[k][0], reqHeader[k]) } if reqHeader[k][0] != v[0] { t.Fatalf("Header error: want %s=%s, got %s=%s", reqHeader[k], reqHeader[k][0], reqHeader[k], v[0]) } } body, err := ioutil.ReadAll(request.Body) if err != nil { t.Fatalf("Read body error: %s", err) } assert.DeepEqual(t, wantBody, string(body)) respHeader := writer.Header() for k, v := range reqHeader { respHeader[k] = v } // When the incoming status code is nil, the execution of this code is skipped // and the status code is set to 200 if statusCode != nil { writer.WriteHeader(statusCode[0]) } _, err = writer.Write([]byte("test")) if err != nil { t.Fatalf("Write body error: %s", err) } _, err = writer.Write([]byte(" body")) if err != nil { t.Fatalf("Write body error: %s", err) } } func TestCopyToHertzRequest(t *testing.T) { req := http.Request{ Method: "GET", RequestURI: "/test", URL: &url.URL{ Scheme: "http", Host: "test.com", }, Proto: "HTTP/1.1", Header: http.Header{}, } req.Header.Set("key1", "value1") req.Header.Add("key2", "value2") req.Header.Add("key2", "value22") hertzReq := protocol.Request{} err := CopyToHertzRequest(&req, &hertzReq) assert.Nil(t, err) assert.DeepEqual(t, req.Method, string(hertzReq.Method())) assert.DeepEqual(t, req.RequestURI, string(hertzReq.Path())) assert.DeepEqual(t, req.Proto, hertzReq.Header.GetProtocol()) assert.DeepEqual(t, req.Header.Get("key1"), hertzReq.Header.Get("key1")) valueSlice := make([]string, 0, 2) hertzReq.Header.VisitAllCustomHeader(func(key, value []byte) { if strings.ToLower(string(key)) == "key2" { valueSlice = append(valueSlice, string(value)) } }) assert.DeepEqual(t, req.Header.Values("key2"), valueSlice) assert.DeepEqual(t, 3, hertzReq.Header.Len()) } ================================================ FILE: pkg/common/adaptor/response.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adaptor import ( "net/http" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" ) type compatResponse struct { h *protocol.Response header http.Header writeHeader bool } func (c *compatResponse) Header() http.Header { if c.header != nil { return c.header } c.header = make(map[string][]string) return c.header } func (c *compatResponse) Write(b []byte) (int, error) { if !c.writeHeader { c.WriteHeader(consts.StatusOK) } return c.h.BodyWriter().Write(b) } func (c *compatResponse) WriteHeader(statusCode int) { if !c.writeHeader { for k, v := range c.header { for _, vv := range v { if k == consts.HeaderContentLength { continue } if k == consts.HeaderSetCookie { cookie := protocol.AcquireCookie() _ = cookie.Parse(vv) c.h.Header.SetCookie(cookie) continue } c.h.Header.Add(k, vv) } } c.writeHeader = true } c.h.Header.SetStatusCode(statusCode) } // GetCompatResponseWriter only support basic function of ResponseWriter, not for all. // // Deprecated: use HertzHandler instead func GetCompatResponseWriter(resp *protocol.Response) http.ResponseWriter { c := &compatResponse{ h: resp, } c.h.Header.SetNoDefaultContentType(true) h := make(map[string][]string) tmpKey := make([][]byte, 0, c.h.Header.Len()) c.h.Header.VisitAll(func(k, v []byte) { h[string(k)] = append(h[string(k)], string(v)) tmpKey = append(tmpKey, k) }) for _, k := range tmpKey { c.h.Header.DelBytes(k) } c.header = h return c } ================================================ FILE: pkg/common/adaptor/utils.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adaptor import ( "bytes" "fmt" "io" "strconv" "strings" "unsafe" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/protocol/consts" ) // methodstr tries to return consts without allocation func methodstr(m []byte) string { if len(m) == 0 { return "GET" } switch m[0] { case 'G': if string(m) == consts.MethodGet { return consts.MethodGet } case 'P': switch string(m) { case consts.MethodPost: return consts.MethodPost case consts.MethodPut: return consts.MethodPut case consts.MethodPatch: return consts.MethodPatch } case 'H': if string(m) == consts.MethodHead { return consts.MethodHead } case 'D': if string(m) == consts.MethodDelete { return consts.MethodDelete } } return string(m) } type bytesRWCloser struct { bytes.Reader } // newBytesRWCloser adds noop Close method to bytes.Reader without allocation func newBytesRWCloser(b []byte) io.ReadCloser { rd := bytes.NewReader(b) return (*bytesRWCloser)(unsafe.Pointer(rd)) } // Close implements the [io.Closer] interface. func (bytesRWCloser) Close() error { return nil } func writer2writerExt(w network.Writer) network.ExtWriter { return extWriter{w} } type extWriter struct { network.Writer } func (w extWriter) Write(b []byte) (int, error) { buf, err := w.Writer.Malloc(len(b)) if err != nil { return 0, err } return copy(buf, b), nil } func (w extWriter) Finalize() error { return w.Flush() } func reader2closer(r io.Reader) io.ReadCloser { rc, ok := r.(io.ReadCloser) if ok { return rc } return io.NopCloser(r) } func parseHTTPVersion(s string) (major, minor int, _ error) { v := strings.TrimPrefix(s, "HTTP/") if len(v) == len(s) { return 1, 1, fmt.Errorf("invalid http version: %q", s) } switch v { case "1.0": return 1, 0, nil case "1.1": return 1, 1, nil default: a, b, ok := strings.Cut(v, ".") if ok { major, err1 := strconv.Atoi(a) minor, err2 := strconv.Atoi(b) if err1 == nil && err2 == nil { return major, minor, nil } } } return 1, 1, fmt.Errorf("invalid http version: %q", s) } ================================================ FILE: pkg/common/adaptor/utils_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package adaptor import ( "runtime" "testing" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestMethodStr(t *testing.T) { var m0, m1 runtime.MemStats assertEqual := func(got, expect string) { if got != expect { t.Helper() t.Fatal(got) } } runtime.ReadMemStats(&m0) for i := 0; i < 1000; i++ { assertEqual(methodstr([]byte(nil)), "GET") assertEqual(methodstr(bytestr.StrGet), "GET") assertEqual(methodstr(bytestr.StrHead), "HEAD") assertEqual(methodstr(bytestr.StrPost), "POST") assertEqual(methodstr(bytestr.StrPut), "PUT") assertEqual(methodstr(bytestr.StrDelete), "DELETE") assertEqual(methodstr(bytestr.StrPatch), "PATCH") } runtime.ReadMemStats(&m1) // should be zero, but in case of other background task running diff := m1.Mallocs - m0.Mallocs assert.Assert(t, diff < 50, diff) } func TestBytesRWCloser(t *testing.T) { // Test data testData := []byte("test bytesRWCloser") // Create a new bytesRWCloser rc := newBytesRWCloser(testData) // Read from the reader buf := make([]byte, len(testData)) n, err := rc.Read(buf) // Verify read was successful assert.Nil(t, err) assert.DeepEqual(t, n, len(testData)) assert.DeepEqual(t, buf, testData) // Test that Close returns nil err = rc.Close() assert.Nil(t, err) } func TestParseHTTPVersion(t *testing.T) { // Test HTTP/1.0 major, minor, err := parseHTTPVersion("HTTP/1.0") assert.Nil(t, err) assert.DeepEqual(t, major, 1) assert.DeepEqual(t, minor, 0) // Test HTTP/1.1 major, minor, err = parseHTTPVersion("HTTP/1.1") assert.Nil(t, err) assert.DeepEqual(t, major, 1) assert.DeepEqual(t, minor, 1) // Test HTTP/2.0 major, minor, err = parseHTTPVersion("HTTP/2.0") assert.Nil(t, err) assert.DeepEqual(t, major, 2) assert.DeepEqual(t, minor, 0) // Test HTTP/3.1 major, minor, err = parseHTTPVersion("HTTP/3.1") assert.Nil(t, err) assert.DeepEqual(t, major, 3) assert.DeepEqual(t, minor, 1) // Test missing HTTP prefix major, minor, err = parseHTTPVersion("1.1") assert.NotNil(t, err) assert.DeepEqual(t, major, 1) assert.DeepEqual(t, minor, 1) // Test missing dot separator major, minor, err = parseHTTPVersion("HTTP/11") assert.NotNil(t, err) assert.DeepEqual(t, major, 1) assert.DeepEqual(t, minor, 1) // Test empty string major, minor, err = parseHTTPVersion("") assert.NotNil(t, err) assert.DeepEqual(t, major, 1) assert.DeepEqual(t, minor, 1) } ================================================ FILE: pkg/common/bytebufferpool/bytebuffer.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package bytebufferpool import "io" // ByteBuffer provides byte buffer, which can be used for minimizing // memory allocations. // // ByteBuffer may be used with functions appending data to the given []byte // slice. See example code for details. // // Use Get for obtaining an empty byte buffer. type ByteBuffer struct { // B is a byte buffer to use in append-like workloads. // See example code for details. B []byte } // Len returns the size of the byte buffer. func (b *ByteBuffer) Len() int { return len(b.B) } func (b *ByteBuffer) Cap() int { return cap(b.B) } // ReadFrom implements io.ReaderFrom. // // The function appends all the data read from r to b. func (b *ByteBuffer) ReadFrom(r io.Reader) (int64, error) { p := b.B nStart := int64(len(p)) nMax := int64(cap(p)) n := nStart if nMax == 0 { nMax = 64 p = make([]byte, nMax) } else { p = p[:nMax] } for { if n == nMax { nMax *= 2 bNew := make([]byte, nMax) copy(bNew, p) p = bNew } nn, err := r.Read(p[n:]) n += int64(nn) if err != nil { b.B = p[:n] n -= nStart if err == io.EOF { return n, nil } return n, err } } } // WriteTo implements io.WriterTo. func (b *ByteBuffer) WriteTo(w io.Writer) (int64, error) { n, err := w.Write(b.B) return int64(n), err } // Bytes returns b.B, i.e. all the bytes accumulated in the buffer. // // The purpose of this function is bytes.Buffer compatibility. func (b *ByteBuffer) Bytes() []byte { return b.B } // Write implements io.Writer - it appends p to ByteBuffer.B func (b *ByteBuffer) Write(p []byte) (int, error) { b.B = append(b.B, p...) return len(p), nil } // WriteByte appends the byte c to the buffer. // // The purpose of this function is bytes.Buffer compatibility. // // The function always returns nil. func (b *ByteBuffer) WriteByte(c byte) error { b.B = append(b.B, c) return nil } // WriteString appends s to ByteBuffer.B. func (b *ByteBuffer) WriteString(s string) (int, error) { b.B = append(b.B, s...) return len(s), nil } // Set sets ByteBuffer.B to p. func (b *ByteBuffer) Set(p []byte) { b.B = append(b.B[:0], p...) } // SetString sets ByteBuffer.B to s. func (b *ByteBuffer) SetString(s string) { b.B = append(b.B[:0], s...) } // String returns string representation of ByteBuffer.B. func (b *ByteBuffer) String() string { return string(b.B) } // Reset makes ByteBuffer.B empty. func (b *ByteBuffer) Reset() { b.B = b.B[:0] } ================================================ FILE: pkg/common/bytebufferpool/bytebuffer_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package bytebufferpool import ( "bytes" "fmt" "io" "testing" "time" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestByteBufferReadFrom(t *testing.T) { prefix := "foobar" expectedS := "asadfsdafsadfasdfisdsdfa" prefixLen := int64(len(prefix)) expectedN := int64(len(expectedS)) var bb ByteBuffer bb.WriteString(prefix) rf := (io.ReaderFrom)(&bb) for i := 0; i < 20; i++ { r := bytes.NewBufferString(expectedS) n, err := rf.ReadFrom(r) if n != expectedN { t.Fatalf("unexpected n=%d. Expecting %d. iteration %d", n, expectedN, i) } if err != nil { t.Fatalf("unexpected error: %s", err) } bbLen := int64(bb.Len()) expectedLen := prefixLen + int64(i+1)*expectedN if bbLen != expectedLen { t.Fatalf("unexpected byteBuffer length: %d. Expecting %d", bbLen, expectedLen) } assert.True(t, bb.Cap() >= int(expectedLen)) for j := 0; j < i; j++ { start := prefixLen + int64(j)*expectedN b := bb.B[start : start+expectedN] if string(b) != expectedS { t.Fatalf("unexpected byteBuffer contents: %q. Expecting %q", b, expectedS) } } } } func TestByteBufferWriteTo(t *testing.T) { expectedS := "foobarbaz" var bb ByteBuffer bb.WriteString(expectedS[:3]) bb.WriteString(expectedS[3:]) wt := (io.WriterTo)(&bb) var w bytes.Buffer for i := 0; i < 10; i++ { n, err := wt.WriteTo(&w) if n != int64(len(expectedS)) { t.Fatalf("unexpected n returned from WriteTo: %d. Expecting %d", n, len(expectedS)) } if err != nil { t.Fatalf("unexpected error: %s", err) } s := w.String() if s != expectedS { t.Fatalf("unexpected string written %q. Expecting %q", s, expectedS) } w.Reset() assert.True(t, bb.Cap() >= len(expectedS)) } } func TestByteBufferGetPutSerial(t *testing.T) { testByteBufferGetPut(t) } func TestByteBufferGetPutConcurrent(t *testing.T) { concurrency := 10 ch := make(chan struct{}, concurrency) for i := 0; i < concurrency; i++ { go func() { testByteBufferGetPut(t) ch <- struct{}{} }() } for i := 0; i < concurrency; i++ { select { case <-ch: case <-time.After(time.Second): t.Fatalf("timeout!") } } } func testByteBufferGetPut(t *testing.T) { for i := 0; i < 10; i++ { expectedS := fmt.Sprintf("num %d", i) b := Get() b.B = append(b.B, "num "...) b.B = append(b.B, fmt.Sprintf("%d", i)...) if string(b.B) != expectedS { t.Fatalf("unexpected result: %q. Expecting %q", b.B, expectedS) } Put(b) } } func testByteBufferGetString(t *testing.T) { for i := 0; i < 10; i++ { expectedS := fmt.Sprintf("num %d", i) b := Get() b.SetString(expectedS) if b.String() != expectedS { t.Fatalf("unexpected result: %q. Expecting %q", b.B, expectedS) } Put(b) } } func TestByteBufferGetStringSerial(t *testing.T) { testByteBufferGetString(t) } func TestByteBufferGetStringConcurrent(t *testing.T) { concurrency := 10 ch := make(chan struct{}, concurrency) for i := 0; i < concurrency; i++ { go func() { testByteBufferGetString(t) ch <- struct{}{} }() } for i := 0; i < concurrency; i++ { select { case <-ch: case <-time.After(time.Second): t.Fatalf("timeout!") } } } ================================================ FILE: pkg/common/bytebufferpool/doc.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // The files in bytebufferpool package are forked from fasthttp[github.com/valyala/fasthttp], // and we keep the original Copyright[Copyright 2015 fasthttp authors] and License of fasthttp for those files. // We also need to modify as we need, the modifications are Copyright of 2022 CloudWeGo Authors. // Thanks for fasthttp authors! Below is the source code information: // Repo: github.com/valyala/fasthttp // Forked Version: v1.36.0 package bytebufferpool ================================================ FILE: pkg/common/bytebufferpool/pool.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package bytebufferpool import ( "sort" "sync" "sync/atomic" ) const ( minBitSize = 6 // 2**6=64 is a CPU cache line size steps = 20 minSize = 1 << minBitSize maxSize = 1 << (minBitSize + steps - 1) calibrateCallsThreshold = 42000 maxPercentile = 0.95 ) // Pool represents byte buffer pool. // // Distinct pools may be used for distinct types of byte buffers. // Properly determined byte buffer types with their own pools may help reducing // memory waste. type Pool struct { calls [steps]uint64 calibrating uint64 defaultSize uint64 maxSize uint64 pool sync.Pool } var defaultPool Pool // Get returns an empty byte buffer from the pool. // // Got byte buffer may be returned to the pool via Put call. // This reduces the number of memory allocations required for byte buffer // management. func Get() *ByteBuffer { return defaultPool.Get() } // Get returns new byte buffer with zero length. // // The byte buffer may be returned to the pool via Put after the use // in order to minimize GC overhead. func (p *Pool) Get() *ByteBuffer { v := p.pool.Get() if v != nil { return v.(*ByteBuffer) } return &ByteBuffer{ B: make([]byte, 0, atomic.LoadUint64(&p.defaultSize)), } } // Put returns byte buffer to the pool. // // ByteBuffer.B mustn't be touched after returning it to the pool. // Otherwise data races will occur. func Put(b *ByteBuffer) { defaultPool.Put(b) } // Put releases byte buffer obtained via Get to the pool. // // The buffer mustn't be accessed after returning to the pool. func (p *Pool) Put(b *ByteBuffer) { idx := index(len(b.B)) if atomic.AddUint64(&p.calls[idx], 1) > calibrateCallsThreshold { p.calibrate() } maxSize := int(atomic.LoadUint64(&p.maxSize)) if maxSize == 0 || cap(b.B) <= maxSize { b.Reset() p.pool.Put(b) } } func (p *Pool) calibrate() { if !atomic.CompareAndSwapUint64(&p.calibrating, 0, 1) { return } a := make(callSizes, 0, steps) var callsSum uint64 for i := uint64(0); i < steps; i++ { calls := atomic.SwapUint64(&p.calls[i], 0) callsSum += calls a = append(a, callSize{ calls: calls, size: minSize << i, }) } sort.Sort(a) defaultSize := a[0].size maxSize := defaultSize maxSum := uint64(float64(callsSum) * maxPercentile) callsSum = 0 for i := 0; i < steps; i++ { if callsSum > maxSum { break } callsSum += a[i].calls size := a[i].size if size > maxSize { maxSize = size } } atomic.StoreUint64(&p.defaultSize, defaultSize) atomic.StoreUint64(&p.maxSize, maxSize) atomic.StoreUint64(&p.calibrating, 0) } type callSize struct { calls uint64 size uint64 } type callSizes []callSize func (ci callSizes) Len() int { return len(ci) } func (ci callSizes) Less(i, j int) bool { return ci[i].calls > ci[j].calls } func (ci callSizes) Swap(i, j int) { ci[i], ci[j] = ci[j], ci[i] } func index(n int) int { n-- n >>= minBitSize idx := 0 for n > 0 { n >>= 1 idx++ } if idx >= steps { idx = steps - 1 } return idx } ================================================ FILE: pkg/common/bytebufferpool/pool_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package bytebufferpool import ( "math/rand" "testing" "time" ) func TestIndex(t *testing.T) { testIndex(t, 0, 0) testIndex(t, 1, 0) testIndex(t, minSize-1, 0) testIndex(t, minSize, 0) testIndex(t, minSize+1, 1) testIndex(t, 2*minSize-1, 1) testIndex(t, 2*minSize, 1) testIndex(t, 2*minSize+1, 2) testIndex(t, maxSize-1, steps-1) testIndex(t, maxSize, steps-1) testIndex(t, maxSize+1, steps-1) } func testIndex(t *testing.T, n, expectedIdx int) { idx := index(n) if idx != expectedIdx { t.Fatalf("unexpected idx for n=%d: %d. Expecting %d", n, idx, expectedIdx) } } func TestPoolCalibrate(t *testing.T) { for i := 0; i < steps*calibrateCallsThreshold; i++ { n := 1004 if i%15 == 0 { n = rand.Intn(15234) } testGetPut(t, n) } } func TestPoolVariousSizesSerial(t *testing.T) { testPoolVariousSizes(t) } func TestPoolVariousSizesConcurrent(t *testing.T) { concurrency := 5 ch := make(chan struct{}) for i := 0; i < concurrency; i++ { go func() { testPoolVariousSizes(t) ch <- struct{}{} }() } for i := 0; i < concurrency; i++ { select { case <-ch: case <-time.After(10 * time.Second): t.Fatalf("timeout") } } } func testPoolVariousSizes(t *testing.T) { for i := 0; i < steps+1; i++ { n := 1 << uint32(i) testGetPut(t, n) testGetPut(t, n+1) testGetPut(t, n-1) for j := 0; j < 10; j++ { testGetPut(t, j+n) } } } func testGetPut(t *testing.T, n int) { bb := Get() if len(bb.B) > 0 { t.Fatalf("non-empty byte buffer returned from acquire") } bb.B = allocNBytes(bb.B, n) Put(bb) } func allocNBytes(dst []byte, n int) []byte { diff := n - cap(dst) if diff <= 0 { return dst[:n] } return append(dst, make([]byte, diff)...) } ================================================ FILE: pkg/common/compress/compress.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package compress import ( "bytes" "compress/gzip" "fmt" "io" "sync" "github.com/cloudwego/hertz/pkg/common/bytebufferpool" "github.com/cloudwego/hertz/pkg/common/stackless" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/network" ) const CompressDefaultCompression = 6 // flate.DefaultCompression var gzipReaderPool sync.Pool var ( stacklessGzipWriterPoolMap = newCompressWriterPoolMap() realGzipWriterPoolMap = newCompressWriterPoolMap() ) func newCompressWriterPoolMap() []*sync.Pool { // Initialize pools for all the compression levels defined // in https://golang.org/pkg/compress/flate/#pkg-constants . // Compression levels are normalized with normalizeCompressLevel, // so the fit [0..11]. var m []*sync.Pool for i := 0; i < 12; i++ { m = append(m, &sync.Pool{}) } return m } type compressCtx struct { w io.Writer p []byte level int } // AppendGunzipBytes appends gunzipped src to dst and returns the resulting dst. func AppendGunzipBytes(dst, src []byte) ([]byte, error) { w := &byteSliceWriter{dst} _, err := WriteGunzip(w, src) return w.b, err } type byteSliceWriter struct { b []byte } func (w *byteSliceWriter) Write(p []byte) (int, error) { w.b = append(w.b, p...) return len(p), nil } // WriteGunzip writes gunzipped p to w and returns the number of uncompressed // bytes written to w. func WriteGunzip(w io.Writer, p []byte) (int, error) { r := &byteSliceReader{p} zr, err := AcquireGzipReader(r) if err != nil { return 0, err } zw := network.NewWriter(w) n, err := utils.CopyZeroAlloc(zw, zr) ReleaseGzipReader(zr) nn := int(n) if int64(nn) != n { return 0, fmt.Errorf("too much data gunzipped: %d", n) } return nn, err } type byteSliceReader struct { b []byte } func (r *byteSliceReader) Read(p []byte) (int, error) { if len(r.b) == 0 { return 0, io.EOF } n := copy(p, r.b) r.b = r.b[n:] return n, nil } func AcquireGzipReader(r io.Reader) (*gzip.Reader, error) { v := gzipReaderPool.Get() if v == nil { return gzip.NewReader(r) } zr := v.(*gzip.Reader) if err := zr.Reset(r); err != nil { return nil, err } return zr, nil } func ReleaseGzipReader(zr *gzip.Reader) { zr.Close() gzipReaderPool.Put(zr) } // AppendGzipBytes appends gzipped src to dst and returns the resulting dst. func AppendGzipBytes(dst, src []byte) []byte { return AppendGzipBytesLevel(dst, src, CompressDefaultCompression) } // AppendGzipBytesLevel appends gzipped src to dst using the given // compression level and returns the resulting dst. // // Supported compression levels are: // // - CompressNoCompression // - CompressBestSpeed // - CompressBestCompression // - CompressDefaultCompression // - CompressHuffmanOnly func AppendGzipBytesLevel(dst, src []byte, level int) []byte { w := &byteSliceWriter{dst} WriteGzipLevel(w, src, level) //nolint:errcheck return w.b } var stacklessWriteGzip = stackless.NewFunc(nonblockingWriteGzip) func nonblockingWriteGzip(ctxv interface{}) { ctx := ctxv.(*compressCtx) zw := acquireRealGzipWriter(ctx.w, ctx.level) _, err := zw.Write(ctx.p) if err != nil { panic(fmt.Sprintf("BUG: gzip.Writer.Write for len(p)=%d returned unexpected error: %s", len(ctx.p), err)) } releaseRealGzipWriter(zw, ctx.level) } func releaseRealGzipWriter(zw *gzip.Writer, level int) { zw.Close() nLevel := normalizeCompressLevel(level) p := realGzipWriterPoolMap[nLevel] p.Put(zw) } func acquireRealGzipWriter(w io.Writer, level int) *gzip.Writer { nLevel := normalizeCompressLevel(level) p := realGzipWriterPoolMap[nLevel] v := p.Get() if v == nil { zw, err := gzip.NewWriterLevel(w, level) if err != nil { panic(fmt.Sprintf("BUG: unexpected error from gzip.NewWriterLevel(%d): %s", level, err)) } return zw } zw := v.(*gzip.Writer) zw.Reset(w) return zw } // normalizes compression level into [0..11], so it could be used as an index // in *PoolMap. func normalizeCompressLevel(level int) int { // -2 is the lowest compression level - CompressHuffmanOnly // 9 is the highest compression level - CompressBestCompression if level < -2 || level > 9 { level = CompressDefaultCompression } return level + 2 } // WriteGzipLevel writes gzipped p to w using the given compression level // and returns the number of compressed bytes written to w. // // Supported compression levels are: // // - CompressNoCompression // - CompressBestSpeed // - CompressBestCompression // - CompressDefaultCompression // - CompressHuffmanOnly func WriteGzipLevel(w io.Writer, p []byte, level int) (int, error) { switch w.(type) { case *byteSliceWriter, *bytes.Buffer, *bytebufferpool.ByteBuffer: // These writers don't block, so we can just use stacklessWriteGzip ctx := &compressCtx{ w: w, p: p, level: level, } stacklessWriteGzip(ctx) return len(p), nil default: zw := AcquireStacklessGzipWriter(w, level) n, err := zw.Write(p) ReleaseStacklessGzipWriter(zw, level) return n, err } } func AcquireStacklessGzipWriter(w io.Writer, level int) stackless.Writer { nLevel := normalizeCompressLevel(level) p := stacklessGzipWriterPoolMap[nLevel] v := p.Get() if v == nil { return stackless.NewWriter(w, func(w io.Writer) stackless.Writer { return acquireRealGzipWriter(w, level) }) } sw := v.(stackless.Writer) sw.Reset(w) return sw } func ReleaseStacklessGzipWriter(sw stackless.Writer, level int) { sw.Close() nLevel := normalizeCompressLevel(level) p := stacklessGzipWriterPoolMap[nLevel] p.Put(sw) } ================================================ FILE: pkg/common/compress/compress_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package compress import ( "io" "testing" ) func TestCompressNewCompressWriterPoolMap(t *testing.T) { pool := newCompressWriterPoolMap() if len(pool) != 12 { t.Fatalf("Unexpected number for WriterPoolMap: %d. Expecting 12", len(pool)) } } func TestCompressAppendGunzipBytes(t *testing.T) { dst1 := []byte("") // src unzip -> "hello". The src must the string that has been gunzipped. src1 := []byte{31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 202, 72, 205, 201, 201, 7, 0, 0, 0, 255, 255} expectedRes1 := "hello" res1, err1 := AppendGunzipBytes(dst1, src1) // gzip will wrap io.EOF to io.ErrUnexpectedEOF // just ignore in this case if err1 != io.ErrUnexpectedEOF { t.Fatalf("Unexpected error: %s", err1) } if string(res1) != expectedRes1 { t.Fatalf("Unexpected : %s. Expecting : %s", res1, expectedRes1) } dst2 := []byte("!!!") src2 := []byte{31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 202, 72, 205, 201, 201, 7, 0, 0, 0, 255, 255} expectedRes2 := "!!!hello" res2, err2 := AppendGunzipBytes(dst2, src2) if err2 != io.ErrUnexpectedEOF { t.Fatalf("Unexpected error: %s", err2) } if string(res2) != expectedRes2 { t.Fatalf("Unexpected : %s. Expecting : %s", res2, expectedRes2) } dst3 := []byte("!!!") src3 := []byte{31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 0, 0, 0, 255, 255} expectedRes3 := "!!!" res3, err3 := AppendGunzipBytes(dst3, src3) if err3 != io.ErrUnexpectedEOF { t.Fatalf("Unexpected error: %s", err3) } if string(res3) != expectedRes3 { t.Fatalf("Unexpected : %s. Expecting : %s", res3, expectedRes3) } } func TestCompressAppendGzipBytesLevel(t *testing.T) { // test the byteSliceWriter case for WriteGzipLevel dst1 := []byte("") src1 := []byte("hello") res1 := AppendGzipBytesLevel(dst1, src1, 5) expectedRes1 := []byte{31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 202, 72, 205, 201, 201, 7, 4, 0, 0, 255, 255, 134, 166, 16, 54, 5, 0, 0, 0} if string(res1) != string(expectedRes1) { t.Fatalf("Unexpected : %s. Expecting : %s", res1, expectedRes1) } } func TestCompressWriteGzipLevel(t *testing.T) { // test default case for WriteGzipLevel var w defaultByteWriter p := []byte("hello") expectedW := []byte{31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 202, 72, 205, 201, 201, 7, 4, 0, 0, 255, 255, 134, 166, 16, 54, 5, 0, 0, 0} num, err := WriteGzipLevel(&w, p, 5) if string(expectedW) != string(w.b) { t.Fatalf("Unexpected : %s. Expecting: %s.", w.b, expectedW) } if num != len(p) { t.Fatalf("Unexpected number of compressed bytes: %d", num) } if err != nil { t.Fatalf("Unexpected error: %s", err) } } type defaultByteWriter struct { b []byte } func (w *defaultByteWriter) Write(p []byte) (int, error) { w.b = append(w.b, p...) return len(p), nil } ================================================ FILE: pkg/common/compress/doc.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // The files in compress package are forked from fasthttp[github.com/valyala/fasthttp], // and we keep the original Copyright[Copyright 2015 fasthttp authors] and License of fasthttp for those files. // We also need to modify as we need, the modifications are Copyright of 2022 CloudWeGo Authors. // Thanks for fasthttp authors! Below is the source code information: // Repo: github.com/valyala/fasthttp // Forked Version: v1.36.0 package compress ================================================ FILE: pkg/common/config/client_option.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package config import ( "crypto/tls" "time" "github.com/cloudwego/hertz/pkg/app/client/retry" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/protocol/consts" ) type ConnPoolState struct { // The conn num of conn pool. These conns are idle connections. PoolConnNum int // Total conn num. TotalConnNum int // Number of pending connections WaitConnNum int // HostClient Addr Addr string // Maximum number of connections, <= 0 means no limit. MaxConns int } type HostClientState interface { ConnPoolState() ConnPoolState } type HostClientStateFunc func(HostClientState) // ClientOption is the only struct that can be used to set ClientOptions. type ClientOption struct { F func(o *ClientOptions) } type ClientOptions struct { // Timeout for establishing a connection to server DialTimeout time.Duration // The max connection nums for each host, <= 0 means no limit MaxConnsPerHost int MaxIdleConnDuration time.Duration MaxConnDuration time.Duration MaxConnWaitTimeout time.Duration KeepAlive bool ReadTimeout time.Duration TLSConfig *tls.Config ResponseBodyStream bool // Client name. Used in User-Agent request header. // // Default client name is used if not set. Name string // NoDefaultUserAgentHeader when set to true, causes the default // User-Agent header to be excluded from the Request. NoDefaultUserAgentHeader bool // Dialer is the custom dialer used to establish connection. // Default Dialer is used if not set. Dialer network.Dialer // Attempt to connect to both ipv4 and ipv6 addresses if set to true. // // This option is used only if default TCP dialer is used, // i.e. if Dialer is blank. // // By default client connects only to ipv4 addresses, // since unfortunately ipv6 remains broken in many networks worldwide :) DialDualStack bool // Maximum duration for full request writing (including body). // // By default request write timeout is unlimited. WriteTimeout time.Duration // Maximum response body size. // // The client returns ErrBodyTooLarge if this limit is greater than 0 // and response body is greater than the limit. // // By default response body size is unlimited. MaxResponseBodySize int // Header names are passed as-is without normalization // if this option is set. // // Disabled header names' normalization may be useful only for proxying // responses to other clients expecting case-sensitive header names. // // By default request and response header names are normalized, i.e. // The first letter and the first letters following dashes // are uppercased, while all the other letters are lowercased. // Examples: // // * HOST -> Host // * content-type -> Content-Type // * cONTENT-lenGTH -> Content-Length DisableHeaderNamesNormalizing bool // Path values are sent as-is without normalization // // Disabled path normalization may be useful for proxying incoming requests // to servers that are expecting paths to be forwarded as-is. // // By default path values are normalized, i.e. // extra slashes are removed, special characters are encoded. DisablePathNormalizing bool // all configurations related to retry RetryConfig *retry.Config HostClientStateObserve HostClientStateFunc // StateObserve execution interval ObservationInterval time.Duration // Callback hook for re-configuring host client // If an error is returned, the request will be terminated. HostClientConfigHook func(hc interface{}) error } func NewClientOptions(opts []ClientOption) *ClientOptions { options := &ClientOptions{ DialTimeout: consts.DefaultDialTimeout, MaxIdleConnDuration: consts.DefaultMaxIdleConnDuration, KeepAlive: true, ObservationInterval: time.Second * 5, } options.Apply(opts) return options } func (o *ClientOptions) Apply(opts []ClientOption) { for _, op := range opts { op.F(o) } } ================================================ FILE: pkg/common/config/client_option_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package config import ( "testing" "time" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/protocol/consts" ) // TestDefaultClientOptions test client options with default values func TestDefaultClientOptions(t *testing.T) { options := NewClientOptions([]ClientOption{}) assert.DeepEqual(t, consts.DefaultDialTimeout, options.DialTimeout) assert.DeepEqual(t, 0, options.MaxConnsPerHost) assert.DeepEqual(t, consts.DefaultMaxIdleConnDuration, options.MaxIdleConnDuration) assert.DeepEqual(t, true, options.KeepAlive) } // TestCustomClientOptions test client options with custom values func TestCustomClientOptions(t *testing.T) { options := NewClientOptions([]ClientOption{}) options.Apply([]ClientOption{ { F: func(o *ClientOptions) { o.DialTimeout = 2 * time.Second }, }, }) assert.DeepEqual(t, 2*time.Second, options.DialTimeout) } ================================================ FILE: pkg/common/config/option.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package config import ( "context" "crypto/tls" "net" "time" "github.com/cloudwego/hertz/pkg/app/server/registry" "github.com/cloudwego/hertz/pkg/network" ) // Option is the only struct that can be used to set Options. type Option struct { F func(o *Options) } const ( defaultKeepAliveTimeout = 1 * time.Minute defaultReadTimeout = 3 * time.Minute defaultAddr = ":8888" defaultNetwork = "tcp" defaultBasePath = "/" defaultMaxRequestBodySize = 4 << 20 // 4MB defaultMaxHeaderBytes = 1 << 20 // 1MB defaultWaitExitTimeout = time.Second * 5 defaultReadBufferSize = 4 * 1024 ) type Options struct { KeepAliveTimeout time.Duration ReadTimeout time.Duration WriteTimeout time.Duration IdleTimeout time.Duration RedirectTrailingSlash bool MaxRequestBodySize int MaxHeaderBytes int MaxKeepBodySize int GetOnly bool DisableKeepalive bool RedirectFixedPath bool HandleMethodNotAllowed bool UseRawPath bool RemoveExtraSlash bool UnescapePathValues bool DisablePreParseMultipartForm bool NoDefaultDate bool NoDefaultContentType bool StreamRequestBody bool NoDefaultServerHeader bool DisablePrintRoute bool SenseClientDisconnection bool Network string Addr string BasePath string ExitWaitTimeout time.Duration TLS *tls.Config H2C bool ReadBufferSize int ALPN bool Tracers []interface{} TraceLevel interface{} Listener net.Listener ListenConfig *net.ListenConfig BindConfig interface{} CustomBinder interface{} CustomValidator interface{} // Deprecated: Use CustomValidator with a ValidatorFunc instead ValidateConfig interface{} // TransporterNewer is the function to create a transporter. TransporterNewer func(opt *Options) network.Transporter AltTransporterNewer func(opt *Options) network.Transporter // In netpoll library, OnAccept is called after connection accepted // but before adding it to epoll. OnConnect is called after adding it to epoll. // The difference is that onConnect can get data but OnAccept cannot. // If you'd like to check whether the peer IP is in the blacklist, you can use OnAccept. // In go net, OnAccept is executed after connection accepted but before establishing // tls connection. OnConnect is executed after establishing tls connection. OnAccept func(conn net.Conn) context.Context OnConnect func(ctx context.Context, conn network.Conn) context.Context // Registry is used for service registry. Registry registry.Registry // RegistryInfo is base info used for service registry. RegistryInfo *registry.Info // Enable automatically HTML template reloading mechanism. AutoReloadRender bool // If AutoReloadInterval is set to 0(default). // The HTML template will reload according to files' changing event // otherwise it will reload after AutoReloadInterval. AutoReloadInterval time.Duration // Header names are passed as-is without normalization // if this option is set. // // Disabled header names' normalization may be useful only for proxying // responses to other clients expecting case-sensitive header names. // // By default, request and response header names are normalized, i.e. // The first letter and the first letters following dashes // are uppercased, while all the other letters are lowercased. // Examples: // // * HOST -> Host // * content-type -> Content-Type // * cONTENT-lenGTH -> Content-Length DisableHeaderNamesNormalizing bool } func (o *Options) Apply(opts []Option) { for _, op := range opts { op.F(o) } } func NewOptions(opts []Option) *Options { options := &Options{ // Keep-alive timeout. When idle connection exceeds this time, // server will send keep-alive packets to ensure it's a validated // connection. // // NOTE: Usually there is no need to care about this value, just // care about IdleTimeout. KeepAliveTimeout: defaultKeepAliveTimeout, // the timeout of reading from low-level library ReadTimeout: defaultReadTimeout, // When there is no request during the idleTimeout, the connection // will be closed by server. // Default to ReadTimeout. Zero means no timeout. IdleTimeout: defaultReadTimeout, // Enables automatic redirection if the current route can't be matched but a // handler for the path with (without) the trailing slash exists. // For example if /foo/ is requested but a route only exists for /foo, the // client is redirected to /foo with http status code 301 for GET requests // and 308 for all other request methods. RedirectTrailingSlash: true, // If enabled, the router tries to fix the current request path, if no // handle is registered for it. // First superfluous path elements like ../ or // are removed. // Afterwards the router does a case-insensitive lookup of the cleaned path. // If a handle can be found for this route, the router makes a redirection // to the corrected path with status code 301 for GET requests and 308 for // all other request methods. // For example /FOO and /..//Foo could be redirected to /foo. // RedirectTrailingSlash is independent of this option. RedirectFixedPath: false, // If enabled, the router checks if another method is allowed for the // current route, if the current request can not be routed. // If this is the case, the request is answered with 'Method Not Allowed' // and HTTP status code 405. // If no other Method is allowed, the request is delegated to the NotFound // handler. HandleMethodNotAllowed: false, // If enabled, the url.RawPath will be used to find parameters. UseRawPath: false, // RemoveExtraSlash a parameter can be parsed from the URL even with extra slashes. RemoveExtraSlash: false, // If true, the path value will be unescaped. // If UseRawPath is false (by default), the UnescapePathValues effectively is true, // as url.Path gonna be used, which is already unescaped. UnescapePathValues: true, // ContinueHandler is called after receiving the Expect 100 Continue Header // // https://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html#sec8.2.3 // https://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.1.1 // Using ContinueHandler a server can make decisioning on whether or not // to read a potentially large request body based on the headers // // The default is to automatically read request bodies of Expect 100 Continue requests // like they are normal requests DisablePreParseMultipartForm: false, // When set to true, causes the default Content-Type header to be excluded from the response. NoDefaultContentType: false, // When set to true, causes the default date header to be excluded from the response. NoDefaultDate: false, // Routes info printing is not disabled by default // Disabled when set to True DisablePrintRoute: false, // The ability to sense client disconnection is disabled by default SenseClientDisconnection: false, // "tcp", "udp", "unix"(unix domain socket) Network: defaultNetwork, // listen address Addr: defaultAddr, // basePath BasePath: defaultBasePath, // Define the max request body size. If the body Size exceeds this value, // an error will be returned MaxRequestBodySize: defaultMaxRequestBodySize, // Define the max request header size. If the header size exceeds this value, // an error will be returned MaxHeaderBytes: defaultMaxHeaderBytes, // max reserved body buffer size when reset Request & Response // If the body size exceeds this value, then the buffer will be put to // sync.Pool instead of hold by Request/Response directly. MaxKeepBodySize: defaultMaxRequestBodySize, // only accept GET request GetOnly: false, DisableKeepalive: false, // request body stream switch StreamRequestBody: false, NoDefaultServerHeader: false, // graceful shutdown wait time ExitWaitTimeout: defaultWaitExitTimeout, // tls config TLS: nil, // Set init read buffer size. Usually there is no need to set it. ReadBufferSize: defaultReadBufferSize, // ALPN switch ALPN: false, // H2C switch H2C: false, // tracers Tracers: []interface{}{}, // trace level, default LevelDetailed TraceLevel: new(interface{}), Registry: registry.NoopRegistry, // Disabled header names' normalization, default false DisableHeaderNamesNormalizing: false, } options.Apply(opts) return options } ================================================ FILE: pkg/common/config/option_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package config import ( "testing" "time" "github.com/cloudwego/hertz/pkg/app/server/registry" "github.com/cloudwego/hertz/pkg/common/test/assert" ) // TestDefaultOptions test options with default values func TestDefaultOptions(t *testing.T) { options := NewOptions([]Option{}) assert.DeepEqual(t, defaultKeepAliveTimeout, options.KeepAliveTimeout) assert.DeepEqual(t, defaultReadTimeout, options.ReadTimeout) assert.DeepEqual(t, defaultReadTimeout, options.IdleTimeout) assert.DeepEqual(t, time.Duration(0), options.WriteTimeout) assert.True(t, options.RedirectTrailingSlash) assert.True(t, options.RedirectTrailingSlash) assert.False(t, options.HandleMethodNotAllowed) assert.False(t, options.UseRawPath) assert.False(t, options.RemoveExtraSlash) assert.True(t, options.UnescapePathValues) assert.False(t, options.DisablePreParseMultipartForm) assert.False(t, options.SenseClientDisconnection) assert.DeepEqual(t, defaultNetwork, options.Network) assert.DeepEqual(t, defaultAddr, options.Addr) assert.DeepEqual(t, defaultMaxRequestBodySize, options.MaxRequestBodySize) assert.False(t, options.GetOnly) assert.False(t, options.DisableKeepalive) assert.False(t, options.NoDefaultServerHeader) assert.DeepEqual(t, defaultWaitExitTimeout, options.ExitWaitTimeout) assert.Nil(t, options.TLS) assert.DeepEqual(t, defaultReadBufferSize, options.ReadBufferSize) assert.False(t, options.ALPN) assert.False(t, options.H2C) assert.DeepEqual(t, []interface{}{}, options.Tracers) assert.DeepEqual(t, new(interface{}), options.TraceLevel) assert.DeepEqual(t, registry.NoopRegistry, options.Registry) assert.Nil(t, options.BindConfig) assert.Nil(t, options.ValidateConfig) assert.Nil(t, options.CustomBinder) assert.Nil(t, options.CustomValidator) assert.DeepEqual(t, false, options.DisableHeaderNamesNormalizing) } // TestApplyCustomOptions test apply options with custom values after init func TestApplyCustomOptions(t *testing.T) { options := NewOptions([]Option{}) options.Apply([]Option{ {F: func(o *Options) { o.Network = "unix" }}, }) assert.DeepEqual(t, "unix", options.Network) } ================================================ FILE: pkg/common/config/request_option.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package config import "time" var preDefinedOpts []RequestOption type RequestOptions struct { tags map[string]string isSD bool dialTimeout time.Duration readTimeout time.Duration writeTimeout time.Duration // Request timeout. Usually set by DoDeadline or DoTimeout // if <= 0, means not set requestTimeout time.Duration start time.Time } // RequestOption is the only struct to set request-level options. type RequestOption struct { F func(o *RequestOptions) } // NewRequestOptions create a *RequestOptions according to the given opts. func NewRequestOptions(opts []RequestOption) *RequestOptions { options := &RequestOptions{ tags: make(map[string]string), isSD: false, } if preDefinedOpts != nil { options.Apply(preDefinedOpts) } options.Apply(opts) return options } // WithTag set tag in RequestOptions. func WithTag(k, v string) RequestOption { return RequestOption{F: func(o *RequestOptions) { o.tags[k] = v }} } // WithSD set isSD in RequestOptions. func WithSD(b bool) RequestOption { return RequestOption{F: func(o *RequestOptions) { o.isSD = b }} } // WithDialTimeout sets dial timeout. // // This is the request level configuration. It has a higher // priority than the client level configuration // Note: it won't take effect in the case of the number of // connections in the connection pool exceeds the maximum // number of connections and needs to establish a connection // while waiting. func WithDialTimeout(t time.Duration) RequestOption { return RequestOption{F: func(o *RequestOptions) { o.dialTimeout = t }} } // WithReadTimeout sets read timeout. // // This is the request level configuration. It has a higher // priority than the client level configuration func WithReadTimeout(t time.Duration) RequestOption { return RequestOption{F: func(o *RequestOptions) { o.readTimeout = t }} } // WithWriteTimeout sets write timeout. // // This is the request level configuration. It has a higher // priority than the client level configuration func WithWriteTimeout(t time.Duration) RequestOption { return RequestOption{F: func(o *RequestOptions) { o.writeTimeout = t }} } // WithRequestTimeout sets whole request timeout. If it reaches timeout, // the client will return. // // This is the request level configuration. func WithRequestTimeout(t time.Duration) RequestOption { return RequestOption{F: func(o *RequestOptions) { o.requestTimeout = t }} } func (o *RequestOptions) Apply(opts []RequestOption) { for _, op := range opts { op.F(o) } } func (o *RequestOptions) Tag(k string) string { return o.tags[k] } func (o *RequestOptions) Tags() map[string]string { return o.tags } func (o *RequestOptions) IsSD() bool { return o.isSD } func (o *RequestOptions) DialTimeout() time.Duration { return o.dialTimeout } func (o *RequestOptions) ReadTimeout() time.Duration { return o.readTimeout } func (o *RequestOptions) WriteTimeout() time.Duration { return o.writeTimeout } func (o *RequestOptions) RequestTimeout() time.Duration { return o.requestTimeout } // StartRequest records the start time of the request. // // Note: Users should not call this method. func (o *RequestOptions) StartRequest() { if o.requestTimeout > 0 { o.start = time.Now() } } func (o *RequestOptions) StartTime() time.Time { return o.start } func (o *RequestOptions) CopyTo(dst *RequestOptions) { if dst.tags == nil { dst.tags = make(map[string]string) } for k, v := range o.tags { dst.tags[k] = v } dst.isSD = o.isSD dst.readTimeout = o.readTimeout dst.writeTimeout = o.writeTimeout dst.dialTimeout = o.dialTimeout dst.requestTimeout = o.requestTimeout dst.start = o.start } // SetPreDefinedOpts Pre define some RequestOption here func SetPreDefinedOpts(opts ...RequestOption) { preDefinedOpts = nil preDefinedOpts = append(preDefinedOpts, opts...) } ================================================ FILE: pkg/common/config/request_option_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package config import ( "testing" "time" "github.com/cloudwego/hertz/pkg/common/test/assert" ) // TestRequestOptions test request options with custom values func TestRequestOptions(t *testing.T) { opt := NewRequestOptions([]RequestOption{ WithTag("a", "b"), WithTag("c", "d"), WithTag("e", "f"), WithSD(true), WithDialTimeout(time.Second), WithReadTimeout(time.Second), WithWriteTimeout(time.Second), }) assert.DeepEqual(t, "b", opt.Tag("a")) assert.DeepEqual(t, "d", opt.Tag("c")) assert.DeepEqual(t, "f", opt.Tag("e")) assert.DeepEqual(t, time.Second, opt.DialTimeout()) assert.DeepEqual(t, time.Second, opt.ReadTimeout()) assert.DeepEqual(t, time.Second, opt.WriteTimeout()) assert.True(t, opt.IsSD()) } // TestRequestOptionsWithDefaultOpts test request options with default values func TestRequestOptionsWithDefaultOpts(t *testing.T) { SetPreDefinedOpts(WithTag("pre-defined", "blablabla"), WithTag("a", "default-value"), WithSD(true)) opt := NewRequestOptions([]RequestOption{ WithTag("a", "b"), WithSD(false), }) assert.DeepEqual(t, "b", opt.Tag("a")) assert.DeepEqual(t, "blablabla", opt.Tag("pre-defined")) assert.DeepEqual(t, map[string]string{ "a": "b", "pre-defined": "blablabla", }, opt.Tags()) assert.False(t, opt.IsSD()) SetPreDefinedOpts() assert.Nil(t, preDefinedOpts) assert.DeepEqual(t, time.Duration(0), opt.WriteTimeout()) assert.DeepEqual(t, time.Duration(0), opt.ReadTimeout()) assert.DeepEqual(t, time.Duration(0), opt.DialTimeout()) } // TestRequestOptions_CopyTo test request options copy to another one func TestRequestOptions_CopyTo(t *testing.T) { opt := NewRequestOptions([]RequestOption{ WithTag("a", "b"), WithSD(false), }) var copyOpt RequestOptions opt.CopyTo(©Opt) assert.DeepEqual(t, opt.Tags(), copyOpt.Tags()) assert.DeepEqual(t, opt.IsSD(), copyOpt.IsSD()) } ================================================ FILE: pkg/common/errors/errors.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * The MIT License (MIT) * * Copyright (c) 2014 Manuel Martínez-Almeida * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors */ package errors import ( "errors" "fmt" "reflect" "strings" ) var ( // These errors are the base error, which are used for checking in errors.Is() ErrNeedMore = errors.New("need more data") ErrChunkedStream = errors.New("chunked stream") ErrBodyTooLarge = errors.New("body size exceeds the given limit") ErrHeaderTooLarge = errors.New("header size exceeds the given limit") ErrHijacked = errors.New("connection has been hijacked") ErrTimeout = errors.New("timeout") ErrIdleTimeout = errors.New("idle timeout") ErrNothingRead = errors.New("nothing read") ErrShortConnection = errors.New("short connection") ErrConnectionClosed = errors.New("connection closed") ErrNotSupportProtocol = errors.New("not support protocol") ErrNoMultipartForm = errors.New("request has no multipart/form-data Content-Type") ErrBadPoolConn = errors.New("connection is closed by peer while being in the connection pool") // ErrNoFreeConns is returned by the HTTP client when MaxConnsPerHost (or // HostClient.MaxConns) is set to a positive value and all connections to // the target host are in use. // If MaxConnWaitTimeout is also set, the client waits for a free connection // up to that duration before returning this error. // // Before v0.10.3, MaxConnsPerHost defaulted to consts.DefaultMaxConnsPerHost // (512), so this error could be returned even when no limit was intended. // Since v0.10.3, MaxConnsPerHost defaults to 0 (no limit), meaning this // error is only returned when MaxConnsPerHost is explicitly configured. ErrNoFreeConns = errors.New("no free connections available to host") ) // ErrorType is an unsigned 64-bit error code as defined in the hertz spec. type ErrorType uint64 type Error struct { Err error Type ErrorType Meta interface{} } const ( // ErrorTypeBind is used when Context.Bind() fails. ErrorTypeBind ErrorType = 1 << iota // ErrorTypeRender is used when Context.Render() fails. ErrorTypeRender // ErrorTypePrivate indicates a private error. ErrorTypePrivate // ErrorTypePublic indicates a public error. ErrorTypePublic // ErrorTypeAny indicates any other error. ErrorTypeAny ) type ErrorChain []*Error var _ error = (*Error)(nil) // SetType sets the error's type. func (msg *Error) SetType(flags ErrorType) *Error { msg.Type = flags return msg } // AbortWithMsg implements the error interface. func (msg *Error) Error() string { return msg.Err.Error() } func (a ErrorChain) String() string { if len(a) == 0 { return "" } var buffer strings.Builder for i, msg := range a { fmt.Fprintf(&buffer, "Error #%02d: %s\n", i+1, msg.Err) if msg.Meta != nil { fmt.Fprintf(&buffer, " Meta: %v\n", msg.Meta) } } return buffer.String() } func (msg *Error) Unwrap() error { return msg.Err } // SetMeta sets the error's meta data. func (msg *Error) SetMeta(data interface{}) *Error { msg.Meta = data return msg } // IsType judges one error. func (msg *Error) IsType(flags ErrorType) bool { return (msg.Type & flags) > 0 } // JSON creates a properly formatted JSON func (msg *Error) JSON() interface{} { jsonData := make(map[string]interface{}) if msg.Meta != nil { value := reflect.ValueOf(msg.Meta) switch value.Kind() { case reflect.Struct: return msg.Meta case reflect.Map: for _, key := range value.MapKeys() { jsonData[key.String()] = value.MapIndex(key).Interface() } default: jsonData["meta"] = msg.Meta } } if _, ok := jsonData["error"]; !ok { jsonData["error"] = msg.Error() } return jsonData } // Errors returns an array will all the error messages. // Example: // // c.Error(errors.New("first")) // c.Error(errors.New("second")) // c.Error(errors.New("third")) // c.Errors.Errors() // == []string{"first", "second", "third"} func (a ErrorChain) Errors() []string { if len(a) == 0 { return nil } errorStrings := make([]string, len(a)) for i, err := range a { errorStrings[i] = err.Error() } return errorStrings } // ByType returns a readonly copy filtered the byte. // ie ByType(hertz.ErrorTypePublic) returns a slice of errors with type=ErrorTypePublic. func (a ErrorChain) ByType(typ ErrorType) ErrorChain { if len(a) == 0 { return nil } if typ == ErrorTypeAny { return a } var result ErrorChain for _, msg := range a { if msg.IsType(typ) { result = append(result, msg) } } return result } // Last returns the last error in the slice. It returns nil if the array is empty. // Shortcut for errors[len(errors)-1]. func (a ErrorChain) Last() *Error { if length := len(a); length > 0 { return a[length-1] } return nil } func (a ErrorChain) JSON() interface{} { switch length := len(a); length { case 0: return nil case 1: return a.Last().JSON() default: jsonData := make([]interface{}, length) for i, err := range a { jsonData[i] = err.JSON() } return jsonData } } func New(err error, t ErrorType, meta interface{}) *Error { return &Error{ Err: err, Type: t, Meta: meta, } } // shortcut for creating a public *Error from string func NewPublic(err string) *Error { return New(errors.New(err), ErrorTypePublic, nil) } func NewPrivate(err string) *Error { return New(errors.New(err), ErrorTypePrivate, nil) } func Newf(t ErrorType, meta interface{}, format string, v ...interface{}) *Error { return New(fmt.Errorf(format, v...), t, meta) } func NewPublicf(format string, v ...interface{}) *Error { return New(fmt.Errorf(format, v...), ErrorTypePublic, nil) } func NewPrivatef(format string, v ...interface{}) *Error { return New(fmt.Errorf(format, v...), ErrorTypePrivate, nil) } ================================================ FILE: pkg/common/errors/errors_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * The MIT License (MIT) * * Copyright (c) 2014 Manuel Martínez-Almeida * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors */ package errors import ( "errors" "testing" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestError(t *testing.T) { baseError := errors.New("test error") err := &Error{ Err: baseError, Type: ErrorTypePrivate, } assert.DeepEqual(t, err.Error(), baseError.Error()) assert.DeepEqual(t, map[string]interface{}{"error": baseError.Error()}, err.JSON()) assert.DeepEqual(t, err.SetType(ErrorTypePublic), err) assert.DeepEqual(t, ErrorTypePublic, err.Type) assert.DeepEqual(t, err.SetMeta("some data"), err) assert.DeepEqual(t, "some data", err.Meta) assert.DeepEqual(t, map[string]interface{}{ "error": baseError.Error(), "meta": "some data", }, err.JSON()) err.SetMeta(map[string]interface{}{ // nolint: errcheck "status": "200", "data": "some data", }) assert.DeepEqual(t, map[string]interface{}{ "error": baseError.Error(), "status": "200", "data": "some data", }, err.JSON()) err.SetMeta(map[string]interface{}{ // nolint: errcheck "error": "custom error", "status": "200", "data": "some data", }) assert.DeepEqual(t, map[string]interface{}{ "error": "custom error", "status": "200", "data": "some data", }, err.JSON()) type customError struct { status string data string } err.SetMeta(customError{status: "200", data: "other data"}) // nolint: errcheck assert.DeepEqual(t, customError{status: "200", data: "other data"}, err.JSON()) } func TestErrorSlice(t *testing.T) { errs := ErrorChain{ {Err: errors.New("first"), Type: ErrorTypePrivate}, {Err: errors.New("second"), Type: ErrorTypePrivate, Meta: "some data"}, {Err: errors.New("third"), Type: ErrorTypePublic, Meta: map[string]interface{}{"status": "400"}}, } assert.DeepEqual(t, errs, errs.ByType(ErrorTypeAny)) assert.DeepEqual(t, "third", errs.Last().Error()) assert.DeepEqual(t, []string{"first", "second", "third"}, errs.Errors()) assert.DeepEqual(t, []string{"third"}, errs.ByType(ErrorTypePublic).Errors()) assert.DeepEqual(t, []string{"first", "second"}, errs.ByType(ErrorTypePrivate).Errors()) assert.DeepEqual(t, []string{"first", "second", "third"}, errs.ByType(ErrorTypePublic|ErrorTypePrivate).Errors()) assert.DeepEqual(t, "", errs.ByType(ErrorTypeBind).String()) assert.DeepEqual(t, `Error #01: first Error #02: second Meta: some data Error #03: third Meta: map[status:400] `, errs.String()) assert.DeepEqual(t, []interface{}{ map[string]interface{}{"error": "first"}, map[string]interface{}{"error": "second", "meta": "some data"}, map[string]interface{}{"error": "third", "status": "400"}, }, errs.JSON()) errs = ErrorChain{ {Err: errors.New("first"), Type: ErrorTypePrivate}, } assert.DeepEqual(t, map[string]interface{}{"error": "first"}, errs.JSON()) errs = ErrorChain{} assert.DeepEqual(t, true, errs.Last() == nil) assert.Nil(t, errs.JSON()) assert.DeepEqual(t, "", errs.String()) } func TestErrorFormat(t *testing.T) { err := Newf(ErrorTypeAny, nil, "caused by %s", "reason") assert.DeepEqual(t, New(errors.New("caused by reason"), ErrorTypeAny, nil), err) publicErr := NewPublicf("caused by %s", "reason") assert.DeepEqual(t, New(errors.New("caused by reason"), ErrorTypePublic, nil), publicErr) privateErr := NewPrivatef("caused by %s", "reason") assert.DeepEqual(t, New(errors.New("caused by reason"), ErrorTypePrivate, nil), privateErr) } ================================================ FILE: pkg/common/hlog/consts.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package hlog const ( systemLogPrefix = "HERTZ: " EngineErrorFormat = "Error=%s, remoteAddr=%s" ) ================================================ FILE: pkg/common/hlog/default.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package hlog import ( "context" "fmt" "io" "log" "os" ) // Fatal calls the default logger's Fatal method and then os.Exit(1). func Fatal(v ...interface{}) { logger.Fatal(v...) } // Error calls the default logger's Error method. func Error(v ...interface{}) { logger.Error(v...) } // Warn calls the default logger's Warn method. func Warn(v ...interface{}) { logger.Warn(v...) } // Notice calls the default logger's Notice method. func Notice(v ...interface{}) { logger.Notice(v...) } // Info calls the default logger's Info method. func Info(v ...interface{}) { logger.Info(v...) } // Debug calls the default logger's Debug method. func Debug(v ...interface{}) { logger.Debug(v...) } // Trace calls the default logger's Trace method. func Trace(v ...interface{}) { logger.Trace(v...) } // Fatalf calls the default logger's Fatalf method and then os.Exit(1). func Fatalf(format string, v ...interface{}) { logger.Fatalf(format, v...) } // Errorf calls the default logger's Errorf method. func Errorf(format string, v ...interface{}) { logger.Errorf(format, v...) } // Warnf calls the default logger's Warnf method. func Warnf(format string, v ...interface{}) { logger.Warnf(format, v...) } // Noticef calls the default logger's Noticef method. func Noticef(format string, v ...interface{}) { logger.Noticef(format, v...) } // Infof calls the default logger's Infof method. func Infof(format string, v ...interface{}) { logger.Infof(format, v...) } // Debugf calls the default logger's Debugf method. func Debugf(format string, v ...interface{}) { logger.Debugf(format, v...) } // Tracef calls the default logger's Tracef method. func Tracef(format string, v ...interface{}) { logger.Tracef(format, v...) } // CtxFatalf calls the default logger's CtxFatalf method and then os.Exit(1). func CtxFatalf(ctx context.Context, format string, v ...interface{}) { logger.CtxFatalf(ctx, format, v...) } // CtxErrorf calls the default logger's CtxErrorf method. func CtxErrorf(ctx context.Context, format string, v ...interface{}) { logger.CtxErrorf(ctx, format, v...) } // CtxWarnf calls the default logger's CtxWarnf method. func CtxWarnf(ctx context.Context, format string, v ...interface{}) { logger.CtxWarnf(ctx, format, v...) } // CtxNoticef calls the default logger's CtxNoticef method. func CtxNoticef(ctx context.Context, format string, v ...interface{}) { logger.CtxNoticef(ctx, format, v...) } // CtxInfof calls the default logger's CtxInfof method. func CtxInfof(ctx context.Context, format string, v ...interface{}) { logger.CtxInfof(ctx, format, v...) } // CtxDebugf calls the default logger's CtxDebugf method. func CtxDebugf(ctx context.Context, format string, v ...interface{}) { logger.CtxDebugf(ctx, format, v...) } // CtxTracef calls the default logger's CtxTracef method. func CtxTracef(ctx context.Context, format string, v ...interface{}) { logger.CtxTracef(ctx, format, v...) } type defaultLogger struct { stdlog *log.Logger level Level depth int } func (ll *defaultLogger) SetOutput(w io.Writer) { ll.stdlog.SetOutput(w) } func (ll *defaultLogger) SetLevel(lv Level) { ll.level = lv } func (ll *defaultLogger) logf(lv Level, format *string, v ...interface{}) { if ll.level > lv { return } msg := lv.toString() if format != nil { if len(v) > 0 { msg += fmt.Sprintf(*format, v...) } else { msg += *format } } else { msg += fmt.Sprint(v...) } ll.stdlog.Output(ll.depth, msg) if lv == LevelFatal { os.Exit(1) } } func (ll *defaultLogger) Fatal(v ...interface{}) { ll.logf(LevelFatal, nil, v...) } func (ll *defaultLogger) Error(v ...interface{}) { ll.logf(LevelError, nil, v...) } func (ll *defaultLogger) Warn(v ...interface{}) { ll.logf(LevelWarn, nil, v...) } func (ll *defaultLogger) Notice(v ...interface{}) { ll.logf(LevelNotice, nil, v...) } func (ll *defaultLogger) Info(v ...interface{}) { ll.logf(LevelInfo, nil, v...) } func (ll *defaultLogger) Debug(v ...interface{}) { ll.logf(LevelDebug, nil, v...) } func (ll *defaultLogger) Trace(v ...interface{}) { ll.logf(LevelTrace, nil, v...) } func (ll *defaultLogger) Fatalf(format string, v ...interface{}) { ll.logf(LevelFatal, &format, v...) } func (ll *defaultLogger) Errorf(format string, v ...interface{}) { ll.logf(LevelError, &format, v...) } func (ll *defaultLogger) Warnf(format string, v ...interface{}) { ll.logf(LevelWarn, &format, v...) } func (ll *defaultLogger) Noticef(format string, v ...interface{}) { ll.logf(LevelNotice, &format, v...) } func (ll *defaultLogger) Infof(format string, v ...interface{}) { ll.logf(LevelInfo, &format, v...) } func (ll *defaultLogger) Debugf(format string, v ...interface{}) { ll.logf(LevelDebug, &format, v...) } func (ll *defaultLogger) Tracef(format string, v ...interface{}) { ll.logf(LevelTrace, &format, v...) } func (ll *defaultLogger) CtxFatalf(ctx context.Context, format string, v ...interface{}) { ll.logf(LevelFatal, &format, v...) } func (ll *defaultLogger) CtxErrorf(ctx context.Context, format string, v ...interface{}) { ll.logf(LevelError, &format, v...) } func (ll *defaultLogger) CtxWarnf(ctx context.Context, format string, v ...interface{}) { ll.logf(LevelWarn, &format, v...) } func (ll *defaultLogger) CtxNoticef(ctx context.Context, format string, v ...interface{}) { ll.logf(LevelNotice, &format, v...) } func (ll *defaultLogger) CtxInfof(ctx context.Context, format string, v ...interface{}) { ll.logf(LevelInfo, &format, v...) } func (ll *defaultLogger) CtxDebugf(ctx context.Context, format string, v ...interface{}) { ll.logf(LevelDebug, &format, v...) } func (ll *defaultLogger) CtxTracef(ctx context.Context, format string, v ...interface{}) { ll.logf(LevelTrace, &format, v...) } ================================================ FILE: pkg/common/hlog/default_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package hlog import ( "context" "log" "os" "testing" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func initTestLogger() { logger = &defaultLogger{ stdlog: log.New(os.Stderr, "", 0), depth: 4, } } type byteSliceWriter struct { b []byte } func (w *byteSliceWriter) Write(p []byte) (int, error) { w.b = append(w.b, p...) return len(p), nil } func TestDefaultLogger(t *testing.T) { initTestLogger() var w byteSliceWriter SetOutput(&w) Trace("trace work") Debug("received work order") Info("starting work") Notice("something happens in work") Warn("work may fail") Error("work failed") assert.DeepEqual(t, "[Trace] trace work\n"+ "[Debug] received work order\n"+ "[Info] starting work\n"+ "[Notice] something happens in work\n"+ "[Warn] work may fail\n"+ "[Error] work failed\n", string(w.b)) } func TestDefaultFormatLogger(t *testing.T) { initTestLogger() var w byteSliceWriter SetOutput(&w) work := "work" Tracef("trace %s", work) Debugf("received %s order", work) Infof("starting %s", work) Noticef("something happens in %s", work) Warnf("%s may fail", work) Errorf("%s failed", work) assert.DeepEqual(t, "[Trace] trace work\n"+ "[Debug] received work order\n"+ "[Info] starting work\n"+ "[Notice] something happens in work\n"+ "[Warn] work may fail\n"+ "[Error] work failed\n", string(w.b)) } func TestCtxLogger(t *testing.T) { initTestLogger() var w byteSliceWriter SetOutput(&w) ctx := context.Background() work := "work" CtxTracef(ctx, "trace %s", work) CtxDebugf(ctx, "received %s order", work) CtxInfof(ctx, "starting %s", work) CtxNoticef(ctx, "something happens in %s", work) CtxWarnf(ctx, "%s may fail", work) CtxErrorf(ctx, "%s failed", work) assert.DeepEqual(t, "[Trace] trace work\n"+ "[Debug] received work order\n"+ "[Info] starting work\n"+ "[Notice] something happens in work\n"+ "[Warn] work may fail\n"+ "[Error] work failed\n", string(w.b)) } func TestFormatLoggerWithEscapedCharacters(t *testing.T) { initTestLogger() var w byteSliceWriter SetOutput(&w) Infof("http://localhost:8080/ping?f=http://localhost:3000/hello?c=%E5%A4%A7hi%E5%93%A6%E5%95%8A%E8%AF%B4%E5%BE%97%E5%A5%BD") assert.DeepEqual(t, "[Info] http://localhost:8080/ping?f=http://localhost:3000/hello?c=%E5%A4%A7hi%E5%93%A6%E5%95%8A%E8%AF%B4%E5%BE%97%E5%A5%BD\n", string(w.b)) } func TestSetLevel(t *testing.T) { setLogger := &defaultLogger{ stdlog: log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile|log.Lmicroseconds), depth: 4, } setLogger.SetLevel(LevelTrace) assert.DeepEqual(t, LevelTrace, setLogger.level) assert.DeepEqual(t, LevelTrace.toString(), setLogger.level.toString()) setLogger.SetLevel(LevelDebug) assert.DeepEqual(t, LevelDebug, setLogger.level) assert.DeepEqual(t, LevelDebug.toString(), setLogger.level.toString()) setLogger.SetLevel(LevelInfo) assert.DeepEqual(t, LevelInfo, setLogger.level) assert.DeepEqual(t, LevelInfo.toString(), setLogger.level.toString()) setLogger.SetLevel(LevelNotice) assert.DeepEqual(t, LevelNotice, setLogger.level) assert.DeepEqual(t, LevelNotice.toString(), setLogger.level.toString()) setLogger.SetLevel(LevelWarn) assert.DeepEqual(t, LevelWarn, setLogger.level) assert.DeepEqual(t, LevelWarn.toString(), setLogger.level.toString()) setLogger.SetLevel(LevelError) assert.DeepEqual(t, LevelError, setLogger.level) assert.DeepEqual(t, LevelError.toString(), setLogger.level.toString()) setLogger.SetLevel(LevelFatal) assert.DeepEqual(t, LevelFatal, setLogger.level) assert.DeepEqual(t, LevelFatal.toString(), setLogger.level.toString()) setLogger.SetLevel(7) assert.DeepEqual(t, 7, int(setLogger.level)) assert.DeepEqual(t, "[?7] ", setLogger.level.toString()) } ================================================ FILE: pkg/common/hlog/hlog.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package hlog import ( "io" "log" "os" ) var ( // Provide default logger for users to use logger FullLogger = &defaultLogger{ stdlog: log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile|log.Lmicroseconds), depth: 4, } // Provide system logger for print system log sysLogger FullLogger = &systemLogger{ &defaultLogger{ stdlog: log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile|log.Lmicroseconds), depth: 4, }, systemLogPrefix, } ) // SetOutput sets the output of default logger and system logger. By default, it is stderr. func SetOutput(w io.Writer) { logger.SetOutput(w) sysLogger.SetOutput(w) } // SetLevel sets the level of logs below which logs will not be output. // The default logger and system logger level is LevelTrace. // Note that this method is not concurrent-safe. func SetLevel(lv Level) { logger.SetLevel(lv) sysLogger.SetLevel(lv) } // DefaultLogger return the default logger for hertz. func DefaultLogger() FullLogger { return logger } // SystemLogger return the system logger for hertz to print system log. // This function is not recommended for users to use. func SystemLogger() FullLogger { return sysLogger } // SetSystemLogger sets the system logger. // Note that this method is not concurrent-safe and must not be called // This function is not recommended for users to use. func SetSystemLogger(v FullLogger) { sysLogger = &systemLogger{v, systemLogPrefix} } // SetLogger sets the default logger and the system logger. // Note that this method is not concurrent-safe and must not be called // after the use of DefaultLogger and global functions in this package. func SetLogger(v FullLogger) { logger = v SetSystemLogger(v) } ================================================ FILE: pkg/common/hlog/hlog_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package hlog import ( "log" "os" "testing" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestDefaultAndSysLogger(t *testing.T) { defaultLog := DefaultLogger() systemLog := SystemLogger() assert.DeepEqual(t, logger, defaultLog) assert.DeepEqual(t, sysLogger, systemLog) assert.NotEqual(t, logger, systemLog) assert.NotEqual(t, sysLogger, defaultLog) } func TestSetLogger(t *testing.T) { setLog := &defaultLogger{ stdlog: log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile|log.Lmicroseconds), depth: 6, } setSysLog := &systemLogger{ setLog, systemLogPrefix, } assert.NotEqual(t, logger, setLog) assert.NotEqual(t, sysLogger, setSysLog) SetLogger(setLog) assert.DeepEqual(t, logger, setLog) assert.DeepEqual(t, sysLogger, setSysLog) } ================================================ FILE: pkg/common/hlog/log.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package hlog import ( "context" "fmt" "io" ) // FormatLogger is a logger interface that output logs with a format. type FormatLogger interface { Tracef(format string, v ...interface{}) Debugf(format string, v ...interface{}) Infof(format string, v ...interface{}) Noticef(format string, v ...interface{}) Warnf(format string, v ...interface{}) Errorf(format string, v ...interface{}) Fatalf(format string, v ...interface{}) } // Logger is a logger interface that provides logging function with levels. type Logger interface { Trace(v ...interface{}) Debug(v ...interface{}) Info(v ...interface{}) Notice(v ...interface{}) Warn(v ...interface{}) Error(v ...interface{}) Fatal(v ...interface{}) } // CtxLogger is a logger interface that accepts a context argument and output // logs with a format. type CtxLogger interface { CtxTracef(ctx context.Context, format string, v ...interface{}) CtxDebugf(ctx context.Context, format string, v ...interface{}) CtxInfof(ctx context.Context, format string, v ...interface{}) CtxNoticef(ctx context.Context, format string, v ...interface{}) CtxWarnf(ctx context.Context, format string, v ...interface{}) CtxErrorf(ctx context.Context, format string, v ...interface{}) CtxFatalf(ctx context.Context, format string, v ...interface{}) } // Control provides methods to config a logger. type Control interface { SetLevel(Level) SetOutput(io.Writer) } // FullLogger is the combination of Logger, FormatLogger, CtxLogger and Control. type FullLogger interface { Logger FormatLogger CtxLogger Control } // Level defines the priority of a log message. // When a logger is configured with a level, any log message with a lower // log level (smaller by integer comparison) will not be output. type Level int // The levels of logs. const ( LevelTrace Level = iota LevelDebug LevelInfo LevelNotice LevelWarn LevelError LevelFatal ) var strs = []string{ "[Trace] ", "[Debug] ", "[Info] ", "[Notice] ", "[Warn] ", "[Error] ", "[Fatal] ", } func (lv Level) toString() string { if lv >= LevelTrace && lv <= LevelFatal { return strs[lv] } return fmt.Sprintf("[?%d] ", lv) } ================================================ FILE: pkg/common/hlog/system.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package hlog import ( "context" "io" "strings" "sync" ) var silentMode = false // SetSilentMode is used to mute engine error log, // for example: error when reading request headers. // If true, hertz engine will mute it. func SetSilentMode(s bool) { silentMode = s } var builderPool = sync.Pool{New: func() interface{} { return &strings.Builder{} // nolint:SA6002 }} type systemLogger struct { logger FullLogger prefix string } func (ll *systemLogger) SetOutput(w io.Writer) { ll.logger.SetOutput(w) } func (ll *systemLogger) SetLevel(lv Level) { ll.logger.SetLevel(lv) } func (ll *systemLogger) Fatal(v ...interface{}) { v = append([]interface{}{ll.prefix}, v...) ll.logger.Fatal(v...) } func (ll *systemLogger) Error(v ...interface{}) { v = append([]interface{}{ll.prefix}, v...) ll.logger.Error(v...) } func (ll *systemLogger) Warn(v ...interface{}) { v = append([]interface{}{ll.prefix}, v...) ll.logger.Warn(v...) } func (ll *systemLogger) Notice(v ...interface{}) { v = append([]interface{}{ll.prefix}, v...) ll.logger.Notice(v...) } func (ll *systemLogger) Info(v ...interface{}) { v = append([]interface{}{ll.prefix}, v...) ll.logger.Info(v...) } func (ll *systemLogger) Debug(v ...interface{}) { v = append([]interface{}{ll.prefix}, v...) ll.logger.Debug(v...) } func (ll *systemLogger) Trace(v ...interface{}) { v = append([]interface{}{ll.prefix}, v...) ll.logger.Trace(v...) } func (ll *systemLogger) Fatalf(format string, v ...interface{}) { ll.logger.Fatalf(ll.addPrefix(format), v...) } func (ll *systemLogger) Errorf(format string, v ...interface{}) { if silentMode && format == EngineErrorFormat { return } ll.logger.Errorf(ll.addPrefix(format), v...) } func (ll *systemLogger) Warnf(format string, v ...interface{}) { ll.logger.Warnf(ll.addPrefix(format), v...) } func (ll *systemLogger) Noticef(format string, v ...interface{}) { ll.logger.Noticef(ll.addPrefix(format), v...) } func (ll *systemLogger) Infof(format string, v ...interface{}) { ll.logger.Infof(ll.addPrefix(format), v...) } func (ll *systemLogger) Debugf(format string, v ...interface{}) { ll.logger.Debugf(ll.addPrefix(format), v...) } func (ll *systemLogger) Tracef(format string, v ...interface{}) { ll.logger.Tracef(ll.addPrefix(format), v...) } func (ll *systemLogger) CtxFatalf(ctx context.Context, format string, v ...interface{}) { ll.logger.CtxFatalf(ctx, ll.addPrefix(format), v...) } func (ll *systemLogger) CtxErrorf(ctx context.Context, format string, v ...interface{}) { ll.logger.CtxErrorf(ctx, ll.addPrefix(format), v...) } func (ll *systemLogger) CtxWarnf(ctx context.Context, format string, v ...interface{}) { ll.logger.CtxWarnf(ctx, ll.addPrefix(format), v...) } func (ll *systemLogger) CtxNoticef(ctx context.Context, format string, v ...interface{}) { ll.logger.CtxNoticef(ctx, ll.addPrefix(format), v...) } func (ll *systemLogger) CtxInfof(ctx context.Context, format string, v ...interface{}) { ll.logger.CtxInfof(ctx, ll.addPrefix(format), v...) } func (ll *systemLogger) CtxDebugf(ctx context.Context, format string, v ...interface{}) { ll.logger.CtxDebugf(ctx, ll.addPrefix(format), v...) } func (ll *systemLogger) CtxTracef(ctx context.Context, format string, v ...interface{}) { ll.logger.CtxTracef(ctx, ll.addPrefix(format), v...) } func (ll *systemLogger) addPrefix(format string) string { builder := builderPool.Get().(*strings.Builder) builder.Grow(len(format) + len(ll.prefix)) builder.WriteString(ll.prefix) builder.WriteString(format) s := builder.String() builder.Reset() builderPool.Put(builder) // nolint:SA6002 return s } ================================================ FILE: pkg/common/hlog/system_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package hlog import ( "context" "log" "os" "testing" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func initTestSysLogger() { sysLogger = &systemLogger{ &defaultLogger{ stdlog: log.New(os.Stderr, "", 0), depth: 4, }, systemLogPrefix, } } func TestSysLogger(t *testing.T) { initTestSysLogger() var w byteSliceWriter SetOutput(&w) sysLogger.Trace("trace work") sysLogger.Debug("received work order") sysLogger.Info("starting work") sysLogger.Notice("something happens in work") sysLogger.Warn("work may fail") sysLogger.Error("work failed") assert.DeepEqual(t, "[Trace] HERTZ: trace work\n"+ "[Debug] HERTZ: received work order\n"+ "[Info] HERTZ: starting work\n"+ "[Notice] HERTZ: something happens in work\n"+ "[Warn] HERTZ: work may fail\n"+ "[Error] HERTZ: work failed\n", string(w.b)) } func TestSysFormatLogger(t *testing.T) { initTestSysLogger() var w byteSliceWriter SetOutput(&w) work := "work" sysLogger.Tracef("trace %s", work) sysLogger.Debugf("received %s order", work) sysLogger.Infof("starting %s", work) sysLogger.Noticef("something happens in %s", work) sysLogger.Warnf("%s may fail", work) sysLogger.Errorf("%s failed", work) assert.DeepEqual(t, "[Trace] HERTZ: trace work\n"+ "[Debug] HERTZ: received work order\n"+ "[Info] HERTZ: starting work\n"+ "[Notice] HERTZ: something happens in work\n"+ "[Warn] HERTZ: work may fail\n"+ "[Error] HERTZ: work failed\n", string(w.b)) } func TestSysCtxLogger(t *testing.T) { initTestSysLogger() var w byteSliceWriter SetOutput(&w) ctx := context.Background() work := "work" sysLogger.CtxTracef(ctx, "trace %s", work) sysLogger.CtxDebugf(ctx, "received %s order", work) sysLogger.CtxInfof(ctx, "starting %s", work) sysLogger.CtxNoticef(ctx, "something happens in %s", work) sysLogger.CtxWarnf(ctx, "%s may fail", work) sysLogger.CtxErrorf(ctx, "%s failed", work) assert.DeepEqual(t, "[Trace] HERTZ: trace work\n"+ "[Debug] HERTZ: received work order\n"+ "[Info] HERTZ: starting work\n"+ "[Notice] HERTZ: something happens in work\n"+ "[Warn] HERTZ: work may fail\n"+ "[Error] HERTZ: work failed\n", string(w.b)) } ================================================ FILE: pkg/common/json/sonic.go ================================================ // Copyright 2022 CloudWeGo Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // //go:build (amd64 || arm64) && !stdjson package json import "github.com/bytedance/sonic" // Name is the name of the effective json package. const Name = "sonic" var ( json = sonic.ConfigStd // Marshal is sonic implementation exported by hertz which is used by rendering. Marshal = json.Marshal // Unmarshal is sonic implementation exported by hertz which is used by binding. Unmarshal = json.Unmarshal // MarshalIndent is sonic implementation exported by hertz. MarshalIndent = json.MarshalIndent // NewDecoder is sonic implementation exported by hertz. NewDecoder = json.NewDecoder // NewEncoder is sonic implementation exported by hertz. NewEncoder = json.NewEncoder ) ================================================ FILE: pkg/common/json/std.go ================================================ // Copyright 2022 CloudWeGo Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // //go:build !(amd64 || arm64) || stdjson package json import "encoding/json" // Name is the name of the effective json package. const Name = "encoding/json" var ( // Marshal is standard implementation exported by hertz which is used by rendering. Marshal = json.Marshal // Unmarshal is standard implementation exported by hertz which is used by binding. Unmarshal = json.Unmarshal // MarshalIndent is standard implementation exported by hertz. MarshalIndent = json.MarshalIndent // NewDecoder is standard implementation exported by hertz. NewDecoder = json.NewDecoder // NewEncoder is standard implementation exported by hertz. NewEncoder = json.NewEncoder ) ================================================ FILE: pkg/common/stackless/doc.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // The files in stackless package are forked from fasthttp[github.com/valyala/fasthttp], // and we keep the original Copyright[Copyright 2015 fasthttp authors] and License of fasthttp for those files. // We also need to modify as we need, the modifications are Copyright of 2022 CloudWeGo Authors. // Thanks for fasthttp authors! Below is the source code information: // Repo: github.com/valyala/fasthttp // Forked Version: v1.36.0 package stackless ================================================ FILE: pkg/common/stackless/func.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package stackless import ( "runtime" "sync" ) // NewFunc returns stackless wrapper for the function f. // // Unlike f, the returned stackless wrapper doesn't use stack space // on the goroutine that calls it. // The wrapper may save a lot of stack space if the following conditions // are met: // // - f doesn't contain blocking calls on network, I/O or channels; // - f uses a lot of stack space; // - the wrapper is called from high number of concurrent goroutines. // // The stackless wrapper returns false if the call cannot be processed // at the moment due to high load. func NewFunc(f func(ctx interface{})) func(ctx interface{}) bool { if f == nil { panic("BUG: f cannot be nil") } funcWorkCh := make(chan *funcWork, runtime.GOMAXPROCS(-1)*2048) onceInit := func() { n := runtime.GOMAXPROCS(-1) for i := 0; i < n; i++ { go funcWorker(funcWorkCh, f) } } var once sync.Once return func(ctx interface{}) bool { once.Do(onceInit) fw := getFuncWork() fw.ctx = ctx select { case funcWorkCh <- fw: default: putFuncWork(fw) return false } <-fw.done putFuncWork(fw) return true } } func funcWorker(funcWorkCh <-chan *funcWork, f func(ctx interface{})) { for fw := range funcWorkCh { f(fw.ctx) fw.done <- struct{}{} } } func getFuncWork() *funcWork { v := funcWorkPool.Get() if v == nil { v = &funcWork{ done: make(chan struct{}, 1), } } return v.(*funcWork) } func putFuncWork(fw *funcWork) { fw.ctx = nil funcWorkPool.Put(fw) } var funcWorkPool sync.Pool type funcWork struct { ctx interface{} done chan struct{} } ================================================ FILE: pkg/common/stackless/func_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package stackless import ( "fmt" "sync/atomic" "testing" "time" ) func TestNewFuncSimple(t *testing.T) { var n uint64 f := NewFunc(func(ctx interface{}) { atomic.AddUint64(&n, uint64(ctx.(int))) }) iterations := 4 * 1024 for i := 0; i < iterations; i++ { if !f(2) { t.Fatalf("f mustn't return false") } } if n != uint64(2*iterations) { t.Fatalf("Unexpected n: %d. Expecting %d", n, 2*iterations) } } func TestNewFuncMulti(t *testing.T) { var n1, n2 uint64 f1 := NewFunc(func(ctx interface{}) { atomic.AddUint64(&n1, uint64(ctx.(int))) }) f2 := NewFunc(func(ctx interface{}) { atomic.AddUint64(&n2, uint64(ctx.(int))) }) iterations := 4 * 1024 f1Done := make(chan error, 1) go func() { var err error for i := 0; i < iterations; i++ { if !f1(3) { err = fmt.Errorf("f1 mustn't return false") break } } f1Done <- err }() f2Done := make(chan error, 1) go func() { var err error for i := 0; i < iterations; i++ { if !f2(5) { err = fmt.Errorf("f2 mustn't return false") break } } f2Done <- err }() select { case err := <-f1Done: if err != nil { t.Fatalf("unexpected error: %s", err) } case <-time.After(time.Second): t.Fatalf("timeout") } select { case err := <-f2Done: if err != nil { t.Fatalf("unexpected error: %s", err) } case <-time.After(time.Second): t.Fatalf("timeout") } if n1 != uint64(3*iterations) { t.Fatalf("unexpected n1: %d. Expecting %d", n1, 3*iterations) } if n2 != uint64(5*iterations) { t.Fatalf("unexpected n2: %d. Expecting %d", n2, 5*iterations) } } ================================================ FILE: pkg/common/stackless/func_timing_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package stackless import ( "sync/atomic" "testing" ) func BenchmarkFuncOverhead(b *testing.B) { var n uint64 f := NewFunc(func(ctx interface{}) { atomic.AddUint64(&n, *(ctx.(*uint64))) }) b.RunParallel(func(pb *testing.PB) { x := uint64(1) for pb.Next() { if !f(&x) { b.Fatalf("f mustn't return false") } } }) if n != uint64(b.N) { b.Fatalf("unexpected n: %d. Expecting %d", n, b.N) } } func BenchmarkFuncPure(b *testing.B) { var n uint64 f := func(x *uint64) { atomic.AddUint64(&n, *x) } b.RunParallel(func(pb *testing.PB) { x := uint64(1) for pb.Next() { f(&x) } }) if n != uint64(b.N) { b.Fatalf("unexpected n: %d. Expecting %d", n, b.N) } } ================================================ FILE: pkg/common/stackless/writer.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package stackless import ( "fmt" "io" "github.com/cloudwego/hertz/pkg/common/bytebufferpool" "github.com/cloudwego/hertz/pkg/common/errors" ) // Writer is an interface stackless writer must conform to. // // The interface contains common subset for Writers from compress/* packages. type Writer interface { Write(p []byte) (int, error) Flush() error Close() error Reset(w io.Writer) } // NewWriterFunc must return new writer that will be wrapped into // stackless writer. type NewWriterFunc func(w io.Writer) Writer // NewWriter creates a stackless writer around a writer returned // from newWriter. // // The returned writer writes data to dstW. // // Writers that use a lot of stack space may be wrapped into stackless writer, // thus saving stack space for high number of concurrently running goroutines. func NewWriter(dstW io.Writer, newWriter NewWriterFunc) Writer { w := &writer{ dstW: dstW, } w.zw = newWriter(&w.xw) return w } type writer struct { dstW io.Writer zw Writer xw xWriter err error n int p []byte op op } type op int const ( opWrite op = iota opFlush opClose opReset ) func (w *writer) Write(p []byte) (int, error) { w.p = p err := w.do(opWrite) w.p = nil return w.n, err } func (w *writer) Flush() error { return w.do(opFlush) } func (w *writer) Close() error { return w.do(opClose) } func (w *writer) Reset(dstW io.Writer) { w.xw.Reset() w.do(opReset) //nolint:errcheck w.dstW = dstW } func (w *writer) do(op op) error { w.op = op if !stacklessWriterFunc(w) { return errHighLoad } err := w.err if err != nil { return err } if w.xw.bb != nil && len(w.xw.bb.B) > 0 { _, err = w.dstW.Write(w.xw.bb.B) } w.xw.Reset() return err } var errHighLoad = errors.NewPublic("cannot compress data due to high load") var stacklessWriterFunc = NewFunc(writerFunc) func writerFunc(ctx interface{}) { w := ctx.(*writer) switch w.op { case opWrite: w.n, w.err = w.zw.Write(w.p) case opFlush: w.err = w.zw.Flush() case opClose: w.err = w.zw.Close() case opReset: w.zw.Reset(&w.xw) w.err = nil default: panic(fmt.Sprintf("BUG: unexpected op: %d", w.op)) } } type xWriter struct { bb *bytebufferpool.ByteBuffer } func (w *xWriter) Write(p []byte) (int, error) { if w.bb == nil { w.bb = bufferPool.Get() } return w.bb.Write(p) } func (w *xWriter) Reset() { if w.bb != nil { bufferPool.Put(w.bb) w.bb = nil } } var bufferPool bytebufferpool.Pool ================================================ FILE: pkg/common/stackless/writer_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package stackless import ( "bytes" "compress/flate" "compress/gzip" "fmt" "io" "io/ioutil" "testing" "time" ) func TestCompressFlateSerial(t *testing.T) { if err := testCompressFlate(); err != nil { t.Fatalf("unexpected error: %s", err) } } func TestCompressFlateConcurrent(t *testing.T) { if err := testConcurrent(testCompressFlate, 10); err != nil { t.Fatalf("unexpected error: %s", err) } } func testCompressFlate() error { return testWriter(func(w io.Writer) Writer { zw, err := flate.NewWriter(w, flate.DefaultCompression) if err != nil { panic(fmt.Sprintf("BUG: unexpected error: %s", err)) } return zw }, func(r io.Reader) io.Reader { return flate.NewReader(r) }) } func TestCompressGzipSerial(t *testing.T) { if err := testCompressGzip(); err != nil { t.Fatalf("unexpected error: %s", err) } } func TestCompressGzipConcurrent(t *testing.T) { if err := testConcurrent(testCompressGzip, 10); err != nil { t.Fatalf("unexpected error: %s", err) } } func testCompressGzip() error { return testWriter(func(w io.Writer) Writer { return gzip.NewWriter(w) }, func(r io.Reader) io.Reader { zr, err := gzip.NewReader(r) if err != nil { panic(fmt.Sprintf("BUG: cannot create gzip reader: %s", err)) } return zr }) } func testWriter(newWriter NewWriterFunc, newReader func(io.Reader) io.Reader) error { dstW := &bytes.Buffer{} w := NewWriter(dstW, newWriter) for i := 0; i < 5; i++ { if err := testWriterReuse(w, dstW, newReader); err != nil { return fmt.Errorf("unexpected error when re-using writer on iteration %d: %s", i, err) } dstW = &bytes.Buffer{} w.Reset(dstW) } return nil } func testWriterReuse(w Writer, r io.Reader, newReader func(io.Reader) io.Reader) error { wantW := &bytes.Buffer{} mw := io.MultiWriter(w, wantW) for i := 0; i < 30; i++ { fmt.Fprintf(mw, "foobar %d\n", i) if i%13 == 0 { if err := w.Flush(); err != nil { return fmt.Errorf("error on flush: %s", err) } } } w.Close() zr := newReader(r) data, err := ioutil.ReadAll(zr) if err != nil { return fmt.Errorf("unexpected error: %s, data=%q", err, data) } wantData := wantW.Bytes() if !bytes.Equal(data, wantData) { return fmt.Errorf("unexpected data: %q. Expecting %q", data, wantData) } return nil } func testConcurrent(testFunc func() error, concurrency int) error { ch := make(chan error, concurrency) for i := 0; i < concurrency; i++ { go func() { ch <- testFunc() }() } for i := 0; i < concurrency; i++ { select { case err := <-ch: if err != nil { return fmt.Errorf("unexpected error on goroutine %d: %s", i, err) } case <-time.After(time.Second): return fmt.Errorf("timeout on goroutine %d", i) } } return nil } ================================================ FILE: pkg/common/test/assert/assert.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package assert import ( "reflect" ) type testingT interface { Helper() Fatal(args ...any) Fatalf(format string, args ...any) } // Assert . func Assert(t testingT, cond bool, val ...interface{}) { t.Helper() if !cond { if len(val) > 0 { val = append([]interface{}{"assertion failed:"}, val...) t.Fatal(val...) } else { t.Fatal("assertion failed") } } } // Assertf . func Assertf(t testingT, cond bool, format string, val ...interface{}) { t.Helper() if !cond { t.Fatalf(format, val...) } } // DeepEqual . func DeepEqual(t testingT, expected, actual interface{}) { t.Helper() if !reflect.DeepEqual(actual, expected) { t.Fatalf("assertion failed, unexpected: %v, expected: %v", actual, expected) } } func isNil(rv reflect.Value) bool { switch rv.Kind() { case reflect.Chan, reflect.Func, reflect.Map, reflect.Pointer, reflect.UnsafePointer, reflect.Interface, reflect.Slice: if rv.IsNil() { return true } } return false } func Nil(t testingT, data interface{}) { t.Helper() if data == nil || isNil(reflect.ValueOf(data)) { return } t.Fatalf("assertion failed, unexpected: %v, expected: nil", data) } func NotNil(t testingT, data interface{}) { t.Helper() if data == nil || isNil(reflect.ValueOf(data)) { t.Fatalf("assertion failed, unexpected: %v, expected: not nil", data) } } // NotEqual . func NotEqual(t testingT, expected, actual interface{}) { t.Helper() if expected == nil || actual == nil { if expected == actual { t.Fatalf("assertion failed: %v == %v", actual, expected) } } if reflect.DeepEqual(actual, expected) { t.Fatalf("assertion failed: %v == %v", actual, expected) } } func True(t testingT, obj interface{}) { t.Helper() DeepEqual(t, true, obj) } func False(t testingT, obj interface{}) { t.Helper() DeepEqual(t, false, obj) } // Panic . func Panic(t testingT, fn func()) { t.Helper() defer func() { if err := recover(); err == nil { t.Fatal("assertion failed: did not panic") } }() fn() } // NotPanic . func NotPanic(t testingT, fn func()) { t.Helper() defer func() { if err := recover(); err != nil { t.Fatal("assertion failed: panicked") } }() fn() } ================================================ FILE: pkg/common/test/assert/assert_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package assert import ( "fmt" "strings" "testing" ) type mockTestingT struct { fatalstr string } func (mockTestingT) Helper() {} func (m *mockTestingT) Reset() { m.fatalstr = "" } func (m *mockTestingT) Fatal(args ...any) { m.fatalstr = fmt.Sprintln(args...) } func (m *mockTestingT) Fatalf(fm string, args ...any) { m.fatalstr = fmt.Sprintf(fm, args...) } func (m *mockTestingT) String() string { return m.fatalstr } func (m *mockTestingT) Expect(t *testing.T, s string) { t.Helper() got := strings.TrimSpace(m.fatalstr) if got != s { t.Fatalf("got: %q expect: %q", got, s) } m.Reset() } func TestAssert(t *testing.T) { m := &mockTestingT{} Assert(m, true) m.Expect(t, "") Assert(m, false) m.Expect(t, "assertion failed") Assert(m, false, "hello") m.Expect(t, "assertion failed: hello") Assertf(m, true, "hello %s", "world") m.Expect(t, "") Assertf(m, false, "hello %s", "world") m.Expect(t, "hello world") } func TestNil(t *testing.T) { m := &mockTestingT{} Nil(m, nil) m.Expect(t, "") Nil(m, (*testing.T)(nil)) m.Expect(t, "") Nil(m, 1) m.Expect(t, "assertion failed, unexpected: 1, expected: nil") Nil(m, "hello") m.Expect(t, "assertion failed, unexpected: hello, expected: nil") NotNil(m, 1) m.Expect(t, "") NotNil(m, "hello") m.Expect(t, "") NotNil(m, struct { hello string }{}) m.Expect(t, "") NotNil(m, (*testing.T)(nil)) m.Expect(t, `assertion failed, unexpected: , expected: not nil`) } func TestDeepEqual(t *testing.T) { m := &mockTestingT{} DeepEqual(m, 1, 1) m.Expect(t, "") DeepEqual(m, 1, 2) m.Expect(t, `assertion failed, unexpected: 2, expected: 1`) } func TestNotEqual(t *testing.T) { m := &mockTestingT{} NotEqual(m, 1, 2) m.Expect(t, "") NotEqual(m, nil, nil) m.Expect(t, `assertion failed: == `) } func TestTrueFalse(t *testing.T) { m := &mockTestingT{} True(m, true) m.Expect(t, "") False(m, false) m.Expect(t, "") } func TestPanic(t *testing.T) { m := &mockTestingT{} Panic(m, func() { panic("hello") }) m.Expect(t, "") Panic(m, func() { }) m.Expect(t, `assertion failed: did not panic`) NotPanic(m, func() { }) m.Expect(t, "") NotPanic(m, func() { panic("hello") }) m.Expect(t, `assertion failed: panicked`) } ================================================ FILE: pkg/common/test/mock/body_data.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package mock import "fmt" func CreateFixedBody(bodySize int) []byte { var b []byte for i := 0; i < bodySize; i++ { b = append(b, byte(i%10)+'0') } return b } func CreateChunkedBody(body []byte, trailer map[string]string, hasTrailer bool) []byte { var b []byte chunkSize := 1 for len(body) > 0 { if chunkSize > len(body) { chunkSize = len(body) } b = append(b, []byte(fmt.Sprintf("%x\r\n", chunkSize))...) b = append(b, body[:chunkSize]...) b = append(b, []byte("\r\n")...) body = body[chunkSize:] chunkSize++ } if hasTrailer { b = append(b, "0\r\n"...) for k, v := range trailer { b = append(b, k...) b = append(b, ": "...) b = append(b, v...) b = append(b, "\r\n"...) } b = append(b, "\r\n"...) } return b } ================================================ FILE: pkg/common/test/mock/body_data_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package mock import ( "testing" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestGenerateCreateFixedBody(t *testing.T) { bodySize := 10 resFixedBody := "0123456789" b := CreateFixedBody(bodySize) if string(b) != resFixedBody { t.Fatalf("Unexpected %s. Expecting %s.", b, resFixedBody) } nilFixedBody := CreateFixedBody(0) if nilFixedBody != nil { t.Fatalf("Unexpected %s. Expecting a nil", nilFixedBody) } } func TestGenerateCreateChunkedBody(t *testing.T) { bodySize := 10 b := CreateFixedBody(bodySize) trailer := map[string]string{"Foo": "chunked shit"} expectCb := "1\r\n0\r\n2\r\n12\r\n3\r\n345\r\n4\r\n6789\r\n0\r\nFoo: chunked shit\r\n\r\n" cb := CreateChunkedBody(b, trailer, true) assert.DeepEqual(t, expectCb, string(cb)) } ================================================ FILE: pkg/common/test/mock/network.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package mock import ( "bytes" "crypto/tls" "io" "net" "strings" "time" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/netpoll" ) var ( ErrReadTimeout = errs.New(errs.ErrTimeout, errs.ErrorTypePublic, "read timeout") ErrWriteTimeout = errs.New(errs.ErrTimeout, errs.ErrorTypePublic, "write timeout") ) type Conn struct { zr network.Reader zw network.ReadWriter wroteLen int rtimeout time.Duration wtimeout time.Duration } type Recorder interface { network.Reader WroteLen() int } func (m *Conn) SetWriteTimeout(t time.Duration) error { // TODO implement me return nil } type SlowReadConn struct { *Conn } func (m *SlowReadConn) SetWriteTimeout(t time.Duration) error { return nil } func (m *SlowReadConn) SetReadTimeout(t time.Duration) error { m.Conn.rtimeout = t return nil } func SlowReadDialer(addr string) (network.Conn, error) { return NewSlowReadConn(""), nil } func SlowWriteDialer(addr string) (network.Conn, error) { return NewSlowWriteConn(""), nil } func (m *Conn) ReadBinary(n int) (p []byte, err error) { return m.zr.(netpoll.Reader).ReadBinary(n) } func (m *Conn) Read(b []byte) (n int, err error) { return netpoll.NewIOReader(m.zr.(netpoll.Reader)).Read(b) } func (m *Conn) Write(b []byte) (n int, err error) { return netpoll.NewIOWriter(m.zw.(netpoll.ReadWriter)).Write(b) } func (m *Conn) Release() error { return nil } func (m *Conn) Peek(i int) ([]byte, error) { b, err := m.zr.Peek(i) if err != nil || len(b) != i { if m.rtimeout <= 0 { // simulate timeout forever select {} } time.Sleep(m.rtimeout) return nil, errs.ErrTimeout } return b, err } func (m *Conn) Skip(n int) error { return m.zr.Skip(n) } func (m *Conn) ReadByte() (byte, error) { return m.zr.ReadByte() } func (m *Conn) Len() int { return m.zr.Len() } func (m *Conn) Malloc(n int) (buf []byte, err error) { m.wroteLen += n return m.zw.Malloc(n) } func (m *Conn) WriteBinary(b []byte) (n int, err error) { n, err = m.zw.WriteBinary(b) m.wroteLen += n return n, err } func (m *Conn) Flush() error { return m.zw.Flush() } func (m *Conn) WriterRecorder() Recorder { return &recorder{c: m, Reader: m.zw} } func (m *Conn) GetReadTimeout() time.Duration { return m.rtimeout } func (m *Conn) GetWriteTimeout() time.Duration { return m.wtimeout } type recorder struct { c *Conn network.Reader } func (r *recorder) WroteLen() int { return r.c.wroteLen } func (m *SlowReadConn) Peek(i int) ([]byte, error) { b, err := m.zr.Peek(i) if m.rtimeout > 0 { time.Sleep(m.rtimeout) } else { time.Sleep(100 * time.Millisecond) } if err != nil || len(b) != i { return nil, ErrReadTimeout } return b, err } func NewConn(source string) *Conn { zr := netpoll.NewReader(strings.NewReader(source)) zw := netpoll.NewReadWriter(&bytes.Buffer{}) return &Conn{ zr: zr, zw: zw, } } type BrokenConn struct { *Conn } func (o *BrokenConn) Peek(i int) ([]byte, error) { return nil, io.ErrUnexpectedEOF } func (o *BrokenConn) Read(b []byte) (n int, err error) { return 0, io.ErrUnexpectedEOF } func (o *BrokenConn) Flush() error { return errs.ErrConnectionClosed } func NewBrokenConn(source string) *BrokenConn { return &BrokenConn{Conn: NewConn(source)} } type OneTimeConn struct { isRead bool isFlushed bool contentLength int *Conn } func (o *OneTimeConn) Peek(n int) ([]byte, error) { if o.isRead { return nil, io.EOF } return o.Conn.Peek(n) } func (o *OneTimeConn) Skip(n int) error { if o.isRead { return io.EOF } o.contentLength -= n if o.contentLength == 0 { o.isRead = true } return o.Conn.Skip(n) } func (o *OneTimeConn) Flush() error { if o.isFlushed { return errs.ErrConnectionClosed } o.isFlushed = true return o.Conn.Flush() } func NewOneTimeConn(source string) *OneTimeConn { return &OneTimeConn{isRead: false, isFlushed: false, Conn: NewConn(source), contentLength: len(source)} } func NewSlowReadConn(source string) *SlowReadConn { return &SlowReadConn{Conn: NewConn(source)} } type ErrorReadConn struct { *Conn errorToReturn error } func NewErrorReadConn(err error) *ErrorReadConn { return &ErrorReadConn{ Conn: NewConn(""), errorToReturn: err, } } func (er *ErrorReadConn) Peek(n int) ([]byte, error) { return nil, er.errorToReturn } type SlowWriteConn struct { *Conn writeTimeout time.Duration } func (m *SlowWriteConn) SetWriteTimeout(t time.Duration) error { m.writeTimeout = t return nil } func NewSlowWriteConn(source string) *SlowWriteConn { return &SlowWriteConn{NewConn(source), 0} } func (m *SlowWriteConn) Flush() error { err := m.zw.Flush() if err == nil { time.Sleep(m.writeTimeout) return ErrWriteTimeout } return err } func (m *Conn) Close() error { return nil } func (m *Conn) LocalAddr() net.Addr { return nil } func (m *Conn) RemoteAddr() net.Addr { return nil } func (m *Conn) SetDeadline(t time.Time) error { m.rtimeout = -time.Since(t) m.wtimeout = m.rtimeout return nil } func (m *Conn) SetReadDeadline(t time.Time) error { m.rtimeout = -time.Since(t) return nil } func (m *Conn) SetWriteDeadline(t time.Time) error { panic("implement me") } func (m *Conn) Reader() network.Reader { return m.zr } func (m *Conn) Writer() network.Writer { return m.zw } func (m *Conn) IsActive() bool { panic("implement me") } func (m *Conn) SetIdleTimeout(timeout time.Duration) error { return nil } func (m *Conn) SetReadTimeout(t time.Duration) error { m.rtimeout = t return nil } func (m *Conn) SetOnRequest(on netpoll.OnRequest) error { panic("implement me") } func (m *Conn) AddCloseCallback(callback netpoll.CloseCallback) error { panic("implement me") } type StreamConn struct { HasReleased bool Data []byte } func NewStreamConn() *StreamConn { return &StreamConn{ Data: make([]byte, 1<<15, 1<<16), } } func (m *StreamConn) Peek(n int) ([]byte, error) { if len(m.Data) >= n { return m.Data[:n], nil } if n == 1 { m.Data = m.Data[:cap(m.Data)] return m.Data[:1], nil } return nil, errs.NewPublic("not enough data") } func (m *StreamConn) Skip(n int) error { if len(m.Data) >= n { m.Data = m.Data[n:] return nil } return errs.NewPublic("not enough data") } func (m *StreamConn) Release() error { m.HasReleased = true return nil } func (m *StreamConn) Len() int { return len(m.Data) } func (m *StreamConn) ReadByte() (byte, error) { panic("implement me") } func (m *StreamConn) ReadBinary(n int) (p []byte, err error) { panic("implement me") } func DialerFun(addr string) (network.Conn, error) { return NewConn(""), nil } type MockWriter struct { w network.Writer MockMalloc func(n int) (buf []byte, err error) MockWriteBinary func(b []byte) (n int, err error) MockFlush func() error } func NewMockWriter(w network.Writer) *MockWriter { return &MockWriter{w: w} } func (m *MockWriter) Malloc(n int) (buf []byte, err error) { if m.MockMalloc != nil { return m.MockMalloc(n) } return m.w.Malloc(n) } func (m *MockWriter) WriteBinary(b []byte) (n int, err error) { if m.MockWriteBinary != nil { return m.MockWriteBinary(b) } return m.w.WriteBinary(b) } func (m *MockWriter) Flush() error { if m.MockFlush != nil { return m.MockFlush() } return m.w.Flush() } type TLSConn struct { network.Conn HandshakeErr error } var _ network.ConnTLSer = (*TLSConn)(nil) func (c *TLSConn) Handshake() error { return c.HandshakeErr } func (c *TLSConn) ConnectionState() tls.ConnectionState { return tls.ConnectionState{} } func NewTLSConn(conn network.Conn) *TLSConn { return &TLSConn{Conn: conn} } ================================================ FILE: pkg/common/test/mock/network_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package mock import ( "context" "io" "testing" "time" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/netpoll" ) func TestConn(t *testing.T) { t.Run("TestReader", func(t *testing.T) { s1 := "abcdef4343" conn1 := NewConn(s1) assert.Nil(t, conn1.SetWriteTimeout(1)) err := conn1.SetReadDeadline(time.Now().Add(time.Millisecond * 100)) assert.DeepEqual(t, nil, err) err = conn1.SetReadTimeout(time.Millisecond * 100) assert.DeepEqual(t, nil, err) assert.DeepEqual(t, time.Millisecond*100, conn1.GetReadTimeout()) // Peek Skip Read b, _ := conn1.Peek(1) assert.DeepEqual(t, []byte{'a'}, b) conn1.Skip(1) readByte, _ := conn1.ReadByte() assert.DeepEqual(t, byte('b'), readByte) p := make([]byte, 100) n, err := conn1.Read(p) assert.DeepEqual(t, nil, err) assert.DeepEqual(t, s1[2:], string(p[:n])) _, err = conn1.Peek(1) assert.DeepEqual(t, errs.ErrTimeout, err) conn2 := NewConn(s1) p, _ = conn2.ReadBinary(len(s1)) assert.DeepEqual(t, s1, string(p)) assert.DeepEqual(t, 0, conn2.Len()) // Reader assert.DeepEqual(t, conn2.zr, conn2.Reader()) }) t.Run("TestReadWriter", func(t *testing.T) { s1 := "abcdef4343" conn := NewConn(s1) p, err := conn.ReadBinary(len(s1)) assert.DeepEqual(t, nil, err) assert.DeepEqual(t, s1, string(p)) wr := conn.WriterRecorder() s2 := "efghljk" // WriteBinary n, err := conn.WriteBinary([]byte(s2)) assert.DeepEqual(t, nil, err) assert.DeepEqual(t, len(s2), n) assert.DeepEqual(t, len(s2), wr.WroteLen()) // Flush p, _ = wr.ReadBinary(len(s2)) assert.DeepEqual(t, len(p), 0) conn.Flush() p, _ = wr.ReadBinary(len(s2)) assert.DeepEqual(t, s2, string(p)) // Write s3 := "foobarbaz" n, err = conn.Write([]byte(s3)) assert.DeepEqual(t, nil, err) assert.DeepEqual(t, len(s3), n) p, _ = wr.ReadBinary(len(s3)) assert.DeepEqual(t, s3, string(p)) // Malloc buf, _ := conn.Malloc(10) assert.DeepEqual(t, 10, len(buf)) // Writer assert.DeepEqual(t, conn.zw, conn.Writer()) _, err = DialerFun("") assert.DeepEqual(t, nil, err) }) t.Run("TestNotImplement", func(t *testing.T) { conn := NewConn("") t1 := time.Now().Add(time.Millisecond) du1 := time.Second assert.DeepEqual(t, nil, conn.Release()) assert.DeepEqual(t, nil, conn.Close()) assert.DeepEqual(t, nil, conn.LocalAddr()) assert.DeepEqual(t, nil, conn.RemoteAddr()) assert.DeepEqual(t, nil, conn.SetIdleTimeout(du1)) assert.Panic(t, func() { conn.SetWriteDeadline(t1) }) assert.Panic(t, func() { conn.IsActive() }) assert.Panic(t, func() { conn.SetOnRequest(func(ctx context.Context, connection netpoll.Connection) error { return nil }) }) assert.Panic(t, func() { conn.AddCloseCallback(func(connection netpoll.Connection) error { return nil }) }) }) } func TestSlowConn(t *testing.T) { t.Run("TestSlowReadConn", func(t *testing.T) { s1 := "abcdefg" conn := NewSlowReadConn(s1) assert.Nil(t, conn.SetWriteTimeout(1)) assert.Nil(t, conn.SetReadTimeout(1)) assert.DeepEqual(t, time.Duration(1), conn.GetReadTimeout()) b, err := conn.Peek(4) assert.DeepEqual(t, nil, err) assert.DeepEqual(t, s1[:4], string(b)) conn.Skip(len(s1)) _, err = conn.Peek(1) assert.DeepEqual(t, ErrReadTimeout, err) _, err = SlowReadDialer("") assert.DeepEqual(t, nil, err) }) t.Run("TestSlowWriteConn", func(t *testing.T) { conn, err := SlowWriteDialer("") assert.DeepEqual(t, nil, err) conn.SetWriteTimeout(time.Millisecond * 100) err = conn.Flush() assert.DeepEqual(t, ErrWriteTimeout, err) }) } func TestStreamConn(t *testing.T) { t.Run("TestStreamConn", func(t *testing.T) { conn := NewStreamConn() _, err := conn.Peek(10) assert.DeepEqual(t, nil, err) conn.Skip(conn.Len()) assert.DeepEqual(t, 0, conn.Len()) _, err = conn.Peek(10) assert.DeepEqual(t, "not enough data", err.Error()) _, err = conn.Peek(1) assert.DeepEqual(t, nil, err) assert.DeepEqual(t, cap(conn.Data), conn.Len()) err = conn.Skip(conn.Len() + 1) assert.DeepEqual(t, "not enough data", err.Error()) err = conn.Release() assert.DeepEqual(t, nil, err) assert.DeepEqual(t, true, conn.HasReleased) }) t.Run("TestNotImplement", func(t *testing.T) { conn := NewStreamConn() assert.Panic(t, func() { conn.ReadByte() }) assert.Panic(t, func() { conn.ReadBinary(10) }) }) } func TestBrokenConn_Flush(t *testing.T) { conn := NewBrokenConn("") n, err := conn.Writer().WriteBinary([]byte("Foo")) assert.DeepEqual(t, 3, n) assert.Nil(t, err) assert.DeepEqual(t, errs.ErrConnectionClosed, conn.Flush()) } func TestBrokenConn_Peek(t *testing.T) { conn := NewBrokenConn("Foo") buf, err := conn.Peek(3) assert.Nil(t, buf) assert.DeepEqual(t, io.ErrUnexpectedEOF, err) } func TestOneTimeConn_Flush(t *testing.T) { conn := NewOneTimeConn("") n, err := conn.Writer().WriteBinary([]byte("Foo")) assert.DeepEqual(t, 3, n) assert.Nil(t, err) assert.Nil(t, conn.Flush()) n, err = conn.Writer().WriteBinary([]byte("Bar")) assert.DeepEqual(t, 3, n) assert.Nil(t, err) assert.DeepEqual(t, errs.ErrConnectionClosed, conn.Flush()) } func TestOneTimeConn_Skip(t *testing.T) { conn := NewOneTimeConn("FooBar") buf, err := conn.Peek(3) assert.DeepEqual(t, "Foo", string(buf)) assert.Nil(t, err) assert.Nil(t, conn.Skip(3)) assert.DeepEqual(t, 3, conn.contentLength) buf, err = conn.Peek(3) assert.DeepEqual(t, "Bar", string(buf)) assert.Nil(t, err) assert.Nil(t, conn.Skip(3)) assert.DeepEqual(t, 0, conn.contentLength) buf, err = conn.Peek(3) assert.DeepEqual(t, 0, len(buf)) assert.DeepEqual(t, io.EOF, err) assert.DeepEqual(t, io.EOF, conn.Skip(3)) assert.DeepEqual(t, 0, conn.contentLength) } ================================================ FILE: pkg/common/test/mock/reader.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package mock import ( "bufio" "bytes" "io" ) // ZeroCopyReader is used to create ZeroCopyReader for testing. // // NOTE: In principle, ut should use the zcReader created by netpoll.NewReader() for mock testing, // but because zcReader does not implement the io.Reader interface, the test requirements of // io.Reader involved are replaced with MockZeroCopyReader type ZeroCopyReader struct { *bufio.Reader } func (m ZeroCopyReader) Peek(n int) ([]byte, error) { b, err := m.Reader.Peek(n) // if n is bigger than the buffer in m.Reader, // it will only return bufio.ErrBufferFull even if the underline reader return io.EOF. // so we make another Peek to get the real error. // for more info: https://github.com/golang/go/issues/50569 if err == bufio.ErrBufferFull && len(b) == 0 { return m.Reader.Peek(1) } return b, err } func (m ZeroCopyReader) Skip(n int) (err error) { _, err = m.Reader.Discard(n) return } func (m ZeroCopyReader) Release() (err error) { return nil } func (m ZeroCopyReader) Len() (length int) { return m.Reader.Buffered() } func (m ZeroCopyReader) ReadBinary(n int) (p []byte, err error) { panic("implement me") } func NewZeroCopyReader(r string) ZeroCopyReader { br := bufio.NewReaderSize(bytes.NewBufferString(r), len(r)) return ZeroCopyReader{br} } func NewLimitReader(r *bytes.Buffer) io.LimitedReader { return io.LimitedReader{ R: r, N: int64(r.Len()), } } type EOFReader struct{} func (e *EOFReader) Peek(n int) ([]byte, error) { return []byte{}, io.EOF } func (e *EOFReader) Skip(n int) error { return nil } func (e *EOFReader) Release() error { return nil } func (e *EOFReader) Len() int { return 0 } func (e *EOFReader) ReadByte() (byte, error) { return ' ', io.EOF } func (e *EOFReader) ReadBinary(n int) (p []byte, err error) { return p, io.EOF } func (e *EOFReader) Read(p []byte) (n int, err error) { return 0, io.EOF } ================================================ FILE: pkg/common/test/mock/reader_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package mock import ( "bufio" "io" "testing" "testing/iotest" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestZeroCopyReader(t *testing.T) { // raw r := "abcdef4343" zr := NewZeroCopyReader(r) rs := readBytes(zr.Reader) assert.DeepEqual(t, rs, r) // peek zr = NewZeroCopyReader(r) s, err := zr.Peek(1) assert.DeepEqual(t, nil, err) assert.DeepEqual(t, "a", string(s)) s, err = zr.Peek(4) assert.DeepEqual(t, nil, err) assert.DeepEqual(t, "abcd", string(s)) // https://github.com/golang/go/issues/50569 ezr := NewZeroCopyReader("") s, err = ezr.Peek(32) assert.DeepEqual(t, io.EOF, err) assert.DeepEqual(t, "", string(s)) // skip err = zr.Skip(1) assert.DeepEqual(t, nil, err) s, err = zr.Peek(4) assert.DeepEqual(t, nil, err) assert.DeepEqual(t, "bcde", string(s)) // len assert.DeepEqual(t, len(r)-1, zr.Len()) assert.DeepEqual(t, nil, ezr.Release()) assert.Panic(t, func() { // not implement zr.ReadBinary(10) }) } func TestEOFReader(t *testing.T) { r := &EOFReader{} s, err := r.Peek(1) assert.DeepEqual(t, io.EOF, err) assert.DeepEqual(t, "", string(s)) assert.DeepEqual(t, nil, r.Skip(1)) assert.DeepEqual(t, 0, r.Len()) _, err = r.ReadByte() assert.DeepEqual(t, io.EOF, err) _, err = r.ReadBinary(10) assert.DeepEqual(t, io.EOF, err) _, err = r.Read(s) assert.DeepEqual(t, io.EOF, err) assert.DeepEqual(t, nil, r.Release()) } func readBytes(buf *bufio.Reader) string { var b [1000]byte nb := 0 for { c, err := buf.ReadByte() if err == io.EOF { break } if err == nil { b[nb] = c nb++ } else if err != iotest.ErrTimeout { panic("Data: " + err.Error()) } } return string(b[0:nb]) } ================================================ FILE: pkg/common/test/mock/writer.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package mock import "bytes" type ExtWriter struct { tmp []byte Buf *bytes.Buffer IsFinal *bool } func (m *ExtWriter) Write(p []byte) (n int, err error) { m.tmp = p return len(p), nil } func (m *ExtWriter) Flush() error { _, err := m.Buf.Write(m.tmp) return err } func (m *ExtWriter) Finalize() error { if !*m.IsFinal { *m.IsFinal = true } return nil } func (m *ExtWriter) SetBody(body []byte) { m.Buf.Reset() m.tmp = body } ================================================ FILE: pkg/common/test/mock/writer_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package mock import ( "bytes" "testing" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestExtWriter(t *testing.T) { b1 := []byte("abcdef4343") buf := new(bytes.Buffer) isFinal := false w := &ExtWriter{ Buf: buf, IsFinal: &isFinal, } // write n, err := w.Write(b1) assert.DeepEqual(t, nil, err) assert.DeepEqual(t, len(b1), n) // flush err = w.Flush() assert.DeepEqual(t, nil, err) assert.DeepEqual(t, b1, w.Buf.Bytes()) // setbody b2 := []byte("abc") w.SetBody(b2) err = w.Flush() assert.DeepEqual(t, nil, err) assert.DeepEqual(t, b2, w.Buf.Bytes()) w.Finalize() assert.DeepEqual(t, true, *(w.IsFinal)) } ================================================ FILE: pkg/common/testdata/conf/p_s_m.yaml ================================================ Develop: ServicePort: "6789" DebugPort: "6790" EnablePprof: true LogLevel: "debug" LogInterval: "hour" EnableMetrics: false ConsoleLog: true AgentLog: true FileLog: true Product: ServicePort: "6789" DebugPort: "6790" EnablePprof: true LogLevel: "info" LogInterval: "hour" EnableMetrics: true ConsoleLog: false AgentLog: true FileLog: true ================================================ FILE: pkg/common/testdata/proto/test.pb.go ================================================ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.28.0 // protoc v3.19.3 // source: test.proto package proto import ( reflect "reflect" sync "sync" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" ) const ( // Verify that this generated code is sufficiently up-to-date. _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) // Verify that runtime/protoimpl is sufficiently up-to-date. _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) type FOO int32 const ( FOO_X FOO = 17 ) // Enum value maps for FOO. var ( FOO_name = map[int32]string{ 17: "X", } FOO_value = map[string]int32{ "X": 17, } ) func (x FOO) Enum() *FOO { p := new(FOO) *p = x return p } func (x FOO) String() string { return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) } func (FOO) Descriptor() protoreflect.EnumDescriptor { return file_test_proto_enumTypes[0].Descriptor() } func (FOO) Type() protoreflect.EnumType { return &file_test_proto_enumTypes[0] } func (x FOO) Number() protoreflect.EnumNumber { return protoreflect.EnumNumber(x) } // Deprecated: Do not use. func (x *FOO) UnmarshalJSON(b []byte) error { num, err := protoimpl.X.UnmarshalJSONEnum(x.Descriptor(), b) if err != nil { return err } *x = FOO(num) return nil } // Deprecated: Use FOO.Descriptor instead. func (FOO) EnumDescriptor() ([]byte, []int) { return file_test_proto_rawDescGZIP(), []int{0} } type Test struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields Label *string `protobuf:"bytes,1,req,name=label" json:"label,omitempty"` Type *int32 `protobuf:"varint,2,opt,name=type,def=77" json:"type,omitempty"` Reps []int64 `protobuf:"varint,3,rep,name=reps" json:"reps,omitempty"` Optionalgroup *Test_OptionalGroup `protobuf:"group,4,opt,name=OptionalGroup,json=optionalgroup" json:"optionalgroup,omitempty"` } // Default values for Test fields. const ( Default_Test_Type = int32(77) ) func (x *Test) Reset() { *x = Test{} if protoimpl.UnsafeEnabled { mi := &file_test_proto_msgTypes[0] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } } func (x *Test) String() string { return protoimpl.X.MessageStringOf(x) } func (*Test) ProtoMessage() {} func (x *Test) ProtoReflect() protoreflect.Message { mi := &file_test_proto_msgTypes[0] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use Test.ProtoReflect.Descriptor instead. func (*Test) Descriptor() ([]byte, []int) { return file_test_proto_rawDescGZIP(), []int{0} } func (x *Test) GetLabel() string { if x != nil && x.Label != nil { return *x.Label } return "" } func (x *Test) GetType() int32 { if x != nil && x.Type != nil { return *x.Type } return Default_Test_Type } func (x *Test) GetReps() []int64 { if x != nil { return x.Reps } return nil } func (x *Test) GetOptionalgroup() *Test_OptionalGroup { if x != nil { return x.Optionalgroup } return nil } type TestStruct struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields Body []byte `protobuf:"bytes,1,opt,name=body" json:"body,omitempty"` } func (x *TestStruct) Reset() { *x = TestStruct{} if protoimpl.UnsafeEnabled { mi := &file_test_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } } func (x *TestStruct) String() string { return protoimpl.X.MessageStringOf(x) } func (*TestStruct) ProtoMessage() {} func (x *TestStruct) ProtoReflect() protoreflect.Message { mi := &file_test_proto_msgTypes[1] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use TestStruct.ProtoReflect.Descriptor instead. func (*TestStruct) Descriptor() ([]byte, []int) { return file_test_proto_rawDescGZIP(), []int{1} } func (x *TestStruct) GetBody() []byte { if x != nil { return x.Body } return nil } type Test_OptionalGroup struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields RequiredField *string `protobuf:"bytes,5,req,name=RequiredField" json:"RequiredField,omitempty"` } func (x *Test_OptionalGroup) Reset() { *x = Test_OptionalGroup{} if protoimpl.UnsafeEnabled { mi := &file_test_proto_msgTypes[2] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } } func (x *Test_OptionalGroup) String() string { return protoimpl.X.MessageStringOf(x) } func (*Test_OptionalGroup) ProtoMessage() {} func (x *Test_OptionalGroup) ProtoReflect() protoreflect.Message { mi := &file_test_proto_msgTypes[2] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use Test_OptionalGroup.ProtoReflect.Descriptor instead. func (*Test_OptionalGroup) Descriptor() ([]byte, []int) { return file_test_proto_rawDescGZIP(), []int{0, 0} } func (x *Test_OptionalGroup) GetRequiredField() string { if x != nil && x.RequiredField != nil { return *x.RequiredField } return "" } var File_test_proto protoreflect.FileDescriptor var file_test_proto_rawDesc = []byte{ 0x0a, 0x0a, 0x74, 0x65, 0x73, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0c, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x22, 0xc7, 0x01, 0x0a, 0x04, 0x54, 0x65, 0x73, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x6c, 0x61, 0x62, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x02, 0x28, 0x09, 0x52, 0x05, 0x6c, 0x61, 0x62, 0x65, 0x6c, 0x12, 0x16, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x3a, 0x02, 0x37, 0x37, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x72, 0x65, 0x70, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x03, 0x52, 0x04, 0x72, 0x65, 0x70, 0x73, 0x12, 0x46, 0x0a, 0x0d, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0a, 0x32, 0x20, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x54, 0x65, 0x73, 0x74, 0x2e, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, 0x0d, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x1a, 0x35, 0x0a, 0x0d, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x24, 0x0a, 0x0d, 0x52, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x64, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x18, 0x05, 0x20, 0x02, 0x28, 0x09, 0x52, 0x0d, 0x52, 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x64, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x22, 0x20, 0x0a, 0x0a, 0x54, 0x65, 0x73, 0x74, 0x53, 0x74, 0x72, 0x75, 0x63, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x62, 0x6f, 0x64, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x62, 0x6f, 0x64, 0x79, 0x2a, 0x0c, 0x0a, 0x03, 0x46, 0x4f, 0x4f, 0x12, 0x05, 0x0a, 0x01, 0x58, 0x10, 0x11, 0x42, 0x10, 0x5a, 0x0e, 0x2e, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, } var ( file_test_proto_rawDescOnce sync.Once file_test_proto_rawDescData = file_test_proto_rawDesc ) func file_test_proto_rawDescGZIP() []byte { file_test_proto_rawDescOnce.Do(func() { file_test_proto_rawDescData = protoimpl.X.CompressGZIP(file_test_proto_rawDescData) }) return file_test_proto_rawDescData } var ( file_test_proto_enumTypes = make([]protoimpl.EnumInfo, 1) file_test_proto_msgTypes = make([]protoimpl.MessageInfo, 3) file_test_proto_goTypes = []interface{}{ (FOO)(0), // 0: proto.FOO (*Test)(nil), // 1: proto.Test (*TestStruct)(nil), // 2: proto.TestStruct (*Test_OptionalGroup)(nil), // 3: proto.Test.OptionalGroup } ) var file_test_proto_depIdxs = []int32{ 3, // 0: proto.Test.optionalgroup:type_name -> proto.Test.OptionalGroup 1, // [1:1] is the sub-list for method output_type 1, // [1:1] is the sub-list for method input_type 1, // [1:1] is the sub-list for extension type_name 1, // [1:1] is the sub-list for extension extendee 0, // [0:1] is the sub-list for field type_name } func init() { file_test_proto_init() } func file_test_proto_init() { if File_test_proto != nil { return } if !protoimpl.UnsafeEnabled { file_test_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*Test); i { case 0: return &v.state case 1: return &v.sizeCache case 2: return &v.unknownFields default: return nil } } file_test_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*TestStruct); i { case 0: return &v.state case 1: return &v.sizeCache case 2: return &v.unknownFields default: return nil } } file_test_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*Test_OptionalGroup); i { case 0: return &v.state case 1: return &v.sizeCache case 2: return &v.unknownFields default: return nil } } } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_test_proto_rawDesc, NumEnums: 1, NumMessages: 3, NumExtensions: 0, NumServices: 0, }, GoTypes: file_test_proto_goTypes, DependencyIndexes: file_test_proto_depIdxs, EnumInfos: file_test_proto_enumTypes, MessageInfos: file_test_proto_msgTypes, }.Build() File_test_proto = out.File file_test_proto_rawDesc = nil file_test_proto_goTypes = nil file_test_proto_depIdxs = nil } ================================================ FILE: pkg/common/testdata/proto/test.proto ================================================ syntax = "proto2"; package protoexample; option go_package = "./proto"; enum FOO {X=17;}; message Test { required string label = 1; optional int32 type = 2[default=77]; repeated int64 reps = 3; optional group OptionalGroup = 4{ required string RequiredField = 5; } } message TestStruct { optional bytes body = 1; } ================================================ FILE: pkg/common/testdata/template/htmltemplate.html ================================================

Date: {[{.now | formatAsDate}]}

================================================ FILE: pkg/common/testdata/template/index.tmpl ================================================

{[{ .title }]}

================================================ FILE: pkg/common/testdata/test.txt ================================================ hello world! ================================================ FILE: pkg/common/timer/doc.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // The files in timer package are forked from fasthttp[github.com/valyala/fasthttp], // and we keep the original Copyright[Copyright 2015 fasthttp authors] and License of fasthttp for those files. // We also need to modify as we need, the modifications are Copyright of 2022 CloudWeGo Authors. // Thanks for fasthttp authors! Below is the source code information: // Repo: github.com/valyala/fasthttp // Forked Version: v1.36.0 package timer ================================================ FILE: pkg/common/timer/timer.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package timer import ( "sync" "time" ) func initTimer(t *time.Timer, timeout time.Duration) *time.Timer { if t == nil { return time.NewTimer(timeout) } if t.Reset(timeout) { panic("BUG: active timer trapped into initTimer()") } return t } func stopTimer(t *time.Timer) { if !t.Stop() { // Collect possibly added time from the channel // if timer has been stopped and nobody collected its value. select { case <-t.C: default: } } } // AcquireTimer returns a time.Timer from the pool and updates it to // send the current time on its channel after at least timeout. // // The returned Timer may be returned to the pool with ReleaseTimer // when no longer needed. This allows reducing GC load. func AcquireTimer(timeout time.Duration) *time.Timer { v := timerPool.Get() if v == nil { return time.NewTimer(timeout) } t := v.(*time.Timer) initTimer(t, timeout) return t } // ReleaseTimer returns the time.Timer acquired via AcquireTimer to the pool // and prevents the Timer from firing. // // Do not access the released time.Timer or read from its channel otherwise // data races may occur. func ReleaseTimer(t *time.Timer) { stopTimer(t) timerPool.Put(t) } var timerPool sync.Pool ================================================ FILE: pkg/common/timer/timer_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package timer import ( "testing" "time" ) // test initTimer function func TestTimerInitTimer(t *testing.T) { // test nil Timer var nilTimer *time.Timer resNilTime := initTimer(nilTimer, 2*time.Second) if resNilTime == nil { t.Fatalf("Unexpected a nil. Expecting a Timer.") } // test the panic panicTimer := time.NewTimer(1 * time.Second) resPanicTimer := wrapInitTimer(panicTimer, 2*time.Second) if resPanicTimer != -1 { t.Fatalf("Expecting a panic for Timer, but nil") } // sleep enough time to test next timer time.Sleep(3 * time.Second) } func wrapInitTimer(t *time.Timer, timeout time.Duration) (ret int) { defer func() { if err := recover(); err != nil { ret = -1 } }() res := initTimer(t, timeout) if res != nil { ret = 1 } return ret } func TestTimerStopTimer(t *testing.T) { normalTimer := time.NewTimer(3 * time.Second) stopTimer(normalTimer) if normalTimer.Stop() { t.Fatalf("Expecting timer stopped, but it doesn't") } } func TestTimerAcquireTimer(t *testing.T) { normalTimer := AcquireTimer(2 * time.Second) if normalTimer == nil { t.Fatalf("Unexpected nil, expecting a timer") } ReleaseTimer(normalTimer) } func TestTimerReleaseTimer(t *testing.T) { normalTimer := AcquireTimer(2 * time.Second) ReleaseTimer(normalTimer) if normalTimer.Stop() { t.Fatalf("Expecting the timer is released.") } } ================================================ FILE: pkg/common/tracer/stats/event.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package stats import ( "sync" "sync/atomic" "github.com/cloudwego/hertz/pkg/common/errors" ) // EventIndex indicates a unique event. type EventIndex int // Level sets the record level. type Level int // Event levels. const ( LevelDisabled Level = iota LevelBase LevelDetailed ) // Event is used to indicate a specific event. type Event interface { Index() EventIndex Level() Level } type event struct { idx EventIndex level Level } // Index implements the Event interface. func (e event) Index() EventIndex { return e.idx } // Level implements the Event interface. func (e event) Level() Level { return e.level } const ( _ EventIndex = iota serverHandleStart serverHandleFinish httpStart httpFinish readHeaderStart readHeaderFinish readBodyStart readBodyFinish writeStart writeFinish predefinedEventNum ) // Predefined events. var ( HTTPStart = newEvent(httpStart, LevelBase) HTTPFinish = newEvent(httpFinish, LevelBase) ServerHandleStart = newEvent(serverHandleStart, LevelDetailed) ServerHandleFinish = newEvent(serverHandleFinish, LevelDetailed) ReadHeaderStart = newEvent(readHeaderStart, LevelDetailed) ReadHeaderFinish = newEvent(readHeaderFinish, LevelDetailed) ReadBodyStart = newEvent(readBodyStart, LevelDetailed) ReadBodyFinish = newEvent(readBodyFinish, LevelDetailed) WriteStart = newEvent(writeStart, LevelDetailed) WriteFinish = newEvent(writeFinish, LevelDetailed) ) // errors var ( ErrNotAllowed = errors.NewPublic("event definition is not allowed after initialization") ErrDuplicated = errors.NewPublic("event name is already defined") ) var ( lock sync.RWMutex inited int32 userDefined = make(map[string]Event) maxEventNum = int(predefinedEventNum) ) // FinishInitialization freezes all events defined and prevents further definitions to be added. func FinishInitialization() { atomic.StoreInt32(&inited, 1) } // DefineNewEvent allows user to add event definitions during program initialization. func DefineNewEvent(name string, level Level) (Event, error) { if atomic.LoadInt32(&inited) == 1 { return nil, ErrNotAllowed } lock.Lock() defer lock.Unlock() evt, exist := userDefined[name] if exist { return evt, ErrDuplicated } userDefined[name] = newEvent(EventIndex(maxEventNum), level) maxEventNum++ return userDefined[name], nil } // MaxEventNum returns the number of event defined. func MaxEventNum() int { lock.RLock() defer lock.RUnlock() return maxEventNum } func newEvent(idx EventIndex, level Level) Event { return event{ idx: idx, level: level, } } ================================================ FILE: pkg/common/tracer/stats/event_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package stats import ( "testing" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestDefineNewEvent(t *testing.T) { num0 := MaxEventNum() event1, err1 := DefineNewEvent("myevent", LevelDetailed) num1 := MaxEventNum() assert.Assert(t, err1 == nil) assert.Assert(t, event1 != nil) assert.Assert(t, num1 == num0+1) assert.Assert(t, event1.Level() == LevelDetailed) event2, err2 := DefineNewEvent("myevent", LevelBase) num2 := MaxEventNum() assert.Assert(t, err2 == ErrDuplicated) assert.Assert(t, event2 == event1) assert.Assert(t, num2 == num1) assert.Assert(t, event2.Level() == LevelDetailed) FinishInitialization() event3, err3 := DefineNewEvent("another", LevelDetailed) num3 := MaxEventNum() assert.Assert(t, err3 == ErrNotAllowed) assert.Assert(t, event3 == nil) assert.Assert(t, num3 == num1) } ================================================ FILE: pkg/common/tracer/stats/status.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package stats // Status is used to indicate the status of an event. type Status int8 // Predefined status. const ( StatusInfo Status = 1 StatusWarn Status = 2 StatusError Status = 3 ) ================================================ FILE: pkg/common/tracer/traceinfo/httpstats.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package traceinfo import ( "sync" "time" "github.com/cloudwego/hertz/pkg/common/tracer/stats" ) var _ HTTPStats = (*httpStats)(nil) var ( eventPool sync.Pool once sync.Once maxEventNum int ) type event struct { event stats.Event status stats.Status info string time time.Time } // Event implements the Event interface. func (e *event) Event() stats.Event { return e.event } // Status implements the Event interface. func (e *event) Status() stats.Status { return e.status } // Info implements the Event interface. func (e *event) Info() string { return e.info } // Time implements the Event interface. func (e *event) Time() time.Time { return e.time } // IsNil implements the Event interface. func (e *event) IsNil() bool { return e == nil } func newEvent() interface{} { return &event{} } func (e *event) zero() { e.event = nil e.status = 0 e.info = "" e.time = time.Time{} } // Recycle reuses the event. func (e *event) Recycle() { e.zero() eventPool.Put(e) } type httpStats struct { sync.RWMutex level stats.Level eventMap []Event sendSize int recvSize int err error panicErr interface{} } func init() { eventPool.New = newEvent } // Record implements the HTTPStats interface. func (h *httpStats) Record(e stats.Event, status stats.Status, info string) { if e.Level() > h.level { return } eve := eventPool.Get().(*event) eve.event = e eve.status = status eve.info = info eve.time = time.Now() idx := e.Index() h.Lock() h.eventMap[idx] = eve h.Unlock() } // SendSize implements the HTTPStats interface. func (h *httpStats) SendSize() int { return h.sendSize } // RecvSize implements the HTTPStats interface. func (h *httpStats) RecvSize() int { return h.recvSize } // Error implements the HTTPStats interface. func (h *httpStats) Error() error { return h.err } // Panicked implements the HTTPStats interface. func (h *httpStats) Panicked() (bool, interface{}) { return h.panicErr != nil, h.panicErr } // GetEvent implements the HTTPStats interface. func (h *httpStats) GetEvent(e stats.Event) Event { idx := e.Index() h.RLock() evt := h.eventMap[idx] h.RUnlock() if evt == nil || evt.IsNil() { return nil } return evt } // Level implements the HTTPStats interface. func (h *httpStats) Level() stats.Level { return h.level } // SetSendSize sets send size. func (h *httpStats) SetSendSize(size int) { h.sendSize = size } // SetRecvSize sets recv size. func (h *httpStats) SetRecvSize(size int) { h.recvSize = size } // SetError sets error. func (h *httpStats) SetError(err error) { h.err = err } // SetPanicked sets if panicked. func (h *httpStats) SetPanicked(x interface{}) { h.panicErr = x } // SetLevel sets the level. func (h *httpStats) SetLevel(level stats.Level) { h.level = level } // Reset resets the stats. func (h *httpStats) Reset() { h.err = nil h.panicErr = nil h.recvSize = 0 h.sendSize = 0 for i := range h.eventMap { if h.eventMap[i] != nil { h.eventMap[i].(*event).Recycle() h.eventMap[i] = nil } } } // ImmutableView restricts the httpStats into a read-only traceinfo.HTTPStats. func (h *httpStats) ImmutableView() HTTPStats { return h } // NewHTTPStats creates a new HTTPStats. func NewHTTPStats() HTTPStats { once.Do(func() { stats.FinishInitialization() maxEventNum = stats.MaxEventNum() }) return &httpStats{ eventMap: make([]Event, maxEventNum), } } ================================================ FILE: pkg/common/tracer/traceinfo/interface.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package traceinfo import ( "time" "github.com/cloudwego/hertz/pkg/common/tracer/stats" ) // HTTPStats is used to collect statistics about the HTTP. type HTTPStats interface { Record(event stats.Event, status stats.Status, info string) GetEvent(event stats.Event) Event SendSize() int SetSendSize(size int) RecvSize() int SetRecvSize(size int) Error() error SetError(err error) Panicked() (bool, interface{}) SetPanicked(x interface{}) Level() stats.Level SetLevel(level stats.Level) Reset() } // Event is the abstraction of an event happened at a specific time. type Event interface { Event() stats.Event Status() stats.Status Info() string Time() time.Time IsNil() bool } // TraceInfo contains the trace message in Hertz. type TraceInfo interface { Stats() HTTPStats Reset() } ================================================ FILE: pkg/common/tracer/traceinfo/traceinfo.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package traceinfo type traceInfo struct { stats HTTPStats } // Stats implements the HTTPInfo interface. func (r *traceInfo) Stats() HTTPStats { return r.stats } // Reset reuses the traceInfo. func (r *traceInfo) Reset() { r.stats.Reset() } // NewTraceInfo creates a new traceInfoImpl using the given information. func NewTraceInfo() TraceInfo { return &traceInfo{stats: NewHTTPStats()} } ================================================ FILE: pkg/common/tracer/tracer.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package tracer import ( "context" "github.com/cloudwego/hertz/pkg/app" ) // Tracer is executed at the start and finish of an HTTP. type Tracer interface { Start(ctx context.Context, c *app.RequestContext) context.Context Finish(ctx context.Context, c *app.RequestContext) } type Controller interface { Append(col Tracer) DoStart(ctx context.Context, c *app.RequestContext) context.Context DoFinish(ctx context.Context, c *app.RequestContext, err error) HasTracer() bool } ================================================ FILE: pkg/common/ut/context.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ut import ( "io" "io/ioutil" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route" ) // CreateUtRequestContext returns an app.RequestContext for testing purposes func CreateUtRequestContext(method, url string, body *Body, headers ...Header) *app.RequestContext { engine := route.NewEngine(config.NewOptions([]config.Option{})) return createUtRequestContext(engine, method, url, body, headers...) } func createUtRequestContext(engine *route.Engine, method, url string, body *Body, headers ...Header) *app.RequestContext { ctx := engine.NewContext() var r *protocol.Request if body != nil && body.Body != nil { r = protocol.NewRequest(method, url, body.Body) r.CopyTo(&ctx.Request) if engine.IsStreamRequestBody() || body.Len == -1 { ctx.Request.SetBodyStream(body.Body, body.Len) } else { buf, err := ioutil.ReadAll(&io.LimitedReader{R: body.Body, N: int64(body.Len)}) ctx.Request.SetBody(buf) if err != nil && err != io.EOF { panic(err) } } } else { r = protocol.NewRequest(method, url, nil) r.CopyTo(&ctx.Request) } for _, v := range headers { if ctx.Request.Header.Get(v.Key) != "" { ctx.Request.Header.Add(v.Key, v.Value) } else { ctx.Request.Header.Set(v.Key, v.Value) } } return ctx } ================================================ FILE: pkg/common/ut/context_test.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ut import ( "bytes" "testing" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestCreateUtRequestContext(t *testing.T) { body := "1" method := "PUT" path := "/hey/dy" headerKey := "Connection" headerValue := "close" ctx := CreateUtRequestContext(method, path, &Body{bytes.NewBufferString(body), len(body)}, Header{headerKey, headerValue}) assert.DeepEqual(t, method, string(ctx.Method())) assert.DeepEqual(t, path, string(ctx.Path())) body1, err := ctx.Body() assert.DeepEqual(t, nil, err) assert.DeepEqual(t, body, string(body1)) assert.DeepEqual(t, headerValue, string(ctx.GetHeader(headerKey))) } ================================================ FILE: pkg/common/ut/request.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Package ut provides a convenient way to write unit test for the business logic. package ut import ( "context" "io" "github.com/cloudwego/hertz/pkg/route" ) // Header is a key-value pair indicating one http header type Header struct { Key string Value string } // Body is for setting Request.Body type Body struct { Body io.Reader Len int } // PerformRequest send a constructed request to given engine without network transporting // // # Url can be a standard relative URI or a simple absolute path // // If engine.streamRequestBody is true, it sets body as bodyStream // if not, it sets body as bodyBytes // // ResponseRecorder returned are flushed, which means its StatusCode is always set (default 200) // // See ./request_test.go for more examples func PerformRequest(engine *route.Engine, method, url string, body *Body, headers ...Header) *ResponseRecorder { ctx := createUtRequestContext(engine, method, url, body, headers...) engine.ServeHTTP(context.Background(), ctx) w := NewRecorder() h := w.Header() ctx.Response.Header.CopyTo(h) w.WriteHeader(ctx.Response.StatusCode()) w.Write(ctx.Response.Body()) w.Flush() return w } ================================================ FILE: pkg/common/ut/request_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ut import ( "bytes" "context" "fmt" "testing" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/route" ) func newTestEngine() *route.Engine { opt := config.NewOptions([]config.Option{}) return route.NewEngine(opt) } func TestPerformRequest(t *testing.T) { router := newTestEngine() router.PUT("/hey/:user", func(ctx context.Context, c *app.RequestContext) { user := c.Param("user") if string(c.Request.Body()) == "1" { assert.DeepEqual(t, "close", c.Request.Header.Get("Connection")) c.Response.SetConnectionClose() c.JSON(consts.StatusCreated, map[string]string{"hi": user}) } else if string(c.Request.Body()) == "" { c.AbortWithMsg("unauthorized", consts.StatusUnauthorized) } else { assert.DeepEqual(t, "PUT /hey/dy HTTP/1.1\r\nContent-Type: application/x-www-form-urlencoded\r\nTransfer-Encoding: chunked\r\n\r\n", string(c.Request.Header.Header())) c.String(consts.StatusAccepted, "body:%v", string(c.Request.Body())) } }) router.GET("/her/header", func(ctx context.Context, c *app.RequestContext) { assert.DeepEqual(t, "application/json", string(c.GetHeader("Content-Type"))) assert.DeepEqual(t, 1, c.Request.Header.ContentLength()) assert.DeepEqual(t, "a", c.Request.Header.Get("dummy")) }) // valid user w := PerformRequest(router, "PUT", "/hey/dy", &Body{bytes.NewBufferString("1"), 1}, Header{"Connection", "close"}) resp := w.Result() assert.DeepEqual(t, consts.StatusCreated, resp.StatusCode()) assert.DeepEqual(t, "{\"hi\":\"dy\"}", string(resp.Body())) assert.DeepEqual(t, "application/json; charset=utf-8", string(resp.Header.ContentType())) assert.DeepEqual(t, true, resp.Header.ConnectionClose()) // unauthorized user w = PerformRequest(router, "PUT", "/hey/dy", nil) _ = w.Result() resp = w.Result() assert.DeepEqual(t, consts.StatusUnauthorized, resp.StatusCode()) assert.DeepEqual(t, "unauthorized", string(resp.Body())) assert.DeepEqual(t, "text/plain; charset=utf-8", string(resp.Header.ContentType())) assert.DeepEqual(t, 12, resp.Header.ContentLength()) // special header PerformRequest(router, "GET", "/hey/header", nil, Header{"content-type", "application/json"}, Header{"content-length", "1"}, Header{"dummy", "a"}, Header{"dummy", "b"}, ) // not found w = PerformRequest(router, "GET", "/hey", nil) resp = w.Result() assert.DeepEqual(t, consts.StatusNotFound, resp.StatusCode()) // fake body w = PerformRequest(router, "GET", "/hey", nil) _, err := w.WriteString(", faker") resp = w.Result() assert.Nil(t, err) assert.DeepEqual(t, consts.StatusNotFound, resp.StatusCode()) assert.DeepEqual(t, "Not Found, faker", string(resp.Body())) // chunked body body := bytes.NewReader(createChunkedBody([]byte("hello world!"))) w = PerformRequest(router, "PUT", "/hey/dy", &Body{body, -1}) resp = w.Result() assert.DeepEqual(t, consts.StatusAccepted, resp.StatusCode()) assert.DeepEqual(t, "body:1\r\nh\r\n2\r\nel\r\n3\r\nlo \r\n4\r\nworl\r\n2\r\nd!\r\n0\r\n\r\n", string(resp.Body())) } func createChunkedBody(body []byte) []byte { var b []byte chunkSize := 1 for len(body) > 0 { if chunkSize > len(body) { chunkSize = len(body) } b = append(b, []byte(fmt.Sprintf("%x\r\n", chunkSize))...) b = append(b, body[:chunkSize]...) b = append(b, []byte("\r\n")...) body = body[chunkSize:] chunkSize++ } return append(b, []byte("0\r\n\r\n")...) } ================================================ FILE: pkg/common/ut/response.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ut import ( "bytes" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" ) // ResponseRecorder records handler's response for later test type ResponseRecorder struct { // Code is the HTTP response code set by WriteHeader. // // Note that if a Handler never calls WriteHeader or Write, // this might end up being 0, rather than the implicit // http.StatusOK. To get the implicit value, use the Result // method. Code int // header contains the headers explicitly set by the Handler. // It is an internal detail. header *protocol.ResponseHeader // Body is the buffer to which the Handler's Write calls are sent. // If nil, the Writes are silently discarded. Body *bytes.Buffer // Flushed is whether the Handler called Flush. Flushed bool result *protocol.Response // cache of Result's return value wroteHeader bool } // NewRecorder returns an initialized ResponseRecorder. func NewRecorder() *ResponseRecorder { return &ResponseRecorder{ header: new(protocol.ResponseHeader), Body: new(bytes.Buffer), Code: consts.StatusOK, } } // Header returns the response headers to mutate within a handler. // To test the headers that were written after a handler completes, // use the Result method and see the returned Response value's Header. func (rw *ResponseRecorder) Header() *protocol.ResponseHeader { m := rw.header if m == nil { m = new(protocol.ResponseHeader) rw.header = m } return m } // Write implements io.Writer. The data in buf is written to // rw.Body, if not nil. func (rw *ResponseRecorder) Write(buf []byte) (int, error) { if !rw.wroteHeader { rw.WriteHeader(consts.StatusOK) } if rw.Body != nil { rw.Body.Write(buf) } return len(buf), nil } // WriteString implements io.StringWriter. The data in str is written // to rw.Body, if not nil. func (rw *ResponseRecorder) WriteString(str string) (int, error) { if !rw.wroteHeader { rw.WriteHeader(consts.StatusOK) } if rw.Body != nil { rw.Body.WriteString(str) } return len(str), nil } // WriteHeader sends an HTTP response header with the provided // status code. func (rw *ResponseRecorder) WriteHeader(code int) { if rw.wroteHeader { return } if rw.header == nil { rw.header = new(protocol.ResponseHeader) } rw.header.SetStatusCode(code) rw.Code = code rw.wroteHeader = true } // Flush implements http.Flusher. To test whether Flush was // called, see rw.Flushed. func (rw *ResponseRecorder) Flush() { if !rw.wroteHeader { rw.WriteHeader(consts.StatusOK) } rw.Flushed = true } // Result returns the response generated by the handler. // // The returned Response will have at least its StatusCode, // Header, Body, and optionally Trailer populated. // More fields may be populated in the future, so callers should // not DeepEqual the result in tests. // // The Response.Header is a snapshot of the headers at the time of the // first write call, or at the time of this call, if the handler never // did a write. // // The Response.Body is guaranteed to be non-nil and Body.Read call is // guaranteed to not return any error other than io.EOF. // // Result must only be called after the handler has finished running. func (rw *ResponseRecorder) Result() *protocol.Response { if rw.result != nil { return rw.result } res := new(protocol.Response) h := rw.Header() h.CopyTo(&res.Header) if rw.Body != nil { b := rw.Body.Bytes() res.SetBody(b) res.Header.SetContentLength(len(b)) } rw.result = res return res } ================================================ FILE: pkg/common/ut/response_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ut import ( "testing" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/protocol/consts" ) func TestResult(t *testing.T) { r := new(ResponseRecorder) ret := r.Result() assert.DeepEqual(t, consts.StatusOK, ret.StatusCode()) } func TestFlush(t *testing.T) { r := new(ResponseRecorder) r.Flush() ret := r.Result() assert.DeepEqual(t, consts.StatusOK, ret.StatusCode()) } func TestWriterHeader(t *testing.T) { r := NewRecorder() r.WriteHeader(consts.StatusCreated) r.WriteHeader(consts.StatusOK) ret := r.Result() assert.DeepEqual(t, consts.StatusCreated, ret.StatusCode()) } func TestWriteString(t *testing.T) { r := NewRecorder() r.WriteString("hello") ret := r.Result() assert.DeepEqual(t, "hello", string(ret.Body())) } func TestWrite(t *testing.T) { r := NewRecorder() r.Write([]byte("hello")) ret := r.Result() assert.DeepEqual(t, "hello", string(ret.Body())) } ================================================ FILE: pkg/common/utils/bufpool.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package utils import "sync" var CopyBufPool = sync.Pool{ New: func() interface{} { return make([]byte, 4096) }, } ================================================ FILE: pkg/common/utils/chunk.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package utils import ( "bytes" "fmt" "io" "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/network" ) var errBrokenChunk = errors.NewPublic("cannot find crlf at the end of chunk") func ParseChunkSize(r network.Reader) (int, error) { n, err := bytesconv.ReadHexInt(r) if err != nil { if err == io.EOF { err = io.ErrUnexpectedEOF } return -1, err } for { c, err := r.ReadByte() if err != nil { return -1, errors.NewPublic(fmt.Sprintf("cannot read '\r' char at the end of chunk size: %s", err)) } // Skip any trailing whitespace after chunk size. if c == ' ' { continue } if c != '\r' { return -1, errors.NewPublic( fmt.Sprintf("unexpected char %q at the end of chunk size. Expected %q", c, '\r'), ) } break } c, err := r.ReadByte() if err != nil { return -1, errors.NewPublic(fmt.Sprintf("cannot read '\n' char at the end of chunk size: %s", err)) } if c != '\n' { return -1, errors.NewPublic( fmt.Sprintf("unexpected char %q at the end of chunk size. Expected %q", c, '\n'), ) } return n, nil } // SkipCRLF will only skip the next CRLF("\r\n"), otherwise, error will be returned. func SkipCRLF(reader network.Reader) error { p, err := reader.Peek(len(bytestr.StrCRLF)) if err != nil { return err } if !bytes.Equal(p, bytestr.StrCRLF) { return errBrokenChunk } reader.Skip(len(p)) // nolint: errcheck return nil } ================================================ FILE: pkg/common/utils/chunk_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package utils import ( "testing" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/test/mock" ) func TestChunkParseChunkSizeGetCorrect(t *testing.T) { // iterate the hexMap, and judge the difference between dec and ParseChunkSize hexMap := map[int]string{0: "0", 10: "a", 100: "64", 1000: "3e8"} for dec, hex := range hexMap { chunkSizeBody := hex + "\r\n" zr := mock.NewZeroCopyReader(chunkSizeBody) chunkSize, err := ParseChunkSize(zr) assert.DeepEqual(t, nil, err) assert.DeepEqual(t, chunkSize, dec) } } func TestChunkParseChunkSizeGetError(t *testing.T) { // test err from -----n, err := bytesconv.ReadHexInt(r)----- chunkSizeBody := "" zr := mock.NewZeroCopyReader(chunkSizeBody) chunkSize, err := ParseChunkSize(zr) assert.NotNil(t, err) assert.DeepEqual(t, -1, chunkSize) // test err from -----c, err := r.ReadByte()----- chunkSizeBody = "0" zr = mock.NewZeroCopyReader(chunkSizeBody) chunkSize, err = ParseChunkSize(zr) assert.NotNil(t, err) assert.DeepEqual(t, -1, chunkSize) // test err from -----c, err := r.ReadByte()----- chunkSizeBody = "0" + "\r" zr = mock.NewZeroCopyReader(chunkSizeBody) chunkSize, err = ParseChunkSize(zr) assert.NotNil(t, err) assert.DeepEqual(t, -1, chunkSize) // test err from -----c, err := r.ReadByte()----- chunkSizeBody = "0" + "\r" + "\r" zr = mock.NewZeroCopyReader(chunkSizeBody) chunkSize, err = ParseChunkSize(zr) assert.NotNil(t, err) assert.DeepEqual(t, -1, chunkSize) } func TestChunkParseChunkSizeCorrectWhiteSpace(t *testing.T) { // test the whitespace whiteSpace := "" for i := 0; i < 10; i++ { whiteSpace += " " chunkSizeBody := "0" + whiteSpace + "\r\n" zr := mock.NewZeroCopyReader(chunkSizeBody) chunkSize, err := ParseChunkSize(zr) assert.DeepEqual(t, nil, err) assert.DeepEqual(t, 0, chunkSize) } } func TestChunkParseChunkSizeNonCRLF(t *testing.T) { // test non-"\r\n" chunkSizeBody := "0" + "\n\r" zr := mock.NewZeroCopyReader(chunkSizeBody) chunkSize, err := ParseChunkSize(zr) assert.DeepEqual(t, true, err != nil) assert.DeepEqual(t, -1, chunkSize) } func TestChunkReadTrueCRLF(t *testing.T) { CRLF := "\r\n" zr := mock.NewZeroCopyReader(CRLF) err := SkipCRLF(zr) assert.DeepEqual(t, nil, err) } func TestChunkReadFalseCRLF(t *testing.T) { CRLF := "\n\r" zr := mock.NewZeroCopyReader(CRLF) err := SkipCRLF(zr) assert.DeepEqual(t, errBrokenChunk, err) } ================================================ FILE: pkg/common/utils/env.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package utils import ( "os" "strconv" "strings" "github.com/cloudwego/hertz/pkg/common/errors" ) // Get bool from env func GetBoolFromEnv(key string) (bool, error) { value, isExist := os.LookupEnv(key) if !isExist { return false, errors.NewPublic("env not exist") } value = strings.TrimSpace(value) return strconv.ParseBool(value) } ================================================ FILE: pkg/common/utils/ioutil.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package utils import ( "io" "github.com/cloudwego/hertz/pkg/network" ) func CopyBuffer(dst network.Writer, src io.Reader, buf []byte) (written int64, err error) { if buf != nil && len(buf) == 0 { panic("empty buffer in io.CopyBuffer") } return copyBuffer(dst, src, buf) } // copyBuffer is the actual implementation of Copy and CopyBuffer. // If buf is nil, one is allocated. func copyBuffer(dst network.Writer, src io.Reader, buf []byte) (written int64, err error) { if wt, ok := src.(io.WriterTo); ok { if w, ok := dst.(io.Writer); ok { return wt.WriteTo(w) } } // Sendfile impl if rf, ok := dst.(io.ReaderFrom); ok { return rf.ReadFrom(src) } if buf == nil { size := 32 * 1024 if l, ok := src.(*io.LimitedReader); ok && int64(size) > l.N { if l.N < 1 { size = 1 } else { size = int(l.N) } } buf = make([]byte, size) } for { nr, er := src.Read(buf) if nr > 0 { nw, eb := dst.WriteBinary(buf[:nr]) if eb != nil { err = eb return } if nw > 0 { written += int64(nw) } if nr != nw { err = io.ErrShortWrite return } if err = dst.Flush(); err != nil { return } } if er != nil { if er != io.EOF { err = er } break } } return } func CopyZeroAlloc(w network.Writer, r io.Reader) (int64, error) { vbuf := CopyBufPool.Get() buf := vbuf.([]byte) n, err := CopyBuffer(w, r, buf) CopyBufPool.Put(vbuf) return n, err } ================================================ FILE: pkg/common/utils/ioutil_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package utils import ( "bytes" "io" "testing" "github.com/cloudwego/hertz/pkg/common/test/mock" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/network" ) type writeReadTest interface { Write(p []byte) (n int, err error) Malloc(n int) (buf []byte, err error) WriteBinary(b []byte) (n int, err error) Flush() error } type readerTest interface { ReadFrom(r io.Reader) (n int64, err error) Malloc(n int) (buf []byte, err error) WriteBinary(b []byte) (n int, err error) Flush() error } type testWriter struct { w io.Writer } func (t testWriter) Write(p []byte) (n int, err error) { return } func (t testWriter) Malloc(n int) (buf []byte, err error) { return } func (t testWriter) WriteBinary(b []byte) (n int, err error) { return } func (t testWriter) Flush() error { return nil } type testReader struct { r io.ReaderFrom } func (t testReader) ReadFrom(r io.Reader) (n int64, err error) { return } func (t testReader) Malloc(n int) (buf []byte, err error) { return } func (t testReader) WriteBinary(b []byte) (n int, err error) { return } func (t testReader) Flush() error { return nil } func newTestWriter(w io.Writer) writeReadTest { return &testWriter{ w: w, } } func newTestReaderForm(r io.ReaderFrom) readerTest { return &testReader{ r: r, } } func TestIoutilCopyBuffer(t *testing.T) { var writeBuffer bytes.Buffer str := string("hertz is very good!!!") src := bytes.NewBufferString(str) dst := network.NewWriter(&writeBuffer) var buf []byte // src.Len() will change, when use src.read(p []byte) srcLen := int64(src.Len()) written, err := CopyBuffer(dst, src, buf) assert.DeepEqual(t, written, srcLen) assert.DeepEqual(t, err, nil) assert.DeepEqual(t, []byte(str), writeBuffer.Bytes()) // Test when no data is readable writeBuffer.Reset() emptySrc := bytes.NewBufferString("") written, err = CopyBuffer(dst, emptySrc, buf) assert.DeepEqual(t, written, int64(0)) assert.Nil(t, err) assert.DeepEqual(t, []byte(""), writeBuffer.Bytes()) // Test a LimitedReader writeBuffer.Reset() limit := int64(5) limitedSrc := io.LimitedReader{R: bytes.NewBufferString(str), N: limit} written, err = CopyBuffer(dst, &limitedSrc, buf) assert.DeepEqual(t, written, limit) assert.Nil(t, err) assert.DeepEqual(t, []byte(str[:limit]), writeBuffer.Bytes()) } func TestIoutilCopyBufferWithIoWriter(t *testing.T) { var writeBuffer bytes.Buffer str := "hertz is very good!!!" var buf []byte src := bytes.NewBuffer([]byte(str)) ioWriter := newTestWriter(&writeBuffer) // to show example about -----w, ok := dst.(io.Writer)----- _, ok := ioWriter.(io.Writer) assert.DeepEqual(t, true, ok) written, err := CopyBuffer(ioWriter, src, buf) assert.DeepEqual(t, written, int64(0)) assert.NotNil(t, err) assert.DeepEqual(t, []byte(nil), writeBuffer.Bytes()) } func TestIoutilCopyBufferWithIoReaderFrom(t *testing.T) { var writeBuffer bytes.Buffer str := "hertz is very good!!!" var buf []byte src := bytes.NewBufferString(str) ioReaderFrom := newTestReaderForm(&writeBuffer) // to show example about -----rf, ok := dst.(io.ReaderFrom)----- _, ok := ioReaderFrom.(io.Writer) assert.DeepEqual(t, false, ok) _, ok = ioReaderFrom.(io.ReaderFrom) assert.DeepEqual(t, true, ok) written, err := CopyBuffer(ioReaderFrom, src, buf) assert.DeepEqual(t, written, int64(0)) assert.Nil(t, err) assert.DeepEqual(t, []byte(nil), writeBuffer.Bytes()) } func TestIoutilCopyBufferWithPanic(t *testing.T) { var writeBuffer bytes.Buffer str := "hertz is very good!!!" var buf []byte defer func() { if r := recover(); r != nil { assert.DeepEqual(t, "empty buffer in io.CopyBuffer", r) } }() src := bytes.NewBufferString(str) dst := network.NewWriter(&writeBuffer) buf = make([]byte, 0) _, _ = CopyBuffer(dst, src, buf) } func TestIoutilCopyBufferWithNilBuffer(t *testing.T) { var writeBuffer bytes.Buffer str := string("hertz is very good!!!") src := bytes.NewBufferString(str) dst := network.NewWriter(&writeBuffer) // src.Len() will change, when use src.read(p []byte) srcLen := int64(src.Len()) written, err := CopyBuffer(dst, src, nil) assert.DeepEqual(t, written, srcLen) assert.Nil(t, err) assert.DeepEqual(t, []byte(str), writeBuffer.Bytes()) } func TestIoutilCopyBufferWithNilBufferAndIoLimitedReader(t *testing.T) { var writeBuffer bytes.Buffer str := "hertz is very good!!!" src := bytes.NewBufferString(str) reader := mock.NewLimitReader(src) dst := network.NewWriter(&writeBuffer) srcLen := int64(src.Len()) written, err := CopyBuffer(dst, &reader, nil) assert.DeepEqual(t, written, srcLen) assert.Nil(t, err) assert.DeepEqual(t, []byte(str), writeBuffer.Bytes()) // test l.N < 1 writeBuffer.Reset() str = "" src = bytes.NewBufferString(str) reader = mock.NewLimitReader(src) dst = network.NewWriter(&writeBuffer) srcLen = int64(src.Len()) written, err = CopyBuffer(dst, &reader, nil) assert.DeepEqual(t, written, srcLen) assert.Nil(t, err) assert.DeepEqual(t, []byte(str), writeBuffer.Bytes()) } func TestIoutilCopyZeroAlloc(t *testing.T) { var writeBuffer bytes.Buffer str := "hertz is very good!!!" src := bytes.NewBufferString(str) dst := network.NewWriter(&writeBuffer) srcLen := int64(src.Len()) written, err := CopyZeroAlloc(dst, src) assert.DeepEqual(t, written, srcLen) assert.DeepEqual(t, err, nil) assert.DeepEqual(t, []byte(str), writeBuffer.Bytes()) // Test when no data is readable writeBuffer.Reset() emptySrc := bytes.NewBufferString("") written, err = CopyZeroAlloc(dst, emptySrc) assert.DeepEqual(t, written, int64(0)) assert.Nil(t, err) assert.DeepEqual(t, []byte(""), writeBuffer.Bytes()) } func TestIoutilCopyBufferWithEmptyBuffer(t *testing.T) { var writeBuffer bytes.Buffer str := "hertz is very good!!!" src := bytes.NewBufferString(str) dst := network.NewWriter(&writeBuffer) // Use a non-empty buffer of length 0 emptyBuf := make([]byte, 0) func() { defer func() { if r := recover(); r != nil { assert.DeepEqual(t, "empty buffer in io.CopyBuffer", r) } }() written, err := CopyBuffer(dst, src, emptyBuf) assert.Nil(t, err) assert.DeepEqual(t, written, int64(len(str))) assert.DeepEqual(t, []byte(str), writeBuffer.Bytes()) }() } func TestIoutilCopyBufferWithLimitedReader(t *testing.T) { var writeBuffer bytes.Buffer str := "hertz is very good!!!" src := bytes.NewBufferString(str) limit := int64(5) limitedSrc := io.LimitedReader{R: src, N: limit} dst := network.NewWriter(&writeBuffer) var buf []byte // Test LimitedReader status written, err := CopyBuffer(dst, &limitedSrc, buf) assert.Nil(t, err) assert.DeepEqual(t, written, limit) assert.DeepEqual(t, []byte(str[:limit]), writeBuffer.Bytes()) } ================================================ FILE: pkg/common/utils/netaddr.go ================================================ /* * Copyright 2021 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package utils import "net" var _ net.Addr = (*NetAddr)(nil) // NetAddr implements the net.Addr interface. type NetAddr struct { network string address string } // NewNetAddr creates a new NetAddr object with the network and address provided. func NewNetAddr(network, address string) net.Addr { return &NetAddr{network, address} } // Network implements the net.Addr interface. func (na *NetAddr) Network() string { return na.network } // String implements the net.Addr interface. func (na *NetAddr) String() string { return na.address } ================================================ FILE: pkg/common/utils/netaddr_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package utils import ( "testing" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestNetAddr(t *testing.T) { networkAddr := NewNetAddr("127.0.0.1", "192.168.1.1") assert.DeepEqual(t, networkAddr.Network(), "127.0.0.1") assert.DeepEqual(t, networkAddr.String(), "192.168.1.1") } ================================================ FILE: pkg/common/utils/network.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package utils import "net" const ( UNKNOWN_IP_ADDR = "-" ) var localIP string // LocalIP returns host's ip func LocalIP() string { return localIP } // getLocalIp enumerates local net interfaces to find local ip, it should only be called in init phase func getLocalIp() string { inters, err := net.Interfaces() if err != nil { return UNKNOWN_IP_ADDR } for _, inter := range inters { if inter.Flags&net.FlagLoopback != net.FlagLoopback && inter.Flags&net.FlagUp != 0 { addrs, err := inter.Addrs() if err != nil { return UNKNOWN_IP_ADDR } for _, addr := range addrs { if ipnet, ok := addr.(*net.IPNet); ok && !ipnet.IP.IsLoopback() { return ipnet.IP.String() } } } } return UNKNOWN_IP_ADDR } func init() { localIP = getLocalIp() } // TLSRecordHeaderLooksLikeHTTP reports whether a TLS record header // looks like it might've been a misdirected plaintext HTTP request. func TLSRecordHeaderLooksLikeHTTP(hdr [5]byte) bool { switch string(hdr[:]) { case "GET /", "HEAD ", "POST ", "PUT /", "OPTIO": return true } return false } ================================================ FILE: pkg/common/utils/network_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package utils import ( "testing" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestTLSRecordHeaderLooksLikeHTTP(t *testing.T) { HeaderValueAndExpectedResult := [][]interface{}{ {[5]byte{'G', 'E', 'T', ' ', '/'}, true}, {[5]byte{'H', 'E', 'A', 'D', ' '}, true}, {[5]byte{'P', 'O', 'S', 'T', ' '}, true}, {[5]byte{'P', 'U', 'T', ' ', '/'}, true}, {[5]byte{'O', 'P', 'T', 'I', 'O'}, true}, {[5]byte{'G', 'E', 'T', '/', ' '}, false}, {[5]byte{' ', 'H', 'E', 'A', 'D'}, false}, {[5]byte{' ', 'P', 'O', 'S', 'T'}, false}, {[5]byte{'P', 'U', 'T', '/', ' '}, false}, {[5]byte{'H', 'E', 'R', 'T', 'Z'}, false}, } for _, testCase := range HeaderValueAndExpectedResult { value, expectedResult := testCase[0].([5]byte), testCase[1].(bool) assert.DeepEqual(t, expectedResult, TLSRecordHeaderLooksLikeHTTP(value)) } } func TestLocalIP(t *testing.T) { // Mock the localIP variable for testing purposes. localIP = "192.168.0.1" // Ensure that LocalIP() returns the expected local IP. expectedIP := "192.168.0.1" if got := LocalIP(); got != expectedIP { assert.DeepEqual(t, got, expectedIP) } } ================================================ FILE: pkg/common/utils/path.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * The MIT License (MIT) * * Copyright (c) 2014 Manuel Martínez-Almeida * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors */ package utils import "strings" // CleanPath is the URL version of path.Clean, it returns a canonical URL path // for p, eliminating . and .. elements. // // The following rules are applied iteratively until no further processing can // be done: // 1. Replace multiple slashes with a single slash. // 2. Eliminate each . path name element (the current directory). // 3. Eliminate each inner .. path name element (the parent directory) // along with the non-.. element that precedes it. // 4. Eliminate .. elements that begin a rooted path: // that is, replace "/.." by "/" at the beginning of a path. // // If the result of this process is an empty string, "/" is returned func CleanPath(p string) string { const stackBufSize = 128 // Turn empty string into "/" if p == "" { return "/" } // Reasonably sized buffer on stack to avoid allocations in the common case. // If a larger buffer is required, it gets allocated dynamically. buf := make([]byte, 0, stackBufSize) n := len(p) // Invariants: // reading from path; r is index of next byte to process. // writing to buf; w is index of next byte to write. // path must start with '/' r := 1 w := 1 if p[0] != '/' { r = 0 if n+1 > stackBufSize { buf = make([]byte, n+1) } else { buf = buf[:n+1] } buf[0] = '/' } trailing := n > 1 && p[n-1] == '/' // A bit more clunky without a 'lazybuf' like the path package, but the loop // gets completely inlined (bufApp calls). // So in contrast to the path package this loop has no expensive function // calls (except make, if needed). for r < n { switch { case p[r] == '/': // empty path element, trailing slash is added after the end r++ case p[r] == '.' && r+1 == n: trailing = true r++ case p[r] == '.' && p[r+1] == '/': // . element r += 2 case p[r] == '.' && p[r+1] == '.' && (r+2 == n || p[r+2] == '/'): // .. element: remove to last / r += 3 if w > 1 { // can backtrack w-- if len(buf) == 0 { for w > 1 && p[w] != '/' { w-- } } else { for w > 1 && buf[w] != '/' { w-- } } } default: // Real path element. // Add slash if needed if w > 1 { bufApp(&buf, p, w, '/') w++ } // Copy element for r < n && p[r] != '/' { bufApp(&buf, p, w, p[r]) w++ r++ } } } // Re-append trailing slash if trailing && w > 1 { bufApp(&buf, p, w, '/') w++ } // If the original string was not modified (or only shortened at the end), // return the respective substring of the original string. // Otherwise return a new string from the buffer. if len(buf) == 0 { return p[:w] } return string(buf[:w]) } // Internal helper to lazily create a buffer if necessary. // Calls to this function get inlined. func bufApp(buf *[]byte, s string, w int, c byte) { b := *buf if len(b) == 0 { // No modification of the original string so far. // If the next character is the same as in the original string, we do // not yet have to allocate a buffer. if s[w] == c { return } // Otherwise use either the stack buffer, if it is large enough, or // allocate a new buffer on the heap, and copy all previous characters. if l := len(s); l > cap(b) { *buf = make([]byte, len(s)) } else { *buf = (*buf)[:l] } b = *buf copy(b, s[:w]) } b[w] = c } // AddMissingPort adds a port to a host if it is missing. // A literal IPv6 address in hostport must be enclosed in square // brackets, as in "[::1]:80", "[::1%lo0]:80". func AddMissingPort(addr string, isTLS bool) string { if strings.IndexByte(addr, ':') >= 0 { endOfV6 := strings.IndexByte(addr, ']') // we do not care about the validity of the address, just check if it has more bytes after ']' if endOfV6 < len(addr)-1 { return addr } } if !isTLS { return addr + ":80" } return addr + ":443" } ================================================ FILE: pkg/common/utils/path_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * The MIT License (MIT) * * Copyright (c) 2014 Manuel Martínez-Almeida * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors */ package utils import ( "testing" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestPathCleanPath(t *testing.T) { normalPath := "/Foo/Bar/go/src/github.com/cloudwego/hertz/pkg/common/utils/path_test.go" expectedNormalPath := "/Foo/Bar/go/src/github.com/cloudwego/hertz/pkg/common/utils/path_test.go" cleanNormalPath := CleanPath(normalPath) assert.DeepEqual(t, expectedNormalPath, cleanNormalPath) singleDotPath := "/Foo/Bar/./././go/src" expectedSingleDotPath := "/Foo/Bar/go/src" cleanSingleDotPath := CleanPath(singleDotPath) assert.DeepEqual(t, expectedSingleDotPath, cleanSingleDotPath) doubleDotPath := "../../.." expectedDoubleDotPath := "/" cleanDoublePotPath := CleanPath(doubleDotPath) assert.DeepEqual(t, expectedDoubleDotPath, cleanDoublePotPath) // MultiDot can be treated as a file name multiDotPath := "/../...." expectedMultiDotPath := "/...." cleanMultiDotPath := CleanPath(multiDotPath) assert.DeepEqual(t, expectedMultiDotPath, cleanMultiDotPath) nullPath := "" expectedNullPath := "/" cleanNullPath := CleanPath(nullPath) assert.DeepEqual(t, expectedNullPath, cleanNullPath) relativePath := "/Foo/Bar/../go/src/../../github.com/cloudwego/hertz" expectedRelativePath := "/Foo/github.com/cloudwego/hertz" cleanRelativePath := CleanPath(relativePath) assert.DeepEqual(t, expectedRelativePath, cleanRelativePath) multiSlashPath := "///////Foo//Bar////go//src/github.com/cloudwego/hertz//.." expectedMultiSlashPath := "/Foo/Bar/go/src/github.com/cloudwego" cleanMultiSlashPath := CleanPath(multiSlashPath) assert.DeepEqual(t, expectedMultiSlashPath, cleanMultiSlashPath) inputPath := "/Foo/Bar/go/src/github.com/cloudwego/hertz/pkg/common/utils/path_test.go/." expectedPath := "/Foo/Bar/go/src/github.com/cloudwego/hertz/pkg/common/utils/path_test.go/" cleanedPath := CleanPath(inputPath) assert.DeepEqual(t, expectedPath, cleanedPath) } // The Function AddMissingPort can only add the missed port, don't consider the other error case. func TestPathAddMissingPort(t *testing.T) { ipList := []string{"127.0.0.1", "111.111.1.1", "[0:0:0:0:0:ffff:192.1.56.10]", "[0:0:0:0:0:ffff:c0a8:101]", "www.foobar.com"} for _, ip := range ipList { assert.DeepEqual(t, ip+":443", AddMissingPort(ip, true)) assert.DeepEqual(t, ip+":80", AddMissingPort(ip, false)) customizedPort := ":8080" assert.DeepEqual(t, ip+customizedPort, AddMissingPort(ip+customizedPort, true)) assert.DeepEqual(t, ip+customizedPort, AddMissingPort(ip+customizedPort, false)) } } ================================================ FILE: pkg/common/utils/utils.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package utils import ( "bytes" "reflect" "runtime" "strings" "github.com/cloudwego/hertz/internal/bytesconv" errs "github.com/cloudwego/hertz/pkg/common/errors" ) var errNeedMore = errs.New(errs.ErrNeedMore, errs.ErrorTypePublic, "cannot find trailing lf") func Assert(guard bool, text string) { if !guard { panic(text) } } // H is a shortcut for map[string]interface{} type H map[string]interface{} func IsTrueString(str string) bool { return strings.ToLower(str) == "true" } func NameOfFunction(f interface{}) string { return runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name() } func CaseInsensitiveCompare(a, b []byte) bool { if len(a) != len(b) { return false } for i := 0; i < len(a); i++ { if a[i]|0x20 != b[i]|0x20 { return false } } return true } func NormalizeHeaderKey(b []byte, disableNormalizing bool) { if disableNormalizing { return } n := len(b) if n == 0 { return } b[0] = bytesconv.ToUpperTable[b[0]] for i := 1; i < n; i++ { p := &b[i] if *p == '-' { i++ if i < n { b[i] = bytesconv.ToUpperTable[b[i]] } continue } *p = bytesconv.ToLowerTable[*p] } } func NextLine(b []byte) ([]byte, []byte, error) { nNext := bytes.IndexByte(b, '\n') if nNext < 0 { return nil, nil, errNeedMore } n := nNext if n > 0 && b[n-1] == '\r' { n-- } return b[:n], b[nNext+1:], nil } func FilterContentType(content string) string { for i, char := range content { if char == ' ' || char == ';' { return content[:i] } } return content } ================================================ FILE: pkg/common/utils/utils_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package utils import ( "testing" "github.com/cloudwego/hertz/pkg/common/test/assert" ) // test assert func func TestUtilsAssert(t *testing.T) { assertPanic := func() (panicked bool) { defer func() { if r := recover(); r != nil { panicked = true } }() Assert(false, "should panic") return false } // Checking if the assertPanic function results in a panic as expected. // We expect a true value because it should panic. assert.DeepEqual(t, true, assertPanic()) // Checking if a true assertion does not result in a panic. // We create a wrapper around Assert to capture if it panics when it should not. noPanic := func() (panicked bool) { defer func() { if r := recover(); r != nil { panicked = true } }() Assert(true, "should not panic") return false } // We expect a false value because it should not panic. assert.DeepEqual(t, false, noPanic()) } func TestUtilsIsTrueString(t *testing.T) { normalTrueStr := "true" upperTrueStr := "TRUE" otherStr := "hertz" assert.DeepEqual(t, true, IsTrueString(normalTrueStr)) assert.DeepEqual(t, true, IsTrueString(upperTrueStr)) assert.DeepEqual(t, false, IsTrueString(otherStr)) } // used for TestUtilsNameOfFunction func testName(a int) { } // return the relative path for the function func TestUtilsNameOfFunction(t *testing.T) { pathOfTestName := "github.com/cloudwego/hertz/pkg/common/utils.testName" pathOfIsTrueString := "github.com/cloudwego/hertz/pkg/common/utils.IsTrueString" nameOfTestName := NameOfFunction(testName) nameOfIsTrueString := NameOfFunction(IsTrueString) assert.DeepEqual(t, pathOfTestName, nameOfTestName) assert.DeepEqual(t, pathOfIsTrueString, nameOfIsTrueString) } func TestUtilsCaseInsensitiveCompare(t *testing.T) { lowerStr := []byte("content-length") upperStr := []byte("Content-Length") assert.DeepEqual(t, true, CaseInsensitiveCompare(lowerStr, upperStr)) lessStr := []byte("content-type") moreStr := []byte("content-length") assert.DeepEqual(t, false, CaseInsensitiveCompare(lessStr, moreStr)) firstStr := []byte("content-type") secondStr := []byte("content0type") assert.DeepEqual(t, false, CaseInsensitiveCompare(firstStr, secondStr)) } // NormalizeHeaderKey can upper the first letter and lower the other letter in // HTTP header, invervaled by '-'. // Example: "content-type" -> "Content-Type" func TestUtilsNormalizeHeaderKey(t *testing.T) { contentTypeStr := []byte("Content-Type") lowerContentTypeStr := []byte("content-type") mixedContentTypeStr := []byte("conTENt-tYpE") mixedContertTypeStrWithoutNormalizing := []byte("Content-type") NormalizeHeaderKey(contentTypeStr, false) NormalizeHeaderKey(lowerContentTypeStr, false) NormalizeHeaderKey(mixedContentTypeStr, false) NormalizeHeaderKey(lowerContentTypeStr, true) assert.DeepEqual(t, "Content-Type", string(contentTypeStr)) assert.DeepEqual(t, "Content-Type", string(lowerContentTypeStr)) assert.DeepEqual(t, "Content-Type", string(mixedContentTypeStr)) assert.DeepEqual(t, "Content-type", string(mixedContertTypeStrWithoutNormalizing)) } // Cutting up the header Type. // Example: "Content-Type: application/x-www-form-urlencoded\r\nDate: Fri, 6 Aug 2021 11:00:31 GMT" // ->"Content-Type: application/x-www-form-urlencoded" and "Date: Fri, 6 Aug 2021 11:00:31 GMT" func TestUtilsNextLine(t *testing.T) { multiHeaderStr := []byte("Content-Type: application/x-www-form-urlencoded\r\nDate: Fri, 6 Aug 2021 11:00:31 GMT") contentTypeStr, dateStr, hErr := NextLine(multiHeaderStr) assert.DeepEqual(t, nil, hErr) assert.DeepEqual(t, "Content-Type: application/x-www-form-urlencoded", string(contentTypeStr)) assert.DeepEqual(t, "Date: Fri, 6 Aug 2021 11:00:31 GMT", string(dateStr)) multiHeaderStrWithoutReturn := []byte("Content-Type: application/x-www-form-urlencoded\nDate: Fri, 6 Aug 2021 11:00:31 GMT") contentTypeStr, dateStr, hErr = NextLine(multiHeaderStrWithoutReturn) assert.DeepEqual(t, nil, hErr) assert.DeepEqual(t, "Content-Type: application/x-www-form-urlencoded", string(contentTypeStr)) assert.DeepEqual(t, "Date: Fri, 6 Aug 2021 11:00:31 GMT", string(dateStr)) singleHeaderStrWithFirstNewLine := []byte("\nContent-Type: application/x-www-form-urlencoded") firstStr, secondStr, sErr := NextLine(singleHeaderStrWithFirstNewLine) assert.DeepEqual(t, nil, sErr) assert.DeepEqual(t, string(""), string(firstStr)) assert.DeepEqual(t, "Content-Type: application/x-www-form-urlencoded", string(secondStr)) singleHeaderStr := []byte("Content-Type: application/x-www-form-urlencoded") _, _, sErr = NextLine(singleHeaderStr) assert.DeepEqual(t, errNeedMore, sErr) } func TestFilterContentType(t *testing.T) { contentType := "text/plain; charset=utf-8" contentType = FilterContentType(contentType) assert.DeepEqual(t, "text/plain", contentType) } func TestNormalizeHeaderKeyEdgeCases(t *testing.T) { empty := []byte("") NormalizeHeaderKey(empty, false) assert.DeepEqual(t, []byte(""), empty) NormalizeHeaderKey(empty, true) assert.DeepEqual(t, []byte(""), empty) } func TestFilterContentTypeEdgeCases(t *testing.T) { simpleContentType := "text/plain" assert.DeepEqual(t, "text/plain", FilterContentType(simpleContentType)) complexContentType := "text/html; charset=utf-8; format=flowed" assert.DeepEqual(t, "text/html", FilterContentType(complexContentType)) } ================================================ FILE: pkg/network/connection.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package network import ( "context" "crypto/tls" "fmt" "io" "net" "time" ) // Reader is for buffered Reader type Reader interface { // Peek returns the next n bytes without advancing the reader. Peek(n int) ([]byte, error) // Skip discards the next n bytes. Skip(n int) error // Release the memory space occupied by all read slices. This method needs to be executed actively to // recycle the memory after confirming that the previously read data is no longer in use. // After invoking Release, the slices obtained by the method such as Peek will // become an invalid address and cannot be used anymore. Release() error // Len returns the total length of the readable data in the reader. Len() int // ReadByte is used to read one byte with advancing the read pointer. ReadByte() (byte, error) // ReadBinary is used to read next n byte with copy, and the read pointer will be advanced. ReadBinary(n int) (p []byte, err error) } type Writer interface { // Malloc will provide a n bytes buffer to send data. Malloc(n int) (buf []byte, err error) // WriteBinary will use the user buffer to flush. // NOTE: Before flush successfully, the buffer b should be valid. WriteBinary(b []byte) (n int, err error) // Flush will send data to the peer end. Flush() error } type ReadWriter interface { Reader Writer } type Conn interface { net.Conn Reader Writer // SetReadTimeout should work for every Read process SetReadTimeout(t time.Duration) error SetWriteTimeout(t time.Duration) error } type ConnTLSer interface { Handshake() error ConnectionState() tls.ConnectionState } type HandleSpecificError interface { HandleSpecificError(err error, rip string) (needIgnore bool) } type ErrorNormalization interface { ToHertzError(err error) error } type DialFunc func(addr string) (Conn, error) /****************** Stream-based connection *******************/ // StreamConn is interface for stream-based connection abstraction. type StreamConn interface { GetRawConnection() interface{} // HandshakeComplete blocks until the handshake completes (or fails). HandshakeComplete() context.Context // GetVersion returns the version of the protocol used by the connection. GetVersion() uint32 // CloseWithError closes the connection with an error. // The error string will be sent to the peer. CloseWithError(err ApplicationError, errMsg string) error // LocalAddr returns the local address. LocalAddr() net.Addr // RemoteAddr returns the address of the peer. RemoteAddr() net.Addr // The context is cancelled when the connection is closed. Context() context.Context // Streamer is the interface for stream operations. Streamer } type Streamer interface { // AcceptStream returns the next stream opened by the peer, blocking until one is available. // If the connection was closed due to a timeout, the error satisfies // the net.Error interface, and Timeout() will be true. AcceptStream(context.Context) (Stream, error) // AcceptUniStream returns the next unidirectional stream opened by the peer, blocking until one is available. // If the connection was closed due to a timeout, the error satisfies // the net.Error interface, and Timeout() will be true. AcceptUniStream(context.Context) (ReceiveStream, error) // OpenStream opens a new bidirectional QUIC stream. // There is no signaling to the peer about new streams: // The peer can only accept the stream after data has been sent on the stream. // If the error is non-nil, it satisfies the net.Error interface. // When reaching the peer's stream limit, err.Temporary() will be true. // If the connection was closed due to a timeout, Timeout() will be true. OpenStream() (Stream, error) // OpenStreamSync opens a new bidirectional QUIC stream. // It blocks until a new stream can be opened. // If the error is non-nil, it satisfies the net.Error interface. // If the connection was closed due to a timeout, Timeout() will be true. OpenStreamSync(context.Context) (Stream, error) // OpenUniStream opens a new outgoing unidirectional QUIC stream. // If the error is non-nil, it satisfies the net.Error interface. // When reaching the peer's stream limit, Temporary() will be true. // If the connection was closed due to a timeout, Timeout() will be true. OpenUniStream() (SendStream, error) // OpenUniStreamSync opens a new outgoing unidirectional QUIC stream. // It blocks until a new stream can be opened. // If the error is non-nil, it satisfies the net.Error interface. // If the connection was closed due to a timeout, Timeout() will be true. OpenUniStreamSync(context.Context) (SendStream, error) } type Stream interface { ReceiveStream SendStream } // ReceiveStream is the interface for receiving data on a stream. type ReceiveStream interface { StreamID() int64 io.Reader // CancelRead aborts receiving on this stream. // It will ask the peer to stop transmitting stream data. // Read will unblock immediately, and future Read calls will fail. // When called multiple times or after reading the io.EOF it is a no-op. CancelRead(err ApplicationError) // SetReadDeadline sets the deadline for future Read calls and // any currently-blocked Read call. // A zero value for t means Read will not time out. SetReadDeadline(t time.Time) error } // SendStream is the interface for sending data on a stream. type SendStream interface { StreamID() int64 // Writer writes data to the stream. // Write can be made to time out and return a net.Error with Timeout() == true // after a fixed time limit; see SetDeadline and SetWriteDeadline. // If the stream was canceled by the peer, the error implements the StreamError // interface, and Canceled() == true. // If the connection was closed due to a timeout, the error satisfies // the net.Error interface, and Timeout() will be true. io.Writer // CancelWrite aborts sending on this stream. // Data already written, but not yet delivered to the peer is not guaranteed to be delivered reliably. // Write will unblock immediately, and future calls to Write will fail. // When called multiple times or after closing the stream it is a no-op. CancelWrite(err ApplicationError) // Closer closes the write-direction of the stream. // Future calls to Write are not permitted after calling Close. // It must not be called concurrently with Write. // It must not be called after calling CancelWrite. io.Closer // The Context is canceled as soon as the write-side of the stream is closed. // This happens when Close() or CancelWrite() is called, or when the peer // cancels the read-side of their stream. Context() context.Context // SetWriteDeadline sets the deadline for future Write calls // and any currently-blocked Write call. // Even if write times out, it may return n > 0, indicating that // some data was successfully written. // A zero value for t means Write will not time out. SetWriteDeadline(t time.Time) error } type ApplicationError interface { ErrCode() uint64 fmt.Stringer } ================================================ FILE: pkg/network/dialer/dialer.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package dialer import ( "crypto/tls" "net" "time" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/network/standard" ) // will be netpoll.NewDialer() if available, see netpoll.go var defaultDialer network.Dialer = standard.NewDialer() // SetDialer is used to set the global default dialer. // Deprecated: use WithDialer instead. func SetDialer(dialer network.Dialer) { defaultDialer = dialer } func DefaultDialer() network.Dialer { return defaultDialer } func DialConnection(network, address string, timeout time.Duration, tlsConfig *tls.Config) (conn network.Conn, err error) { return defaultDialer.DialConnection(network, address, timeout, tlsConfig) } func DialTimeout(network, address string, timeout time.Duration, tlsConfig *tls.Config) (conn net.Conn, err error) { return defaultDialer.DialTimeout(network, address, timeout, tlsConfig) } // AddTLS is used to add tls to a persistent connection, i.e. negotiate a TLS session. If conn is already a TLS // tunnel, this function establishes a nested TLS session inside the encrypted channel. func AddTLS(conn network.Conn, tlsConfig *tls.Config) (network.Conn, error) { return defaultDialer.AddTLS(conn, tlsConfig) } ================================================ FILE: pkg/network/dialer/dialer_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package dialer import ( "crypto/tls" "errors" "net" "testing" "time" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/network" ) func TestDialer(t *testing.T) { SetDialer(&mockDialer{}) dialer := DefaultDialer() assert.DeepEqual(t, &mockDialer{}, dialer) _, err := AddTLS(nil, nil) assert.NotNil(t, err) _, err = DialConnection("", "", 0, nil) assert.NotNil(t, err) _, err = DialTimeout("", "", 0, nil) assert.NotNil(t, err) } type mockDialer struct{} func (m *mockDialer) DialConnection(network, address string, timeout time.Duration, tlsConfig *tls.Config) (conn network.Conn, err error) { return nil, errors.New("method not implement") } func (m *mockDialer) DialTimeout(network, address string, timeout time.Duration, tlsConfig *tls.Config) (conn net.Conn, err error) { return nil, errors.New("method not implement") } func (m *mockDialer) AddTLS(conn network.Conn, tlsConfig *tls.Config) (network.Conn, error) { return nil, errors.New("method not implement") } ================================================ FILE: pkg/network/dialer/netpoll.go ================================================ //go:build (amd64 || arm64) && (linux || darwin) /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package dialer import ( "os" "strconv" "github.com/cloudwego/hertz/pkg/network/netpoll" ) func init() { if v, _ := strconv.ParseBool(os.Getenv("HERTZ_NO_NETPOLL")); !v { defaultDialer = netpoll.NewDialer() } } ================================================ FILE: pkg/network/dialer.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package network import ( "crypto/tls" "net" "time" ) type Dialer interface { // DialConnection is used to dial the peer end. DialConnection(network, address string, timeout time.Duration, tlsConfig *tls.Config) (conn Conn, err error) // DialTimeout is used to dial the peer end with a timeout. // // NOTE: Not recommended to use this function. Just for compatibility. DialTimeout(network, address string, timeout time.Duration, tlsConfig *tls.Config) (conn net.Conn, err error) // AddTLS will transfer a common connection to a tls connection. AddTLS(conn Conn, tlsConfig *tls.Config) (Conn, error) } ================================================ FILE: pkg/network/netpoll/connection.go ================================================ // Copyright 2022 CloudWeGo Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // package netpoll import ( "errors" "io" "strings" "syscall" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/hlog" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/netpoll" ) type Conn struct { network.Conn } func (c *Conn) ToHertzError(err error) error { if errors.Is(err, netpoll.ErrConnClosed) || errors.Is(err, syscall.EPIPE) { return errs.ErrConnectionClosed } // only unify read timeout for now if errors.Is(err, netpoll.ErrReadTimeout) { return errs.ErrTimeout } return err } func (c *Conn) Peek(n int) (b []byte, err error) { b, err = c.Conn.Peek(n) err = normalizeErr(err) return } func (c *Conn) Read(p []byte) (int, error) { n, err := c.Conn.Read(p) err = normalizeErr(err) return n, err } func (c *Conn) Skip(n int) error { return c.Conn.Skip(n) } func (c *Conn) Release() error { return c.Conn.Release() } func (c *Conn) Len() int { return c.Conn.Len() } func (c *Conn) ReadByte() (b byte, err error) { b, err = c.Conn.ReadByte() err = normalizeErr(err) return } func (c *Conn) ReadBinary(n int) (b []byte, err error) { b, err = c.Conn.ReadBinary(n) err = normalizeErr(err) return } func (c *Conn) Malloc(n int) (buf []byte, err error) { return c.Conn.Malloc(n) } func (c *Conn) WriteBinary(b []byte) (n int, err error) { return c.Conn.WriteBinary(b) } func (c *Conn) Flush() error { return c.Conn.Flush() } func (c *Conn) HandleSpecificError(err error, rip string) (needIgnore bool) { if errors.Is(err, netpoll.ErrConnClosed) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) { // ignore flushing error when connection is closed or reset if strings.Contains(err.Error(), "when flush") { return true } hlog.SystemLogger().Debugf("Netpoll error=%s, remoteAddr=%s", err.Error(), rip) return true } return false } func normalizeErr(err error) error { if errors.Is(err, netpoll.ErrEOF) { return io.EOF } return err } func newConn(c netpoll.Connection) network.Conn { return &Conn{Conn: c.(network.Conn)} } ================================================ FILE: pkg/network/netpoll/connection_test.go ================================================ // Copyright 2022 CloudWeGo Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // //go:build !windows package netpoll import ( "errors" "net" "testing" "time" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/netpoll" ) func TestReadBytes(t *testing.T) { c := &mockConn{[]byte("a"), nil, 0} conn := newConn(c) assert.DeepEqual(t, 1, conn.Len()) b, _ := conn.Peek(1) assert.DeepEqual(t, []byte{'a'}, b) readByte, _ := conn.ReadByte() assert.DeepEqual(t, byte('a'), readByte) _, err := conn.ReadByte() assert.DeepEqual(t, errors.New("readByte error: index out of range"), err) c = &mockConn{[]byte("bcd"), nil, 0} conn = newConn(c) readBinary, _ := conn.ReadBinary(2) assert.DeepEqual(t, []byte{'b', 'c'}, readBinary) _, err = conn.ReadBinary(2) assert.DeepEqual(t, errors.New("readBinary error: index out of range"), err) } func TestPeekRelease(t *testing.T) { c := &mockConn{[]byte("abcdefg"), nil, 0} conn := newConn(c) // release the buf conn.Release() _, err := conn.Peek(1) assert.DeepEqual(t, errors.New("peek error"), err) assert.DeepEqual(t, errors.New("skip error"), conn.Skip(2)) } func TestWriteLogin(t *testing.T) { c := &mockConn{nil, []byte("abcdefg"), 0} conn := newConn(c) buf, _ := conn.Malloc(10) assert.DeepEqual(t, 10, len(buf)) n, _ := conn.WriteBinary([]byte("abcdefg")) assert.DeepEqual(t, 7, n) assert.DeepEqual(t, errors.New("flush error"), conn.Flush()) } func TestHandleSpecificError(t *testing.T) { conn := &Conn{} assert.DeepEqual(t, false, conn.HandleSpecificError(nil, "")) assert.DeepEqual(t, true, conn.HandleSpecificError(netpoll.ErrConnClosed, "")) } type mockConn struct { readBuf []byte writeBuf []byte // index for the first readable byte in readBuf off int } func (m *mockConn) SetWriteTimeout(timeout time.Duration) error { // TODO implement me panic("implement me") } // mockConn's methods is simplified for unit test // Peek returns the next n bytes without advancing the reader func (m *mockConn) Peek(n int) (b []byte, err error) { if m.off+n-1 < len(m.readBuf) { return m.readBuf[m.off : m.off+n], nil } return nil, errors.New("peek error") } // Skip discards the next n bytes func (m *mockConn) Skip(n int) error { if m.off+n < len(m.readBuf) { m.off += n return nil } return errors.New("skip error") } // Release the memory space occupied by all read slices func (m *mockConn) Release() error { m.readBuf = nil m.off = 0 return nil } // Len returns the total length of the readable data in the reader func (m *mockConn) Len() int { return len(m.readBuf) - m.off } // ReadByte is used to read one byte with advancing the read pointer func (m *mockConn) ReadByte() (byte, error) { if m.off < len(m.readBuf) { m.off++ return m.readBuf[m.off-1], nil } return 0, errors.New("readByte error: index out of range") } // ReadBinary is used to read next n byte with copy, and the read pointer will be advanced func (m *mockConn) ReadBinary(n int) (b []byte, err error) { if m.off+n < len(m.readBuf) { m.off += n return m.readBuf[m.off-n : m.off], nil } return nil, errors.New("readBinary error: index out of range") } // Malloc will provide a n bytes buffer to send data func (m *mockConn) Malloc(n int) (buf []byte, err error) { m.writeBuf = make([]byte, n) return m.writeBuf, nil } // WriteBinary will use the user buffer to flush func (m *mockConn) WriteBinary(b []byte) (n int, err error) { return len(b), nil } // Flush will send data to the peer end func (m *mockConn) Flush() error { return errors.New("flush error") } func (m *mockConn) HandleSpecificError(err error, rip string) (needIgnore bool) { panic("implement me") } func (m *mockConn) Read(b []byte) (n int, err error) { panic("implement me") } func (m *mockConn) Write(b []byte) (n int, err error) { panic("implement me") } func (m *mockConn) Close() error { panic("implement me") } func (m *mockConn) LocalAddr() net.Addr { panic("implement me") } func (m *mockConn) RemoteAddr() net.Addr { panic("implement me") } func (m *mockConn) SetDeadline(deadline time.Time) error { panic("implement me") } func (m *mockConn) SetReadDeadline(deadline time.Time) error { panic("implement me") } func (m *mockConn) SetWriteDeadline(deadline time.Time) error { panic("implement me") } func (m *mockConn) Reader() netpoll.Reader { panic("implement me") } func (m *mockConn) Writer() netpoll.Writer { panic("implement me") } func (m *mockConn) IsActive() bool { panic("implement me") } func (m *mockConn) SetReadTimeout(timeout time.Duration) error { panic("implement me") } func (m *mockConn) SetIdleTimeout(timeout time.Duration) error { panic("implement me") } func (m *mockConn) SetOnRequest(on netpoll.OnRequest) error { panic("implement me") } func (m *mockConn) AddCloseCallback(callback netpoll.CloseCallback) error { panic("implement me") } ================================================ FILE: pkg/network/netpoll/dial.go ================================================ // Copyright 2022 CloudWeGo Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // package netpoll import ( "crypto/tls" "net" "time" "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/netpoll" ) var errNotSupportTLS = errors.NewPublic("not support tls") type dialer struct { netpoll.Dialer } func (d dialer) DialConnection(n, address string, timeout time.Duration, tlsConfig *tls.Config) (conn network.Conn, err error) { if tlsConfig != nil { // https return nil, errNotSupportTLS } c, err := d.Dialer.DialConnection(n, address, timeout) if err != nil { return nil, err } conn = newConn(c) return } func (d dialer) DialTimeout(network, address string, timeout time.Duration, tlsConfig *tls.Config) (conn net.Conn, err error) { if tlsConfig != nil { return nil, errNotSupportTLS } conn, err = d.Dialer.DialTimeout(network, address, timeout) return } func (d dialer) AddTLS(conn network.Conn, tlsConfig *tls.Config) (network.Conn, error) { return nil, errNotSupportTLS } func NewDialer() network.Dialer { return dialer{Dialer: netpoll.NewDialer()} } ================================================ FILE: pkg/network/netpoll/dial_test.go ================================================ // Copyright 2023 CloudWeGo Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // //go:build !windows package netpoll import ( "context" "crypto/tls" "testing" "time" "github.com/cloudwego/hertz/internal/testutils" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/test/mock" ) func TestDial(t *testing.T) { t.Run("NetpollDial", func(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() transporter := NewTransporter(&config.Options{ Listener: ln, }) go transporter.ListenAndServe(func(ctx context.Context, conn interface{}) error { return nil }) defer transporter.Close() time.Sleep(100 * time.Millisecond) dial := NewDialer() addr := ln.Addr().String() nw := ln.Addr().Network() // DialConnection _, err := dial.DialConnection(nw, "localhost:10101", time.Second, nil) // wrong addr assert.NotNil(t, err) nwConn, err := dial.DialConnection(nw, addr, time.Second, nil) assert.Nil(t, err) defer nwConn.Close() _, err = nwConn.Write([]byte("abcdef")) assert.Nil(t, err) // DialTimeout nConn, err := dial.DialTimeout("tcp", addr, time.Second, nil) assert.Nil(t, err) defer nConn.Close() }) t.Run("NotSupportTLS", func(t *testing.T) { dial := NewDialer() _, err := dial.AddTLS(mock.NewConn(""), nil) assert.DeepEqual(t, errNotSupportTLS, err) _, err = dial.DialConnection("tcp", "localhost:10102", time.Microsecond, &tls.Config{}) assert.DeepEqual(t, errNotSupportTLS, err) _, err = dial.DialTimeout("tcp", "localhost:10102", time.Microsecond, &tls.Config{}) assert.DeepEqual(t, errNotSupportTLS, err) }) } ================================================ FILE: pkg/network/netpoll/transport.go ================================================ // Copyright 2022 CloudWeGo Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // //go:build !windows package netpoll import ( "context" "io" "net" "sync" "time" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/common/hlog" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/netpoll" ) func init() { // disable netpoll's log netpoll.SetLoggerOutput(io.Discard) } type ctxCancelKeyStruct struct{} var ctxCancelKey = ctxCancelKeyStruct{} func cancelContext(ctx context.Context) context.Context { ctx, cancel := context.WithCancel(ctx) ctx = context.WithValue(ctx, ctxCancelKey, cancel) return ctx } type transporter struct { senseClientDisconnection bool network string addr string keepAliveTimeout time.Duration readTimeout time.Duration writeTimeout time.Duration listenConfig *net.ListenConfig OnAccept func(conn net.Conn) context.Context OnConnect func(ctx context.Context, conn network.Conn) context.Context mu sync.RWMutex ln net.Listener el netpoll.EventLoop } // For transporter switch func NewTransporter(options *config.Options) network.Transporter { return &transporter{ senseClientDisconnection: options.SenseClientDisconnection, network: options.Network, addr: options.Addr, keepAliveTimeout: options.KeepAliveTimeout, readTimeout: options.ReadTimeout, writeTimeout: options.WriteTimeout, ln: options.Listener, listenConfig: options.ListenConfig, OnAccept: options.OnAccept, OnConnect: options.OnConnect, } } func (t *transporter) Listener() net.Listener { t.mu.RLock() defer t.mu.RUnlock() return t.ln } // ListenAndServe binds listen address and keep serving, until an error occurs // or the transport shutdowns func (t *transporter) ListenAndServe(onReq network.OnData) (err error) { network.UnlinkUdsFile(t.network, t.addr) //nolint:errcheck t.mu.Lock() if t.ln == nil { if t.listenConfig != nil { t.ln, err = t.listenConfig.Listen(context.Background(), t.network, t.addr) } else { t.ln, err = net.Listen(t.network, t.addr) } if err != nil { t.mu.Unlock() panic("create netpoll listener fail: " + err.Error()) } } ln := t.ln t.mu.Unlock() // Initialize custom option for EventLoop opts := []netpoll.Option{ netpoll.WithIdleTimeout(t.keepAliveTimeout), netpoll.WithOnPrepare(func(conn netpoll.Connection) context.Context { conn.SetReadTimeout(t.readTimeout) // nolint:errcheck if t.writeTimeout > 0 { conn.SetWriteTimeout(t.writeTimeout) } ctx := context.Background() if t.OnAccept != nil { ctx = t.OnAccept(newConn(conn)) } if t.senseClientDisconnection { ctx = cancelContext(ctx) } return ctx }), } if t.OnConnect != nil { opts = append(opts, netpoll.WithOnConnect(func(ctx context.Context, conn netpoll.Connection) context.Context { return t.OnConnect(ctx, newConn(conn)) })) } if t.senseClientDisconnection { opts = append(opts, netpoll.WithOnDisconnect(func(ctx context.Context, connection netpoll.Connection) { cancelFunc, ok := ctx.Value(ctxCancelKey).(context.CancelFunc) if cancelFunc != nil && ok { cancelFunc() } })) } // Create EventLoop t.mu.Lock() t.el, err = netpoll.NewEventLoop(func(ctx context.Context, connection netpoll.Connection) error { return onReq(ctx, newConn(connection)) }, opts...) eventLoop := t.el t.mu.Unlock() if err != nil { panic("create netpoll event-loop fail") } // Start Server hlog.SystemLogger().Infof("HTTP server listening on address=%s", ln.Addr().String()) err = eventLoop.Serve(ln) if err != nil { panic("netpoll server exit") } return nil } // Close forces transport to close immediately (no wait timeout) func (t *transporter) Close() error { ctx, cancel := context.WithTimeout(context.Background(), 0) defer cancel() return t.Shutdown(ctx) } // Shutdown will trigger listener stop and graceful shutdown // It will wait all connections close until reaching ctx.Deadline() func (t *transporter) Shutdown(ctx context.Context) error { defer func() { network.UnlinkUdsFile(t.network, t.addr) //nolint:errcheck t.mu.RUnlock() }() t.mu.RLock() if t.el == nil { return nil } return t.el.Shutdown(ctx) } ================================================ FILE: pkg/network/netpoll/transport_test.go ================================================ // Copyright 2023 CloudWeGo Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // //go:build !windows package netpoll import ( "context" "net" "sync/atomic" "syscall" "testing" "time" "github.com/cloudwego/hertz/internal/testutils" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/network" "golang.org/x/sys/unix" ) func TestTransport(t *testing.T) { t.Run("TestDefault", func(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() var onConnFlag, onAcceptFlag, onDataFlag int32 transporter := NewTransporter(&config.Options{ Listener: ln, OnConnect: func(ctx context.Context, conn network.Conn) context.Context { atomic.StoreInt32(&onConnFlag, 1) return ctx }, WriteTimeout: time.Second, OnAccept: func(conn net.Conn) context.Context { atomic.StoreInt32(&onAcceptFlag, 1) return context.Background() }, }) go transporter.ListenAndServe(func(ctx context.Context, conn interface{}) error { atomic.StoreInt32(&onDataFlag, 1) return nil }) defer transporter.Close() time.Sleep(100 * time.Millisecond) addr := ln.Addr().String() nw := ln.Addr().Network() dial := NewDialer() conn, err := dial.DialConnection(nw, addr, time.Second, nil) assert.Nil(t, err) _, err = conn.Write([]byte("123")) assert.Nil(t, err) time.Sleep(100 * time.Millisecond) assert.Assert(t, atomic.LoadInt32(&onConnFlag) == 1) assert.Assert(t, atomic.LoadInt32(&onAcceptFlag) == 1) assert.Assert(t, atomic.LoadInt32(&onDataFlag) == 1) }) t.Run("TestSenseClientDisconnection", func(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() var onReqFlag int32 transporter := NewTransporter(&config.Options{ Listener: ln, SenseClientDisconnection: true, }) go transporter.ListenAndServe(func(ctx context.Context, conn interface{}) error { atomic.StoreInt32(&onReqFlag, 1) time.Sleep(100 * time.Millisecond) assert.DeepEqual(t, context.Canceled, ctx.Err()) return nil }) defer transporter.Close() time.Sleep(100 * time.Millisecond) addr := ln.Addr().String() nw := ln.Addr().Network() dial := NewDialer() conn, err := dial.DialConnection(nw, addr, time.Second, nil) assert.Nil(t, err) _, err = conn.Write([]byte("123")) assert.Nil(t, err) err = conn.Close() assert.Nil(t, err) time.Sleep(100 * time.Millisecond) assert.Assert(t, atomic.LoadInt32(&onReqFlag) == 1) }) t.Run("TestListenConfig", func(t *testing.T) { listenCfg := &net.ListenConfig{Control: func(network, address string, c syscall.RawConn) error { return c.Control(func(fd uintptr) { syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEADDR, 1) syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1) }) }} transporter := NewTransporter(&config.Options{ Network: "tcp", Addr: "127.0.0.1:0", ListenConfig: listenCfg, }) go transporter.ListenAndServe(func(ctx context.Context, conn interface{}) error { return nil }) defer transporter.Close() }) t.Run("TestExceptionCase", func(t *testing.T) { assert.Panic(t, func() { // listen err transporter := NewTransporter(&config.Options{ Network: "unknown", }) transporter.ListenAndServe(func(ctx context.Context, conn interface{}) error { return nil }) }) }) t.Run("TestWithListener", func(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() var onDataFlag int32 trans := NewTransporter(&config.Options{ Listener: ln, }).(*transporter) go trans.ListenAndServe(func(ctx context.Context, conn interface{}) error { atomic.StoreInt32(&onDataFlag, 1) return nil }) defer trans.Close() time.Sleep(100 * time.Millisecond) // Verify listener is used assert.DeepEqual(t, ln.Addr().String(), trans.Listener().Addr().String()) nw := ln.Addr().Network() // Connect and send data dial := NewDialer() conn, err := dial.DialConnection(nw, ln.Addr().String(), time.Second, nil) assert.Nil(t, err) _, err = conn.Write([]byte("test")) assert.Nil(t, err) time.Sleep(100 * time.Millisecond) assert.Assert(t, atomic.LoadInt32(&onDataFlag) == 1) }) } ================================================ FILE: pkg/network/standard/connection.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package standard import ( "crypto/tls" "errors" "io" "net" "strconv" "syscall" "time" "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/gopkg/net/connstate" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/hlog" "github.com/cloudwego/hertz/pkg/network" ) type Conn struct { c net.Conn br *bufiox.DefaultReader bw *bufiox.DefaultWriter stater connstate.ConnStater buf [8]byte } func (c *Conn) ToHertzError(err error) error { if errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ENOTCONN) { return errs.ErrConnectionClosed } if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() { return errs.ErrTimeout } return err } func (c *Conn) SetWriteTimeout(t time.Duration) error { if t <= 0 { return c.c.SetWriteDeadline(time.Time{}) } return c.c.SetWriteDeadline(time.Now().Add(t)) } func (c *Conn) SetReadTimeout(t time.Duration) error { if t <= 0 { return c.c.SetReadDeadline(time.Time{}) } return c.c.SetReadDeadline(time.Now().Add(t)) } type TLSConn struct { Conn } // Peek returns the next n bytes without advancing the reader. If Peek returns // fewer than n bytes, it also returns an error explaining why the read is short. func (c *Conn) Peek(n int) ([]byte, error) { buf, err := c.br.Peek(n) // bufiox readAtLeast converts partial-read+EOF to ErrUnexpectedEOF, // but hertz protocol code expects the original io.EOF. if err == io.ErrUnexpectedEOF { err = io.EOF } return buf, err } // Skip discards the next n bytes. func (c *Conn) Skip(n int) error { if c.Len() < n { return errs.NewPrivate("skip[" + strconv.Itoa(n) + "] not enough") } return c.br.Skip(n) } // Release frees internal read buffers. func (c *Conn) Release() error { return c.br.Release(nil) } // Len returns the total length of the readable data in the reader. func (c *Conn) Len() int { return c.br.Buffered() } // ReadByte is used to read one byte with advancing the read pointer. func (c *Conn) ReadByte() (b byte, err error) { // Use Read instead of Peek+Skip to avoid holding a ref to the underlying buffer. _, err = c.br.Read(c.buf[:1]) if err == nil { b = c.buf[0] } return } // ReadBinary is used to read next n byte with copy, and the read pointer will be advanced. func (c *Conn) ReadBinary(n int) ([]byte, error) { out := make([]byte, n) _, err := c.br.ReadBinary(out) if err != nil { return nil, err } return out, nil } // Read implements io.Reader. func (c *Conn) Read(b []byte) (int, error) { return c.br.Read(b) } // Write calls Write syscall directly to send data. // Will flush buffer immediately, for performance considerations use WriteBinary instead. func (c *Conn) Write(b []byte) (n int, err error) { if err = c.Flush(); err != nil { return } return c.c.Write(b) } // ReadFrom implements io.ReaderFrom. If the underlying writer // supports the ReadFrom method, this calls the underlying ReadFrom // without buffering. func (c *Conn) ReadFrom(r io.Reader) (n int64, err error) { if err = c.Flush(); err != nil { return } if w, ok := c.c.(io.ReaderFrom); ok { n, err = w.ReadFrom(r) return } var buf [32 * 1024]byte for { m, rerr := r.Read(buf[:]) if m > 0 { dst, werr := c.bw.Malloc(m) if werr != nil { return n, werr } copy(dst, buf[:m]) n += int64(m) } if rerr != nil { if rerr != io.EOF { err = rerr } return } } } // Close closes the connection func (c *Conn) Close() error { // Close stater first to stop epoll monitoring if c.stater != nil { c.stater.Close() c.stater = nil } return c.c.Close() } // LocalAddr returns the local address of the connection. func (c *Conn) LocalAddr() net.Addr { return c.c.LocalAddr() } // RemoteAddr returns the remote address of the connection. func (c *Conn) RemoteAddr() net.Addr { return c.c.RemoteAddr() } // SetDeadline sets the connection deadline. func (c *Conn) SetDeadline(t time.Time) error { return c.c.SetDeadline(t) } // SetReadDeadline sets the read deadline of the connection. func (c *Conn) SetReadDeadline(t time.Time) error { return c.c.SetReadDeadline(t) } // SetWriteDeadline sets the write deadline of the connection. func (c *Conn) SetWriteDeadline(t time.Time) error { return c.c.SetWriteDeadline(t) } // Malloc will provide a n bytes buffer to send data. func (c *Conn) Malloc(n int) ([]byte, error) { return c.bw.Malloc(n) } // WriteBinary will use the user buffer to flush. // NOTE: Before flush successfully, the buffer b should be valid. func (c *Conn) WriteBinary(b []byte) (int, error) { return c.bw.WriteBinary(b) } // Flush will send data to the peer end. func (c *Conn) Flush() error { return c.bw.Flush() } func (c *Conn) HandleSpecificError(err error, rip string) (needIgnore bool) { if errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ECONNRESET) { hlog.SystemLogger().Debugf("Go net library error=%s, remoteAddr=%s", err.Error(), rip) return true } return false } func (c *TLSConn) Handshake() error { return c.c.(network.ConnTLSer).Handshake() } func (c *TLSConn) ConnectionState() tls.ConnectionState { return c.c.(network.ConnTLSer).ConnectionState() } func newConn(c net.Conn, size int) network.Conn { return &Conn{ c: c, br: bufiox.NewDefaultReaderSize(c, size), bw: bufiox.NewDefaultWriter(c), } } func newTLSConn(c net.Conn, size int) network.Conn { return &TLSConn{ Conn{ c: c, br: bufiox.NewDefaultReaderSize(c, size), bw: bufiox.NewDefaultWriter(c), }, } } ================================================ FILE: pkg/network/standard/connection_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package standard import ( "bytes" "crypto/tls" "errors" "io" "net" "strings" "syscall" "testing" "time" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/test/assert" ) // --- helpers --- func mkConn(data []byte) (*bufConn, *Conn) { c := &bufConn{r: bytes.NewReader(data), w: &bytes.Buffer{}} return c, newConn(c, 4096).(*Conn) } type bufConn struct { mockConn r io.Reader w *bytes.Buffer } func (c *bufConn) Read(b []byte) (int, error) { return c.r.Read(b) } func (c *bufConn) Write(b []byte) (int, error) { return c.w.Write(b) } type mockConn struct { buffer bytes.Buffer localAddr net.Addr remoteAddr net.Addr } func (m *mockConn) Handshake() error { return errors.New("not supported") } func (m *mockConn) ConnectionState() tls.ConnectionState { return tls.ConnectionState{} } func (m mockConn) Read(b []byte) (int, error) { for i := range b { b[i] = 0 } n := len(b) if n > 8192 { n = 8192 } return n, nil } func (m *mockConn) Write(b []byte) (int, error) { return m.buffer.Write(b) } func (m *mockConn) Close() error { return errors.New("not supported") } func (m *mockConn) LocalAddr() net.Addr { return m.localAddr } func (m *mockConn) RemoteAddr() net.Addr { return m.remoteAddr } func (m *mockConn) SetDeadline(t time.Time) error { return m.SetWriteDeadline(t) } func (m *mockConn) SetReadDeadline(time.Time) error { return errors.New("read deadline not supported") } func (m *mockConn) SetWriteDeadline(time.Time) error { return errors.New("write deadline not supported") } type errReader struct { data []byte err error done bool } func (r *errReader) Read(b []byte) (int, error) { if r.done { return 0, r.err } r.done = true return copy(b, r.data), nil } type readerFromConn struct { mockConn } func (c *readerFromConn) ReadFrom(r io.Reader) (int64, error) { return io.Copy(&c.buffer, r) } // --- tests --- func TestRead(t *testing.T) { data := bytes.Repeat([]byte{1}, 10000) _, conn := mkConn(data) b := make([]byte, 5) n, err := conn.Read(b) assert.Nil(t, err) assert.DeepEqual(t, 5, n) b = make([]byte, 20000) n, err = conn.Read(b) assert.Nil(t, err) assert.True(t, n > 0) } func TestPeekSkipRelease(t *testing.T) { data := bytes.Repeat([]byte{0}, 10000) _, conn := mkConn(data) b, err := conn.Peek(100) assert.Nil(t, err) assert.DeepEqual(t, 100, len(b)) // skip more than buffered fails assert.NotNil(t, conn.Skip(conn.Len()+1)) // skip all succeeds assert.Nil(t, conn.Skip(conn.Len())) assert.DeepEqual(t, 0, conn.Len()) // release then peek refills conn.Release() b, err = conn.Peek(1) assert.Nil(t, err) assert.DeepEqual(t, 1, len(b)) } func TestReadByteAndReadBinary(t *testing.T) { _, conn := mkConn([]byte("abcdef")) rb, err := conn.ReadBinary(3) assert.Nil(t, err) assert.DeepEqual(t, []byte("abc"), rb) by, err := conn.ReadByte() assert.Nil(t, err) assert.DeepEqual(t, byte('d'), by) } func TestWriteAndFlush(t *testing.T) { c := &mockConn{} conn := newConn(c, 4096).(*Conn) buf, _ := conn.Malloc(5) copy(buf, []byte("hello")) conn.WriteBinary([]byte(" world")) assert.Nil(t, conn.Flush()) assert.DeepEqual(t, "hello world", c.buffer.String()) } func TestReadFrom(t *testing.T) { t.Run("fallback loop", func(t *testing.T) { c := &mockConn{} conn := newConn(c, 4096) n, err := conn.(io.ReaderFrom).ReadFrom(strings.NewReader("hello")) assert.Nil(t, err) assert.DeepEqual(t, int64(5), n) conn.Flush() assert.DeepEqual(t, "hello", c.buffer.String()) }) t.Run("reader error", func(t *testing.T) { readErr := errors.New("disk failure") c := &mockConn{} conn := newConn(c, 4096) n, err := conn.(io.ReaderFrom).ReadFrom(&errReader{data: []byte("partial"), err: readErr}) assert.DeepEqual(t, readErr, err) assert.DeepEqual(t, int64(7), n) }) t.Run("underlying ReaderFrom", func(t *testing.T) { c := &readerFromConn{} conn := newConn(c, 4096) n, err := conn.(io.ReaderFrom).ReadFrom(strings.NewReader("hello")) assert.Nil(t, err) assert.DeepEqual(t, int64(5), n) assert.DeepEqual(t, "hello", c.buffer.String()) }) } func TestPeekPartialEOF(t *testing.T) { _, conn := mkConn(make([]byte, 100)) b, err := conn.Peek(100) assert.Nil(t, err) assert.DeepEqual(t, 100, len(b)) conn.Skip(10) b, err = conn.Peek(100) assert.DeepEqual(t, io.EOF, err) assert.DeepEqual(t, 90, len(b)) } func TestConnAddrs(t *testing.T) { local := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 80} remote := &net.TCPAddr{IP: net.IPv4(10, 0, 0, 1), Port: 9000} c := &mockConn{localAddr: local, remoteAddr: remote} conn := newConn(c, 4096) assert.DeepEqual(t, local, conn.LocalAddr()) assert.DeepEqual(t, remote, conn.RemoteAddr()) } func TestSetDeadlines(t *testing.T) { c := &mockConn{} conn := newConn(c, 4096) assert.NotNil(t, conn.SetDeadline(time.Time{})) assert.NotNil(t, conn.SetReadDeadline(time.Time{})) assert.NotNil(t, conn.SetWriteDeadline(time.Time{})) } func TestSetTimeouts(t *testing.T) { c := &mockConn{} conn := newConn(c, 4096) assert.NotNil(t, conn.SetReadTimeout(time.Second)) assert.NotNil(t, conn.SetReadTimeout(-1)) assert.NotNil(t, conn.SetWriteTimeout(time.Second)) assert.NotNil(t, conn.SetWriteTimeout(-1)) } func TestToHertzError(t *testing.T) { conn := &Conn{} other := errors.New("other") assert.DeepEqual(t, errs.ErrConnectionClosed, conn.ToHertzError(syscall.EPIPE)) assert.DeepEqual(t, errs.ErrConnectionClosed, conn.ToHertzError(syscall.ENOTCONN)) assert.DeepEqual(t, errs.ErrTimeout, conn.ToHertzError(&net.OpError{Op: "read", Err: &timeoutErr{}})) assert.DeepEqual(t, other, conn.ToHertzError(other)) } func TestHandleSpecificError(t *testing.T) { conn := &Conn{} assert.DeepEqual(t, false, conn.HandleSpecificError(nil, "")) assert.DeepEqual(t, true, conn.HandleSpecificError(syscall.EPIPE, "")) } func TestTLSConn(t *testing.T) { c := &mockConn{} tc := newTLSConn(c, 4096).(*TLSConn) assert.NotNil(t, tc.Handshake()) assert.DeepEqual(t, tls.ConnectionState{}, tc.ConnectionState()) } type mockAddr struct { network string address string } func (m *mockAddr) Network() string { return m.network } func (m *mockAddr) String() string { return m.address } type timeoutErr struct{} func (e *timeoutErr) Error() string { return "timeout" } func (e *timeoutErr) Timeout() bool { return true } func (e *timeoutErr) Temporary() bool { return true } ================================================ FILE: pkg/network/standard/dial.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package standard import ( "crypto/tls" "net" "time" "github.com/cloudwego/hertz/pkg/network" ) type dialer struct{} func (d *dialer) DialConnection(n, address string, timeout time.Duration, tlsConfig *tls.Config) (conn network.Conn, err error) { c, err := net.DialTimeout(n, address, timeout) if tlsConfig != nil { cTLS := tls.Client(c, tlsConfig) conn = newTLSConn(cTLS, 0) return } conn = newConn(c, 0) return } func (d *dialer) DialTimeout(network, address string, timeout time.Duration, tlsConfig *tls.Config) (conn net.Conn, err error) { conn, err = net.DialTimeout(network, address, timeout) return } func (d *dialer) AddTLS(conn network.Conn, tlsConfig *tls.Config) (network.Conn, error) { cTlS := tls.Client(conn, tlsConfig) err := cTlS.Handshake() if err != nil { return nil, err } conn = newTLSConn(cTlS, 0) return conn, nil } func NewDialer() network.Dialer { return &dialer{} } ================================================ FILE: pkg/network/standard/dial_test.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package standard import ( "context" "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/pem" "math/big" "net" "testing" "time" "github.com/cloudwego/hertz/internal/testutils" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestDial(t *testing.T) { const nw = "tcp" ln := testutils.NewTestListener(t) defer ln.Close() transporter := NewTransporter(&config.Options{ Listener: ln, Network: nw, }) go transporter.ListenAndServe(func(ctx context.Context, conn interface{}) error { return nil }) defer transporter.Close() time.Sleep(time.Millisecond * 100) addr := ln.Addr().String() dial := NewDialer() _, err := dial.DialConnection(nw, addr, time.Second, nil) assert.Nil(t, err) nConn, err := dial.DialTimeout(nw, addr, time.Second, nil) assert.Nil(t, err) defer nConn.Close() } func TestDialTLS(t *testing.T) { const nw = "tcp" addr := "127.0.0.1:0" data := []byte("abcdefg") listened := make(chan net.Listener) go func() { mockTLSServe(nw, addr, func(conn net.Conn) { defer conn.Close() _, err := conn.Write(data) assert.Nil(t, err) }, listened) }() select { case ln := <-listened: addr = ln.Addr().String() case <-time.After(time.Second * 5): t.Fatalf("timeout") } buf := make([]byte, len(data)) dial := NewDialer() conn, err := dial.DialConnection(nw, addr, time.Second, &tls.Config{ InsecureSkipVerify: true, }) assert.Nil(t, err) _, err = conn.Read(buf) assert.Nil(t, err) assert.DeepEqual(t, string(data), string(buf)) conn, err = dial.DialConnection(nw, addr, time.Second, nil) assert.Nil(t, err) nConn, err := dial.AddTLS(conn, &tls.Config{ InsecureSkipVerify: true, }) assert.Nil(t, err) _, err = nConn.Read(buf) assert.Nil(t, err) assert.DeepEqual(t, string(data), string(buf)) } func mockTLSServe(nw, addr string, handle func(conn net.Conn), listened chan net.Listener) (err error) { certData, keyData, err := generateTestCertificate("") if err != nil { return } cert, err := tls.X509KeyPair(certData, keyData) if err != nil { return } tlsConfig := &tls.Config{ Certificates: []tls.Certificate{cert}, } ln, err := tls.Listen(nw, addr, tlsConfig) if err != nil { return } defer ln.Close() listened <- ln for { conn, err := ln.Accept() if err != nil { continue } go handle(conn) } } // generateTestCertificate generates a test certificate and private key based on the given host. func generateTestCertificate(host string) ([]byte, []byte, error) { priv, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { return nil, nil, err } serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) if err != nil { return nil, nil, err } cert := &x509.Certificate{ SerialNumber: serialNumber, Subject: pkix.Name{ Organization: []string{"fasthttp test"}, }, NotBefore: time.Now(), NotAfter: time.Now().Add(365 * 24 * time.Hour), KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, SignatureAlgorithm: x509.SHA256WithRSA, DNSNames: []string{host}, BasicConstraintsValid: true, IsCA: true, } certBytes, err := x509.CreateCertificate( rand.Reader, cert, cert, &priv.PublicKey, priv, ) p := pem.EncodeToMemory( &pem.Block{ Type: "PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv), }, ) b := pem.EncodeToMemory( &pem.Block{ Type: "CERTIFICATE", Bytes: certBytes, }, ) return b, p, err } ================================================ FILE: pkg/network/standard/transport.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package standard import ( "context" "crypto/tls" "errors" "net" "strings" "sync" "sync/atomic" "time" "github.com/cloudwego/gopkg/net/connstate" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/common/hlog" "github.com/cloudwego/hertz/pkg/network" ) type transport struct { // The underlying read buffer is a buffer node list, `readBufferSize` is the size of a single node. // `defaultReadBufferSize` (4KB) is used if not set. readBufferSize int network string addr string keepAliveTimeout time.Duration senseClientDisconnection bool readTimeout time.Duration handler network.OnData tls *tls.Config listenConfig *net.ListenConfig OnAccept func(conn net.Conn) context.Context OnConnect func(ctx context.Context, conn network.Conn) context.Context // active connections. it +1 after accept and -1 after handler returns active int32 shuttingDown int32 mu sync.RWMutex ln net.Listener } func (t *transport) Listener() net.Listener { t.mu.RLock() defer t.mu.RUnlock() return t.ln } func (t *transport) serve() (err error) { network.UnlinkUdsFile(t.network, t.addr) //nolint:errcheck t.mu.Lock() if t.ln == nil { if t.listenConfig != nil { t.ln, err = t.listenConfig.Listen(context.Background(), t.network, t.addr) } else { t.ln, err = net.Listen(t.network, t.addr) } if err != nil { t.mu.Unlock() return err } } // fix concurrency issue // normally listener must not be changed during serve() ln := t.ln t.mu.Unlock() hlog.SystemLogger().Infof("HTTP server listening on address=%s", ln.Addr().String()) for { ctx := context.Background() conn, err := ln.Accept() if err != nil { if atomic.LoadInt32(&t.shuttingDown) > 0 { return nil } if strings.Contains(err.Error(), "closed") { return nil } hlog.SystemLogger().Errorf("Accept err: %v", err) return err } t.updateActive(1) if t.OnAccept != nil { ctx = t.OnAccept(conn) } var c network.Conn if t.tls != nil { c = newTLSConn(tls.Server(conn, t.tls), t.readBufferSize) } else { c = newConn(conn, t.readBufferSize) } if t.OnConnect != nil { ctx = t.OnConnect(ctx, c) } go func(ctx context.Context, conn network.Conn) { if t.senseClientDisconnection { // Get the underlying net.Conn for connstate registration var rawConn net.Conn var stdConn *Conn switch v := conn.(type) { case *Conn: stdConn = v rawConn = v.c case *TLSConn: // TLSConn embeds Conn stdConn = &v.Conn rawConn = v.c default: // Other connection types are not supported t.handler(ctx, conn) t.updateActive(-1) return } // Register connection close callback var cancelCtx context.CancelFunc ctx, cancelCtx = context.WithCancel(ctx) stater, err := connstate.ListenConnState(rawConn, connstate.WithOnRemoteClosed(connstate.OnRemoteClosed(cancelCtx)), ) if err != nil { hlog.SystemLogger().Errorf("ListenConnState failed: %v, connection close detection disabled", err) } else { // Set stater to Conn, it will be cleaned up when Close is called stdConn.stater = stater } } t.handler(ctx, conn) t.updateActive(-1) }(ctx, c) } } func (t *transport) updateActive(delta int32) int32 { return atomic.AddInt32(&t.active, delta) } func (t *transport) ListenAndServe(onData network.OnData) (err error) { t.handler = onData return t.serve() } func (t *transport) Close() error { ctx, cancel := context.WithTimeout(context.Background(), 0) defer cancel() return t.Shutdown(ctx) } var ( shutdownTimeout = 30 * time.Second shutdownTicker = 10 * time.Millisecond errShutdownTimeout = errors.New("shutdown timeout") ) func (t *transport) Shutdown(ctx context.Context) error { atomic.StoreInt32(&t.shuttingDown, 1) defer func() { network.UnlinkUdsFile(t.network, t.addr) //nolint:errcheck }() if ln := t.Listener(); ln != nil { _ = ln.Close() } tk := time.NewTicker(shutdownTicker) defer tk.Stop() // make sure t.active is updated correctly under concurrency <-tk.C // luckily the server is idle, no more active connections if t.updateActive(0) <= 0 { return nil } // check periodically to see if all connections closed t0 := time.Now() for { select { case now := <-tk.C: if t.updateActive(0) <= 0 { return nil } if now.Sub(t0) > shutdownTimeout { return errShutdownTimeout } case <-ctx.Done(): return ctx.Err() } } } // For transporter switch func NewTransporter(options *config.Options) network.Transporter { return &transport{ readBufferSize: options.ReadBufferSize, network: options.Network, addr: options.Addr, keepAliveTimeout: options.KeepAliveTimeout, readTimeout: options.ReadTimeout, senseClientDisconnection: options.SenseClientDisconnection, tls: options.TLS, ln: options.Listener, listenConfig: options.ListenConfig, OnAccept: options.OnAccept, OnConnect: options.OnConnect, } } ================================================ FILE: pkg/network/standard/transport_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package standard import ( "context" "io" "net" "testing" "time" "github.com/cloudwego/hertz/internal/testutils" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/network" ) func assertWriteRead(t *testing.T, c io.ReadWriter, w, r string) { t.Helper() _, err := c.Write([]byte(w)) if err != nil { t.Fatal("write err", err) } b := make([]byte, len(r)) _, err = io.ReadFull(c, b) if err != nil { t.Fatal("read err", err) } if s := string(b); s != r { t.Fatal("read", r) } } func TestTransporter(t *testing.T) { handlerExit := make(chan struct{}) req := "hello" resp := "world" trans := NewTransporter(&config.Options{Network: "tcp", Addr: "127.0.0.1:0", SenseClientDisconnection: true}).(*transport) go trans.ListenAndServe(func(ctx context.Context, conn interface{}) error { // SenseClientDisconnection is configured, connection close detection is handled by connstate c := conn.(network.Conn) defer c.Close() assertWriteRead(t, c, resp, req) <-handlerExit return nil }) for trans.Listener() == nil { // wait server up time.Sleep(5 * time.Millisecond) } // dial and test c, err := net.Dial("tcp", trans.Listener().Addr().String()) if err != nil { t.Fatal(err) } assertWriteRead(t, c, req, resp) checkActiveConn := func(n int) { if v := trans.updateActive(0); v != int32(n) { t.Helper() t.Fatal("trans active conn", v) } } checkActiveConn(1) // make sure shutdownTimeout will be reset after this test defer func(old time.Duration) { shutdownTimeout = old }(shutdownTimeout) defer func(old time.Duration) { shutdownTicker = old }(shutdownTicker) shutdownTicker = time.Millisecond // shorter for saving test time // case: wait util shutdownTimeout shutdownTimeout = time.Millisecond if err := trans.Shutdown(context.Background()); err != errShutdownTimeout { t.Fatal(err) } // case: ctx done shutdownTimeout = time.Second // long enough ctx, cancel := context.WithCancel(context.Background()) cancel() err = trans.Shutdown(ctx) if err != ctx.Err() { t.Fatal(err) } // case: even after listener is closed, handler may still active checkActiveConn(1) // case: Shutdown blocks at ticker and check periodically. shutdownTimeout = time.Second // long enough go trans.Shutdown(context.Background()) time.Sleep(30 * time.Millisecond) // make sure Shutdown blocks at for loop close(handlerExit) // signal handler to return time.Sleep(10 * time.Millisecond) // wait handler returns, and active conn to be updated. checkActiveConn(0) } func TestAcceptError(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() trans := NewTransporter(&config.Options{Listener: ln}).(*transport) errCh := make(chan error, 1) go func() { errCh <- trans.ListenAndServe(func(context.Context, interface{}) error { return nil }) }() time.Sleep(10 * time.Millisecond) // Wait for serve to start // Close listener to trigger error ln.Close() // Wait for serve to exit with error if err := <-errCh; err != nil { t.Fatal("expected nil after listener close") } } ================================================ FILE: pkg/network/standard/unix_test.go ================================================ // Copyright 2026 CloudWeGo Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // //go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris package standard import ( "context" "net" "testing" "time" "github.com/cloudwego/gopkg/net/connstate" "github.com/cloudwego/hertz/pkg/common/config" ) // TestConnCloseWithStater tests that Close properly closes the stater func TestConnCloseWithStater(t *testing.T) { c := &mockConn{ localAddr: &mockAddr{ network: "tcp", address: "127.0.0.1:8080", }, remoteAddr: &mockAddr{ network: "tcp", address: "127.0.0.1:12345", }, } conn := newConn(c, 0).(*Conn) // Create a mock stater mockStater := &mockConnStater{} conn.stater = mockStater // Close the connection - mockConn.Close() returns error, but stater should still be closed _ = conn.Close() // After Close, stater should be nil if conn.stater != nil { t.Errorf("Stater should be nil after Close, got %v", conn.stater) } // MockStater's Close method should have been called if !mockStater.closed { t.Error("MockStater's Close method was not called") } } // TestConnstateConnectionCloseDetection tests the connstate-based connection close detection func TestConnstateConnectionCloseDetection(t *testing.T) { // Use net.Pipe to create a pair of connected connections cliConn, svrConn := net.Pipe() defer cliConn.Close() defer svrConn.Close() // Create a context that can be canceled ctx, cancelCtx := context.WithCancel(context.Background()) defer cancelCtx() // Create server-side Conn svrStdConn := newConn(svrConn, 0).(*Conn) // Register connstate callback stater, err := connstate.ListenConnState(svrConn, connstate.WithOnRemoteClosed(func() { cancelCtx() }), ) if err != nil { // If connstate is not supported on this platform, skip the test t.Skipf("connstate.ListenConnState not supported: %v", err) return } defer stater.Close() // Set stater to Conn svrStdConn.stater = stater // Test 1: Normal operation - context should not be canceled select { case <-ctx.Done(): t.Fatal("Context should not be canceled in normal operation") case <-time.After(100 * time.Millisecond): // Expected - context not canceled } // Test 2: Close client connection - context should be canceled cliConn.Close() // Wait for context to be canceled (with timeout) timeout := time.NewTimer(2 * time.Second) defer timeout.Stop() select { case <-ctx.Done(): // Expected - context canceled after client closes connection case <-timeout.C: t.Fatal("Context was not canceled after client closed connection") } // Close should properly clean up stater err = svrStdConn.Close() if err != nil { t.Errorf("Close failed: %v", err) } } // TestSenseClientDisconnectionContextCancel tests that the context is canceled // when the client disconnects and SenseClientDisconnection is enabled func TestSenseClientDisconnectionContextCancel(t *testing.T) { handlerRunning := make(chan struct{}) handlerExited := make(chan struct{}) trans := NewTransporter(&config.Options{ Network: "tcp", Addr: "127.0.0.1:0", SenseClientDisconnection: true, }).(*transport) go trans.ListenAndServe(func(ctx context.Context, conn interface{}) error { close(handlerRunning) select { case <-ctx.Done(): // Context was canceled as expected case <-time.After(3 * time.Second): panic("Context was not canceled after client disconnected") } close(handlerExited) return nil }) for trans.Listener() == nil { time.Sleep(2 * time.Millisecond) } clientConn, err := net.Dial("tcp", trans.Listener().Addr().String()) if err != nil { t.Fatal(err) } // Wait for handler to start running <-handlerRunning // Close client connection to trigger disconnection detection clientConn.Close() // Wait for handler to exit (context should be canceled) select { case <-handlerExited: // Handler exited as expected case <-time.After(4 * time.Second): t.Fatal("Handler did not exit after client disconnected") } } // TestSenseClientDisconnectionDisabled tests that context cancel behavior // when SenseClientDisconnection is disabled func TestSenseClientDisconnectionDisabled(t *testing.T) { handlerRunning := make(chan struct{}) handlerExited := make(chan struct{}) trans := NewTransporter(&config.Options{ Network: "tcp", Addr: "127.0.0.1:0", SenseClientDisconnection: false, }).(*transport) go trans.ListenAndServe(func(ctx context.Context, conn interface{}) error { close(handlerRunning) select { case <-ctx.Done(): panic("Context was canceled after client disconnected") case <-time.After(200 * time.Millisecond): } close(handlerExited) return nil }) for trans.Listener() == nil { time.Sleep(1 * time.Millisecond) } clientConn, err := net.Dial("tcp", trans.Listener().Addr().String()) if err != nil { t.Fatal(err) } // Wait for handler to start running <-handlerRunning // Close client connection to trigger disconnection detection clientConn.Close() // Wait for handler to exit (context should be canceled) select { case <-handlerExited: // Handler exited as expected case <-time.After(500 * time.Millisecond): t.Fatal("Handler did not exit after client disconnected") } } // mockConnStater is a mock implementation of connstate.ConnStater for testing type mockConnStater struct { closed bool state connstate.ConnState } func (m *mockConnStater) Close() error { m.closed = true return nil } func (m *mockConnStater) State() connstate.ConnState { return m.state } ================================================ FILE: pkg/network/transport.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package network import ( "context" ) type Transporter interface { // Close the transporter immediately Close() error // Graceful shutdown the transporter Shutdown(ctx context.Context) error // Start listen and ready to accept connection ListenAndServe(onData OnData) error } // Callback when data is ready on the connection type OnData func(ctx context.Context, conn interface{}) error ================================================ FILE: pkg/network/utils.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package network import "syscall" func UnlinkUdsFile(network, addr string) error { if network == "unix" { return syscall.Unlink(addr) } return nil } ================================================ FILE: pkg/network/utils_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package network import ( "os" "runtime" "testing" ) func TestUnlinkUdsFile(t *testing.T) { if runtime.GOOS == "windows" { t.SkipNow() } tmp := "tmpFile" var err error err = UnlinkUdsFile("unix", tmp) if err == nil { t.Errorf("should have error when unlinking a nonexistent file") } os.Create(tmp) err = UnlinkUdsFile("unix", tmp) if err != nil { t.Errorf("unlink file failed: %s", err.Error()) } isExist, _ := pathExists(tmp) if isExist { t.Errorf("unlink file failed, file still exist") } } func pathExists(path string) (bool, error) { _, err := os.Stat(path) if err == nil { return true, nil } if os.IsNotExist(err) { return false, nil } return false, err } ================================================ FILE: pkg/network/version.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package network const ( // QUIC version codes VersionTLS uint32 = 0x1 VersionDraft29 uint32 = 0xff00001d Version1 uint32 = 0x1 Version2 uint32 = 0x709a50c4 ) ================================================ FILE: pkg/network/writer.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package network import ( "io" "sync" "github.com/bytedance/gopkg/lang/mcache" ) const size4K = 1024 * 4 type node struct { data []byte readOnly bool } var nodePool = sync.Pool{} func init() { nodePool.New = func() interface{} { return &node{} } } type networkWriter struct { caches []*node w io.Writer } func (w *networkWriter) release() { for _, n := range w.caches { if !n.readOnly { mcache.Free(n.data) } n.data = nil n.readOnly = false nodePool.Put(n) } w.caches = w.caches[:0] } func (w *networkWriter) Malloc(length int) (buf []byte, err error) { idx := len(w.caches) if idx > 0 { idx -= 1 inUse := len(w.caches[idx].data) if !w.caches[idx].readOnly && cap(w.caches[idx].data)-inUse >= length { end := inUse + length w.caches[idx].data = w.caches[idx].data[:end] return w.caches[idx].data[inUse:end], nil } } buf = mcache.Malloc(length) n := nodePool.Get().(*node) n.data = buf w.caches = append(w.caches, n) return } func (w *networkWriter) WriteBinary(b []byte) (length int, err error) { length = len(b) if length < size4K { buf, _ := w.Malloc(length) copy(buf, b) return } node := nodePool.Get().(*node) node.readOnly = true node.data = b w.caches = append(w.caches, node) return } func (w *networkWriter) Flush() (err error) { for _, c := range w.caches { _, err = w.w.Write(c.data) if err != nil { break } } w.release() return } func NewWriter(w io.Writer) Writer { return &networkWriter{ w: w, } } type ExtWriter interface { io.Writer // Flush sends data to peer immediately Flush() error // Finalize will be called by framework before the writer is released. // Implementations must guarantee that Finalize is safe for multiple calls. Finalize() error } ================================================ FILE: pkg/network/writer_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package network import ( "testing" "github.com/cloudwego/hertz/pkg/common/test/assert" ) const ( size1K = 1024 ) func TestConvertNetworkWriter(t *testing.T) { iw := &mockIOWriter{} w := NewWriter(iw) nw, _ := w.(*networkWriter) // Test malloc buf, _ := w.Malloc(size1K) assert.DeepEqual(t, len(buf), size1K) assert.DeepEqual(t, len(nw.caches), 1) assert.DeepEqual(t, len(nw.caches[0].data), size1K) assert.DeepEqual(t, cap(nw.caches[0].data), size1K) err := w.Flush() assert.Nil(t, err) assert.DeepEqual(t, size1K, iw.WriteNum) assert.DeepEqual(t, len(nw.caches), 0) assert.DeepEqual(t, cap(nw.caches), 1) // Test malloc left size buf, _ = w.Malloc(size1K + 1) assert.DeepEqual(t, len(buf), size1K+1) assert.DeepEqual(t, len(nw.caches), 1) assert.DeepEqual(t, len(nw.caches[0].data), size1K+1) assert.DeepEqual(t, cap(nw.caches[0].data), size1K*2) buf, _ = w.Malloc(size1K / 2) assert.DeepEqual(t, len(buf), size1K/2) assert.DeepEqual(t, len(nw.caches), 1) assert.DeepEqual(t, len(nw.caches[0].data), size1K+1+size1K/2) assert.DeepEqual(t, cap(nw.caches[0].data), size1K*2) buf, _ = w.Malloc(size1K / 2) assert.DeepEqual(t, len(buf), size1K/2) assert.DeepEqual(t, len(nw.caches), 2) assert.DeepEqual(t, len(nw.caches[0].data), size1K+1+size1K/2) assert.DeepEqual(t, cap(nw.caches[0].data), size1K*2) assert.DeepEqual(t, len(nw.caches[1].data), size1K/2) assert.DeepEqual(t, cap(nw.caches[1].data), size1K/2) err = w.Flush() assert.Nil(t, err) assert.DeepEqual(t, size1K*3+1, iw.WriteNum) assert.DeepEqual(t, len(nw.caches), 0) assert.DeepEqual(t, cap(nw.caches), 2) // Test WriteBinary after Malloc buf, _ = w.Malloc(size1K * 6) assert.DeepEqual(t, len(buf), size1K*6) assert.DeepEqual(t, len(nw.caches[0].data), size1K*6) b := make([]byte, size1K) w.WriteBinary(b) assert.DeepEqual(t, size1K*3+1, iw.WriteNum) assert.DeepEqual(t, len(nw.caches[0].data), size1K*7) assert.DeepEqual(t, cap(nw.caches[0].data), size1K*8) b = make([]byte, size1K*4) w.WriteBinary(b) assert.DeepEqual(t, len(nw.caches[1].data), size1K*4) assert.DeepEqual(t, cap(nw.caches[1].data), size1K*4) assert.DeepEqual(t, nw.caches[1].readOnly, true) w.Flush() assert.DeepEqual(t, size1K*14+1, iw.WriteNum) } type mockIOWriter struct { WriteNum int } func (m *mockIOWriter) Write(p []byte) (n int, err error) { m.WriteNum += len(p) return len(p), nil } ================================================ FILE: pkg/protocol/args.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package protocol import ( "bytes" "io" "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/internal/nocopy" ) const ( argsNoValue = true ArgsHasValue = false ) type argsScanner struct { b []byte } type Args struct { noCopy nocopy.NoCopy //lint:ignore U1000 until noCopy is used args []argsKV buf []byte } // Set sets 'key=value' argument. func (a *Args) Set(key, value string) { a.args = setArg(a.args, key, value, ArgsHasValue) } // Reset clears query args. func (a *Args) Reset() { a.args = a.args[:0] } // CopyTo copies all args to dst. func (a *Args) CopyTo(dst *Args) { dst.Reset() dst.args = copyArgs(dst.args, a.args) } // Del deletes argument with the given key from query args. func (a *Args) Del(key string) { a.args = delAllArgs(a.args, key) } // DelBytes deletes argument with the given key from query args. func (a *Args) DelBytes(key []byte) { a.args = delAllArgs(a.args, bytesconv.B2s(key)) } func (s *argsScanner) next(kv *argsKV) bool { if len(s.b) == 0 { return false } kv.noValue = ArgsHasValue isKey := true k := 0 for i, c := range s.b { switch c { case '=': if isKey { isKey = false kv.key = decodeArgAppend(kv.key[:0], s.b[:i]) k = i + 1 } case '&': if isKey { kv.key = decodeArgAppend(kv.key[:0], s.b[:i]) kv.value = kv.value[:0] kv.noValue = argsNoValue } else { kv.value = decodeArgAppend(kv.value[:0], s.b[k:i]) } s.b = s.b[i+1:] return true } } if isKey { kv.key = decodeArgAppend(kv.key[:0], s.b) kv.value = kv.value[:0] kv.noValue = argsNoValue } else { kv.value = decodeArgAppend(kv.value[:0], s.b[k:]) } s.b = s.b[len(s.b):] return true } func decodeArgAppend(dst, src []byte) []byte { if bytes.IndexByte(src, '%') < 0 && bytes.IndexByte(src, '+') < 0 { // fast path: src doesn't contain encoded chars return append(dst, src...) } // slow path for i := 0; i < len(src); i++ { c := src[i] if c == '%' { if i+2 >= len(src) { return append(dst, src[i:]...) } x2 := bytesconv.Hex2intTable[src[i+2]] x1 := bytesconv.Hex2intTable[src[i+1]] if x1 == 16 || x2 == 16 { dst = append(dst, '%') } else { dst = append(dst, x1<<4|x2) i += 2 } } else if c == '+' { dst = append(dst, ' ') } else { dst = append(dst, c) } } return dst } func allocArg(h []argsKV) ([]argsKV, *argsKV) { n := len(h) if cap(h) > n { h = h[:n+1] kv := &h[n] if kv.value == nil { // bytes in value would be reused, and it's not always nil // only set to empty when it's nil kv.value = []byte{} } return h, kv } h = append(h, argsKV{value: []byte{}}) return h, &h[n] } func releaseArg(h []argsKV) []argsKV { return h[:len(h)-1] } func updateArgBytes(h []argsKV, key, value []byte) []argsKV { n := len(h) for i := 0; i < n; i++ { kv := &h[i] if kv.noValue && bytes.Equal(key, kv.key) { kv.value = append(kv.value[:0], value...) kv.noValue = ArgsHasValue return h } } return h } func setArgBytes(h []argsKV, key, value []byte, noValue bool) []argsKV { n := len(h) for i := 0; i < n; i++ { kv := &h[i] if bytes.Equal(key, kv.key) { if noValue { kv.value = kv.value[:0] } else { kv.value = append(kv.value[:0], value...) } kv.noValue = noValue return h } } return appendArgBytes(h, key, value, noValue) } func setArg(h []argsKV, key, value string, noValue bool) []argsKV { n := len(h) for i := 0; i < n; i++ { kv := &h[i] if key == string(kv.key) { if noValue { kv.value = kv.value[:0] } else { kv.value = append(kv.value[:0], value...) } kv.noValue = noValue return h } } return appendArg(h, key, value, noValue) } func peekArgBytes(h []argsKV, k []byte) []byte { for i, n := 0, len(h); i < n; i++ { kv := &h[i] if bytes.Equal(kv.key, k) { return kv.value } } return nil } func peekAllArgBytesToDst(dst [][]byte, h []argsKV, k []byte) [][]byte { for i, n := 0, len(h); i < n; i++ { kv := &h[i] if bytes.Equal(kv.key, k) { dst = append(dst, kv.value) } } return dst } func delAllArgsBytes(args []argsKV, key []byte) []argsKV { return delAllArgs(args, bytesconv.B2s(key)) } func delAllArgs(args []argsKV, key string) []argsKV { for i, n := 0, len(args); i < n; i++ { kv := &args[i] if key == string(kv.key) { tmp := *kv copy(args[i:], args[i+1:]) n-- i-- args[n] = tmp args = args[:n] } } return args } // Has returns true if the given key exists in Args. func (a *Args) Has(key string) bool { return hasArg(a.args, key) } func hasArg(h []argsKV, key string) bool { for i, n := 0, len(h); i < n; i++ { kv := &h[i] if key == string(kv.key) { return true } } return false } // String returns string representation of query args. func (a *Args) String() string { return string(a.QueryString()) } // decodeArgAppendNoPlus is almost identical to decodeArgAppend, but it doesn't // substitute '+' with ' '. // // The function is copy-pasted from decodeArgAppend due to the performance // reasons only. func decodeArgAppendNoPlus(dst, src []byte) []byte { if bytes.IndexByte(src, '%') < 0 { // fast path: src doesn't contain encoded chars return append(dst, src...) } // slow path for i := 0; i < len(src); i++ { c := src[i] if c == '%' { if i+2 >= len(src) { return append(dst, src[i:]...) } x2 := bytesconv.Hex2intTable[src[i+2]] x1 := bytesconv.Hex2intTable[src[i+1]] if x1 == 16 || x2 == 16 { dst = append(dst, '%') } else { dst = append(dst, x1<<4|x2) i += 2 } } else { dst = append(dst, c) } } return dst } func peekArgStr(h []argsKV, k string) []byte { for i, n := 0, len(h); i < n; i++ { kv := &h[i] if string(kv.key) == k { return kv.value } } return nil } func peekArgStrExists(h []argsKV, k string) (string, bool) { for i, n := 0, len(h); i < n; i++ { kv := &h[i] if string(kv.key) == k { return string(kv.value), true } } return "", false } // QueryString returns query string for the args. // // The returned value is valid until the next call to Args methods. func (a *Args) QueryString() []byte { a.buf = a.AppendBytes(a.buf[:0]) return a.buf } // ParseBytes parses the given b containing query args. func (a *Args) ParseBytes(b []byte) { a.Reset() var s argsScanner s.b = b var kv *argsKV a.args, kv = allocArg(a.args) for s.next(kv) { if len(kv.key) > 0 || len(kv.value) > 0 { a.args, kv = allocArg(a.args) } } a.args = releaseArg(a.args) if len(a.args) == 0 { return } } // Peek returns query arg value for the given key. // // Returned value is valid until the next Args call. func (a *Args) Peek(key string) []byte { return peekArgStr(a.args, key) } func (a *Args) PeekExists(key string) (string, bool) { return peekArgStrExists(a.args, key) } // PeekAll returns all the arg values for the given key. func (a *Args) PeekAll(key string) [][]byte { var values [][]byte a.VisitAll(func(k, v []byte) { if bytesconv.B2s(k) == key { values = append(values, v) } }) return values } func visitArgs(args []argsKV, f func(k, v []byte)) { for i, n := 0, len(args); i < n; i++ { kv := &args[i] f(kv.key, kv.value) } } // Len returns the number of query args. func (a *Args) Len() int { return len(a.args) } // AppendBytes appends query string to dst and returns the extended dst. func (a *Args) AppendBytes(dst []byte) []byte { for i, n := 0, len(a.args); i < n; i++ { kv := &a.args[i] dst = bytesconv.AppendQuotedArg(dst, kv.key) if !kv.noValue { dst = append(dst, '=') if len(kv.value) > 0 { dst = bytesconv.AppendQuotedArg(dst, kv.value) } } if i+1 < n { dst = append(dst, '&') } } return dst } // VisitAll calls f for each existing arg. // // f must not retain references to key and value after returning. // Make key and/or value copies if you need storing them after returning. func (a *Args) VisitAll(f func(key, value []byte)) { visitArgs(a.args, f) } // WriteTo writes query string to w. // // WriteTo implements io.WriterTo interface. func (a *Args) WriteTo(w io.Writer) (int64, error) { n, err := w.Write(a.QueryString()) return int64(n), err } // Add adds 'key=value' argument. // // Multiple values for the same key may be added. func (a *Args) Add(key, value string) { a.args = appendArg(a.args, key, value, ArgsHasValue) } ================================================ FILE: pkg/protocol/args_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package protocol import ( "testing" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestArgsDeleteAll(t *testing.T) { t.Parallel() var a Args a.Add("q1", "foo") a.Add("q1", "bar") a.Add("q1", "baz") a.Add("q1", "quux") a.Add("q2", "1234") a.Del("q1") if a.Len() != 1 || a.Has("q1") { t.Fatalf("Expected q1 arg to be completely deleted. Current Args: %s", a.String()) } } func TestArgsBytesOperation(t *testing.T) { var a Args a.Add("q1", "foo") a.Add("q2", "bar") setArgBytes(a.args, a.args[0].key, a.args[0].value, false) assert.DeepEqual(t, []byte("foo"), peekArgBytes(a.args, []byte("q1"))) setArgBytes(a.args, a.args[1].key, a.args[1].value, true) assert.DeepEqual(t, []byte(""), peekArgBytes(a.args, []byte("q2"))) } func TestArgsPeekExists(t *testing.T) { var a Args a.Add("q1", "foo") a.Add("", "") a.Add("?", "=") v1, b1 := a.PeekExists("q1") assert.DeepEqual(t, []byte("foo"), []byte(v1)) assert.True(t, b1) v2, b2 := a.PeekExists("") assert.DeepEqual(t, []byte(""), []byte(v2)) assert.True(t, b2) v3, b3 := a.PeekExists("q3") assert.DeepEqual(t, "", v3) assert.False(t, b3) v4, b4 := a.PeekExists("?") assert.DeepEqual(t, "=", v4) assert.True(t, b4) } func TestSetArg(t *testing.T) { a := Args{args: setArg(nil, "q1", "foo", true)} a.Add("", "") setArgBytes(a.args, []byte("q3"), []byte("bar"), false) s := a.String() assert.DeepEqual(t, []byte("q1&="), []byte(s)) } // Test the encoding of special parameters func TestArgsParseBytes(t *testing.T) { var ta1 Args ta1.Add("q1", "foo") ta1.Add("q1", "bar") ta1.Add("q2", "123") ta1.Add("q3", "") var a1 Args a1.ParseBytes([]byte("q1=foo&q1=bar&q2=123&q3=")) assert.DeepEqual(t, &ta1, &a1) var ta2 Args ta2.Add("?", "foo") ta2.Add("&", "bar") ta2.Add("&", "?") ta2.Add("=", "=") var a2 Args a2.ParseBytes([]byte("%3F=foo&%26=bar&%26=%3F&%3D=%3D")) assert.DeepEqual(t, &ta2, &a2) } func TestArgsVisitAll(t *testing.T) { var a Args var s []string a.Add("cloudwego", "hertz") a.Add("hello", "world") a.VisitAll(func(key, value []byte) { s = append(s, string(key), string(value)) }) assert.DeepEqual(t, []string{"cloudwego", "hertz", "hello", "world"}, s) } func TestArgsCopyTo(t *testing.T) { var a Args a.Add("cloudwego", "") a.Add("hello", "world") var b Args a.CopyTo(&b) assert.Assert(t, b.Len() == 2) assert.Assert(t, a.Peek("cloudwego") != nil && len(a.Peek("cloudwego")) == 0) assert.Assert(t, string(a.Peek("hello")) == "world") assert.Assert(t, a.Peek("key-not-exists") == nil) } func TestArgsPeek(t *testing.T) { var a Args a.Add("cloudwego", "") a.Add("hello", "world") assert.Assert(t, a.Peek("cloudwego") != nil && len(a.Peek("cloudwego")) == 0) assert.Assert(t, string(a.Peek("hello")) == "world") assert.Assert(t, a.Peek("key-not-exists") == nil) // reset and reuse a.Reset() a.Add("cloudwego", "") a.Add("hello", "world") assert.Assert(t, a.Peek("cloudwego") != nil && len(a.Peek("cloudwego")) == 0) assert.Assert(t, string(a.Peek("hello")) == "world") assert.Assert(t, a.Peek("key-not-exists") == nil) } func TestArgsPeekMulti(t *testing.T) { var a Args a.Add("cloudwego", "hertz") a.Add("cloudwego", "kitex") a.Add("cloudwego", "") a.Add("hello", "world") vv := a.PeekAll("cloudwego") expectedVV := [][]byte{ []byte("hertz"), []byte("kitex"), []byte{}, } assert.DeepEqual(t, expectedVV, vv) vv = a.PeekAll("aaaa") assert.DeepEqual(t, 0, len(vv)) vv = a.PeekAll("hello") expectedVV = [][]byte{[]byte("world")} assert.DeepEqual(t, expectedVV, vv) } ================================================ FILE: pkg/protocol/client/client.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package client import ( "context" "sync" "time" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/timer" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" ) const defaultMaxRedirectsCount = 16 var ( errTimeout = errors.New(errors.ErrTimeout, errors.ErrorTypePublic, "host client") errMissingLocation = errors.NewPublic("missing Location header for http redirect") errTooManyRedirects = errors.NewPublic("too many redirects detected when doing the request") clientURLResponseChPool sync.Pool ) type HostClient interface { Doer SetDynamicConfig(dc *DynamicConfig) CloseIdleConnections() ShouldRemove() bool ConnectionCount() int } type Doer interface { Do(ctx context.Context, req *protocol.Request, resp *protocol.Response) error } // DefaultRetryIf Default retry condition, mainly used for idempotent requests. // If this cannot be satisfied, you can implement your own retry condition. func DefaultRetryIf(req *protocol.Request, resp *protocol.Response, err error) bool { // cannot retry if the request body is not rewindable if req.IsBodyStream() { return false } if isIdempotent(req, resp, err) { return true } return false } func isIdempotent(req *protocol.Request, resp *protocol.Response, err error) bool { return req.Header.IsGet() || req.Header.IsHead() || req.Header.IsPut() || req.Header.IsDelete() || req.Header.IsOptions() || req.Header.IsTrace() } // DynamicConfig is config set which will be confirmed when starts a request. type DynamicConfig struct { Addr string ProxyURI *protocol.URI IsTLS bool } // RetryIfFunc signature of retry if function // Judge whether to retry by request,response or error , return true is retry type RetryIfFunc func(req *protocol.Request, resp *protocol.Response, err error) bool type clientURLResponse struct { statusCode int body []byte err error } func GetURL(ctx context.Context, dst []byte, url string, c Doer, requestOptions ...config.RequestOption) (statusCode int, body []byte, err error) { req := protocol.AcquireRequest() req.SetOptions(requestOptions...) statusCode, body, err = doRequestFollowRedirectsBuffer(ctx, req, dst, url, c) protocol.ReleaseRequest(req) return statusCode, body, err } func GetURLTimeout(ctx context.Context, dst []byte, url string, timeout time.Duration, c Doer, requestOptions ...config.RequestOption) (statusCode int, body []byte, err error) { deadline := time.Now().Add(timeout) return GetURLDeadline(ctx, dst, url, deadline, c, requestOptions...) } func GetURLDeadline(ctx context.Context, dst []byte, url string, deadline time.Time, c Doer, requestOptions ...config.RequestOption) (statusCode int, body []byte, err error) { timeout := -time.Since(deadline) if timeout <= 0 { return 0, dst, errTimeout } var ch chan clientURLResponse chv := clientURLResponseChPool.Get() if chv == nil { chv = make(chan clientURLResponse, 1) } ch = chv.(chan clientURLResponse) req := protocol.AcquireRequest() req.SetOptions(requestOptions...) // Note that the request continues execution on errTimeout until // client-specific ReadTimeout exceeds. This helps to limit load // on slow hosts by MaxConns* concurrent requests. // // Without this 'hack' the load on slow host could exceed MaxConns* // concurrent requests, since timed out requests on client side // usually continue execution on the host. go func() { statusCodeCopy, bodyCopy, errCopy := doRequestFollowRedirectsBuffer(ctx, req, dst, url, c) ch <- clientURLResponse{ statusCode: statusCodeCopy, body: bodyCopy, err: errCopy, } }() tc := timer.AcquireTimer(timeout) select { case resp := <-ch: protocol.ReleaseRequest(req) clientURLResponseChPool.Put(chv) statusCode = resp.statusCode body = resp.body err = resp.err case <-tc.C: body = dst err = errTimeout } timer.ReleaseTimer(tc) return statusCode, body, err } func PostURL(ctx context.Context, dst []byte, url string, postArgs *protocol.Args, c Doer, requestOptions ...config.RequestOption) (statusCode int, body []byte, err error) { req := protocol.AcquireRequest() req.Header.SetMethodBytes(bytestr.StrPost) req.Header.SetContentTypeBytes(bytestr.MIMEPostForm) req.SetOptions(requestOptions...) if postArgs != nil { if _, err := postArgs.WriteTo(req.BodyWriter()); err != nil { return 0, nil, err } } statusCode, body, err = doRequestFollowRedirectsBuffer(ctx, req, dst, url, c) protocol.ReleaseRequest(req) return statusCode, body, err } func doRequestFollowRedirectsBuffer(ctx context.Context, req *protocol.Request, dst []byte, url string, c Doer) (statusCode int, body []byte, err error) { resp := protocol.AcquireResponse() bodyBuf := resp.BodyBuffer() oldBody := bodyBuf.B bodyBuf.B = dst statusCode, _, err = DoRequestFollowRedirects(ctx, req, resp, url, defaultMaxRedirectsCount, c) // In HTTP2 scenario, client use stream mode to create a request and its body is in body stream. // In HTTP1, only client recv body exceed max body size and client is in stream mode can trig it. body = resp.Body() bodyBuf.B = oldBody protocol.ReleaseResponse(resp) return statusCode, body, err } func DoRequestFollowRedirects(ctx context.Context, req *protocol.Request, resp *protocol.Response, url string, maxRedirectsCount int, c Doer) (statusCode int, body []byte, err error) { redirectsCount := 0 for { req.SetRequestURI(url) req.ParseURI() if err = c.Do(ctx, req, resp); err != nil { break } statusCode = resp.Header.StatusCode() if !StatusCodeIsRedirect(statusCode) { break } redirectsCount++ if redirectsCount > maxRedirectsCount { err = errTooManyRedirects break } location := resp.Header.PeekLocation() if len(location) == 0 { err = errMissingLocation break } url = getRedirectURL(url, location) // Remove the former host header. req.Header.Del(consts.HeaderHost) } return statusCode, body, err } // StatusCodeIsRedirect returns true if the status code indicates a redirect. func StatusCodeIsRedirect(statusCode int) bool { return statusCode == consts.StatusMovedPermanently || statusCode == consts.StatusFound || statusCode == consts.StatusSeeOther || statusCode == consts.StatusTemporaryRedirect || statusCode == consts.StatusPermanentRedirect } func getRedirectURL(baseURL string, location []byte) string { u := protocol.AcquireURI() u.Update(baseURL) u.UpdateBytes(location) redirectURL := u.String() protocol.ReleaseURI(u) return redirectURL } func DoTimeout(ctx context.Context, req *protocol.Request, resp *protocol.Response, timeout time.Duration, c Doer) error { if timeout <= 0 { return errTimeout } // Note: it will overwrite the reqTimeout. req.SetOptions(config.WithRequestTimeout(timeout)) return c.Do(ctx, req, resp) } func DoDeadline(ctx context.Context, req *protocol.Request, resp *protocol.Response, deadline time.Time, c Doer) error { timeout := time.Until(deadline) if timeout <= 0 { return errTimeout } // Note: it will overwrite the reqTimeout. req.SetOptions(config.WithRequestTimeout(timeout)) return c.Do(ctx, req, resp) } ================================================ FILE: pkg/protocol/client/client_test.go ================================================ /* * Copyright 2024 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package client import ( "context" "errors" "testing" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/pkg/protocol" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) var firstTime = true type MockDoer struct { mock.Mock } func (m *MockDoer) Do(ctx context.Context, req *protocol.Request, resp *protocol.Response) error { // this is the real logic in (c *HostClient) doNonNilReqResp method if len(req.Header.Host()) == 0 { req.Header.SetHostBytes(req.URI().Host()) } if firstTime { // req.Header.Host() is the real host writing to the wire if string(req.Header.Host()) != "example.com" { return errors.New("host not match") } // this is the real logic in (c *HostClient) doNonNilReqResp method if len(req.Header.Host()) == 0 { req.Header.SetHostBytes(req.URI().Host()) } resp.Header.SetCanonical(bytestr.StrLocation, []byte("https://a.b.c/foo")) resp.SetStatusCode(301) firstTime = false return nil } if string(req.Header.Host()) != "a.b.c" { resp.SetStatusCode(400) return errors.New("host not match") } resp.SetStatusCode(200) return nil } func TestDoRequestFollowRedirects(t *testing.T) { mockDoer := new(MockDoer) mockDoer.On("Do", mock.Anything, mock.Anything, mock.Anything).Return(nil) statusCode, _, err := DoRequestFollowRedirects(context.Background(), &protocol.Request{}, &protocol.Response{}, "https://example.com", defaultMaxRedirectsCount, mockDoer) assert.NoError(t, err) assert.Equal(t, 200, statusCode) } ================================================ FILE: pkg/protocol/consts/default.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package consts import "time" const ( // *** Server default value *** // DefaultMaxInMemoryFileSize defines the in memory file size when parse // multipart_form. If the size exceeds, then hertz will write to disk. DefaultMaxInMemoryFileSize = 16 * 1024 * 1024 // *** Client default value start from here *** // DefaultDialTimeout is timeout used by Dialer and DialDualStack // for establishing TCP connections. DefaultDialTimeout = time.Second // Deprecated: no longer used as a default. Previously, unconfigured clients // silently fell back to this value, capping connections even when no limit was // intended, causing ErrNoFreeConns when busy connections reached this cap. // Since v0.10.3: MaxConnsPerHost now defaults to 0 (no limit). DefaultMaxConnsPerHost = 512 // DefaultMaxIdleConnDuration is the default duration before idle keep-alive // connection is closed. DefaultMaxIdleConnDuration = 10 * time.Second // DefaultMaxIdempotentCallAttempts is the default idempotent calls attempts count. DefaultMaxIdempotentCallAttempts = 1 // DefaultMaxRetryTimes is the default call times of retry DefaultMaxRetryTimes = 1 ) ================================================ FILE: pkg/protocol/consts/fs.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package consts import "time" const ( // files bigger than this size are sent with sendfile MaxSmallFileSize = 2 * 4096 // FSHandlerCacheDuration is the default expiration duration for inactive // file handlers opened by FS. FSHandlerCacheDuration = 10 * time.Second // FSCompressedFileSuffix is the suffix FS adds to the original file names // when trying to store compressed file under the new file name. // See FS.Compress for details. FSCompressedFileSuffix = ".hertz.gz" FsMinCompressRatio = 0.8 FsMaxCompressibleFileSize = 8 * 1024 * 1024 ) ================================================ FILE: pkg/protocol/consts/headers.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package consts const ( HeaderDate = "Date" HeaderIfModifiedSince = "If-Modified-Since" HeaderLastModified = "Last-Modified" // Redirects HeaderLocation = "Location" // Transfer coding HeaderTE = "TE" HeaderTrailer = "Trailer" HeaderTrailerLower = "trailer" HeaderTransferEncoding = "Transfer-Encoding" // Controls HeaderCookie = "Cookie" HeaderExpect = "Expect" HeaderMaxForwards = "Max-Forwards" HeaderSetCookie = "Set-Cookie" HeaderSetCookieLower = "set-cookie" // Connection management HeaderConnection = "Connection" HeaderKeepAlive = "Keep-Alive" HeaderProxyConnection = "Proxy-Connection" // Authentication HeaderAuthorization = "Authorization" HeaderProxyAuthenticate = "Proxy-Authenticate" HeaderProxyAuthorization = "Proxy-Authorization" HeaderWWWAuthenticate = "WWW-Authenticate" // Range requests HeaderAcceptRanges = "Accept-Ranges" HeaderContentRange = "Content-Range" HeaderIfRange = "If-Range" HeaderRange = "Range" // Response context HeaderAllow = "Allow" HeaderServer = "Server" HeaderServerLower = "server" // Request context HeaderFrom = "From" HeaderHost = "Host" HeaderReferer = "Referer" HeaderReferrerPolicy = "Referrer-Policy" HeaderUserAgent = "User-Agent" // Message body information HeaderContentEncoding = "Content-Encoding" HeaderContentLanguage = "Content-Language" HeaderContentLength = "Content-Length" HeaderContentLocation = "Content-Location" HeaderContentType = "Content-Type" // Content negotiation HeaderAccept = "Accept" HeaderAcceptCharset = "Accept-Charset" HeaderAcceptEncoding = "Accept-Encoding" HeaderAcceptLanguage = "Accept-Language" HeaderAltSvc = "Alt-Svc" // Protocol HTTP11 = "HTTP/1.1" HTTP10 = "HTTP/1.0" HTTP20 = "HTTP/2.0" // MIME text MIMETextPlain = "text/plain" MIMETextPlainUTF8 = "text/plain; charset=utf-8" MIMETextPlainISO88591 = "text/plain; charset=iso-8859-1" MIMETextPlainFormatFlowed = "text/plain; format=flowed" MIMETextPlainDelSpaceYes = "text/plain; delsp=yes" MiMETextPlainDelSpaceNo = "text/plain; delsp=no" MIMETextHtml = "text/html" MIMETextCss = "text/css" MIMETextJavascript = "text/javascript" MIMEMultipartPOSTForm = "multipart/form-data" // MIME application MIMEApplicationOctetStream = "application/octet-stream" MIMEApplicationFlash = "application/x-shockwave-flash" MIMEApplicationHTMLForm = "application/x-www-form-urlencoded" MIMEApplicationHTMLFormUTF8 = "application/x-www-form-urlencoded; charset=UTF-8" MIMEApplicationTar = "application/x-tar" MIMEApplicationGZip = "application/gzip" MIMEApplicationXGZip = "application/x-gzip" MIMEApplicationBZip2 = "application/bzip2" MIMEApplicationXBZip2 = "application/x-bzip2" MIMEApplicationShell = "application/x-sh" MIMEApplicationDownload = "application/x-msdownload" MIMEApplicationJSON = "application/json" MIMEApplicationJSONUTF8 = "application/json; charset=utf-8" MIMEApplicationXML = "application/xml" MIMEApplicationXMLUTF8 = "application/xml; charset=utf-8" MIMEApplicationZip = "application/zip" MIMEApplicationPdf = "application/pdf" MIMEApplicationWord = "application/msword" MIMEApplicationExcel = "application/vnd.ms-excel" MIMEApplicationPPT = "application/vnd.ms-powerpoint" MIMEApplicationOpenXMLWord = "application/vnd.openxmlformats-officedocument.wordprocessingml.document" MIMEApplicationOpenXMLExcel = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" MIMEApplicationOpenXMLPPT = "application/vnd.openxmlformats-officedocument.presentationml.presentation" MIMEPROTOBUF = "application/x-protobuf" // MIME image MIMEImageJPEG = "image/jpeg" MIMEImagePNG = "image/png" MIMEImageGIF = "image/gif" MIMEImageBitmap = "image/bmp" MIMEImageWebP = "image/webp" MIMEImageIco = "image/x-icon" MIMEImageMicrosoftICO = "image/vnd.microsoft.icon" MIMEImageTIFF = "image/tiff" MIMEImageSVG = "image/svg+xml" MIMEImagePhotoshop = "image/vnd.adobe.photoshop" // MIME audio MIMEAudioBasic = "audio/basic" MIMEAudioL24 = "audio/L24" MIMEAudioMP3 = "audio/mp3" MIMEAudioMP4 = "audio/mp4" MIMEAudioMPEG = "audio/mpeg" MIMEAudioOggVorbis = "audio/ogg" MIMEAudioWAVE = "audio/vnd.wave" MIMEAudioWebM = "audio/webm" MIMEAudioAAC = "audio/x-aac" MIMEAudioAIFF = "audio/x-aiff" MIMEAudioMIDI = "audio/x-midi" MIMEAudioM3U = "audio/x-mpegurl" MIMEAudioRealAudio = "audio/x-pn-realaudio" // MIME video MIMEVideoMPEG = "video/mpeg" MIMEVideoOgg = "video/ogg" MIMEVideoMP4 = "video/mp4" MIMEVideoQuickTime = "video/quicktime" MIMEVideoWinMediaVideo = "video/x-ms-wmv" MIMEVideWebM = "video/webm" MIMEVideoFlashVideo = "video/x-flv" MIMEVideo3GPP = "video/3gpp" MIMEVideoAVI = "video/x-msvideo" MIMEVideoMatroska = "video/x-matroska" ) ================================================ FILE: pkg/protocol/consts/http2.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package consts // ClientPreface is the string that must be sent by new // connections from clients. const ClientPreface = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" ================================================ FILE: pkg/protocol/consts/methods.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package consts // HTTP methods were copied from net/http. const ( MethodGet = "GET" // RFC 7231, 4.3.1 MethodHead = "HEAD" // RFC 7231, 4.3.2 MethodPost = "POST" // RFC 7231, 4.3.3 MethodPut = "PUT" // RFC 7231, 4.3.4 MethodPatch = "PATCH" // RFC 5789 MethodDelete = "DELETE" // RFC 7231, 4.3.5 MethodConnect = "CONNECT" // RFC 7231, 4.3.6 MethodOptions = "OPTIONS" // RFC 7231, 4.3.7 MethodTrace = "TRACE" // RFC 7231, 4.3.8 ) ================================================ FILE: pkg/protocol/consts/status.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package consts import ( "fmt" "sync/atomic" ) const ( statusMessageMin = 100 statusMessageMax = 511 ) // HTTP status codes were stolen from net/http. const ( StatusContinue = 100 // RFC 7231, 6.2.1 StatusSwitchingProtocols = 101 // RFC 7231, 6.2.2 StatusProcessing = 102 // RFC 2518, 10.1 StatusOK = 200 // RFC 7231, 6.3.1 StatusCreated = 201 // RFC 7231, 6.3.2 StatusAccepted = 202 // RFC 7231, 6.3.3 StatusNonAuthoritativeInfo = 203 // RFC 7231, 6.3.4 StatusNoContent = 204 // RFC 7231, 6.3.5 StatusResetContent = 205 // RFC 7231, 6.3.6 StatusPartialContent = 206 // RFC 7233, 4.1 StatusMultiStatus = 207 // RFC 4918, 11.1 StatusAlreadyReported = 208 // RFC 5842, 7.1 StatusIMUsed = 226 // RFC 3229, 10.4.1 StatusMultipleChoices = 300 // RFC 7231, 6.4.1 StatusMovedPermanently = 301 // RFC 7231, 6.4.2 StatusFound = 302 // RFC 7231, 6.4.3 StatusSeeOther = 303 // RFC 7231, 6.4.4 StatusNotModified = 304 // RFC 7232, 4.1 StatusUseProxy = 305 // RFC 7231, 6.4.5 _ = 306 // RFC 7231, 6.4.6 (Unused) StatusTemporaryRedirect = 307 // RFC 7231, 6.4.7 StatusPermanentRedirect = 308 // RFC 7538, 3 StatusBadRequest = 400 // RFC 7231, 6.5.1 StatusUnauthorized = 401 // RFC 7235, 3.1 StatusPaymentRequired = 402 // RFC 7231, 6.5.2 StatusForbidden = 403 // RFC 7231, 6.5.3 StatusNotFound = 404 // RFC 7231, 6.5.4 StatusMethodNotAllowed = 405 // RFC 7231, 6.5.5 StatusNotAcceptable = 406 // RFC 7231, 6.5.6 StatusProxyAuthRequired = 407 // RFC 7235, 3.2 StatusRequestTimeout = 408 // RFC 7231, 6.5.7 StatusConflict = 409 // RFC 7231, 6.5.8 StatusGone = 410 // RFC 7231, 6.5.9 StatusLengthRequired = 411 // RFC 7231, 6.5.10 StatusPreconditionFailed = 412 // RFC 7232, 4.2 StatusRequestEntityTooLarge = 413 // RFC 7231, 6.5.11 StatusRequestURITooLong = 414 // RFC 7231, 6.5.12 StatusUnsupportedMediaType = 415 // RFC 7231, 6.5.13 StatusRequestedRangeNotSatisfiable = 416 // RFC 7233, 4.4 StatusExpectationFailed = 417 // RFC 7231, 6.5.14 StatusTeapot = 418 // RFC 7168, 2.3.3 StatusUnprocessableEntity = 422 // RFC 4918, 11.2 StatusLocked = 423 // RFC 4918, 11.3 StatusFailedDependency = 424 // RFC 4918, 11.4 StatusUpgradeRequired = 426 // RFC 7231, 6.5.15 StatusPreconditionRequired = 428 // RFC 6585, 3 StatusTooManyRequests = 429 // RFC 6585, 4 StatusRequestHeaderFieldsTooLarge = 431 // RFC 6585, 5 StatusUnavailableForLegalReasons = 451 // RFC 7725, 3 StatusInternalServerError = 500 // RFC 7231, 6.6.1 StatusNotImplemented = 501 // RFC 7231, 6.6.2 StatusBadGateway = 502 // RFC 7231, 6.6.3 StatusServiceUnavailable = 503 // RFC 7231, 6.6.4 StatusGatewayTimeout = 504 // RFC 7231, 6.6.5 StatusHTTPVersionNotSupported = 505 // RFC 7231, 6.6.6 StatusVariantAlsoNegotiates = 506 // RFC 2295, 8.1 StatusInsufficientStorage = 507 // RFC 4918, 11.5 StatusLoopDetected = 508 // RFC 5842, 7.2 StatusNotExtended = 510 // RFC 2774, 7 StatusNetworkAuthenticationRequired = 511 // RFC 6585, 6 ) var ( statusLines atomic.Value statusMessages = map[int]string{ StatusContinue: "Continue", StatusSwitchingProtocols: "Switching Protocols", StatusProcessing: "Processing", StatusOK: "OK", StatusCreated: "Created", StatusAccepted: "Accepted", StatusNonAuthoritativeInfo: "Non-Authoritative Information", StatusNoContent: "No Content", StatusResetContent: "Reset Content", StatusPartialContent: "Partial Content", StatusMultiStatus: "Multi-Status", StatusAlreadyReported: "Already Reported", StatusIMUsed: "IM Used", StatusMultipleChoices: "Multiple Choices", StatusMovedPermanently: "Moved Permanently", StatusFound: "Found", StatusSeeOther: "See Other", StatusNotModified: "Not Modified", StatusUseProxy: "Use Proxy", StatusTemporaryRedirect: "Temporary Redirect", StatusPermanentRedirect: "Permanent Redirect", StatusBadRequest: "Bad Request", StatusUnauthorized: "Unauthorized", StatusPaymentRequired: "Payment Required", StatusForbidden: "Forbidden", StatusNotFound: "Not Found", StatusMethodNotAllowed: "Method Not Allowed", StatusNotAcceptable: "Not Acceptable", StatusProxyAuthRequired: "Proxy Authentication Required", StatusRequestTimeout: "Request Timeout", StatusConflict: "Conflict", StatusGone: "Gone", StatusLengthRequired: "Length Required", StatusPreconditionFailed: "Precondition Failed", StatusRequestEntityTooLarge: "Request Entity Too Large", StatusRequestURITooLong: "Request URI Too Long", StatusUnsupportedMediaType: "Unsupported Media Type", StatusRequestedRangeNotSatisfiable: "Requested Range Not Satisfiable", StatusExpectationFailed: "Expectation Failed", StatusTeapot: "I'm a teapot", StatusUnprocessableEntity: "Unprocessable Entity", StatusLocked: "Locked", StatusFailedDependency: "Failed Dependency", StatusUpgradeRequired: "Upgrade Required", StatusPreconditionRequired: "Precondition Required", StatusTooManyRequests: "Too Many Requests", StatusRequestHeaderFieldsTooLarge: "Request Header Fields Too Large", StatusUnavailableForLegalReasons: "Unavailable For Legal Reasons", StatusInternalServerError: "Internal Server Error", StatusNotImplemented: "Not Implemented", StatusBadGateway: "Bad Gateway", StatusServiceUnavailable: "Service Unavailable", StatusGatewayTimeout: "Gateway Timeout", StatusHTTPVersionNotSupported: "HTTP Version Not Supported", StatusVariantAlsoNegotiates: "Variant Also Negotiates", StatusInsufficientStorage: "Insufficient Storage", StatusLoopDetected: "Loop Detected", StatusNotExtended: "Not Extended", StatusNetworkAuthenticationRequired: "Network Authentication Required", } ) // StatusMessage returns HTTP status message for the given status code. func StatusMessage(statusCode int) string { if statusCode < statusMessageMin || statusCode > statusMessageMax { return "Unknown Status Code" } s := statusMessages[statusCode] if s == "" { s = "Unknown Status Code" } return s } func init() { statusLines.Store(make(map[int][]byte)) } func StatusLine(statusCode int) []byte { m := statusLines.Load().(map[int][]byte) h := m[statusCode] if h != nil { return h } statusText := StatusMessage(statusCode) h = []byte(fmt.Sprintf("HTTP/1.1 %d %s\r\n", statusCode, statusText)) newM := make(map[int][]byte, len(m)+1) for k, v := range m { newM[k] = v } newM[statusCode] = h statusLines.Store(newM) return h } ================================================ FILE: pkg/protocol/cookie.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package protocol import ( "bytes" "sync" "time" "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/internal/nocopy" "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/hlog" "github.com/cloudwego/hertz/pkg/common/utils" ) const ( // CookieSameSiteDisabled removes the SameSite flag CookieSameSiteDisabled CookieSameSite = iota // CookieSameSiteDefaultMode sets the SameSite flag CookieSameSiteDefaultMode // CookieSameSiteLaxMode sets the SameSite flag with the "Lax" parameter CookieSameSiteLaxMode // CookieSameSiteStrictMode sets the SameSite flag with the "Strict" parameter CookieSameSiteStrictMode // CookieSameSiteNoneMode sets the SameSite flag with the "None" parameter // see https://tools.ietf.org/html/draft-west-cookie-incrementalism-00 // third-party cookies are phasing out, use Partitioned cookies instead // see https://developers.google.com/privacy-sandbox/3pcd CookieSameSiteNoneMode ) var zeroTime time.Time var ( errNoCookies = errors.NewPublic("no cookies found") // CookieExpireDelete may be set on Cookie.Expire for expiring the given cookie. CookieExpireDelete = time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) // CookieExpireUnlimited indicates that the cookie doesn't expire. CookieExpireUnlimited = zeroTime ) // CookieSameSite is an enum for the mode in which the SameSite flag should be set for the given cookie. // See https://tools.ietf.org/html/draft-ietf-httpbis-cookie-same-site-00 for details. type CookieSameSite int // Cookie represents HTTP response cookie. // // Do not copy Cookie objects. Create new object and use CopyTo instead. // // Cookie instance MUST NOT be used from concurrently running goroutines. type Cookie struct { noCopy nocopy.NoCopy //lint:ignore U1000 until noCopy is used key []byte value []byte expire time.Time maxAge int domain []byte path []byte httpOnly bool secure bool // A partitioned third-party cookie is tied to the top-level site // where it's initially set and cannot be accessed from elsewhere. partitioned bool sameSite CookieSameSite bufKV argsKV buf []byte } var cookiePool = &sync.Pool{ New: func() interface{} { return &Cookie{} }, } // AcquireCookie returns an empty Cookie object from the pool. // // The returned object may be returned back to the pool with ReleaseCookie. // This allows reducing GC load. func AcquireCookie() *Cookie { return cookiePool.Get().(*Cookie) } // ReleaseCookie returns the Cookie object acquired with AcquireCookie back // to the pool. // // Do not access released Cookie object, otherwise data races may occur. func ReleaseCookie(c *Cookie) { c.Reset() cookiePool.Put(c) } // SetDomain sets cookie domain. func (c *Cookie) SetDomain(domain string) { c.domain = append(c.domain[:0], domain...) } // SetPath sets cookie path. func (c *Cookie) SetPath(path string) { c.buf = append(c.buf[:0], path...) c.path = normalizePath(c.path, c.buf) } // SetPathBytes sets cookie path. func (c *Cookie) SetPathBytes(path []byte) { c.buf = append(c.buf[:0], path...) c.path = normalizePath(c.path, c.buf) } // SetExpire sets cookie expiration time. // // Set expiration time to CookieExpireDelete for expiring (deleting) // the cookie on the client. // // By default cookie lifetime is limited by browser session. func (c *Cookie) SetExpire(expire time.Time) { c.expire = expire } // SetKey sets cookie name. func (c *Cookie) SetKey(key string) { c.key = append(c.key[:0], key...) } // SetKeyBytes sets cookie name. func (c *Cookie) SetKeyBytes(key []byte) { c.key = append(c.key[:0], key...) } // SetValue sets cookie value. func (c *Cookie) SetValue(value string) { warnIfInvalid(bytesconv.S2b(value)) c.value = append(c.value[:0], value...) } // SetValueBytes sets cookie value. func (c *Cookie) SetValueBytes(value []byte) { warnIfInvalid(value) c.value = append(c.value[:0], value...) } // AppendBytes appends cookie representation to dst and returns // the extended dst. func (c *Cookie) AppendBytes(dst []byte) []byte { if len(c.key) > 0 { dst = append(dst, c.key...) dst = append(dst, '=') } dst = append(dst, c.value...) if c.maxAge != 0 { dst = append(dst, ';', ' ') dst = append(dst, bytestr.StrCookieMaxAge...) dst = append(dst, '=') if c.maxAge < 0 { dst = append(dst, '0') } else { dst = bytesconv.AppendUint(dst, c.maxAge) } } if !c.expire.IsZero() { dst = append(dst, ';', ' ') dst = append(dst, bytestr.StrCookieExpires...) dst = append(dst, '=') dst = bytesconv.AppendHTTPDate(dst, c.expire) } if len(c.domain) > 0 { dst = appendCookiePart(dst, bytestr.StrCookieDomain, c.domain) } if len(c.path) > 0 { dst = appendCookiePart(dst, bytestr.StrCookiePath, c.path) } if c.httpOnly { dst = append(dst, ';', ' ') dst = append(dst, bytestr.StrCookieHTTPOnly...) } if c.secure { dst = append(dst, ';', ' ') dst = append(dst, bytestr.StrCookieSecure...) } switch c.sameSite { case CookieSameSiteDefaultMode: dst = append(dst, ';', ' ') dst = append(dst, bytestr.StrCookieSameSite...) case CookieSameSiteLaxMode: dst = appendCookiePart(dst, bytestr.StrCookieSameSite, bytestr.StrCookieSameSiteLax) case CookieSameSiteStrictMode: dst = appendCookiePart(dst, bytestr.StrCookieSameSite, bytestr.StrCookieSameSiteStrict) case CookieSameSiteNoneMode: dst = appendCookiePart(dst, bytestr.StrCookieSameSite, bytestr.StrCookieSameSiteNone) } if c.partitioned { dst = append(dst, ';', ' ') dst = append(dst, bytestr.StrCookiePartitioned...) } return dst } func appendCookiePart(dst, key, value []byte) []byte { dst = append(dst, ';', ' ') dst = append(dst, key...) dst = append(dst, '=') return append(dst, value...) } func appendRequestCookieBytes(dst []byte, cookies []argsKV) []byte { for i, n := 0, len(cookies); i < n; i++ { kv := &cookies[i] if len(kv.key) > 0 { dst = append(dst, kv.key...) dst = append(dst, '=') } dst = append(dst, kv.value...) if i+1 < n { dst = append(dst, ';', ' ') } } return dst } // For Response we can not use the above function as response cookies // already contain the key= in the value. func appendResponseCookieBytes(dst []byte, cookies []argsKV) []byte { for i, n := 0, len(cookies); i < n; i++ { kv := &cookies[i] dst = append(dst, kv.value...) if i+1 < n { dst = append(dst, ';', ' ') } } return dst } type cookieScanner struct { b []byte } func parseRequestCookies(cookies []argsKV, src []byte) []argsKV { var s cookieScanner s.b = src var kv *argsKV cookies, kv = allocArg(cookies) for s.next(kv) { if len(kv.key) > 0 || len(kv.value) > 0 { cookies, kv = allocArg(cookies) } } return releaseArg(cookies) } func (s *cookieScanner) next(kv *argsKV) bool { b := s.b if len(b) == 0 { return false } isKey := true k := 0 for i, c := range b { switch c { case '=': if isKey { isKey = false kv.key = decodeCookieArg(kv.key, b[:i], false) k = i + 1 } case ';': if isKey { kv.key = kv.key[:0] } kv.value = decodeCookieArg(kv.value, b[k:i], true) s.b = b[i+1:] return true } } if isKey { kv.key = kv.key[:0] } kv.value = decodeCookieArg(kv.value, b[k:], true) s.b = b[len(b):] return true } // Key returns cookie name. // // The returned value is valid until the next Cookie modification method call. func (c *Cookie) Key() []byte { return c.key } // Cookie returns cookie representation. // // The returned value is valid until the next call to Cookie methods. func (c *Cookie) Cookie() []byte { c.buf = c.AppendBytes(c.buf[:0]) return c.buf } // Reset clears the cookie. func (c *Cookie) Reset() { c.key = c.key[:0] c.value = c.value[:0] c.expire = zeroTime c.maxAge = 0 c.domain = c.domain[:0] c.path = c.path[:0] c.httpOnly = false c.secure = false c.sameSite = CookieSameSiteDisabled c.partitioned = false } // Value returns cookie value. // // The returned value is valid until the next Cookie modification method call. func (c *Cookie) Value() []byte { return c.value } // Parse parses Set-Cookie header. func (c *Cookie) Parse(src string) error { c.buf = append(c.buf[:0], src...) return c.ParseBytes(c.buf) } // ParseBytes parses Set-Cookie header. func (c *Cookie) ParseBytes(src []byte) error { c.Reset() var s cookieScanner s.b = src kv := &c.bufKV if !s.next(kv) { return errNoCookies } c.key = append(c.key[:0], kv.key...) c.value = append(c.value[:0], kv.value...) for s.next(kv) { if len(kv.key) != 0 { // Case-insensitive switch on first char switch kv.key[0] | 0x20 { case 'm': if utils.CaseInsensitiveCompare(bytestr.StrCookieMaxAge, kv.key) { maxAge, err := bytesconv.ParseUint(kv.value) if err != nil { return err } c.maxAge = maxAge } case 'e': // "expires" if utils.CaseInsensitiveCompare(bytestr.StrCookieExpires, kv.key) { v := bytesconv.B2s(kv.value) // Try the same two formats as net/http // See: https://github.com/golang/go/blob/00379be17e63a5b75b3237819392d2dc3b313a27/src/net/http/cookie.go#L133-L135 exptime, err := time.ParseInLocation(time.RFC1123, v, time.UTC) if err != nil { exptime, err = time.Parse("Mon, 02-Jan-2006 15:04:05 MST", v) if err != nil { return err } } c.expire = exptime } case 'd': // "domain" if utils.CaseInsensitiveCompare(bytestr.StrCookieDomain, kv.key) { c.domain = append(c.domain[:0], kv.value...) } case 'p': // "path" if utils.CaseInsensitiveCompare(bytestr.StrCookiePath, kv.key) { c.path = append(c.path[:0], kv.value...) } case 's': // "samesite" if utils.CaseInsensitiveCompare(bytestr.StrCookieSameSite, kv.key) { // Case-insensitive switch on first char switch kv.value[0] | 0x20 { case 'l': // "lax" if utils.CaseInsensitiveCompare(bytestr.StrCookieSameSiteLax, kv.value) { c.sameSite = CookieSameSiteLaxMode } case 's': // "strict" if utils.CaseInsensitiveCompare(bytestr.StrCookieSameSiteStrict, kv.value) { c.sameSite = CookieSameSiteStrictMode } case 'n': // "none" if utils.CaseInsensitiveCompare(bytestr.StrCookieSameSiteNone, kv.value) { c.sameSite = CookieSameSiteNoneMode } } } } } else if len(kv.value) != 0 { // Case-insensitive switch on first char switch kv.value[0] | 0x20 { case 'h': // "httponly" if utils.CaseInsensitiveCompare(bytestr.StrCookieHTTPOnly, kv.value) { c.httpOnly = true } case 's': // "secure" if utils.CaseInsensitiveCompare(bytestr.StrCookieSecure, kv.value) { c.secure = true } else if utils.CaseInsensitiveCompare(bytestr.StrCookieSameSite, kv.value) { c.sameSite = CookieSameSiteDefaultMode } case 'p': // "partitioned" if utils.CaseInsensitiveCompare(bytestr.StrCookiePartitioned, kv.value) { c.partitioned = true } } } // else empty or no match } return nil } // MaxAge returns the seconds until the cookie is meant to expire or 0 // if no max age. func (c *Cookie) MaxAge() int { return c.maxAge } // SetMaxAge sets cookie expiration time based on seconds. // // Values: // // > 0: Set max-age to the specified number of seconds // = 0: Unset the max-age attribute (no max-age appears in the cookie) // < 0: Set max-age=0 to instruct the browser to immediately delete the cookie func (c *Cookie) SetMaxAge(seconds int) { c.maxAge = seconds } // Expire returns cookie expiration time. // // CookieExpireUnlimited is returned if cookie doesn't expire func (c *Cookie) Expire() time.Time { expire := c.expire if expire.IsZero() { expire = CookieExpireUnlimited } return expire } // Domain returns cookie domain. // // The returned domain is valid until the next Cookie modification method call. func (c *Cookie) Domain() []byte { return c.domain } // Path returns cookie path. func (c *Cookie) Path() []byte { return c.path } // Secure returns true if the cookie is secure. func (c *Cookie) Secure() bool { return c.secure } // SetSecure sets cookie's secure flag to the given value. func (c *Cookie) SetSecure(secure bool) { c.secure = secure } // SameSite returns the SameSite mode. func (c *Cookie) SameSite() CookieSameSite { return c.sameSite } // Partitioned returns if cookie is partitioned. func (c *Cookie) Partitioned() bool { return c.partitioned } // SetSameSite sets the cookie's SameSite flag to the given value. // set value CookieSameSiteNoneMode will set Secure to true also to avoid browser rejection func (c *Cookie) SetSameSite(mode CookieSameSite) { c.sameSite = mode if mode == CookieSameSiteNoneMode { c.SetSecure(true) } } // HTTPOnly returns true if the cookie is http only. func (c *Cookie) HTTPOnly() bool { return c.httpOnly } // SetHTTPOnly sets cookie's httpOnly flag to the given value. func (c *Cookie) SetHTTPOnly(httpOnly bool) { c.httpOnly = httpOnly } // SetPartitioned sets cookie as partitioned. Setting Partitioned to true will also set Secure. func (c *Cookie) SetPartitioned(partitioned bool) { c.partitioned = partitioned if partitioned { c.SetSecure(true) } } // String returns cookie representation. func (c *Cookie) String() string { return string(c.Cookie()) } func decodeCookieArg(dst, src []byte, skipQuotes bool) []byte { for len(src) > 0 && src[0] == ' ' { src = src[1:] } for len(src) > 0 && src[len(src)-1] == ' ' { src = src[:len(src)-1] } if skipQuotes { if len(src) > 1 && src[0] == '"' && src[len(src)-1] == '"' { src = src[1 : len(src)-1] } } return append(dst[:0], src...) } func getCookieKey(dst, src []byte) []byte { n := bytes.IndexByte(src, '=') if n >= 0 { src = src[:n] } return decodeCookieArg(dst, src, false) } func warnIfInvalid(value []byte) bool { for i := range value { if bytesconv.ValidCookieValueTable[value[i]] == 0 { hlog.SystemLogger().Warnf("Invalid byte %q in Cookie.Value, "+ "it may cause compatibility problems with user agents", value[i]) return false } } return true } ================================================ FILE: pkg/protocol/cookie_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package protocol import ( "math/rand" "strings" "testing" "time" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestCookieAppendBytes(t *testing.T) { t.Parallel() c := &Cookie{} testCookieAppendBytes(t, c, "", "bar", "bar") testCookieAppendBytes(t, c, "foo", "", "foo=") testCookieAppendBytes(t, c, "ффф", "12 лодлы", "ффф=12 лодлы") c.SetDomain("foobar.com") testCookieAppendBytes(t, c, "a", "b", "a=b; domain=foobar.com") c.SetPath("/a/b") testCookieAppendBytes(t, c, "aa", "bb", "aa=bb; domain=foobar.com; path=/a/b") c.SetExpire(CookieExpireDelete) testCookieAppendBytes(t, c, "xxx", "yyy", "xxx=yyy; expires=Tue, 10 Nov 2009 23:00:00 GMT; domain=foobar.com; path=/a/b") c.SetPartitioned(true) testCookieAppendBytes(t, c, "xxx", "yyy", "xxx=yyy; expires=Tue, 10 Nov 2009 23:00:00 GMT; domain=foobar.com; path=/a/b; secure; Partitioned") } func testCookieAppendBytes(t *testing.T, c *Cookie, key, value, expectedS string) { c.SetKey(key) c.SetValue(value) result := string(c.AppendBytes(nil)) if result != expectedS { t.Fatalf("Unexpected cookie %q. Expecting %q", result, expectedS) } } func TestParseRequestCookies(t *testing.T) { t.Parallel() testParseRequestCookies(t, "", "") testParseRequestCookies(t, "=", "") testParseRequestCookies(t, "foo", "foo") testParseRequestCookies(t, "=foo", "foo") testParseRequestCookies(t, "bar=", "bar=") testParseRequestCookies(t, "xxx=aa;bb=c; =d; ;;e=g", "xxx=aa; bb=c; d; e=g") testParseRequestCookies(t, "a;b;c; d=1;d=2", "a; b; c; d=1; d=2") testParseRequestCookies(t, " %D0%B8%D0%B2%D0%B5%D1%82=a%20b%3Bc ;s%20s=aaa ", "%D0%B8%D0%B2%D0%B5%D1%82=a%20b%3Bc; s%20s=aaa") } func testParseRequestCookies(t *testing.T, s, expectedS string) { cookies := parseRequestCookies(nil, []byte(s)) ss := string(appendRequestCookieBytes(nil, cookies)) if ss != expectedS { t.Fatalf("Unexpected cookies after parsing: %q. Expecting %q. String to parse %q", ss, expectedS, s) } } func TestAppendRequestCookieBytes(t *testing.T) { t.Parallel() testAppendRequestCookieBytes(t, "=", "") testAppendRequestCookieBytes(t, "foo=", "foo=") testAppendRequestCookieBytes(t, "=bar", "bar") testAppendRequestCookieBytes(t, "привет=a bc&s s=aaa", "привет=a bc; s s=aaa") } func testAppendRequestCookieBytes(t *testing.T, s, expectedS string) { kvs := strings.Split(s, "&") cookies := make([]argsKV, 0, len(kvs)) for _, ss := range kvs { tmp := strings.SplitN(ss, "=", 2) if len(tmp) != 2 { t.Fatalf("Cannot find '=' in %q, part of %q", ss, s) } cookies = append(cookies, argsKV{ key: []byte(tmp[0]), value: []byte(tmp[1]), }) } prefix := "foobar" result := string(appendRequestCookieBytes([]byte(prefix), cookies)) if result[:len(prefix)] != prefix { t.Fatalf("unexpected prefix %q. Expecting %q for cookie %q", result[:len(prefix)], prefix, s) } result = result[len(prefix):] if result != expectedS { t.Fatalf("Unexpected result %q. Expecting %q for cookie %q", result, expectedS, s) } } func TestCookieSecureHTTPOnly(t *testing.T) { t.Parallel() var c Cookie if err := c.Parse("foo=bar; HttpOnly; secure"); err != nil { t.Fatalf("unexpected error: %s", err) } if !c.Secure() { t.Fatalf("secure must be set") } if !c.HTTPOnly() { t.Fatalf("HttpOnly must be set") } s := c.String() if !strings.Contains(s, "; secure") { t.Fatalf("missing secure flag in cookie %q", s) } if !strings.Contains(s, "; HttpOnly") { t.Fatalf("missing HttpOnly flag in cookie %q", s) } } func TestCookieSecure(t *testing.T) { t.Parallel() var c Cookie if err := c.Parse("foo=bar; secure"); err != nil { t.Fatalf("unexpected error: %s", err) } if !c.Secure() { t.Fatalf("secure must be set") } s := c.String() if !strings.Contains(s, "; secure") { t.Fatalf("missing secure flag in cookie %q", s) } if err := c.Parse("foo=bar"); err != nil { t.Fatalf("unexpected error: %s", err) } if c.Secure() { t.Fatalf("Unexpected secure flag set") } s = c.String() if strings.Contains(s, "secure") { t.Fatalf("unexpected secure flag in cookie %q", s) } } func TestCookieSameSite(t *testing.T) { t.Parallel() var c Cookie if err := c.Parse("foo=bar; samesite"); err != nil { t.Fatalf("unexpected error: %s", err) } if c.SameSite() != CookieSameSiteDefaultMode { t.Fatalf("SameSite must be set") } s := c.String() if !strings.Contains(s, "; SameSite") { t.Fatalf("missing SameSite flag in cookie %q", s) } if err := c.Parse("foo=bar; samesite=lax"); err != nil { t.Fatalf("unexpected error: %s", err) } if c.SameSite() != CookieSameSiteLaxMode { t.Fatalf("SameSite Lax Mode must be set") } s = c.String() if !strings.Contains(s, "; SameSite=Lax") { t.Fatalf("missing SameSite flag in cookie %q", s) } if err := c.Parse("foo=bar; samesite=strict"); err != nil { t.Fatalf("unexpected error: %s", err) } if c.SameSite() != CookieSameSiteStrictMode { t.Fatalf("SameSite Strict Mode must be set") } s = c.String() if !strings.Contains(s, "; SameSite=Strict") { t.Fatalf("missing SameSite flag in cookie %q", s) } if err := c.Parse("foo=bar; samesite=none"); err != nil { t.Fatalf("unexpected error: %s", err) } if c.SameSite() != CookieSameSiteNoneMode { t.Fatalf("SameSite None Mode must be set") } s = c.String() if !strings.Contains(s, "; SameSite=None") { t.Fatalf("missing SameSite flag in cookie %q", s) } if err := c.Parse("foo=bar"); err != nil { t.Fatalf("unexpected error: %s", err) } c.SetSameSite(CookieSameSiteNoneMode) s = c.String() if !strings.Contains(s, "; SameSite=None") { t.Fatalf("missing SameSite flag in cookie %q", s) } if !strings.Contains(s, "; secure") { t.Fatalf("missing Secure flag in cookie %q", s) } if err := c.Parse("foo=bar"); err != nil { t.Fatalf("unexpected error: %s", err) } if c.SameSite() != CookieSameSiteDisabled { t.Fatalf("Unexpected SameSite flag set") } s = c.String() if strings.Contains(s, "SameSite") { t.Fatalf("unexpected SameSite flag in cookie %q", s) } } func TestCookiePartitioned(t *testing.T) { t.Parallel() var c Cookie if err := c.Parse("__Host-name=value; Secure; Path=/; SameSite=None; Partitioned;"); err != nil { t.Fatalf("unexpected error for valid paritionedd cookie: %s", err) } if !c.Partitioned() { t.Fatalf("partitioned must be set") } if err := c.Parse("foo=bar"); err != nil { t.Fatalf("unexpected error: %s", err) } c.SetPartitioned(true) s := c.String() if !strings.Contains(s, "; Partitioned") { t.Fatalf("missing Partitioned flag in cookie %q", s) } if !strings.Contains(s, "; secure") { t.Fatalf("missing Secure flag in cookie %q", s) } } func TestCookieMaxAgeExpires(t *testing.T) { t.Parallel() var c Cookie maxAge := 100 if err := c.Parse("foo=bar; max-age=100"); err != nil { t.Fatalf("unexpected error: %s", err) } if maxAge != c.MaxAge() { t.Fatalf("max-age must be set") } s := c.String() if !strings.Contains(s, "; max-age=100") { t.Fatalf("missing max-age flag in cookie %q", s) } if err := c.Parse("foo=bar; expires=Tue, 10 Nov 2009 23:00:00 GMT; max-age=100;"); err != nil { t.Fatalf("unexpected error: %s", err) } if maxAge != c.MaxAge() { t.Fatalf("max-age ignored") } expectedExpires := time.Date(2009, 11, 10, 23, 0, 0, 0, time.UTC) if !c.Expire().Equal(expectedExpires) { t.Fatalf("expires not parsed correctly. Got %v, expected %v", c.Expire(), expectedExpires) } c.SetExpire(time.Time{}) s = c.String() if s != "foo=bar; max-age=100" { t.Fatalf("missing max-age in cookie %q", s) } c.SetMaxAge(-1) s = c.String() if s != "foo=bar; max-age=0" { t.Fatalf("negative max-age should output 0: %q", s) } expires := time.Unix(100, 0) c.SetExpire(expires) s = c.String() if s != "foo=bar; max-age=0; expires=Thu, 01 Jan 1970 00:01:40 GMT" { t.Fatalf("expires should be included along with negative max-age (output as 0): %q", s) } c.SetMaxAge(0) s = c.String() if s != "foo=bar; expires=Thu, 01 Jan 1970 00:01:40 GMT" { t.Fatalf("missing expires %q", s) } } func TestCookieHttpOnly(t *testing.T) { t.Parallel() var c Cookie if err := c.Parse("foo=bar; HttpOnly"); err != nil { t.Fatalf("unexpected error: %s", err) } if !c.HTTPOnly() { t.Fatalf("HTTPOnly must be set") } s := c.String() if !strings.Contains(s, "; HttpOnly") { t.Fatalf("missing HttpOnly flag in cookie %q", s) } if err := c.Parse("foo=bar"); err != nil { t.Fatalf("unexpected error: %s", err) } if c.HTTPOnly() { t.Fatalf("Unexpected HTTPOnly flag set") } s = c.String() if strings.Contains(s, "HttpOnly") { t.Fatalf("unexpected HttpOnly flag in cookie %q", s) } } func TestCookieParse(t *testing.T) { t.Parallel() testCookieParse(t, "foo", "foo") testCookieParse(t, "foo=bar", "foo=bar") testCookieParse(t, "foo=", "foo=") testCookieParse(t, `foo="bar"`, "foo=bar") testCookieParse(t, `"foo"=bar`, `"foo"=bar`) testCookieParse(t, "foo=bar; Domain=aaa.com; PATH=/foo/bar", "foo=bar; domain=aaa.com; path=/foo/bar") testCookieParse(t, "foo=bar; max-age= 101 ; expires= Tue, 10 Nov 2009 23:00:00 GMT", "foo=bar; max-age=101; expires=Tue, 10 Nov 2009 23:00:00 GMT") testCookieParse(t, " xxx = yyy ; path=/a/b;;;domain=foobar.com ; expires= Tue, 10 Nov 2009 23:00:00 GMT ; ;;", "xxx=yyy; expires=Tue, 10 Nov 2009 23:00:00 GMT; domain=foobar.com; path=/a/b") } func Test_decodeCookieArg(t *testing.T) { src := []byte(" \"aaaaabbbbb\" ") dst := make([]byte, 0) dst = decodeCookieArg(dst, src, true) assert.DeepEqual(t, []byte("aaaaabbbbb"), dst) } func testCookieParse(t *testing.T, s, expectedS string) { var c Cookie if err := c.Parse(s); err != nil { t.Fatalf("unexpected error: %s", err) } result := string(c.Cookie()) if result != expectedS { t.Fatalf("unexpected cookies %q. Expecting %q. Original %q", result, expectedS, s) } } func Test_WarnIfInvalid(t *testing.T) { assert.False(t, warnIfInvalid([]byte(";"))) assert.False(t, warnIfInvalid([]byte("\\"))) assert.False(t, warnIfInvalid([]byte("\""))) assert.True(t, warnIfInvalid([]byte(""))) for i := 0; i < 5; i++ { validCookie := getValidCookie() assert.True(t, warnIfInvalid(validCookie)) } } func getValidCookie() []byte { var validCookie []byte for i := 0; i < 100; i++ { r := rand.Intn(0x78-0x20) + 0x20 if r == ';' || r == '\\' || r == '"' { continue } validCookie = append(validCookie, byte(r)) } return validCookie } ================================================ FILE: pkg/protocol/doc.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // The files in bytebufferpool package are forked from fasthttp[github.com/valyala/fasthttp], // and we keep the original Copyright[Copyright 2015 fasthttp authors] and License of fasthttp for those files. // We also need to modify as we need, the modifications are Copyright of 2022 CloudWeGo Authors. // Thanks for fasthttp authors! Below is the source code information: // Repo: github.com/valyala/fasthttp // Forked Version: v1.36.0 package protocol ================================================ FILE: pkg/protocol/header.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package protocol import ( "bytes" "net/http" "sync" "sync/atomic" "time" "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/internal/nocopy" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/hlog" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol/consts" ) var ( ServerDate atomic.Value ServerDateOnce sync.Once // serverDateOnce.Do(updateServerDate) ) type RequestHeader struct { noCopy nocopy.NoCopy //lint:ignore U1000 until noCopy is used disableNormalizing bool connectionClose bool noDefaultContentType bool // These two fields have been moved close to other bool fields // for reducing RequestHeader object size. cookiesCollected bool contentLength int contentLengthBytes []byte method []byte requestURI []byte host []byte contentType []byte userAgent []byte mulHeader [][]byte protocol string h []argsKV bufKV argsKV trailer *Trailer cookies []argsKV // stores an immutable copy of headers as they were received from the // wire. rawHeaders []byte } func (h *RequestHeader) SetRawHeaders(r []byte) { h.rawHeaders = r } // ResponseHeader represents HTTP response header. // // It is forbidden copying ResponseHeader instances. // Create new instances instead and use CopyTo. // // ResponseHeader instance MUST NOT be used from concurrently running // goroutines. type ResponseHeader struct { noCopy nocopy.NoCopy //lint:ignore U1000 until noCopy is used disableNormalizing bool connectionClose bool noDefaultContentType bool noDefaultDate bool statusCode int contentLength int contentLengthBytes []byte contentEncoding []byte contentType []byte server []byte mulHeader [][]byte protocol string h []argsKV bufKV argsKV trailer *Trailer cookies []argsKV headerLength int } // SetHeaderLength sets the size of header for tracer. func (h *ResponseHeader) SetHeaderLength(length int) { h.headerLength = length } // GetHeaderLength gets the size of header for tracer. func (h *ResponseHeader) GetHeaderLength() int { return h.headerLength } // SetContentRange sets 'Content-Range: bytes startPos-endPos/contentLength' // header. func (h *ResponseHeader) SetContentRange(startPos, endPos, contentLength int) { b := h.bufKV.value[:0] b = append(b, bytestr.StrBytes...) b = append(b, ' ') b = bytesconv.AppendUint(b, startPos) b = append(b, '-') b = bytesconv.AppendUint(b, endPos) b = append(b, '/') b = bytesconv.AppendUint(b, contentLength) h.bufKV.value = b h.SetCanonical(bytestr.StrContentRange, h.bufKV.value) } func (h *ResponseHeader) NoDefaultContentType() bool { return h.noDefaultContentType } // SetConnectionClose sets 'Connection: close' header. func (h *ResponseHeader) SetConnectionClose(close bool) { h.connectionClose = close } func (h *ResponseHeader) PeekArgBytes(key []byte) []byte { return peekArgBytes(h.h, key) } // Deprecated: Use ResponseHeader.SetProtocol(consts.HTTP11) instead // // Now SetNoHTTP11(true) equal to SetProtocol(consts.HTTP10) // SetNoHTTP11(false) equal to SetProtocol(consts.HTTP11) func (h *ResponseHeader) SetNoHTTP11(b bool) { if b { h.protocol = consts.HTTP10 return } h.protocol = consts.HTTP11 } // Cookie fills cookie for the given cookie.Key. // // Returns false if cookie with the given cookie.Key is missing. func (h *ResponseHeader) Cookie(cookie *Cookie) bool { v := peekArgBytes(h.cookies, cookie.Key()) if v == nil { return false } cookie.ParseBytes(v) //nolint:errcheck return true } // FullCookie returns complete cookie bytes func (h *ResponseHeader) FullCookie() []byte { return h.Peek(consts.HeaderSetCookie) } // IsHTTP11 returns true if the response is HTTP/1.1. func (h *ResponseHeader) IsHTTP11() bool { return h.protocol == consts.HTTP11 } // SetContentType sets Content-Type header value. func (h *ResponseHeader) SetContentType(contentType string) { h.contentType = append(h.contentType[:0], contentType...) } func (h *ResponseHeader) GetHeaders() []argsKV { return h.h } // Reset clears response header. func (h *ResponseHeader) Reset() { h.disableNormalizing = false h.Trailer().disableNormalizing = false h.noDefaultContentType = false h.noDefaultDate = false h.ResetSkipNormalize() } // CopyTo copies all the headers to dst. func (h *ResponseHeader) CopyTo(dst *ResponseHeader) { dst.Reset() dst.disableNormalizing = h.disableNormalizing dst.connectionClose = h.connectionClose dst.noDefaultContentType = h.noDefaultContentType dst.noDefaultDate = h.noDefaultDate dst.statusCode = h.statusCode dst.contentLength = h.contentLength dst.contentLengthBytes = append(dst.contentLengthBytes[:0], h.contentLengthBytes...) dst.contentEncoding = append(dst.contentEncoding[:0], h.contentEncoding...) dst.contentType = append(dst.contentType[:0], h.contentType...) dst.server = append(dst.server[:0], h.server...) dst.h = copyArgs(dst.h, h.h) dst.cookies = copyArgs(dst.cookies, h.cookies) dst.protocol = h.protocol dst.headerLength = h.headerLength h.Trailer().CopyTo(dst.Trailer()) } // Multiple headers with the same key may be added with this function. // Use Set for setting a single header for the given key. // // the Content-Type, Content-Length, Connection, Cookie, // Transfer-Encoding, Host and User-Agent headers can only be set once // and will overwrite the previous value. func (h *RequestHeader) Add(key, value string) { if h.setSpecialHeader(bytesconv.S2b(key), bytesconv.S2b(value)) { return } k := []byte(key) utils.NormalizeHeaderKey(k, h.disableNormalizing) h.h = appendArg(h.h, bytesconv.B2s(k), value, ArgsHasValue) } // VisitAll calls f for each header. // // f must not retain references to key and/or value after returning. // Copy key and/or value contents before returning if you need retaining them. func (h *ResponseHeader) VisitAll(f func(key, value []byte)) { if len(h.contentLengthBytes) > 0 { f(bytestr.StrContentLength, h.contentLengthBytes) } contentType := h.ContentType() if len(contentType) > 0 { f(bytestr.StrContentType, contentType) } contentEncoding := h.ContentEncoding() if len(contentEncoding) > 0 { f(bytestr.StrContentEncoding, contentEncoding) } server := h.Server() if len(server) > 0 { f(bytestr.StrServer, server) } if len(h.cookies) > 0 { visitArgs(h.cookies, func(k, v []byte) { f(bytestr.StrSetCookie, v) }) } if !h.Trailer().Empty() { f(bytestr.StrTrailer, h.Trailer().GetBytes()) } visitArgs(h.h, f) if h.ConnectionClose() { f(bytestr.StrConnection, bytestr.StrClose) } } // IsHTTP11 returns true if the request is HTTP/1.1. func (h *RequestHeader) IsHTTP11() bool { return h.protocol == consts.HTTP11 } func (h *RequestHeader) SetProtocol(p string) { h.protocol = p } func (h *RequestHeader) GetProtocol() string { return h.protocol } // Deprecated: Use RequestHeader.SetProtocol(consts.HTTP11) instead // // Now SetNoHTTP11(true) equal to SetProtocol(consts.HTTP10) // SetNoHTTP11(false) equal to SetProtocol(consts.HTTP11) func (h *RequestHeader) SetNoHTTP11(b bool) { if b { h.protocol = consts.HTTP10 return } h.protocol = consts.HTTP11 } func (h *RequestHeader) InitBufValue(size int) { if size > cap(h.bufKV.value) { h.bufKV.value = make([]byte, 0, size) } } func (h *RequestHeader) GetBufValue() []byte { return h.bufKV.value } // HasAcceptEncodingBytes returns true if the header contains // the given Accept-Encoding value. func (h *RequestHeader) HasAcceptEncodingBytes(acceptEncoding []byte) bool { ae := h.peek(consts.HeaderAcceptEncoding) n := bytes.Index(ae, acceptEncoding) if n < 0 { return false } b := ae[n+len(acceptEncoding):] if len(b) > 0 && b[0] != ',' { return false } if n == 0 { return true } return ae[n-1] == ' ' } func (h *RequestHeader) PeekIfModifiedSinceBytes() []byte { return h.peek(consts.HeaderIfModifiedSince) } // RequestURI returns RequestURI from the first HTTP request line. func (h *RequestHeader) RequestURI() []byte { requestURI := h.requestURI if len(requestURI) == 0 { requestURI = bytestr.StrSlash } return requestURI } func (h *RequestHeader) PeekArgBytes(key []byte) []byte { return peekArgBytes(h.h, key) } // RawHeaders returns raw header key/value bytes. // // Depending on server configuration, header keys may be normalized to // capital-case in place. // // This copy is set aside during parsing, so empty slice is returned for all // cases where parsing did not happen. Similarly, request line is not stored // during parsing and can not be returned. // // The slice is not safe to use after the handler returns. func (h *RequestHeader) RawHeaders() []byte { return h.rawHeaders } // AppendBytes appends request header representation to dst and returns // the extended dst. func (h *RequestHeader) AppendBytes(dst []byte) []byte { dst = append(dst, h.Method()...) dst = append(dst, ' ') dst = append(dst, h.RequestURI()...) dst = append(dst, ' ') dst = append(dst, bytestr.StrHTTP11...) dst = append(dst, bytestr.StrCRLF...) userAgent := h.UserAgent() if len(userAgent) > 0 { dst = appendHeaderLine(dst, bytestr.StrUserAgent, userAgent) } host := h.Host() if len(host) > 0 { dst = appendHeaderLine(dst, bytestr.StrHost, host) } contentType := h.ContentType() if len(contentType) == 0 && !h.IgnoreBody() && !h.noDefaultContentType { contentType = bytestr.MIMEPostForm } if len(contentType) > 0 { dst = appendHeaderLine(dst, bytestr.StrContentType, contentType) } if len(h.contentLengthBytes) > 0 { dst = appendHeaderLine(dst, bytestr.StrContentLength, h.contentLengthBytes) } for i, n := 0, len(h.h); i < n; i++ { kv := &h.h[i] dst = appendHeaderLine(dst, kv.key, kv.value) } if !h.Trailer().Empty() { dst = appendHeaderLine(dst, bytestr.StrTrailer, h.Trailer().GetBytes()) } // there is no need in h.collectCookies() here, since if cookies aren't collected yet, // they all are located in h.h. n := len(h.cookies) if n > 0 { dst = append(dst, bytestr.StrCookie...) dst = append(dst, bytestr.StrColonSpace...) dst = appendRequestCookieBytes(dst, h.cookies) dst = append(dst, bytestr.StrCRLF...) } if h.ConnectionClose() { dst = appendHeaderLine(dst, bytestr.StrConnection, bytestr.StrClose) } return append(dst, bytestr.StrCRLF...) } // Header returns request header representation. // // The returned representation is valid until the next call to RequestHeader methods. func (h *RequestHeader) Header() []byte { h.bufKV.value = h.AppendBytes(h.bufKV.value[:0]) return h.bufKV.value } // IsPut returns true if request method is PUT. func (h *RequestHeader) IsPut() bool { return bytes.Equal(h.Method(), bytestr.StrPut) } // IsHead returns true if request method is HEAD. func (h *RequestHeader) IsHead() bool { return bytes.Equal(h.Method(), bytestr.StrHead) } // IsPost returns true if request method is POST. func (h *RequestHeader) IsPost() bool { return bytes.Equal(h.Method(), bytestr.StrPost) } // IsDelete returns true if request method is DELETE. func (h *RequestHeader) IsDelete() bool { return bytes.Equal(h.Method(), bytestr.StrDelete) } // IsConnect returns true if request method is CONNECT. func (h *RequestHeader) IsConnect() bool { return bytes.Equal(h.Method(), bytestr.StrConnect) } func (h *RequestHeader) IgnoreBody() bool { return h.IsGet() || h.IsHead() } // ContentLength returns Content-Length header value. // // It may be negative: // -1 means Transfer-Encoding: chunked. func (h *RequestHeader) ContentLength() int { return h.contentLength } // SetHost sets Host header value. func (h *RequestHeader) SetHost(host string) { h.host = append(h.host[:0], host...) } // SetStatusCode sets response status code. func (h *ResponseHeader) SetStatusCode(statusCode int) { checkWriteHeaderCode(statusCode) h.statusCode = statusCode } func checkWriteHeaderCode(code int) { // For now, we only emit a warning for bad codes. // In the future we might block things over 599 or under 100 if code < 100 || code > 599 { hlog.SystemLogger().Warnf("Invalid StatusCode code %v, status code should not be under 100 or over 599.\n"+ "For more info: https://www.rfc-editor.org/rfc/rfc9110.html#name-status-codes", code) } } func (h *ResponseHeader) ResetSkipNormalize() { h.protocol = "" h.connectionClose = false h.statusCode = 0 h.contentLength = 0 h.contentLengthBytes = h.contentLengthBytes[:0] h.contentEncoding = h.contentEncoding[:0] h.contentType = h.contentType[:0] h.server = h.server[:0] h.h = h.h[:0] h.cookies = h.cookies[:0] h.Trailer().ResetSkipNormalize() h.mulHeader = h.mulHeader[:0] } // ContentLength returns Content-Length header value. // // It may be negative: // -1 means Transfer-Encoding: chunked. // -2 means Transfer-Encoding: identity. func (h *ResponseHeader) ContentLength() int { return h.contentLength } // Set sets the given 'key: value' header. // // Use Add for setting multiple header values under the same key. func (h *ResponseHeader) Set(key, value string) { initHeaderKV(&h.bufKV, key, value, h.disableNormalizing) h.SetCanonical(h.bufKV.key, h.bufKV.value) } // Add adds the given 'key: value' header. // // Multiple headers with the same key may be added with this function. // Use Set for setting a single header for the given key. // // the Content-Type, Content-Length, Connection, Server, Set-Cookie, // Transfer-Encoding and Date headers can only be set once and will // overwrite the previous value. func (h *ResponseHeader) Add(key, value string) { if h.setSpecialHeader(bytesconv.S2b(key), bytesconv.S2b(value)) { return } k := []byte(key) utils.NormalizeHeaderKey(k, h.disableNormalizing) h.h = appendArg(h.h, bytesconv.B2s(k), value, ArgsHasValue) } // SetContentLength sets Content-Length header value. // // Content-Length may be negative: // -1 means Transfer-Encoding: chunked. // -2 means Transfer-Encoding: identity. func (h *ResponseHeader) SetContentLength(contentLength int) { if h.MustSkipContentLength() { return } h.contentLength = contentLength if contentLength >= 0 { h.contentLengthBytes = bytesconv.AppendUint(h.contentLengthBytes[:0], contentLength) h.h = delAllArgsBytes(h.h, bytestr.StrTransferEncoding) } else { h.contentLengthBytes = h.contentLengthBytes[:0] value := bytestr.StrChunked if contentLength == -2 { h.SetConnectionClose(true) value = bytestr.StrIdentity } h.h = setArgBytes(h.h, bytestr.StrTransferEncoding, value, ArgsHasValue) } } func (h *ResponseHeader) ContentLengthBytes() []byte { return h.contentLengthBytes } func (h *ResponseHeader) InitContentLengthWithValue(contentLength int) { h.contentLength = contentLength } // VisitAllCookie calls f for each response cookie. // // Cookie name is passed in key and the whole Set-Cookie header value // is passed in value on each f invocation. Value may be parsed // with Cookie.ParseBytes(). // // f must not retain references to key and/or value after returning. func (h *ResponseHeader) VisitAllCookie(f func(key, value []byte)) { visitArgs(h.cookies, f) } // DelAllCookies removes all the cookies from response headers. func (h *ResponseHeader) DelAllCookies() { h.cookies = h.cookies[:0] } // DelCookie removes cookie under the given key from response header. // // Note that DelCookie doesn't remove the cookie from the client. // Use DelClientCookie instead. func (h *ResponseHeader) DelCookie(key string) { h.cookies = delAllArgs(h.cookies, key) } // DelCookieBytes removes cookie under the given key from response header. // // Note that DelCookieBytes doesn't remove the cookie from the client. // Use DelClientCookieBytes instead. func (h *ResponseHeader) DelCookieBytes(key []byte) { h.DelCookie(bytesconv.B2s(key)) } // DelBytes deletes header with the given key. func (h *ResponseHeader) DelBytes(key []byte) { k := []byte(string(key)) utils.NormalizeHeaderKey(k, h.disableNormalizing) h.del(k) } // Header returns response header representation. // // The returned value is valid until the next call to ResponseHeader methods. func (h *ResponseHeader) Header() []byte { h.bufKV.value = h.AppendBytes(h.bufKV.value[:0]) return h.bufKV.value } func (h *ResponseHeader) PeekLocation() []byte { return h.peek(consts.HeaderLocation) } // DelClientCookie instructs the client to remove the given cookie. // This doesn't work for a cookie with specific domain or path, // you should delete it manually like: // // c := AcquireCookie() // c.SetKey(key) // c.SetDomain("example.com") // c.SetPath("/path") // c.SetExpire(CookieExpireDelete) // h.SetCookie(c) // ReleaseCookie(c) // // Use DelCookie if you want just removing the cookie from response header. func (h *ResponseHeader) DelClientCookie(key string) { h.DelCookie(key) c := AcquireCookie() c.SetKey(key) c.SetExpire(CookieExpireDelete) h.SetCookie(c) ReleaseCookie(c) } // DelClientCookieBytes instructs the client to remove the given cookie. // This doesn't work for a cookie with specific domain or path, // you should delete it manually like: // // c := AcquireCookie() // c.SetKey(key) // c.SetDomain("example.com") // c.SetPath("/path") // c.SetExpire(CookieExpireDelete) // h.SetCookie(c) // ReleaseCookie(c) // // Use DelCookieBytes if you want just removing the cookie from response header. func (h *ResponseHeader) DelClientCookieBytes(key []byte) { h.DelClientCookie(bytesconv.B2s(key)) } // Peek returns header value for the given key. // // Returned value is valid until the next call to ResponseHeader. // Do not store references to returned value. Make copies instead. func (h *ResponseHeader) Peek(key string) []byte { k := []byte(key) utils.NormalizeHeaderKey(k, h.disableNormalizing) return h.peek(bytesconv.B2s(k)) } func (h *ResponseHeader) IsDisableNormalizing() bool { return h.disableNormalizing } func (h *ResponseHeader) ParseSetCookie(value []byte) { var kv *argsKV h.cookies, kv = allocArg(h.cookies) kv.key = getCookieKey(kv.key, value) kv.value = append(kv.value[:0], value...) } func (h *ResponseHeader) peek(key string) []byte { switch key { case consts.HeaderContentType: return h.ContentType() case consts.HeaderContentEncoding: return h.ContentEncoding() case consts.HeaderServer: return h.Server() case consts.HeaderConnection: if h.ConnectionClose() { return bytestr.StrClose } return peekArgStr(h.h, key) case consts.HeaderContentLength: return h.contentLengthBytes case consts.HeaderSetCookie: return appendResponseCookieBytes(nil, h.cookies) case consts.HeaderTrailer: return h.Trailer().GetBytes() default: return peekArgStr(h.h, key) } } // PeekAll returns all header value for the given key. // // The returned value is valid until the request is released, // either though ReleaseResponse or your request handler returning. // Any future calls to the Peek* will modify the returned value. // Do not store references to returned value. Use ResponseHeader.GetAll(key) instead. func (h *ResponseHeader) PeekAll(key string) [][]byte { k := []byte(key) utils.NormalizeHeaderKey(k, h.disableNormalizing) return h.peekAll(k) } func (h *ResponseHeader) peekAll(key []byte) [][]byte { h.mulHeader = h.mulHeader[:0] switch string(key) { case consts.HeaderContentType: if contentType := h.ContentType(); len(contentType) > 0 { h.mulHeader = append(h.mulHeader, contentType) } case consts.HeaderContentEncoding: if contentEncoding := h.ContentEncoding(); len(contentEncoding) > 0 { h.mulHeader = append(h.mulHeader, contentEncoding) } case consts.HeaderServer: if server := h.Server(); len(server) > 0 { h.mulHeader = append(h.mulHeader, server) } case consts.HeaderConnection: if h.ConnectionClose() { h.mulHeader = append(h.mulHeader, bytestr.StrClose) } else { h.mulHeader = peekAllArgBytesToDst(h.mulHeader, h.h, key) } case consts.HeaderContentLength: h.mulHeader = append(h.mulHeader, h.contentLengthBytes) case consts.HeaderSetCookie: h.mulHeader = append(h.mulHeader, appendResponseCookieBytes(nil, h.cookies)) default: h.mulHeader = peekAllArgBytesToDst(h.mulHeader, h.h, key) } return h.mulHeader } // PeekAll returns all header value for the given key. // // The returned value is valid until the request is released, // either though ReleaseRequest or your request handler returning. // Any future calls to the Peek* will modify the returned value. // Do not store references to returned value. Use RequestHeader.GetAll(key) instead. func (h *RequestHeader) PeekAll(key string) [][]byte { k := []byte(key) utils.NormalizeHeaderKey(k, h.disableNormalizing) return h.peekAll(k) } func (h *RequestHeader) peekAll(key []byte) [][]byte { h.mulHeader = h.mulHeader[:0] switch string(key) { case consts.HeaderHost: if host := h.Host(); len(host) > 0 { h.mulHeader = append(h.mulHeader, host) } case consts.HeaderContentType: if contentType := h.ContentType(); len(contentType) > 0 { h.mulHeader = append(h.mulHeader, contentType) } case consts.HeaderUserAgent: if ua := h.UserAgent(); len(ua) > 0 { h.mulHeader = append(h.mulHeader, ua) } case consts.HeaderConnection: if h.ConnectionClose() { h.mulHeader = append(h.mulHeader, bytestr.StrClose) } else { h.mulHeader = peekAllArgBytesToDst(h.mulHeader, h.h, key) } case consts.HeaderContentLength: h.mulHeader = append(h.mulHeader, h.contentLengthBytes) case consts.HeaderCookie: if h.cookiesCollected { h.mulHeader = append(h.mulHeader, appendRequestCookieBytes(nil, h.cookies)) } else { h.mulHeader = peekAllArgBytesToDst(h.mulHeader, h.h, key) } default: h.mulHeader = peekAllArgBytesToDst(h.mulHeader, h.h, key) } return h.mulHeader } // SetContentTypeBytes sets Content-Type header value. func (h *ResponseHeader) SetContentTypeBytes(contentType []byte) { h.contentType = append(h.contentType[:0], contentType...) } // ContentEncoding returns Content-Encoding header value. func (h *ResponseHeader) ContentEncoding() []byte { return h.contentEncoding } // SetContentEncoding sets Content-Encoding header value. func (h *ResponseHeader) SetContentEncoding(contentEncoding string) { h.contentEncoding = append(h.contentEncoding[:0], contentEncoding...) } // SetContentEncodingBytes sets Content-Encoding header value. func (h *ResponseHeader) SetContentEncodingBytes(contentEncoding []byte) { h.contentEncoding = append(h.contentEncoding[:0], contentEncoding...) } func (h *ResponseHeader) SetContentLengthBytes(contentLength []byte) { h.contentLengthBytes = append(h.contentLengthBytes[:0], contentLength...) } // SetCanonical sets the given 'key: value' header assuming that // key is in canonical form. func (h *ResponseHeader) SetCanonical(key, value []byte) { if h.setSpecialHeader(key, value) { return } h.h = setArgBytes(h.h, key, value, ArgsHasValue) } // ResetConnectionClose clears 'Connection: close' header if it exists. func (h *ResponseHeader) ResetConnectionClose() { if h.connectionClose { h.connectionClose = false h.h = delAllArgsBytes(h.h, bytestr.StrConnection) } } // Server returns Server header value. func (h *ResponseHeader) Server() []byte { return h.server } func (h *ResponseHeader) AddArgBytes(key, value []byte, noValue bool) { h.h = appendArgBytes(h.h, key, value, noValue) } func (h *ResponseHeader) SetArgBytes(key, value []byte, noValue bool) { h.h = setArgBytes(h.h, key, value, noValue) } // AppendBytes appends response header representation to dst and returns // the extended dst. func (h *ResponseHeader) AppendBytes(dst []byte) []byte { statusCode := h.StatusCode() if statusCode < 0 { statusCode = consts.StatusOK } dst = append(dst, consts.StatusLine(statusCode)...) server := h.Server() if len(server) != 0 { dst = appendHeaderLine(dst, bytestr.StrServer, server) } if !h.noDefaultDate { ServerDateOnce.Do(UpdateServerDate) dst = appendHeaderLine(dst, bytestr.StrDate, ServerDate.Load().([]byte)) } // Append Content-Type only for non-zero responses // or if it is explicitly set. if h.ContentLength() != 0 || len(h.contentType) > 0 { contentType := h.ContentType() if len(contentType) > 0 { dst = appendHeaderLine(dst, bytestr.StrContentType, contentType) } } contentEncoding := h.ContentEncoding() if len(contentEncoding) > 0 { dst = appendHeaderLine(dst, bytestr.StrContentEncoding, contentEncoding) } if len(h.contentLengthBytes) > 0 { dst = appendHeaderLine(dst, bytestr.StrContentLength, h.contentLengthBytes) } for i, n := 0, len(h.h); i < n; i++ { kv := &h.h[i] if h.noDefaultDate || !bytes.Equal(kv.key, bytestr.StrDate) { dst = appendHeaderLine(dst, kv.key, kv.value) } } if !h.Trailer().Empty() { dst = appendHeaderLine(dst, bytestr.StrTrailer, h.Trailer().GetBytes()) } n := len(h.cookies) if n > 0 { for i := 0; i < n; i++ { kv := &h.cookies[i] dst = appendHeaderLine(dst, bytestr.StrSetCookie, kv.value) } } if h.ConnectionClose() { dst = appendHeaderLine(dst, bytestr.StrConnection, bytestr.StrClose) } return append(dst, bytestr.StrCRLF...) } // ConnectionClose returns true if 'Connection: close' header is set. func (h *ResponseHeader) ConnectionClose() bool { return h.connectionClose } func (h *ResponseHeader) GetCookies() []argsKV { return h.cookies } // ContentType returns Content-Type header value. func (h *ResponseHeader) ContentType() []byte { contentType := h.contentType if !h.noDefaultContentType && len(h.contentType) == 0 { contentType = bytestr.DefaultContentType } return contentType } // SetNoDefaultContentType set noDefaultContentType value of ResponseHeader. func (h *ResponseHeader) SetNoDefaultContentType(b bool) { h.noDefaultContentType = b } // SetNoDefaultDate set noDefaultDate value of ResponseHeader. func (h *ResponseHeader) SetNoDefaultDate(b bool) { h.noDefaultDate = b } // SetServerBytes sets Server header value. func (h *ResponseHeader) SetServerBytes(server []byte) { h.server = append(h.server[:0], server...) } func (h *ResponseHeader) MustSkipContentLength() bool { // From http/1.1 specs: // All 1xx (informational), 204 (no content), and 304 (not modified) responses MUST NOT include a message-body statusCode := h.StatusCode() // Fast path. if statusCode < 100 || statusCode == consts.StatusOK { return false } // Slow path. return statusCode == consts.StatusNotModified || statusCode == consts.StatusNoContent || statusCode < 200 } // StatusCode returns response status code. func (h *ResponseHeader) StatusCode() int { if h.statusCode == 0 { return consts.StatusOK } return h.statusCode } // Del deletes header with the given key. func (h *ResponseHeader) Del(key string) { k := []byte(key) utils.NormalizeHeaderKey(k, h.disableNormalizing) h.del(k) } func (h *ResponseHeader) del(key []byte) { switch string(key) { case consts.HeaderContentType: h.contentType = h.contentType[:0] case consts.HeaderContentEncoding: h.contentEncoding = h.contentEncoding[:0] case consts.HeaderServer: h.server = h.server[:0] case consts.HeaderSetCookie: h.cookies = h.cookies[:0] case consts.HeaderContentLength: h.contentLength = 0 h.contentLengthBytes = h.contentLengthBytes[:0] case consts.HeaderConnection: h.connectionClose = false case consts.HeaderTrailer: h.Trailer().ResetSkipNormalize() } h.h = delAllArgsBytes(h.h, key) } // SetBytesV sets the given 'key: value' header. // // Use AddBytesV for setting multiple header values under the same key. func (h *ResponseHeader) SetBytesV(key string, value []byte) { k := []byte(key) utils.NormalizeHeaderKey(k, h.disableNormalizing) h.SetCanonical(k, value) } // Len returns the number of headers set, // i.e. the number of times f is called in VisitAll. func (h *ResponseHeader) Len() int { n := 0 h.VisitAll(func(k, v []byte) { n++ }) return n } // Len returns the number of headers set, // i.e. the number of times f is called in VisitAll. func (h *RequestHeader) Len() int { n := 0 h.VisitAll(func(k, v []byte) { n++ }) return n } // Reset clears request header. func (h *RequestHeader) Reset() { h.disableNormalizing = false h.Trailer().disableNormalizing = false h.ResetSkipNormalize() } // SetByteRange sets 'Range: bytes=startPos-endPos' header. // // - If startPos is negative, then 'bytes=-startPos' value is set. // - If endPos is negative, then 'bytes=startPos-' value is set. func (h *RequestHeader) SetByteRange(startPos, endPos int) { b := h.bufKV.value[:0] b = append(b, bytestr.StrBytes...) b = append(b, '=') if startPos >= 0 { b = bytesconv.AppendUint(b, startPos) } else { endPos = -startPos } b = append(b, '-') if endPos >= 0 { b = bytesconv.AppendUint(b, endPos) } h.bufKV.value = b h.SetCanonical(bytestr.StrRange, h.bufKV.value) } // DelBytes deletes header with the given key. func (h *RequestHeader) DelBytes(key []byte) { k := []byte(string(key)) utils.NormalizeHeaderKey(k, h.disableNormalizing) h.del(k) } // Del deletes header with the given key. func (h *RequestHeader) Del(key string) { k := []byte(key) utils.NormalizeHeaderKey(k, h.disableNormalizing) h.del(k) } func (h *RequestHeader) SetArgBytes(key, value []byte, noValue bool) { h.h = setArgBytes(h.h, key, value, noValue) } func (h *RequestHeader) del(key []byte) { switch string(key) { case consts.HeaderHost: h.host = h.host[:0] case consts.HeaderContentType: h.contentType = h.contentType[:0] case consts.HeaderUserAgent: h.userAgent = h.userAgent[:0] case consts.HeaderCookie: h.cookies = h.cookies[:0] case consts.HeaderContentLength: h.contentLength = 0 h.contentLengthBytes = h.contentLengthBytes[:0] case consts.HeaderConnection: h.connectionClose = false case consts.HeaderTrailer: h.Trailer().ResetSkipNormalize() } h.h = delAllArgsBytes(h.h, key) } // CopyTo copies all the headers to dst. func (h *RequestHeader) CopyTo(dst *RequestHeader) { dst.Reset() dst.disableNormalizing = h.disableNormalizing dst.connectionClose = h.connectionClose dst.noDefaultContentType = h.noDefaultContentType dst.contentLength = h.contentLength dst.contentLengthBytes = append(dst.contentLengthBytes[:0], h.contentLengthBytes...) dst.method = append(dst.method[:0], h.method...) dst.requestURI = append(dst.requestURI[:0], h.requestURI...) dst.host = append(dst.host[:0], h.host...) dst.contentType = append(dst.contentType[:0], h.contentType...) dst.userAgent = append(dst.userAgent[:0], h.userAgent...) h.Trailer().CopyTo(dst.Trailer()) dst.h = copyArgs(dst.h, h.h) dst.cookies = copyArgs(dst.cookies, h.cookies) dst.cookiesCollected = h.cookiesCollected dst.rawHeaders = append(dst.rawHeaders[:0], h.rawHeaders...) dst.protocol = h.protocol } // Peek returns header value for the given key. // // Returned value is valid until the next call to RequestHeader. // Do not store references to returned value. Make copies instead. func (h *RequestHeader) Peek(key string) []byte { k := []byte(key) utils.NormalizeHeaderKey(k, h.disableNormalizing) return h.peek(bytesconv.B2s(k)) } // SetMultipartFormBoundary sets the following Content-Type: // 'multipart/form-data; boundary=...' // where ... is substituted by the given boundary. func (h *RequestHeader) SetMultipartFormBoundary(boundary string) { b := h.bufKV.value[:0] b = append(b, bytestr.MIMEFormData...) b = append(b, ';', ' ') b = append(b, bytestr.StrBoundary...) b = append(b, '=') b = append(b, boundary...) h.bufKV.value = b h.SetContentTypeBytes(h.bufKV.value) } func (h *RequestHeader) ContentLengthBytes() []byte { return h.contentLengthBytes } func (h *RequestHeader) SetContentLengthBytes(contentLength []byte) { h.contentLengthBytes = append(h.contentLengthBytes[:0], contentLength...) } // SetContentTypeBytes sets Content-Type header value. func (h *RequestHeader) SetContentTypeBytes(contentType []byte) { h.contentType = append(h.contentType[:0], contentType...) } // ContentType returns Content-Type header value. func (h *RequestHeader) ContentType() []byte { return h.contentType } // SetNoDefaultContentType controls the default Content-Type header behaviour. // // When set to false, the Content-Type header is sent with a default value if no Content-Type value is specified. // When set to true, no Content-Type header is sent if no Content-Type value is specified. func (h *RequestHeader) SetNoDefaultContentType(b bool) { h.noDefaultContentType = b } // SetContentLength sets Content-Length header value. // // Negative content-length sets 'Transfer-Encoding: chunked' header. func (h *RequestHeader) SetContentLength(contentLength int) { h.contentLength = contentLength if contentLength >= 0 { h.contentLengthBytes = bytesconv.AppendUint(h.contentLengthBytes[:0], contentLength) h.h = delAllArgsBytes(h.h, bytestr.StrTransferEncoding) } else { h.contentLengthBytes = h.contentLengthBytes[:0] h.h = setArgBytes(h.h, bytestr.StrTransferEncoding, bytestr.StrChunked, ArgsHasValue) } } func (h *RequestHeader) InitContentLengthWithValue(contentLength int) { h.contentLength = contentLength } // MultipartFormBoundary returns boundary part // from 'multipart/form-data; boundary=...' Content-Type. func (h *RequestHeader) MultipartFormBoundary() []byte { b := h.ContentType() if !bytes.HasPrefix(b, bytestr.MIMEFormData) { return nil } b = b[len(bytestr.MIMEFormData):] if len(b) == 0 || b[0] != ';' { return nil } var n int for len(b) > 0 { n++ for len(b) > n && b[n] == ' ' { n++ } b = b[n:] if !bytes.HasPrefix(b, bytestr.StrBoundary) { if n = bytes.IndexByte(b, ';'); n < 0 { return nil } continue } b = b[len(bytestr.StrBoundary):] if len(b) == 0 || b[0] != '=' { return nil } b = b[1:] if n = bytes.IndexByte(b, ';'); n >= 0 { b = b[:n] } if len(b) > 1 && b[0] == '"' && b[len(b)-1] == '"' { b = b[1 : len(b)-1] } return b } return nil } // ConnectionClose returns true if 'Connection: close' header is set. func (h *RequestHeader) ConnectionClose() bool { return h.connectionClose } // Method returns HTTP request method. func (h *RequestHeader) Method() []byte { if len(h.method) == 0 { return bytestr.StrGet } return h.method } // IsGet returns true if request method is GET. func (h *RequestHeader) IsGet() bool { return bytes.Equal(h.Method(), bytestr.StrGet) } // IsOptions returns true if request method is Options. func (h *RequestHeader) IsOptions() bool { return bytes.Equal(h.Method(), bytestr.StrOptions) } // IsTrace returns true if request method is Trace. func (h *RequestHeader) IsTrace() bool { return bytes.Equal(h.Method(), bytestr.StrTrace) } // SetHostBytes sets Host header value. func (h *RequestHeader) SetHostBytes(host []byte) { h.host = append(h.host[:0], host...) } // SetRequestURIBytes sets RequestURI for the first HTTP request line. // RequestURI must be properly encoded. // Use URI.RequestURI for constructing proper RequestURI if unsure. func (h *RequestHeader) SetRequestURIBytes(requestURI []byte) { h.requestURI = append(h.requestURI[:0], requestURI...) } // SetBytesKV sets the given 'key: value' header. // // Use AddBytesKV for setting multiple header values under the same key. func (h *RequestHeader) SetBytesKV(key, value []byte) { k := []byte(string(key)) utils.NormalizeHeaderKey(k, h.disableNormalizing) h.SetCanonical(k, value) } func (h *RequestHeader) AddArgBytes(key, value []byte, noValue bool) { h.h = appendArgBytes(h.h, key, value, noValue) } // SetUserAgentBytes sets User-Agent header value. func (h *RequestHeader) SetUserAgentBytes(userAgent []byte) { h.userAgent = append(h.userAgent[:0], userAgent...) } // SetCookie sets 'key: value' cookies. func (h *RequestHeader) SetCookie(key, value string) { h.collectCookies() h.cookies = setArg(h.cookies, key, value, ArgsHasValue) } // SetCookie sets the given response cookie. // It is save re-using the cookie after the function returns. func (h *ResponseHeader) SetCookie(cookie *Cookie) { h.cookies = setArgBytes(h.cookies, cookie.Key(), cookie.Cookie(), ArgsHasValue) } // Cookie returns cookie for the given key. func (h *RequestHeader) Cookie(key string) []byte { h.collectCookies() return peekArgStr(h.cookies, key) } // Cookies returns all the request cookies. // // It's a good idea to call protocol.ReleaseCookie to reduce GC load after the cookie used. func (h *RequestHeader) Cookies() []*Cookie { var cookies []*Cookie h.VisitAllCookie(func(key, value []byte) { cookie := AcquireCookie() cookie.SetKeyBytes(key) cookie.SetValueBytes(value) cookies = append(cookies, cookie) }) return cookies } func (h *RequestHeader) PeekRange() []byte { return h.peek(consts.HeaderRange) } func (h *RequestHeader) PeekContentEncoding() []byte { return h.peek(consts.HeaderContentEncoding) } // FullCookie returns complete cookie bytes func (h *RequestHeader) FullCookie() []byte { return h.Peek(consts.HeaderCookie) } // DelCookie removes cookie under the given key. func (h *RequestHeader) DelCookie(key string) { h.collectCookies() h.cookies = delAllArgs(h.cookies, key) } // DelAllCookies removes all the cookies from request headers. func (h *RequestHeader) DelAllCookies() { h.collectCookies() h.cookies = h.cookies[:0] } // VisitAllCookie calls f for each request cookie. // // f must not retain references to key and/or value after returning. func (h *RequestHeader) VisitAllCookie(f func(key, value []byte)) { h.collectCookies() visitArgs(h.cookies, f) } func (h *RequestHeader) collectCookies() { if h.cookiesCollected { return } for i, n := 0, len(h.h); i < n; i++ { kv := &h.h[i] if bytes.Equal(kv.key, bytestr.StrCookie) { h.cookies = parseRequestCookies(h.cookies, kv.value) tmp := *kv copy(h.h[i:], h.h[i+1:]) n-- i-- h.h[n] = tmp h.h = h.h[:n] } } h.cookiesCollected = true } func (h *RequestHeader) SetConnectionClose(close bool) { h.connectionClose = close } // ResetConnectionClose clears 'Connection: close' header if it exists. func (h *RequestHeader) ResetConnectionClose() { if h.connectionClose { h.connectionClose = false h.h = delAllArgsBytes(h.h, bytestr.StrConnection) } } // SetMethod sets HTTP request method. func (h *RequestHeader) SetMethod(method string) { h.method = append(h.method[:0], method...) } // SetRequestURI sets RequestURI for the first HTTP request line. // RequestURI must be properly encoded. // Use URI.RequestURI for constructing proper RequestURI if unsure. func (h *RequestHeader) SetRequestURI(requestURI string) { h.requestURI = append(h.requestURI[:0], requestURI...) } // Set sets the given 'key: value' header. // // Use Add for setting multiple header values under the same key. func (h *RequestHeader) Set(key, value string) { initHeaderKV(&h.bufKV, key, value, h.disableNormalizing) h.SetCanonical(h.bufKV.key, h.bufKV.value) } func initHeaderKV(kv *argsKV, key, value string, disableNormalizing bool) { kv.key = append(kv.key[:0], key...) utils.NormalizeHeaderKey(kv.key, disableNormalizing) kv.value = append(kv.value[:0], value...) } // SetCanonical sets the given 'key: value' header assuming that // key is in canonical form. func (h *RequestHeader) SetCanonical(key, value []byte) { if h.setSpecialHeader(key, value) { return } h.h = setArgBytes(h.h, key, value, ArgsHasValue) } func (h *RequestHeader) ResetSkipNormalize() { h.connectionClose = false h.protocol = "" h.noDefaultContentType = false h.contentLength = 0 h.contentLengthBytes = h.contentLengthBytes[:0] h.method = h.method[:0] h.requestURI = h.requestURI[:0] h.host = h.host[:0] h.contentType = h.contentType[:0] h.userAgent = h.userAgent[:0] h.h = h.h[:0] h.cookies = h.cookies[:0] h.cookiesCollected = false h.rawHeaders = h.rawHeaders[:0] h.mulHeader = h.mulHeader[:0] h.Trailer().ResetSkipNormalize() } func peekRawHeader(buf, key []byte) []byte { n := bytes.Index(buf, key) if n < 0 { return nil } if n > 0 && buf[n-1] != '\n' { return nil } n += len(key) if n >= len(buf) { return nil } if buf[n] != ':' { return nil } n++ if buf[n] != ' ' { return nil } n++ buf = buf[n:] n = bytes.IndexByte(buf, '\n') if n < 0 { return nil } if n > 0 && buf[n-1] == '\r' { n-- } return buf[:n] } // Host returns Host header value. func (h *RequestHeader) Host() []byte { return h.host } // UserAgent returns User-Agent header value. func (h *RequestHeader) UserAgent() []byte { return h.userAgent } // DisableNormalizing disables header names' normalization. // // By default all the header names are normalized by uppercasing // the first letter and all the first letters following dashes, // while lowercasing all the other letters. // Examples: // // - CONNECTION -> Connection // - conteNT-tYPE -> Content-Type // - foo-bar-baz -> Foo-Bar-Baz // // Disable header names' normalization only if you know what are you doing. func (h *RequestHeader) DisableNormalizing() { h.disableNormalizing = true h.Trailer().DisableNormalizing() } func (h *RequestHeader) IsDisableNormalizing() bool { return h.disableNormalizing } // String returns request header representation. func (h *RequestHeader) String() string { return string(h.Header()) } // VisitAll calls f for each header. // // f must not retain references to key and/or value after returning. // Copy key and/or value contents before returning if you need retaining them. // // To get the headers in order they were received use VisitAllInOrder. func (h *RequestHeader) VisitAll(f func(key, value []byte)) { host := h.Host() if len(host) > 0 { f(bytestr.StrHost, host) } if len(h.contentLengthBytes) > 0 { f(bytestr.StrContentLength, h.contentLengthBytes) } contentType := h.ContentType() if len(contentType) > 0 { f(bytestr.StrContentType, contentType) } userAgent := h.UserAgent() if len(userAgent) > 0 { f(bytestr.StrUserAgent, userAgent) } if !h.Trailer().Empty() { f(bytestr.StrTrailer, h.Trailer().GetBytes()) } h.collectCookies() if len(h.cookies) > 0 { h.bufKV.value = appendRequestCookieBytes(h.bufKV.value[:0], h.cookies) f(bytestr.StrCookie, h.bufKV.value) } visitArgs(h.h, f) if h.ConnectionClose() { f(bytestr.StrConnection, bytestr.StrClose) } } // VisitAllCustomHeader calls f for each header in header.h which contains all headers // except cookie, host, content-length, content-type, user-agent and connection. // // f must not retain references to key and/or value after returning. // Copy key and/or value contents before returning if you need retaining them. // // To get the headers in order they were received use VisitAllInOrder. func (h *RequestHeader) VisitAllCustomHeader(f func(key, value []byte)) { visitArgs(h.h, f) } func ParseContentLength(b []byte) (int, error) { v, n, err := bytesconv.ParseUintBuf(b) if err != nil { return -1, err } if n != len(b) { return -1, errs.NewPublic("non-numeric chars at the end of Content-Length") } return v, nil } func appendArgBytes(args []argsKV, key, value []byte, noValue bool) []argsKV { var kv *argsKV args, kv = allocArg(args) kv.key = append(kv.key[:0], key...) if noValue { kv.value = kv.value[:0] } else { kv.value = append(kv.value[:0], value...) } kv.noValue = noValue return args } func appendArg(args []argsKV, key, value string, noValue bool) []argsKV { var kv *argsKV args, kv = allocArg(args) kv.key = append(kv.key[:0], key...) if noValue { kv.value = kv.value[:0] } else { kv.value = append(kv.value[:0], value...) } kv.noValue = noValue return args } func (h *RequestHeader) peek(key string) []byte { switch key { case consts.HeaderHost: return h.Host() case consts.HeaderContentType: return h.ContentType() case consts.HeaderUserAgent: return h.UserAgent() case consts.HeaderConnection: if h.ConnectionClose() { return bytestr.StrClose } return peekArgStr(h.h, key) case consts.HeaderContentLength: return h.contentLengthBytes case consts.HeaderCookie: if h.cookiesCollected { return appendRequestCookieBytes(nil, h.cookies) } return peekArgStr(h.h, key) case consts.HeaderTrailer: return h.Trailer().GetBytes() default: return peekArgStr(h.h, key) } } func (h *RequestHeader) Get(key string) string { return string(h.Peek(key)) } func (h *ResponseHeader) Get(key string) string { return string(h.Peek(key)) } // GetAll returns all header value for the given key // it is concurrent safety and long lifetime. func (h *RequestHeader) GetAll(key string) []string { headers := h.PeekAll(key) res := make([]string, 0, len(headers)) for _, header := range headers { res = append(res, string(header)) } return res } // GetAll returns all header value for the given key and is concurrent safety. // it is concurrent safety and long lifetime. func (h *ResponseHeader) GetAll(key string) []string { headers := h.PeekAll(key) res := make([]string, 0, len(headers)) for _, header := range headers { res = append(res, string(header)) } return res } func appendHeaderLine(dst, key, value []byte) []byte { for _, k := range key { // if header field contains invalid key, just skip it. if bytesconv.ValidHeaderFieldNameTable[k] == 0 { return dst } } dst = append(dst, key...) dst = append(dst, bytestr.StrColonSpace...) dst = appendHeaderValue(dst, value) return append(dst, bytestr.StrCRLF...) } func appendHeaderValue(dst, v []byte) []byte { ret := append(dst, v...) v = ret[len(dst):] for i, c := range v { // '\r' or '\n' -> ' ' if c == '\r' || c == '\n' { v[i] = ' ' } } return ret } func UpdateServerDate() { refreshServerDate() go func() { for { time.Sleep(time.Second) refreshServerDate() } }() } func refreshServerDate() { b := bytesconv.AppendHTTPDate(make([]byte, 0, len(http.TimeFormat)), time.Now()) ServerDate.Store(b) } // SetMethodBytes sets HTTP request method. func (h *RequestHeader) SetMethodBytes(method []byte) { h.method = append(h.method[:0], method...) } // DisableNormalizing disables header names' normalization. // // By default all the header names are normalized by uppercasing // the first letter and all the first letters following dashes, // while lowercasing all the other letters. // Examples: // // - CONNECTION -> Connection // - conteNT-tYPE -> Content-Type // - foo-bar-baz -> Foo-Bar-Baz // // Disable header names' normalization only if you know what are you doing. func (h *ResponseHeader) DisableNormalizing() { h.disableNormalizing = true h.Trailer().DisableNormalizing() } // setSpecialHeader handles special headers and return true when a header is processed. func (h *ResponseHeader) setSpecialHeader(key, value []byte) bool { if len(key) == 0 { return false } switch key[0] | 0x20 { case 'c': if utils.CaseInsensitiveCompare(bytestr.StrContentType, key) { h.SetContentTypeBytes(value) return true } else if utils.CaseInsensitiveCompare(bytestr.StrContentLength, key) { if contentLength, err := ParseContentLength(value); err == nil { h.contentLength = contentLength h.contentLengthBytes = append(h.contentLengthBytes[:0], value...) } return true } else if utils.CaseInsensitiveCompare(bytestr.StrContentEncoding, key) { h.SetContentEncodingBytes(value) return true } else if utils.CaseInsensitiveCompare(bytestr.StrConnection, key) { if bytes.Equal(bytestr.StrClose, value) { h.SetConnectionClose(true) } else { h.ResetConnectionClose() h.h = setArgBytes(h.h, key, value, ArgsHasValue) } return true } case 's': if utils.CaseInsensitiveCompare(bytestr.StrServer, key) { h.SetServerBytes(value) return true } else if utils.CaseInsensitiveCompare(bytestr.StrSetCookie, key) { var kv *argsKV h.cookies, kv = allocArg(h.cookies) kv.key = getCookieKey(kv.key, value) kv.value = append(kv.value[:0], value...) return true } case 't': if utils.CaseInsensitiveCompare(bytestr.StrTransferEncoding, key) { // Transfer-Encoding is managed automatically. return true } else if utils.CaseInsensitiveCompare(bytestr.StrTrailer, key) { // copy value to avoid panic value = append(h.bufKV.value[:0], value...) h.Trailer().SetTrailers(value) return true } case 'd': if utils.CaseInsensitiveCompare(bytestr.StrDate, key) { // Date is managed automatically. return true } } return false } // setSpecialHeader handles special headers and return true when a header is processed. func (h *RequestHeader) setSpecialHeader(key, value []byte) bool { if len(key) == 0 { return false } switch key[0] | 0x20 { case 'c': if utils.CaseInsensitiveCompare(bytestr.StrContentType, key) { h.SetContentTypeBytes(value) return true } else if utils.CaseInsensitiveCompare(bytestr.StrContentLength, key) { if contentLength, err := ParseContentLength(value); err == nil { h.contentLength = contentLength h.contentLengthBytes = append(h.contentLengthBytes[:0], value...) } return true } else if utils.CaseInsensitiveCompare(bytestr.StrConnection, key) { if bytes.Equal(bytestr.StrClose, value) { h.SetConnectionClose(true) } else { h.ResetConnectionClose() h.h = setArgBytes(h.h, key, value, ArgsHasValue) } return true } else if utils.CaseInsensitiveCompare(bytestr.StrCookie, key) { h.collectCookies() h.cookies = parseRequestCookies(h.cookies, value) return true } case 't': if utils.CaseInsensitiveCompare(bytestr.StrTransferEncoding, key) { // Transfer-Encoding is managed automatically. return true } else if utils.CaseInsensitiveCompare(bytestr.StrTrailer, key) { // copy value to avoid panic value = append(h.bufKV.value[:0], value...) h.Trailer().SetTrailers(value) return true } case 'h': if utils.CaseInsensitiveCompare(bytestr.StrHost, key) { h.SetHostBytes(value) return true } case 'u': if utils.CaseInsensitiveCompare(bytestr.StrUserAgent, key) { h.SetUserAgentBytes(value) return true } } return false } // Trailer returns the Trailer of HTTP Header. func (h *ResponseHeader) Trailer() *Trailer { if h.trailer == nil { h.trailer = new(Trailer) } return h.trailer } // Trailer returns the Trailer of HTTP Header. func (h *RequestHeader) Trailer() *Trailer { if h.trailer == nil { h.trailer = new(Trailer) } return h.trailer } func (h *ResponseHeader) SetProtocol(p string) { h.protocol = p } func (h *ResponseHeader) GetProtocol() string { return h.protocol } ================================================ FILE: pkg/protocol/header_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package protocol import ( "bytes" "fmt" "net/http" "strings" "testing" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/pkg/common/hlog" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/protocol/consts" ) func TestRequestHeaderSetRawHeaders(t *testing.T) { h := RequestHeader{} h.SetRawHeaders([]byte("foo")) assert.DeepEqual(t, h.rawHeaders, []byte("foo")) } func TestResponseHeaderSetHeaderLength(t *testing.T) { h := ResponseHeader{} h.SetHeaderLength(15) assert.DeepEqual(t, h.headerLength, 15) assert.DeepEqual(t, h.GetHeaderLength(), 15) } func TestSetNoHTTP11(t *testing.T) { rh := ResponseHeader{} rh.SetProtocol(consts.HTTP10) assert.DeepEqual(t, consts.HTTP10, rh.protocol) rh.SetProtocol(consts.HTTP11) assert.DeepEqual(t, consts.HTTP11, rh.protocol) assert.True(t, rh.IsHTTP11()) h := RequestHeader{} h.SetProtocol(consts.HTTP10) assert.DeepEqual(t, consts.HTTP10, h.protocol) h.SetProtocol(consts.HTTP11) assert.DeepEqual(t, consts.HTTP11, h.protocol) assert.True(t, h.IsHTTP11()) } func TestResponseHeaderSetContentType(t *testing.T) { h := ResponseHeader{} h.SetContentType("foo") assert.DeepEqual(t, h.contentType, []byte("foo")) } func TestSetContentLengthBytes(t *testing.T) { h := RequestHeader{} h.SetContentLengthBytes([]byte("foo")) assert.DeepEqual(t, h.contentLengthBytes, []byte("foo")) rh := ResponseHeader{} rh.SetContentLengthBytes([]byte("foo")) assert.DeepEqual(t, rh.contentLengthBytes, []byte("foo")) } func TestInitContentLengthWithValue(t *testing.T) { initLength := 100 h := RequestHeader{} h.InitContentLengthWithValue(initLength) assert.DeepEqual(t, h.contentLength, initLength) rh := ResponseHeader{} rh.InitContentLengthWithValue(initLength) assert.DeepEqual(t, rh.contentLength, initLength) } func TestSetContentEncoding(t *testing.T) { rh := ResponseHeader{} rh.SetContentEncoding("gzip") assert.DeepEqual(t, rh.contentEncoding, []byte("gzip")) } func Test_peekRawHeader(t *testing.T) { s := "Expect: 100-continue\r\nUser-Agent: foo\r\nHost: 127.0.0.1\r\nConnection: Keep-Alive\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" assert.DeepEqual(t, []byte("127.0.0.1"), peekRawHeader([]byte(s), []byte("Host"))) } func TestResponseHeader_SetContentLength(t *testing.T) { rh := new(ResponseHeader) rh.SetContentLength(-1) assert.True(t, strings.Contains(string(rh.Header()), "Transfer-Encoding: chunked")) rh.SetContentLength(-2) assert.True(t, strings.Contains(string(rh.Header()), "Transfer-Encoding: identity")) } func TestResponseHeader_SetContentRange(t *testing.T) { rh := new(ResponseHeader) rh.SetContentRange(1, 5, 10) assert.DeepEqual(t, rh.bufKV.value, []byte("bytes 1-5/10")) } func TestSetCanonical(t *testing.T) { h := ResponseHeader{} h.SetCanonical([]byte(consts.HeaderContentType), []byte("foo")) h.SetCanonical([]byte(consts.HeaderServer), []byte("foo1")) h.SetCanonical([]byte(consts.HeaderSetCookie), []byte("foo2")) h.SetCanonical([]byte(consts.HeaderContentLength), []byte("3")) h.SetCanonical([]byte(consts.HeaderConnection), []byte("foo4")) h.SetCanonical([]byte(consts.HeaderTransferEncoding), []byte("foo5")) h.SetCanonical([]byte(consts.HeaderTrailer), []byte("foo7")) h.SetCanonical([]byte("bar"), []byte("foo6")) assert.DeepEqual(t, []byte("foo"), h.ContentType()) assert.DeepEqual(t, []byte("foo1"), h.Server()) assert.DeepEqual(t, true, strings.Contains(string(h.Header()), "foo2")) assert.DeepEqual(t, 3, h.ContentLength()) assert.DeepEqual(t, false, h.ConnectionClose()) assert.DeepEqual(t, false, strings.Contains(string(h.ContentType()), "foo5")) assert.DeepEqual(t, true, strings.Contains(string(h.Header()), "Trailer: Foo7")) assert.DeepEqual(t, true, strings.Contains(string(h.Header()), "bar: foo6")) } func TestHasAcceptEncodingBytes(t *testing.T) { h := RequestHeader{} h.Set(consts.HeaderAcceptEncoding, "gzip") assert.True(t, h.HasAcceptEncodingBytes([]byte("gzip"))) } func TestRequestHeaderGet(t *testing.T) { h := RequestHeader{} rightVal := "yyy" h.Set("xxx", rightVal) val := h.Get("xxx") if val != rightVal { t.Fatalf("Unexpected %v. Expected %v", val, rightVal) } } func TestResponseHeaderGet(t *testing.T) { h := ResponseHeader{} rightVal := "yyy" h.Set("xxx", rightVal) val := h.Get("xxx") assert.DeepEqual(t, val, rightVal) } func TestRequestHeaderGetAll(t *testing.T) { h := RequestHeader{} h.Set("Foo-Bar", "foo") h.Add("Foo-Bar", "bar") h.Add("Foo-Bar", "foo-bar") values := h.GetAll("Foo-Bar") assert.DeepEqual(t, values, []string{"foo", "bar", "foo-bar"}) } func TestResponseHeaderGetAll(t *testing.T) { h := ResponseHeader{} h.Set("Foo-Bar", "foo") h.Add("Foo-Bar", "bar") h.Add("Foo-Bar", "foo-bar") values := h.GetAll("Foo-Bar") assert.DeepEqual(t, values, []string{"foo", "bar", "foo-bar"}) } func TestRequestHeaderVisitAll(t *testing.T) { h := RequestHeader{} h.Set("xxx", "yyy") h.Set("xxx2", "yyy2") h.SetHost("host") h.SetContentLengthBytes([]byte("content-length")) h.Set(consts.HeaderContentType, "content-type") h.Set(consts.HeaderUserAgent, "user-agent") err := h.Trailer().SetTrailers([]byte("foo, bar")) if err != nil { t.Fatalf("Set trailer err %v", err) } h.SetCookie("foo", "bar") h.Set(consts.HeaderConnection, "close") h.VisitAll(func(k, v []byte) { key := string(k) value := string(v) switch key { case consts.HeaderHost: assert.DeepEqual(t, value, "host") case consts.HeaderContentLength: assert.DeepEqual(t, value, "content-length") case consts.HeaderContentType: assert.DeepEqual(t, value, "content-type") case consts.HeaderUserAgent: assert.DeepEqual(t, value, "user-agent") case consts.HeaderTrailer: assert.DeepEqual(t, value, "Foo, Bar") case consts.HeaderCookie: assert.DeepEqual(t, value, "foo=bar") case consts.HeaderConnection: assert.DeepEqual(t, value, "close") case "Xxx": assert.DeepEqual(t, value, "yyy") case "Xxx2": assert.DeepEqual(t, value, "yyy2") default: t.Fatalf("Unexpected key %v", key) } }) } func TestRequestHeaderCookie(t *testing.T) { var h RequestHeader h.SetCookie("foo", "bar") cookie := h.Cookie("foo") assert.DeepEqual(t, []byte("bar"), cookie) } func TestRequestHeaderCookies(t *testing.T) { var h RequestHeader h.SetCookie("foo", "bar") h.SetCookie("привет", "мир") cookies := h.Cookies() assert.DeepEqual(t, 2, len(cookies)) assert.DeepEqual(t, []byte("foo"), cookies[0].Key()) assert.DeepEqual(t, []byte("bar"), cookies[0].Value()) assert.DeepEqual(t, []byte("привет"), cookies[1].Key()) assert.DeepEqual(t, []byte("мир"), cookies[1].Value()) } func TestRequestHeaderDel(t *testing.T) { t.Parallel() var h RequestHeader h.Set("Foo-Bar", "baz") h.Set("aaa", "bbb") h.Set("ccc", "ddd") h.Set(consts.HeaderConnection, "keep-alive") h.Set(consts.HeaderContentType, "aaa") h.Set(consts.HeaderServer, "aaabbb") h.Set(consts.HeaderContentLength, "1123") h.Set(consts.HeaderTrailer, "foo, bar") h.Set(consts.HeaderUserAgent, "foo-bar") h.SetHost("foobar") h.SetCookie("foo", "bar") h.del([]byte("Foo-Bar")) h.del([]byte("Connection")) h.DelBytes([]byte("Content-Type")) h.del([]byte(consts.HeaderServer)) h.del([]byte("Content-Length")) h.del([]byte("Set-Cookie")) h.del([]byte("Host")) h.del([]byte(consts.HeaderTrailer)) h.del([]byte(consts.HeaderUserAgent)) h.DelCookie("foo") h.Del("ccc") hv := h.Peek("aaa") if string(hv) != "bbb" { t.Fatalf("unexpected header value: %q. Expecting %q", hv, "bbb") } hv = h.Peek("ccc") if string(hv) != "" { t.Fatalf("unexpected header value: %q. Expecting %q", hv, "") } hv = h.Peek("Foo-Bar") if len(hv) > 0 { t.Fatalf("non-zero header value: %q", hv) } hv = h.Peek(consts.HeaderConnection) if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } hv = h.Peek(consts.HeaderContentType) if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } hv = h.Peek(consts.HeaderServer) if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } hv = h.Peek(consts.HeaderContentLength) if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } hv = h.FullCookie() if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } hv = h.Peek(consts.HeaderCookie) if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } hv = h.Peek(consts.HeaderTrailer) if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } hv = h.Peek(consts.HeaderUserAgent) if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } if h.ContentLength() != 0 { t.Fatalf("unexpected content-length: %d. Expecting 0", h.ContentLength()) } } func TestResponseHeaderDel(t *testing.T) { t.Parallel() var h ResponseHeader h.Set("Foo-Bar", "baz") h.Set("aaa", "bbb") h.Set(consts.HeaderConnection, "keep-alive") h.Set(consts.HeaderContentType, "aaa") h.Set(consts.HeaderContentEncoding, "gzip") h.Set(consts.HeaderServer, "aaabbb") h.Set(consts.HeaderContentLength, "1123") h.Set(consts.HeaderTrailer, "foo, bar") var c Cookie c.SetKey("foo") c.SetValue("bar") h.SetCookie(&c) h.Del("foo-bar") h.Del("connection") h.DelBytes([]byte("content-type")) h.Del(consts.HeaderServer) h.Del("content-length") h.Del("set-cookie") h.Del("content-encoding") h.Del(consts.HeaderTrailer) hv := h.Peek("aaa") if string(hv) != "bbb" { t.Fatalf("unexpected header value: %q. Expecting %q", hv, "bbb") } hv = h.Peek("Foo-Bar") if len(hv) > 0 { t.Fatalf("non-zero header value: %q", hv) } hv = h.Peek(consts.HeaderConnection) if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } hv = h.Peek(consts.HeaderContentType) if string(hv) != string(bytestr.DefaultContentType) { t.Fatalf("unexpected content-type: %q. Expecting %q", hv, bytestr.DefaultContentType) } hv = h.Peek(consts.HeaderContentEncoding) if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } hv = h.Peek(consts.HeaderServer) if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } hv = h.Peek(consts.HeaderContentLength) if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } hv = h.Peek(consts.HeaderTrailer) if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } if h.Cookie(&c) { t.Fatalf("unexpected cookie obtained: %v", &c) } if h.ContentLength() != 0 { t.Fatalf("unexpected content-length: %d. Expecting 0", h.ContentLength()) } } func TestResponseHeaderDelClientCookie(t *testing.T) { t.Parallel() cookieName := "foobar" var h ResponseHeader c := AcquireCookie() c.SetKey(cookieName) c.SetValue("aasdfsdaf") h.SetCookie(c) h.DelClientCookieBytes([]byte(cookieName)) if !h.Cookie(c) { t.Fatalf("expecting cookie %q", c.Key()) } if !c.Expire().Equal(CookieExpireDelete) { t.Fatalf("unexpected cookie expiration time: %s. Expecting %s", c.Expire(), CookieExpireDelete) } if len(c.Value()) > 0 { t.Fatalf("unexpected cookie value: %q. Expecting empty value", c.Value()) } ReleaseCookie(c) } func TestResponseHeaderResetConnectionClose(t *testing.T) { h := ResponseHeader{} h.Set(consts.HeaderConnection, "close") hv := h.Peek(consts.HeaderConnection) assert.DeepEqual(t, hv, []byte("close")) h.SetConnectionClose(true) h.ResetConnectionClose() assert.False(t, h.connectionClose) hv = h.Peek(consts.HeaderConnection) if len(hv) > 0 { t.Fatalf("ResetConnectionClose do not work,Connection: %q", hv) } } func TestRequestHeaderResetConnectionClose(t *testing.T) { h := RequestHeader{} h.Set(consts.HeaderConnection, "close") hv := h.Peek(consts.HeaderConnection) assert.DeepEqual(t, hv, []byte("close")) h.connectionClose = true h.ResetConnectionClose() assert.False(t, h.connectionClose) hv = h.Peek(consts.HeaderConnection) if len(hv) > 0 { t.Fatalf("ResetConnectionClose do not work,Connection: %q", hv) } } func TestCheckWriteHeaderCode(t *testing.T) { buffer := bytes.NewBuffer(make([]byte, 0, 1024)) hlog.SetOutput(buffer) checkWriteHeaderCode(99) assert.True(t, strings.Contains(buffer.String(), "[Warn] HERTZ: Invalid StatusCode code")) buffer.Reset() checkWriteHeaderCode(600) assert.True(t, strings.Contains(buffer.String(), "[Warn] HERTZ: Invalid StatusCode code")) buffer.Reset() checkWriteHeaderCode(100) assert.False(t, strings.Contains(buffer.String(), "[Warn] HERTZ: Invalid StatusCode code")) buffer.Reset() checkWriteHeaderCode(599) assert.False(t, strings.Contains(buffer.String(), "[Warn] HERTZ: Invalid StatusCode code")) } func TestResponseHeaderAdd(t *testing.T) { t.Parallel() m := make(map[string]struct{}) var h ResponseHeader h.Add("aaa", "bbb") h.Add("content-type", "xxx") h.SetContentEncoding("gzip") m["bbb"] = struct{}{} m["xxx"] = struct{}{} m["gzip"] = struct{}{} for i := 0; i < 10; i++ { v := fmt.Sprintf("%d", i) h.Add("Foo-Bar", v) m[v] = struct{}{} } if h.Len() != 13 { t.Fatalf("unexpected header len %d. Expecting 13", h.Len()) } h.VisitAll(func(k, v []byte) { switch string(k) { case "Aaa", "Foo-Bar", "Content-Type", "Content-Encoding": if _, ok := m[string(v)]; !ok { t.Fatalf("unexpected value found %q. key %q", v, k) } delete(m, string(v)) default: t.Fatalf("unexpected key found: %q", k) } }) if len(m) > 0 { t.Fatalf("%d headers are missed", len(m)) } } func TestRequestHeaderAdd(t *testing.T) { t.Parallel() m := make(map[string]struct{}) var h RequestHeader h.Add("aaa", "bbb") h.Add("user-agent", "xxx") m["bbb"] = struct{}{} m["xxx"] = struct{}{} for i := 0; i < 10; i++ { v := fmt.Sprintf("%d", i) h.Add("Foo-Bar", v) m[v] = struct{}{} } if h.Len() != 12 { t.Fatalf("unexpected header len %d. Expecting 12", h.Len()) } h.VisitAll(func(k, v []byte) { switch string(k) { case "Aaa", "Foo-Bar", "User-Agent": if _, ok := m[string(v)]; !ok { t.Fatalf("unexpected value found %q. key %q", v, k) } delete(m, string(v)) default: t.Fatalf("unexpected key found: %q", k) } }) if len(m) > 0 { t.Fatalf("%d headers are missed", len(m)) } } func TestResponseHeaderAddContentType(t *testing.T) { t.Parallel() var h ResponseHeader h.Add("Content-Type", "test") got := string(h.Peek("Content-Type")) expected := "test" if got != expected { t.Errorf("expected %q got %q", expected, got) } if n := strings.Count(string(h.Header()), "Content-Type: "); n != 1 { t.Errorf("Content-Type occurred %d times", n) } } func TestResponseHeaderAddContentEncoding(t *testing.T) { t.Parallel() var h ResponseHeader h.Add("Content-Encoding", "test") got := string(h.ContentEncoding()) expected := "test" if got != expected { t.Errorf("expected %q got %q", expected, got) } if n := strings.Count(string(h.Header()), "Content-Encoding: "); n != 1 { t.Errorf("Content-Encoding occurred %d times", n) } } func TestRequestHeaderAddContentType(t *testing.T) { t.Parallel() var h RequestHeader h.Add("Content-Type", "test") got := string(h.Peek("Content-Type")) expected := "test" if got != expected { t.Errorf("expected %q got %q", expected, got) } if n := strings.Count(h.String(), "Content-Type: "); n != 1 { t.Errorf("Content-Type occurred %d times", n) } } func TestSetMultipartFormBoundary(t *testing.T) { h := RequestHeader{} h.SetMultipartFormBoundary("foo") assert.DeepEqual(t, h.contentType, []byte("multipart/form-data; boundary=foo")) } func TestRequestHeaderSetByteRange(t *testing.T) { var h RequestHeader h.SetByteRange(1, 5) hv := h.Peek(consts.HeaderRange) assert.DeepEqual(t, hv, []byte("bytes=1-5")) } func TestRequestHeaderSetMethodBytes(t *testing.T) { var h RequestHeader h.SetMethodBytes([]byte("foo")) assert.DeepEqual(t, h.Method(), []byte("foo")) } func TestRequestHeaderSetBytesKV(t *testing.T) { var h RequestHeader h.SetBytesKV([]byte("foo"), []byte("foo1")) hv := h.Peek("foo") assert.DeepEqual(t, hv, []byte("foo1")) } func TestResponseHeaderSetBytesV(t *testing.T) { var h ResponseHeader h.SetBytesV("foo", []byte("foo1")) hv := h.Peek("foo") assert.DeepEqual(t, hv, []byte("foo1")) } func TestRequestHeaderInitBufValue(t *testing.T) { var h RequestHeader slice := make([]byte, 0, 10) h.InitBufValue(10) assert.DeepEqual(t, cap(h.bufKV.value), cap(slice)) assert.DeepEqual(t, h.GetBufValue(), slice) } func TestRequestHeaderDelAllCookies(t *testing.T) { var h RequestHeader h.SetCanonical([]byte(consts.HeaderSetCookie), []byte("foo2")) h.DelAllCookies() hv := h.FullCookie() if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } } func TestResponseHeaderDelAllCookies(t *testing.T) { var h ResponseHeader h.SetCanonical([]byte(consts.HeaderSetCookie), []byte("foo")) h.DelAllCookies() hv := h.FullCookie() if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } } func TestRequestHeaderSetNoDefaultContentType(t *testing.T) { var h RequestHeader h.SetMethod(http.MethodPost) b := h.AppendBytes(nil) assert.DeepEqual(t, b, []byte("POST / HTTP/1.1\r\nContent-Type: application/x-www-form-urlencoded\r\n\r\n")) h.SetNoDefaultContentType(true) b = h.AppendBytes(nil) assert.DeepEqual(t, b, []byte("POST / HTTP/1.1\r\n\r\n")) } func TestRequestHeader_PeekAll(t *testing.T) { t.Parallel() h := &RequestHeader{} h.Add(consts.HeaderConnection, "keep-alive") h.Add("Content-Type", "aaa") h.Add(consts.HeaderHost, "aaabbb") h.Add("User-Agent", "asdfas") h.Add("Content-Length", "1123") h.Add("Cookie", "foobar=baz") h.Add("aaa", "aaa") h.Add("aaa", "bbb") expectRequestHeaderAll(t, h, consts.HeaderConnection, [][]byte{[]byte("keep-alive")}) expectRequestHeaderAll(t, h, "Content-Type", [][]byte{[]byte("aaa")}) expectRequestHeaderAll(t, h, consts.HeaderHost, [][]byte{[]byte("aaabbb")}) expectRequestHeaderAll(t, h, "User-Agent", [][]byte{[]byte("asdfas")}) expectRequestHeaderAll(t, h, "Content-Length", [][]byte{[]byte("1123")}) expectRequestHeaderAll(t, h, "Cookie", [][]byte{[]byte("foobar=baz")}) expectRequestHeaderAll(t, h, "aaa", [][]byte{[]byte("aaa"), []byte("bbb")}) h.DelBytes([]byte("Content-Type")) h.DelBytes([]byte((consts.HeaderHost))) h.DelBytes([]byte("aaa")) expectRequestHeaderAll(t, h, "Content-Type", [][]byte{}) expectRequestHeaderAll(t, h, consts.HeaderHost, [][]byte{}) expectRequestHeaderAll(t, h, "aaa", [][]byte{}) } func expectRequestHeaderAll(t *testing.T, h *RequestHeader, key string, expectedValue [][]byte) { if len(h.PeekAll(key)) != len(expectedValue) { t.Fatalf("Unexpected size for key %q: %d. Expected %d", key, len(h.PeekAll(key)), len(expectedValue)) } assert.DeepEqual(t, h.PeekAll(key), expectedValue) } func TestResponseHeader_PeekAll(t *testing.T) { t.Parallel() h := &ResponseHeader{} h.Add(consts.HeaderContentType, "aaa/bbb") h.Add(consts.HeaderContentEncoding, "gzip") h.Add(consts.HeaderConnection, "close") h.Add(consts.HeaderContentLength, "1234") h.Add(consts.HeaderServer, "aaaa") h.Add(consts.HeaderSetCookie, "cccc") h.Add("aaa", "aaa") h.Add("aaa", "bbb") expectResponseHeaderAll(t, h, consts.HeaderContentType, [][]byte{[]byte("aaa/bbb")}) expectResponseHeaderAll(t, h, consts.HeaderContentEncoding, [][]byte{[]byte("gzip")}) expectResponseHeaderAll(t, h, consts.HeaderConnection, [][]byte{[]byte("close")}) expectResponseHeaderAll(t, h, consts.HeaderContentLength, [][]byte{[]byte("1234")}) expectResponseHeaderAll(t, h, consts.HeaderServer, [][]byte{[]byte("aaaa")}) expectResponseHeaderAll(t, h, consts.HeaderSetCookie, [][]byte{[]byte("cccc")}) expectResponseHeaderAll(t, h, "aaa", [][]byte{[]byte("aaa"), []byte("bbb")}) h.Del(consts.HeaderContentType) h.Del(consts.HeaderContentEncoding) expectResponseHeaderAll(t, h, consts.HeaderContentType, [][]byte{bytestr.DefaultContentType}) expectResponseHeaderAll(t, h, consts.HeaderContentEncoding, [][]byte{}) } func expectResponseHeaderAll(t *testing.T, h *ResponseHeader, key string, expectedValue [][]byte) { if len(h.PeekAll(key)) != len(expectedValue) { t.Fatalf("Unexpected size for key %q: %d. Expected %d", key, len(h.PeekAll(key)), len(expectedValue)) } assert.DeepEqual(t, h.PeekAll(key), expectedValue) } func TestRequestHeaderCopyTo(t *testing.T) { t.Parallel() h, hCopy := &RequestHeader{}, &RequestHeader{} h.SetProtocol(consts.HTTP10) h.SetMethod(consts.MethodPatch) h.SetNoDefaultContentType(true) h.Add(consts.HeaderConnection, "keep-alive") h.Add("Content-Type", "aaa") h.Add(consts.HeaderHost, "aaabbb") h.Add("User-Agent", "asdfas") h.Add("Content-Length", "1123") h.Add("Cookie", "foobar=baz") h.Add("aaa", "aaa") h.Add("aaa", "bbb") h.CopyTo(hCopy) expectRequestHeaderAll(t, hCopy, consts.HeaderConnection, [][]byte{[]byte("keep-alive")}) expectRequestHeaderAll(t, hCopy, "Content-Type", [][]byte{[]byte("aaa")}) expectRequestHeaderAll(t, hCopy, consts.HeaderHost, [][]byte{[]byte("aaabbb")}) expectRequestHeaderAll(t, hCopy, "User-Agent", [][]byte{[]byte("asdfas")}) expectRequestHeaderAll(t, hCopy, "Content-Length", [][]byte{[]byte("1123")}) expectRequestHeaderAll(t, hCopy, "Cookie", [][]byte{[]byte("foobar=baz")}) expectRequestHeaderAll(t, hCopy, "aaa", [][]byte{[]byte("aaa"), []byte("bbb")}) assert.DeepEqual(t, hCopy.GetProtocol(), consts.HTTP10) assert.DeepEqual(t, hCopy.noDefaultContentType, true) assert.DeepEqual(t, string(hCopy.Method()), consts.MethodPatch) } func TestResponseHeaderCopyTo(t *testing.T) { t.Parallel() h, hCopy := &ResponseHeader{}, &ResponseHeader{} h.SetProtocol(consts.HTTP10) h.SetHeaderLength(100) h.SetNoDefaultContentType(true) h.Add(consts.HeaderContentType, "aaa/bbb") h.Add(consts.HeaderContentEncoding, "gzip") h.Add(consts.HeaderConnection, "close") h.Add(consts.HeaderContentLength, "1234") h.Add(consts.HeaderServer, "aaaa") h.Add(consts.HeaderSetCookie, "cccc") h.Add("aaa", "aaa") h.Add("aaa", "bbb") h.CopyTo(hCopy) expectResponseHeaderAll(t, hCopy, consts.HeaderContentType, [][]byte{[]byte("aaa/bbb")}) expectResponseHeaderAll(t, hCopy, consts.HeaderContentEncoding, [][]byte{[]byte("gzip")}) expectResponseHeaderAll(t, hCopy, consts.HeaderConnection, [][]byte{[]byte("close")}) expectResponseHeaderAll(t, hCopy, consts.HeaderContentLength, [][]byte{[]byte("1234")}) expectResponseHeaderAll(t, hCopy, consts.HeaderServer, [][]byte{[]byte("aaaa")}) expectResponseHeaderAll(t, hCopy, consts.HeaderSetCookie, [][]byte{[]byte("cccc")}) expectResponseHeaderAll(t, hCopy, "aaa", [][]byte{[]byte("aaa"), []byte("bbb")}) assert.DeepEqual(t, hCopy.GetProtocol(), consts.HTTP10) assert.DeepEqual(t, hCopy.noDefaultContentType, true) assert.DeepEqual(t, hCopy.GetHeaderLength(), 100) } func TestResponseHeaderDateEmpty(t *testing.T) { t.Parallel() var h ResponseHeader h.noDefaultDate = true headers := string(h.Header()) if strings.Contains(headers, "\r\nDate: ") { t.Fatalf("ResponseDateNoDefaultNotEmpty fail, response: \n%+v\noutcome: \n%q\n", h, headers) //nolint:govet } } func TestSetTrailerWithROString(t *testing.T) { h := &RequestHeader{} h.Add(consts.HeaderTrailer, "foo,bar,hertz") assert.DeepEqual(t, "Foo, Bar, Hertz", h.Get(consts.HeaderTrailer)) h1 := &ResponseHeader{} h1.Add(consts.HeaderTrailer, "foo,bar,hertz") assert.DeepEqual(t, "Foo, Bar, Hertz", h1.Get(consts.HeaderTrailer)) } func Benchmark_RequestHeader_Peek(b *testing.B) { h := &RequestHeader{} h.Add("hello", "world") if s := string(h.Peek("hello")); s != "world" { b.Fatal(s) } b.ResetTimer() for i := 0; i < b.N; i++ { h.Peek("hello") } } func TestAppendHeaderLine(t *testing.T) { tests := []struct { name string dst []byte key []byte value []byte expected []byte }{ { name: "basic header", dst: []byte{}, key: []byte("Content-Type"), value: []byte("application/json"), expected: []byte("Content-Type: application/json\r\n"), }, { name: "value with newlines", dst: []byte{}, key: []byte("X-Custom"), value: []byte("value\nwith\rnewlines"), expected: []byte("X-Custom: value with newlines\r\n"), }, { name: "invalid key", dst: []byte("initial"), key: []byte("Invalid\x00Key"), value: []byte("value"), expected: []byte("initial"), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := appendHeaderLine(tt.dst, tt.key, tt.value) if !bytes.Equal(result, tt.expected) { t.Errorf("appendHeaderLine() = %q, want %q", result, tt.expected) } }) } } ================================================ FILE: pkg/protocol/header_timing_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package protocol import ( "net/http" "strconv" "testing" ) func BenchmarkHTTPHeaderGet(b *testing.B) { hh := make(http.Header) hh.Set("X-tt-logid", "abc123456789") b.ResetTimer() for i := 0; i < b.N; i++ { hh.Get("X-tt-logid") } } func BenchmarkHertzHeaderGet(b *testing.B) { zh := new(ResponseHeader) zh.Set("X-tt-logid", "abc123456789") b.ResetTimer() for i := 0; i < b.N; i++ { zh.Get("X-tt-logid") } } func BenchmarkHTTPHeaderSet(b *testing.B) { hh := make(http.Header) b.ResetTimer() for i := 0; i < b.N; i++ { hh.Set("X-tt-logid", "abc123456789") } } func BenchmarkHertzHeaderSet(b *testing.B) { zh := new(ResponseHeader) b.ResetTimer() for i := 0; i < b.N; i++ { zh.Set("X-tt-logid", "abc123456789") } } func BenchmarkHTTPHeaderAdd(b *testing.B) { hh := make(http.Header) b.ResetTimer() for i := 0; i < b.N; i++ { hh.Add("X-tt-"+strconv.Itoa(i), "abc123456789") } } func BenchmarkHertzHeaderAdd(b *testing.B) { zh := new(ResponseHeader) b.ResetTimer() for i := 0; i < b.N; i++ { zh.Add("X-tt-"+strconv.Itoa(i), "abc123456789") } } func BenchmarkRefreshServerDate(b *testing.B) { b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { refreshServerDate() } } func BenchmarkHeaderAppendBytes(b *testing.B) { h := new(ResponseHeader) h.Set("X-tt-logid", "abc123456789") h.SetServerBytes([]byte("hertz")) buf := make([]byte, 0, 1024) b.ResetTimer() for i := 0; i < b.N; i++ { _ = h.AppendBytes(buf) } } ================================================ FILE: pkg/protocol/http1/client.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package http1 import ( "bytes" "context" "crypto/tls" "errors" "io" "net" "runtime" "strings" "sync" "sync/atomic" "syscall" "time" "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/internal/nocopy" "github.com/cloudwego/hertz/pkg/app/client/retry" "github.com/cloudwego/hertz/pkg/common/config" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/hlog" "github.com/cloudwego/hertz/pkg/common/timer" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/network/dialer" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/client" "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/protocol/http1/proxy" reqI "github.com/cloudwego/hertz/pkg/protocol/http1/req" respI "github.com/cloudwego/hertz/pkg/protocol/http1/resp" ) var ( errConnectionClosed = errs.NewPublic("the server closed connection before returning the first response byte. " + "Make sure the server returns 'Connection: close' response header before closing the connection") errTimeout = errs.New(errs.ErrTimeout, errs.ErrorTypePublic, "host client") ) // HostClient balances http requests among hosts listed in Addr. // // HostClient may be used for balancing load among multiple upstream hosts. // While multiple addresses passed to HostClient.Addr may be used for balancing // load among them, it would be better using LBClient instead, since HostClient // may unevenly balance load among upstream hosts. // // It is forbidden copying HostClient instances. Create new instances instead. // // It is safe calling HostClient methods from concurrently running goroutines. type HostClient struct { noCopy nocopy.NoCopy //lint:ignore U1000 until noCopy is used *ClientOptions // Comma-separated list of upstream HTTP server host addresses, // which are passed to Dialer in a round-robin manner. // // Each address may contain port if default dialer is used. // For example, // // - foobar.com:80 // - foobar.com:443 // - foobar.com:8080 Addr string IsTLS bool ProxyURI *protocol.URI clientName atomic.Value lastUseTime uint32 connsLock sync.Mutex connsCount int conns []*clientConn connsWait *wantConnQueue addrsLock sync.Mutex addrs []string addrIdx uint32 tlsConfigMap map[string]*tls.Config tlsConfigMapLock sync.Mutex pendingRequests int32 connsCleanerRun bool closed chan struct{} } func (c *HostClient) SetDynamicConfig(dc *client.DynamicConfig) { c.Addr = dc.Addr c.ProxyURI = dc.ProxyURI c.IsTLS = dc.IsTLS // start observation after setting addr to avoid race if c.StateObserve != nil { go func() { t := time.NewTicker(c.ObservationInterval) for { select { case <-c.closed: return case <-t.C: c.StateObserve(c) } } }() } } type clientConn struct { c network.Conn createdTime time.Time lastUseTime time.Time } var startTimeUnix = time.Now().Unix() // LastUseTime returns time the client was last used func (c *HostClient) LastUseTime() time.Time { n := atomic.LoadUint32(&c.lastUseTime) return time.Unix(startTimeUnix+int64(n), 0) } // Get returns the status code and body of url. // // The contents of dst will be replaced by the body and returned, if the dst // is too small a new slice will be allocated. // // The function follows redirects. Use Do* for manually handling redirects. func (c *HostClient) Get(ctx context.Context, dst []byte, url string) (statusCode int, body []byte, err error) { return client.GetURL(ctx, dst, url, c) } func (c *HostClient) ConnectionCount() (count int) { c.connsLock.Lock() count = len(c.conns) c.connsLock.Unlock() return } func (c *HostClient) WantConnectionCount() (count int) { return c.connsWait.len() } func (c *HostClient) ConnPoolState() config.ConnPoolState { c.connsLock.Lock() defer c.connsLock.Unlock() cps := config.ConnPoolState{ PoolConnNum: len(c.conns), TotalConnNum: c.connsCount, Addr: c.Addr, MaxConns: c.MaxConns, } if c.connsWait != nil { cps.WaitConnNum = c.connsWait.len() } return cps } // GetTimeout returns the status code and body of url. // // The contents of dst will be replaced by the body and returned, if the dst // is too small a new slice will be allocated. // // The function follows redirects. Use Do* for manually handling redirects. // // errTimeout error is returned if url contents couldn't be fetched // during the given timeout. func (c *HostClient) GetTimeout(ctx context.Context, dst []byte, url string, timeout time.Duration) (statusCode int, body []byte, err error) { return client.GetURLTimeout(ctx, dst, url, timeout, c) } // GetDeadline returns the status code and body of url. // // The contents of dst will be replaced by the body and returned, if the dst // is too small a new slice will be allocated. // // The function follows redirects. Use Do* for manually handling redirects. // // errTimeout error is returned if url contents couldn't be fetched // until the given deadline. func (c *HostClient) GetDeadline(ctx context.Context, dst []byte, url string, deadline time.Time) (statusCode int, body []byte, err error) { return client.GetURLDeadline(ctx, dst, url, deadline, c) } // Post sends POST request to the given url with the given POST arguments. // // The contents of dst will be replaced by the body and returned, if the dst // is too small a new slice will be allocated. // // The function follows redirects. Use Do* for manually handling redirects. // // Empty POST body is sent if postArgs is nil. func (c *HostClient) Post(ctx context.Context, dst []byte, url string, postArgs *protocol.Args) (statusCode int, body []byte, err error) { return client.PostURL(ctx, dst, url, postArgs, c) } // A wantConnQueue is a queue of wantConns. // // inspired by net/http/transport.go type wantConnQueue struct { // This is a queue, not a deque. // It is split into two stages - head[headPos:] and tail. // popFront is trivial (headPos++) on the first stage, and // pushBack is trivial (append) on the second stage. // If the first stage is empty, popFront can swap the // first and second stages to remedy the situation. // // This two-stage split is analogous to the use of two lists // in Okasaki's purely functional queue but without the // overhead of reversing the list when swapping stages. head []*wantConn headPos int tail []*wantConn } // A wantConn records state about a wanted connection // (that is, an active call to getConn). // The conn may be gotten by dialing or by finding an idle connection, // or a cancellation may make the conn no longer wanted. // These three options are racing against each other and use // wantConn to coordinate and agree about the winning outcome. // // inspired by net/http/transport.go type wantConn struct { ready chan struct{} mu sync.Mutex // protects conn, err, close(ready) conn *clientConn err error } // DoTimeout performs the given request and waits for response during // the given timeout duration. // // Request must contain at least non-zero RequestURI with full url (including // scheme and host) or non-zero Host header + RequestURI. // // The function doesn't follow redirects. Use Get* for following redirects. // // Response is ignored if resp is nil. // // errTimeout is returned if the response wasn't returned during // the given timeout. // // If MaxConns is configured (> 0), ErrNoFreeConns is returned // when all connections to the host are busy. // // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. // // Warning: DoTimeout does not terminate the request itself. The request will // continue in the background and the response will be discarded. // If requests take too long and the connection pool gets filled up please // try setting a ReadTimeout. func (c *HostClient) DoTimeout(ctx context.Context, req *protocol.Request, resp *protocol.Response, timeout time.Duration) error { return client.DoTimeout(ctx, req, resp, timeout, c) } // DoDeadline performs the given request and waits for response until // the given deadline. // // Request must contain at least non-zero RequestURI with full url (including // scheme and host) or non-zero Host header + RequestURI. // // The function doesn't follow redirects. Use Get* for following redirects. // // Response is ignored if resp is nil. // // errTimeout is returned if the response wasn't returned until // the given deadline. // // If MaxConns is configured (> 0), ErrNoFreeConns is returned // when all connections to the host are busy. // // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. func (c *HostClient) DoDeadline(ctx context.Context, req *protocol.Request, resp *protocol.Response, deadline time.Time) error { return client.DoDeadline(ctx, req, resp, deadline, c) } // DoRedirects performs the given http request and fills the given http response, // following up to maxRedirectsCount redirects. When the redirect count exceeds // maxRedirectsCount, ErrTooManyRedirects is returned. // // Request must contain at least non-zero RequestURI with full url (including // scheme and host) or non-zero Host header + RequestURI. // // Client determines the server to be requested in the following order: // // - from RequestURI if it contains full url with scheme and host; // - from Host header otherwise. // // Response is ignored if resp is nil. // // If MaxConns is configured (> 0), ErrNoFreeConns is returned // when all connections to the requested host are busy. // // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. func (c *HostClient) DoRedirects(ctx context.Context, req *protocol.Request, resp *protocol.Response, maxRedirectsCount int) error { _, _, err := client.DoRequestFollowRedirects(ctx, req, resp, req.URI().String(), maxRedirectsCount, c) return err } // Do performs the given http request and sets the corresponding response. // // Request must contain at least non-zero RequestURI with full url (including // scheme and host) or non-zero Host header + RequestURI. // // The function doesn't follow redirects. Use Get* for following redirects. // // Response is ignored if resp is nil. // // If MaxConns is configured (> 0), ErrNoFreeConns is returned // when all connections to the host are busy. // // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. func (c *HostClient) Do(ctx context.Context, req *protocol.Request, resp *protocol.Response) error { if ctx == nil { panic("ctx is nil") } var ( err error canIdempotentRetry bool isDefaultRetryFunc = true attempts uint = 0 connAttempts uint = 0 maxAttempts uint = 1 isRequestRetryable client.RetryIfFunc = client.DefaultRetryIf ) retryCfg := c.ClientOptions.RetryConfig if retryCfg != nil { maxAttempts = retryCfg.MaxAttemptTimes } if c.ClientOptions.RetryIfFunc != nil { isRequestRetryable = c.ClientOptions.RetryIfFunc // if the user has provided a custom retry function, the canIdempotentRetry has no meaning anymore. // User will have full control over the retry logic through the custom retry function. isDefaultRetryFunc = false } atomic.AddInt32(&c.pendingRequests, 1) req.Options().StartRequest() for { select { case <-ctx.Done(): req.CloseBodyStream() //nolint:errcheck return ctx.Err() default: } canIdempotentRetry, err = c.do(req, resp) // If there is no custom retry and err is equal to nil, the loop simply exits. if err == nil && isDefaultRetryFunc { if connAttempts != 0 { hlog.SystemLogger().Warnf("Client connection attempt times: %d, url: %s. "+ "This is mainly because the connection in pool is closed by peer in advance. "+ "If this number is too high which indicates that long-connection are basically unavailable, "+ "try to change the request to short-connection.\n", connAttempts, req.URI().FullURI()) } break } // This connection is closed by the peer when it is in the connection pool. // // This case is possible if the server closes the idle // keep-alive connection on timeout. // // Apache and nginx usually do this. if canIdempotentRetry && client.DefaultRetryIf(req, resp, err) && errors.Is(err, errs.ErrBadPoolConn) { connAttempts++ continue } if isDefaultRetryFunc { break } attempts++ if attempts >= maxAttempts { break } // Check whether this request should be retried if !isRequestRetryable(req, resp, err) { break } wait := retry.Delay(attempts, err, retryCfg) // Retry after wait time time.Sleep(wait) } atomic.AddInt32(&c.pendingRequests, -1) if err == io.EOF { err = errConnectionClosed } return err } // PendingRequests returns the current number of requests the client // is executing. // // This function may be used for balancing load among multiple HostClient // instances. func (c *HostClient) PendingRequests() int { return int(atomic.LoadInt32(&c.pendingRequests)) } func (c *HostClient) do(req *protocol.Request, resp *protocol.Response) (bool, error) { nilResp := false if resp == nil { nilResp = true resp = protocol.AcquireResponse() } canIdempotentRetry, err := c.doNonNilReqResp(req, resp) if nilResp { protocol.ReleaseResponse(resp) } return canIdempotentRetry, err } func timeUntil(deadline time.Time) time.Duration { timeout := time.Until(deadline) if timeout <= 0 { return -1 } return timeout } // calcTimeout checks deadline and returns timeout for conn.SetXXXTimeout // // returns 0 which means no timeout // returns -1 if deadline exceeded func calcTimeout(deadline time.Time, timeout time.Duration) time.Duration { if timeout <= 0 { if deadline.IsZero() { return 0 } return timeUntil(deadline) } if deadline.IsZero() { return timeout // must > 0 } if d := timeUntil(deadline); d < timeout { return d } return timeout } func (c *HostClient) getTimeouts(o *config.RequestOptions) (dtimeout, rtimeout, wtimeout time.Duration) { dtimeout = c.DialTimeout if v := o.DialTimeout(); v > 0 { dtimeout = v } rtimeout = c.ReadTimeout if v := o.ReadTimeout(); v > 0 { rtimeout = v } wtimeout = c.WriteTimeout if v := o.WriteTimeout(); v > 0 { wtimeout = v } return } func (c *HostClient) doNonNilReqResp(req *protocol.Request, resp *protocol.Response) (bool, error) { if req == nil { panic("BUG: req cannot be nil") } if resp == nil { panic("BUG: resp cannot be nil") } atomic.StoreUint32(&c.lastUseTime, uint32(time.Now().Unix()-startTimeUnix)) // Free up resources occupied by response before sending the request, // so the GC may reclaim these resources (e.g. response body). // backing up SkipBody in case it was set explicitly customSkipBody := resp.SkipBody resp.Reset() resp.SkipBody = customSkipBody if c.DisablePathNormalizing { req.URI().DisablePathNormalizing = true } o := req.Options() deadline := time.Time{} if v := o.RequestTimeout(); v > 0 { deadline = o.StartTime().Add(v) } dtimeout, rtimeout, wtimeout := c.getTimeouts(o) // dial starts timeout := calcTimeout(deadline, dtimeout) if timeout < 0 { return false, errTimeout } cc, inPool, err := c.acquireConn(timeout) if err != nil { return false, err } conn := cc.c resp.ParseNetAddr(conn) if c.IsTLS && timeout > 0 { // force handshake using dial timeout // NOTE: Handshake() here is optional as Write would tirigger handshake // but for tls handshake, it writes and reads, and we need to set deadline for that. tlsconn, ok := conn.(network.ConnTLSer) if ok { // currently netpoll doesn't support conn.SetDeadline nor tls, but crypto/tls.Conn does. // in case netpoll supports tls in the future, may need to change this to // call both conn.SetReadTimeout, and conn.SetWriteTimeout err := conn.SetDeadline(time.Now().Add(timeout)) if err == nil { err = tlsconn.Handshake() // NOTE: no need conn.SetDeadline(time.Time{})? // we always reset before Write and Read } if err != nil { c.closeConn(cc) return true, err } } } usingProxy := false if c.ProxyURI != nil && bytes.Equal(req.Scheme(), bytestr.StrHTTP) { usingProxy = true proxy.SetProxyAuthHeader(&req.Header, c.ProxyURI) } // write starts timeout = calcTimeout(deadline, wtimeout) if timeout < 0 { c.closeConn(cc) return false, errTimeout } if err = conn.SetWriteTimeout(timeout); err != nil { c.closeConn(cc) return true, err } resetConnection := false if c.MaxConnDuration > 0 && time.Since(cc.createdTime) > c.MaxConnDuration && !req.ConnectionClose() { req.SetConnectionClose() resetConnection = true } userAgentOld := req.Header.UserAgent() if len(userAgentOld) == 0 { req.Header.SetUserAgentBytes(c.getClientName()) } zw := c.acquireWriter(conn) if !usingProxy { err = reqI.Write(req, zw) } else { err = reqI.ProxyWrite(req, zw) } if resetConnection { req.Header.ResetConnectionClose() } if err == nil { err = zw.Flush() } if err != nil { defer c.closeConn(cc) errNorm, ok := conn.(network.ErrorNormalization) if ok { err = errNorm.ToHertzError(err) } if !errors.Is(err, errs.ErrConnectionClosed) { return true, err } // introduced by https://github.com/cloudwego/hertz/pull/412 // only for reading 4xx err // short period of time (50ms) is enough for this case // NOTE: can't use deadline since it likely already exceeded deadline when write timeout = 50 * time.Millisecond if rtimeout > 0 && timeout > rtimeout { timeout = rtimeout } if conn.SetReadTimeout(timeout) != nil { return true, err } zr := c.acquireReader(conn) defer zr.Release() if respI.ReadHeaderAndLimitBody(resp, zr, c.MaxResponseBodySize) == nil { if code := resp.StatusCode(); code >= 400 && code < 600 { // strictly for 4xx only, but 5xx is also acceptable. // both can be considered better response rather than write err return false, nil } } if inPool { err = errs.ErrBadPoolConn } return true, err } // read starts timeout = calcTimeout(deadline, rtimeout) if timeout < 0 { c.closeConn(cc) return false, errTimeout } if err = conn.SetReadTimeout(timeout); err != nil { c.closeConn(cc) return true, err } if customSkipBody || req.Header.IsHead() || req.Header.IsConnect() { resp.SkipBody = true } if c.DisableHeaderNamesNormalizing { resp.Header.DisableNormalizing() } zr := c.acquireReader(conn) // errs.ErrBadPoolConn error are returned when the // 1 byte peek read fails, and we're actually anticipating a response. // Usually this is just due to the inherent keep-alive shut down race, // where the server closed the connection at the same time the client // wrote. The underlying err field is usually io.EOF or some // ECONNRESET sort of thing which varies by platform. _, err = zr.Peek(1) if err != nil { zr.Release() //nolint:errcheck c.closeConn(cc) if inPool && (err == io.EOF || err == syscall.ECONNRESET) { return true, errs.ErrBadPoolConn } // if this is not a pooled connection, // we should not retry to avoid getting stuck in an endless retry loop. errNorm, ok := conn.(network.ErrorNormalization) if ok { err = errNorm.ToHertzError(err) } return false, err } // init here for passing in ReadBodyStream's closure // and this value will be assigned after reading Response's Header // // This is to solve the circular dependency problem of Response and BodyStream shouldCloseConn := false if err = respI.ReadHeaders(resp, zr); err != nil { _ = zr.Release() c.closeConn(cc) return true, err } stream := c.ResponseBodyStream // if it's server-sent event response, // we should set stream=true or it may block till timeout if !stream && resp.Header.ContentLength() < 0 && bytes.HasPrefix(resp.Header.ContentType(), bytestr.MIMETextEventStream) { stream = true } if !stream { err = respI.ReadRespBody(resp, zr, c.MaxResponseBodySize) } else { err = respI.ReadRespBodyStream(resp, zr, c.MaxResponseBodySize, func(shouldClose bool) error { if shouldCloseConn || shouldClose { c.closeConn(cc) } else { c.releaseConn(cc) } return nil }) } zr.Release() //nolint:errcheck if err != nil { c.closeConn(cc) // Don't retry in case of ErrBodyTooLarge since we will just get the same again. retry := !errors.Is(err, errs.ErrBodyTooLarge) return retry, err } shouldCloseConn = resetConnection || req.ConnectionClose() || resp.ConnectionClose() if resp.Header.StatusCode() == consts.StatusSwitchingProtocols && bytes.EqualFold(resp.Header.Peek(consts.HeaderConnection), bytestr.StrUpgrade) { // can not reuse connection in this case, it's no longer http1 protocol. // set BodyStream for (*Response).Hijack resp.SetBodyStream(newUpgradeConn(c, cc), -1) return false, nil } // In stream mode, we still can close/release the connection immediately if there is no content on the wire. if stream && resp.BodyStream() != protocol.NoResponseBody { return false, nil } if shouldCloseConn { c.closeConn(cc) } else { c.releaseConn(cc) } return false, nil } var poolUpgradeConn = sync.Pool{ New: func() interface{} { return &upgradeConn{} }, } type upgradeConn struct { c *HostClient cc *clientConn } func newUpgradeConn(c *HostClient, cc *clientConn) *upgradeConn { p := poolUpgradeConn.Get().(*upgradeConn) p.c = c p.cc = cc runtime.SetFinalizer(p, (*upgradeConn).gc) return p } // Read implements io.Reader func (p *upgradeConn) Read(b []byte) (int, error) { return p.cc.c.Read(b) } // Hijack returns underlying network.Conn. This method is called by (*Response).Hijack func (p *upgradeConn) Hijack() (network.Conn, error) { return p.cc.c, nil } // gc closes conn and reuse upgradeConn. // // It MUST be called only by go runtime to avoid concurenccy issue. // For the 1st GC, it closes conn, and put upgradeConn back to pool // For the 2nd GC, it will be recycled if it's still in pool func (p *upgradeConn) gc() error { if p.c != nil { runtime.SetFinalizer(p, nil) p.c.closeConn(p.cc) p.c = nil p.cc = nil poolUpgradeConn.Put(p) } return nil } func (c *HostClient) Close() error { close(c.closed) return nil } // SetMaxConns sets up the maximum number of connections which may be established to all hosts listed in Addr. func (c *HostClient) SetMaxConns(newMaxConns int) { c.connsLock.Lock() c.MaxConns = newMaxConns c.connsLock.Unlock() } func (c *HostClient) acquireConn(dialTimeout time.Duration) (cc *clientConn, inPool bool, err error) { createConn := false startCleaner := false var n int c.connsLock.Lock() n = len(c.conns) if n == 0 { if c.MaxConns <= 0 || c.connsCount < c.MaxConns { c.connsCount++ createConn = true if !c.connsCleanerRun { startCleaner = true c.connsCleanerRun = true } } } else { n-- cc = c.conns[n] c.conns[n] = nil c.conns = c.conns[:n] } c.connsLock.Unlock() if cc != nil { return cc, true, nil } if !createConn { if c.MaxConnWaitTimeout <= 0 { return nil, true, errs.ErrNoFreeConns } timeout := c.MaxConnWaitTimeout // wait for a free connection tc := timer.AcquireTimer(timeout) defer timer.ReleaseTimer(tc) w := &wantConn{ ready: make(chan struct{}, 1), } defer func() { if err != nil { w.cancel(c, err) } }() // Note: In the case of setting MaxConnWaitTimeout, if the number // of connections in the connection pool exceeds the maximum // number of connections and needs to establish a connection while // waiting, the dialtimeout on the hostclient is used instead of // the dialtimeout in request options. c.queueForIdle(w) select { case <-w.ready: return w.conn, true, w.err case <-tc.C: return nil, true, errs.ErrNoFreeConns } } if startCleaner { go c.connsCleaner() } conn, err := c.dialHostHard(dialTimeout) if err != nil { c.decConnsCount() return nil, false, err } cc = acquireClientConn(conn) return cc, false, nil } func (c *HostClient) queueForIdle(w *wantConn) { c.connsLock.Lock() defer c.connsLock.Unlock() if c.connsWait == nil { c.connsWait = &wantConnQueue{} } c.connsWait.clearFront() c.connsWait.pushBack(w) } func (c *HostClient) dialConnFor(w *wantConn) { conn, err := c.dialHostHard(c.DialTimeout) if err != nil { w.tryDeliver(nil, err) c.decConnsCount() return } cc := acquireClientConn(conn) delivered := w.tryDeliver(cc, nil) if !delivered { // not delivered, return idle connection c.releaseConn(cc) } } // CloseIdleConnections closes any connections which were previously // connected from previous requests but are now sitting idle in a // "keep-alive" state. It does not interrupt any connections currently // in use. func (c *HostClient) CloseIdleConnections() { c.connsLock.Lock() scratch := append([]*clientConn{}, c.conns...) for i := range c.conns { c.conns[i] = nil } c.conns = c.conns[:0] c.connsLock.Unlock() for _, cc := range scratch { c.closeConn(cc) } } func (c *HostClient) ShouldRemove() bool { c.connsLock.Lock() defer c.connsLock.Unlock() return c.connsCount == 0 } func (c *HostClient) connsCleaner() { var ( scratch []*clientConn maxIdleConnDuration = c.MaxIdleConnDuration ) if maxIdleConnDuration <= 0 { maxIdleConnDuration = consts.DefaultMaxIdleConnDuration } for { currentTime := time.Now() // Determine idle connections to be closed. c.connsLock.Lock() conns := c.conns n := len(conns) i := 0 for i < n && currentTime.Sub(conns[i].lastUseTime) > maxIdleConnDuration { i++ } sleepFor := maxIdleConnDuration if i < n { // + 1 so we actually sleep past the expiration time and not up to it. // Otherwise the > check above would still fail. sleepFor = maxIdleConnDuration - currentTime.Sub(conns[i].lastUseTime) + 1 } scratch = append(scratch[:0], conns[:i]...) if i > 0 { m := copy(conns, conns[i:]) for i = m; i < n; i++ { conns[i] = nil } c.conns = conns[:m] } c.connsLock.Unlock() // Close idle connections. for i, cc := range scratch { c.closeConn(cc) scratch[i] = nil } // Determine whether to stop the connsCleaner. c.connsLock.Lock() mustStop := c.connsCount == 0 if mustStop { c.connsCleanerRun = false } c.connsLock.Unlock() if mustStop { break } time.Sleep(sleepFor) } } func (c *HostClient) closeConn(cc *clientConn) { c.decConnsCount() cc.c.Close() releaseClientConn(cc) } func (c *HostClient) decConnsCount() { if c.MaxConnWaitTimeout <= 0 { c.connsLock.Lock() c.connsCount-- c.connsLock.Unlock() return } c.connsLock.Lock() defer c.connsLock.Unlock() dialed := false if q := c.connsWait; q != nil && q.len() > 0 { for q.len() > 0 { w := q.popFront() if w.waiting() { go c.dialConnFor(w) dialed = true break } } } if !dialed { c.connsCount-- } } func acquireClientConn(conn network.Conn) *clientConn { v := clientConnPool.Get() if v == nil { v = &clientConn{} } cc := v.(*clientConn) cc.c = conn cc.createdTime = time.Now() return cc } func releaseClientConn(cc *clientConn) { // Reset all fields. *cc = clientConn{} clientConnPool.Put(cc) } var clientConnPool sync.Pool func (c *HostClient) releaseConn(cc *clientConn) { if cc.c.Len() > 0 { // unexpected buffered data due to malformed response c.closeConn(cc) return } cc.lastUseTime = time.Now() if c.MaxConnWaitTimeout <= 0 { c.connsLock.Lock() c.conns = append(c.conns, cc) c.connsLock.Unlock() return } // try to deliver an idle connection to a *wantConn c.connsLock.Lock() defer c.connsLock.Unlock() delivered := false if q := c.connsWait; q != nil && q.len() > 0 { for q.len() > 0 { w := q.popFront() if w.waiting() { delivered = w.tryDeliver(cc, nil) break } } } if !delivered { c.conns = append(c.conns, cc) } } func (c *HostClient) acquireWriter(conn network.Conn) network.Writer { return conn } func (c *HostClient) acquireReader(conn network.Conn) network.Reader { return conn } func newClientTLSConfig(c *tls.Config, addr string) *tls.Config { if c == nil { c = &tls.Config{} } else { c = c.Clone() } if c.ClientSessionCache == nil { c.ClientSessionCache = tls.NewLRUClientSessionCache(0) } if len(c.ServerName) == 0 { serverName := tlsServerName(addr) if serverName == "*" { c.InsecureSkipVerify = true } else { c.ServerName = serverName } } return c } func tlsServerName(addr string) string { if !strings.Contains(addr, ":") { return addr } host, _, err := net.SplitHostPort(addr) if err != nil { return "*" } return host } func (c *HostClient) nextAddr() string { c.addrsLock.Lock() if c.addrs == nil { c.addrs = strings.Split(c.Addr, ",") } addr := c.addrs[0] if len(c.addrs) > 1 { addr = c.addrs[c.addrIdx%uint32(len(c.addrs))] c.addrIdx++ } c.addrsLock.Unlock() return addr } func (c *HostClient) dialHostHard(dialTimeout time.Duration) (conn network.Conn, err error) { // attempt to dial all the available hosts before giving up. c.addrsLock.Lock() n := len(c.addrs) c.addrsLock.Unlock() if n == 0 { // It looks like c.addrs isn't initialized yet. n = 1 } deadline := time.Now().Add(dialTimeout) for n > 0 { addr := c.nextAddr() tlsConfig := c.cachedTLSConfig(addr) conn, err = dialAddr(addr, c.Dialer, c.DialDualStack, tlsConfig, dialTimeout, c.ProxyURI, c.IsTLS) if err == nil { return conn, nil } if time.Since(deadline) >= 0 { break } n-- } return nil, err } func (c *HostClient) cachedTLSConfig(addr string) *tls.Config { var cfgAddr string if c.ProxyURI != nil && bytes.Equal(c.ProxyURI.Scheme(), bytestr.StrHTTPS) { cfgAddr = bytesconv.B2s(c.ProxyURI.Host()) } if c.IsTLS && cfgAddr == "" { cfgAddr = addr } if cfgAddr == "" { return nil } c.tlsConfigMapLock.Lock() if c.tlsConfigMap == nil { c.tlsConfigMap = make(map[string]*tls.Config) } cfg := c.tlsConfigMap[cfgAddr] if cfg == nil { cfg = newClientTLSConfig(c.TLSConfig, cfgAddr) c.tlsConfigMap[cfgAddr] = cfg } c.tlsConfigMapLock.Unlock() return cfg } func dialAddr(addr string, dial network.Dialer, dialDualStack bool, tlsConfig *tls.Config, timeout time.Duration, proxyURI *protocol.URI, isTLS bool) (network.Conn, error) { var conn network.Conn var err error if dial == nil { hlog.SystemLogger().Warn("HostClient: no dialer specified, trying to use default dialer") dial = dialer.DefaultDialer() } dialFunc := dial.DialConnection // addr has already been added port, no need to do it here if proxyURI != nil { // use tcp connection first, proxy will AddTLS to it conn, err = dialFunc("tcp", string(proxyURI.Host()), timeout, nil) } else { conn, err = dialFunc("tcp", addr, timeout, tlsConfig) } if err != nil { return nil, err } if conn == nil { panic("BUG: dial.DialConnection returned (nil, nil)") } if proxyURI != nil { conn, err = proxy.SetupProxy(conn, addr, proxyURI, tlsConfig, isTLS, dial) } // conn must be nil when got error, so doesn't need to close it if err != nil { return nil, err } return conn, nil } func (c *HostClient) getClientName() []byte { v := c.clientName.Load() var clientName []byte if v == nil { clientName = []byte(c.Name) if len(clientName) == 0 && !c.NoDefaultUserAgentHeader { clientName = bytestr.DefaultUserAgent } c.clientName.Store(clientName) } else { clientName = v.([]byte) } return clientName } // waiting reports whether w is still waiting for an answer (connection or error). func (w *wantConn) waiting() bool { select { case <-w.ready: return false default: return true } } // tryDeliver attempts to deliver conn, err to w and reports whether it succeeded. func (w *wantConn) tryDeliver(conn *clientConn, err error) bool { w.mu.Lock() defer w.mu.Unlock() if w.conn != nil || w.err != nil { return false } w.conn = conn w.err = err if w.conn == nil && w.err == nil { panic("hertz: internal error: misuse of tryDeliver") } close(w.ready) return true } // cancel marks w as no longer wanting a result (for example, due to cancellation). // If a connection has been delivered already, cancel returns it with c.releaseConn. func (w *wantConn) cancel(c *HostClient, err error) { w.mu.Lock() if w.conn == nil && w.err == nil { close(w.ready) // catch misbehavior in future delivery } conn := w.conn w.conn = nil w.err = err w.mu.Unlock() if conn != nil { c.releaseConn(conn) } } // len returns the number of items in the queue. func (q *wantConnQueue) len() int { return len(q.head) - q.headPos + len(q.tail) } // pushBack adds w to the back of the queue. func (q *wantConnQueue) pushBack(w *wantConn) { q.tail = append(q.tail, w) } // popFront removes and returns the wantConn at the front of the queue. func (q *wantConnQueue) popFront() *wantConn { if q.headPos >= len(q.head) { if len(q.tail) == 0 { return nil } // Pick up tail as new head, clear tail. q.head, q.headPos, q.tail = q.tail, 0, q.head[:0] } w := q.head[q.headPos] q.head[q.headPos] = nil q.headPos++ return w } // peekFront returns the wantConn at the front of the queue without removing it. func (q *wantConnQueue) peekFront() *wantConn { if q.headPos < len(q.head) { return q.head[q.headPos] } if len(q.tail) > 0 { return q.tail[0] } return nil } // cleanFront pops any wantConns that are no longer waiting from the head of the // queue, reporting whether any were popped. func (q *wantConnQueue) clearFront() (cleaned bool) { for { w := q.peekFront() if w == nil || w.waiting() { return cleaned } q.popFront() cleaned = true } } func NewHostClient(c *ClientOptions) client.HostClient { hc := &HostClient{ ClientOptions: c, closed: make(chan struct{}), } return hc } type ClientOptions struct { // Client name. Used in User-Agent request header. Name string // NoDefaultUserAgentHeader when set to true, causes the default // User-Agent header to be excluded from the Request. NoDefaultUserAgentHeader bool // Callback for establishing new connection to the host. // // Default Dialer is used if not set. Dialer network.Dialer // Timeout for establishing new connections to hosts. // // Default DialTimeout is used if not set. DialTimeout time.Duration // Attempt to connect to both ipv4 and ipv6 host addresses // if set to true. // // This option is used only if default TCP dialer is used, // i.e. if Dialer is blank. // // By default client connects only to ipv4 addresses, // since unfortunately ipv6 remains broken in many networks worldwide :) DialDualStack bool // Whether to use TLS (aka SSL or HTTPS) for host connections. // Optional TLS config. TLSConfig *tls.Config // Maximum number of connections which may be established to all hosts // listed in Addr. // // You can change this value while the HostClient is being used // using HostClient.SetMaxConns(value) // // no limit if <= 0. MaxConns int // Keep-alive connections are closed after this duration. // // By default connection duration is unlimited. MaxConnDuration time.Duration // Idle keep-alive connections are closed after this duration. // // By default idle connections are closed // after DefaultMaxIdleConnDuration. MaxIdleConnDuration time.Duration // Maximum duration for full response reading (including body). // // By default response read timeout is unlimited. ReadTimeout time.Duration // Maximum duration for full request writing (including body). // // By default request write timeout is unlimited. WriteTimeout time.Duration // Maximum response body size. // // The client returns errBodyTooLarge if this limit is greater than 0 // and response body is greater than the limit. // // By default response body size is unlimited. MaxResponseBodySize int // Header names are passed as-is without normalization // if this option is set. // // Disabled header names' normalization may be useful only for proxying // responses to other clients expecting case-sensitive header names. // // By default request and response header names are normalized, i.e. // The first letter and the first letters following dashes // are uppercased, while all the other letters are lowercased. // Examples: // // * HOST -> Host // * content-type -> Content-Type // * cONTENT-lenGTH -> Content-Length DisableHeaderNamesNormalizing bool // Path values are sent as-is without normalization // // Disabled path normalization may be useful for proxying incoming requests // to servers that are expecting paths to be forwarded as-is. // // By default path values are normalized, i.e. // extra slashes are removed, special characters are encoded. DisablePathNormalizing bool // Maximum duration for waiting for a free connection. // // By default will not wait, return ErrNoFreeConns immediately MaxConnWaitTimeout time.Duration // ResponseBodyStream enables response body streaming ResponseBodyStream bool // All configurations related to retry RetryConfig *retry.Config RetryIfFunc client.RetryIfFunc // Observe hostclient state StateObserve config.HostClientStateFunc // StateObserve execution interval ObservationInterval time.Duration } ================================================ FILE: pkg/protocol/http1/client_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package http1 import ( "bytes" "context" "crypto/tls" "errors" "fmt" "io/ioutil" "net" "net/http" "strings" "sync" "sync/atomic" "testing" "time" "github.com/cloudwego/netpoll" "github.com/cloudwego/hertz/pkg/app/client/retry" "github.com/cloudwego/hertz/pkg/common/config" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/hlog" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/test/mock" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/network/dialer" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/client" "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/protocol/http1/resp" ) var errDialTimeout = errs.New(errs.ErrTimeout, errs.ErrorTypePublic, "dial timeout") func TestHostClientMaxConnWaitTimeoutWithEarlierDeadline(t *testing.T) { var ( emptyBodyCount uint8 wg sync.WaitGroup // make deadline reach earlier than conns wait timeout timeout = 10 * time.Millisecond ) c := &HostClient{ ClientOptions: &ClientOptions{ Dialer: dialerFunc(func(network, addr string, timeout time.Duration) (network.Conn, error) { return mock.SlowReadDialer(addr) }), MaxConns: 1, MaxConnWaitTimeout: 50 * time.Millisecond, }, Addr: "foobar", } var errTimeoutCount uint32 for i := 0; i < 5; i++ { wg.Add(1) go func() { defer wg.Done() req := protocol.AcquireRequest() req.SetRequestURI("http://foobar/baz") req.Header.SetMethod(consts.MethodPost) req.SetBodyString("bar") resp := protocol.AcquireResponse() if err := c.DoDeadline(context.Background(), req, resp, time.Now().Add(timeout)); err != nil { if !errors.Is(err, errs.ErrTimeout) { t.Errorf("unexpected error: %s. Expecting %s", err, errs.ErrTimeout) } atomic.AddUint32(&errTimeoutCount, 1) } else { if resp.StatusCode() != consts.StatusOK { t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), consts.StatusOK) } body := resp.Body() if string(body) != "foo" { t.Errorf("unexpected body %q. Expecting %q", body, "abcd") } } }() } wg.Wait() c.connsLock.Lock() for { w := c.connsWait.popFront() if w == nil { break } w.mu.Lock() if w.err != nil && !errors.Is(w.err, errs.ErrNoFreeConns) { t.Errorf("unexpected error: %s. Expecting %s", w.err, errs.ErrNoFreeConns) } w.mu.Unlock() } c.connsLock.Unlock() if errTimeoutCount == 0 { t.Errorf("unexpected errTimeoutCount: %d. Expecting > 0", errTimeoutCount) } if emptyBodyCount > 0 { t.Fatalf("at least one request body was empty") } } func TestReadHeaderErr(t *testing.T) { ln, _ := net.Listen("tcp", "localhost:0") defer ln.Close() svr := http.Server{} svr.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { hj := w.(http.Hijacker) conn, rw, err := hj.Hijack() assert.Nil(t, err) defer conn.Close() rw.Write([]byte("HTTP/1.1 200 OK\r\nConnection: close\r\nContent-Ty")) rw.Flush() }) go svr.Serve(ln) req := protocol.AcquireRequest() defer protocol.ReleaseRequest(req) req.SetRequestURI("http://" + ln.Addr().String()) resp := protocol.AcquireResponse() defer protocol.ReleaseResponse(resp) c := &HostClient{ Addr: ln.Addr().String(), ClientOptions: &ClientOptions{ Dialer: dialer.DefaultDialer(), }, } err := c.Do(context.Background(), req, resp) assert.NotNil(t, err) } func TestResponseReadBodyStream(t *testing.T) { // small body genBody := "abcdef4343" s := "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 5\r\n\r\n" testContinueReadResponseBodyStream(t, s, genBody, 10, 5, 0, 5) testContinueReadResponseBodyStream(t, s, genBody, 1, 5, 0, 0) // big body (> 8193) s1 := "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 9216\r\nContent-Type: foo/bar\r\n\r\n" genBody = strings.Repeat("1", 9*1024) testContinueReadResponseBodyStream(t, s1, genBody, 10*1024, 5*1024, 4*1024, 0) testContinueReadResponseBodyStream(t, s1, genBody, 10*1024, 1*1024, 8*1024, 0) testContinueReadResponseBodyStream(t, s1, genBody, 10*1024, 9*1024, 0*1024, 0) // normal stream testContinueReadResponseBodyStream(t, s1, genBody, 1*1024, 5*1024, 4*1024, 0) testContinueReadResponseBodyStream(t, s1, genBody, 1*1024, 1*1024, 8*1024, 0) testContinueReadResponseBodyStream(t, s1, genBody, 1*1024, 9*1024, 0*1024, 0) testContinueReadResponseBodyStream(t, s1, genBody, 5, 5*1024, 4*1024, 0) testContinueReadResponseBodyStream(t, s1, genBody, 5, 1*1024, 8*1024, 0) testContinueReadResponseBodyStream(t, s1, genBody, 5, 9*1024, 0, 0) // critical point testContinueReadResponseBodyStream(t, s1, genBody, 8*1024+1, 5*1024, 4*1024, 0) testContinueReadResponseBodyStream(t, s1, genBody, 8*1024+1, 1*1024, 8*1024, 0) testContinueReadResponseBodyStream(t, s1, genBody, 8*1024+1, 9*1024, 0*1024, 0) // chunked body s2 := "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nTransfer-Encoding: chunked\r\nContent-Type: aa/bb\r\n\r\n3\r\nabc\r\n5\r\n12345\r\n0\r\n\r\ntrail" testContinueReadResponseBodyStream(t, s2, "", 10*1024, 3, 5, 5) s3 := "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nTransfer-Encoding: chunked\r\nContent-Type: aa/bb\r\n\r\n3\r\nabc\r\n5\r\n12345\r\n0\r\n\r\n" testContinueReadResponseBodyStream(t, s3, "", 10*1024, 3, 5, 0) } func testContinueReadResponseBodyStream(t *testing.T, header, body string, maxBodySize, firstRead, leftBytes, bytesLeftInReader int) { mr := netpoll.NewReader(bytes.NewBufferString(header + body)) var r protocol.Response if err := resp.ReadHeaderBodyStream(&r, mr, maxBodySize, nil); err != nil { t.Fatalf("error when reading request body stream: %s", err) } fRead := firstRead streamRead := make([]byte, fRead) sR, _ := r.BodyStream().Read(streamRead) if sR != firstRead { t.Fatalf("should read %d from stream body, but got %d", firstRead, sR) } leftB, _ := ioutil.ReadAll(r.BodyStream()) if len(leftB) != leftBytes { t.Fatalf("should left %d bytes from stream body, but left %d", leftBytes, len(leftB)) } if r.Header.ContentLength() > 0 { gotBody := append(streamRead, leftB...) if !bytes.Equal([]byte(body[:r.Header.ContentLength()]), gotBody) { t.Fatalf("body read from stream is not equal to the origin. Got: %s", gotBody) } } left, _ := mr.Next(mr.Len()) if len(left) != bytesLeftInReader { fmt.Printf("##########header:%s,body:%s,%d:max,first:%d,left:%d,leftin:%d\n", header, body, maxBodySize, firstRead, leftBytes, bytesLeftInReader) fmt.Printf("##########left: %s\n", left) t.Fatalf("should left %d bytes in original reader. got %q", bytesLeftInReader, len(left)) } } type dialerFunc func(network, addr string, timeout time.Duration) (network.Conn, error) func (f dialerFunc) DialConnection(network, address string, timeout time.Duration, tlsConfig *tls.Config) (conn network.Conn, err error) { return f(network, address, timeout) } func (_ dialerFunc) DialTimeout(network, address string, timeout time.Duration, tlsConfig *tls.Config) (conn net.Conn, err error) { return nil, nil } func (_ dialerFunc) AddTLS(conn network.Conn, tlsConfig *tls.Config) (network.Conn, error) { return nil, nil } type slowDialer struct { network.Dialer } func (s *slowDialer) DialConnection(network, address string, timeout time.Duration, tlsConfig *tls.Config) (conn network.Conn, err error) { time.Sleep(timeout) return nil, errDialTimeout } func TestTimeoutPriority(t *testing.T) { rtimeoutDialer := dialerFunc(func(network, addr string, timeout time.Duration) (network.Conn, error) { return mock.SlowReadDialer(addr) }) wtimeoutDialer := dialerFunc(func(network, addr string, timeout time.Duration) (network.Conn, error) { return mock.SlowWriteDialer(addr) }) noopRequestOpt := config.RequestOption{F: func(o *config.RequestOptions) {}} tests := []struct { name string dialer network.Dialer clientOpts *ClientOptions reqOpt config.RequestOption expectDelay time.Duration expectedErr error }{ // ReadTimeout cases { "ReadTimeout_cli_60ms_req_100ms", rtimeoutDialer, &ClientOptions{ReadTimeout: 60 * time.Millisecond}, config.WithReadTimeout(100 * time.Millisecond), 100 * time.Millisecond, mock.ErrReadTimeout, }, { "ReadTimeout_cli_100ms_req_60ms", rtimeoutDialer, &ClientOptions{ReadTimeout: 100 * time.Millisecond}, config.WithReadTimeout(60 * time.Millisecond), 60 * time.Millisecond, mock.ErrReadTimeout, }, { "ReadTimeout_cli_unset_req_60ms", rtimeoutDialer, &ClientOptions{}, config.WithReadTimeout(60 * time.Millisecond), 60 * time.Millisecond, mock.ErrReadTimeout, }, { "ReadTimeout_cli_60ms_req_unset", rtimeoutDialer, &ClientOptions{ReadTimeout: 60 * time.Millisecond}, noopRequestOpt, 60 * time.Millisecond, mock.ErrReadTimeout, }, // WriteTimeout cases { "WriteTimeout_cli_100ms_req_150ms", wtimeoutDialer, &ClientOptions{WriteTimeout: 100 * time.Millisecond}, config.WithWriteTimeout(150 * time.Millisecond), 150 * time.Millisecond, mock.ErrWriteTimeout, }, { "WriteTimeout_cli_150ms_req_100ms", wtimeoutDialer, &ClientOptions{WriteTimeout: 150 * time.Millisecond}, config.WithWriteTimeout(100 * time.Millisecond), 100 * time.Millisecond, mock.ErrWriteTimeout, }, { "WriteTimeout_cli_unset_req_120ms", wtimeoutDialer, &ClientOptions{}, config.WithWriteTimeout(120 * time.Millisecond), 120 * time.Millisecond, mock.ErrWriteTimeout, }, { "WriteTimeout_cli_120ms_req_unset", wtimeoutDialer, &ClientOptions{WriteTimeout: 120 * time.Millisecond}, noopRequestOpt, 120 * time.Millisecond, mock.ErrWriteTimeout, }, // DialTimeout cases { "DialTimeout_cli_60ms_req_100ms", &slowDialer{}, &ClientOptions{DialTimeout: 60 * time.Millisecond}, config.WithDialTimeout(100 * time.Millisecond), 100 * time.Millisecond, errDialTimeout, }, { "DialTimeout_cli_100ms_req_60ms", &slowDialer{}, &ClientOptions{DialTimeout: 100 * time.Millisecond}, config.WithDialTimeout(60 * time.Millisecond), 60 * time.Millisecond, errDialTimeout, }, { "DialTimeout_cli_unset_req_60ms", &slowDialer{}, &ClientOptions{}, config.WithDialTimeout(60 * time.Millisecond), 60 * time.Millisecond, errDialTimeout, }, { "DialTimeout_cli_60ms_req_unset", &slowDialer{}, &ClientOptions{DialTimeout: 60 * time.Millisecond}, noopRequestOpt, 60 * time.Millisecond, errDialTimeout, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tt.clientOpts.Dialer = tt.dialer c := &HostClient{ClientOptions: tt.clientOpts, Addr: "foobar"} req := protocol.AcquireRequest() req.SetRequestURI("http://foobar/baz") req.SetOptions(tt.reqOpt) start := time.Now() err := c.Do(context.Background(), req, protocol.AcquireResponse()) duration := time.Since(start) assert.DeepEqual(t, tt.expectedErr, err) // Check if duration is within expected delay ±30ms tolerance := 30 * time.Millisecond if !(duration >= tt.expectDelay-tolerance && duration <= tt.expectDelay+tolerance) { t.Errorf("Duration %v not within expected %v ±%v", duration, tt.expectDelay, tolerance) } }) } } func TestDoNonNilReqResp(t *testing.T) { c := &HostClient{ ClientOptions: &ClientOptions{ Dialer: dialerFunc(func(network, addr string, timeout time.Duration) (network.Conn, error) { return mock.NewConn("HTTP/1.1 400 OK\nContent-Length: 6\n\n123456"), nil }), }, } req := protocol.AcquireRequest() resp := protocol.AcquireResponse() req.SetHost("foobar") retry, err := c.doNonNilReqResp(req, resp) assert.False(t, retry) assert.Nil(t, err) assert.DeepEqual(t, resp.StatusCode(), 400) assert.DeepEqual(t, resp.Body(), []byte("123456")) } func TestDoNonNilReqResp_WriteErr(t *testing.T) { c := &HostClient{ ClientOptions: &ClientOptions{}, } req := protocol.AcquireRequest() resp := protocol.AcquireResponse() req.SetHost("foobar") req.SetConnectionClose() // won't reuse the conn // 200 with write err, will return write err c.ClientOptions.Dialer = dialerFunc(func(network, addr string, timeout time.Duration) (network.Conn, error) { return &writeErrConn{mock.NewConn("HTTP/1.1 200 OK\nContent-Length: 6\n\n123456")}, nil }) retry, err := c.doNonNilReqResp(req, resp) assert.True(t, retry) assert.NotNil(t, err) c = &HostClient{ ClientOptions: &ClientOptions{ Dialer: dialerFunc(func(network, addr string, timeout time.Duration) (network.Conn, error) { return &writeErrConn{mock.NewConn("HTTP/1.1 400 OK\nContent-Length: 6\n\n123456")}, nil }), }, } // 400 with write err, will NOT return write err c.ClientOptions.Dialer = dialerFunc(func(network, addr string, timeout time.Duration) (network.Conn, error) { return &writeErrConn{mock.NewConn("HTTP/1.1 400 OK\nContent-Length: 6\n\n123456")}, nil }) retry, err = c.doNonNilReqResp(req, resp) assert.False(t, retry) assert.Nil(t, err) assert.DeepEqual(t, resp.StatusCode(), 400) assert.DeepEqual(t, resp.Body(), []byte("123456")) } func TestDoNonNilReqResp_TLS(t *testing.T) { const ( dialTimeout = 123 * time.Millisecond dev = 10 * time.Millisecond ) conn := mock.NewConn("HTTP/1.1 200 OK\nContent-Length: 5\n\n54321") tlsconn := mock.NewTLSConn(conn) c := &HostClient{ IsTLS: true, ClientOptions: &ClientOptions{ DialTimeout: dialTimeout, Dialer: dialerFunc(func(network, addr string, timeout time.Duration) (network.Conn, error) { return tlsconn, nil }), }, } req := protocol.AcquireRequest() resp := protocol.AcquireResponse() req.SetHost("foobar") // HandshakeErr != nil tlsconn.HandshakeErr = errors.New("testerr") retry, err := c.doNonNilReqResp(req, resp) assert.True(t, retry) assert.True(t, err == tlsconn.HandshakeErr) if diff := conn.GetReadTimeout() - dialTimeout; diff < -dev || diff > dev { t.Fatal("unexpected timeout. got", conn.GetReadTimeout(), "expect", dialTimeout) } assert.True(t, conn.GetReadTimeout() == conn.GetWriteTimeout()) // HandshakeErr == nil tlsconn.HandshakeErr = nil retry, err = c.doNonNilReqResp(req, resp) assert.False(t, retry) assert.Nil(t, err) assert.DeepEqual(t, resp.StatusCode(), 200) assert.DeepEqual(t, resp.Body(), []byte("54321")) } func TestDoNonNilReqResp_Err(t *testing.T) { c := &HostClient{ ClientOptions: &ClientOptions{ Dialer: dialerFunc(func(network, addr string, timeout time.Duration) (network.Conn, error) { return peekErrConn{writeErrConn{mock.NewConn("")}}, nil }), }, } req := protocol.AcquireRequest() resp := protocol.AcquireResponse() req.SetHost("foobar") retry, err := c.doNonNilReqResp(req, resp) assert.True(t, retry) assert.NotNil(t, err) assert.Assert(t, err == errs.ErrConnectionClosed, err) // returned by writeErrConn } func doGET(t *testing.T, addr, path string) *protocol.Response { req := protocol.AcquireRequest() defer protocol.ReleaseRequest(req) req.SetRequestURI("http://" + addr + path) resp := protocol.AcquireResponse() c := &HostClient{ Addr: addr, ClientOptions: &ClientOptions{ Dialer: dialer.DefaultDialer(), }, } err := c.Do(context.Background(), req, resp) assert.Nil(t, err) return resp } func TestStreamResponse_EventStream(t *testing.T) { ln, _ := net.Listen("tcp", "localhost:0") defer ln.Close() svr := http.Server{} svr.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") f := w.(http.Flusher) for i := 0; i < 5; i++ { _, err := w.Write([]byte(fmt.Sprintf("data:%d", i))) assert.Nil(t, err) f.Flush() // Transfer-Encoding chunked time.Sleep(20 * time.Millisecond) } }) go svr.Serve(ln) resp := doGET(t, ln.Addr().String(), "/") defer protocol.ReleaseResponse(resp) assert.Assert(t, resp.IsBodyStream()) r := resp.BodyStream() b := make([]byte, 10) for i := 0; i < 5; i++ { n, err := r.Read(b) assert.Nil(t, err) assert.Assert(t, string(b[:n]) == fmt.Sprintf("data:%d", i)) } } func TestStreamResponse_ConnUpgrade(t *testing.T) { ln, _ := net.Listen("tcp", "localhost:0") defer ln.Close() svr := http.Server{} svr.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { hj, ok := w.(http.Hijacker) if !ok { http.Error(w, "webserver doesn't support hijacking", http.StatusInternalServerError) return } conn, rw, err := hj.Hijack() if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } defer conn.Close() _, err = rw.WriteString("HTTP/1.1 101 Switching Protocols\nConnection: Upgrade\n\n") assert.Nil(t, err) assert.Nil(t, rw.Flush()) b := make([]byte, 100) for { // echo with "echo:" prefix n, err := rw.Read(b) if err != nil { return } _, err = rw.Write([]byte("echo:" + string(b[:n]))) if err != nil { return } _ = rw.Flush() } }) go svr.Serve(ln) resp := doGET(t, ln.Addr().String(), "/") defer protocol.ReleaseResponse(resp) assert.DeepEqual(t, resp.StatusCode(), 101) s := resp.BodyStream() assert.NotNil(t, s) conn, err := resp.Hijack() assert.Nil(t, err) b := make([]byte, 100) _, _ = conn.Write(append(b[:0], "hello"...)) n, err := s.Read(b) // same as conn.Read assert.Nil(t, err) assert.DeepEqual(t, string(b[:n]), "echo:hello") } func TestStateObserve(t *testing.T) { syncState := struct { mu sync.Mutex state config.ConnPoolState }{} c := &HostClient{ ClientOptions: &ClientOptions{ Dialer: dialerFunc(func(network, addr string, timeout time.Duration) (network.Conn, error) { return mock.SlowReadDialer(addr) }), StateObserve: func(hcs config.HostClientState) { syncState.mu.Lock() defer syncState.mu.Unlock() syncState.state = hcs.ConnPoolState() }, ObservationInterval: 50 * time.Millisecond, }, Addr: "foobar", closed: make(chan struct{}), } c.SetDynamicConfig(&client.DynamicConfig{ Addr: utils.AddMissingPort(c.Addr, true), }) time.Sleep(500 * time.Millisecond) assert.Nil(t, c.Close()) syncState.mu.Lock() assert.DeepEqual(t, "foobar:443", syncState.state.Addr) syncState.mu.Unlock() } func TestCachedTLSConfig(t *testing.T) { c := &HostClient{ ClientOptions: &ClientOptions{ Dialer: dialerFunc(func(network, addr string, timeout time.Duration) (network.Conn, error) { return mock.SlowReadDialer(addr) }), TLSConfig: &tls.Config{ InsecureSkipVerify: true, }, }, Addr: "foobar", IsTLS: true, } cfg1 := c.cachedTLSConfig("foobar") cfg2 := c.cachedTLSConfig("baz") assert.NotEqual(t, cfg1, cfg2) cfg3 := c.cachedTLSConfig("foobar") assert.DeepEqual(t, cfg1, cfg3) } func TestRetry(t *testing.T) { var times int32 c := &HostClient{ ClientOptions: &ClientOptions{ Dialer: dialerFunc(func(network, addr string, timeout time.Duration) (network.Conn, error) { times++ if times < 3 { return &retryConn{ Conn: mock.NewConn(""), }, nil } return mock.NewConn("HTTP/1.1 200 OK\r\nContent-Length: 10\r\nContent-Type: foo/bar\r\n\r\n0123456789"), nil }), RetryConfig: &retry.Config{ MaxAttemptTimes: 5, Delay: time.Millisecond * 10, }, RetryIfFunc: func(req *protocol.Request, resp *protocol.Response, err error) bool { return resp.Header.ContentLength() != 10 }, }, Addr: "foobar", } req := protocol.AcquireRequest() req.SetRequestURI("http://foobar/baz") req.SetOptions(config.WithWriteTimeout(time.Millisecond * 100)) resp := protocol.AcquireResponse() ch := make(chan error, 1) go func() { ch <- c.Do(context.Background(), req, resp) }() select { case <-time.After(time.Second * 2): t.Fatalf("should use writeTimeout in request options") case err := <-ch: assert.Nil(t, err) assert.True(t, times == 3) assert.DeepEqual(t, resp.StatusCode(), 200) assert.DeepEqual(t, resp.Body(), []byte("0123456789")) } } // mockConn for getting error when write binary data. type writeErrConn struct { network.Conn } func (w writeErrConn) WriteBinary(b []byte) (n int, err error) { return 0, errs.ErrConnectionClosed } type peekErrConn struct { network.Conn } func (c peekErrConn) Peek(n int) ([]byte, error) { return nil, errors.New("peek err") } type retryConn struct { network.Conn } func (w retryConn) SetWriteTimeout(t time.Duration) error { return errors.New("should retry") } func TestConnInPoolRetry(t *testing.T) { c := &HostClient{ ClientOptions: &ClientOptions{ Dialer: dialerFunc(func(network, addr string, timeout time.Duration) (network.Conn, error) { return mock.NewOneTimeConn("HTTP/1.1 200 OK\r\nContent-Length: 10\r\nContent-Type: foo/bar\r\n\r\n0123456789"), nil }), }, Addr: "foobar", } req := protocol.AcquireRequest() req.SetRequestURI("http://foobar/baz") req.SetOptions(config.WithWriteTimeout(time.Millisecond * 100)) resp := protocol.AcquireResponse() logbuf := &bytes.Buffer{} hlog.SetOutput(logbuf) err := c.Do(context.Background(), req, resp) assert.Nil(t, err) assert.DeepEqual(t, resp.StatusCode(), 200) assert.DeepEqual(t, string(resp.Body()), "0123456789") assert.True(t, logbuf.String() == "") protocol.ReleaseResponse(resp) resp = protocol.AcquireResponse() err = c.Do(context.Background(), req, resp) assert.Nil(t, err) assert.DeepEqual(t, resp.StatusCode(), 200) assert.DeepEqual(t, string(resp.Body()), "0123456789") assert.True(t, strings.Contains(logbuf.String(), "Client connection attempt times: 1")) } func TestConnNotRetry(t *testing.T) { c := &HostClient{ ClientOptions: &ClientOptions{ Dialer: dialerFunc(func(network, addr string, timeout time.Duration) (network.Conn, error) { return mock.NewBrokenConn(""), nil }), }, Addr: "foobar", } req := protocol.AcquireRequest() req.SetRequestURI("http://foobar/baz") req.SetOptions(config.WithWriteTimeout(time.Millisecond * 100)) resp := protocol.AcquireResponse() logbuf := &bytes.Buffer{} hlog.SetOutput(logbuf) err := c.Do(context.Background(), req, resp) assert.DeepEqual(t, errs.ErrConnectionClosed, err) assert.True(t, logbuf.String() == "") protocol.ReleaseResponse(resp) } type countCloseConn struct { network.Conn isClose bool } func (c *countCloseConn) Close() error { c.isClose = true return nil } func newCountCloseConn(s string) *countCloseConn { return &countCloseConn{ Conn: mock.NewConn(s), } } func TestStreamNoContent(t *testing.T) { conn := newCountCloseConn("HTTP/1.1 204 Foo Bar\r\nContent-Type: aab\r\nTrailer: Foo\r\nContent-Encoding: deflate\r\nTransfer-Encoding: chunked\r\n\r\n0\r\nFoo: bar\r\n\r\nHTTP/1.2") c := &HostClient{ ClientOptions: &ClientOptions{ Dialer: dialerFunc(func(network, addr string, timeout time.Duration) (network.Conn, error) { return conn, nil }), }, Addr: "foobar", } c.ResponseBodyStream = true req := protocol.AcquireRequest() req.SetRequestURI("http://foobar/baz") req.Header.SetConnectionClose(true) resp := protocol.AcquireResponse() c.Do(context.Background(), req, resp) assert.True(t, conn.isClose) } func TestDialTimeout(t *testing.T) { c := &HostClient{ ClientOptions: &ClientOptions{ DialTimeout: time.Second * 10, Dialer: dialerFunc(func(network, addr string, timeout time.Duration) (network.Conn, error) { assert.DeepEqual(t, time.Second*10, timeout) return nil, errors.New("test error") }), }, Addr: "foobar", } req := protocol.AcquireRequest() req.SetRequestURI("http://foobar/baz") resp := protocol.AcquireResponse() c.Do(context.Background(), req, resp) } func TestContextNil(t *testing.T) { defer func() { v := recover() assert.NotNil(t, v) assert.True(t, fmt.Sprint(v) == "ctx is nil") }() c := &HostClient{} c.Do(nil, nil, nil) //nolint:staticcheck // SA1012: do not pass a nil Context } func TestCalcimeout(t *testing.T) { now := time.Now() tests := []struct { name string deadline time.Time timeout time.Duration expected time.Duration }{ {"zero deadline, positive timeout", time.Time{}, 5 * time.Second, 5 * time.Second}, {"zero deadline, zero timeout", time.Time{}, 0, 0}, {"zero deadline, negative timeout", time.Time{}, -1 * time.Second, 0}, {"future deadline, zero timeout", now.Add(10 * time.Second), 0, 10 * time.Second}, {"future deadline, positive timeout (deadline < timeout)", now.Add(3 * time.Second), 5 * time.Second, 3 * time.Second}, {"future deadline, positive timeout (deadline > timeout)", now.Add(8 * time.Second), 5 * time.Second, 5 * time.Second}, {"past deadline, zero timeout", now.Add(-5 * time.Second), 0, -1}, {"past deadline, positive timeout", now.Add(-5 * time.Second), time.Second, -1}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := calcTimeout(tt.deadline, tt.timeout) diff := result - tt.expected if diff < -50*time.Millisecond || diff > 50*time.Millisecond { t.Errorf("calcTimeout(%v, %v) = %v, expected %v", tt.deadline, tt.timeout, result, tt.expected) } }) } } type mockConnClosed struct { closed bool network.Conn } func (m *mockConnClosed) Close() error { m.closed = true return m.Conn.Close() } // mock CRLF attacking func TestDoNonNilReqResp_releaseConn(t *testing.T) { respStr := "HTTP/1.1 400 OK\nContent-Length: 6\n\n123456" conn := &mockConnClosed{Conn: mock.NewConn(respStr + respStr)} c := &HostClient{ ClientOptions: &ClientOptions{ Dialer: dialerFunc(func(network, addr string, timeout time.Duration) (network.Conn, error) { return conn, nil }), }, } req := protocol.AcquireRequest() resp := protocol.AcquireResponse() req.SetHost("foobar") retry, err := c.doNonNilReqResp(req, resp) assert.False(t, retry) assert.Nil(t, err) assert.DeepEqual(t, resp.StatusCode(), 400) assert.DeepEqual(t, resp.Body(), []byte("123456")) assert.True(t, conn.closed) } ================================================ FILE: pkg/protocol/http1/client_unix_test.go ================================================ // Copyright 2023 CloudWeGo Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // //go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris package http1 import ( "context" "errors" "net/http" "runtime" "sync" "sync/atomic" "testing" "time" "github.com/cloudwego/hertz/internal/testutils" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/network/netpoll" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" ) func TestGcBodyStream(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() srv := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { for range [1024]int{} { w.Write([]byte("hello world\n")) } })} go srv.Serve(ln) addr := ln.Addr().String() c := &HostClient{ ClientOptions: &ClientOptions{ Dialer: netpoll.NewDialer(), ResponseBodyStream: true, }, Addr: addr, } for i := 0; i < 10; i++ { req, resp := protocol.AcquireRequest(), protocol.AcquireResponse() req.SetRequestURI("http://" + addr) req.SetMethod(consts.MethodPost) err := c.Do(context.Background(), req, resp) if err != nil { t.Errorf("client Do error=%v", err.Error()) } } runtime.GC() // wait for gc time.Sleep(100 * time.Millisecond) c.CloseIdleConnections() assert.DeepEqual(t, 0, c.ConnPoolState().TotalConnNum) } func TestMaxConn(t *testing.T) { ln := testutils.NewTestListener(t) defer ln.Close() srv := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.Write([]byte("hello world\n")) })} go srv.Serve(ln) addr := ln.Addr().String() c := &HostClient{ ClientOptions: &ClientOptions{ Dialer: netpoll.NewDialer(), ResponseBodyStream: true, MaxConnWaitTimeout: time.Millisecond * 100, MaxConns: 5, }, Addr: addr, } var successCount int32 var noFreeCount int32 wg := sync.WaitGroup{} for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() req, resp := protocol.AcquireRequest(), protocol.AcquireResponse() req.SetRequestURI("http://" + addr) req.SetMethod(consts.MethodPost) err := c.Do(context.Background(), req, resp) if err != nil { if errors.Is(err, errs.ErrNoFreeConns) { atomic.AddInt32(&noFreeCount, 1) return } t.Errorf("client Do error=%v", err.Error()) } atomic.AddInt32(&successCount, 1) }() } wg.Wait() assert.True(t, atomic.LoadInt32(&successCount) == 5) assert.True(t, atomic.LoadInt32(&noFreeCount) == 5) assert.DeepEqual(t, 0, c.ConnectionCount()) assert.DeepEqual(t, 5, c.WantConnectionCount()) runtime.GC() // wait for gc time.Sleep(100 * time.Millisecond) c.CloseIdleConnections() assert.DeepEqual(t, 0, c.WantConnectionCount()) } ================================================ FILE: pkg/protocol/http1/ext/common.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package ext import ( "bytes" "errors" "fmt" "io" "strings" "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/internal/bytestr" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/hlog" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" ) const maxContentLengthInStream = 8 * 1024 var errBrokenChunk = errs.NewPublic("cannot find crlf at the end of chunk").SetMeta("when read body chunk") func MustPeekBuffered(r network.Reader) []byte { l := r.Len() buf, err := r.Peek(l) if len(buf) == 0 || err != nil { panic(fmt.Sprintf("bufio.Reader.Peek() returned unexpected data (%q, %v)", buf, err)) } return buf } func MustDiscard(r network.Reader, n int) { if err := r.Skip(n); err != nil { panic(fmt.Sprintf("bufio.Reader.Discard(%d) failed: %s", n, err)) } } func ReadRawHeaders(dst, buf []byte) ([]byte, int, error) { n := bytes.IndexByte(buf, '\n') if n < 0 { return dst[:0], 0, errNeedMore } if (n == 1 && buf[0] == '\r') || n == 0 { // empty headers return dst, n + 1, nil } n++ b := buf m := n for { b = b[m:] m = bytes.IndexByte(b, '\n') if m < 0 { return dst, 0, errNeedMore } m++ n += m if (m == 2 && b[0] == '\r') || m == 1 { dst = append(dst, buf[:n]...) return dst, n, nil } } } func WriteBodyChunked(w network.Writer, r io.Reader) error { vbuf := utils.CopyBufPool.Get() buf := vbuf.([]byte) var err error var n int for { n, err = r.Read(buf) if n == 0 { if err == nil { panic("BUG: io.Reader returned 0, nil") } if !errors.Is(err, io.EOF) { hlog.SystemLogger().Warnf("writing chunked response body encountered an error from the reader, "+ "this may cause the short of the content in response body, error: %s", err.Error()) } if err = WriteChunk(w, buf[:0], true); err != nil { break } err = nil break } if err = WriteChunk(w, buf[:n], true); err != nil { break } } utils.CopyBufPool.Put(vbuf) return err } func WriteBodyFixedSize(w network.Writer, r io.Reader, size int64) error { if size == 0 { return nil } if size > consts.MaxSmallFileSize { if err := w.Flush(); err != nil { return err } } if size > 0 { r = io.LimitReader(r, size) } n, err := utils.CopyZeroAlloc(w, r) if n != size && err == nil { err = fmt.Errorf("copied %d bytes from body stream instead of %d bytes", n, size) } return err } func appendBodyFixedSize(r network.Reader, dst []byte, n int) ([]byte, error) { if n == 0 { return dst, nil } offset := len(dst) dstLen := offset + n if cap(dst) < dstLen { b := make([]byte, round2(dstLen)) copy(b, dst) dst = b } dst = dst[:dstLen] // Prefer io.Reader over Peek to avoid holding a ref to the underlying buffer. if rd, ok := r.(io.Reader); ok { rn, err := io.ReadFull(rd, dst[offset:]) return dst[:offset+rn], err } // Peek can get all data, otherwise it will through error buf, err := r.Peek(n) if err != nil { if err == io.EOF { err = io.ErrUnexpectedEOF } return dst[:offset], err } copy(dst[offset:], buf) r.Skip(len(buf)) // nolint: errcheck return dst, nil } func readBodyIdentity(r network.Reader, maxBodySize int, dst []byte) ([]byte, error) { dst = dst[:cap(dst)] if len(dst) == 0 { dst = make([]byte, 1024) } offset := 0 for { nn := r.Len() if nn == 0 { _, err := r.Peek(1) if err != nil { return dst[:offset], nil } nn = r.Len() } if nn >= (len(dst) - offset) { nn = len(dst) - offset } buf, err := r.Peek(nn) if err != nil { return dst[:offset], err } copy(dst[offset:], buf) r.Skip(nn) // nolint: errcheck offset += nn if maxBodySize > 0 && offset > maxBodySize { return dst[:offset], errBodyTooLarge } if len(dst) == offset { n := round2(2 * offset) if maxBodySize > 0 && n > maxBodySize { n = maxBodySize + 1 } b := make([]byte, n) copy(b, dst) dst = b } } } func ReadBody(r network.Reader, contentLength, maxBodySize int, dst []byte) ([]byte, error) { dst = dst[:0] if contentLength >= 0 { if maxBodySize > 0 && contentLength > maxBodySize { return dst, errBodyTooLarge } return appendBodyFixedSize(r, dst, contentLength) } if contentLength == -1 { return readBodyChunked(r, maxBodySize, dst) } return readBodyIdentity(r, maxBodySize, dst) } func LimitedReaderSize(r io.Reader) int64 { lr, ok := r.(*io.LimitedReader) if !ok { return -1 } return lr.N } func readBodyChunked(r network.Reader, maxBodySize int, dst []byte) ([]byte, error) { if len(dst) > 0 { panic("BUG: expected zero-length buffer") } strCRLFLen := len(bytestr.StrCRLF) for { chunkSize, err := utils.ParseChunkSize(r) if err != nil { return dst, err } // If it is the end of chunk, Read CRLF after reading trailer if chunkSize == 0 { return dst, nil } if maxBodySize > 0 && len(dst)+chunkSize > maxBodySize { return dst, errBodyTooLarge } dst, err = appendBodyFixedSize(r, dst, chunkSize+strCRLFLen) if err != nil { return dst, err } if !bytes.Equal(dst[len(dst)-strCRLFLen:], bytestr.StrCRLF) { return dst, errBrokenChunk } dst = dst[:len(dst)-strCRLFLen] } } func round2(n int) int { if n <= 0 { return 0 } n-- x := uint(0) for n > 0 { n >>= 1 x++ } return 1 << x } func WriteChunk(w network.Writer, b []byte, withFlush bool) error { n := len(b) sz := bytesconv.EncodedIntHexLen(uint64(n)) + 2 // len + CRLF if n > 0 { sz += n + 2 // data + CRLF } wb, err := w.Malloc(sz) if err != nil { return err } wb = bytesconv.AppendIntHex(wb[:0], uint64(n)) wb = append(wb, '\r', '\n') if n > 0 { wb = append(wb, b...) wb = append(wb, '\r', '\n') } if len(wb) != sz { panic("[BUG] len mismatch") } if !withFlush { return nil } return w.Flush() } func isOnlyCRLF(b []byte) bool { for _, ch := range b { if ch != '\r' && ch != '\n' { return false } } return true } func BufferSnippet(b []byte) string { n := len(b) start := 20 end := n - start if start >= end { start = n end = n } bStart, bEnd := b[:start], b[end:] if len(bEnd) == 0 { return fmt.Sprintf("%q", b) } return fmt.Sprintf("%q...%q", bStart, bEnd) } func normalizeHeaderValue(ov, ob []byte, headerLength int) (nv, nb []byte, nhl int) { nv = ov length := len(ov) if length <= 0 { return } write := 0 shrunk := 0 lineStart := false for read := 0; read < length; read++ { c := ov[read] if c == '\r' || c == '\n' { shrunk++ if c == '\n' { lineStart = true } continue } else if lineStart && c == '\t' { c = ' ' } else { lineStart = false } nv[write] = c write++ } nv = nv[:write] copy(ob[write:], ob[write+shrunk:]) // Check if we need to skip \r\n or just \n skip := 0 if ob[write] == '\r' { if ob[write+1] == '\n' { skip += 2 } else { skip++ } } else if ob[write] == '\n' { skip++ } nb = ob[write+skip : len(ob)-shrunk] nhl = headerLength - shrunk return } func stripSpace(b []byte) []byte { for len(b) > 0 && b[0] == ' ' { b = b[1:] } for len(b) > 0 && b[len(b)-1] == ' ' { b = b[:len(b)-1] } return b } func SkipTrailer(r network.Reader) error { n := 1 for { err := trySkipTrailer(r, n) if err == nil { return nil } if !errors.Is(err, errs.ErrNeedMore) { return err } // No more data available on the wire, try block peek(by netpoll) if n == r.Len() { n++ continue } n = r.Len() } } func trySkipTrailer(r network.Reader, n int) error { b, err := r.Peek(n) if len(b) == 0 { // Return ErrTimeout on any timeout. if err != nil && strings.Contains(err.Error(), "timeout") { return errs.New(errs.ErrTimeout, errs.ErrorTypePublic, "read response header") } if n == 1 || err == io.EOF { return io.EOF } return errs.NewPublicf("error when reading request trailer: %w", err) } b = MustPeekBuffered(r) headersLen, errParse := skipTrailer(b) if errParse != nil { if err == io.EOF { return err } return HeaderError("response", err, errParse, b) } MustDiscard(r, headersLen) return nil } func skipTrailer(buf []byte) (int, error) { skip := 0 strCRLFLen := len(bytestr.StrCRLF) for { index := bytes.Index(buf, bytestr.StrCRLF) if index == -1 { return 0, errs.ErrNeedMore } buf = buf[index+strCRLFLen:] skip += index + strCRLFLen if index == 0 { return skip, nil } } } func ReadTrailer(t *protocol.Trailer, r network.Reader) error { n := 1 for { err := tryReadTrailer(t, r, n) if err == nil { return nil } if !errors.Is(err, errs.ErrNeedMore) { t.ResetSkipNormalize() return err } // No more data available on the wire, try block peek(by netpoll) if n == r.Len() { n++ continue } n = r.Len() } } func tryReadTrailer(t *protocol.Trailer, r network.Reader, n int) error { b, err := r.Peek(n) if len(b) == 0 { // Return ErrTimeout on any timeout. if err != nil && strings.Contains(err.Error(), "timeout") { return errs.New(errs.ErrTimeout, errs.ErrorTypePublic, "read response header") } if n == 1 || err == io.EOF { return io.EOF } return errs.NewPublicf("error when reading request trailer: %w", err) } b = MustPeekBuffered(r) headersLen, errParse := parseTrailer(t, b) if errParse != nil { if err == io.EOF { return err } return HeaderError("response", err, errParse, b) } MustDiscard(r, headersLen) return nil } func parseTrailer(t *protocol.Trailer, buf []byte) (int, error) { // Skip any 0 length chunk. if buf[0] == '0' { skip := len(bytestr.StrCRLF) + 1 if len(buf) < skip { return 0, io.EOF } buf = buf[skip:] } var s HeaderScanner s.B = buf s.DisableNormalizing = t.IsDisableNormalizing() var err error for s.Next() { if len(s.Key) > 0 { if bytes.IndexByte(s.Key, ' ') != -1 || bytes.IndexByte(s.Key, '\t') != -1 { err = fmt.Errorf("invalid trailer key %q", s.Key) continue } err = t.UpdateArgBytes(s.Key, s.Value) } } if s.Err != nil { return 0, s.Err } if err != nil { return 0, err } return s.HLen, nil } // writeTrailer writes response trailer to w func WriteTrailer(t *protocol.Trailer, w network.Writer) error { _, err := w.WriteBinary(t.Header()) return err } ================================================ FILE: pkg/protocol/http1/ext/common_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ext import ( "bytes" "errors" "io" "strings" "testing" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/hlog" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/test/mock" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/netpoll" ) func Test_stripSpace(t *testing.T) { a := stripSpace([]byte(" a")) b := stripSpace([]byte("b ")) c := stripSpace([]byte(" c ")) assert.DeepEqual(t, []byte("a"), a) assert.DeepEqual(t, []byte("b"), b) assert.DeepEqual(t, []byte("c"), c) } func Test_bufferSnippet(t *testing.T) { a := make([]byte, 39) b := make([]byte, 41) assert.False(t, strings.Contains(BufferSnippet(a), "\"...\"")) assert.True(t, strings.Contains(BufferSnippet(b), "\"...\"")) } func Test_isOnlyCRLF(t *testing.T) { assert.True(t, isOnlyCRLF([]byte("\r\n"))) assert.True(t, isOnlyCRLF([]byte("\n"))) } func TestReadTrailer(t *testing.T) { exceptedTrailers := map[string]string{"Hertz": "test"} zr := mock.NewZeroCopyReader("0\r\nHertz: test\r\n\r\n") trailer := protocol.Trailer{} keys := make([]string, 0, len(exceptedTrailers)) for k := range exceptedTrailers { keys = append(keys, k) } trailer.SetTrailers([]byte(strings.Join(keys, ", "))) err := ReadTrailer(&trailer, zr) if err != nil { t.Fatalf("Cannot read trailer: %v", err) } for k, v := range exceptedTrailers { got := trailer.Peek(k) if !bytes.Equal(got, []byte(v)) { t.Fatalf("Unexpected trailer %q. Expected %q. Got %q", k, v, got) } } } func TestReadTrailerError(t *testing.T) { // with bad trailer zr := mock.NewZeroCopyReader("0\r\nHertz: test\r\nContent-Type: aaa\r\n\r\n") trailer := protocol.Trailer{} err := ReadTrailer(&trailer, zr) if err == nil { t.Fatalf("expecting error.") } // eof er := mock.EOFReader{} trailer = protocol.Trailer{} err = ReadTrailer(&trailer, &er) assert.DeepEqual(t, io.EOF, err) } func TestReadTrailer1(t *testing.T) { exceptedTrailers := map[string]string{} zr := mock.NewZeroCopyReader("0\r\n\r\n") trailer := protocol.Trailer{} err := ReadTrailer(&trailer, zr) if err != nil { t.Fatalf("Cannot read trailer: %v", err) } for k, v := range exceptedTrailers { got := trailer.Peek(k) if !bytes.Equal(got, []byte(v)) { t.Fatalf("Unexpected trailer %q. Expected %q. Got %q", k, v, got) } } } func TestReadRawHeaders(t *testing.T) { s := "HTTP/1.1 200 OK\r\n" + "EmptyValue1:\r\n" + "Content-Type: foo/bar;\r\n\tnewline;\r\n another/newline\r\n" + "Foo: Bar\r\n" + "Multi-Line: one;\r\n two\r\n" + "Values: v1;\r\n v2; v3;\r\n v4;\tv5\r\n" + "Content-Length: 5\r\n\r\n" + "HELLOaaa" var dst []byte rawHeaders, index, err := ReadRawHeaders(dst, []byte(s)) assert.Nil(t, err) assert.DeepEqual(t, s[:index], string(rawHeaders)) } func TestBodyChunked(t *testing.T) { var log bytes.Buffer hlog.SetOutput(&log) body := "foobar baz aaa bbb ccc" chunk := "16\r\nfoobar baz aaa bbb ccc\r\n0\r\n" b := bytes.NewBufferString(body) var w bytes.Buffer zw := netpoll.NewWriter(&w) WriteBodyChunked(zw, b) assert.DeepEqual(t, chunk, w.String()) zr := mock.NewZeroCopyReader(chunk) rb, err := ReadBody(zr, -1, 0, nil) assert.Nil(t, err) assert.DeepEqual(t, body, string(rb)) assert.DeepEqual(t, 0, log.Len()) } func TestBrokenBodyChunked(t *testing.T) { brokenReader := mock.NewBrokenConn("") var log bytes.Buffer hlog.SetOutput(&log) var w bytes.Buffer zw := netpoll.NewWriter(&w) err := WriteBodyChunked(zw, brokenReader) assert.Nil(t, err) assert.DeepEqual(t, []byte("0\r\n"), w.Bytes()) assert.True(t, bytes.Contains(log.Bytes(), []byte("writing chunked response body encountered an error from the reader"))) } func TestBodyFixedSize(t *testing.T) { body := mock.CreateFixedBody(10) b := bytes.NewBuffer(body) var w bytes.Buffer zw := netpoll.NewWriter(&w) WriteBodyFixedSize(zw, b, int64(len(body))) assert.DeepEqual(t, body, w.Bytes()) zr := mock.NewZeroCopyReader(string(body)) rb, err := ReadBody(zr, len(body), 0, nil) assert.Nil(t, err) assert.DeepEqual(t, body, rb) } func TestBodyFixedSizeQuickPath(t *testing.T) { conn := mock.NewBrokenConn("") err := WriteBodyFixedSize(conn.Writer(), conn, 0) assert.Nil(t, err) } func TestBodyIdentity(t *testing.T) { body := mock.CreateFixedBody(1024) zr := mock.NewZeroCopyReader(string(body)) rb, err := ReadBody(zr, -2, 0, nil) assert.Nil(t, err) assert.DeepEqual(t, string(body), string(rb)) } func TestBodySkipTrailer(t *testing.T) { t.Run("TestBodySkipTrailer", func(t *testing.T) { body := mock.CreateFixedBody(10) trailer := map[string]string{"Foo": "chunked shit"} chunkedBody := mock.CreateChunkedBody(body, trailer, true) r := mock.NewSlowReadConn(string(chunkedBody)) err := SkipTrailer(r) assert.Nil(t, err) _, err = r.ReadByte() assert.NotNil(t, err) assert.True(t, errors.Is(err, netpoll.ErrEOF)) }) t.Run("TestBodySkipTrailerError", func(t *testing.T) { // timeout error sr := mock.NewSlowReadConn("") err := SkipTrailer(sr) assert.NotNil(t, err) assert.True(t, errors.Is(err, errs.ErrTimeout)) // EOF error er := &mock.EOFReader{} err = SkipTrailer(er) assert.NotNil(t, err) assert.True(t, errors.Is(err, io.EOF)) }) } ================================================ FILE: pkg/protocol/http1/ext/error.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package ext import ( "errors" "fmt" "io" errs "github.com/cloudwego/hertz/pkg/common/errors" ) var ( errNeedMore = errs.New(errs.ErrNeedMore, errs.ErrorTypePublic, "cannot find trailing lf") errBodyTooLarge = errs.New(errs.ErrBodyTooLarge, errs.ErrorTypePublic, "ext") ) func HeaderError(typ string, err, errParse error, b []byte) error { if !errors.Is(errParse, errs.ErrNeedMore) { return headerErrorMsg(typ, errParse, b) } if err == nil { return errNeedMore } // Buggy servers may leave trailing CRLFs after http body. // Treat this case as EOF. if isOnlyCRLF(b) { return io.EOF } return headerErrorMsg(typ, err, b) } func headerErrorMsg(typ string, err error, b []byte) error { return errs.NewPublic(fmt.Sprintf("error when reading %s headers: %s. Buffer size=%d, contents: %s", typ, err, len(b), BufferSnippet(b))) } ================================================ FILE: pkg/protocol/http1/ext/headerscanner.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package ext import ( "bytes" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/utils" ) var errInvalidName = errs.NewPublic("invalid header name") type HeaderScanner struct { B []byte Key []byte Value []byte Err error // HLen stores header subslice len HLen int DisableNormalizing bool // by checking whether the Next line contains a colon or not to tell // it's a header entry or a multi line value of current header entry. // the side effect of this operation is that we know the index of the // Next colon and new line, so this can be used during Next iteration, // instead of find them again. nextColon int nextNewLine int initialized bool } type HeaderValueScanner struct { B []byte Value []byte } func (s *HeaderScanner) Next() bool { if !s.initialized { s.nextColon = -1 s.nextNewLine = -1 s.initialized = true } bLen := len(s.B) if bLen >= 2 && s.B[0] == '\r' && s.B[1] == '\n' { s.B = s.B[2:] s.HLen += 2 return false } if bLen >= 1 && s.B[0] == '\n' { s.B = s.B[1:] s.HLen++ return false } var n int if s.nextColon >= 0 { n = s.nextColon s.nextColon = -1 } else { n = bytes.IndexByte(s.B, ':') // There can't be a \n inside the header name, check for this. x := bytes.IndexByte(s.B, '\n') if x < 0 { // A header name should always at some point be followed by a \n // even if it's the one that terminates the header block. s.Err = errNeedMore return false } if x < n { // There was a \n before the : s.Err = errInvalidName return false } } if n < 0 { s.Err = errNeedMore return false } s.Key = s.B[:n] utils.NormalizeHeaderKey(s.Key, s.DisableNormalizing) n++ for len(s.B) > n && s.B[n] == ' ' { n++ // the newline index is a relative index, and lines below trimmed `s.b` by `n`, // so the relative newline index also shifted forward. it's safe to decrease // to a minus value, it means it's invalid, and will find the newline again. s.nextNewLine-- } s.HLen += n s.B = s.B[n:] if s.nextNewLine >= 0 { n = s.nextNewLine s.nextNewLine = -1 } else { n = bytes.IndexByte(s.B, '\n') } if n < 0 { s.Err = errNeedMore return false } isMultiLineValue := false for { if n+1 >= len(s.B) { break } if s.B[n+1] != ' ' && s.B[n+1] != '\t' { break } d := bytes.IndexByte(s.B[n+1:], '\n') if d <= 0 { break } else if d == 1 && s.B[n+1] == '\r' { break } e := n + d + 1 if c := bytes.IndexByte(s.B[n+1:e], ':'); c >= 0 { s.nextColon = c s.nextNewLine = d - c - 1 break } isMultiLineValue = true n = e } if n >= len(s.B) { s.Err = errNeedMore return false } oldB := s.B s.Value = s.B[:n] s.HLen += n + 1 s.B = s.B[n+1:] if n > 0 && s.Value[n-1] == '\r' { n-- } for n > 0 && s.Value[n-1] == ' ' { n-- } s.Value = s.Value[:n] if isMultiLineValue { s.Value, s.B, s.HLen = normalizeHeaderValue(s.Value, oldB, s.HLen) } return true } func (s *HeaderValueScanner) next() bool { b := s.B if len(b) == 0 { return false } n := bytes.IndexByte(b, ',') if n < 0 { s.Value = stripSpace(b) s.B = b[len(b):] return true } s.Value = stripSpace(b[:n]) s.B = b[n+1:] return true } func HasHeaderValue(s, value []byte) bool { var vs HeaderValueScanner vs.B = s for vs.next() { if utils.CaseInsensitiveCompare(vs.Value, value) { return true } } return false } ================================================ FILE: pkg/protocol/http1/ext/headerscanner_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package ext import ( "bufio" "errors" "net/http" "strings" "testing" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestHasHeaderValue(t *testing.T) { s := []byte("Expect: 100-continue, User-Agent: foo, Host: 127.0.0.1, Connection: Keep-Alive, Content-Length: 5") assert.True(t, HasHeaderValue(s, []byte("Connection: Keep-Alive"))) assert.False(t, HasHeaderValue(s, []byte("Connection: Keep-Alive1"))) } func TestResponseHeaderMultiLineValue(t *testing.T) { firstLine := "HTTP/1.1 200 OK\r\n" rawHeaders := "EmptyValue1:\r\n" + "Content-Type: foo/bar;\r\n\tnewline;\r\n another/newline\r\n" + "Foo: Bar\r\n" + "Multi-Line: one;\r\n two\r\n" + "Values: v1;\r\n v2; v3;\r\n v4;\tv5\r\n" + "\r\n" // compared with http response response, err := http.ReadResponse(bufio.NewReader(strings.NewReader(firstLine+rawHeaders)), nil) assert.Nil(t, err) defer func() { response.Body.Close() }() hs := &HeaderScanner{} hs.B = []byte(rawHeaders) hs.DisableNormalizing = false hmap := make(map[string]string, len(response.Header)) for hs.Next() { if len(hs.Key) > 0 { hmap[string(hs.Key)] = string(hs.Value) } } for name, vals := range response.Header { got := hmap[name] want := vals[0] assert.DeepEqual(t, want, got) } } func TestHeaderScannerError(t *testing.T) { t.Run("TestHeaderScannerErrorInvalidName", func(t *testing.T) { rawHeaders := "Host: go.dev\r\nGopher-New-\r\n Line: This is a header on multiple lines\r\n\r\n" testTestHeaderScannerError(t, rawHeaders, errInvalidName) }) t.Run("TestHeaderScannerErrorNeedMore", func(t *testing.T) { rawHeaders := "This is a header on multiple lines" testTestHeaderScannerError(t, rawHeaders, errs.ErrNeedMore) rawHeaders = "Gopher-New-\r\n Line" testTestHeaderScannerError(t, rawHeaders, errs.ErrNeedMore) }) } func testTestHeaderScannerError(t *testing.T, rawHeaders string, expectError error) { hs := &HeaderScanner{} hs.B = []byte(rawHeaders) hs.DisableNormalizing = false for hs.Next() { } assert.NotNil(t, hs.Err) assert.True(t, errors.Is(hs.Err, expectError)) } ================================================ FILE: pkg/protocol/http1/ext/stream.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package ext import ( "bytes" "io" "sync" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/pkg/common/bytebufferpool" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/protocol" ) var ( errChunkedStream = errs.New(errs.ErrChunkedStream, errs.ErrorTypePublic, nil) bodyStreamPool = sync.Pool{ New: func() interface{} { return &bodyStream{} }, } ) // Deprecated: Use github.com/cloudwego/hertz/pkg/protocol.NoBody instead. var NoBody = protocol.NoBody type bodyStream struct { prefetchedBytes *bytes.Reader reader network.Reader trailer *protocol.Trailer offset int contentLength int chunkLeft int // whether the chunk has reached the EOF chunkEOF bool } func ReadBodyWithStreaming(zr network.Reader, contentLength, maxBodySize int, dst []byte) (b []byte, err error) { if contentLength == -1 { // handled in requestStream.Read() return b, errChunkedStream } dst = dst[:0] if maxBodySize <= 0 { maxBodySize = maxContentLengthInStream } readN := maxBodySize if readN > contentLength { readN = contentLength } if readN > maxContentLengthInStream { readN = maxContentLengthInStream } if contentLength >= 0 && maxBodySize >= contentLength { b, err = appendBodyFixedSize(zr, dst, readN) } else { b, err = readBodyIdentity(zr, readN, dst) } if err != nil { return b, err } if contentLength > maxBodySize { return b, errBodyTooLarge } return b, nil } func AcquireBodyStream(b *bytebufferpool.ByteBuffer, r network.Reader, t *protocol.Trailer, contentLength int) io.Reader { rs := bodyStreamPool.Get().(*bodyStream) rs.prefetchedBytes = bytes.NewReader(b.B) rs.reader = r rs.contentLength = contentLength rs.trailer = t rs.chunkEOF = false return rs } func (rs *bodyStream) Read(p []byte) (int, error) { defer func() { if rs.reader != nil { rs.reader.Release() //nolint:errcheck } }() if rs.contentLength == -1 { if rs.chunkEOF { return 0, io.EOF } if rs.chunkLeft == 0 { chunkSize, err := utils.ParseChunkSize(rs.reader) if err != nil { return 0, err } if chunkSize == 0 { err = ReadTrailer(rs.trailer, rs.reader) if err == nil { rs.chunkEOF = true err = io.EOF } return 0, err } rs.chunkLeft = chunkSize } bytesToRead := len(p) if bytesToRead > rs.chunkLeft { bytesToRead = rs.chunkLeft } src, err := rs.reader.Peek(bytesToRead) copied := copy(p, src) rs.reader.Skip(copied) // nolint: errcheck rs.chunkLeft -= copied if err != nil { if err == io.EOF { err = io.ErrUnexpectedEOF } return copied, err } if rs.chunkLeft == 0 { err = utils.SkipCRLF(rs.reader) if err == io.EOF { err = io.ErrUnexpectedEOF } } return copied, err } if rs.offset == rs.contentLength { return 0, io.EOF } var n int var err error // read from the pre-read buffer if int(rs.prefetchedBytes.Size()) > rs.offset { n, err = rs.prefetchedBytes.Read(p) rs.offset += n if rs.offset == rs.contentLength { return n, io.EOF } if err != nil || len(p) == n { return n, err } } // read from the wire m := len(p) - n remain := rs.contentLength - rs.offset if m > remain { m = remain } if conn, ok := rs.reader.(io.Reader); ok { m, err = conn.Read(p[n:]) } else { var tmp []byte tmp, err = rs.reader.Peek(m) m = copy(p[n:], tmp) rs.reader.Skip(m) // nolint: errcheck } rs.offset += m n += m if err != nil { // the data on stream may be incomplete if err == io.EOF { if rs.offset != rs.contentLength && rs.contentLength != -2 { err = io.ErrUnexpectedEOF } // ensure that skipRest works fine rs.offset = rs.contentLength } return n, err } if rs.offset == rs.contentLength { err = io.EOF } return n, err } func (rs *bodyStream) skipRest() error { // The body length doesn't exceed the maxContentLengthInStream or // the bodyStream has been skip rest if rs.prefetchedBytes == nil { return nil } // the request is chunked encoding if rs.contentLength == -1 { if rs.chunkEOF { return nil } strCRLFLen := len(bytestr.StrCRLF) for { chunkSize, err := utils.ParseChunkSize(rs.reader) if err != nil { return err } if chunkSize == 0 { rs.chunkEOF = true return SkipTrailer(rs.reader) } err = rs.reader.Skip(chunkSize) if err != nil { return err } crlf, err := rs.reader.Peek(strCRLFLen) if err != nil { return err } if !bytes.Equal(crlf, bytestr.StrCRLF) { return errBrokenChunk } err = rs.reader.Skip(strCRLFLen) if err != nil { return err } // After Skip, the buffer needs to be released to prevent OOM if there are too much data on conn. err = rs.reader.Release() if err != nil { return err } } } // max value of pSize is 8193, it's safe. pSize := int(rs.prefetchedBytes.Size()) if rs.contentLength <= pSize || rs.offset == rs.contentLength { return nil } needSkipLen := 0 if rs.offset > pSize { needSkipLen = rs.contentLength - rs.offset } else { needSkipLen = rs.contentLength - pSize } // must skip size for { skip := rs.reader.Len() if skip == 0 { _, err := rs.reader.Peek(1) if err != nil { return err } skip = rs.reader.Len() } if skip > needSkipLen { skip = needSkipLen } err := rs.reader.Skip(skip) if err != nil { return err } // After Skip, the buffer needs to be released to prevent OOM if there are too much data on conn. err = rs.reader.Release() if err != nil { return err } needSkipLen -= skip if needSkipLen == 0 { return nil } } } // ReleaseBodyStream releases the body stream. // Error of skipRest may be returned if there is one. // // NOTE: Be careful to use this method unless you know what it's for. func ReleaseBodyStream(requestReader io.Reader) (err error) { if rs, ok := requestReader.(*bodyStream); ok { err = rs.skipRest() rs.reset() bodyStreamPool.Put(rs) } return } func (rs *bodyStream) reset() { rs.prefetchedBytes = nil rs.offset = 0 rs.reader = nil rs.trailer = nil rs.chunkEOF = false rs.chunkLeft = 0 rs.contentLength = 0 } ================================================ FILE: pkg/protocol/http1/ext/stream_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ext import ( "bytes" "errors" "fmt" "io" "testing" "github.com/cloudwego/hertz/pkg/common/bytebufferpool" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/test/mock" "github.com/cloudwego/hertz/pkg/protocol" ) func createChunkedBody(body, rest []byte, trailer map[string]string, hasTrailer bool) []byte { var b []byte chunkSize := 1 for len(body) > 0 { if chunkSize > len(body) { chunkSize = len(body) } b = append(b, []byte(fmt.Sprintf("%x\r\n", chunkSize))...) b = append(b, body[:chunkSize]...) b = append(b, []byte("\r\n")...) body = body[chunkSize:] chunkSize++ } if hasTrailer { b = append(b, "0\r\n"...) for k, v := range trailer { b = append(b, k...) b = append(b, ": "...) b = append(b, v...) b = append(b, "\r\n"...) } b = append(b, "\r\n"...) } return append(b, rest...) } func testChunkedSkipRest(t *testing.T, data, rest string) { var pool bytebufferpool.Pool reader := mock.NewZeroCopyReader(data) bs := AcquireBodyStream(pool.Get(), reader, &protocol.Trailer{}, -1) err := bs.(*bodyStream).skipRest() assert.Nil(t, err) rest_data, err := io.ReadAll(reader) assert.Nil(t, err) assert.DeepEqual(t, rest, string(rest_data)) } func testChunkedSkipRestWithBodySize(t *testing.T, bodySize int) { body := mock.CreateFixedBody(bodySize) rest := mock.CreateFixedBody(bodySize) data := createChunkedBody(body, rest, map[string]string{"foo": "bar"}, true) testChunkedSkipRest(t, string(data), string(rest)) } func TestChunkedSkipRest(t *testing.T) { t.Parallel() testChunkedSkipRest(t, "0\r\n\r\n", "") testChunkedSkipRest(t, "0\r\n\r\nHTTP/1.1 / POST", "HTTP/1.1 / POST") testChunkedSkipRest(t, "0\r\nHertz: test\r\nfoo: bar\r\n\r\nHTTP/1.1 / POST", "HTTP/1.1 / POST") testChunkedSkipRestWithBodySize(t, 5) // medium-size body testChunkedSkipRestWithBodySize(t, 43488) // big body testChunkedSkipRestWithBodySize(t, 3*1024*1024) } func TestBodyStream_Reset(t *testing.T) { t.Parallel() bs := bodyStream{ prefetchedBytes: bytes.NewReader([]byte("aaa")), reader: mock.NewZeroCopyReader("bbb"), trailer: &protocol.Trailer{}, offset: 10, contentLength: 20, chunkLeft: 50, chunkEOF: true, } bs.reset() assert.Nil(t, bs.prefetchedBytes) assert.Nil(t, bs.reader) assert.Nil(t, bs.trailer) assert.DeepEqual(t, 0, bs.offset) assert.DeepEqual(t, 0, bs.contentLength) assert.DeepEqual(t, 0, bs.chunkLeft) assert.False(t, bs.chunkEOF) } func TestReadBodyWithStreaming(t *testing.T) { t.Run("TestBodyFixedSize", func(t *testing.T) { bodySize := 1024 body := mock.CreateFixedBody(bodySize) reader := mock.NewZeroCopyReader(string(body)) dst, err := ReadBodyWithStreaming(reader, bodySize, -1, nil) assert.Nil(t, err) assert.DeepEqual(t, body, dst) }) t.Run("TestBodyFixedSizeMaxContentLength", func(t *testing.T) { bodySize := 8 * 1024 * 2 body := mock.CreateFixedBody(bodySize) reader := mock.NewZeroCopyReader(string(body)) dst, err := ReadBodyWithStreaming(reader, bodySize, 8*1024*10, nil) assert.Nil(t, err) assert.DeepEqual(t, body[:maxContentLengthInStream], dst) }) t.Run("TestBodyIdentity", func(t *testing.T) { bodySize := 1024 body := mock.CreateFixedBody(bodySize) reader := mock.NewZeroCopyReader(string(body)) dst, err := ReadBodyWithStreaming(reader, -2, 512, nil) assert.Nil(t, err) assert.DeepEqual(t, body, dst) }) t.Run("TestErrBodyTooLarge", func(t *testing.T) { bodySize := 2048 body := mock.CreateFixedBody(bodySize) reader := mock.NewZeroCopyReader(string(body)) dst, err := ReadBodyWithStreaming(reader, bodySize, 1024, nil) assert.True(t, errors.Is(err, errBodyTooLarge)) assert.DeepEqual(t, body[:len(dst)], dst) }) t.Run("TestErrChunkedStream", func(t *testing.T) { bodySize := 1024 body := mock.CreateFixedBody(bodySize) reader := mock.NewZeroCopyReader(string(body)) dst, err := ReadBodyWithStreaming(reader, -1, bodySize, nil) assert.True(t, errors.Is(err, errChunkedStream)) assert.Nil(t, dst) }) } func TestBodyStream(t *testing.T) { t.Run("TestBodyStreamPrereadBuffer", func(t *testing.T) { bodySize := 1024 body := mock.CreateFixedBody(bodySize) byteBuffer := &bytebufferpool.ByteBuffer{} byteBuffer.Set(body) bs := AcquireBodyStream(byteBuffer, mock.NewSlowReadConn(""), nil, len(body)) defer func() { ReleaseBodyStream(bs) }() b := make([]byte, bodySize) err := bodyStreamRead(bs, b) assert.Nil(t, err) assert.DeepEqual(t, len(body), len(b)) assert.DeepEqual(t, string(body), string(b)) }) t.Run("TestBodyStreamRelease", func(t *testing.T) { bodySize := 1024 body := mock.CreateFixedBody(bodySize) byteBuffer := &bytebufferpool.ByteBuffer{} byteBuffer.Set(body) bs := AcquireBodyStream(byteBuffer, mock.NewSlowReadConn(string(body)), nil, bodySize*2) err := ReleaseBodyStream(bs) assert.Nil(t, err) }) t.Run("TestBodyStreamChunked", func(t *testing.T) { bodySize := 5 body := mock.CreateFixedBody(bodySize) expectedTrailer := map[string]string{"Foo": "chunked shit"} chunkedBody := mock.CreateChunkedBody(body, expectedTrailer, true) byteBuffer := &bytebufferpool.ByteBuffer{} byteBuffer.Set(chunkedBody) bs := AcquireBodyStream(byteBuffer, mock.NewSlowReadConn(string(chunkedBody)), &protocol.Trailer{}, -1) defer func() { ReleaseBodyStream(bs) }() b := make([]byte, bodySize) err := bodyStreamRead(bs, b) assert.Nil(t, err) assert.DeepEqual(t, len(body), len(b)) assert.DeepEqual(t, string(body), string(b)) }) t.Run("TestBodyStreamReadFromWire", func(t *testing.T) { bodySize := 1024 body := mock.CreateFixedBody(bodySize) byteBuffer := &bytebufferpool.ByteBuffer{} byteBuffer.Set(body) rcBodySize := 128 rcBody := mock.CreateFixedBody(rcBodySize) bs := AcquireBodyStream(byteBuffer, mock.NewSlowReadConn(string(rcBody)), nil, -2) defer func() { ReleaseBodyStream(bs) }() b := make([]byte, bodySize) err := bodyStreamRead(bs, b) assert.Nil(t, err) assert.DeepEqual(t, len(body), len(b)) assert.DeepEqual(t, string(body), string(b)) }) } func bodyStreamRead(bs io.Reader, b []byte) (err error) { nb := 0 for { p := make([]byte, 64) n, rErr := bs.Read(p) if n > 0 { copy(b[nb:], p[:]) nb = nb + n } if rErr != nil { if rErr != io.EOF { err = rErr } break } } return } ================================================ FILE: pkg/protocol/http1/factory/client.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package factory import ( "github.com/cloudwego/hertz/pkg/protocol/client" "github.com/cloudwego/hertz/pkg/protocol/http1" "github.com/cloudwego/hertz/pkg/protocol/suite" ) var _ suite.ClientFactory = (*clientFactory)(nil) type clientFactory struct { option *http1.ClientOptions } func (s *clientFactory) NewHostClient() (client client.HostClient, err error) { return http1.NewHostClient(s.option), nil } func NewClientFactory(option *http1.ClientOptions) suite.ClientFactory { return &clientFactory{ option: option, } } ================================================ FILE: pkg/protocol/http1/factory/server.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package factory import ( "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/http1" "github.com/cloudwego/hertz/pkg/protocol/suite" ) var _ suite.ServerFactory = (*serverFactory)(nil) type serverFactory struct { option *http1.Option } // New is called by Hertz during engine.Run() func (s *serverFactory) New(core suite.Core) (server protocol.Server, err error) { serv := http1.NewServer() serv.Option = *s.option serv.Core = core return serv, nil } func NewServerFactory(option *http1.Option) suite.ServerFactory { return &serverFactory{ option: option, } } ================================================ FILE: pkg/protocol/http1/proxy/proxy.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * Copyright 2016 The Go Authors. All rights reserved. * Use of this source code is governed by a BSD-style * license that can be found in the LICENSE file. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package proxy import ( "bytes" "context" "crypto/tls" "encoding/base64" "time" "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" reqI "github.com/cloudwego/hertz/pkg/protocol/http1/req" respI "github.com/cloudwego/hertz/pkg/protocol/http1/resp" ) func SetupProxy(conn network.Conn, addr string, proxyURI *protocol.URI, tlsConfig *tls.Config, isTLS bool, dialer network.Dialer) (network.Conn, error) { var err error if bytes.Equal(proxyURI.Scheme(), bytestr.StrHTTPS) { conn, err = dialer.AddTLS(conn, tlsConfig) if err != nil { return nil, err } } switch { case proxyURI == nil: // Do nothing. Not using a proxy. case isTLS: // target addr is https connectReq, connectResp := protocol.AcquireRequest(), protocol.AcquireResponse() defer func() { protocol.ReleaseRequest(connectReq) protocol.ReleaseResponse(connectResp) }() SetProxyAuthHeader(&connectReq.Header, proxyURI) connectReq.SetMethod(consts.MethodConnect) connectReq.SetHost(addr) // Skip response body when send CONNECT request. connectResp.SkipBody = true // If there's no done channel (no deadline or cancellation // from the caller possible), at least set some (long) // timeout here. This will make sure we don't block forever // and leak a goroutine if the connection stops replying // after the TCP connect. connectCtx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) defer cancel() didReadResponse := make(chan struct{}) // closed after CONNECT write+read is done or fails // Write the CONNECT request & read the response. go func() { defer close(didReadResponse) err = reqI.Write(connectReq, conn) if err != nil { return } err = conn.Flush() if err != nil { return } err = respI.Read(connectResp, conn) }() select { case <-connectCtx.Done(): conn.Close() <-didReadResponse return nil, connectCtx.Err() case <-didReadResponse: } if err != nil { conn.Close() return nil, err } if connectResp.StatusCode() != consts.StatusOK { conn.Close() return nil, errors.NewPublic(consts.StatusMessage(connectResp.StatusCode())) } } if proxyURI != nil && isTLS { conn, err = dialer.AddTLS(conn, tlsConfig) if err != nil { return nil, err } } return conn, nil } func SetProxyAuthHeader(h *protocol.RequestHeader, proxyURI *protocol.URI) { if username := proxyURI.Username(); username != nil { password := proxyURI.Password() auth := base64.StdEncoding.EncodeToString(bytesconv.S2b(bytesconv.B2s(username) + ":" + bytesconv.B2s(password))) h.Set("Proxy-Authorization", "Basic "+auth) } } ================================================ FILE: pkg/protocol/http1/req/header.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package req import ( "bytes" "errors" "fmt" "io" "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/internal/bytestr" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/protocol/http1/ext" ) var errEOFReadHeader = errs.NewPublic("error when reading request headers: EOF") // Write writes request header to w. func WriteHeader(h *protocol.RequestHeader, w network.Writer) error { header := h.Header() _, err := w.WriteBinary(header) return err } func ReadHeader(h *protocol.RequestHeader, r network.Reader) error { return ReadHeaderWithLimit(h, r, 0) } func ReadHeaderWithLimit(h *protocol.RequestHeader, r network.Reader, maxHeaderBytes int) error { n := 1 for { err := tryReadWithLimit(h, r, n, maxHeaderBytes) if err == nil { return nil } if !errors.Is(err, errs.ErrNeedMore) { h.ResetSkipNormalize() return err } // No more data available on the wire, try block peek if n == r.Len() { n++ continue } n = r.Len() } } func tryReadWithLimit(h *protocol.RequestHeader, r network.Reader, n, maxHeaderBytes int) error { h.ResetSkipNormalize() b, err := r.Peek(n) if len(b) == 0 { if err != io.EOF { return err } // n == 1 on the first read for the request. if n == 1 { // We didn't read a single byte. return errs.New(errs.ErrNothingRead, errs.ErrorTypePrivate, err) } return errEOFReadHeader } b = ext.MustPeekBuffered(r) if maxHeaderBytes > 0 && len(b) > maxHeaderBytes { b = b[:maxHeaderBytes] } headersLen, errParse := parse(h, b) if errParse != nil { if maxHeaderBytes > 0 && len(b) >= maxHeaderBytes && errors.Is(errParse, errs.ErrNeedMore) { return errHeaderTooLarge } return ext.HeaderError("request", err, errParse, b) } ext.MustDiscard(r, headersLen) return nil } func parse(h *protocol.RequestHeader, buf []byte) (int, error) { m, err := parseFirstLine(h, buf) if err != nil { return 0, err } rawHeaders, _, err := ext.ReadRawHeaders(h.RawHeaders()[:0], buf[m:]) h.SetRawHeaders(rawHeaders) if err != nil { return 0, err } n, err := parseHeaders(h, buf[m:]) if err != nil { return 0, err } return m + n, nil } const ( maxCheckMethodLen = 10 // reuse ValidHeaderFieldNameTable for Method, both are `token` // see: // https://www.rfc-editor.org/rfc/rfc9110.html#name-methods // https://www.rfc-editor.org/rfc/rfc9110.html#name-field-names validMethodCharTable = bytesconv.ValidHeaderFieldNameTable ) var errMalformedHTTPRequest = errors.New("malformed HTTP request") // request-line = method SP request-target SP HTTP-version CRLF func parseFirstLine(h *protocol.RequestHeader, buf []byte) (int, error) { b, leftb, err := utils.NextLine(buf) if err != nil { // errs.ErrNeedMore? // check malformed HTTP request before reading more data // NOTE: // only check method bytes if errs.ErrNeedMore for closing malformed connections. // for performance concern, it won't be checked in the hot path. for i, c := range buf { if c == ' ' || i > maxCheckMethodLen { break // skip if SP or reach maxCheckMethodLen } if validMethodCharTable[c] == 0 { return 0, errMalformedHTTPRequest } } return 0, err } // parse method n := bytes.IndexByte(b, ' ') if n <= 0 { return 0, errMalformedHTTPRequest } h.SetMethodBytes(b[:n]) b = b[n+1:] // parse request-target (uri) n = bytes.IndexByte(b, ' ') if n <= 0 { return 0, errMalformedHTTPRequest } h.SetRequestURIBytes(b[:n]) b = b[n+1:] // parse http protocol switch string(b) { case consts.HTTP11: // likely HTTP/1.1 h.SetProtocol(consts.HTTP11) case consts.HTTP10: h.SetProtocol(consts.HTTP10) default: if len(b) < 5 || string(b[:5]) != "HTTP/" { return 0, errMalformedHTTPRequest } // XXX: all other cases are considered to be HTTP/1.0 for safe h.SetProtocol(consts.HTTP10) } return len(buf) - len(leftb), nil } // validHeaderFieldValue is equal to httpguts.ValidHeaderFieldValue(shares the same context) func validHeaderFieldValue(val []byte) bool { for _, v := range val { if bytesconv.ValidHeaderFieldValueTable[v] == 0 { return false } } return true } func parseHeaders(h *protocol.RequestHeader, buf []byte) (int, error) { h.InitContentLengthWithValue(-2) var s ext.HeaderScanner s.B = buf s.DisableNormalizing = h.IsDisableNormalizing() var err error for s.Next() { if len(s.Key) > 0 { // Spaces between the header key and colon are not allowed. // See RFC 7230, Section 3.2.4. if bytes.IndexByte(s.Key, ' ') != -1 || bytes.IndexByte(s.Key, '\t') != -1 { err = fmt.Errorf("invalid header key %q", s.Key) return 0, err } // Check the invalid chars in header value if !validHeaderFieldValue(s.Value) { err = fmt.Errorf("invalid header value %q", s.Value) return 0, err } switch s.Key[0] | 0x20 { case 'h': if utils.CaseInsensitiveCompare(s.Key, bytestr.StrHost) { h.SetHostBytes(s.Value) continue } case 'u': if utils.CaseInsensitiveCompare(s.Key, bytestr.StrUserAgent) { h.SetUserAgentBytes(s.Value) continue } case 'c': if utils.CaseInsensitiveCompare(s.Key, bytestr.StrContentType) { h.SetContentTypeBytes(s.Value) continue } if utils.CaseInsensitiveCompare(s.Key, bytestr.StrContentLength) { if h.ContentLength() != -1 { var nerr error var contentLength int if contentLength, nerr = protocol.ParseContentLength(s.Value); nerr != nil { if err == nil { err = nerr } h.InitContentLengthWithValue(-2) } else { h.InitContentLengthWithValue(contentLength) h.SetContentLengthBytes(s.Value) } } continue } if utils.CaseInsensitiveCompare(s.Key, bytestr.StrConnection) { if bytes.Equal(s.Value, bytestr.StrClose) { h.SetConnectionClose(true) } else { h.SetConnectionClose(false) h.AddArgBytes(s.Key, s.Value, protocol.ArgsHasValue) } continue } case 't': if utils.CaseInsensitiveCompare(s.Key, bytestr.StrTransferEncoding) { if !bytes.Equal(s.Value, bytestr.StrIdentity) { h.InitContentLengthWithValue(-1) h.SetArgBytes(bytestr.StrTransferEncoding, bytestr.StrChunked, protocol.ArgsHasValue) } continue } if utils.CaseInsensitiveCompare(s.Key, bytestr.StrTrailer) { if nerr := h.Trailer().SetTrailers(s.Value); nerr != nil { if err == nil { err = nerr } } continue } } } h.AddArgBytes(s.Key, s.Value, protocol.ArgsHasValue) } if s.Err != nil && err == nil { err = s.Err } if err != nil { h.SetConnectionClose(true) return 0, err } if h.ContentLength() < 0 { h.SetContentLengthBytes(h.ContentLengthBytes()[:0]) } if !h.IsHTTP11() && !h.ConnectionClose() { // close connection for non-http/1.1 request unless 'Connection: keep-alive' is set. v := h.PeekArgBytes(bytestr.StrConnection) h.SetConnectionClose(!ext.HasHeaderValue(v, bytestr.StrKeepAlive)) } return s.HLen, nil } ================================================ FILE: pkg/protocol/http1/req/header_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package req import ( "bufio" "bytes" "errors" "fmt" "strings" "testing" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/test/mock" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/netpoll" ) func TestRequestHeader_Read(t *testing.T) { s := "PUT /foo/bar HTTP/1.1\r\nExpect: 100-continue\r\nUser-Agent: foo\r\nHost: 127.0.0.1\r\nConnection: Keep-Alive\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" zr := mock.NewZeroCopyReader(s) rh := protocol.RequestHeader{} ReadHeader(&rh, zr) // firstline assert.DeepEqual(t, []byte(consts.MethodPut), rh.Method()) assert.DeepEqual(t, []byte("/foo/bar"), rh.RequestURI()) assert.True(t, rh.IsHTTP11()) // headers assert.DeepEqual(t, 5, rh.ContentLength()) assert.DeepEqual(t, []byte("foo/bar"), rh.ContentType()) count := 0 rh.VisitAll(func(key, value []byte) { count += 1 }) assert.DeepEqual(t, 6, count) assert.DeepEqual(t, []byte("foo"), rh.UserAgent()) assert.DeepEqual(t, []byte("127.0.0.1"), rh.Host()) assert.DeepEqual(t, []byte("100-continue"), rh.Peek("Expect")) } func TestRequestHeader_Peek(t *testing.T) { s := "PUT /foo/bar HTTP/1.1\r\nExpect: 100-continue\r\nUser-Agent: foo\r\nHost: 127.0.0.1\r\nConnection: Keep-Alive\r\nContent-Length: 5\r\nTransfer-Encoding: foo\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" zr := mock.NewZeroCopyReader(s) rh := protocol.RequestHeader{} ReadHeader(&rh, zr) assert.DeepEqual(t, []byte("100-continue"), rh.Peek("Expect")) assert.DeepEqual(t, []byte("127.0.0.1"), rh.Peek("Host")) assert.DeepEqual(t, []byte("foo"), rh.Peek("User-Agent")) assert.DeepEqual(t, []byte("Keep-Alive"), rh.Peek("Connection")) assert.DeepEqual(t, []byte(""), rh.Peek("Content-Length")) assert.DeepEqual(t, []byte("foo/bar"), rh.Peek("Content-Type")) } func TestRequestHeaderSetGet(t *testing.T) { t.Parallel() h := &protocol.RequestHeader{} h.SetRequestURI("/aa/bbb") h.SetMethod(consts.MethodPost) h.Set("foo", "bar") h.Set("host", "12345") h.Set("content-type", "aaa/bbb") h.Set("content-length", "1234") h.Set("user-agent", "aaabbb") h.Set("referer", "axcv") h.Set("baz", "xxxxx") h.Set("transfer-encoding", "chunked") h.Set("connection", "close") expectRequestHeaderGet(t, h, "Foo", "bar") expectRequestHeaderGet(t, h, consts.HeaderHost, "12345") expectRequestHeaderGet(t, h, consts.HeaderContentType, "aaa/bbb") expectRequestHeaderGet(t, h, consts.HeaderContentLength, "1234") expectRequestHeaderGet(t, h, "USER-AGent", "aaabbb") expectRequestHeaderGet(t, h, consts.HeaderReferer, "axcv") expectRequestHeaderGet(t, h, "baz", "xxxxx") expectRequestHeaderGet(t, h, consts.HeaderTransferEncoding, "") expectRequestHeaderGet(t, h, "connecTION", "close") if !h.ConnectionClose() { t.Fatalf("unset connection: close") } if h.ContentLength() != 1234 { t.Fatalf("Unexpected content-length %d. Expected %d", h.ContentLength(), 1234) } w := &bytes.Buffer{} bw := bufio.NewWriter(w) zw := netpoll.NewWriter(bw) err := WriteHeader(h, zw) if err != nil { t.Fatalf("Unexpected error when writing request header: %s", err) } if err := bw.Flush(); err != nil { t.Fatalf("Unexpected error when flushing request header: %s", err) } zw.Flush() bw.Flush() var h1 protocol.RequestHeader br := bufio.NewReader(w) zr := mock.ZeroCopyReader{Reader: br} if err = ReadHeader(&h1, zr); err != nil { t.Fatalf("Unexpected error when reading request header: %s", err) } if h1.ContentLength() != h.ContentLength() { t.Fatalf("Unexpected Content-Length %d. Expected %d", h1.ContentLength(), h.ContentLength()) } expectRequestHeaderGet(t, &h1, "Foo", "bar") expectRequestHeaderGet(t, &h1, "HOST", "12345") expectRequestHeaderGet(t, &h1, consts.HeaderContentType, "aaa/bbb") expectRequestHeaderGet(t, &h1, consts.HeaderContentLength, "1234") expectRequestHeaderGet(t, &h1, "USER-AGent", "aaabbb") expectRequestHeaderGet(t, &h1, consts.HeaderReferer, "axcv") expectRequestHeaderGet(t, &h1, "baz", "xxxxx") expectRequestHeaderGet(t, &h1, consts.HeaderTransferEncoding, "") expectRequestHeaderGet(t, &h1, consts.HeaderConnection, "close") if !h1.ConnectionClose() { t.Fatalf("unset connection: close") } } func TestRequestHeaderCookie(t *testing.T) { t.Parallel() var h protocol.RequestHeader h.SetRequestURI("/foobar") h.Set(consts.HeaderHost, "foobar.com") h.SetCookie("foo", "bar") h.SetCookie("привет", "мир") if string(h.Cookie("foo")) != "bar" { t.Fatalf("Unexpected cookie value %q. Expected %q", h.Cookie("foo"), "bar") } if string(h.Cookie("привет")) != "мир" { t.Fatalf("Unexpected cookie value %q. Expected %q", h.Cookie("привет"), "мир") } w := &bytes.Buffer{} zw := netpoll.NewWriter(w) if err := WriteHeader(&h, zw); err != nil { t.Fatalf("Unexpected error: %s", err) } if err := zw.Flush(); err != nil { t.Fatalf("Unexpected error: %s", err) } var h1 protocol.RequestHeader br := bufio.NewReader(w) zr := mock.ZeroCopyReader{Reader: br} if err := ReadHeader(&h1, zr); err != nil { t.Fatalf("Unexpected error: %s", err) } if !bytes.Equal(h1.Cookie("foo"), h.Cookie("foo")) { t.Fatalf("Unexpected cookie value %q. Expected %q", h1.Cookie("foo"), h.Cookie("foo")) } h1.DelCookie("foo") if len(h1.Cookie("foo")) > 0 { t.Fatalf("Unexpected cookie found: %q", h1.Cookie("foo")) } if !bytes.Equal(h1.Cookie("привет"), h.Cookie("привет")) { t.Fatalf("Unexpected cookie value %q. Expected %q", h1.Cookie("привет"), h.Cookie("привет")) } h1.DelCookie("привет") if len(h1.Cookie("привет")) > 0 { t.Fatalf("Unexpected cookie found: %q", h1.Cookie("привет")) } h.DelAllCookies() if len(h.Cookie("foo")) > 0 { t.Fatalf("Unexpected cookie found: %q", h.Cookie("foo")) } if len(h.Cookie("привет")) > 0 { t.Fatalf("Unexpected cookie found: %q", h.Cookie("привет")) } } func TestRequestRawHeaders(t *testing.T) { t.Parallel() kvs := "hOsT: foobar\r\n" + "value: b\r\n" + "\r\n" t.Run("normalized", func(t *testing.T) { s := "GET / HTTP/1.1\r\n" + kvs exp := kvs var h protocol.RequestHeader zr := mock.NewZeroCopyReader(s) if err := ReadHeader(&h, zr); err != nil { t.Fatalf("unexpected error: %s", err) } if string(h.Host()) != "foobar" { t.Fatalf("unexpected host: %q. Expecting %q", h.Host(), "foobar") } v2 := h.Peek("Value") if !bytes.Equal(v2, []byte{'b'}) { t.Fatalf("expecting non empty value. Got %q", v2) } if raw := h.RawHeaders(); string(raw) != exp { t.Fatalf("expected header %q, got %q", exp, raw) } }) for _, n := range []int{0, 1, 4, 8} { t.Run(fmt.Sprintf("post-%dk", n), func(t *testing.T) { l := 1024 * n body := make([]byte, l) for i := range body { body[i] = 'a' } cl := fmt.Sprintf("Content-Length: %d\r\n", l) s := "POST / HTTP/1.1\r\n" + cl + kvs + string(body) exp := cl + kvs var h protocol.RequestHeader zr := mock.NewZeroCopyReader(s) if err := ReadHeader(&h, zr); err != nil { t.Fatalf("unexpected error: %s", err) } if string(h.Host()) != "foobar" { t.Fatalf("unexpected host: %q. Expecting %q", h.Host(), "foobar") } v2 := h.Peek("Value") if !bytes.Equal(v2, []byte{'b'}) { t.Fatalf("expecting non empty value. Got %q", v2) } if raw := h.RawHeaders(); string(raw) != exp { t.Fatalf("expected header %q, got %q", exp, raw) } }) } t.Run("http10", func(t *testing.T) { s := "GET / HTTP/1.0\r\n" + kvs exp := kvs var h protocol.RequestHeader zr := mock.NewZeroCopyReader(s) if err := ReadHeader(&h, zr); err != nil { t.Fatalf("unexpected error: %s", err) } if string(h.Host()) != "foobar" { t.Fatalf("unexpected host: %q. Expecting %q", h.Host(), "foobar") } v2 := h.Peek("Value") if !bytes.Equal(v2, []byte{'b'}) { t.Fatalf("expecting non empty value. Got %q", v2) } if raw := h.RawHeaders(); string(raw) != exp { t.Fatalf("expected header %q, got %q", exp, raw) } }) t.Run("no-kvs", func(t *testing.T) { s := "GET / HTTP/1.1\r\n\r\n" exp := "" var h protocol.RequestHeader h.DisableNormalizing() zr := mock.NewZeroCopyReader(s) if err := ReadHeader(&h, zr); err != nil { t.Fatalf("unexpected error: %s", err) } if string(h.Host()) != "" { t.Fatalf("unexpected host: %q. Expecting %q", h.Host(), "") } v1 := h.Peek("NoKey") if len(v1) > 0 { t.Fatalf("expecting empty value. Got %q", v1) } if raw := h.RawHeaders(); string(raw) != exp { t.Fatalf("expected header %q, got %q", exp, raw) } }) } func TestRequestHeaderEmptyValueFromHeader(t *testing.T) { t.Parallel() var h1 protocol.RequestHeader h1.SetRequestURI("/foo/bar") h1.SetHost("foobar") h1.Set("EmptyValue1", "") h1.Set("EmptyValue2", " ") s := h1.String() var h protocol.RequestHeader zr := mock.NewZeroCopyReader(s) if err := ReadHeader(&h, zr); err != nil { t.Fatalf("unexpected error: %s", err) } if string(h.Host()) != string(h1.Host()) { t.Fatalf("unexpected host: %q. Expecting %q", h.Host(), h1.Host()) } v1 := h.Peek("EmptyValue1") if len(v1) > 0 { t.Fatalf("expecting empty value. Got %q", v1) } v2 := h.Peek("EmptyValue2") if len(v2) > 0 { t.Fatalf("expecting empty value. Got %q", v2) } } func TestRequestHeaderEmptyValueFromString(t *testing.T) { t.Parallel() s := "GET / HTTP/1.1\r\n" + "EmptyValue1:\r\n" + "Host: foobar\r\n" + "EmptyValue2: \r\n" + "\r\n" var h protocol.RequestHeader zr := mock.NewZeroCopyReader(s) if err := ReadHeader(&h, zr); err != nil { t.Fatalf("unexpected error: %s", err) } if string(h.Host()) != "foobar" { t.Fatalf("unexpected host: %q. Expecting %q", h.Host(), "foobar") } v1 := h.Peek("EmptyValue1") if len(v1) > 0 { t.Fatalf("expecting empty value. Got %q", v1) } v2 := h.Peek("EmptyValue2") if len(v2) > 0 { t.Fatalf("expecting empty value. Got %q", v2) } } func expectRequestHeaderGet(t *testing.T, h *protocol.RequestHeader, key, expectedValue string) { if string(h.Peek(key)) != expectedValue { t.Fatalf("Unexpected value for key %q: %q. Expected %q", key, h.Peek(key), expectedValue) } } func TestRequestHeader_PeekIfExists(t *testing.T) { s := "PUT /foo/bar HTTP/1.1\r\nExpect: 100-continue\r\nexists: \r\nContent-Type: foo/bar\r\n\r\nabcdef4343" rh := protocol.RequestHeader{} err := ReadHeader(&rh, mock.NewZeroCopyReader(s)) if err != nil { t.Fatal(err) } assert.DeepEqual(t, []byte{}, rh.Peek("exists")) assert.DeepEqual(t, []byte(nil), rh.Peek("non-exists")) } func TestRequestHeaderError(t *testing.T) { er := mock.EOFReader{} rh := protocol.RequestHeader{} err := ReadHeader(&rh, &er) assert.True(t, errors.Is(err, errs.ErrNothingRead)) } func TestReadHeader(t *testing.T) { s := "P" zr := mock.NewZeroCopyReader(s) rh := protocol.RequestHeader{} err := ReadHeader(&rh, zr) assert.NotNil(t, err) } func TestParseHeaders(t *testing.T) { rh := protocol.RequestHeader{} _, err := parseHeaders(&rh, []byte{' '}) assert.NotNil(t, err) } func TestTryRead(t *testing.T) { rh := protocol.RequestHeader{} s := "P" zr := mock.NewZeroCopyReader(s) err := tryReadWithLimit(&rh, zr, 0, 0) assert.Nil(t, err) } func TestReadHeaderWithLimit(t *testing.T) { validRequest := "GET /path HTTP/1.1\r\nHost: example.com\r\n\r\n" zr := mock.NewZeroCopyReader(validRequest) h := &protocol.RequestHeader{} err := ReadHeaderWithLimit(h, zr, 0) assert.Nil(t, err) assert.DeepEqual(t, string(h.Method()), "GET") } func TestReadHeaderWithLimitExceeded(t *testing.T) { largeRequest := "GET /path HTTP/1.1\r\nHost: example.com\r\nLarge-Header: " + strings.Repeat("x", 100) + "\r\n\r\n" zr := mock.NewZeroCopyReader(largeRequest) h := &protocol.RequestHeader{} err := ReadHeaderWithLimit(h, zr, 50) assert.NotNil(t, err) } func TestParseFirstLine(t *testing.T) { tests := []struct { name string input []byte method string uri string protocol string err error }{ { name: "case: normal", input: []byte("GET /path/to/resource HTTP/1.0\r\n"), method: "GET", uri: "/path/to/resource", protocol: "HTTP/1.0", }, { name: "case: empty uri", input: []byte("GET HTTP/1.1\r\n"), err: errMalformedHTTPRequest, }, { name: "case: unknown protocol should use HTTP/1.0", input: []byte("POST /path/to/resource HTTP/1.2\r\n"), method: "POST", uri: "/path/to/resource", protocol: "HTTP/1.0", }, { name: "case: invalid protocol", input: []byte("POST /path/to/resource XTTP/1.1\r\n"), err: errMalformedHTTPRequest, }, { name: "case: input too large", input: make([]byte, 9<<10), err: errMalformedHTTPRequest, }, { name: "case: method invalid", input: []byte("< / HTTP/1."), err: errMalformedHTTPRequest, }, { name: "case: need more err", input: []byte("GET / HTTP/1."), err: errs.ErrNeedMore, }, } for _, tc := range tests { h := &protocol.RequestHeader{} _, err := parseFirstLine(h, tc.input) if tc.err != nil { assert.Assert(t, errors.Is(err, tc.err), tc.name, err) continue } assert.Assert(t, err == nil, tc.name, err) if string(h.Method()) != tc.method || string(h.RequestURI()) != tc.uri || h.GetProtocol() != tc.protocol { t.Fatal(tc.name, "got", h.String()) } } } func TestParse(t *testing.T) { tests := []struct { name string input []byte expected int wantErr bool }{ // normal test { name: "normal", input: []byte("GET /path/to/resource HTTP/1.1\r\nHost: example.com\r\n\r\n"), expected: len([]byte("GET /path/to/resource HTTP/1.1\r\nHost: example.com\r\n\r\n")), wantErr: false, }, // parseFirstLine error { name: "parseFirstLine error", input: []byte("INVALID_LINE\r\nHost: example.com\r\n\r\n"), expected: 0, wantErr: true, }, // ext.ReadRawHeaders error { name: "ext.ReadRawHeaders error", input: []byte("GET /path/to/resource HTTP/1.1\r\nINVALID_HEADER\r\n\r\n"), expected: 0, wantErr: true, }, // parseHeaders error { name: "parseHeaders error", input: []byte("GET /path/to/resource HTTP/1.1\r\nHost: example.com\r\nINVALID_HEADER\r\n"), expected: 0, wantErr: true, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { header := &protocol.RequestHeader{} bytesRead, err := parse(header, tc.input) if (err != nil) != tc.wantErr { t.Errorf("Expected error: %v, but got: %v", tc.wantErr, err) } if bytesRead != tc.expected { t.Errorf("Expected bytes read: %d, but got: %d", tc.expected, bytesRead) } }) } } ================================================ FILE: pkg/protocol/http1/req/request.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package req import ( "bytes" "encoding/base64" "errors" "fmt" "io" "mime/multipart" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/pkg/common/bytebufferpool" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/protocol/http1/ext" ) var ( errRequestHostRequired = errs.NewPublic("missing required Host header in request") errGetOnly = errs.NewPublic("non-GET request received") errBodyTooLarge = errs.New(errs.ErrBodyTooLarge, errs.ErrorTypePublic, "http1/req") errHeaderTooLarge = errs.New(errs.ErrHeaderTooLarge, errs.ErrorTypePublic, "http1/req") ) type h1Request struct { *protocol.Request } // String returns request representation. // // Returns error message instead of request representation on error. // // Use Write instead of String for performance-critical code. func (h1Req *h1Request) String() string { w := bytebufferpool.Get() zw := network.NewWriter(w) if err := Write(h1Req.Request, zw); err != nil { return err.Error() } if err := zw.Flush(); err != nil { return err.Error() } s := string(w.B) bytebufferpool.Put(w) return s } func GetHTTP1Request(req *protocol.Request) fmt.Stringer { return &h1Request{req} } // ReadHeaderAndLimitBody reads request from the given r, limiting the body size. // // If maxBodySize > 0 and the body size exceeds maxBodySize, // then errBodyTooLarge is returned. // // RemoveMultipartFormFiles or Reset must be called after // reading multipart/form-data request in order to delete temporarily // uploaded files. // // If MayContinue returns true, the caller must: // // - Either send StatusExpectationFailed response if request headers don't // satisfy the caller. // - Or send StatusContinue response before reading request body // with ContinueReadBody. // - Or close the connection. // // io.EOF is returned if r is closed before reading the first header byte. func ReadHeaderAndLimitBody(req *protocol.Request, r network.Reader, maxBodySize int, preParse ...bool) error { var parse bool if len(preParse) == 0 { parse = true } else { parse = preParse[0] } req.ResetSkipHeader() if err := ReadHeader(&req.Header, r); err != nil { return err } return ReadLimitBody(req, r, maxBodySize, false, parse) } // Read reads request (including body) from the given r. // // RemoveMultipartFormFiles or Reset must be called after // reading multipart/form-data request in order to delete temporarily // uploaded files. // // If MayContinue returns true, the caller must: // // - Either send StatusExpectationFailed response if request headers don't // satisfy the caller. // - Or send StatusContinue response before reading request body // with ContinueReadBody. // - Or close the connection. // // io.EOF is returned if r is closed before reading the first header byte. func Read(req *protocol.Request, r network.Reader, preParse ...bool) error { return ReadHeaderAndLimitBody(req, r, 0, preParse...) } // Write writes request to w. // // Write doesn't flush request to w for performance reasons. // // See also WriteTo. func Write(req *protocol.Request, w network.Writer) error { return write(req, w, false) } // ProxyWrite is like Write but writes the request in the form // expected by an HTTP proxy. In particular, ProxyWrite writes the // initial Request-URI line of the request with an absolute URI, per // section 5.3 of RFC 7230, including the scheme and host. func ProxyWrite(req *protocol.Request, w network.Writer) error { return write(req, w, true) } // write writes request to w. // It supports proxy situation. func write(req *protocol.Request, w network.Writer, usingProxy bool) error { if len(req.Header.Host()) == 0 || req.IsURIParsed() { uri := req.URI() host := uri.Host() if len(host) == 0 { return errRequestHostRequired } if len(req.Header.Host()) == 0 { req.Header.SetHostBytes(host) } ruri := uri.RequestURI() if bytes.Equal(req.Method(), bytestr.StrConnect) { ruri = uri.Host() } else if usingProxy { ruri = uri.FullURI() } req.Header.SetRequestURIBytes(ruri) if len(uri.Username()) > 0 { // RequestHeader.SetBytesKV only uses RequestHeader.bufKV.key // So we are free to use RequestHeader.bufKV.value as a scratch pad for // the base64 encoding. nl := len(uri.Username()) + len(uri.Password()) + 1 nb := nl + len(bytestr.StrBasicSpace) tl := nb + base64.StdEncoding.EncodedLen(nl) req.Header.InitBufValue(tl) buf := req.Header.GetBufValue()[:0] buf = append(buf, uri.Username()...) buf = append(buf, bytestr.StrColon...) buf = append(buf, uri.Password()...) buf = append(buf, bytestr.StrBasicSpace...) base64.StdEncoding.Encode(buf[nb:tl], buf[:nl]) req.Header.SetBytesKV(bytestr.StrAuthorization, buf[nl:tl]) } } if req.IsBodyStream() { return writeBodyStream(req, w) } body := req.BodyBytes() err := handleMultipart(req) if err != nil { return fmt.Errorf("error when handle multipart: %s", err) } if req.OnlyMultipartForm() { m, _ := req.MultipartForm() // req.multipartForm != nil body, err = protocol.MarshalMultipartForm(m, req.MultipartFormBoundary()) if err != nil { return fmt.Errorf("error when marshaling multipart form: %s", err) } req.Header.SetMultipartFormBoundary(req.MultipartFormBoundary()) } hasBody := false if len(body) == 0 { body = req.PostArgString() } if len(body) != 0 || !req.Header.IgnoreBody() { hasBody = true req.Header.SetContentLength(len(body)) } header := req.Header.Header() if _, err := w.WriteBinary(header); err != nil { return err } // Write body if hasBody { w.WriteBinary(body) //nolint:errcheck } else if len(body) > 0 { return fmt.Errorf("non-zero body for non-POST request. body=%q", body) } return nil } // ContinueReadBodyStream reads request body in stream if request header contains // 'Expect: 100-continue'. // // The caller must send StatusContinue response before calling this method. // // If maxBodySize > 0 and the body size exceeds maxBodySize, // then errBodyTooLarge is returned. func ContinueReadBodyStream(req *protocol.Request, zr network.Reader, maxBodySize int, preParseMultipartForm ...bool) error { var err error contentLength := req.Header.ContentLength() if contentLength > 0 { if len(preParseMultipartForm) == 0 || preParseMultipartForm[0] { // Pre-read multipart form data of known length. // This way we limit memory usage for large file uploads, since their contents // is streamed into temporary files if file size exceeds defaultMaxInMemoryFileSize. req.SetMultipartFormBoundary(string(req.Header.MultipartFormBoundary())) if len(req.MultipartFormBoundary()) > 0 && len(req.Header.PeekContentEncoding()) == 0 { err := protocol.ParseMultipartForm(zr.(io.Reader), req, contentLength, consts.DefaultMaxInMemoryFileSize) if err != nil { req.Reset() } return err } } } if contentLength == -2 { // identity body has no sense for http requests, since // the end of body is determined by connection close. // So just ignore request body for requests without // 'Content-Length' and 'Transfer-Encoding' headers. // refer to https://tools.ietf.org/html/rfc7230#section-3.3.2 if !req.Header.IgnoreBody() { req.Header.SetContentLength(0) } return nil } bodyBuf := req.BodyBuffer() bodyBuf.Reset() bodyBuf.B, err = ext.ReadBodyWithStreaming(zr, contentLength, maxBodySize, bodyBuf.B) if err != nil { if errors.Is(err, errs.ErrBodyTooLarge) { req.Header.SetContentLength(contentLength) req.ConstructBodyStream(bodyBuf, ext.AcquireBodyStream(bodyBuf, zr, req.Header.Trailer(), contentLength)) return nil } if errors.Is(err, errs.ErrChunkedStream) { req.ConstructBodyStream(bodyBuf, ext.AcquireBodyStream(bodyBuf, zr, req.Header.Trailer(), contentLength)) return nil } req.Reset() return err } req.ConstructBodyStream(bodyBuf, ext.AcquireBodyStream(bodyBuf, zr, req.Header.Trailer(), contentLength)) return nil } func ContinueReadBody(req *protocol.Request, r network.Reader, maxBodySize int, preParseMultipartForm ...bool) error { var err error contentLength := req.Header.ContentLength() if contentLength > 0 { if maxBodySize > 0 && contentLength > maxBodySize { return errBodyTooLarge } if len(preParseMultipartForm) == 0 || preParseMultipartForm[0] { // Pre-read multipart form data of known length. // This way we limit memory usage for large file uploads, since their contents // is streamed into temporary files if file size exceeds defaultMaxInMemoryFileSize. req.SetMultipartFormBoundary(string(req.Header.MultipartFormBoundary())) if len(req.MultipartFormBoundary()) > 0 && len(req.Header.PeekContentEncoding()) == 0 { err := protocol.ParseMultipartForm(r.(io.Reader), req, contentLength, consts.DefaultMaxInMemoryFileSize) if err != nil { req.Reset() } return err } } // This optimization is just suitable for ping-pong case and the ext.ReadBody is // a common function, so we just handle this situation before ext.ReadBody buf, err := r.Peek(contentLength) if err != nil { return err } r.Skip(contentLength) // nolint: errcheck req.SetBodyRaw(buf) return nil } if contentLength == -2 { // identity body has no sense for http requests, since // the end of body is determined by connection close. // So just ignore request body for requests without // 'Content-Length' and 'Transfer-Encoding' headers. // refer to https://tools.ietf.org/html/rfc7230#section-3.3.2 if !req.Header.IgnoreBody() { req.Header.SetContentLength(0) } return nil } bodyBuf := req.BodyBuffer() bodyBuf.Reset() bodyBuf.B, err = ext.ReadBody(r, contentLength, maxBodySize, bodyBuf.B) if err != nil { req.Reset() return err } if req.Header.ContentLength() == -1 { err = ext.ReadTrailer(req.Header.Trailer(), r) if err != nil && err != io.EOF { return err } } req.Header.SetContentLength(len(bodyBuf.B)) return nil } func ReadBodyStream(req *protocol.Request, zr network.Reader, maxBodySize int, getOnly, preParseMultipartForm bool) error { if getOnly && !req.Header.IsGet() { return errGetOnly } if req.MayContinue() { // 'Expect: 100-continue' header found. Let the caller deciding // whether to read request body or // to return StatusExpectationFailed. return nil } return ContinueReadBodyStream(req, zr, maxBodySize, preParseMultipartForm) } func ReadLimitBody(req *protocol.Request, r network.Reader, maxBodySize int, getOnly, preParseMultipartForm bool) error { // Do not reset the request here - the caller must reset it before // calling this method. if getOnly && !req.Header.IsGet() { return errGetOnly } if req.MayContinue() { // 'Expect: 100-continue' header found. Let the caller deciding // whether to read request body or // to return StatusExpectationFailed. return nil } return ContinueReadBody(req, r, maxBodySize, preParseMultipartForm) } func writeBodyStream(req *protocol.Request, w network.Writer) error { var err error contentLength := req.Header.ContentLength() if contentLength < 0 { lrSize := ext.LimitedReaderSize(req.BodyStream()) if lrSize >= 0 { contentLength = int(lrSize) if int64(contentLength) != lrSize { contentLength = -1 } if contentLength >= 0 { req.Header.SetContentLength(contentLength) } } } if contentLength >= 0 { if err = WriteHeader(&req.Header, w); err == nil { err = ext.WriteBodyFixedSize(w, req.BodyStream(), int64(contentLength)) } } else { req.Header.SetContentLength(-1) err = WriteHeader(&req.Header, w) if err == nil { err = ext.WriteBodyChunked(w, req.BodyStream()) } if err == nil { err = ext.WriteTrailer(req.Header.Trailer(), w) } } err1 := req.CloseBodyStream() if err == nil { err = err1 } return err } func handleMultipart(req *protocol.Request) error { if len(req.MultipartFiles()) == 0 && len(req.MultipartFields()) == 0 { return nil } var err error bodyBuffer := &bytes.Buffer{} w := multipart.NewWriter(bodyBuffer) if len(req.MultipartFiles()) > 0 { for _, f := range req.MultipartFiles() { if f.Reader != nil { err = protocol.WriteMultipartFormFile(w, f.ParamName, f.Name, f.Reader) } else { err = protocol.AddFile(w, f.ParamName, f.Name) } if err != nil { return err } } } if len(req.MultipartFields()) > 0 { for _, mf := range req.MultipartFields() { if err = protocol.AddMultipartFormField(w, mf); err != nil { return err } } } req.Header.Set(consts.HeaderContentType, w.FormDataContentType()) if err = w.Close(); err != nil { return err } r := multipart.NewReader(bodyBuffer, w.Boundary()) f, err := r.ReadForm(int64(bodyBuffer.Len())) if err != nil { return err } protocol.SetMultipartFormWithBoundary(req, f, w.Boundary()) return nil } ================================================ FILE: pkg/protocol/http1/req/request_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package req import ( "bufio" "bytes" "encoding/base64" "errors" "fmt" "io" "io/ioutil" "mime/multipart" "net/url" "strings" "testing" "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/pkg/common/bytebufferpool" "github.com/cloudwego/hertz/pkg/common/compress" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/test/mock" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/protocol/http1/ext" "github.com/cloudwego/netpoll" ) func TestRequestContinueReadBody(t *testing.T) { t.Parallel() s := "PUT /foo/bar HTTP/1.1\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" zr := mock.NewZeroCopyReader(s) var r protocol.Request if err := Read(&r, zr); err != nil { t.Fatalf("unexpected error: %s", err) } if err := ContinueReadBody(&r, zr, 0, true); err != nil { t.Fatalf("error when reading request body: %s", err) } body := r.Body() if string(body) != "abcde" { t.Fatalf("unexpected body %q. Expecting %q", body, "abcde") } tail, err := zr.Peek(zr.Len()) if err != nil { t.Fatalf("unexpected error: %s", err) } if string(tail) != "f4343" { t.Fatalf("unexpected tail %q. Expecting %q", tail, "f4343") } } func TestRequestReadNoBody(t *testing.T) { t.Parallel() var r protocol.Request s := "GET / HTTP/1.1\r\n\r\n" zr := mock.NewZeroCopyReader(s) if err := Read(&r, zr); err != nil { t.Fatalf("unexpected error: %s", err) } r.SetHost("foobar") headerStr := r.Header.String() if strings.Contains(headerStr, "Content-Length: ") { t.Fatalf("unexpected Content-Length") } } func TestRequestRead(t *testing.T) { t.Parallel() var r protocol.Request s := "POST / HTTP/1.1\r\n\r\n" zr := mock.NewZeroCopyReader(s) if err := Read(&r, zr); err != nil { t.Fatalf("unexpected error: %s", err) } r.SetHost("foobar") headerStr := r.Header.String() if !strings.Contains(headerStr, "Content-Length: ") { t.Fatalf("should contain Content-Length") } cLen := r.Header.Peek(consts.HeaderContentLength) if string(cLen) != "0" { t.Fatalf("unexpected Content-Length: %s, Expecting 0", string(cLen)) } } func TestRequestReadNoBodyStreaming(t *testing.T) { t.Parallel() var r protocol.Request r.Header.SetContentLength(-2) r.Header.SetMethod("GET") s := "" zr := mock.NewZeroCopyReader(s) if err := ContinueReadBodyStream(&r, zr, 2048, true); err != nil { t.Fatalf("unexpected error: %s", err) } r.SetHost("foobar") headerStr := r.Header.String() if strings.Contains(headerStr, "Content-Length: ") { t.Fatalf("unexpected Content-Length") } } func TestRequestReadStreaming(t *testing.T) { t.Parallel() var r protocol.Request r.Header.SetContentLength(-2) r.Header.SetMethod("POST") s := "" zr := mock.NewZeroCopyReader(s) if err := ContinueReadBodyStream(&r, zr, 2048, true); err != nil { t.Fatalf("unexpected error: %s", err) } r.SetHost("foobar") headerStr := r.Header.String() if !strings.Contains(headerStr, "Content-Length: ") { t.Fatalf("should contain Content-Length") } cLen := r.Header.Peek(consts.HeaderContentLength) if string(cLen) != "0" { t.Fatalf("unexpected Content-Length: %s, Expecting 0", string(cLen)) } } func TestMethodAndPathAndQueryString(t *testing.T) { s := "PUT /foo/bar?query=1 HTTP/1.1\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" zr := mock.NewZeroCopyReader(s) var r protocol.Request if err := Read(&r, zr); err != nil { t.Fatalf("unexpected error: %s", err) } if string(r.RequestURI()) != "/foo/bar?query=1" { t.Fatalf("unexpected request uri %s. Expecting %s", r.RequestURI(), "/foo/bar?query=1") } if string(r.Method()) != "PUT" { t.Fatalf("unexpected method %s. Expecting %s", r.Header.Method(), "PUT") } if string(r.Path()) != "/foo/bar" { t.Fatalf("unexpected uri path %s. Expecting %s", r.URI().Path(), "/foo/bar") } if string(r.QueryString()) != "query=1" { t.Fatalf("unexpected query string %s. Expecting %s", r.URI().QueryString(), "query=1") } } func TestRequestSuccess(t *testing.T) { t.Parallel() // empty method, user-agent and body testRequestSuccess(t, "", "/foo/bar", "google.com", "", "", consts.MethodGet) // non-empty user-agent testRequestSuccess(t, consts.MethodGet, "/foo/bar", "google.com", "MSIE", "", consts.MethodGet) // non-empty method testRequestSuccess(t, consts.MethodHead, "/aaa", "fobar", "", "", consts.MethodHead) // POST method with body testRequestSuccess(t, consts.MethodPost, "/bbb", "aaa.com", "Chrome aaa", "post body", consts.MethodPost) // PUT method with body testRequestSuccess(t, consts.MethodPut, "/aa/bb", "a.com", "aaa", "put body", consts.MethodPut) // only host is set testRequestSuccess(t, "", "", "gooble.com", "", "", consts.MethodGet) // get with body testRequestSuccess(t, consts.MethodGet, "/foo/bar", "aaa.com", "", "foobar", consts.MethodGet) } func TestRequestMultipartFormBoundary(t *testing.T) { t.Parallel() testRequestMultipartFormBoundary(t, "POST / HTTP/1.1\r\nContent-Type: multipart/form-data; boundary=foobar\r\n\r\n", "foobar") // incorrect content-type testRequestMultipartFormBoundary(t, "POST / HTTP/1.1\r\nContent-Type: foo/bar\r\n\r\n", "") // empty boundary testRequestMultipartFormBoundary(t, "POST / HTTP/1.1\r\nContent-Type: multipart/form-data; boundary=\r\n\r\n", "") // missing boundary testRequestMultipartFormBoundary(t, "POST / HTTP/1.1\r\nContent-Type: multipart/form-data\r\n\r\n", "") // boundary after other content-type params testRequestMultipartFormBoundary(t, "POST / HTTP/1.1\r\nContent-Type: multipart/form-data; foo=bar; boundary=--aaabb \r\n\r\n", "--aaabb") // quoted boundary testRequestMultipartFormBoundary(t, "POST / HTTP/1.1\r\nContent-Type: multipart/form-data; boundary=\"foobar\"\r\n\r\n", "foobar") var h protocol.RequestHeader h.SetMultipartFormBoundary("foobarbaz") b := h.MultipartFormBoundary() if string(b) != "foobarbaz" { t.Fatalf("unexpected boundary %q. Expecting %q", b, "foobarbaz") } } func testRequestSuccess(t *testing.T, method, requestURI, host, userAgent, body, expectedMethod string) { var req protocol.Request req.Header.SetMethod(method) req.Header.SetRequestURI(requestURI) req.Header.Set(consts.HeaderHost, host) req.Header.Set(consts.HeaderUserAgent, userAgent) req.SetBody([]byte(body)) contentType := "foobar" if method == consts.MethodPost { req.Header.Set(consts.HeaderContentType, contentType) } w := &bytes.Buffer{} zw := netpoll.NewWriter(w) err := Write(&req, zw) if err != nil { t.Fatalf("Unexpected error when calling Write(): %s", err) } if err = zw.Flush(); err != nil { t.Fatalf("Unexpected error when flushing bufio.Writer: %s", err) } var req1 protocol.Request br := bufio.NewReader(w) zr := netpoll.NewReader(br) if err = Read(&req1, zr); err != nil { t.Fatalf("Unexpected error when calling Read(): %s", err) } if string(req1.Header.Method()) != expectedMethod { t.Fatalf("Unexpected method: %q. Expected %q", req1.Header.Method(), expectedMethod) } if len(requestURI) == 0 { requestURI = "/" } if string(req1.Header.RequestURI()) != requestURI { t.Fatalf("Unexpected RequestURI: %q. Expected %q", req1.Header.RequestURI(), requestURI) } if string(req1.Header.Peek(consts.HeaderHost)) != host { t.Fatalf("Unexpected host: %q. Expected %q", req1.Header.Peek(consts.HeaderHost), host) } if string(req1.Header.Peek(consts.HeaderUserAgent)) != userAgent { t.Fatalf("Unexpected user-agent: %q. Expected %q", req1.Header.Peek(consts.HeaderUserAgent), userAgent) } if !bytes.Equal(req1.Body(), []byte(body)) { t.Fatalf("Unexpected body: %q. Expected %q", req1.Body(), body) } if method == consts.MethodPost && string(req1.Header.Peek(consts.HeaderContentType)) != contentType { t.Fatalf("Unexpected content-type: %q. Expected %q", req1.Header.Peek(consts.HeaderContentType), contentType) } } func TestRequestWriteError(t *testing.T) { t.Parallel() // no host testRequestWriteError(t, "", "/foo/bar", "", "", "") } func TestRequestPostArgsSuccess(t *testing.T) { t.Parallel() var req protocol.Request testRequestPostArgsSuccess(t, &req, "POST / HTTP/1.1\r\nHost: aaa.com\r\nContent-Type: application/x-www-form-urlencoded\r\nContent-Length: 0\r\n\r\n", 0, "foo=", "=") testRequestPostArgsSuccess(t, &req, "POST / HTTP/1.1\r\nHost: aaa.com\r\nContent-Type: application/x-www-form-urlencoded\r\nContent-Length: 18\r\n\r\nfoo&b%20r=b+z=&qwe", 3, "foo=", "b r=b z=", "qwe=") } func testRequestPostArgsSuccess(t *testing.T, req *protocol.Request, s string, expectedArgsLen int, expectedArgs ...string) { r := bytes.NewBufferString(s) zr := netpoll.NewReader(r) err := Read(req, zr) if err != nil { t.Fatalf("Unexpected error when reading %q: %s", s, err) } args := req.PostArgs() if args.Len() != expectedArgsLen { t.Fatalf("Unexpected args len %d. Expected %d for %q", args.Len(), expectedArgsLen, s) } for _, x := range expectedArgs { tmp := strings.SplitN(x, "=", 2) k := tmp[0] v := tmp[1] vv := string(args.Peek(k)) if vv != v { t.Fatalf("Unexpected value for key %q: %q. Expected %q for %q", k, vv, v, s) } } } func TestRequestPostArgsBodyStream(t *testing.T) { var req protocol.Request s := "POST / HTTP/1.1\r\nHost: aaa.com\r\nContent-Type: application/x-www-form-urlencoded\r\nContent-Length: 8196\r\n\r\n" contentB := make([]byte, 8192) for i := 0; i < len(contentB); i++ { contentB[i] = 'a' } content := string(contentB) requestString := s + url.Values{"key": []string{content}}.Encode() r := bytes.NewBufferString(requestString) zr := netpoll.NewReader(r) if err := ReadHeader(&req.Header, zr); err != nil { t.Fatalf("Unexpected error when reading header %q: %s", s, err) } err := ReadBodyStream(&req, zr, 1024*4, false, false) if err != nil { t.Fatalf("Unexpected error when reading bodystream %q: %s", s, err) } if string(req.PostArgs().Peek("key")) != content { assert.DeepEqual(t, content, string(req.PostArgs().Peek("key"))) } } func testRequestWriteError(t *testing.T, method, requestURI, host, userAgent, body string) { var req protocol.Request req.Header.SetMethod(method) req.Header.SetRequestURI(requestURI) req.Header.Set(consts.HeaderHost, host) req.Header.Set(consts.HeaderUserAgent, userAgent) req.SetBody([]byte(body)) w := &bytebufferpool.ByteBuffer{} zw := netpoll.NewWriter(w) err := Write(&req, zw) if err == nil { t.Fatalf("Expecting error when writing request=%#v", &req) } } func TestChunkedUnexpectedEOF(t *testing.T) { reader := &mock.EOFReader{} _, err := ext.ReadBody(reader, -1, 0, nil) if err != io.ErrUnexpectedEOF { assert.DeepEqual(t, io.ErrUnexpectedEOF, err) } var pool bytebufferpool.Pool var req1 protocol.Request bs := ext.AcquireBodyStream(pool.Get(), reader, req1.Header.Trailer(), -1) byteSlice := make([]byte, 4096) _, err = bs.Read(byteSlice) if err != io.ErrUnexpectedEOF { assert.DeepEqual(t, io.ErrUnexpectedEOF, err) } } func TestReadBodyChunked(t *testing.T) { t.Parallel() // zero-size body testReadBodyChunked(t, 0) // small-size body testReadBodyChunked(t, 5) // medium-size body testReadBodyChunked(t, 43488) // big body testReadBodyChunked(t, 3*1024*1024) // smaller body after big one testReadBodyChunked(t, 12343) } func TestReadBodyFixedSize(t *testing.T) { t.Parallel() // zero-size body testReadBodyFixedSize(t, 0) // small-size body testReadBodyFixedSize(t, 3) // medium-size body testReadBodyFixedSize(t, 1024) // large-size body testReadBodyFixedSize(t, 1024*1024) // smaller body after big one testReadBodyFixedSize(t, 34345) } func TestRequestWriteRequestURINoHost(t *testing.T) { t.Parallel() var req protocol.Request req.Header.SetRequestURI("http://user:pass@google.com/foo/bar?baz=aaa") var w bytes.Buffer zw := netpoll.NewWriter(&w) if err := Write(&req, zw); err != nil { t.Fatalf("unexpected error: %s", err) } if err := zw.Flush(); err != nil { t.Fatalf("unexpected error: %s", err) } var req1 protocol.Request br := bufio.NewReader(&w) zr := netpoll.NewReader(br) if err := Read(&req1, zr); err != nil { t.Fatalf("unexpected error: %s", err) } if string(req1.Header.Host()) != "google.com" { t.Fatalf("unexpected host: %q. Expecting %q", req1.Header.Host(), "google.com") } if string(req.Header.RequestURI()) != "/foo/bar?baz=aaa" { t.Fatalf("unexpected requestURI: %q. Expecting %q", req.Header.RequestURI(), "/foo/bar?baz=aaa") } // authorization authorization := req.Header.Get(string(bytestr.StrAuthorization)) author, err := base64.StdEncoding.DecodeString(authorization[len(bytestr.StrBasicSpace):]) if err != nil { t.Fatalf("expecting error") } if string(author) != "user:pass" { t.Fatalf("unexpected Authorization: %q. Expecting %q", authorization, "user:pass") } // verify that Write returns error on non-absolute RequestURI req.Reset() req.Header.SetRequestURI("/foo/bar") w.Reset() if err := Write(&req, zw); err == nil { t.Fatalf("expecting error") } } func TestRequestWriteMultipartFile(t *testing.T) { t.Parallel() var req protocol.Request req.Header.SetHost("foobar.com") req.Header.SetMethod(consts.MethodPost) req.SetFileReader("filea", "filea.txt", bytes.NewReader([]byte("This is filea."))) req.SetMultipartField("fileb", "fileb.txt", "text/plain", bytes.NewReader([]byte("This is fileb."))) var w bytes.Buffer zw := netpoll.NewWriter(&w) if err := Write(&req, zw); err != nil { t.Fatalf("unexpected error: %s", err) } if err := zw.Flush(); err != nil { t.Fatalf("unexpected error: %s", err) } var req1 protocol.Request zr := mock.NewZeroCopyReader(w.String()) if err := Read(&req1, zr); err != nil { t.Fatalf("unexpected error: %s", err) } filea, err := req1.FormFile("filea") assert.Nil(t, err) assert.DeepEqual(t, "filea.txt", filea.Filename) fileb, err := req1.FormFile("fileb") assert.Nil(t, err) assert.DeepEqual(t, "fileb.txt", fileb.Filename) } func TestSetRequestBodyStreamChunked(t *testing.T) { t.Parallel() testSetRequestBodyStreamChunked(t, "", map[string]string{"Foo": "bar"}) body := "foobar baz aaa bbb ccc" testSetRequestBodyStreamChunked(t, body, nil) body = string(mock.CreateFixedBody(10001)) testSetRequestBodyStreamChunked(t, body, map[string]string{"Foo": "test", "Bar": "test"}) } func TestSetRequestBodyStreamFixedSize(t *testing.T) { t.Parallel() testSetRequestBodyStream(t, "a") testSetRequestBodyStream(t, string(mock.CreateFixedBody(4097))) testSetRequestBodyStream(t, string(mock.CreateFixedBody(100500))) } func TestRequestHostFromRequestURI(t *testing.T) { t.Parallel() hExpected := "foobar.com" var req protocol.Request req.SetRequestURI("http://proxy-host:123/foobar?baz") req.SetHost(hExpected) h := bytesconv.B2s(req.Host()) if h != hExpected { t.Fatalf("unexpected host set: %q. Expecting %q", h, hExpected) } } func TestRequestHostFromHeader(t *testing.T) { t.Parallel() hExpected := "foobar.com" var req protocol.Request req.Header.SetHost(hExpected) h := bytesconv.B2s(req.Host()) if h != hExpected { t.Fatalf("unexpected host set: %q. Expecting %q", h, hExpected) } } func TestRequestContentTypeWithCharset(t *testing.T) { t.Parallel() expectedContentType := consts.MIMEApplicationHTMLFormUTF8 expectedBody := "0123=56789" s := fmt.Sprintf("POST / HTTP/1.1\r\nContent-Type: %s\r\nContent-Length: %d\r\n\r\n%s", expectedContentType, len(expectedBody), expectedBody) zr := mock.NewZeroCopyReader(s) var r protocol.Request if err := Read(&r, zr); err != nil { t.Fatalf("unexpected error: %s", err) } body := r.Body() if string(body) != expectedBody { t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) } ct := r.Header.ContentType() if string(ct) != expectedContentType { t.Fatalf("unexpected content-type %q. Expecting %q", ct, expectedContentType) } args := r.PostArgs() if args.Len() != 1 { t.Fatalf("unexpected number of POST args: %d. Expecting 1", args.Len()) } av := args.Peek("0123") if string(av) != "56789" { t.Fatalf("unexpected POST arg value: %q. Expecting %q", av, "56789") } } func TestRequestBodyStreamMultipleBodyCalls(t *testing.T) { t.Parallel() var r protocol.Request s := "foobar baz abc" if r.IsBodyStream() { t.Fatalf("IsBodyStream must return false") } r.SetBodyStream(bytes.NewBufferString(s), len(s)) if !r.IsBodyStream() { t.Fatalf("IsBodyStream must return true") } for i := 0; i < 10; i++ { body := r.Body() if string(body) != s { t.Fatalf("unexpected body %q. Expecting %q. iteration %d", body, s, i) } } } func TestRequestNoContentLength(t *testing.T) { t.Parallel() var r protocol.Request r.Header.SetMethod(consts.MethodHead) r.Header.SetHost("foobar") s := GetHTTP1Request(&r).String() if strings.Contains(s, "Content-Length: ") { t.Fatalf("unexpected content-length in HEAD request %q", s) } r.Header.SetMethod(consts.MethodPost) fmt.Fprintf(r.BodyWriter(), "foobar body") s = GetHTTP1Request(&r).String() if !strings.Contains(s, "Content-Length: ") { t.Fatalf("missing content-length header in non-GET request %q", s) } } func TestRequestReadGzippedBody(t *testing.T) { t.Parallel() var r protocol.Request bodyOriginal := "foo bar baz compress me better!" body := compress.AppendGzipBytes(nil, []byte(bodyOriginal)) s := fmt.Sprintf("POST /foobar HTTP/1.1\r\nContent-Type: foo/bar\r\nContent-Encoding: gzip\r\nContent-Length: %d\r\n\r\n%s", len(body), body) zr := mock.NewZeroCopyReader(s) if err := Read(&r, zr); err != nil { t.Fatalf("unexpected error: %s", err) } if string(r.Header.Peek(consts.HeaderContentEncoding)) != "gzip" { t.Fatalf("unexpected content-encoding: %q. Expecting %q", r.Header.Peek(consts.HeaderContentEncoding), "gzip") } if r.Header.ContentLength() != len(body) { t.Fatalf("unexpected content-length: %d. Expecting %d", r.Header.ContentLength(), len(body)) } if string(r.Body()) != string(body) { t.Fatalf("unexpected body: %q. Expecting %q", r.Body(), body) } bodyGunzipped, err := compress.AppendGunzipBytes(nil, r.Body()) if err != nil { t.Fatalf("unexpected error when uncompressing data: %s", err) } if string(bodyGunzipped) != bodyOriginal { t.Fatalf("unexpected uncompressed body %q. Expecting %q", bodyGunzipped, bodyOriginal) } } func TestRequestReadPostNoBody(t *testing.T) { t.Parallel() var r protocol.Request s := "POST /foo/bar HTTP/1.1\r\nContent-Type: aaa/bbb\r\n\r\naaaa" zr := mock.NewZeroCopyReader(s) if err := Read(&r, zr); err != nil { t.Fatalf("unexpected error: %s", err) } if string(r.Header.RequestURI()) != "/foo/bar" { t.Fatalf("unexpected request uri %q. Expecting %q", r.Header.RequestURI(), "/foo/bar") } if string(r.Header.ContentType()) != "aaa/bbb" { t.Fatalf("unexpected content-type %q. Expecting %q", r.Header.ContentType(), "aaa/bbb") } if len(r.Body()) != 0 { t.Fatalf("unexpected body found %q. Expecting empty body", r.Body()) } if r.Header.ContentLength() != 0 { t.Fatalf("unexpected content-length: %d. Expecting 0", r.Header.ContentLength()) } tail, err := ioutil.ReadAll(zr) if err != nil { t.Fatalf("unexpected error: %s", err) } if string(tail) != "aaaa" { t.Fatalf("unexpected tail %q. Expecting %q", tail, "aaaa") } } func TestRequestContinueReadBodyDisablePrereadMultipartForm(t *testing.T) { t.Parallel() var w bytes.Buffer mw := multipart.NewWriter(&w) for i := 0; i < 10; i++ { k := fmt.Sprintf("key_%d", i) v := fmt.Sprintf("value_%d", i) if err := mw.WriteField(k, v); err != nil { t.Fatalf("unexpected error: %s", err) } } boundary := mw.Boundary() if err := mw.Close(); err != nil { t.Fatalf("unexpected error: %s", err) } formData := w.Bytes() s := fmt.Sprintf("POST / HTTP/1.1\r\nHost: aaa\r\nContent-Type: multipart/form-data; boundary=%s\r\nContent-Length: %d\r\n\r\n%s", boundary, len(formData), formData) zr := mock.NewZeroCopyReader(s) var r protocol.Request if err := ReadHeader(&r.Header, zr); err != nil { t.Fatalf("unexpected error reading headers: %s", err) } if err := ReadLimitBody(&r, zr, 10000, false, false); err != nil { t.Fatalf("unexpected error reading body: %s", err) } if r.HasMultipartForm() { t.Fatalf("The multipartForm of the Request must be nil") } if string(formData) != string(r.Body()) { t.Fatalf("The body given must equal the body in the Request") } } func TestRequestReadLimitBody(t *testing.T) { t.Parallel() testRequestReadLimitBodyReadOnly(t, "POST /foo HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: 9\r\nContent-Type: aaa\r\n\r\n123456789") // request with content-length testRequestReadLimitBodySuccess(t, "POST /foo HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: 9\r\nContent-Type: aaa\r\n\r\n123456789", 9) testRequestReadLimitBodySuccess(t, "POST /foo HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: 9\r\nContent-Type: aaa\r\n\r\n123456789", 92) testRequestReadLimitBodyError(t, "POST /foo HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: 9\r\nContent-Type: aaa\r\n\r\n123456789", 5) // chunked request testRequestReadLimitBodySuccess(t, "POST /a HTTP/1.1\r\nHost: a.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\n\r\n", 9) testRequestReadLimitBodySuccess(t, "POST /a HTTP/1.1\r\nHost: a.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\n\r\n", 999) testRequestReadLimitBodyError(t, "POST /a HTTP/1.1\r\nHost: a.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\n\r\n", 8) } func testRequestReadLimitBodyReadOnly(t *testing.T, s string) { var req protocol.Request zr := mock.NewZeroCopyReader(s) ReadHeader(&req.Header, zr) if err := ReadLimitBody(&req, zr, 10, true, false); err == nil { t.Fatalf("expected error: %s", errGetOnly.Error()) } } func TestRequestString(t *testing.T) { t.Parallel() var r protocol.Request r.SetRequestURI("http://foobar.com/aaa") s := GetHTTP1Request(&r).String() expectedS := "GET /aaa HTTP/1.1\r\nHost: foobar.com\r\n\r\n" if s != expectedS { t.Fatalf("unexpected request: %q. Expecting %q", s, expectedS) } } func TestRequestReadChunked(t *testing.T) { t.Parallel() var req protocol.Request s := "POST /foo HTTP/1.1\r\nHost: google.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa/bb\r\n\r\n3\r\nabc\r\n5\r\n12345\r\n0\r\n\r\nTrail: test\r\n\r\n" zr := netpoll.NewReader(bytes.NewBufferString(s)) err := Read(&req, zr) if err != nil { t.Fatalf("Unexpected error when reading chunked request: %s", err) } expectedBody := "abc12345" if string(req.Body()) != expectedBody { t.Fatalf("Unexpected body %q. Expected %q", req.Body(), expectedBody) } verifyRequestHeader(t, &req.Header, 8, "/foo", "google.com", "", "aa/bb") verifyTrailer(t, zr, map[string]string{"Trail": "test"}) } func verifyTrailer(t *testing.T, r network.Reader, exceptedTrailers map[string]string) { trailer := protocol.Trailer{} keys := make([]string, 0, len(exceptedTrailers)) for k := range exceptedTrailers { keys = append(keys, k) } trailer.SetTrailers([]byte(strings.Join(keys, ", "))) err := ext.ReadTrailer(&trailer, r) if err == io.EOF && exceptedTrailers == nil { return } if err != nil { t.Fatalf("Cannot read trailer: %v", err) } for k, v := range exceptedTrailers { got := trailer.Peek(k) if !bytes.Equal(got, []byte(v)) { t.Fatalf("Unexpected trailer %q. Expected %q. Got %q", k, v, got) } } } func TestRequestChunkedWhitespace(t *testing.T) { t.Parallel() var req protocol.Request s := "POST /foo HTTP/1.1\r\nHost: google.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa/bb\r\n\r\n3 \r\nabc\r\n0\r\n\r\n" zr := mock.NewZeroCopyReader(s) err := Read(&req, zr) if err != nil { t.Fatalf("Unexpected error when reading chunked request: %s", err) } expectedBody := "abc" if string(req.Body()) != expectedBody { t.Fatalf("Unexpected body %q. Expected %q", req.Body(), expectedBody) } } func testRequestReadLimitBodyError(t *testing.T, s string, maxBodySize int) { var req protocol.Request zr := mock.NewZeroCopyReader(s) err := ReadHeaderAndLimitBody(&req, zr, maxBodySize) if err == nil { t.Fatalf("expecting error. s=%q, maxBodySize=%d", s, maxBodySize) } if !errors.Is(err, errs.ErrBodyTooLarge) { t.Fatalf("unexpected error: %s. Expecting %s. s=%q, maxBodySize=%d", err, errBodyTooLarge, s, maxBodySize) } } func testRequestReadLimitBodySuccess(t *testing.T, s string, maxBodySize int) { var req protocol.Request zr := mock.NewZeroCopyReader(s) if err := ReadHeaderAndLimitBody(&req, zr, maxBodySize); err != nil { t.Fatalf("unexpected error: %s. s=%q, maxBodySize=%d", err, s, maxBodySize) } } func testSetRequestBodyStream(t *testing.T, body string) { var req protocol.Request req.Header.SetHost("foobar.com") req.Header.SetMethod(consts.MethodPost) bodySize := len(body) if req.IsBodyStream() { t.Fatalf("IsBodyStream must return false") } req.SetBodyStream(bytes.NewBufferString(body), bodySize) if !req.IsBodyStream() { t.Fatalf("IsBodyStream must return true") } var w bytes.Buffer zw := netpoll.NewWriter(&w) if err := Write(&req, zw); err != nil { t.Fatalf("unexpected error when writing request: %s. body=%q", err, body) } if err := zw.Flush(); err != nil { t.Fatalf("unexpected error when flushing request: %s. body=%q", err, body) } var req1 protocol.Request br := bufio.NewReader(&w) zr := netpoll.NewReader(br) if err := Read(&req1, zr); err != nil { t.Fatalf("unexpected error when reading request: %s. body=%q", err, body) } if string(req1.Body()) != body { fmt.Println(string(req1.Body())) fmt.Println(body) t.Fatalf("unexpected body %q. Expecting %q", req1.Body(), body) } } func testSetRequestBodyStreamChunked(t *testing.T, body string, trailer map[string]string) { var req protocol.Request req.Header.SetHost("foobar.com") req.Header.SetMethod(consts.MethodPost) if req.IsBodyStream() { t.Fatalf("IsBodyStream must return false") } req.SetBodyStream(bytes.NewBufferString(body), -1) if !req.IsBodyStream() { t.Fatalf("IsBodyStream must return true") } var w bytes.Buffer zw := netpoll.NewWriter(&w) for k, v := range trailer { err := req.Header.Trailer().Add(k, v) if err != nil { t.Fatalf("unexpected error: %v", err) } } if err := Write(&req, zw); err != nil { t.Fatalf("unexpected error when writing request: %v, body=%q", err, body) } if err := zw.Flush(); err != nil { t.Fatalf("unexpected error when flushing request: %v, body=%q", err, body) } var req1 protocol.Request br := bufio.NewReader(&w) zr := netpoll.NewReader(br) if err := Read(&req1, zr); err != nil { t.Fatalf("unexpected error when reading request: %v. body=%q", err, body) } if string(req1.Body()) != body { t.Fatalf("unexpected body %q. Expecting %q", req1.Body(), body) } for k, v := range trailer { r := req.Header.Trailer().Peek(k) if string(r) != v { t.Fatalf("unexpected trailer %q. Expecting %q. Got %q", k, v, r) } } } func TestRequestMultipartForm(t *testing.T) { t.Parallel() var w bytes.Buffer mw := multipart.NewWriter(&w) for i := 0; i < 10; i++ { k := fmt.Sprintf("key_%d", i) v := fmt.Sprintf("value_%d", i) if err := mw.WriteField(k, v); err != nil { t.Fatalf("unexpected error: %s", err) } } boundary := mw.Boundary() if err := mw.Close(); err != nil { t.Fatalf("unexpected error: %s", err) } formData := w.Bytes() for i := 0; i < 5; i++ { formData = testRequestMultipartForm(t, boundary, formData, 10) testRequestMultipartFormNotPreParse(t, boundary, formData, 10) } // verify request unmarshalling / marshalling s := "POST / HTTP/1.1\r\nHost: aaa\r\nContent-Type: multipart/form-data; boundary=foobar\r\nContent-Length: 213\r\n\r\n--foobar\r\nContent-Disposition: form-data; name=\"key_0\"\r\n\r\nvalue_0\r\n--foobar\r\nContent-Disposition: form-data; name=\"key_1\"\r\n\r\nvalue_1\r\n--foobar\r\nContent-Disposition: form-data; name=\"key_2\"\r\n\r\nvalue_2\r\n--foobar--\r\n" var r protocol.Request zr := mock.NewZeroCopyReader(s) if err := Read(&r, zr); err != nil { t.Fatalf("unexpected error: %s", err) } s = GetHTTP1Request(&r).String() zr = mock.NewZeroCopyReader(s) if err := Read(&r, zr); err != nil { t.Fatalf("unexpected error: %s", err) } testRequestMultipartForm(t, "foobar", r.Body(), 3) } func testRequestMultipartForm(t *testing.T, boundary string, formData []byte, partsCount int) []byte { s := fmt.Sprintf("POST / HTTP/1.1\r\nHost: aaa\r\nContent-Type: multipart/form-data; boundary=%s\r\nContent-Length: %d\r\n\r\n%s", boundary, len(formData), formData) var req protocol.Request zr := mock.NewZeroCopyReader(s) if err := Read(&req, zr); err != nil { t.Fatalf("unexpected error: %s", err) } f, err := req.MultipartForm() if err != nil { t.Fatalf("unexpected error: %s", err) } defer req.RemoveMultipartFormFiles() if len(f.File) > 0 { t.Fatalf("unexpected files found in the multipart form: %d", len(f.File)) } if len(f.Value) != partsCount { t.Fatalf("unexpected number of values found: %d. Expecting %d", len(f.Value), partsCount) } for k, vv := range f.Value { if len(vv) != 1 { t.Fatalf("unexpected number of values found for key=%q: %d. Expecting 1", k, len(vv)) } if !strings.HasPrefix(k, "key_") { t.Fatalf("unexpected key prefix=%q. Expecting %q", k, "key_") } v := vv[0] if !strings.HasPrefix(v, "value_") { t.Fatalf("unexpected value prefix=%q. expecting %q", v, "value_") } if k[len("key_"):] != v[len("value_"):] { t.Fatalf("key and value suffixes don't match: %q vs %q", k, v) } } return req.Body() } func testRequestMultipartFormNotPreParse(t *testing.T, boundary string, formData []byte, partsCount int) []byte { s := fmt.Sprintf("POST / HTTP/1.1\r\nHost: aaa\r\nContent-Type: multipart/form-data; boundary=%s\r\nContent-Length: %d\r\n\r\n%s", boundary, len(formData), formData) var req protocol.Request zr := mock.NewZeroCopyReader(s) if err := Read(&req, zr, false); err != nil { t.Fatalf("unexpected error: %s", err) } f, err := req.MultipartForm() if err != nil { t.Fatalf("unexpected error: %s", err) } defer req.RemoveMultipartFormFiles() if len(f.File) > 0 { t.Fatalf("unexpected files found in the multipart form: %d", len(f.File)) } if len(f.Value) != partsCount { t.Fatalf("unexpected number of values found: %d. Expecting %d", len(f.Value), partsCount) } for k, vv := range f.Value { if len(vv) != 1 { t.Fatalf("unexpected number of values found for key=%q: %d. Expecting 1", k, len(vv)) } if !strings.HasPrefix(k, "key_") { t.Fatalf("unexpected key prefix=%q. Expecting %q", k, "key_") } v := vv[0] if !strings.HasPrefix(v, "value_") { t.Fatalf("unexpected value prefix=%q. expecting %q", v, "value_") } if k[len("key_"):] != v[len("value_"):] { t.Fatalf("key and value suffixes don't match: %q vs %q", k, v) } } return req.Body() } func testReadBodyChunked(t *testing.T, bodySize int) { body := mock.CreateFixedBody(bodySize) expectedTrailer := map[string]string{"Foo": "chunked shit"} chunkedBody := mock.CreateChunkedBody(body, expectedTrailer, true) zr := mock.NewZeroCopyReader(string(chunkedBody)) // p,_ := mr.Next(3687) b, err := ext.ReadBody(zr, -1, 0, nil) if err != nil { t.Fatalf("Unexpected error for bodySize=%d: %s. body=%q, chunkedBody=%q", bodySize, err, body, chunkedBody) } if !bytes.Equal(b, body) { t.Fatalf("Unexpected response read for bodySize=%d: %q. Expected %q. chunkedBody=%q", bodySize, b, body, chunkedBody) } verifyTrailer(t, zr, expectedTrailer) } func testReadBodyFixedSize(t *testing.T, bodySize int) { body := mock.CreateFixedBody(bodySize) zr := mock.NewZeroCopyReader(string(body)) b, err := ext.ReadBody(zr, bodySize, 0, nil) if err != nil { t.Fatalf("Unexpected error in ReadResponseBody(%d): %s", bodySize, err) } if !bytes.Equal(b, body) { t.Fatalf("Unexpected response read for bodySize=%d: %q. Expected %q", bodySize, b, body) } verifyTrailer(t, zr, nil) } func TestRequestFormFile(t *testing.T) { t.Parallel() s := `POST /upload HTTP/1.1 Host: localhost:10000 Content-Length: 521 Content-Type: multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg ------WebKitFormBoundaryJwfATyF8tmxSJnLg Content-Disposition: form-data; name="f1" value1 ------WebKitFormBoundaryJwfATyF8tmxSJnLg Content-Disposition: form-data; name="fileaaa"; filename="TODO" Content-Type: application/octet-stream - SessionClient with referer and cookies support. - Client with requests' pipelining support. - ProxyHandler similar to FSHandler. - WebSockets. See https://tools.ietf.org/html/rfc6455 . - HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . ------WebKitFormBoundaryJwfATyF8tmxSJnLg-- tailfoobar` mr := mock.NewZeroCopyReader(s) var r protocol.Request if err := Read(&r, mr); err != nil { t.Fatalf("unexpected error: %s", err) } tail, err := ioutil.ReadAll(mr) if err != nil { t.Fatalf("unexpected error: %s", err) } if string(tail) != "tailfoobar" { t.Fatalf("unexpected tail %q. Expecting %q", tail, "tailfoobar") } fh, err := r.FormFile("fileaaa") if err != nil { t.Fatalf("TestRequestFormFile error: %#v", err.Error()) } if fh == nil { t.Fatalf("fh unexpected nil") } } func TestRequest_ContinueReadBodyStream(t *testing.T) { // small body genBody := "abcdef4343" s := "PUT /foo/bar HTTP/1.1\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\n" testContinueReadBodyStream(t, s, genBody, 10, 5, 0, 5) testContinueReadBodyStream(t, s, genBody, 1, 5, 0, 0) // big body (> 8193) s1 := "PUT /foo/bar HTTP/1.1\r\nExpect: 100-continue\r\nContent-Length: 9216\r\nContent-Type: foo/bar\r\n\r\n" genBody = strings.Repeat("1", 9*1024) testContinueReadBodyStream(t, s1, genBody, 10*1024, 5*1024, 4*1024, 0) testContinueReadBodyStream(t, s1, genBody, 10*1024, 1*1024, 8*1024, 0) testContinueReadBodyStream(t, s1, genBody, 10*1024, 9*1024, 0*1024, 0) // normal stream testContinueReadBodyStream(t, s1, genBody, 1*1024, 5*1024, 4*1024, 0) testContinueReadBodyStream(t, s1, genBody, 1*1024, 1*1024, 8*1024, 0) testContinueReadBodyStream(t, s1, genBody, 1*1024, 9*1024, 0*1024, 0) testContinueReadBodyStream(t, s1, genBody, 5, 5*1024, 4*1024, 0) testContinueReadBodyStream(t, s1, genBody, 5, 1*1024, 8*1024, 0) testContinueReadBodyStream(t, s1, genBody, 5, 9*1024, 0, 0) // critical point testContinueReadBodyStream(t, s1, genBody, 8*1024+1, 5*1024, 4*1024, 0) testContinueReadBodyStream(t, s1, genBody, 8*1024+1, 1*1024, 8*1024, 0) testContinueReadBodyStream(t, s1, genBody, 8*1024+1, 9*1024, 0*1024, 0) // chunked body s2 := "POST /foo HTTP/1.1\r\nHost: google.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa/bb\r\n\r\n3\r\nabc\r\n5\r\n12345\r\n0\r\n\r\ntrail" testContinueReadBodyStream(t, s2, "", 10*1024, 3, 5, 5) s3 := "POST /foo HTTP/1.1\r\nHost: google.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa/bb\r\n\r\n3\r\nabc\r\n5\r\n12345\r\n0\r\n\r\n" testContinueReadBodyStream(t, s3, "", 10*1024, 3, 5, 0) } func TestRequest_Chunked(t *testing.T) { t.Parallel() s4 := "POST /foo HTTP/1.1\r\nHost: google.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa/bb\r\n\r\n5\r\n12345\r\n0\r\n\r\n" testReadChunked(t, s4, "", 3, 2) s5 := "POST /foo HTTP/1.1\r\nHost: google.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa/bb\r\n\r\n5\r\n12345\r\n3\r\n1230\r\n\r\n" testReadChunked(t, s5, "", 3, 5) } func TestRequest_ReadIncompleteStream(t *testing.T) { t.Parallel() // small body genBody := "abcdef4343" s := "PUT /foo/bar HTTP/1.1\r\nExpect: 100-continue\r\nContent-Length: 100\r\nContent-Type: foo/bar\r\n\r\n" testReadIncompleteStream(t, s, genBody) // big body (> 8193) s1 := "PUT /foo/bar HTTP/1.1\r\nExpect: 100-continue\r\nContent-Length: 10000\r\nContent-Type: foo/bar\r\n\r\n" genBody = strings.Repeat("1", 9*1024) testReadIncompleteStream(t, s1, genBody) } func testReadIncompleteStream(t *testing.T, header, body string) { mr := mock.NewZeroCopyReader(header + body) var r protocol.Request if err := ReadHeader(&r.Header, mr); err != nil { t.Fatalf("unexpected error: %s", err) } if err := ContinueReadBodyStream(&r, mr, 1, true); err != nil { t.Fatalf("error when reading request body stream: %s", err) } readBody, err := ioutil.ReadAll(r.BodyStream()) if !bytes.Equal(readBody, []byte(body)) || len(readBody) != len(body) { t.Fatalf("readBody is not equal to the rawBody: %b(len: %d)", readBody, len(readBody)) } if err != io.ErrUnexpectedEOF { t.Fatalf("error should be io.ErrUnexpectedEOF, but got: %s", err) } } func testReadChunked(t *testing.T, header, body string, firstRead, leftBytes int) { mr := mock.NewZeroCopyReader(header + body) var r protocol.Request if err := ReadHeader(&r.Header, mr); err != nil { t.Fatalf("unexpected error: %s", err) } if err := ContinueReadBodyStream(&r, mr, 2048, true); err != nil { t.Fatalf("error when reading request body stream: %s", err) } if r.Header.ContentLength() >= 0 { t.Fatalf("expect a chunked body") } streamRead := make([]byte, firstRead) fr, err := r.BodyStream().Read(streamRead) if err != nil { t.Fatalf("read stream error=%v", err) } if fr != firstRead { t.Fatalf("should read %d from stream body, but got %d", streamRead, fr) } leftB, _ := ioutil.ReadAll(r.BodyStream()) if len(leftB) != leftBytes { t.Fatalf("should left %d bytes from stream body, but left %d", leftBytes, len(leftB)) } } func testContinueReadBodyStream(t *testing.T, header, body string, maxBodySize, firstRead, leftBytes, bytesLeftInReader int) { mr := mock.NewZeroCopyReader(header + body) var r protocol.Request if err := ReadHeader(&r.Header, mr); err != nil { t.Fatalf("unexpected error: %s", err) } if err := ContinueReadBodyStream(&r, mr, maxBodySize, true); err != nil { t.Fatalf("error when reading request body stream: %s", err) } fRead := firstRead streamRead := make([]byte, fRead) sR, _ := r.BodyStream().Read(streamRead) if sR != firstRead { t.Fatalf("should read %d from stream body, but got %d", firstRead, sR) } leftB, _ := ioutil.ReadAll(r.BodyStream()) if len(leftB) != leftBytes { t.Fatalf("should left %d bytes from stream body, but left %d", leftBytes, len(leftB)) } if r.Header.ContentLength() > 0 { gotBody := append(streamRead, leftB...) if !bytes.Equal([]byte(body[:r.Header.ContentLength()]), gotBody) { t.Fatalf("body read from stream is not equal to the origin. Got: %s", gotBody) } } left, _ := mr.Peek(mr.Len()) if len(left) != bytesLeftInReader { fmt.Printf("##########header:%s,body:%s,%d:max,first:%d,left:%d,leftin:%d\n", header, body, maxBodySize, firstRead, leftBytes, bytesLeftInReader) fmt.Printf("##########left: %s\n", left) t.Fatalf("should left %d bytes in original reader. got %q", bytesLeftInReader, len(left)) } } func verifyRequestHeader(t *testing.T, h *protocol.RequestHeader, expectedContentLength int, expectedRequestURI, expectedHost, expectedReferer, expectedContentType string, ) { if h.ContentLength() != expectedContentLength { t.Fatalf("Unexpected Content-Length %d. Expected %d", h.ContentLength(), expectedContentLength) } if string(h.RequestURI()) != expectedRequestURI { t.Fatalf("Unexpected RequestURI %q. Expected %q", h.RequestURI(), expectedRequestURI) } if string(h.Peek(consts.HeaderHost)) != expectedHost { t.Fatalf("Unexpected host %q. Expected %q", h.Peek(consts.HeaderHost), expectedHost) } if string(h.Peek(consts.HeaderReferer)) != expectedReferer { t.Fatalf("Unexpected referer %q. Expected %q", h.Peek(consts.HeaderReferer), expectedReferer) } if string(h.Peek(consts.HeaderContentType)) != expectedContentType { t.Fatalf("Unexpected content-type %q. Expected %q", h.Peek(consts.HeaderContentType), expectedContentType) } } func TestRequestReadMultipartFormWithFile(t *testing.T) { t.Parallel() s := `POST /upload HTTP/1.1 Host: localhost:10000 Content-Length: 521 Content-Type: multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg ------WebKitFormBoundaryJwfATyF8tmxSJnLg Content-Disposition: form-data; name="f1" value1 ------WebKitFormBoundaryJwfATyF8tmxSJnLg Content-Disposition: form-data; name="fileaaa"; filename="TODO" Content-Type: application/octet-stream - SessionClient with referer and cookies support. - Client with requests' pipelining support. - ProxyHandler similar to FSHandler. - WebSockets. See https://tools.ietf.org/html/rfc6455 . - HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . ------WebKitFormBoundaryJwfATyF8tmxSJnLg-- tailfoobar` mr := mock.NewZeroCopyReader(s) var r protocol.Request if err := Read(&r, mr); err != nil { t.Fatalf("unexpected error: %s", err) } tail, err := ioutil.ReadAll(mr) if err != nil { t.Fatalf("unexpected error: %s", err) } if string(tail) != "tailfoobar" { t.Fatalf("unexpected tail %q. Expecting %q", tail, "tailfoobar") } f, err := r.MultipartForm() if err != nil { t.Fatalf("unexpected error: %s", err) } defer r.RemoveMultipartFormFiles() // verify values if len(f.Value) != 1 { t.Fatalf("unexpected number of values in multipart form: %d. Expecting 1", len(f.Value)) } for k, vv := range f.Value { if k != "f1" { t.Fatalf("unexpected value name %q. Expecting %q", k, "f1") } if len(vv) != 1 { t.Fatalf("unexpected number of values %d. Expecting 1", len(vv)) } v := vv[0] if v != "value1" { t.Fatalf("unexpected value %q. Expecting %q", v, "value1") } } // verify files if len(f.File) != 1 { t.Fatalf("unexpected number of file values in multipart form: %d. Expecting 1", len(f.File)) } for k, vv := range f.File { if k != "fileaaa" { t.Fatalf("unexpected file value name %q. Expecting %q", k, "fileaaa") } if len(vv) != 1 { t.Fatalf("unexpected number of file values %d. Expecting 1", len(vv)) } v := vv[0] if v.Filename != "TODO" { t.Fatalf("unexpected filename %q. Expecting %q", v.Filename, "TODO") } ct := v.Header.Get("Content-Type") if ct != "application/octet-stream" { t.Fatalf("unexpected content-type %q. Expecting %q", ct, "application/octet-stream") } } } func testRequestMultipartFormBoundary(t *testing.T, s, boundary string) { var h protocol.RequestHeader zr := mock.NewZeroCopyReader(s) if err := ReadHeader(&h, zr); err != nil { t.Fatalf("unexpected error: %s. s=%q, boundary=%q", err, s, boundary) } b := h.MultipartFormBoundary() if string(b) != boundary { t.Fatalf("unexpected boundary %q. Expecting %q. s=%q", b, boundary, s) } } func TestStreamNotEnoughData(t *testing.T) { req := protocol.AcquireRequest() req.Header.SetContentLength(1 << 16) conn := mock.NewStreamConn() const maxBodySize = 4 * 1024 * 1024 err := ContinueReadBodyStream(req, conn, maxBodySize) assert.Nil(t, err) err = ext.ReleaseBodyStream(req.BodyStream()) assert.Nil(t, err) assert.DeepEqual(t, 0, len(conn.Data)) assert.DeepEqual(t, true, conn.HasReleased) } func TestRequestBodyStreamWithTrailer(t *testing.T) { t.Parallel() testRequestBodyStreamWithTrailer(t, []byte("test"), false) testRequestBodyStreamWithTrailer(t, mock.CreateFixedBody(4097), false) testRequestBodyStreamWithTrailer(t, mock.CreateFixedBody(105000), false) } func testRequestBodyStreamWithTrailer(t *testing.T, body []byte, disableNormalizing bool) { expectedTrailer := map[string]string{ "foo": "testfoo", "bar": "testbar", } var req1 protocol.Request if disableNormalizing { req1.Header.DisableNormalizing() } req1.SetHost("google.com") req1.SetBodyStream(bytes.NewBuffer(body), -1) for k, v := range expectedTrailer { err := req1.Header.Trailer().Set(k, v) if err != nil { t.Fatalf("unexpected error: %s", err) } } w := &bytes.Buffer{} zw := netpoll.NewWriter(w) if err := Write(&req1, zw); err != nil { t.Fatalf("unexpected error: %s", err) } if err := zw.Flush(); err != nil { t.Fatalf("unexpected error: %s", err) } var req2 protocol.Request if disableNormalizing { req2.Header.DisableNormalizing() } br := netpoll.NewReader(w) if err := Read(&req2, br); err != nil { t.Fatalf("unexpected error: %s", err) } reqBody := req2.Body() if !bytes.Equal(reqBody, body) { t.Fatalf("unexpected body: %q. Excepting %q", reqBody, body) } for k, v := range expectedTrailer { kBytes := []byte(k) utils.NormalizeHeaderKey(kBytes, disableNormalizing) r := req2.Header.Trailer().Peek(k) if string(r) != v { t.Fatalf("unexpected trailer header %q: %q. Expecting %s", kBytes, r, v) } } } ================================================ FILE: pkg/protocol/http1/resp/header.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package resp import ( "bytes" "errors" "fmt" "io" "strings" "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/internal/bytestr" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/protocol/http1/ext" ) var errTimeout = errs.New(errs.ErrTimeout, errs.ErrorTypePublic, "read response header") // Read reads response header from r. // // io.EOF is returned if r is closed before reading the first header byte. func ReadHeader(h *protocol.ResponseHeader, r network.Reader) error { n := 1 for { err := tryRead(h, r, n) if err == nil { return nil } if !errors.Is(err, errs.ErrNeedMore) { h.ResetSkipNormalize() return err } // No more data available on the wire, try block peek(by netpoll) if n == r.Len() { n++ continue } n = r.Len() } } // WriteHeader writes response header to w. func WriteHeader(h *protocol.ResponseHeader, w network.Writer) error { // Data may become invalid after the next call of ResponseHeader. // copy before WriteHeader returns header := h.Header() b, err := w.Malloc(len(header)) if err != nil { return err } h.SetHeaderLength(copy(b, header)) return nil } // ConnectionUpgrade returns true if 'Connection: Upgrade' header is set. func ConnectionUpgrade(h *protocol.ResponseHeader) bool { return ext.HasHeaderValue(h.Peek(consts.HeaderConnection), bytestr.StrKeepAlive) } func tryRead(h *protocol.ResponseHeader, r network.Reader, n int) error { h.ResetSkipNormalize() b, err := r.Peek(n) if len(b) == 0 { // Return ErrTimeout on any timeout. if err != nil && strings.Contains(err.Error(), "timeout") { return errTimeout } // treat all other errors on the first byte read as EOF if n == 1 || err == io.EOF { return io.EOF } return fmt.Errorf("error when reading response headers: %s", err) } b = ext.MustPeekBuffered(r) headersLen, errParse := parse(h, b) if errParse != nil { return ext.HeaderError("response", err, errParse, b) } ext.MustDiscard(r, headersLen) return nil } func parseHeaders(h *protocol.ResponseHeader, buf []byte) (int, error) { // 'identity' content-length by default h.InitContentLengthWithValue(-2) var s ext.HeaderScanner s.B = buf s.DisableNormalizing = h.IsDisableNormalizing() var err error for s.Next() { if len(s.Key) > 0 { switch s.Key[0] | 0x20 { case 'c': if utils.CaseInsensitiveCompare(s.Key, bytestr.StrContentType) { h.SetContentTypeBytes(s.Value) continue } if utils.CaseInsensitiveCompare(s.Key, bytestr.StrContentEncoding) { h.SetContentEncodingBytes(s.Value) continue } if utils.CaseInsensitiveCompare(s.Key, bytestr.StrContentLength) { var contentLength int if h.ContentLength() != -1 { if contentLength, err = protocol.ParseContentLength(s.Value); err != nil { h.InitContentLengthWithValue(-2) } else { h.InitContentLengthWithValue(contentLength) h.SetContentLengthBytes(s.Value) } } continue } if utils.CaseInsensitiveCompare(s.Key, bytestr.StrConnection) { if bytes.Equal(s.Value, bytestr.StrClose) { h.SetConnectionClose(true) } else { h.SetConnectionClose(false) h.AddArgBytes(s.Key, s.Value, protocol.ArgsHasValue) } continue } case 's': if utils.CaseInsensitiveCompare(s.Key, bytestr.StrServer) { h.SetServerBytes(s.Value) continue } if utils.CaseInsensitiveCompare(s.Key, bytestr.StrSetCookie) { h.ParseSetCookie(s.Value) continue } case 't': if utils.CaseInsensitiveCompare(s.Key, bytestr.StrTransferEncoding) { if !bytes.Equal(s.Value, bytestr.StrIdentity) { h.InitContentLengthWithValue(-1) h.SetArgBytes(bytestr.StrTransferEncoding, bytestr.StrChunked, protocol.ArgsHasValue) } continue } if utils.CaseInsensitiveCompare(s.Key, bytestr.StrTrailer) { err = h.Trailer().SetTrailers(s.Value) continue } } h.AddArgBytes(s.Key, s.Value, protocol.ArgsHasValue) } } if s.Err != nil { h.SetConnectionClose(true) return 0, s.Err } if h.ContentLength() < 0 { h.SetContentLengthBytes(h.ContentLengthBytes()[:0]) } if h.ContentLength() == -2 && !ConnectionUpgrade(h) && !h.MustSkipContentLength() { h.SetArgBytes(bytestr.StrTransferEncoding, bytestr.StrIdentity, protocol.ArgsHasValue) h.SetConnectionClose(true) } if !h.IsHTTP11() && !h.ConnectionClose() { // close connection for non-http/1.1 response unless 'Connection: keep-alive' is set. v := h.PeekArgBytes(bytestr.StrConnection) h.SetConnectionClose(!ext.HasHeaderValue(v, bytestr.StrKeepAlive)) } return len(buf) - len(s.B), err } func parse(h *protocol.ResponseHeader, buf []byte) (int, error) { m, err := parseFirstLine(h, buf) if err != nil { return 0, err } n, err := parseHeaders(h, buf[m:]) if err != nil { return 0, err } return m + n, nil } func parseFirstLine(h *protocol.ResponseHeader, buf []byte) (int, error) { bNext := buf var b []byte var err error for len(b) == 0 { if b, bNext, err = utils.NextLine(bNext); err != nil { return 0, err } } // parse protocol n := bytes.IndexByte(b, ' ') if n < 0 { return 0, fmt.Errorf("cannot find whitespace in the first line of response %q", buf) } isHTTP11 := bytes.Equal(b[:n], bytestr.StrHTTP11) if !isHTTP11 { h.SetProtocol(consts.HTTP10) } else { h.SetProtocol(consts.HTTP11) } b = b[n+1:] // parse status code var statusCode int statusCode, n, err = bytesconv.ParseUintBuf(b) h.SetStatusCode(statusCode) if err != nil { return 0, fmt.Errorf("cannot parse response status code: %s. Response %q", err, buf) } if len(b) > n && b[n] != ' ' { return 0, fmt.Errorf("unexpected char at the end of status code. Response %q", buf) } return len(buf) - len(bNext), nil } ================================================ FILE: pkg/protocol/http1/resp/header_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package resp import ( "bufio" "bytes" "net/http" "strings" "testing" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/netpoll" ) func TestResponseHeaderCookie(t *testing.T) { t.Parallel() var h protocol.ResponseHeader var c protocol.Cookie c.SetKey("foobar") c.SetValue("aaa") h.SetCookie(&c) c.SetKey("йцук") c.SetDomain("foobar.com") h.SetCookie(&c) c.Reset() c.SetKey("foobar") if !h.Cookie(&c) { t.Fatalf("Cannot find cookie %q", c.Key()) } var expectedC1 protocol.Cookie expectedC1.SetKey("foobar") expectedC1.SetValue("aaa") if !equalCookie(&expectedC1, &c) { t.Fatalf("unexpected cookie\n%#v\nExpected\n%#v\n", &c, &expectedC1) } c.SetKey("йцук") if !h.Cookie(&c) { t.Fatalf("cannot find cookie %q", c.Key()) } var expectedC2 protocol.Cookie expectedC2.SetKey("йцук") expectedC2.SetValue("aaa") expectedC2.SetDomain("foobar.com") if !equalCookie(&expectedC2, &c) { t.Fatalf("unexpected cookie\n%v\nExpected\n%v\n", &c, &expectedC2) } h.VisitAllCookie(func(key, value []byte) { var cc protocol.Cookie if err := cc.ParseBytes(value); err != nil { t.Fatal(err) } if !bytes.Equal(key, cc.Key()) { t.Fatalf("Unexpected cookie key %q. Expected %q", key, cc.Key()) } switch { case bytes.Equal(key, []byte("foobar")): if !equalCookie(&expectedC1, &cc) { t.Fatalf("unexpected cookie\n%v\nExpected\n%v\n", &cc, &expectedC1) } case bytes.Equal(key, []byte("йцук")): if !equalCookie(&expectedC2, &cc) { t.Fatalf("unexpected cookie\n%v\nExpected\n%v\n", &cc, &expectedC2) } default: t.Fatalf("unexpected cookie key %q", key) } }) w := &bytes.Buffer{} zw := netpoll.NewWriter(w) if err := WriteHeader(&h, zw); err != nil { t.Fatalf("unexpected error: %s", err) } if err := zw.Flush(); err != nil { t.Fatalf("unexpected error: %s", err) } h.DelAllCookies() var h1 protocol.ResponseHeader zr := netpoll.NewReader(w) if err := ReadHeader(&h1, zr); err != nil { t.Fatalf("unexpected error: %s", err) } c.SetKey("foobar") if !h1.Cookie(&c) { t.Fatalf("Cannot find cookie %q", c.Key()) } if !equalCookie(&expectedC1, &c) { t.Fatalf("unexpected cookie\n%v\nExpected\n%v\n", &c, &expectedC1) } h1.DelCookie("foobar") if h.Cookie(&c) { t.Fatalf("Unexpected cookie found: %v", &c) } if h1.Cookie(&c) { t.Fatalf("Unexpected cookie found: %v", &c) } c.SetKey("йцук") if !h1.Cookie(&c) { t.Fatalf("cannot find cookie %q", c.Key()) } if !equalCookie(&expectedC2, &c) { t.Fatalf("unexpected cookie\n%v\nExpected\n%v\n", &c, &expectedC2) } h1.DelCookie("йцук") if h.Cookie(&c) { t.Fatalf("Unexpected cookie found: %v", &c) } if h1.Cookie(&c) { t.Fatalf("Unexpected cookie found: %v", &c) } } func equalCookie(c1, c2 *protocol.Cookie) bool { if !bytes.Equal(c1.Key(), c2.Key()) { return false } if !bytes.Equal(c1.Value(), c2.Value()) { return false } if !c1.Expire().Equal(c2.Expire()) { return false } if !bytes.Equal(c1.Domain(), c2.Domain()) { return false } if !bytes.Equal(c1.Path(), c2.Path()) { return false } return true } func TestResponseHeaderMultiLineValue(t *testing.T) { s := "HTTP/1.1 200 OK\r\n" + "EmptyValue1:\r\n" + "Content-Type: foo/bar;\r\n\tnewline;\r\n another/newline\r\n" + "Foo: Bar\r\n" + "Multi-Line: one;\r\n two\r\n" + "Values: v1;\r\n v2; v3;\r\n v4;\tv5\r\n" + "\r\n" header := new(protocol.ResponseHeader) if _, err := parse(header, []byte(s)); err != nil { t.Fatalf("parse headers with multi-line values failed, %s", err) } response, err := http.ReadResponse(bufio.NewReader(strings.NewReader(s)), nil) if err != nil { t.Fatalf("parse response using net/http failed, %s", err) } for name, vals := range response.Header { got := string(header.Peek(name)) want := vals[0] if got != want { t.Errorf("unexpected %s got: %q want: %q", name, got, want) } } } ================================================ FILE: pkg/protocol/http1/resp/response.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package resp import ( "errors" "fmt" "io" "runtime" "sync" "github.com/cloudwego/hertz/pkg/common/bytebufferpool" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/hlog" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/protocol/http1/ext" ) // ErrBodyStreamWritePanic is returned when panic happens during writing body stream. type ErrBodyStreamWritePanic struct { error } type h1Response struct { *protocol.Response } // String returns request representation. // // Returns error message instead of request representation on error. // // Use Write instead of String for performance-critical code. func (h1Resp *h1Response) String() string { w := bytebufferpool.Get() zw := network.NewWriter(w) if err := Write(h1Resp.Response, zw); err != nil { return err.Error() } if err := zw.Flush(); err != nil { return err.Error() } s := string(w.B) bytebufferpool.Put(w) return s } func GetHTTP1Response(resp *protocol.Response) fmt.Stringer { return &h1Response{resp} } // ReadHeaders reads http header into *protocol.Response func ReadHeaders(resp *protocol.Response, r network.Reader) error { resp.ResetBody() err := ReadHeader(&resp.Header, r) if err != nil { return err } if resp.Header.StatusCode() == consts.StatusContinue { // Read the next response according to http://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html . if err = ReadHeader(&resp.Header, r); err != nil { return err } } return nil } // ReadHeaderAndLimitBody ... func ReadHeaderAndLimitBody(resp *protocol.Response, r network.Reader, maxBodySize int) error { if err := ReadHeaders(resp, r); err != nil { return err } return ReadRespBody(resp, r, maxBodySize) } // ReadRespBody reads response body from the given r, limiting the body size. // // If maxBodySize > 0 and the body size exceeds maxBodySize, // then ErrBodyTooLarge is returned. // // io.EOF is returned if r is closed before reading the first header byte. func ReadRespBody(resp *protocol.Response, r network.Reader, maxBodySize int) (err error) { if resp.MustSkipBody() { return nil } bodyBuf := resp.BodyBuffer() bodyBuf.Reset() bodyBuf.B, err = ext.ReadBody(r, resp.Header.ContentLength(), maxBodySize, bodyBuf.B) if err != nil { return err } if resp.Header.ContentLength() == -1 { err = ext.ReadTrailer(resp.Header.Trailer(), r) if err != nil && err != io.EOF { return err } } resp.Header.SetContentLength(len(bodyBuf.B)) return nil } type clientRespStream struct { mu sync.Mutex r io.Reader closeCallback func(shouldClose bool) error } // ForceClose closes underlying conn. It enables `Read` call to return instead of blocking. // // This method is ONLY used by hertz internally. // Normally, users call `Close` when the body is no longer used. func (c *clientRespStream) ForceClose() (err error) { c.mu.Lock() defer c.mu.Unlock() if c.closeCallback != nil { err = c.closeCallback(true) c.closeCallback = nil } // NOTE: DO NOT put back to pool here, // user may still use clientRespStream and call Close() like `defer body.Close()` return } // Close closes response stream gracefully. // // NOTE: // * Since Close() will put it back to pool, MUST ensure it only be called when no longer use. // * MUST NOT call Close() and `Read()` concurrently to avoid race issue func (c *clientRespStream) Close() (err error) { c.mu.Lock() defer c.mu.Unlock() runtime.SetFinalizer(c, nil) // If error happened in ReleaseBodyStream, the connection may be in abnormal state. // Close it in the callback in order to avoid other unexpected problems. err = ext.ReleaseBodyStream(c.r) if c.closeCallback != nil { if err != nil { hlog.SystemLogger().Warnf("error occurred during the stream body close: %s", err) } err = c.closeCallback(err != nil) } c.r = nil c.closeCallback = nil clientRespStreamPool.Put(c) return } func (c *clientRespStream) Read(p []byte) (n int, err error) { return c.r.Read(p) } var clientRespStreamPool = sync.Pool{ New: func() interface{} { return &clientRespStream{} }, } func convertClientRespStream(bs io.Reader, fn func(shouldClose bool) error) *clientRespStream { clientStream := clientRespStreamPool.Get().(*clientRespStream) clientStream.r = bs clientStream.closeCallback = fn runtime.SetFinalizer(clientStream, (*clientRespStream).Close) return clientStream } // ReadHeaderBodyStream ... func ReadHeaderBodyStream(resp *protocol.Response, r network.Reader, maxBodySize int, closeCallBack func(shouldClose bool) error) error { if err := ReadHeaders(resp, r); err != nil { return err } return ReadRespBodyStream(resp, r, maxBodySize, closeCallBack) } // Deprecated: use ReadHeaderBodyStream var ReadBodyStream = ReadHeaderBodyStream // ReadRespBodyStream reads response body in stream func ReadRespBodyStream(resp *protocol.Response, r network.Reader, maxBodySize int, closeCallBack func(shouldClose bool) error) (err error) { if resp.MustSkipBody() { return nil } bodyBuf := resp.BodyBuffer() bodyBuf.Reset() bodyBuf.B, err = ext.ReadBodyWithStreaming(r, resp.Header.ContentLength(), maxBodySize, bodyBuf.B) if err != nil { if errors.Is(err, errs.ErrBodyTooLarge) { bodyStream := ext.AcquireBodyStream(bodyBuf, r, resp.Header.Trailer(), resp.Header.ContentLength()) resp.ConstructBodyStream(bodyBuf, convertClientRespStream(bodyStream, closeCallBack)) return nil } if errors.Is(err, errs.ErrChunkedStream) { bodyStream := ext.AcquireBodyStream(bodyBuf, r, resp.Header.Trailer(), -1) resp.ConstructBodyStream(bodyBuf, convertClientRespStream(bodyStream, closeCallBack)) return nil } resp.Reset() return err } bodyStream := ext.AcquireBodyStream(bodyBuf, r, resp.Header.Trailer(), resp.Header.ContentLength()) resp.ConstructBodyStream(bodyBuf, convertClientRespStream(bodyStream, closeCallBack)) return nil } // Read reads response (including body) from the given r. // // io.EOF is returned if r is closed before reading the first header byte. func Read(resp *protocol.Response, r network.Reader) error { return ReadHeaderAndLimitBody(resp, r, 0) } // Write writes response to w. // // Write doesn't flush response to w for performance reasons. // // See also WriteTo. func Write(resp *protocol.Response, w network.Writer) error { sendBody := !resp.MustSkipBody() if resp.IsBodyStream() { return writeBodyStream(resp, w, sendBody) } body := resp.BodyBytes() bodyLen := len(body) if sendBody || bodyLen > 0 { resp.Header.SetContentLength(bodyLen) } header := resp.Header.Header() _, err := w.WriteBinary(header) if err != nil { return err } resp.Header.SetHeaderLength(len(header)) // Write body if sendBody && bodyLen > 0 { _, err = w.WriteBinary(body) } return err } func writeBodyStream(resp *protocol.Response, w network.Writer, sendBody bool) (err error) { defer func() { if r := recover(); r != nil { err = &ErrBodyStreamWritePanic{ error: fmt.Errorf("panic while writing body stream: %+v", r), } } }() contentLength := resp.Header.ContentLength() if contentLength < 0 { lrSize := ext.LimitedReaderSize(resp.BodyStream()) if lrSize >= 0 { contentLength = int(lrSize) if int64(contentLength) != lrSize { contentLength = -1 } if contentLength >= 0 { resp.Header.SetContentLength(contentLength) } } } if contentLength >= 0 { if err = WriteHeader(&resp.Header, w); err == nil && sendBody { if resp.ImmediateHeaderFlush { err = w.Flush() } if err == nil { err = ext.WriteBodyFixedSize(w, resp.BodyStream(), int64(contentLength)) } } } else { resp.Header.SetContentLength(-1) if err = WriteHeader(&resp.Header, w); err == nil && sendBody { if resp.ImmediateHeaderFlush { err = w.Flush() } if err == nil { err = ext.WriteBodyChunked(w, resp.BodyStream()) } if err == nil { err = ext.WriteTrailer(resp.Header.Trailer(), w) } } } err1 := resp.CloseBodyStream() if err == nil { err = err1 } return err } ================================================ FILE: pkg/protocol/http1/resp/response_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package resp import ( "bufio" "bytes" "errors" "io" "io/ioutil" "strings" "testing" "github.com/cloudwego/hertz/internal/bytestr" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/test/mock" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/protocol/http1/ext" "github.com/cloudwego/netpoll" ) var errBodyTooLarge = errs.New(errs.ErrBodyTooLarge, errs.ErrorTypePublic, "test") type ErroneousBodyStream struct { errOnRead bool errOnClose bool } type testReader struct { read chan (int) cb chan (struct{}) } func (r *testReader) Read(b []byte) (int, error) { read := <-r.read if read == -1 { return 0, io.EOF } r.cb <- struct{}{} total := len(b) if total > read { total = read } for i := 0; i < total; i++ { b[i] = 'x' } return total, nil } func (ebs *ErroneousBodyStream) Read(p []byte) (n int, err error) { if ebs.errOnRead { panic("reading erroneous body stream") } return 0, io.EOF } func (ebs *ErroneousBodyStream) Close() error { if ebs.errOnClose { panic("closing erroneous body stream") } return nil } func TestResponseBodyStreamErrorOnPanicDuringClose(t *testing.T) { t.Parallel() var resp protocol.Response var w bytes.Buffer zw := netpoll.NewWriter(&w) ebs := &ErroneousBodyStream{errOnRead: false, errOnClose: true} resp.SetBodyStream(ebs, 42) err := Write(&resp, zw) if err == nil { t.Fatalf("expected error when writing response.") } e, ok := err.(*ErrBodyStreamWritePanic) if !ok { t.Fatalf("expected error struct to be *ErrBodyStreamWritePanic, got: %+v.", e) } if e.Error() != "panic while writing body stream: closing erroneous body stream" { t.Fatalf("unexpected error value, got: %+v.", e.Error()) } } func TestResponseBodyStreamErrorOnPanicDuringRead(t *testing.T) { t.Parallel() var resp protocol.Response var w bytes.Buffer zw := netpoll.NewWriter(&w) ebs := &ErroneousBodyStream{errOnRead: true, errOnClose: false} resp.SetBodyStream(ebs, 42) err := Write(&resp, zw) if err == nil { t.Fatalf("expected error when writing response.") } e, ok := err.(*ErrBodyStreamWritePanic) if !ok { t.Fatalf("expected error struct to be *ErrBodyStreamWritePanic, got: %+v.", e) } if e.Error() != "panic while writing body stream: reading erroneous body stream" { t.Fatalf("unexpected error value, got: %+v.", e.Error()) } } func testResponseReadError(t *testing.T, resp *protocol.Response, response string) { zr := mock.NewZeroCopyReader(response) err := Read(resp, zr) if err == nil { t.Fatalf("Expecting error for response=%q", response) } testResponseReadSuccess(t, resp, "HTTP/1.1 303 Redisred sedfs sdf\r\nContent-Type: aaa\r\nContent-Length: 5\r\n\r\nHELLOaaa", consts.StatusSeeOther, 5, "aaa", "HELLO", nil, consts.HTTP11) } func testResponseReadSuccess(t *testing.T, resp *protocol.Response, response string, expectedStatusCode, expectedContentLength int, expectedContentType, expectedBody string, expectedTrailer map[string]string, expectedProtocol string, ) { zr := mock.NewZeroCopyReader(response) err := Read(resp, zr) if err != nil { t.Fatalf("Unexpected error: %s", err) } verifyResponseHeader(t, &resp.Header, expectedStatusCode, expectedContentLength, expectedContentType, "", expectedProtocol) if !bytes.Equal(resp.Body(), []byte(expectedBody)) { t.Fatalf("Unexpected body %q. Expected %q", resp.Body(), []byte(expectedBody)) } verifyResponseTrailer(t, &resp.Header, expectedTrailer) } func TestResponseReadSuccess(t *testing.T) { t.Parallel() resp := &protocol.Response{} // usual response testResponseReadSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Length: 10\r\nContent-Type: foo/bar\r\n\r\n0123456789", consts.StatusOK, 10, "foo/bar", "0123456789", nil, consts.HTTP11) // zero response testResponseReadSuccess(t, resp, "HTTP/1.1 500 OK\r\nContent-Length: 0\r\nContent-Type: foo/bar\r\n\r\n", consts.StatusInternalServerError, 0, "foo/bar", "", nil, consts.HTTP11) // response with trailer testResponseReadSuccess(t, resp, "HTTP/1.1 300 OK\r\nTransfer-Encoding: chunked\r\nTrailer: foo\r\nContent-Type: bar\r\n\r\n5\r\n56789\r\n0\r\nfoo: bar\r\n\r\n", consts.StatusMultipleChoices, 5, "bar", "56789", map[string]string{"Foo": "bar"}, consts.HTTP11) // response with trailer disableNormalizing resp.Header.DisableNormalizing() resp.Header.Trailer().DisableNormalizing() testResponseReadSuccess(t, resp, "HTTP/1.1 300 OK\r\nTransfer-Encoding: chunked\r\nTrailer: foo\r\nContent-Type: bar\r\n\r\n5\r\n56789\r\n0\r\nfoo: bar\r\n\r\n", consts.StatusMultipleChoices, 5, "bar", "56789", map[string]string{"foo": "bar"}, consts.HTTP11) // no content-length ('identity' transfer-encoding) testResponseReadSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: foobar\r\n\r\nzxxxx", consts.StatusOK, 5, "foobar", "zxxxx", nil, consts.HTTP11) // explicitly stated 'Transfer-Encoding: identity' testResponseReadSuccess(t, resp, "HTTP/1.1 234 ss\r\nContent-Type: xxx\r\n\r\nxag", 234, 3, "xxx", "xag", nil, consts.HTTP11) // big 'identity' response body := string(mock.CreateFixedBody(100500)) testResponseReadSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\n\r\n"+body, consts.StatusOK, 100500, "aa", body, nil, consts.HTTP11) // chunked response testResponseReadSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nTrailer: Foo2\r\nTransfer-Encoding: chunked\r\n\r\n4\r\nqwer\r\n2\r\nty\r\n0\r\nFoo2: bar2\r\n\r\n", 200, 6, "text/html", "qwerty", map[string]string{"Foo2": "bar2"}, consts.HTTP11) // chunked response with non-chunked Transfer-Encoding. testResponseReadSuccess(t, resp, "HTTP/1.1 230 OK\r\nContent-Type: text\r\nTrailer: Foo3\r\nTransfer-Encoding: aaabbb\r\n\r\n2\r\ner\r\n2\r\nty\r\n0\r\nFoo3: bar3\r\n\r\n", 230, 4, "text", "erty", map[string]string{"Foo3": "bar3"}, consts.HTTP11) // chunked response with empty body testResponseReadSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nTrailer: Foo5\r\nTransfer-Encoding: chunked\r\n\r\n0\r\nFoo5: bar5\r\n\r\n", consts.StatusOK, 0, "text/html", "", map[string]string{"Foo5": "bar5"}, consts.HTTP11) } func TestResponseReadError(t *testing.T) { t.Parallel() resp := &protocol.Response{} // empty response testResponseReadError(t, resp, "") // invalid header testResponseReadError(t, resp, "foobar") // empty body testResponseReadError(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: aaa\r\nContent-Length: 1234\r\n\r\n") // short body testResponseReadError(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: aaa\r\nContent-Length: 1234\r\n\r\nshort") } func TestResponseImmediateHeaderFlushChunked(t *testing.T) { t.Parallel() var r protocol.Response r.ImmediateHeaderFlush = true ch := make(chan int) cb := make(chan struct{}) buf := &testReader{read: ch, cb: cb} r.SetBodyStream(buf, -1) w := bytes.NewBuffer([]byte{}) zw := netpoll.NewWriter(w) waitForIt := make(chan struct{}) go func() { if err := Write(&r, zw); err != nil { t.Errorf("unexpected error: %s", err) } waitForIt <- struct{}{} }() ch <- 3 if !strings.Contains(w.String(), "Transfer-Encoding: chunked") { t.Fatalf("Expected headers to be flushed") } if strings.Contains(w.String(), "xxx") { t.Fatalf("Did not expect body to be written yet") } <-cb ch <- -1 <-waitForIt } func TestResponseImmediateHeaderFlushFixedLength(t *testing.T) { t.Parallel() var r protocol.Response r.ImmediateHeaderFlush = true ch := make(chan int) cb := make(chan struct{}) buf := &testReader{read: ch, cb: cb} r.SetBodyStream(buf, 3) w := bytes.NewBuffer([]byte{}) zw := netpoll.NewWriter(w) waitForIt := make(chan struct{}) go func() { if err := Write(&r, zw); err != nil { t.Errorf("unexpected error: %s", err) } waitForIt <- struct{}{} }() // reader have more data than bodySize, but only the bodySize length of data will be send. ch <- 10 if !strings.Contains(w.String(), "Content-Length: 3") { t.Fatalf("Expected headers to be flushed") } if strings.Contains(w.String(), "xxx") { t.Fatalf("Did not expect body to be written yet") } <-cb // ch <- -1 <-waitForIt } func TestResponseImmediateHeaderFlushFixedLengthWithFewerData(t *testing.T) { t.Parallel() var r protocol.Response r.ImmediateHeaderFlush = true ch := make(chan int) cb := make(chan struct{}) buf := &testReader{read: ch, cb: cb} r.SetBodyStream(buf, 3) w := bytes.NewBuffer([]byte{}) zw := netpoll.NewWriter(w) waitForIt := make(chan struct{}) go func() { if err := Write(&r, zw); err != nil { assert.NotNil(t, err) } waitForIt <- struct{}{} }() // reader have less data than bodySize, server should raise a error in this case ch <- 2 <-cb ch <- -1 <-waitForIt } func TestResponseSuccess(t *testing.T) { t.Parallel() // 200 response testResponseSuccess(t, consts.StatusOK, "test/plain", "server", "foobar", consts.StatusOK, "test/plain", "server") // response with missing statusCode testResponseSuccess(t, 0, "text/plain", "server", "foobar", consts.StatusOK, "text/plain", "server") // response with missing server testResponseSuccess(t, consts.StatusInternalServerError, "aaa", "", "aaadfsd", consts.StatusInternalServerError, "aaa", "") // empty body testResponseSuccess(t, consts.StatusOK, "bbb", "qwer", "", consts.StatusOK, "bbb", "qwer") // missing content-type testResponseSuccess(t, consts.StatusOK, "", "asdfsd", "asdf", consts.StatusOK, string(bytestr.DefaultContentType), "asdfsd") } func TestResponseReadLimitBody(t *testing.T) { t.Parallel() // response with content-length testResponseReadLimitBodySuccess(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 10\r\n\r\n9876543210", 10) testResponseReadLimitBodySuccess(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 10\r\n\r\n9876543210", 100) testResponseReadLimitBodyError(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 10\r\n\r\n9876543210", 9) // response with content-encoding testResponseReadLimitBodySuccess(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Encoding: gzip\r\n\r\n9876543210", 10) // chunked response testResponseReadLimitBodySuccess(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nTransfer-Encoding: chunked\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\n\r\n", 9) testResponseReadLimitBodySuccess(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nTransfer-Encoding: chunked\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\nFoo: bar\r\n\r\n", 9) testResponseReadLimitBodySuccess(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nTransfer-Encoding: chunked\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\n\r\n", 100) testResponseReadLimitBodySuccess(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nTransfer-Encoding: chunked\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\nfoobar\r\n\r\n", 100) testResponseReadLimitBodyError(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nTransfer-Encoding: chunked\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\n\r\n", 2) // identity response testResponseReadLimitBodySuccess(t, "HTTP/1.1 400 OK\r\nContent-Type: aa\r\n\r\n123456", 6) testResponseReadLimitBodySuccess(t, "HTTP/1.1 400 OK\r\nContent-Type: aa\r\n\r\n123456", 106) testResponseReadLimitBodyError(t, "HTTP/1.1 400 OK\r\nContent-Type: aa\r\n\r\n123456", 5) } func TestResponseReadWithoutBody(t *testing.T) { var resp protocol.Response testResponseReadWithoutBody(t, &resp, "HTTP/1.1 304 Not Modified\r\nContent-Type: aa\r\nContent-Encoding: gzip\r\nContent-Length: 1235\r\n\r\n", false, consts.StatusNotModified, 1235, "aa", nil, "gzip", consts.HTTP11) testResponseReadWithoutBody(t, &resp, "HTTP/1.1 200 Foo Bar\r\nContent-Type: aab\r\nTrailer: Foo\r\nContent-Encoding: deflate\r\nTransfer-Encoding: chunked\r\n\r\n0\r\nFoo: bar\r\n\r\nHTTP/1.2", false, consts.StatusOK, 0, "aab", map[string]string{"Foo": "bar"}, "deflate", consts.HTTP11) testResponseReadWithoutBody(t, &resp, "HTTP/1.1 204 Foo Bar\r\nContent-Type: aab\r\nTrailer: Foo\r\nContent-Encoding: deflate\r\nTransfer-Encoding: chunked\r\n\r\n0\r\nFoo: bar\r\n\r\nHTTP/1.2", true, consts.StatusNoContent, -1, "aab", nil, "deflate", consts.HTTP11) testResponseReadWithoutBody(t, &resp, "HTTP/1.1 123 AAA\r\nContent-Type: xxx\r\nContent-Encoding: gzip\r\nContent-Length: 3434\r\n\r\n", false, 123, 3434, "xxx", nil, "gzip", consts.HTTP11) testResponseReadWithoutBody(t, &resp, "HTTP 200 OK\r\nContent-Type: text/xml\r\nContent-Encoding: deflate\r\nContent-Length: 123\r\n\r\nfoobar\r\n", true, consts.StatusOK, 123, "text/xml", nil, "deflate", consts.HTTP10) // '100 Continue' must be skipped. testResponseReadWithoutBody(t, &resp, "HTTP/1.1 100 Continue\r\nFoo-bar: baz\r\n\r\nHTTP/1.1 329 aaa\r\nContent-Type: qwe\r\nContent-Encoding: gzip\r\nContent-Length: 894\r\n\r\n", true, 329, 894, "qwe", nil, "gzip", consts.HTTP11) } func verifyResponseHeader(t *testing.T, h *protocol.ResponseHeader, expectedStatusCode, expectedContentLength int, expectedContentType, expectedContentEncoding, expectedProtocol string) { if h.StatusCode() != expectedStatusCode { t.Fatalf("Unexpected status code %d. Expected %d", h.StatusCode(), expectedStatusCode) } if h.ContentLength() != expectedContentLength { t.Fatalf("Unexpected content length %d. Expected %d", h.ContentLength(), expectedContentLength) } if string(h.ContentType()) != expectedContentType { t.Fatalf("Unexpected content type %q. Expected %q", h.ContentType(), expectedContentType) } if string(h.ContentEncoding()) != expectedContentEncoding { t.Fatalf("Unexpected content encoding %q. Expected %q", h.ContentEncoding(), expectedContentEncoding) } if h.GetProtocol() != expectedProtocol { t.Fatalf("Unexpected protocol %q. Expected %q", h.GetProtocol(), expectedProtocol) } } func testResponseSuccess(t *testing.T, statusCode int, contentType, serverName, body string, expectedStatusCode int, expectedContentType, expectedServerName string, ) { var resp protocol.Response resp.SetStatusCode(statusCode) resp.Header.Set("Content-Type", contentType) resp.Header.Set("Server", serverName) resp.SetBody([]byte(body)) w := &bytes.Buffer{} // bw := bufio.NewWriter(w) zw := netpoll.NewWriter(w) err := Write(&resp, zw) if err != nil { t.Fatalf("Unexpected error when calling Response.Write(): %s", err) } if err = zw.Flush(); err != nil { t.Fatalf("Unexpected error when flushing bufio.Writer: %s", err) } var resp1 protocol.Response br := bufio.NewReader(w) zr := netpoll.NewReader(br) if err = Read(&resp1, zr); err != nil { t.Fatalf("Unexpected error when calling Response.Read(): %s", err) } if resp1.StatusCode() != expectedStatusCode { t.Fatalf("Unexpected status code: %d. Expected %d", resp1.StatusCode(), expectedStatusCode) } if resp1.Header.ContentLength() != len(body) { t.Fatalf("Unexpected content-length: %d. Expected %d", resp1.Header.ContentLength(), len(body)) } if string(resp1.Header.Peek(consts.HeaderContentType)) != expectedContentType { t.Fatalf("Unexpected content-type: %q. Expected %q", resp1.Header.Peek(consts.HeaderContentType), expectedContentType) } if string(resp1.Header.Peek(consts.HeaderServer)) != expectedServerName { t.Fatalf("Unexpected server: %q. Expected %q", resp1.Header.Peek(consts.HeaderServer), expectedServerName) } if !bytes.Equal(resp1.Body(), []byte(body)) { t.Fatalf("Unexpected body: %q. Expected %q", resp1.Body(), body) } } func testResponseReadWithoutBody(t *testing.T, resp *protocol.Response, s string, skipBody bool, expectedStatusCode, expectedContentLength int, expectedContentType string, expectedTrailer map[string]string, expectedContentEncoding, expectedProtocol string, ) { zr := mock.NewZeroCopyReader(s) resp.SkipBody = skipBody err := Read(resp, zr) if err != nil { t.Fatalf("Unexpected error when reading response without body: %s. response=%q", err, s) } if len(resp.Body()) != 0 { t.Fatalf("Unexpected response body %q. Expected %q. response=%q", resp.Body(), "", s) } verifyResponseHeader(t, &resp.Header, expectedStatusCode, expectedContentLength, expectedContentType, expectedContentEncoding, expectedProtocol) verifyResponseTrailer(t, &resp.Header, expectedTrailer) // verify that ordinal response is read after null-body response resp.SkipBody = false testResponseReadSuccess(t, resp, "HTTP/1.1 300 OK\r\nContent-Length: 5\r\nContent-Type: bar\r\n\r\n56789", consts.StatusMultipleChoices, 5, "bar", "56789", nil, consts.HTTP11) } func verifyResponseTrailer(t *testing.T, h *protocol.ResponseHeader, expectedTrailers map[string]string) { for k, v := range expectedTrailers { got := h.Trailer().Peek(k) if !bytes.Equal(got, []byte(v)) { t.Fatalf("Unexpected trailer %q. Expected %q. Got %q", k, v, got) } } h.Trailer().VisitAll(func(key, value []byte) { if v := expectedTrailers[string(key)]; string(value) != v { t.Fatalf("Unexpected trailer %q. Expected %q. Got %q", string(key), v, string(value)) } }) } func testResponseReadLimitBodyError(t *testing.T, s string, maxBodySize int) { var resp protocol.Response zr := netpoll.NewReader(bytes.NewBufferString(s)) err := ReadHeaderAndLimitBody(&resp, zr, maxBodySize) if err == nil { t.Fatalf("expecting error. s=%q, maxBodySize=%d", s, maxBodySize) } if !errors.Is(err, errs.ErrBodyTooLarge) { t.Fatalf("unexpected error: %s. Expecting %s. s=%q, maxBodySize=%d", err, errBodyTooLarge, s, maxBodySize) } } func testResponseReadLimitBodySuccess(t *testing.T, s string, maxBodySize int) { var resp protocol.Response mr := mock.NewZeroCopyReader(s) if err := ReadHeaderAndLimitBody(&resp, mr, maxBodySize); err != nil { t.Fatalf("unexpected error: %s. s=%q, maxBodySize=%d", err, s, maxBodySize) } } func TestResponseBodyStreamWithTrailer(t *testing.T) { t.Parallel() testResponseBodyStreamWithTrailer(t, nil, false) body := mock.CreateFixedBody(1e5) testResponseBodyStreamWithTrailer(t, body, false) testResponseBodyStreamWithTrailer(t, body, true) } func testResponseBodyStreamWithTrailer(t *testing.T, body []byte, disableNormalizing bool) { expectedTrailer := map[string]string{ "foo": "testfoo", "bar": "testbar", } var resp1 protocol.Response if disableNormalizing { resp1.Header.DisableNormalizing() } resp1.SetBodyStream(bytes.NewReader(body), -1) for k, v := range expectedTrailer { err := resp1.Header.Trailer().Add(k, v) if err != nil { t.Fatalf("unexpected error: %s", err) } } var w bytes.Buffer zw := netpoll.NewWriter(&w) if err := Write(&resp1, zw); err != nil { t.Fatalf("unexpected error: %s", err) } if err := zw.Flush(); err != nil { t.Fatalf("unexpected error: %s", err) } var resp2 protocol.Response if disableNormalizing { resp2.Header.DisableNormalizing() } br := bufio.NewReader(&w) zr := netpoll.NewReader(br) if err := Read(&resp2, zr); err != nil { t.Fatalf("unexpected error: %s", err) } respBody := resp2.Body() if !bytes.Equal(respBody, body) { t.Fatalf("unexpected body: %q. Expecting %q", respBody, body) } for k, v := range expectedTrailer { kBytes := []byte(k) utils.NormalizeHeaderKey(kBytes, disableNormalizing) r := resp2.Header.Trailer().Peek(k) if string(r) != v { t.Fatalf("unexpected trailer header %q: %q. Expecting %s", kBytes, r, v) } } } func TestResponseReadBodyStreamBadReader(t *testing.T) { t.Parallel() resp := protocol.AcquireResponse() errReader := mock.NewErrorReadConn(errors.New("test error")) bodyBuf := resp.BodyBuffer() bodyBuf.Reset() bodyStream := ext.AcquireBodyStream(bodyBuf, errReader, resp.Header.Trailer(), 100) resp.ConstructBodyStream(bodyBuf, convertClientRespStream(bodyStream, func(shouldClose bool) error { assert.True(t, shouldClose) return nil })) stBody := resp.BodyStream() closer, _ := stBody.(io.Closer) closer.Close() } func TestSetResponseBodyStreamFixedSize(t *testing.T) { t.Parallel() testSetResponseBodyStream(t, "a") testSetResponseBodyStream(t, string(mock.CreateFixedBody(4097))) testSetResponseBodyStream(t, string(mock.CreateFixedBody(100500))) } func TestSetResponseBodyStreamChunked(t *testing.T) { t.Parallel() testSetResponseBodyStreamChunked(t, "", map[string]string{"Foo": "bar"}) body := "foobar baz aaa bbb ccc" testSetResponseBodyStreamChunked(t, body, nil) body = string(mock.CreateFixedBody(10001)) testSetResponseBodyStreamChunked(t, body, map[string]string{"Foo": "test", "Bar": "test"}) } func testSetResponseBodyStream(t *testing.T, body string) { var resp protocol.Response bodySize := len(body) if resp.IsBodyStream() { t.Fatalf("IsBodyStream must return false") } resp.SetBodyStream(bytes.NewBufferString(body), bodySize) if !resp.IsBodyStream() { t.Fatalf("IsBodyStream must return true") } var w bytes.Buffer zw := netpoll.NewWriter(&w) if err := Write(&resp, zw); err != nil { t.Fatalf("unexpected error when writing response: %s. body=%q", err, body) } if err := zw.Flush(); err != nil { t.Fatalf("unexpected error when flushing response: %s. body=%q", err, body) } var resp1 protocol.Response br := bufio.NewReader(&w) zr := netpoll.NewReader(br) if err := Read(&resp1, zr); err != nil { t.Fatalf("unexpected error when reading response: %s. body=%q", err, body) } if string(resp1.Body()) != body { t.Fatalf("unexpected body %q. Expecting %q", resp1.Body(), body) } } func testSetResponseBodyStreamChunked(t *testing.T, body string, trailer map[string]string) { var resp protocol.Response if resp.IsBodyStream() { t.Fatalf("IsBodyStream must return false") } resp.SetBodyStream(bytes.NewBufferString(body), -1) if !resp.IsBodyStream() { t.Fatalf("IsBodyStream must return true") } var w bytes.Buffer zw := netpoll.NewWriter(&w) for k, v := range trailer { err := resp.Header.Trailer().Add(k, v) if err != nil { t.Fatalf("unexpected error: %s", err) } } if err := Write(&resp, zw); err != nil { t.Fatalf("unexpected error when writing response: %s. body=%q", err, body) } if err := zw.Flush(); err != nil { t.Fatalf("unexpected error when flushing response: %s. body=%q", err, body) } var resp1 protocol.Response br := bufio.NewReader(&w) zr := netpoll.NewReader(br) if err := Read(&resp1, zr); err != nil { t.Fatalf("unexpected error when reading response: %s. body=%q", err, body) } if string(resp1.Body()) != body { t.Fatalf("unexpected body %q. Expecting %q", resp1.Body(), body) } for k, v := range trailer { r := resp.Header.Trailer().Peek(k) if string(r) != v { t.Fatalf("unexpected trailer %s. Expecting %s. Got %q", k, v, r) } } } func testResponseReadBodyStreamSuccess(t *testing.T, resp *protocol.Response, response string, expectedStatusCode, expectedContentLength int, expectedContentType, expectedBody string, expectedTrailer map[string]string, expectedProtocol string, ) { zr := mock.NewZeroCopyReader(response) err := ReadHeaderBodyStream(resp, zr, 0, nil) if err != nil { t.Fatalf("Unexpected error: %s", err) } assert.True(t, resp.IsBodyStream()) body, err := ioutil.ReadAll(resp.BodyStream()) if err != nil && err != io.EOF { t.Fatalf("Unexpected error: %s", err) } verifyResponseHeader(t, &resp.Header, expectedStatusCode, expectedContentLength, expectedContentType, "", expectedProtocol) if !bytes.Equal(body, []byte(expectedBody)) { t.Fatalf("Unexpected body %q. Expected %q", resp.Body(), []byte(expectedBody)) } verifyResponseTrailer(t, &resp.Header, expectedTrailer) } func testResponseReadBodyStreamBadTrailer(t *testing.T, resp *protocol.Response, response string) { zr := mock.NewZeroCopyReader(response) err := ReadHeaderBodyStream(resp, zr, 0, nil) if err != nil { t.Fatalf("Unexpected error: %s", err) } assert.True(t, resp.IsBodyStream()) _, err = ioutil.ReadAll(resp.BodyStream()) if err == nil || err == io.EOF { t.Fatalf("expected error when reading response.") } } func TestResponseReadBodyStream(t *testing.T) { t.Parallel() resp := &protocol.Response{} // usual response testResponseReadBodyStreamSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Length: 10\r\nContent-Type: foo/bar\r\n\r\n0123456789", consts.StatusOK, 10, "foo/bar", "0123456789", nil, consts.HTTP11) // zero response testResponseReadBodyStreamSuccess(t, resp, "HTTP/1.1 500 OK\r\nContent-Length: 0\r\nContent-Type: foo/bar\r\n\r\n", consts.StatusInternalServerError, 0, "foo/bar", "", nil, consts.HTTP11) // response with trailer testResponseReadBodyStreamSuccess(t, resp, "HTTP/1.1 300 OK\r\nTransfer-Encoding: chunked\r\nTrailer: Foo\r\nContent-Type: bar\r\n\r\n5\r\n56789\r\n0\r\nfoo: bar\r\n\r\n", consts.StatusMultipleChoices, -1, "bar", "56789", map[string]string{"Foo": "bar"}, consts.HTTP11) bodyWithLongLength := strings.Repeat("1", 8*1024+1) testResponseReadBodyStreamSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Length: 8193\r\nContent-Type: foo/bar\r\n\r\n"+bodyWithLongLength, consts.StatusOK, 8193, "foo/bar", bodyWithLongLength, nil, consts.HTTP11) // response with trailer disableNormalizing resp.Header.DisableNormalizing() resp.Header.Trailer().DisableNormalizing() testResponseReadBodyStreamSuccess(t, resp, "HTTP/1.1 300 OK\r\nTransfer-Encoding: chunked\r\nTrailer: foo\r\nContent-Type: bar\r\n\r\n5\r\n56789\r\n0\r\nfoo: bar\r\n\r\n", consts.StatusMultipleChoices, -1, "bar", "56789", map[string]string{"foo": "bar"}, consts.HTTP11) // no content-length ('identity' transfer-encoding) testResponseReadBodyStreamSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: foobar\r\n\r\nzxxxx", consts.StatusOK, -2, "foobar", "zxxxx", nil, consts.HTTP11) // explicitly stated 'Transfer-Encoding: identity' testResponseReadBodyStreamSuccess(t, resp, "HTTP/1.1 234 ss\r\nContent-Type: xxx\r\n\r\nxag", 234, -2, "xxx", "xag", nil, consts.HTTP11) // big 'identity' response body := string(mock.CreateFixedBody(100500)) testResponseReadBodyStreamSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\n\r\n"+body, consts.StatusOK, -2, "aa", body, nil, consts.HTTP11) // chunked response testResponseReadBodyStreamSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nTrailer: Foo2\r\nTransfer-Encoding: chunked\r\n\r\n4\r\nqwer\r\n2\r\nty\r\n0\r\nFoo2: bar2\r\n\r\n", 200, -1, "text/html", "qwerty", map[string]string{"Foo2": "bar2"}, consts.HTTP11) // chunked response with non-chunked Transfer-Encoding. testResponseReadBodyStreamSuccess(t, resp, "HTTP/1.1 230 OK\r\nContent-Type: text\r\nTrailer: Foo3\r\nTransfer-Encoding: aaabbb\r\n\r\n2\r\ner\r\n2\r\nty\r\n0\r\nFoo3: bar3\r\n\r\n", 230, -1, "text", "erty", map[string]string{"Foo3": "bar3"}, consts.HTTP11) // chunked response with empty body testResponseReadBodyStreamSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nTrailer: Foo5\r\nTransfer-Encoding: chunked\r\n\r\n0\r\nFoo5: bar5\r\n\r\n", consts.StatusOK, -1, "text/html", "", map[string]string{"Foo5": "bar5"}, consts.HTTP11) } func TestResponseReadBodyStreamBadTrailer(t *testing.T) { t.Parallel() resp := &protocol.Response{} testResponseReadBodyStreamBadTrailer(t, resp, "HTTP/1.1 300 OK\r\nTransfer-Encoding: chunked\r\nContent-Type: bar\r\n\r\n5\r\n56789\r\n0\r\ncontent-type: bar\r\n\r\n") testResponseReadBodyStreamBadTrailer(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nTransfer-Encoding: chunked\r\n\r\n4\r\nqwer\r\n2\r\nty\r\n0\r\nproxy-connection: bar2\r\n\r\n") } ================================================ FILE: pkg/protocol/http1/resp/writer.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package resp import ( "errors" "runtime" "sync" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/http1/ext" ) var chunkWriterPool sync.Pool func init() { chunkWriterPool = sync.Pool{ New: func() interface{} { return &chunkedBodyWriter{} }, } } type chunkedBodyWriter struct { r *protocol.Response w network.Writer err error finalized bool wroteHeader bool } var errChunkedFinished = errors.New("chunked response is finished; no more data will be written.") // Write implements network.ExtWriter.Write / io.Writer.Write func (c *chunkedBodyWriter) Write(p []byte) (n int, err error) { if c.finalized { return 0, errChunkedFinished } if c.err != nil { return 0, c.err } if err := c.WriteHeader(); err != nil { return 0, err } if len(p) == 0 { // prevent from sending zero-len chunk which indicates stream ends. // callers may write with zero-len buf unintentionally. // use Finalize() instead. return 0, nil } if err := c.writeChunk(p); err != nil { return 0, err } return len(p), nil } // WriteHeader writes the response header for chunked encoding func (c *chunkedBodyWriter) WriteHeader() error { if c.wroteHeader { return c.err } c.wroteHeader = true c.r.Header.SetContentLength(-1) if c.err = WriteHeader(&c.r.Header, c.w); c.err != nil { return c.err } return nil } func (c *chunkedBodyWriter) writeChunk(b []byte) error { if c.err = ext.WriteChunk(c.w, b, false); c.err != nil { return c.err } return nil } func (c *chunkedBodyWriter) Flush() error { return c.w.Flush() } // Finalize will write the ending chunk as well as trailer and flush the writer. // Warning: do not call this method by yourself, unless you know what you are doing. func (c *chunkedBodyWriter) Finalize() error { if c.finalized || c.err != nil { return c.err } c.finalized = true if err := c.WriteHeader(); err != nil { return err } // zero-len chunk if err := c.writeChunk(nil); err != nil { return err } // trailer which ends with \r\n _, c.err = c.w.WriteBinary(c.r.Header.Trailer().Header()) if c.err == nil { c.err = c.Flush() } return c.err } func (c *chunkedBodyWriter) release() { c.r = nil c.w = nil c.err = nil c.finalized = false c.wroteHeader = false chunkWriterPool.Put(c) } // NewChunkedBodyWriter creates a new chunked body writer. func NewChunkedBodyWriter(r *protocol.Response, w network.Writer) network.ExtWriter { extWriter := chunkWriterPool.Get().(*chunkedBodyWriter) extWriter.r = r extWriter.w = w runtime.SetFinalizer(extWriter, (*chunkedBodyWriter).release) return extWriter } ================================================ FILE: pkg/protocol/http1/resp/writer_test.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package resp import ( "errors" "strings" "testing" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/test/mock" "github.com/cloudwego/hertz/pkg/protocol" ) func TestNewChunkedBodyWriter(t *testing.T) { response := protocol.AcquireResponse() defer protocol.ReleaseResponse(response) mockConn := mock.NewConn("") w := NewChunkedBodyWriter(response, mockConn) _, _ = w.Write([]byte("hello")) _, _ = w.Write(nil) // noop assert.Nil(t, w.Flush()) out, _ := mockConn.WriterRecorder().Peek(mockConn.WriterRecorder().WroteLen()) resp := string(out) assert.True(t, strings.Contains(resp, "Transfer-Encoding: chunked")) assert.True(t, strings.HasSuffix(resp, "5\r\nhello\r\n")) // Finalize adds 0\r\n\r\n assert.Nil(t, w.Finalize()) assert.Nil(t, w.Finalize()) // noop out, _ = mockConn.WriterRecorder().Peek(mockConn.WriterRecorder().WroteLen()) resp = string(out) assert.True(t, strings.HasSuffix(resp, "5\r\nhello\r\n0\r\n\r\n")) _, err := w.Write([]byte("world")) assert.True(t, err == errChunkedFinished) } func TestNewChunkedBodyWriter_Err(t *testing.T) { response := protocol.AcquireResponse() defer protocol.ReleaseResponse(response) mw := mock.NewMockWriter(nil) w := NewChunkedBodyWriter(response, mw) expectErr := errors.New("mock malloc err") mw.MockMalloc = func(n int) ([]byte, error) { return nil, expectErr } _, err := w.Write([]byte("hello")) assert.True(t, err == expectErr) mw.MockMalloc = nil _, err = w.Write([]byte("world")) assert.True(t, err == expectErr) // next call will return last err w = NewChunkedBodyWriter(response, mw) mw.MockMalloc = func(n int) ([]byte, error) { return nil, expectErr } err = w.Finalize() assert.True(t, err == expectErr) } ================================================ FILE: pkg/protocol/http1/server.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package http1 import ( "context" "crypto/tls" "errors" "io" "net" "sync" "time" "github.com/cloudwego/hertz/internal/bytestr" internalStats "github.com/cloudwego/hertz/internal/stats" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/app/server/render" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/hlog" "github.com/cloudwego/hertz/pkg/common/tracer/stats" "github.com/cloudwego/hertz/pkg/common/tracer/traceinfo" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/protocol/http1/ext" "github.com/cloudwego/hertz/pkg/protocol/http1/req" "github.com/cloudwego/hertz/pkg/protocol/http1/resp" "github.com/cloudwego/hertz/pkg/protocol/suite" ) func init() { if b, err := utils.GetBoolFromEnv("HERTZ_DISABLE_REQUEST_CONTEXT_POOL"); err == nil { disabaleRequestContextPool = b } } // NextProtoTLS is the NPN/ALPN protocol negotiated during // HTTP/1.1's TLS setup. // Also used for server addressing const NextProtoTLS = suite.HTTP1 var ( errHijacked = errs.New(errs.ErrHijacked, errs.ErrorTypePublic, nil) errIdleTimeout = errs.New(errs.ErrIdleTimeout, errs.ErrorTypePrivate, nil) errShortConnection = errs.New(errs.ErrShortConnection, errs.ErrorTypePublic, "server is going to close the connection") errUnexpectedEOF = errs.NewPublic(io.ErrUnexpectedEOF.Error() + " when reading request") disabaleRequestContextPool = false ) type Option struct { StreamRequestBody bool GetOnly bool NoDefaultDate bool NoDefaultContentType bool DisablePreParseMultipartForm bool DisableKeepalive bool NoDefaultServerHeader bool DisableHeaderNamesNormalizing bool MaxRequestBodySize int MaxHeaderBytes int IdleTimeout time.Duration ReadTimeout time.Duration ServerName []byte TLS *tls.Config HTMLRender render.HTMLRender EnableTrace bool ContinueHandler func(header *protocol.RequestHeader) bool HijackConnHandle func(c network.Conn, h app.HijackHandler) } type Server struct { Option Core suite.Core eventStackPool *sync.Pool } func (s Server) getRequestContext() *app.RequestContext { if disabaleRequestContextPool { return &app.RequestContext{} } return s.Core.GetCtxPool().Get().(*app.RequestContext) } func (s Server) putRequestContext(ctx *app.RequestContext) { if disabaleRequestContextPool { return } ctx.Reset() s.Core.GetCtxPool().Put(ctx) } func (s Server) Serve(c context.Context, conn network.Conn) (err error) { var ( zr network.Reader zw network.Writer serverName []byte isHTTP11 bool connectionClose bool continueReadingRequest = true hijackHandler app.HijackHandler // HTTP1 path // 1. Get a request context // 2. Prepare it // 3. Process it // 4. Reset and recycle(in pooled mode) ctx = s.getRequestContext() traceCtl = s.Core.GetTracer() eventsToTrigger *eventStack // Use a new variable to hold the standard context to avoid modify the initial // context. cc = c ) if s.EnableTrace { eventsToTrigger = s.eventStackPool.Get().(*eventStack) } defer func() { if s.EnableTrace { // in case of error, we need to trigger all events if eventsToTrigger != nil { for last := eventsToTrigger.pop(); last != nil; last = eventsToTrigger.pop() { last(ctx.GetTraceInfo(), err) } s.eventStackPool.Put(eventsToTrigger) } if shouldRecordInTraceError(err) { traceCtl.DoFinish(cc, ctx, err) } else { traceCtl.DoFinish(cc, ctx, nil) } } // Hijack may release and close the connection already if zr != nil && !errors.Is(err, errs.ErrHijacked) { zr.Release() //nolint:errcheck zr = nil } if ctx.IsExiled() { return } s.putRequestContext(ctx) }() ctx.HTMLRender = s.HTMLRender ctx.SetConn(conn) ctx.Request.SetIsTLS(s.TLS != nil) ctx.SetEnableTrace(s.EnableTrace) if !s.NoDefaultServerHeader { serverName = s.ServerName } connRequestNum := uint64(0) for { connRequestNum++ if zr == nil { zr = ctx.GetReader() } // If this is a keep-alive connection we want to try and read the first bytes // within the idle time. if connRequestNum > 1 { ctx.GetConn().SetReadTimeout(s.IdleTimeout) //nolint:errcheck _, err = zr.Peek(4) // This is not the first request, and we haven't read a single byte // of a new request yet. This means it's just a keep-alive connection // closing down either because the remote closed it or because // or a read timeout on our side. Either way just close the connection // and don't return any error response. if err != nil { err = errIdleTimeout return } // Reset the real read timeout for the coming request ctx.GetConn().SetReadTimeout(s.ReadTimeout) //nolint:errcheck } if s.EnableTrace { cc = traceCtl.DoStart(c, ctx) internalStats.Record(ctx.GetTraceInfo(), stats.ReadHeaderStart, err) eventsToTrigger.push(func(ti traceinfo.TraceInfo, err error) { internalStats.Record(ti, stats.ReadHeaderFinish, err) }) } ctx.Response.Header.SetNoDefaultDate(s.NoDefaultDate) ctx.Response.Header.SetNoDefaultContentType(s.NoDefaultContentType) if s.DisableHeaderNamesNormalizing { ctx.Request.Header.DisableNormalizing() ctx.Response.Header.DisableNormalizing() } // Read Headers if err = req.ReadHeaderWithLimit(&ctx.Request.Header, zr, s.MaxHeaderBytes); err == nil { if s.EnableTrace { // read header finished if last := eventsToTrigger.pop(); last != nil { last(ctx.GetTraceInfo(), err) } internalStats.Record(ctx.GetTraceInfo(), stats.ReadBodyStart, err) eventsToTrigger.push(func(ti traceinfo.TraceInfo, err error) { internalStats.Record(ti, stats.ReadBodyFinish, err) }) } // Read body if s.StreamRequestBody { err = req.ReadBodyStream(&ctx.Request, zr, s.MaxRequestBodySize, s.GetOnly, !s.DisablePreParseMultipartForm) } else { err = req.ReadLimitBody(&ctx.Request, zr, s.MaxRequestBodySize, s.GetOnly, !s.DisablePreParseMultipartForm) } } if s.EnableTrace { if ctx.Request.Header.ContentLength() >= 0 { ctx.GetTraceInfo().Stats().SetRecvSize(len(ctx.Request.Header.RawHeaders()) + ctx.Request.Header.ContentLength()) } else { ctx.GetTraceInfo().Stats().SetRecvSize(0) } // read body finished if last := eventsToTrigger.pop(); last != nil { last(ctx.GetTraceInfo(), err) } } if err != nil { if errors.Is(err, errs.ErrNothingRead) { return nil } if err == io.EOF { return errUnexpectedEOF } writeErrorResponse(zw, ctx, serverName, err) return } // 'Expect: 100-continue' request handling. // See https://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html#sec8.2.3 for details. if ctx.Request.MayContinue() { // Allow the ability to deny reading the incoming request body if s.ContinueHandler != nil { if continueReadingRequest = s.ContinueHandler(&ctx.Request.Header); !continueReadingRequest { ctx.SetStatusCode(consts.StatusExpectationFailed) } } if continueReadingRequest { zw = ctx.GetWriter() // Send 'HTTP/1.1 100 Continue' response. _, err = zw.WriteBinary(bytestr.StrResponseContinue) if err != nil { return } err = zw.Flush() if err != nil { return } // Read body. if zr == nil { zr = ctx.GetReader() } if s.StreamRequestBody { err = req.ContinueReadBodyStream(&ctx.Request, zr, s.MaxRequestBodySize, !s.DisablePreParseMultipartForm) } else { err = req.ContinueReadBody(&ctx.Request, zr, s.MaxRequestBodySize, !s.DisablePreParseMultipartForm) } if err != nil { writeErrorResponse(zw, ctx, serverName, err) return } } } connectionClose = s.DisableKeepalive || ctx.Request.Header.ConnectionClose() isHTTP11 = ctx.Request.Header.IsHTTP11() if serverName != nil { ctx.Response.Header.SetServerBytes(serverName) } if s.EnableTrace { internalStats.Record(ctx.GetTraceInfo(), stats.ServerHandleStart, err) eventsToTrigger.push(func(ti traceinfo.TraceInfo, err error) { internalStats.Record(ti, stats.ServerHandleFinish, err) }) } if ctx.Request.IsURIParsed() { // ctx.Request.URI() must not be called before ServeHTTP // The only case is concurrency issue when parsing a new request, // and user is reading the old request in background. hlog.SystemLogger().Warnf("%s\n%s\n%s\n%s", "Race detected.", "Please be aware that the protocol.Request passed to handler is only valid before the handler returns.", "DO NOT attempt to keep and access protocol.Request after the handler returns.", "Try build with -race to check the race issue.") return errors.New("race detected") } // Handle the request // // NOTE: All middlewares and business handler will be executed in this. And at this point, the request has been parsed // and the route has been matched. s.Core.ServeHTTP(cc, ctx) if s.EnableTrace { // application layer handle finished if last := eventsToTrigger.pop(); last != nil { last(ctx.GetTraceInfo(), err) } } // exit check if !s.Core.IsRunning() { connectionClose = true } if !ctx.IsGet() && ctx.IsHead() { ctx.Response.SkipBody = true } hijackHandler = ctx.GetHijackHandler() ctx.SetHijackHandler(nil) connectionClose = connectionClose || ctx.Response.ConnectionClose() if connectionClose { ctx.Response.Header.SetCanonical(bytestr.StrConnection, bytestr.StrClose) } else if !isHTTP11 { ctx.Response.Header.SetCanonical(bytestr.StrConnection, bytestr.StrKeepAlive) } if zw == nil { zw = ctx.GetWriter() } if s.EnableTrace { internalStats.Record(ctx.GetTraceInfo(), stats.WriteStart, err) eventsToTrigger.push(func(ti traceinfo.TraceInfo, err error) { internalStats.Record(ti, stats.WriteFinish, err) }) } if err = writeResponse(ctx, zw); err != nil { return } if s.EnableTrace { if ctx.Response.Header.ContentLength() > 0 { ctx.GetTraceInfo().Stats().SetSendSize(ctx.Response.Header.GetHeaderLength() + ctx.Response.Header.ContentLength()) } else { ctx.GetTraceInfo().Stats().SetSendSize(0) } } // Release the zeroCopyReader before flush to prevent data race if zr != nil { zr.Release() //nolint:errcheck zr = nil } // Flush the response. if err = zw.Flush(); err != nil { return } if s.EnableTrace { // write finished if last := eventsToTrigger.pop(); last != nil { last(ctx.GetTraceInfo(), err) } } // Release request body stream if ctx.Request.IsBodyStream() { err = ext.ReleaseBodyStream(ctx.RequestBodyStream()) if err != nil { return } } if hijackHandler != nil { // Hijacked conn process the timeout by itself err = ctx.GetConn().SetReadTimeout(0) if err != nil { return } // Hijack and block the connection until the hijackHandler return s.HijackConnHandle(ctx.GetConn(), hijackHandler) err = errHijacked return } if connectionClose { return errShortConnection } // Back to network layer to trigger. // For now, only netpoll network mode has this feature. // FIXME: check if s.IdleTimeout == 0 { return } // general case if s.EnableTrace { if shouldRecordInTraceError(err) { traceCtl.DoFinish(cc, ctx, err) } else { traceCtl.DoFinish(cc, ctx, nil) } } ctx.ResetWithoutConn() } } func NewServer() *Server { return &Server{ eventStackPool: &sync.Pool{ New: func() interface{} { return &eventStack{} }, }, } } func writeErrorResponse(zw network.Writer, ctx *app.RequestContext, serverName []byte, err error) network.Writer { errorHandler := defaultErrorHandler errorHandler(ctx, err) if serverName != nil { ctx.Response.Header.SetServerBytes(serverName) } ctx.SetConnectionClose() if zw == nil { zw = ctx.GetWriter() } writeResponse(ctx, zw) //nolint:errcheck zw.Flush() //nolint:errcheck return zw } func writeResponse(ctx *app.RequestContext, w network.Writer) error { // Skip default response writing logic if it has been hijacked if ctx.Response.GetHijackWriter() != nil { return ctx.Response.GetHijackWriter().Finalize() } err := resp.Write(&ctx.Response, w) if err != nil { return err } return err } func defaultErrorHandler(ctx *app.RequestContext, err error) { if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() { ctx.AbortWithMsg("Request timeout", consts.StatusRequestTimeout) } else if errors.Is(err, errs.ErrBodyTooLarge) { ctx.AbortWithMsg("Request Entity Too Large", consts.StatusRequestEntityTooLarge) } else if errors.Is(err, errs.ErrHeaderTooLarge) { ctx.AbortWithMsg("Request Header Fields Too Large", consts.StatusRequestHeaderFieldsTooLarge) } else { ctx.AbortWithMsg("Error when parsing request", consts.StatusBadRequest) } } type eventStack []func(ti traceinfo.TraceInfo, err error) func (e *eventStack) isEmpty() bool { return len(*e) == 0 } func (e *eventStack) push(f func(ti traceinfo.TraceInfo, err error)) { *e = append(*e, f) } func (e *eventStack) pop() func(ti traceinfo.TraceInfo, err error) { if e.isEmpty() { return nil } last := (*e)[len(*e)-1] *e = (*e)[:len(*e)-1] return last } func shouldRecordInTraceError(err error) bool { if err == nil { return false } if errors.Is(err, errs.ErrIdleTimeout) { return false } if errors.Is(err, errs.ErrHijacked) { return false } if errors.Is(err, errs.ErrShortConnection) { return false } return true } ================================================ FILE: pkg/protocol/http1/server_test.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package http1 import ( "bytes" "context" "errors" "strings" "sync" "testing" "time" inStats "github.com/cloudwego/hertz/internal/stats" "github.com/cloudwego/hertz/pkg/app" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/test/mock" "github.com/cloudwego/hertz/pkg/common/tracer" "github.com/cloudwego/hertz/pkg/common/tracer/stats" "github.com/cloudwego/hertz/pkg/common/tracer/traceinfo" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/protocol/http1/req" "github.com/cloudwego/hertz/pkg/protocol/http1/resp" ) var pool = &sync.Pool{New: func() interface{} { return &eventStack{} }} func TestTraceEventCompleted(t *testing.T) { server := &Server{} server.eventStackPool = pool server.EnableTrace = true reqCtx := &app.RequestContext{} server.Core = &mockCore{ ctxPool: &sync.Pool{New: func() interface{} { ti := traceinfo.NewTraceInfo() ti.Stats().SetLevel(2) reqCtx.SetTraceInfo(&mockTraceInfo{ti}) return reqCtx }}, controller: &inStats.Controller{}, } err := server.Serve(context.TODO(), mock.NewConn("GET /aaa HTTP/1.1\nHost: foobar.com\n\n")) assert.True(t, errors.Is(err, errs.ErrShortConnection)) traceInfo := reqCtx.GetTraceInfo() assert.False(t, traceInfo.Stats().GetEvent(stats.HTTPStart).IsNil()) assert.False(t, traceInfo.Stats().GetEvent(stats.ReadHeaderStart).IsNil()) assert.False(t, traceInfo.Stats().GetEvent(stats.ReadHeaderFinish).IsNil()) assert.False(t, traceInfo.Stats().GetEvent(stats.ReadBodyStart).IsNil()) assert.False(t, traceInfo.Stats().GetEvent(stats.ReadBodyFinish).IsNil()) assert.False(t, traceInfo.Stats().GetEvent(stats.ServerHandleStart).IsNil()) assert.False(t, traceInfo.Stats().GetEvent(stats.ServerHandleFinish).IsNil()) assert.False(t, traceInfo.Stats().GetEvent(stats.WriteStart).IsNil()) assert.False(t, traceInfo.Stats().GetEvent(stats.WriteFinish).IsNil()) assert.False(t, traceInfo.Stats().GetEvent(stats.HTTPFinish).IsNil()) assert.Nil(t, traceInfo.Stats().Error()) } func TestTraceEventReadHeaderError(t *testing.T) { server := &Server{} server.eventStackPool = pool server.EnableTrace = true reqCtx := &app.RequestContext{} server.Core = &mockCore{ ctxPool: &sync.Pool{New: func() interface{} { ti := traceinfo.NewTraceInfo() ti.Stats().SetLevel(2) reqCtx.SetTraceInfo(&mockTraceInfo{ti}) return reqCtx }}, controller: &inStats.Controller{}, } err := server.Serve(context.TODO(), mock.NewConn("ErrorFirstLine\r\n\r\n")) assert.NotNil(t, err) traceInfo := reqCtx.GetTraceInfo() assert.False(t, traceInfo.Stats().GetEvent(stats.HTTPStart).IsNil()) assert.False(t, traceInfo.Stats().GetEvent(stats.ReadHeaderStart).IsNil()) assert.False(t, traceInfo.Stats().GetEvent(stats.ReadHeaderFinish).IsNil()) assert.Nil(t, traceInfo.Stats().GetEvent(stats.ReadBodyStart)) assert.Nil(t, traceInfo.Stats().GetEvent(stats.ReadBodyFinish)) assert.Nil(t, traceInfo.Stats().GetEvent(stats.ServerHandleStart)) assert.Nil(t, traceInfo.Stats().GetEvent(stats.ServerHandleFinish)) assert.Nil(t, traceInfo.Stats().GetEvent(stats.WriteStart)) assert.Nil(t, traceInfo.Stats().GetEvent(stats.WriteFinish)) assert.False(t, traceInfo.Stats().GetEvent(stats.HTTPFinish).IsNil()) } func TestTraceEventReadBodyError(t *testing.T) { server := &Server{} server.eventStackPool = pool server.EnableTrace = true server.GetOnly = true reqCtx := &app.RequestContext{} server.Core = &mockCore{ ctxPool: &sync.Pool{New: func() interface{} { ti := traceinfo.NewTraceInfo() ti.Stats().SetLevel(2) reqCtx.SetTraceInfo(&mockTraceInfo{ti}) return reqCtx }}, controller: &inStats.Controller{}, } err := server.Serve(context.TODO(), mock.NewConn("POST /aaa HTTP/1.1\nHost: foobar.com\nContent-Length: 5\nContent-Type: foo/bar\n\n12346\n\n")) assert.NotNil(t, err) traceInfo := reqCtx.GetTraceInfo() assert.False(t, traceInfo.Stats().GetEvent(stats.HTTPStart).IsNil()) assert.False(t, traceInfo.Stats().GetEvent(stats.ReadHeaderStart).IsNil()) assert.False(t, traceInfo.Stats().GetEvent(stats.ReadHeaderFinish).IsNil()) assert.False(t, traceInfo.Stats().GetEvent(stats.ReadBodyStart).IsNil()) assert.False(t, traceInfo.Stats().GetEvent(stats.ReadBodyFinish).IsNil()) assert.Nil(t, traceInfo.Stats().GetEvent(stats.ServerHandleStart)) assert.Nil(t, traceInfo.Stats().GetEvent(stats.ServerHandleFinish)) assert.Nil(t, traceInfo.Stats().GetEvent(stats.WriteStart)) assert.Nil(t, traceInfo.Stats().GetEvent(stats.WriteFinish)) assert.False(t, traceInfo.Stats().GetEvent(stats.HTTPFinish).IsNil()) } func TestTraceEventWriteError(t *testing.T) { server := &Server{} server.eventStackPool = pool server.EnableTrace = true reqCtx := &app.RequestContext{} server.Core = &mockCore{ ctxPool: &sync.Pool{New: func() interface{} { ti := traceinfo.NewTraceInfo() ti.Stats().SetLevel(2) reqCtx.SetTraceInfo(&mockTraceInfo{ti}) return reqCtx }}, controller: &inStats.Controller{}, } err := server.Serve( context.TODO(), &mockErrorWriter{ mock.NewConn("POST /aaa HTTP/1.1\nHost: foobar.com\nContent-Length: 5\nContent-Type: foo/bar\n\n12346\n\n"), }, ) assert.NotNil(t, err) traceInfo := reqCtx.GetTraceInfo() assert.False(t, traceInfo.Stats().GetEvent(stats.HTTPStart).IsNil()) assert.False(t, traceInfo.Stats().GetEvent(stats.ReadHeaderStart).IsNil()) assert.False(t, traceInfo.Stats().GetEvent(stats.ReadHeaderFinish).IsNil()) assert.False(t, traceInfo.Stats().GetEvent(stats.ReadBodyStart).IsNil()) assert.False(t, traceInfo.Stats().GetEvent(stats.ReadBodyFinish).IsNil()) assert.False(t, traceInfo.Stats().GetEvent(stats.ServerHandleStart).IsNil()) assert.False(t, traceInfo.Stats().GetEvent(stats.ServerHandleFinish).IsNil()) assert.False(t, traceInfo.Stats().GetEvent(stats.WriteStart).IsNil()) assert.False(t, traceInfo.Stats().GetEvent(stats.WriteFinish).IsNil()) assert.False(t, traceInfo.Stats().GetEvent(stats.HTTPFinish).IsNil()) } func TestEventStack(t *testing.T) { // Create a stack. s := &eventStack{} assert.True(t, s.isEmpty()) count := 0 // Push 10 events. for i := 0; i < 10; i++ { s.push(func(ti traceinfo.TraceInfo, err error) { count += 1 }) } assert.False(t, s.isEmpty()) // Pop 10 events and process them. for last := s.pop(); last != nil; last = s.pop() { last(nil, nil) } assert.DeepEqual(t, 10, count) // Pop an empty stack. e := s.pop() if e != nil { t.Fatalf("should be nil") } } func TestDefaultWriter(t *testing.T) { server := &Server{} reqCtx := &app.RequestContext{} server.Core = &mockCore{ ctxPool: &sync.Pool{New: func() interface{} { return reqCtx }}, mockHandler: func(c context.Context, ctx *app.RequestContext) { ctx.Write([]byte("hello, hertz")) ctx.Flush() }, } defaultConn := mock.NewConn("GET / HTTP/1.1\nHost: foobar.com\n\n") err := server.Serve(context.TODO(), defaultConn) assert.True(t, errors.Is(err, errs.ErrShortConnection)) defaultResponseResult := defaultConn.WriterRecorder() assert.DeepEqual(t, 0, defaultResponseResult.Len()) // all data is flushed so the buffer length is 0 response := protocol.AcquireResponse() resp.Read(response, defaultResponseResult) assert.DeepEqual(t, "hello, hertz", string(response.Body())) } func TestServerDisableReqCtxPool(t *testing.T) { server := &Server{} reqCtx := &app.RequestContext{} server.Core = &mockCore{ ctxPool: &sync.Pool{New: func() interface{} { reqCtx.Set("POOL_KEY", "in pool") return reqCtx }}, mockHandler: func(c context.Context, ctx *app.RequestContext) { if ctx.GetString("POOL_KEY") != "in pool" { t.Fatal("reqCtx is not in pool") } }, isRunning: true, } defaultConn := mock.NewConn("GET / HTTP/1.1\nHost: foobar.com\n\n") err := server.Serve(context.TODO(), defaultConn) assert.Nil(t, err) disabaleRequestContextPool = true defer func() { // reset global variable disabaleRequestContextPool = false }() server.Core = &mockCore{ ctxPool: &sync.Pool{New: func() interface{} { reqCtx.Set("POOL_KEY", "in pool") return reqCtx }}, mockHandler: func(c context.Context, ctx *app.RequestContext) { if len(ctx.GetString("POOL_KEY")) != 0 { t.Fatal("must not get pool key") } }, isRunning: true, } defaultConn = mock.NewConn("GET / HTTP/1.1\nHost: foobar.com\n\n") err = server.Serve(context.TODO(), defaultConn) assert.Nil(t, err) } func TestServer_RaceDetect(t *testing.T) { c := &app.RequestContext{} _ = c.Request.URI() // parsedURI = true m := &mockCore{ ctxPool: &sync.Pool{New: func() interface{} { return c }}, mockHandler: func(c context.Context, ctx *app.RequestContext) { panic("must not be called") }, } s := &Server{} s.Core = m conn := mock.NewConn("GET / HTTP/1.1\nHost: foobar.com\n\n") err := s.Serve(context.Background(), conn) assert.NotNil(t, err) assert.Assert(t, err.Error() == "race detected") } func TestHijackResponseWriter(t *testing.T) { server := &Server{} reqCtx := &app.RequestContext{} buf := new(bytes.Buffer) isFinal := false server.Core = &mockCore{ ctxPool: &sync.Pool{New: func() interface{} { return reqCtx }}, mockHandler: func(c context.Context, ctx *app.RequestContext) { // response before write will be dropped ctx.Write([]byte("invalid data")) ctx.Response.HijackWriter(&mock.ExtWriter{ Buf: buf, IsFinal: &isFinal, }) ctx.Write([]byte("hello, hertz")) ctx.Flush() }, } defaultConn := mock.NewConn("GET / HTTP/1.1\nHost: foobar.com\n\n") err := server.Serve(context.TODO(), defaultConn) assert.True(t, errors.Is(err, errs.ErrShortConnection)) defaultResponseResult := defaultConn.WriterRecorder() response := protocol.AcquireResponse() resp.Read(response, defaultResponseResult) assert.DeepEqual(t, 0, len(response.Body())) assert.DeepEqual(t, "hello, hertz", buf.String()) assert.True(t, isFinal) } func TestHijackHandler(t *testing.T) { server := NewServer() reqCtx := &app.RequestContext{} originReadTimeout := time.Second hijackReadTimeout := 200 * time.Millisecond reqCtx.SetHijackHandler(func(c network.Conn) { c.SetReadTimeout(hijackReadTimeout) // hijack read timeout }) server.Core = &mockCore{ ctxPool: &sync.Pool{New: func() interface{} { return reqCtx }}, } server.HijackConnHandle = func(c network.Conn, h app.HijackHandler) { h(c) } defaultConn := mock.NewConn("GET / HTTP/1.1\nHost: foobar.com\n\n") defaultConn.SetReadTimeout(originReadTimeout) assert.DeepEqual(t, originReadTimeout, defaultConn.GetReadTimeout()) err := server.Serve(context.TODO(), defaultConn) assert.True(t, errors.Is(err, errs.ErrHijacked)) assert.DeepEqual(t, hijackReadTimeout, defaultConn.GetReadTimeout()) } func TestKeepAlive(t *testing.T) { server := NewServer() reqCtx := &app.RequestContext{} times := 0 server.Core = &mockCore{ ctxPool: &sync.Pool{New: func() interface{} { return reqCtx }}, isRunning: true, mockHandler: func(c context.Context, ctx *app.RequestContext) { times++ if string(ctx.Path()) == "/close" { ctx.SetConnectionClose() } }, } server.IdleTimeout = time.Second var s strings.Builder s.WriteString("GET / HTTP/1.1\r\nHost: aaa\r\nConnection: keep-alive\r\n\r\n") s.WriteString("GET /close HTTP/1.0\r\nHost: aaa\r\nConnection: keep-alive\r\n\r\n") // set connection close defaultConn := mock.NewConn(s.String()) err := server.Serve(context.TODO(), defaultConn) assert.True(t, errors.Is(err, errs.ErrShortConnection)) assert.DeepEqual(t, times, 2) } func TestExpect100Continue(t *testing.T) { server := &Server{} reqCtx := &app.RequestContext{} server.Core = &mockCore{ ctxPool: &sync.Pool{New: func() interface{} { return reqCtx }}, mockHandler: func(c context.Context, ctx *app.RequestContext) { data, err := ctx.Body() if err == nil { ctx.Write(data) } }, } defaultConn := mock.NewConn("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345") err := server.Serve(context.TODO(), defaultConn) assert.True(t, errors.Is(err, errs.ErrShortConnection)) defaultResponseResult := defaultConn.WriterRecorder() assert.DeepEqual(t, 0, defaultResponseResult.Len()) response := protocol.AcquireResponse() resp.Read(response, defaultResponseResult) assert.DeepEqual(t, "12345", string(response.Body())) } func TestExpect100ContinueHandler(t *testing.T) { server := &Server{} reqCtx := &app.RequestContext{} server.Core = &mockCore{ ctxPool: &sync.Pool{New: func() interface{} { return reqCtx }}, mockHandler: func(c context.Context, ctx *app.RequestContext) { data, err := ctx.Body() if err == nil { ctx.Write(data) } }, } server.ContinueHandler = func(header *protocol.RequestHeader) bool { return false } defaultConn := mock.NewConn("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345") err := server.Serve(context.TODO(), defaultConn) assert.True(t, errors.Is(err, errs.ErrShortConnection)) defaultResponseResult := defaultConn.WriterRecorder() assert.DeepEqual(t, 0, defaultResponseResult.Len()) response := protocol.AcquireResponse() resp.Read(response, defaultResponseResult) assert.DeepEqual(t, consts.StatusExpectationFailed, response.StatusCode()) assert.DeepEqual(t, "", string(response.Body())) } type mockController struct { FinishTimes int } func (m *mockController) Append(col tracer.Tracer) {} func (m *mockController) DoStart(ctx context.Context, c *app.RequestContext) context.Context { return ctx } func (m *mockController) DoFinish(ctx context.Context, c *app.RequestContext, err error) { m.FinishTimes++ } func (m *mockController) HasTracer() bool { return true } func (m *mockController) reset() { m.FinishTimes = 0 } func TestTraceDoFinishTimes(t *testing.T) { server := &Server{} server.eventStackPool = pool server.EnableTrace = true reqCtx := &app.RequestContext{} controller := &mockController{} server.Core = &mockCore{ ctxPool: &sync.Pool{New: func() interface{} { ti := traceinfo.NewTraceInfo() ti.Stats().SetLevel(2) reqCtx.SetTraceInfo(&mockTraceInfo{ti}) return reqCtx }}, controller: controller, } // for disableKeepAlive case server.DisableKeepalive = true err := server.Serve(context.TODO(), mock.NewConn("GET /aaa HTTP/1.1\nHost: foobar.com\n\n")) assert.True(t, errors.Is(err, errs.ErrShortConnection)) assert.DeepEqual(t, 1, controller.FinishTimes) // for IdleTimeout==0 case server.IdleTimeout = 0 controller.reset() err = server.Serve(context.TODO(), mock.NewConn("GET /aaa HTTP/1.1\nHost: foobar.com\n\n")) assert.True(t, errors.Is(err, errs.ErrShortConnection)) assert.DeepEqual(t, 1, controller.FinishTimes) } type mockCore struct { ctxPool *sync.Pool controller tracer.Controller mockHandler func(c context.Context, ctx *app.RequestContext) isRunning bool } func (m *mockCore) IsRunning() bool { return m.isRunning } func (m *mockCore) GetCtxPool() *sync.Pool { return m.ctxPool } func (m *mockCore) ServeHTTP(c context.Context, ctx *app.RequestContext) { if m.mockHandler != nil { m.mockHandler(c, ctx) } } func (m *mockCore) GetTracer() tracer.Controller { return m.controller } type mockTraceInfo struct { traceinfo.TraceInfo } func (m *mockTraceInfo) Reset() {} type mockErrorWriter struct { network.Conn } func (errorWriter *mockErrorWriter) Flush() error { return errors.New("error") } func TestShouldRecordInTraceError(t *testing.T) { assert.False(t, shouldRecordInTraceError(nil)) assert.False(t, shouldRecordInTraceError(errHijacked)) assert.False(t, shouldRecordInTraceError(errIdleTimeout)) assert.False(t, shouldRecordInTraceError(errShortConnection)) assert.True(t, shouldRecordInTraceError(errTimeout)) assert.True(t, shouldRecordInTraceError(errors.New("foo error"))) } func TestServerMaxHeaderBytes(t *testing.T) { s := &Server{Option: Option{MaxHeaderBytes: 50}} largeHeaderReq := "GET / HTTP/1.1\r\nHost: example.com\r\nVery-Long-Header-Name: " + strings.Repeat("x", 200) + "\r\n\r\n" reader := mock.NewZeroCopyReader(largeHeaderReq) h := &protocol.RequestHeader{} err := req.ReadHeaderWithLimit(h, reader, s.MaxHeaderBytes) assert.NotNil(t, err) } func TestDefaultErrorHandlerHeaderTooLarge(t *testing.T) { ctx := app.NewContext(0) defaultErrorHandler(ctx, errs.ErrHeaderTooLarge) assert.DeepEqual(t, ctx.Response.StatusCode(), consts.StatusRequestHeaderFieldsTooLarge) } ================================================ FILE: pkg/protocol/http1/server_timing_test.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package http1 import ( "context" "fmt" "sync" "testing" inStats "github.com/cloudwego/hertz/internal/stats" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/common/test/mock" "github.com/cloudwego/hertz/pkg/common/tracer/traceinfo" ) func BenchmarkServer_Serve(b *testing.B) { server := &Server{} server.eventStackPool = &sync.Pool{ New: func() interface{} { return &eventStack{} }, } server.EnableTrace = true reqCtx := &app.RequestContext{} server.Core = &mockCore{ ctxPool: &sync.Pool{New: func() interface{} { ti := traceinfo.NewTraceInfo() ti.Stats().SetLevel(2) reqCtx.SetTraceInfo(&mockTraceInfo{ti}) return reqCtx }}, controller: &inStats.Controller{}, } err := server.Serve(context.TODO(), mock.NewConn("GET /aaa HTTP/1.1\nHost: foobar.com\n\n")) if err != nil { fmt.Println(err.Error()) } for i := 0; i < b.N; i++ { server.Serve(context.TODO(), mock.NewConn("GET /aaa HTTP/1.1\nHost: foobar.com\n\n")) } } ================================================ FILE: pkg/protocol/multipart.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package protocol import ( "fmt" "io" "mime/multipart" "net/http" "net/textproto" "os" "path/filepath" "strings" "github.com/cloudwego/hertz/pkg/common/bytebufferpool" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/protocol/consts" ) func ReadMultipartForm(r io.Reader, boundary string, size, maxInMemoryFileSize int) (*multipart.Form, error) { // Do not care about memory allocations here, since they are tiny // compared to multipart data (aka multi-MB files) usually sent // in multipart/form-data requests. if size <= 0 { return nil, fmt.Errorf("form size must be greater than 0. Given %d", size) } lr := io.LimitReader(r, int64(size)) mr := multipart.NewReader(lr, boundary) f, err := mr.ReadForm(int64(maxInMemoryFileSize)) if err != nil { return nil, fmt.Errorf("cannot read multipart/form-data body: %s", err) } return f, nil } // WriteMultipartForm writes the given multipart form f with the given // boundary to w. func WriteMultipartForm(w io.Writer, f *multipart.Form, boundary string) error { // Do not care about memory allocations here, since multipart // form processing is slow. if len(boundary) == 0 { panic("BUG: form boundary cannot be empty") } mw := multipart.NewWriter(w) if err := mw.SetBoundary(boundary); err != nil { return fmt.Errorf("cannot use form boundary %q: %s", boundary, err) } // marshal values for k, vv := range f.Value { for _, v := range vv { if err := mw.WriteField(k, v); err != nil { return fmt.Errorf("cannot write form field %q value %q: %s", k, v, err) } } } // marshal files for k, fvv := range f.File { for _, fv := range fvv { vw, err := mw.CreatePart(fv.Header) zw := network.NewWriter(vw) if err != nil { return fmt.Errorf("cannot create form file %q (%q): %s", k, fv.Filename, err) } fh, err := fv.Open() if err != nil { return fmt.Errorf("cannot open form file %q (%q): %s", k, fv.Filename, err) } if _, err = utils.CopyZeroAlloc(zw, fh); err != nil { return fmt.Errorf("error when copying form file %q (%q): %s", k, fv.Filename, err) } if err = fh.Close(); err != nil { return fmt.Errorf("cannot close form file %q (%q): %s", k, fv.Filename, err) } } } if err := mw.Close(); err != nil { return fmt.Errorf("error when closing multipart form writer: %s", err) } return nil } func MarshalMultipartForm(f *multipart.Form, boundary string) ([]byte, error) { var buf bytebufferpool.ByteBuffer if err := WriteMultipartForm(&buf, f, boundary); err != nil { return nil, err } return buf.B, nil } func WriteMultipartFormFile(w *multipart.Writer, fieldName, fileName string, r io.Reader) error { // Auto detect actual multipart content type cbuf := make([]byte, 512) size, err := r.Read(cbuf) if err != nil && err != io.EOF { return err } partWriter, err := w.CreatePart(CreateMultipartHeader(fieldName, fileName, http.DetectContentType(cbuf[:size]))) if err != nil { return err } if _, err = partWriter.Write(cbuf[:size]); err != nil { return err } _, err = io.Copy(partWriter, r) return err } func CreateMultipartHeader(param, fileName, contentType string) textproto.MIMEHeader { hdr := make(textproto.MIMEHeader) var contentDispositionValue string if len(strings.TrimSpace(fileName)) == 0 { contentDispositionValue = fmt.Sprintf(`form-data; name="%s"`, param) } else { contentDispositionValue = fmt.Sprintf(`form-data; name="%s"; filename="%s"`, param, fileName) } hdr.Set("Content-Disposition", contentDispositionValue) if len(contentType) > 0 { hdr.Set(consts.HeaderContentType, contentType) } return hdr } func AddFile(w *multipart.Writer, fieldName, path string) error { file, err := os.Open(path) if err != nil { return err } defer file.Close() return WriteMultipartFormFile(w, fieldName, filepath.Base(path), file) } func ParseMultipartForm(r io.Reader, request *Request, size, maxInMemoryFileSize int) error { m, err := ReadMultipartForm(r, request.multipartFormBoundary, size, maxInMemoryFileSize) if err != nil { return err } request.multipartForm = m return nil } func SetMultipartFormWithBoundary(req *Request, m *multipart.Form, boundary string) { req.multipartForm = m req.multipartFormBoundary = boundary } ================================================ FILE: pkg/protocol/multipart_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package protocol import ( "bytes" "mime/multipart" "net/textproto" "os" "strings" "testing" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestWriteMultipartForm(t *testing.T) { t.Parallel() var w bytes.Buffer s := strings.Replace(`--foo Content-Disposition: form-data; name="key" value --foo Content-Disposition: form-data; name="file"; filename="test.json" Content-Type: application/json {"foo": "bar"} --foo-- `, "\n", "\r\n", -1) mr := multipart.NewReader(strings.NewReader(s), "foo") form, err := mr.ReadForm(1024) if err != nil { t.Fatalf("unexpected error: %s", err) } // The length of boundary is in the range of [1,70], which can be verified for strings outside this range. err = WriteMultipartForm(&w, form, s) assert.NotNil(t, err) // set Boundary as empty assert.Panic(t, func() { err = WriteMultipartForm(&w, form, "") }) // call WriteField as twice var body bytes.Buffer mw := multipart.NewWriter(&body) if err = mw.WriteField("field1", "value1"); err != nil { t.Fatal(err) } err = WriteMultipartForm(&w, form, s) assert.NotNil(t, err) // normal test err = WriteMultipartForm(&w, form, "foo") if err != nil { t.Fatalf("unexpected error: %s", err) } if w.String() != s { t.Fatalf("unexpected output %q", w.Bytes()) } } func TestParseMultipartForm(t *testing.T) { t.Parallel() s := strings.Replace(`--foo Content-Disposition: form-data; name="key" value --foo-- `, "\n", "\r\n", -1) req1 := Request{} req1.SetMultipartFormBoundary("foo") // test size 0 assert.NotNil(t, ParseMultipartForm(strings.NewReader(s), &req1, 0, 0)) err := ParseMultipartForm(strings.NewReader(s), &req1, 1024, 1024) if err != nil { t.Fatalf("unexpected error %s", err) } req2 := Request{} mr := multipart.NewReader(strings.NewReader(s), "foo") form, err := mr.ReadForm(1024) if err != nil { t.Fatalf("unexpected error: %s", err) } SetMultipartFormWithBoundary(&req2, form, "foo") assert.DeepEqual(t, &req1, &req2) // set Boundary as " " req1.SetMultipartFormBoundary(" ") err = ParseMultipartForm(strings.NewReader(s), &req1, 1024, 1024) assert.NotNil(t, err) // set size 0 err = ParseMultipartForm(strings.NewReader(s), &req1, 0, 0) assert.NotNil(t, err) } func TestWriteMultipartFormFile(t *testing.T) { t.Parallel() bodyBuffer := &bytes.Buffer{} w := multipart.NewWriter(bodyBuffer) // read multipart.go to buf1 f1, err := os.Open("./multipart.go") if err != nil { t.Fatalf("open file %s error: %s", f1.Name(), err) } defer f1.Close() multipartFile := File{ Name: f1.Name(), ParamName: "multipartCode", Reader: f1, } err = WriteMultipartFormFile(w, multipartFile.ParamName, f1.Name(), multipartFile.Reader) if err != nil { t.Fatalf("write multipart error: %s", err) } fileInfo1, err := f1.Stat() if err != nil { t.Fatalf("get file state error: %s", err) } buf1 := make([]byte, fileInfo1.Size()) _, err = f1.ReadAt(buf1, 0) if err != nil { t.Fatalf("read file to bytes error: %s", err) } assert.True(t, strings.Contains(bodyBuffer.String(), string(buf1))) // test file not found assert.Nil(t, WriteMultipartFormFile(w, multipartFile.ParamName, "test.go", multipartFile.Reader)) // Test Add File Function err = AddFile(w, "responseCode", "./response.go") if err != nil { t.Fatalf("add file error: %s", err) } // read response.go to buf2 f2, err := os.Open("./response.go") if err != nil { t.Fatalf("open file %s error: %s", f2.Name(), err) } defer f2.Close() fileInfo2, err := f2.Stat() if err != nil { t.Fatalf("get file state error: %s", err) } buf2 := make([]byte, fileInfo2.Size()) _, err = f2.ReadAt(buf2, 0) if err != nil { t.Fatalf("read file to bytes error: %s", err) } assert.True(t, strings.Contains(bodyBuffer.String(), string(buf2))) // test file not found err = AddFile(w, "responseCode", "./test.go") assert.NotNil(t, err) // test WriteMultipartFormFile without file name bodyBuffer = &bytes.Buffer{} w = multipart.NewWriter(bodyBuffer) // read multipart.go to buf1 f3, err := os.Open("./multipart.go") if err != nil { t.Fatalf("open file %s error: %s", f3.Name(), err) } defer f3.Close() err = WriteMultipartFormFile(w, "multipart", " ", f3) if err != nil { t.Fatalf("write multipart error: %s", err) } assert.False(t, strings.Contains(bodyBuffer.String(), f3.Name())) // test empty file assert.Nil(t, WriteMultipartFormFile(w, "empty_test", "test.data", bytes.NewBuffer(nil))) } func TestMarshalMultipartForm(t *testing.T) { s := strings.Replace(`--foo Content-Disposition: form-data; name="key" value --foo Content-Disposition: form-data; name="file"; filename="test.json" Content-Type: application/json {"foo": "bar"} --foo-- `, "\n", "\r\n", -1) mr := multipart.NewReader(strings.NewReader(s), "foo") form, err := mr.ReadForm(1024) if err != nil { t.Fatalf("unexpected error: %s", err) } bufs, err := MarshalMultipartForm(form, "foo") assert.Nil(t, err) assert.DeepEqual(t, s, string(bufs)) // set boundary invalid _, err = MarshalMultipartForm(form, " ") assert.NotNil(t, err) } func TestAddFile(t *testing.T) { t.Parallel() bodyBuffer := &bytes.Buffer{} w := multipart.NewWriter(bodyBuffer) // add null file err := AddFile(w, "test", "/test") assert.NotNil(t, err) } func TestCreateMultipartHeader(t *testing.T) { t.Parallel() // filename == Null hdr1 := make(textproto.MIMEHeader) hdr1.Set("Content-Disposition", `form-data; name="test"`) hdr1.Set("Content-Type", "application/json") assert.DeepEqual(t, hdr1, CreateMultipartHeader("test", "", "application/json")) // normal test hdr2 := make(textproto.MIMEHeader) hdr2.Set("Content-Disposition", `form-data; name="test"; filename="/test.go"`) hdr2.Set("Content-Type", "application/json") assert.DeepEqual(t, hdr2, CreateMultipartHeader("test", "/test.go", "application/json")) } ================================================ FILE: pkg/protocol/request.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package protocol import ( "bytes" "compress/gzip" "encoding/base64" "fmt" "io" "mime/multipart" "net/url" "strings" "sync" "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/internal/nocopy" "github.com/cloudwego/hertz/pkg/common/bytebufferpool" "github.com/cloudwego/hertz/pkg/common/compress" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/protocol/consts" ) var ( ErrMissingFile = errors.NewPublic("http: no such file") responseBodyPool bytebufferpool.Pool requestBodyPool bytebufferpool.Pool requestPool sync.Pool ) // NoBody is an io.ReadCloser with no bytes. Read always returns EOF // and Close always returns nil. It can be used in an outgoing client // request to explicitly signal that a request has zero bytes. var NoBody = noBody{} type noBody struct{} func (noBody) Read([]byte) (int, error) { return 0, io.EOF } func (noBody) Close() error { return nil } type Request struct { noCopy nocopy.NoCopy //lint:ignore U1000 until noCopy is used Header RequestHeader uri URI postArgs Args bodyStream io.Reader w requestBodyWriter body *bytebufferpool.ByteBuffer bodyRaw []byte maxKeepBodySize int multipartForm *multipart.Form multipartFormBoundary string // Group bool members in order to reduce Request object size. parsedURI bool parsedPostArgs bool isTLS bool multipartFiles []*File multipartFields []*MultipartField // Request level options, service discovery options etc. options *config.RequestOptions } type requestBodyWriter struct { r *Request } // File struct represent file information for multipart request type File struct { Name string ParamName string io.Reader } // MultipartField struct represent custom data part for multipart request type MultipartField struct { Param string FileName string ContentType string io.Reader } func (w *requestBodyWriter) Write(p []byte) (int, error) { w.r.AppendBody(p) return len(p), nil } func (req *Request) Options() *config.RequestOptions { if req.options == nil { req.options = config.NewRequestOptions(nil) } return req.options } // AppendBody appends p to request body. // // It is safe re-using p after the function returns. func (req *Request) AppendBody(p []byte) { req.RemoveMultipartFormFiles() req.CloseBodyStream() //nolint:errcheck req.BodyBuffer().Write(p) //nolint:errcheck } func (req *Request) BodyBuffer() *bytebufferpool.ByteBuffer { if req.body == nil { req.body = requestBodyPool.Get() } req.bodyRaw = nil return req.body } // MayContinue returns true if the request contains // 'Expect: 100-continue' header. // // The caller must do one of the following actions if MayContinue returns true: // // - Either send StatusExpectationFailed response if request headers don't // satisfy the caller. // - Or send StatusContinue response before reading request body // with ContinueReadBody. // - Or close the connection. func (req *Request) MayContinue() bool { return bytes.Equal(req.Header.peek(consts.HeaderExpect), bytestr.Str100Continue) } // Scheme returns the scheme of the request. // uri will be parsed in ServeHTTP(before user's process), so that there is no need for uri nil-check. func (req *Request) Scheme() []byte { return req.uri.Scheme() } // For keepalive connection reuse. // It is roughly the same as ResetSkipHeader, except that the connection related fields are removed: // - req.isTLS func (req *Request) resetSkipHeaderAndConn() { req.ResetBody() req.uri.Reset() req.parsedURI = false req.parsedPostArgs = false req.postArgs.Reset() } func (req *Request) ResetSkipHeader() { req.resetSkipHeaderAndConn() req.isTLS = false } func SwapRequestBody(a, b *Request) { a.body, b.body = b.body, a.body a.bodyRaw, b.bodyRaw = b.bodyRaw, a.bodyRaw a.bodyStream, b.bodyStream = b.bodyStream, a.bodyStream a.multipartFields, b.multipartFields = b.multipartFields, a.multipartFields a.multipartFiles, b.multipartFiles = b.multipartFiles, a.multipartFiles } // Reset clears request contents. func (req *Request) Reset() { req.Header.Reset() req.ResetSkipHeader() req.CloseBodyStream() req.options = nil } func (req *Request) IsURIParsed() bool { return req.parsedURI } func (req *Request) PostArgString() []byte { return req.postArgs.QueryString() } // MultipartForm returns request's multipart form. // // Returns errors.ErrNoMultipartForm if request's Content-Type // isn't 'multipart/form-data'. // // RemoveMultipartFormFiles must be called after returned multipart form // is processed. func (req *Request) MultipartForm() (*multipart.Form, error) { if req.multipartForm != nil { return req.multipartForm, nil } req.multipartFormBoundary = string(req.Header.MultipartFormBoundary()) if len(req.multipartFormBoundary) == 0 { return nil, errors.ErrNoMultipartForm } ce := req.Header.peek(consts.HeaderContentEncoding) var err error var f *multipart.Form if !req.IsBodyStream() { body := req.BodyBytes() if bytes.Equal(ce, bytestr.StrGzip) { // Do not care about memory usage here. var err error if body, err = compress.AppendGunzipBytes(nil, body); err != nil { return nil, fmt.Errorf("cannot gunzip request body: %s", err) } } else if len(ce) > 0 { return nil, fmt.Errorf("unsupported Content-Encoding: %q", ce) } f, err = ReadMultipartForm(bytes.NewReader(body), req.multipartFormBoundary, len(body), len(body)) } else { bodyStream := req.bodyStream if req.Header.contentLength > 0 { bodyStream = io.LimitReader(bodyStream, int64(req.Header.contentLength)) } if bytes.Equal(ce, bytestr.StrGzip) { // Do not care about memory usage here. if bodyStream, err = gzip.NewReader(bodyStream); err != nil { return nil, fmt.Errorf("cannot gunzip request body: %w", err) } } else if len(ce) > 0 { return nil, fmt.Errorf("unsupported Content-Encoding: %q", ce) } mr := multipart.NewReader(bodyStream, req.multipartFormBoundary) f, err = mr.ReadForm(8 * 1024) } if err != nil { return nil, err } req.multipartForm = f return f, nil } // AppendBodyString appends s to request body. func (req *Request) AppendBodyString(s string) { req.RemoveMultipartFormFiles() req.CloseBodyStream() //nolint:errcheck req.BodyBuffer().WriteString(s) //nolint:errcheck } // SetRequestURI sets RequestURI. func (req *Request) SetRequestURI(requestURI string) { req.Header.SetRequestURI(requestURI) req.parsedURI = false } func (req *Request) SetMaxKeepBodySize(n int) { req.maxKeepBodySize = n } // RequestURI returns the RequestURI for the given request. func (req *Request) RequestURI() []byte { return req.Header.RequestURI() } // FormFile returns the first file for the provided form key. func (req *Request) FormFile(name string) (*multipart.FileHeader, error) { mf, err := req.MultipartForm() if err != nil { return nil, err } if mf.File == nil { return nil, err } fhh := mf.File[name] if fhh == nil { return nil, ErrMissingFile } return fhh[0], nil } // SetHost sets host for the request. func (req *Request) SetHost(host string) { req.URI().SetHost(host) } // Host returns the host for the given request. func (req *Request) Host() []byte { return req.URI().Host() } // SetIsTLS is used by TLS server to mark whether the request is a TLS request. // Client shouldn't use this method but should depend on the uri.scheme instead. func (req *Request) SetIsTLS(isTLS bool) { req.isTLS = isTLS } // SwapBody swaps request body with the given body and returns // the previous request body. // // It is forbidden to use the body passed to SwapBody after // the function returns. func (req *Request) SwapBody(body []byte) []byte { bb := req.BodyBuffer() zw := network.NewWriter(bb) if req.IsBodyStream() { bb.Reset() _, err := utils.CopyZeroAlloc(zw, req.bodyStream) req.CloseBodyStream() //nolint:errcheck if err != nil { bb.Reset() bb.SetString(err.Error()) } } req.bodyRaw = nil oldBody := bb.B bb.B = body return oldBody } // CopyTo copies req contents to dst except of body stream. func (req *Request) CopyTo(dst *Request) { req.CopyToSkipBody(dst) if req.bodyRaw != nil { dst.bodyRaw = append(dst.bodyRaw[:0], req.bodyRaw...) if dst.body != nil { dst.body.Reset() } } else if req.body != nil { dst.BodyBuffer().Set(req.body.B) } else if dst.body != nil { dst.body.Reset() } } func (req *Request) CopyToSkipBody(dst *Request) { dst.Reset() req.Header.CopyTo(&dst.Header) req.uri.CopyTo(&dst.uri) dst.parsedURI = req.parsedURI req.postArgs.CopyTo(&dst.postArgs) dst.parsedPostArgs = req.parsedPostArgs dst.isTLS = req.isTLS if req.options != nil { dst.options = &config.RequestOptions{} req.options.CopyTo(dst.options) } // do not copy multipartForm - it will be automatically // re-created on the first call to MultipartForm. } func (req *Request) BodyBytes() []byte { if req.bodyRaw != nil { return req.bodyRaw } if req.body == nil { return nil } return req.body.B } // ResetBody resets request body. func (req *Request) ResetBody() { req.bodyRaw = nil req.RemoveMultipartFormFiles() req.CloseBodyStream() //nolint:errcheck if req.body != nil { if req.body.Cap() <= req.maxKeepBodySize { req.body.Reset() return } requestBodyPool.Put(req.body) req.body = nil } } // SetBodyRaw sets request body, but without copying it. // // From this point onward the body argument must not be changed. func (req *Request) SetBodyRaw(body []byte) { req.ResetBody() req.bodyRaw = body } // SetMultipartFormBoundary will set the multipart form boundary for the request. func (req *Request) SetMultipartFormBoundary(b string) { req.multipartFormBoundary = b } func (req *Request) MultipartFormBoundary() string { return req.multipartFormBoundary } // SetBody sets request body. // // It is safe re-using body argument after the function returns. func (req *Request) SetBody(body []byte) { req.RemoveMultipartFormFiles() req.CloseBodyStream() //nolint:errcheck req.BodyBuffer().Set(body) } // SetBodyString sets request body. func (req *Request) SetBodyString(body string) { req.RemoveMultipartFormFiles() req.CloseBodyStream() //nolint:errcheck req.BodyBuffer().SetString(body) } // SetQueryString sets query string. func (req *Request) SetQueryString(queryString string) { req.URI().SetQueryString(queryString) } // SetFormData sets x-www-form-urlencoded params func (req *Request) SetFormData(data map[string]string) { for k, v := range data { req.postArgs.Add(k, v) } req.parsedPostArgs = true req.Header.SetContentTypeBytes(bytestr.MIMEPostForm) } // SetFormDataFromValues sets x-www-form-urlencoded params from url values. func (req *Request) SetFormDataFromValues(data url.Values) { for k, v := range data { for _, kv := range v { req.postArgs.Add(k, kv) } } req.parsedPostArgs = true req.Header.SetContentTypeBytes(bytestr.MIMEPostForm) } // SetFile sets single file field name and its path for multipart upload. func (req *Request) SetFile(param, filePath string) { req.multipartFiles = append(req.multipartFiles, &File{ Name: filePath, ParamName: param, }) } // SetFiles sets multiple file field name and its path for multipart upload. func (req *Request) SetFiles(files map[string]string) { for f, fp := range files { req.multipartFiles = append(req.multipartFiles, &File{ Name: fp, ParamName: f, }) } } // SetFileReader sets single file using io.Reader for multipart upload. func (req *Request) SetFileReader(param, fileName string, reader io.Reader) { req.multipartFiles = append(req.multipartFiles, &File{ Name: fileName, ParamName: param, Reader: reader, }) } // SetMultipartFormData method allows simple form data to be attached to the request as `multipart:form-data` func (req *Request) SetMultipartFormData(data map[string]string) { for k, v := range data { req.SetMultipartField(k, "", "", strings.NewReader(v)) } } func (req *Request) MultipartFiles() []*File { return req.multipartFiles } // SetMultipartField sets custom data using io.Reader for multipart upload. func (req *Request) SetMultipartField(param, fileName, contentType string, reader io.Reader) { req.multipartFields = append(req.multipartFields, &MultipartField{ Param: param, FileName: fileName, ContentType: contentType, Reader: reader, }) } // SetMultipartFields sets multiple data fields using io.Reader for multipart upload. func (req *Request) SetMultipartFields(fields ...*MultipartField) { req.multipartFields = append(req.multipartFields, fields...) } func (req *Request) MultipartFields() []*MultipartField { return req.multipartFields } // SetBasicAuth sets the basic authentication header in the current HTTP request. func (req *Request) SetBasicAuth(username, password string) { encodeStr := base64.StdEncoding.EncodeToString([]byte(username + ":" + password)) req.SetHeader(consts.HeaderAuthorization, "Basic "+encodeStr) } // BasicAuth can return the username and password in the request's Authorization // header, if the request uses the HTTP Basic Authorization. func (req *Request) BasicAuth() (username, password string, ok bool) { // Using Peek to reduce the cost for type transfer. auth := req.Header.Peek(consts.HeaderAuthorization) if auth == nil { return } return parseBasicAuth(auth) } var prefix = []byte{'B', 'a', 's', 'i', 'c', ' '} // parseBasicAuth can parse an HTTP Basic Authorization string encrypted by base64. // Example: "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==" returns ("Aladdin", "open sesame", true). func parseBasicAuth(auth []byte) (username, password string, ok bool) { if len(auth) < len(prefix) || !bytes.EqualFold(auth[:len(prefix)], prefix) { return } decodeLen := base64.StdEncoding.DecodedLen(len(auth[len(prefix):])) // base64.StdEncoding.Decode(dst,rsc []byte) will return less than DecodedLen(len(src))) decodeData := make([]byte, decodeLen) num, err := base64.StdEncoding.Decode(decodeData, auth[len(prefix):]) if err != nil { return } cs := bytesconv.B2s(decodeData[:num]) s := strings.IndexByte(cs, ':') if s < 0 { return } return cs[:s], cs[s+1:], true } // SetAuthToken sets the auth token header(Default Scheme: Bearer) in the current HTTP request. Header example: // // Authorization: Bearer func (req *Request) SetAuthToken(token string) { req.SetHeader(consts.HeaderAuthorization, "Bearer "+token) } // SetAuthSchemeToken sets the auth token scheme type in the HTTP request. For Example: // // Authorization: func (req *Request) SetAuthSchemeToken(scheme, token string) { req.SetHeader(consts.HeaderAuthorization, scheme+" "+token) } // SetHeader sets a single header field and its value in the current request. func (req *Request) SetHeader(header, value string) { req.Header.Set(header, value) } // SetHeaders sets multiple header field and its value in the current request. func (req *Request) SetHeaders(headers map[string]string) { for h, v := range headers { req.Header.Set(h, v) } } // SetCookie appends a single cookie in the current request instance. func (req *Request) SetCookie(key, value string) { req.Header.SetCookie(key, value) } // SetCookies sets an array of cookies in the current request instance. func (req *Request) SetCookies(hc map[string]string) { for k, v := range hc { req.Header.SetCookie(k, v) } } // SetMethod sets http method for this request. func (req *Request) SetMethod(method string) { req.Header.SetMethod(method) } func (req *Request) OnlyMultipartForm() bool { return req.multipartForm != nil && (req.body == nil || len(req.body.B) == 0) } func (req *Request) HasMultipartForm() bool { return req.multipartForm != nil } // IsBodyStream returns true if body is set via SetBodyStream* func (req *Request) IsBodyStream() bool { return req.bodyStream != nil && req.bodyStream != NoBody } func (req *Request) BodyStream() io.Reader { if req.bodyStream == nil { return NoBody } return req.bodyStream } // SetBodyStream sets request body stream and, optionally body size. // // If bodySize is >= 0, then the bodyStream must provide exactly bodySize bytes // before returning io.EOF. // // If bodySize < 0, then bodyStream is read until io.EOF. // // bodyStream.Close() is called after finishing reading all body data // if it implements io.Closer. // // Note that GET and HEAD requests cannot have body. // // See also SetBodyStreamWriter. func (req *Request) SetBodyStream(bodyStream io.Reader, bodySize int) { req.ResetBody() req.bodyStream = bodyStream req.Header.SetContentLength(bodySize) } func (req *Request) ConstructBodyStream(body *bytebufferpool.ByteBuffer, bodyStream io.Reader) { req.body = body req.bodyStream = bodyStream } // BodyWriter returns writer for populating request body. func (req *Request) BodyWriter() io.Writer { req.w.r = req return &req.w } // PostArgs returns POST arguments. func (req *Request) PostArgs() *Args { req.parsePostArgs() return &req.postArgs } func (req *Request) parsePostArgs() { if req.parsedPostArgs { return } req.parsedPostArgs = true if !bytes.HasPrefix(req.Header.ContentType(), bytestr.MIMEPostForm) { return } req.postArgs.ParseBytes(req.Body()) } // BodyE returns request body. func (req *Request) BodyE() ([]byte, error) { if req.bodyRaw != nil { return req.bodyRaw, nil } if req.IsBodyStream() { bodyBuf := req.BodyBuffer() bodyBuf.Reset() zw := network.NewWriter(bodyBuf) _, err := utils.CopyZeroAlloc(zw, req.bodyStream) req.CloseBodyStream() //nolint:errcheck if err != nil { return nil, err } return req.BodyBytes(), nil } if req.OnlyMultipartForm() { body, err := MarshalMultipartForm(req.multipartForm, req.multipartFormBoundary) if err != nil { return nil, err } return body, nil } return req.BodyBytes(), nil } // Body returns request body. // if get body failed, returns nil. func (req *Request) Body() []byte { body, _ := req.BodyE() return body } // BodyWriteTo writes request body to w. func (req *Request) BodyWriteTo(w io.Writer) error { if req.IsBodyStream() { zw := network.NewWriter(w) _, err := utils.CopyZeroAlloc(zw, req.bodyStream) req.CloseBodyStream() //nolint:errcheck return err } if req.OnlyMultipartForm() { return WriteMultipartForm(w, req.multipartForm, req.multipartFormBoundary) } _, err := w.Write(req.BodyBytes()) return err } func (req *Request) CloseBodyStream() error { if req.bodyStream == nil { return nil } var err error if bsc, ok := req.bodyStream.(io.Closer); ok { err = bsc.Close() } req.bodyStream = nil return err } // URI returns request URI func (req *Request) URI() *URI { req.ParseURI() return &req.uri } func (req *Request) ParseURI() { if req.parsedURI { return } req.parsedURI = true req.uri.parse(req.Header.Host(), req.Header.RequestURI(), req.isTLS) } // RemoveMultipartFormFiles removes multipart/form-data temporary files // associated with the request. func (req *Request) RemoveMultipartFormFiles() { if req.multipartForm != nil { // Do not check for error, since these files may be deleted or moved // to new places by user code. req.multipartForm.RemoveAll() //nolint:errcheck req.multipartForm = nil } req.multipartFormBoundary = "" req.multipartFiles = nil req.multipartFields = nil } func AddMultipartFormField(w *multipart.Writer, mf *MultipartField) error { partWriter, err := w.CreatePart(CreateMultipartHeader(mf.Param, mf.FileName, mf.ContentType)) if err != nil { return err } _, err = io.Copy(partWriter, mf.Reader) return err } // Method returns request method func (req *Request) Method() []byte { return req.Header.Method() } // Path returns request path func (req *Request) Path() []byte { return req.URI().Path() } // QueryString returns request query func (req *Request) QueryString() []byte { return req.URI().QueryString() } // SetOptions is used to set request options. // These options can be used to do something in middlewares such as service discovery. func (req *Request) SetOptions(opts ...config.RequestOption) { req.Options().Apply(opts) } // ConnectionClose returns true if 'Connection: close' header is set. func (req *Request) ConnectionClose() bool { return req.Header.ConnectionClose() } // SetConnectionClose sets 'Connection: close' header. func (req *Request) SetConnectionClose() { req.Header.SetConnectionClose(true) } func (req *Request) ResetWithoutConn() { req.Header.Reset() req.resetSkipHeaderAndConn() req.CloseBodyStream() req.options = nil } // AcquireRequest returns an empty Request instance from request pool. // // The returned Request instance may be passed to ReleaseRequest when it is // no longer needed. This allows Request recycling, reduces GC pressure // and usually improves performance. func AcquireRequest() *Request { v := requestPool.Get() if v == nil { return &Request{} } return v.(*Request) } // ReleaseRequest returns req acquired via AcquireRequest to request pool. // // It is forbidden accessing req and/or its members after returning // it to request pool. func ReleaseRequest(req *Request) { req.Reset() requestPool.Put(req) } // NewRequest makes a new Request given a method, URL, and // optional body. // // # Method's default value is GET // // Url must contain fully qualified uri, i.e. with scheme and host, // and http is assumed if scheme is omitted. // // Protocol version is always HTTP/1.1 // // NewRequest just uses for unit-testing. Use AcquireRequest() in other cases. func NewRequest(method, url string, body io.Reader) *Request { if method == "" { method = consts.MethodGet } req := new(Request) req.SetRequestURI(url) req.SetIsTLS(bytes.HasPrefix(bytesconv.S2b(url), bytestr.StrHTTPS)) req.ParseURI() req.SetMethod(method) req.Header.SetHost(string(req.URI().Host())) req.Header.SetRequestURIBytes(req.URI().RequestURI()) if !req.Header.IgnoreBody() { req.SetBodyStream(body, -1) switch v := req.BodyStream().(type) { case *bytes.Buffer: req.Header.SetContentLength(v.Len()) case *bytes.Reader: req.Header.SetContentLength(v.Len()) case *strings.Reader: req.Header.SetContentLength(v.Len()) default: } } return req } ================================================ FILE: pkg/protocol/request_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package protocol import ( "bytes" "encoding/base64" "fmt" "io" "io/ioutil" "math" "mime/multipart" "strings" "testing" "github.com/cloudwego/hertz/pkg/common/bytebufferpool" "github.com/cloudwego/hertz/pkg/common/compress" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/protocol/consts" ) type errorReader struct{} func (er errorReader) Read(p []byte) (int, error) { return 0, fmt.Errorf("dummy!") } func TestMultiForm(t *testing.T) { var r Request // r.Header.Set() _, err := r.MultipartForm() fmt.Println(err) } func TestRequestBodyWriterWrite(t *testing.T) { w := requestBodyWriter{&Request{}} w.Write([]byte("test")) assert.DeepEqual(t, "test", string(w.r.body.B)) } func TestRequestScheme(t *testing.T) { req := NewRequest("", "ptth://127.0.0.1:8080", nil) assert.DeepEqual(t, "ptth", string(req.Scheme())) req = NewRequest("", "127.0.0.1:8080", nil) assert.DeepEqual(t, "http", string(req.Scheme())) assert.DeepEqual(t, true, req.IsURIParsed()) } func TestRequestHost(t *testing.T) { req := &Request{} req.SetHost("127.0.0.1:8080") assert.DeepEqual(t, "127.0.0.1:8080", string(req.Host())) } func TestRequestSwapBody(t *testing.T) { reqA := &Request{} reqA.SetBodyRaw([]byte("testA")) reqB := &Request{} reqB.SetBodyRaw([]byte("testB")) SwapRequestBody(reqA, reqB) assert.DeepEqual(t, "testA", string(reqB.bodyRaw)) assert.DeepEqual(t, "testB", string(reqA.bodyRaw)) reqA.SetBody([]byte("testA")) reqB.SetBody([]byte("testB")) SwapRequestBody(reqA, reqB) assert.DeepEqual(t, "testA", string(reqB.body.B)) assert.DeepEqual(t, "", string(reqB.bodyRaw)) assert.DeepEqual(t, "testB", string(reqA.body.B)) assert.DeepEqual(t, "", string(reqA.bodyRaw)) reqA.SetBodyStream(strings.NewReader("testA"), len("testA")) reqB.SetBodyStream(strings.NewReader("testB"), len("testB")) SwapRequestBody(reqA, reqB) body := make([]byte, 5) reqB.bodyStream.Read(body) assert.DeepEqual(t, "testA", string(body)) reqA.bodyStream.Read(body) assert.DeepEqual(t, "testB", string(body)) } func TestRequestKnownSizeStreamMultipartFormWithFile(t *testing.T) { t.Parallel() s := `------WebKitFormBoundaryJwfATyF8tmxSJnLg Content-Disposition: form-data; name="f1" value1 ------WebKitFormBoundaryJwfATyF8tmxSJnLg Content-Disposition: form-data; name="fileaaa"; filename="TODO" Content-Type: application/octet-stream - SessionClient with referer and cookies support. - Client with requests' pipelining support. - ProxyHandler similar to FSHandler. - WebSockets. See https://tools.ietf.org/html/rfc6455 . - HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . ------WebKitFormBoundaryJwfATyF8tmxSJnLg-- tailfoobar` mr := strings.NewReader(s) r := NewRequest("POST", "/upload", mr) r.Header.SetContentLength(521) r.Header.SetContentTypeBytes([]byte("multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg")) assert.DeepEqual(t, false, r.HasMultipartForm()) f, err := r.MultipartForm() assert.DeepEqual(t, true, r.HasMultipartForm()) if err != nil { t.Fatalf("unexpected error: %s", err) } defer r.RemoveMultipartFormFiles() // verify tail tail, err := ioutil.ReadAll(mr) if err != nil { t.Fatalf("unexpected error: %s", err) } if string(tail) != "tailfoobar" { t.Fatalf("unexpected tail %q. Expecting %q", tail, "tailfoobar") } // verify values if len(f.Value) != 1 { t.Fatalf("unexpected number of values in multipart form: %d. Expecting 1", len(f.Value)) } for k, vv := range f.Value { if k != "f1" { t.Fatalf("unexpected value name %q. Expecting %q", k, "f1") } if len(vv) != 1 { t.Fatalf("unexpected number of values %d. Expecting 1", len(vv)) } v := vv[0] if v != "value1" { t.Fatalf("unexpected value %q. Expecting %q", v, "value1") } } // verify files if len(f.File) != 1 { t.Fatalf("unexpected number of file values in multipart form: %d. Expecting 1", len(f.File)) } for k, vv := range f.File { if k != "fileaaa" { t.Fatalf("unexpected file value name %q. Expecting %q", k, "fileaaa") } if len(vv) != 1 { t.Fatalf("unexpected number of file values %d. Expecting 1", len(vv)) } v := vv[0] if v.Filename != "TODO" { t.Fatalf("unexpected filename %q. Expecting %q", v.Filename, "TODO") } ct := v.Header.Get("Content-Type") if ct != "application/octet-stream" { t.Fatalf("unexpected content-type %q. Expecting %q", ct, "application/octet-stream") } } firstFile, err := r.FormFile("fileaaa") assert.DeepEqual(t, "TODO", firstFile.Filename) assert.Nil(t, err) } func TestRequestUnknownSizeStreamMultipartFormWithFile(t *testing.T) { t.Parallel() s := `------WebKitFormBoundaryJwfATyF8tmxSJnLg Content-Disposition: form-data; name="f1" value1 ------WebKitFormBoundaryJwfATyF8tmxSJnLg Content-Disposition: form-data; name="fileaaa"; filename="TODO" Content-Type: application/octet-stream - SessionClient with referer and cookies support. - Client with requests' pipelining support. - ProxyHandler similar to FSHandler. - WebSockets. See https://tools.ietf.org/html/rfc6455 . - HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . ------WebKitFormBoundaryJwfATyF8tmxSJnLg-- tailfoobar` mr := strings.NewReader(s) r := NewRequest("POST", "/upload", mr) r.Header.SetContentLength(-1) r.Header.SetContentTypeBytes([]byte("multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg")) f, err := r.MultipartForm() if err != nil { t.Fatalf("unexpected error: %s", err) } defer r.RemoveMultipartFormFiles() // all data must be consumed if the content length is unknown tail, err := ioutil.ReadAll(mr) if err != nil { t.Fatalf("unexpected error: %s", err) } if string(tail) != "" { t.Fatalf("unexpected tail %q. Expecting empty string", tail) } // verify values if len(f.Value) != 1 { t.Fatalf("unexpected number of values in multipart form: %d. Expecting 1", len(f.Value)) } for k, vv := range f.Value { if k != "f1" { t.Fatalf("unexpected value name %q. Expecting %q", k, "f1") } if len(vv) != 1 { t.Fatalf("unexpected number of values %d. Expecting 1", len(vv)) } v := vv[0] if v != "value1" { t.Fatalf("unexpected value %q. Expecting %q", v, "value1") } } // verify files if len(f.File) != 1 { t.Fatalf("unexpected number of file values in multipart form: %d. Expecting 1", len(f.File)) } for k, vv := range f.File { if k != "fileaaa" { t.Fatalf("unexpected file value name %q. Expecting %q", k, "fileaaa") } if len(vv) != 1 { t.Fatalf("unexpected number of file values %d. Expecting 1", len(vv)) } v := vv[0] if v.Filename != "TODO" { t.Fatalf("unexpected filename %q. Expecting %q", v.Filename, "TODO") } ct := v.Header.Get("Content-Type") if ct != "application/octet-stream" { t.Fatalf("unexpected content-type %q. Expecting %q", ct, "application/octet-stream") } } } func TestRequestStreamMultipartFormWithFileGzip(t *testing.T) { t.Parallel() s := `------WebKitFormBoundaryJwfATyF8tmxSJnLg Content-Disposition: form-data; name="f1" value1 ------WebKitFormBoundaryJwfATyF8tmxSJnLg Content-Disposition: form-data; name="fileaaa"; filename="TODO" Content-Type: application/octet-stream - SessionClient with referer and cookies support. - Client with requests' pipelining support. - ProxyHandler similar to FSHandler. - WebSockets. See https://tools.ietf.org/html/rfc6455 . - HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . ------WebKitFormBoundaryJwfATyF8tmxSJnLg-- tailfoobar` ns := compress.AppendGzipBytes(nil, []byte(s)) mr := bytes.NewBuffer(ns) r := NewRequest("POST", "/upload", mr) r.Header.Set("Content-Encoding", "gzip") r.Header.SetContentLength(len(s)) r.Header.SetContentTypeBytes([]byte("multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg")) f, err := r.MultipartForm() if err != nil { t.Fatalf("unexpected error: %s", err) } defer r.RemoveMultipartFormFiles() // verify values if len(f.Value) != 1 { t.Fatalf("unexpected number of values in multipart form: %d. Expecting 1", len(f.Value)) } for k, vv := range f.Value { if k != "f1" { t.Fatalf("unexpected value name %q. Expecting %q", k, "f1") } if len(vv) != 1 { t.Fatalf("unexpected number of values %d. Expecting 1", len(vv)) } v := vv[0] if v != "value1" { t.Fatalf("unexpected value %q. Expecting %q", v, "value1") } } // verify files if len(f.File) != 1 { t.Fatalf("unexpected number of file values in multipart form: %d. Expecting 1", len(f.File)) } for k, vv := range f.File { if k != "fileaaa" { t.Fatalf("unexpected file value name %q. Expecting %q", k, "fileaaa") } if len(vv) != 1 { t.Fatalf("unexpected number of file values %d. Expecting 1", len(vv)) } v := vv[0] if v.Filename != "TODO" { t.Fatalf("unexpected filename %q. Expecting %q", v.Filename, "TODO") } ct := v.Header.Get("Content-Type") if ct != "application/octet-stream" { t.Fatalf("unexpected content-type %q. Expecting %q", ct, "application/octet-stream") } } } func TestRequestMultipartFormBoundary(t *testing.T) { r := &Request{} r.SetMultipartFormBoundary("----boundary----") assert.DeepEqual(t, "----boundary----", r.MultipartFormBoundary()) } func TestRequestSetQueryString(t *testing.T) { r := &Request{} r.SetQueryString("test") assert.DeepEqual(t, "test", string(r.URI().queryString)) } func TestRequestSetFormData(t *testing.T) { r := &Request{} data := map[string]string{"username": "admin"} r.SetFormData(data) assert.DeepEqual(t, "username", string(r.postArgs.args[0].key)) assert.DeepEqual(t, "admin", string(r.postArgs.args[0].value)) assert.DeepEqual(t, true, r.parsedPostArgs) assert.DeepEqual(t, consts.MIMEApplicationHTMLForm, string(r.Header.contentType)) r = &Request{} value := map[string][]string{"item": {"apple", "peach"}} r.SetFormDataFromValues(value) assert.DeepEqual(t, "item", string(r.postArgs.args[0].key)) assert.DeepEqual(t, "apple", string(r.postArgs.args[0].value)) assert.DeepEqual(t, "item", string(r.postArgs.args[1].key)) assert.DeepEqual(t, "peach", string(r.postArgs.args[1].value)) } func TestRequestSetFile(t *testing.T) { r := &Request{} r.SetFile("file", "/usr/bin/test.txt") assert.DeepEqual(t, &File{"/usr/bin/test.txt", "file", nil}, r.multipartFiles[0]) files := map[string]string{"f1": "/usr/bin/test1.txt"} r.SetFiles(files) assert.DeepEqual(t, &File{"/usr/bin/test1.txt", "f1", nil}, r.multipartFiles[1]) assert.DeepEqual(t, []*File{{"/usr/bin/test.txt", "file", nil}, {"/usr/bin/test1.txt", "f1", nil}}, r.MultipartFiles()) } func TestRequestSetFileReader(t *testing.T) { r := &Request{} r.SetFileReader("file", "/usr/bin/test.txt", nil) assert.DeepEqual(t, &File{"/usr/bin/test.txt", "file", nil}, r.multipartFiles[0]) } func TestRequestSetMultipartFormData(t *testing.T) { r := &Request{} data := map[string]string{"item": "apple"} r.SetMultipartFormData(data) assert.DeepEqual(t, &MultipartField{"item", "", "", strings.NewReader("apple")}, r.multipartFields[0]) r = &Request{} fields := []*MultipartField{{"item2", "", "", strings.NewReader("apple2")}, {"item3", "", "", strings.NewReader("apple3")}} r.SetMultipartFields(fields...) assert.DeepEqual(t, fields, r.MultipartFields()) } func TestRequestSetBasicAuth(t *testing.T) { r := &Request{} r.SetBasicAuth("admin", "admin") assert.DeepEqual(t, "Authorization", string(r.Header.h[0].key)) assert.DeepEqual(t, "Basic "+base64.StdEncoding.EncodeToString([]byte("admin:admin")), string(r.Header.h[0].value)) } func TestRequestSetAuthToken(t *testing.T) { r := &Request{} r.SetAuthToken("token") assert.DeepEqual(t, "Authorization", string(r.Header.h[0].key)) assert.DeepEqual(t, "Bearer token", string(r.Header.h[0].value)) r = &Request{} r.SetAuthSchemeToken("http", "token") assert.DeepEqual(t, "Authorization", string(r.Header.h[0].key)) assert.DeepEqual(t, "http token", string(r.Header.h[0].value)) } func TestRequestSetHeaders(t *testing.T) { r := &Request{} headers := map[string]string{"Key1": "value1"} r.SetHeaders(headers) assert.DeepEqual(t, "Key1", string(r.Header.h[0].key)) assert.DeepEqual(t, "value1", string(r.Header.h[0].value)) } func TestRequestSetCookie(t *testing.T) { r := &Request{} r.SetCookie("cookie1", "cookie1") assert.DeepEqual(t, "cookie1", string(r.Header.cookies[0].key)) assert.DeepEqual(t, "cookie1", string(r.Header.cookies[0].value)) r.SetCookies(map[string]string{"cookie2": "cookie2"}) assert.DeepEqual(t, "cookie2", string(r.Header.cookies[1].key)) assert.DeepEqual(t, "cookie2", string(r.Header.cookies[1].value)) } func TestRequestPath(t *testing.T) { r := NewRequest("POST", "/upload?test", nil) assert.DeepEqual(t, "/upload", string(r.Path())) assert.DeepEqual(t, "test", string(r.QueryString())) } func TestRequestConnectionClose(t *testing.T) { r := NewRequest("POST", "/upload?test", nil) assert.DeepEqual(t, false, r.ConnectionClose()) r.SetConnectionClose() assert.DeepEqual(t, true, r.ConnectionClose()) } func TestRequestBodyWriteToPlain(t *testing.T) { t.Parallel() var r Request expectedS := "foobarbaz" r.AppendBodyString(expectedS) testBodyWriteTo(t, &r, expectedS, true) } func TestRequestBodyWriteToMultipart(t *testing.T) { t.Parallel() expectedS := "--foobar\r\nContent-Disposition: form-data; name=\"key_0\"\r\n\r\nvalue_0\r\n--foobar--\r\n" var r Request SetMultipartFormWithBoundary(&r, &multipart.Form{Value: map[string][]string{"key_0": {"value_0"}}}, "foobar") testBodyWriteTo(t, &r, expectedS, true) } func TestNewRequest(t *testing.T) { // get req := NewRequest("GET", "http://www.google.com/hi", bytes.NewReader([]byte("hello"))) assert.NotNil(t, req) assert.DeepEqual(t, "GET /hi HTTP/1.1\r\nHost: www.google.com\r\n\r\n", string(req.Header.Header())) assert.Nil(t, req.Body()) // post + bytes reader req = NewRequest("POST", "http://www.google.com/hi", bytes.NewReader([]byte("hello"))) assert.NotNil(t, req) assert.DeepEqual(t, "POST /hi HTTP/1.1\r\nHost: www.google.com\r\nContent-Type: application/x-www-form-urlencoded\r\nContent-Length: 5\r\n\r\n", string(req.Header.Header())) assert.DeepEqual(t, "hello", string(req.Body())) // post + string reader req = NewRequest("POST", "http://www.google.com/hi", strings.NewReader("hello world")) assert.NotNil(t, req) assert.DeepEqual(t, "POST /hi HTTP/1.1\r\nHost: www.google.com\r\nContent-Type: application/x-www-form-urlencoded\r\nContent-Length: 11\r\n\r\n", string(req.Header.Header())) assert.DeepEqual(t, "hello world", string(req.Body())) // post + bytes buffer req = NewRequest("POST", "http://www.google.com/hi", bytes.NewBuffer([]byte("hello hertz!"))) assert.NotNil(t, req) assert.DeepEqual(t, "POST /hi HTTP/1.1\r\nHost: www.google.com\r\nContent-Type: application/x-www-form-urlencoded\r\nContent-Length: 12\r\n\r\n", string(req.Header.Header())) assert.DeepEqual(t, "hello hertz!", string(req.Body())) // empty method req = NewRequest("", "/", bytes.NewBufferString("")) assert.DeepEqual(t, "GET", string(req.Method())) // unstandard method req = NewRequest("DUMMY", "/", bytes.NewBufferString("")) assert.DeepEqual(t, "DUMMY", string(req.Method())) // empty body req = NewRequest("GET", "/", nil) assert.NotNil(t, req) // wrong body req = NewRequest("POST", "/", errorReader{}) _, err := req.BodyE() assert.DeepEqual(t, err.Error(), "dummy!") req = NewRequest("POST", "/", errorReader{}) body := req.Body() assert.Nil(t, body) // GET RequestURI req = NewRequest("GET", "http://www.google.com/hi?a=1&b=2", nil) assert.DeepEqual(t, "/hi?a=1&b=2", string(req.RequestURI())) // POST RequestURI req = NewRequest("POST", "http://www.google.com/hi?a=1&b=2", nil) assert.DeepEqual(t, "/hi?a=1&b=2", string(req.RequestURI())) // nil-interface body assert.Panic(t, func() { fake := func() *errorReader { return nil } req = NewRequest("POST", "/", fake()) req.Body() }) } func TestRequestResetBody(t *testing.T) { req := Request{} req.BodyBuffer() assert.NotNil(t, req.body) req.maxKeepBodySize = math.MaxUint32 req.ResetBody() assert.NotNil(t, req.body) req.maxKeepBodySize = -1 req.ResetBody() assert.Nil(t, req.body) } func TestRequestConstructBodyStream(t *testing.T) { r := &Request{} b := []byte("test") r.ConstructBodyStream(&bytebufferpool.ByteBuffer{B: b}, strings.NewReader("test")) assert.DeepEqual(t, "test", string(r.body.B)) stream := make([]byte, 4) r.bodyStream.Read(stream) assert.DeepEqual(t, "test", string(stream)) } func TestRequestPostArgs(t *testing.T) { t.Parallel() s := `username=admin&password=admin` mr := strings.NewReader(s) r := &Request{} r.SetBodyStream(mr, len(s)) r.Header.contentType = []byte(consts.MIMEApplicationHTMLForm) arg := r.PostArgs() assert.DeepEqual(t, "username", string(arg.args[0].key)) assert.DeepEqual(t, "admin", string(arg.args[0].value)) assert.DeepEqual(t, "password", string(arg.args[1].key)) assert.DeepEqual(t, "admin", string(arg.args[1].value)) assert.DeepEqual(t, "username=admin&password=admin", string(r.PostArgString())) } func TestRequestMayContinue(t *testing.T) { t.Parallel() var r Request if r.MayContinue() { t.Fatalf("MayContinue on empty request must return false") } r.Header.Set("Expect", "123sdfds") if r.MayContinue() { t.Fatalf("MayContinue on invalid Expect header must return false") } r.Header.Set("Expect", "100-continue") if !r.MayContinue() { t.Fatalf("MayContinue on 'Expect: 100-continue' header must return true") } } func TestRequestSwapBodySerial(t *testing.T) { t.Parallel() testRequestSwapBody(t) } func testRequestSwapBody(t *testing.T) { var b []byte r := &Request{} for i := 0; i < 20; i++ { bOrig := r.Body() b = r.SwapBody(b) if !bytes.Equal(bOrig, b) { t.Fatalf("unexpected body returned: %q. Expecting %q", b, bOrig) } r.AppendBodyString("foobar") } s := "aaaabbbbcccc" b = b[:0] for i := 0; i < 10; i++ { r.SetBodyStream(bytes.NewBufferString(s), len(s)) b = r.SwapBody(b) if string(b) != s { t.Fatalf("unexpected body returned: %q. Expecting %q", b, s) } b = r.SwapBody(b) if len(b) > 0 { t.Fatalf("unexpected body with non-zero size returned: %q", b) } } } // Test case for testing BasicAuth var BasicAuthTests = []struct { header, username, password string ok bool }{ {"Basic " + base64.StdEncoding.EncodeToString([]byte("Aladdin:open sesame")), "Aladdin", "open sesame", true}, // Case doesn't matter: {"BASIC " + base64.StdEncoding.EncodeToString([]byte("Aladdin:open sesame")), "Aladdin", "open sesame", true}, {"basic " + base64.StdEncoding.EncodeToString([]byte("Aladdin:open sesame")), "Aladdin", "open sesame", true}, {"Basic " + base64.StdEncoding.EncodeToString([]byte("Aladdin:open:sesame")), "Aladdin", "open:sesame", true}, {"Basic " + base64.StdEncoding.EncodeToString([]byte(":")), "", "", true}, {"Basic" + base64.StdEncoding.EncodeToString([]byte("Aladdin:open sesame")), "", "", false}, {base64.StdEncoding.EncodeToString([]byte("Aladdin:open sesame")), "", "", false}, {"Basic ", "", "", false}, {"Basic Aladdin:open sesame", "", "", false}, {`Digest username="Aladdin"`, "", "", false}, } // struct for type getBasicAuthTest struct { username, password string ok bool } func TestRequestBasicAuth(t *testing.T) { for _, tt := range BasicAuthTests { req := NewRequest("GET", "http://www.google.com/hi", bytes.NewReader([]byte("hello"))) req.SetHeader("Authorization", tt.header) username, password, ok := req.BasicAuth() if ok != tt.ok || username != tt.username || password != tt.password { t.Fatalf("BasicAuth() = %+v, want %+v", getBasicAuthTest{username, password, ok}, getBasicAuthTest{tt.username, tt.password, tt.ok}) } } } // Issue: NewRequest should create a Request that doesn't use input parameters as its struct, // otherwise it will cause panic when we pass a const string as method to NewRequest and call req.SetMethod() func TestNewRequestWithConstParam(t *testing.T) { const method = "POST" const uri = "http://www.google.com/hi" req := NewRequest(method, uri, nil) req.SetMethod("POST") req.SetRequestURI("http://www.google.com/hi") } func TestRequestCopyToWithOptions(t *testing.T) { req := AcquireRequest() k1 := "a" v1 := "A" k2 := "b" v2 := "B" req.SetOptions(config.WithTag(k1, v1), config.WithTag(k2, v2), config.WithSD(true)) reqCopy := AcquireRequest() req.CopyTo(reqCopy) assert.DeepEqual(t, v1, reqCopy.options.Tag(k1)) assert.DeepEqual(t, v2, reqCopy.options.Tag(k2)) assert.DeepEqual(t, true, reqCopy.options.IsSD()) } func TestRequestSetMaxKeepBodySize(t *testing.T) { r := &Request{} r.SetMaxKeepBodySize(1024) assert.DeepEqual(t, 1024, r.maxKeepBodySize) } func TestRequestBodyReuse(t *testing.T) { req := Request{} req.maxKeepBodySize = 1024 buf := req.BodyBuffer() // set a big body buf.Write(make([]byte, req.maxKeepBodySize+1)) req.ResetBody() assert.Nil(t, req.body) // NOTICE: bytebufferpool may not get a big enough buffer, // so we just mock a new one here req.body = &bytebufferpool.ByteBuffer{ B: make([]byte, 0, req.maxKeepBodySize+1), } // set a small body buf = req.BodyBuffer() buf.Write(make([]byte, 1)) req.ResetBody() assert.Nil(t, req.body) } func TestRequestGetBodyAfterGetBodyStream(t *testing.T) { req := AcquireRequest() req.SetBodyString("abc") req.BodyStream() assert.DeepEqual(t, req.Body(), []byte("abc")) } func TestRequestSetOptionsNotOverwrite(t *testing.T) { req := AcquireRequest() req.SetOptions(config.WithSD(true)) req.SetOptions(config.WithTag("a", "b")) req.SetOptions(config.WithTag("c", "d")) assert.DeepEqual(t, true, req.Options().IsSD()) assert.DeepEqual(t, "b", req.Options().Tag("a")) assert.DeepEqual(t, "d", req.Options().Tag("c")) req.SetOptions(config.WithTag("a", "c")) assert.DeepEqual(t, "c", req.Options().Tag("a")) } type bodyWriterTo interface { BodyWriteTo(writer io.Writer) error Body() []byte } func testBodyWriteTo(t *testing.T, bw bodyWriterTo, expectedS string, isRetainedBody bool) { var buf bytebufferpool.ByteBuffer if err := bw.BodyWriteTo(&buf); err != nil { t.Fatalf("unexpected error: %s", err) } s := buf.B if string(s) != expectedS { t.Fatalf("unexpected result %q. Expecting %q", s, expectedS) } body := bw.Body() if isRetainedBody { if string(body) != expectedS { t.Fatalf("unexpected body %q. Expecting %q", body, expectedS) } } else { if len(body) > 0 { t.Fatalf("unexpected non-zero body after BodyWriteTo: %q", body) } } } func TestReqSafeCopy(t *testing.T) { req := AcquireRequest() req.bodyRaw = make([]byte, 1) reqs := make([]*Request, 10) for i := 0; i < 10; i++ { req.bodyRaw[0] = byte(i) tmpReq := AcquireRequest() req.CopyTo(tmpReq) reqs[i] = tmpReq } for i := 0; i < 10; i++ { assert.DeepEqual(t, []byte{byte(i)}, reqs[i].Body()) } } ================================================ FILE: pkg/protocol/response.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package protocol import ( "errors" "io" "net" "sync" "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/internal/nocopy" "github.com/cloudwego/hertz/pkg/common/bytebufferpool" "github.com/cloudwego/hertz/pkg/common/compress" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/network" ) var ( responsePool sync.Pool // NoResponseBody is an io.ReadCloser with no bytes. Read always returns EOF // and Close always returns nil. It can be used in an ingoing client // response to explicitly signal that a response has zero bytes. NoResponseBody = noBody{} ) // Response represents HTTP response. // // It is forbidden copying Response instances. Create new instances // and use CopyTo instead. // // Response instance MUST NOT be used from concurrently running goroutines. type Response struct { noCopy nocopy.NoCopy //lint:ignore U1000 until noCopy is used // Response header // // Copying Header by value is forbidden. Use pointer to Header instead. Header ResponseHeader // Flush headers as soon as possible without waiting for first body bytes. // Relevant for bodyStream only. ImmediateHeaderFlush bool bodyStream io.Reader w responseBodyWriter body *bytebufferpool.ByteBuffer bodyRaw []byte maxKeepBodySize int // Response.Read() skips reading body if set to true. // Use it for reading HEAD responses. // // Response.Write() skips writing body if set to true. // Use it for writing HEAD responses. SkipBody bool // Remote TCPAddr from concurrently net.Conn raddr net.Addr // Local TCPAddr from concurrently net.Conn laddr net.Addr // If set a hijackWriter, hertz will skip the default header/body writer process. hijackWriter network.ExtWriter } func (resp *Response) GetHijackWriter() network.ExtWriter { return resp.hijackWriter } func (resp *Response) HijackWriter(writer network.ExtWriter) { resp.hijackWriter = writer } type responseBodyWriter struct { r *Response } func (w *responseBodyWriter) Write(p []byte) (int, error) { w.r.AppendBody(p) return len(p), nil } func (resp *Response) MustSkipBody() bool { return resp.SkipBody || resp.Header.MustSkipContentLength() } // BodyGunzip returns un-gzipped body data. // // This method may be used if the response header contains // 'Content-Encoding: gzip' for reading un-gzipped body. // Use Body for reading gzipped response body. func (resp *Response) BodyGunzip() ([]byte, error) { return gunzipData(resp.Body()) } // SetConnectionClose sets 'Connection: close' header. func (resp *Response) SetConnectionClose() { resp.Header.SetConnectionClose(true) } // SetBodyString sets response body. func (resp *Response) SetBodyString(body string) { resp.CloseBodyStream() //nolint:errcheck resp.BodyBuffer().SetString(body) //nolint:errcheck } func (resp *Response) ConstructBodyStream(body *bytebufferpool.ByteBuffer, bodyStream io.Reader) { resp.body = body resp.bodyStream = bodyStream } // BodyWriter returns writer for populating response body. // // If used inside RequestHandler, the returned writer must not be used // after returning from RequestHandler. Use RequestContext.Write // or SetBodyStreamWriter in this case. func (resp *Response) BodyWriter() io.Writer { resp.w.r = resp return &resp.w } // SetStatusCode sets response status code. func (resp *Response) SetStatusCode(statusCode int) { resp.Header.SetStatusCode(statusCode) } func (resp *Response) SetMaxKeepBodySize(n int) { resp.maxKeepBodySize = n } func (resp *Response) BodyBytes() []byte { if resp.bodyRaw != nil { return resp.bodyRaw } if resp.body == nil { return nil } return resp.body.B } func (resp *Response) HasBodyBytes() bool { return len(resp.BodyBytes()) != 0 } func (resp *Response) CopyToSkipBody(dst *Response) { dst.Reset() resp.Header.CopyTo(&dst.Header) dst.SkipBody = resp.SkipBody dst.raddr = resp.raddr dst.laddr = resp.laddr } // IsBodyStream returns true if body is set via SetBodyStream* func (resp *Response) IsBodyStream() bool { return resp.bodyStream != nil } // SetBodyStream sets response body stream and, optionally body size. // // If bodySize is >= 0, then the bodyStream must provide exactly bodySize bytes // before returning io.EOF. // // If bodySize < 0, then bodyStream is read until io.EOF. // // bodyStream.Close() is called after finishing reading all body data // if it implements io.Closer. // // See also SetBodyStreamWriter. func (resp *Response) SetBodyStream(bodyStream io.Reader, bodySize int) { resp.ResetBody() resp.bodyStream = bodyStream resp.Header.SetContentLength(bodySize) } // SetBodyStreamNoReset is almost the same as SetBodyStream, // but it doesn't reset the bodyStream before. func (resp *Response) SetBodyStreamNoReset(bodyStream io.Reader, bodySize int) { resp.bodyStream = bodyStream resp.Header.SetContentLength(bodySize) } // BodyE returns response body. func (resp *Response) BodyE() ([]byte, error) { if resp.bodyStream != nil { bodyBuf := resp.BodyBuffer() bodyBuf.Reset() zw := network.NewWriter(bodyBuf) _, err := utils.CopyZeroAlloc(zw, resp.bodyStream) resp.CloseBodyStream() //nolint:errcheck if err != nil { return nil, err } } return resp.BodyBytes(), nil } // Body returns response body. // if get body failed, returns nil. func (resp *Response) Body() []byte { body, _ := resp.BodyE() return body } // BodyWriteTo writes response body to w. func (resp *Response) BodyWriteTo(w io.Writer) error { zw := network.NewWriter(w) if resp.bodyStream != nil { _, err := utils.CopyZeroAlloc(zw, resp.bodyStream) resp.CloseBodyStream() //nolint:errcheck return err } body := resp.BodyBytes() zw.WriteBinary(body) //nolint:errcheck return zw.Flush() } // CopyTo copies resp contents to dst except of body stream. func (resp *Response) CopyTo(dst *Response) { resp.CopyToSkipBody(dst) if resp.bodyRaw != nil { dst.bodyRaw = append(dst.bodyRaw[:0], resp.bodyRaw...) if dst.body != nil { dst.body.Reset() } } else if resp.body != nil { dst.BodyBuffer().Set(resp.body.B) } else if dst.body != nil { dst.body.Reset() } } func SwapResponseBody(a, b *Response) { a.body, b.body = b.body, a.body a.bodyRaw, b.bodyRaw = b.bodyRaw, a.bodyRaw a.bodyStream, b.bodyStream = b.bodyStream, a.bodyStream } // Reset clears response contents. func (resp *Response) Reset() { resp.Header.Reset() resp.resetSkipHeader() resp.SkipBody = false resp.raddr = nil resp.laddr = nil resp.ImmediateHeaderFlush = false resp.hijackWriter = nil } func (resp *Response) resetSkipHeader() { resp.ResetBody() } // ResetBody resets response body. func (resp *Response) ResetBody() { resp.bodyRaw = nil resp.CloseBodyStream() //nolint:errcheck if resp.body != nil { if resp.body.Cap() <= resp.maxKeepBodySize { resp.body.Reset() return } responseBodyPool.Put(resp.body) resp.body = nil } } // SetBodyRaw sets response body, but without copying it. // // From this point onward the body argument must not be changed. func (resp *Response) SetBodyRaw(body []byte) { resp.ResetBody() resp.bodyRaw = body } // StatusCode returns response status code. func (resp *Response) StatusCode() int { return resp.Header.StatusCode() } // SetBody sets response body. // // It is safe re-using body argument after the function returns. func (resp *Response) SetBody(body []byte) { resp.CloseBodyStream() //nolint:errcheck if resp.GetHijackWriter() == nil { resp.BodyBuffer().Set(body) //nolint:errcheck return } // If the hijack writer support .SetBody() api, then use it. if setter, ok := resp.GetHijackWriter().(interface { SetBody(b []byte) }); ok { setter.SetBody(body) return } // Otherwise, call .Write() api instead. resp.GetHijackWriter().Write(body) //nolint:errcheck } func (resp *Response) BodyStream() io.Reader { if resp.bodyStream == nil { return NoResponseBody } return resp.bodyStream } // Hijack returns the underlying network.Conn if available. // // It's only available when StatusCode() == 101 and "Connection: Upgrade", // coz Hertz will NOT reuse connection in this case, // then make it optional for users to implement their own protocols. // // The most common scenario is used with github.com/hertz-contrib/websocket func (resp *Response) Hijack() (network.Conn, error) { if resp.bodyStream != nil { h, ok := resp.bodyStream.(interface { Hijack() (network.Conn, error) }) if ok { return h.Hijack() } } return nil, errors.New("not available") } // AppendBody appends p to response body. // // It is safe re-using p after the function returns. func (resp *Response) AppendBody(p []byte) { resp.CloseBodyStream() //nolint:errcheck if resp.hijackWriter != nil { resp.hijackWriter.Write(p) //nolint:errcheck return } resp.BodyBuffer().Write(p) //nolint:errcheck } // AppendBodyString appends s to response body. func (resp *Response) AppendBodyString(s string) { resp.CloseBodyStream() //nolint:errcheck if resp.hijackWriter != nil { resp.hijackWriter.Write(bytesconv.S2b(s)) //nolint:errcheck return } resp.BodyBuffer().WriteString(s) //nolint:errcheck } // ConnectionClose returns true if 'Connection: close' header is set. func (resp *Response) ConnectionClose() bool { return resp.Header.ConnectionClose() } // CloseBodyStream tries call Close() of underlying body stream. // // NOTE: // * MUST NOT call CloseBodyStream() and BodyStream().Read() concurrently to avoid race issue. func (resp *Response) CloseBodyStream() error { if resp.bodyStream == nil { return nil } var err error if bsc, ok := resp.bodyStream.(io.Closer); ok { err = bsc.Close() } resp.bodyStream = nil return err } func (resp *Response) BodyBuffer() *bytebufferpool.ByteBuffer { if resp.body == nil { resp.body = responseBodyPool.Get() } resp.bodyRaw = nil return resp.body } func gunzipData(p []byte) ([]byte, error) { var bb bytebufferpool.ByteBuffer _, err := compress.WriteGunzip(&bb, p) if err != nil { return nil, err } return bb.B, nil } // RemoteAddr returns the remote network address. The Addr returned is shared // by all invocations of RemoteAddr, so do not modify it. func (resp *Response) RemoteAddr() net.Addr { return resp.raddr } // LocalAddr returns the local network address. The Addr returned is shared // by all invocations of LocalAddr, so do not modify it. func (resp *Response) LocalAddr() net.Addr { return resp.laddr } func (resp *Response) ParseNetAddr(conn network.Conn) { resp.raddr = conn.RemoteAddr() resp.laddr = conn.LocalAddr() } // AcquireResponse returns an empty Response instance from response pool. // // The returned Response instance may be passed to ReleaseResponse when it is // no longer needed. This allows Response recycling, reduces GC pressure // and usually improves performance. func AcquireResponse() *Response { v := responsePool.Get() if v == nil { return &Response{} } return v.(*Response) } // ReleaseResponse return resp acquired via AcquireResponse to response pool. // // It is forbidden accessing resp and/or its members after returning // it to response pool. func ReleaseResponse(resp *Response) { resp.Reset() responsePool.Put(resp) } ================================================ FILE: pkg/protocol/response_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package protocol import ( "bytes" "errors" "fmt" "math" "reflect" "testing" "github.com/cloudwego/hertz/pkg/common/bytebufferpool" "github.com/cloudwego/hertz/pkg/common/compress" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/test/mock" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/protocol/consts" ) func TestResponseCopyTo(t *testing.T) { t.Parallel() var resp Response // empty copy testResponseCopyTo(t, &resp) // init resp // resp.laddr = zeroTCPAddr resp.SkipBody = true resp.Header.SetStatusCode(consts.StatusOK) resp.SetBodyString("test") testResponseCopyTo(t, &resp) } func TestResponseBodyStreamMultipleBodyCalls(t *testing.T) { t.Parallel() var r Response s := "foobar baz abc" if r.IsBodyStream() { t.Fatalf("IsBodyStream must return false") } r.SetBodyStream(bytes.NewBufferString(s), len(s)) if !r.IsBodyStream() { t.Fatalf("IsBodyStream must return true") } for i := 0; i < 10; i++ { body := r.Body() if string(body) != s { t.Fatalf("unexpected body %q. Expecting %q. iteration %d", body, s, i) } } } func TestResponseBodyWriteToPlain(t *testing.T) { t.Parallel() var r Response expectedS := "foobarbaz" r.AppendBodyString(expectedS) testBodyWriteTo(t, &r, expectedS, true) } func TestResponseBodyWriteToStream(t *testing.T) { t.Parallel() var r Response expectedS := "aaabbbccc" buf := bytes.NewBufferString(expectedS) if r.IsBodyStream() { t.Fatalf("IsBodyStream must return false") } r.SetBodyStream(buf, len(expectedS)) if !r.IsBodyStream() { t.Fatalf("IsBodyStream must return true") } testBodyWriteTo(t, &r, expectedS, false) } func TestResponseBodyWriter(t *testing.T) { t.Parallel() var r Response w := r.BodyWriter() for i := 0; i < 10; i++ { fmt.Fprintf(w, "%d", i) } if string(r.Body()) != "0123456789" { t.Fatalf("unexpected body %q. Expecting %q", r.Body(), "0123456789") } } func TestResponseRawBodySet(t *testing.T) { t.Parallel() var resp Response expectedS := "test" body := []byte(expectedS) resp.SetBodyRaw(body) testBodyWriteTo(t, &resp, expectedS, true) } func TestResponseRawBodyReset(t *testing.T) { t.Parallel() var resp Response body := []byte("test") resp.SetBodyRaw(body) resp.ResetBody() testBodyWriteTo(t, &resp, "", true) } func TestResponseResetBody(t *testing.T) { resp := Response{} resp.BodyBuffer() assert.NotNil(t, resp.body) resp.maxKeepBodySize = math.MaxUint32 resp.ResetBody() assert.NotNil(t, resp.body) resp.maxKeepBodySize = -1 resp.ResetBody() assert.Nil(t, resp.body) } func TestResponseBodyReuse(t *testing.T) { resp := Response{} resp.maxKeepBodySize = 1024 buf := resp.BodyBuffer() // set a big body buf.Write(make([]byte, resp.maxKeepBodySize+1)) resp.ResetBody() assert.Nil(t, resp.body) // NOTICE: bytebufferpool may not get a big enough buffer, // so we just mock a new one here resp.body = &bytebufferpool.ByteBuffer{ B: make([]byte, 0, resp.maxKeepBodySize+1), } // set a small body buf.Write(make([]byte, 1)) resp.ResetBody() assert.Nil(t, resp.body) } func testResponseCopyTo(t *testing.T, src *Response) { var dst Response src.CopyTo(&dst) if !reflect.DeepEqual(src, &dst) { //nolint:govet t.Fatalf("ResponseCopyTo fail, src: \n%+v\ndst: \n%+v\n", src, &dst) //nolint:govet } } func TestResponseMustSkipBody(t *testing.T) { resp := Response{} resp.SetStatusCode(consts.StatusOK) resp.SetBodyString("test") assert.False(t, resp.MustSkipBody()) // no content 204 means that skip body is necessary resp.SetStatusCode(consts.StatusNoContent) resp.ResetBody() assert.True(t, resp.MustSkipBody()) } func TestResponseBodyGunzip(t *testing.T) { t.Parallel() dst1 := []byte("") src1 := []byte("hello") res1 := compress.AppendGzipBytes(dst1, src1) resp := Response{} resp.SetBody(res1) zipData, err := resp.BodyGunzip() assert.Nil(t, err) assert.DeepEqual(t, zipData, src1) } func TestResponseSwapResponseBody(t *testing.T) { t.Parallel() resp1 := Response{} str1 := "resp1" byteBuffer1 := &bytebufferpool.ByteBuffer{} byteBuffer1.Set([]byte(str1)) resp1.ConstructBodyStream(byteBuffer1, bytes.NewBufferString(str1)) assert.True(t, resp1.HasBodyBytes()) resp2 := Response{} str2 := "resp2" byteBuffer2 := &bytebufferpool.ByteBuffer{} byteBuffer2.Set([]byte(str2)) resp2.ConstructBodyStream(byteBuffer2, bytes.NewBufferString(str2)) SwapResponseBody(&resp1, &resp2) assert.DeepEqual(t, resp1.body.B, []byte(str2)) assert.DeepEqual(t, resp1.BodyStream(), bytes.NewBufferString(str2)) assert.DeepEqual(t, resp2.body.B, []byte(str1)) assert.DeepEqual(t, resp2.BodyStream(), bytes.NewBufferString(str1)) } func TestResponseAcquireResponse(t *testing.T) { t.Parallel() for i := 0; i < 10; i++ { resp1 := AcquireResponse() assert.NotNil(t, resp1) assert.Nil(t, resp1.body) assert.Assert(t, resp1.BodyStream() == NoResponseBody) assert.Assert(t, resp1.IsBodyStream() == false) resp1.SetBody([]byte("test")) resp1.SetStatusCode(consts.StatusOK) ReleaseResponse(resp1) } } type closeBuffer struct { *bytes.Buffer } func (b *closeBuffer) Close() error { b.Reset() return nil } func TestSetBodyStreamNoReset(t *testing.T) { t.Parallel() resp := Response{} bsA := &closeBuffer{bytes.NewBufferString("A")} bsB := &closeBuffer{bytes.NewBufferString("B")} bsC := &closeBuffer{bytes.NewBufferString("C")} resp.SetBodyStream(bsA, 1) resp.SetBodyStreamNoReset(bsB, 1) // resp.Body() has closed bsB assert.DeepEqual(t, string(resp.Body()), "B") assert.DeepEqual(t, bsA.String(), "A") resp.bodyStream = bsA resp.SetBodyStream(bsC, 1) assert.DeepEqual(t, bsA.String(), "") } func TestRespSafeCopy(t *testing.T) { resp := AcquireResponse() defer ReleaseResponse(resp) resp.bodyRaw = make([]byte, 1) resps := make([]*Response, 10) for i := 0; i < 10; i++ { resp.bodyRaw[0] = byte(i) tmpResq := AcquireResponse() resp.CopyTo(tmpResq) resps[i] = tmpResq } for i := 0; i < 10; i++ { assert.DeepEqual(t, []byte{byte(i)}, resps[i].Body()) } } func TestResponse_HijackWriter(t *testing.T) { resp := AcquireResponse() defer ReleaseResponse(resp) buf := new(bytes.Buffer) isFinal := false resp.HijackWriter(&mock.ExtWriter{Buf: buf, IsFinal: &isFinal}) resp.AppendBody([]byte("hello")) assert.DeepEqual(t, 0, buf.Len()) resp.GetHijackWriter().Flush() assert.DeepEqual(t, "hello", buf.String()) resp.AppendBodyString(", world") assert.DeepEqual(t, "hello", buf.String()) resp.GetHijackWriter().Flush() assert.DeepEqual(t, "hello, world", buf.String()) resp.SetBody([]byte("hello, hertz")) resp.GetHijackWriter().Flush() assert.DeepEqual(t, "hello, hertz", buf.String()) assert.False(t, isFinal) resp.GetHijackWriter().Finalize() assert.True(t, isFinal) } type HijackerFunc func() (network.Conn, error) func (h HijackerFunc) Read(_ []byte) (int, error) { return 0, errors.New("not implemented") } func (h HijackerFunc) Hijack() (network.Conn, error) { return h() } func TestResponse_Hijack(t *testing.T) { resp := AcquireResponse() defer ReleaseResponse(resp) _, err := resp.Hijack() assert.NotNil(t, err) resp.SetBodyStream(HijackerFunc(func() (network.Conn, error) { return nil, nil }), -1) _, err = resp.Hijack() assert.Nil(t, err) } ================================================ FILE: pkg/protocol/server.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package protocol import ( "context" "github.com/cloudwego/hertz/pkg/network" ) type Server interface { Serve(c context.Context, conn network.Conn) error } type StreamServer interface { Serve(c context.Context, conn network.StreamConn) error } ================================================ FILE: pkg/protocol/sse/event.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package sse import ( "fmt" "sync" "time" ) const ( fieldID = 1 << iota fieldType fieldData fieldRetry ) // Event represents a Server-Sent Event (SSE). type Event struct { ID string Type string // aka `event` field, which means event type Data []byte // hertz only supports reading and writing the field, // and will not take care of retry policy, please implement on your own. Retry time.Duration bitset uint8 } var poolEvent = sync.Pool{} // NewEvent creates a new event. // // Call `Release` when you're done with the event. func NewEvent() *Event { if v := poolEvent.Get(); v != nil { ret := v.(*Event) ret.Reset() return ret } return &Event{} } // Release releases the event back to the pool. func (e *Event) Release() { poolEvent.Put(e) } // String returns a string representation of the event. func (e *Event) String() string { return fmt.Sprintf("Event{ID:%q, Type:%q, Retry:%s, Data:%q}", e.ID, e.Type, e.Retry, e.Data) } // Reset resets the event fields. func (e *Event) Reset() { e.ID = "" e.Type = "" e.Retry = time.Duration(0) e.Data = e.Data[:0] e.bitset = 0 } // Clone creates a copy of the event. // // When it's no longer needed, call `Release` to return it to the pool. func (e *Event) Clone() *Event { p := NewEvent() p.ID = e.ID p.Type = e.Type p.Retry = e.Retry p.Data = append(p.Data[:0], e.Data...) p.bitset = e.bitset return p } // IsSetID returns true if the event ID is set. // // Please use SetID to set the event ID for differentiating notset or empty func (e *Event) IsSetID() bool { return e.bitset&fieldID != 0 || len(e.ID) > 0 } // IsSetType returns true if the event type is set. // // Please use SetEvent to set the event type for differentiating notset or empty func (e *Event) IsSetType() bool { return e.bitset&fieldType != 0 || len(e.Type) > 0 } // IsSetRetry returns true if the retry duration is set. // // Please use SetRetry to set the event retry duration for differentiating notset or empty func (e *Event) IsSetRetry() bool { return e.bitset&fieldRetry != 0 || e.Retry != 0 } // IsSetData returns true if the event data is set. // // Please use SetData to set or AppendData to append the event data for differentiating notset or empty func (e *Event) IsSetData() bool { return e.bitset&fieldData != 0 || len(e.Data) > 0 } // SetID sets the event ID. func (e *Event) SetID(id string) { e.ID = id e.bitset |= fieldID } // SetEvent sets the event type. func (e *Event) SetEvent(eventType string) { e.Type = eventType e.bitset |= fieldType } // SetData sets the event data. func (e *Event) SetData(data []byte) { e.Data = append(e.Data[:0], data...) e.bitset |= fieldData } // SetDataString sets the event data from a string. func (e *Event) SetDataString(data string) { e.Data = append(e.Data[:0], data...) e.bitset |= fieldData } // AppendData appends data to the event data. func (e *Event) AppendData(data []byte) { e.Data = append(e.Data, data...) e.bitset |= fieldData } // AppendDataString appends string data to the event data. func (e *Event) AppendDataString(data string) { e.Data = append(e.Data, data...) e.bitset |= fieldData } // SetRetry sets the retry duration. func (e *Event) SetRetry(retry time.Duration) { e.Retry = retry e.bitset |= fieldRetry } ================================================ FILE: pkg/protocol/sse/event_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package sse import ( "testing" "time" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestEvent_SetAndIsSet(t *testing.T) { e := NewEvent() defer e.Release() // Test initial state assert.Assert(t, !e.IsSetID()) assert.Assert(t, !e.IsSetType()) assert.Assert(t, !e.IsSetRetry()) assert.Assert(t, !e.IsSetData()) // Test SetID and IsSetID e.SetID("test-id") assert.Assert(t, e.IsSetID()) assert.DeepEqual(t, "test-id", e.ID) // Test SetEvent and IsSetType e.SetEvent("test-event") assert.Assert(t, e.IsSetType()) assert.DeepEqual(t, "test-event", e.Type) // Test SetRetry and IsSetRetry r := 3 * time.Second e.SetRetry(r) assert.Assert(t, e.IsSetRetry()) assert.DeepEqual(t, r, e.Retry) // Test SetData and IsSetData d := []byte("test-data") e.SetData(d) assert.Assert(t, e.IsSetData()) assert.DeepEqual(t, d, e.Data) e.Reset() assert.Assert(t, e.IsSetData() == false) e.SetDataString(string(d)) assert.Assert(t, e.IsSetData()) assert.DeepEqual(t, d, e.Data) } func TestEvent_AppendData(t *testing.T) { e := NewEvent() defer e.Release() // Test AppendData e.AppendData([]byte("first")) assert.Assert(t, e.IsSetData()) assert.DeepEqual(t, []byte("first"), e.Data) // Append more data e.AppendDataString("second") assert.DeepEqual(t, []byte("firstsecond"), e.Data) } func TestEvent_Reset(t *testing.T) { e := NewEvent() defer e.Release() // Set all fields e.SetID("test-id") e.SetEvent("test-event") e.SetRetry(3 * time.Second) e.SetData([]byte("test-data")) // Verify all fields are set assert.Assert(t, e.IsSetID()) assert.Assert(t, e.IsSetType()) assert.Assert(t, e.IsSetRetry()) assert.Assert(t, e.IsSetData()) // Reset and verify all fields are cleared e.Reset() assert.Assert(t, !e.IsSetID()) assert.Assert(t, !e.IsSetType()) assert.Assert(t, !e.IsSetRetry()) assert.Assert(t, !e.IsSetData()) assert.DeepEqual(t, "", e.ID) assert.DeepEqual(t, "", e.Type) assert.DeepEqual(t, time.Duration(0), e.Retry) assert.DeepEqual(t, 0, len(e.Data)) } func TestEvent_Clone(t *testing.T) { e1 := NewEvent() e1.SetID("test-id") e1.SetEvent("test-event") e1.SetRetry(3 * time.Second) e1.SetData([]byte("test-data")) e2 := e1.Clone() assert.Assert(t, e2.IsSetID()) assert.Assert(t, e2.IsSetType()) assert.Assert(t, e2.IsSetRetry()) assert.Assert(t, e2.IsSetData()) assert.DeepEqual(t, "test-id", e2.ID) assert.DeepEqual(t, "test-event", e2.Type) assert.DeepEqual(t, 3*time.Second, e2.Retry) assert.DeepEqual(t, []byte("test-data"), e2.Data) e1.Release() e2.Release() } func TestEvent_PoolAndRelease(t *testing.T) { e1 := NewEvent() e1.SetID("test-id") e1.Release() // Get another event from the pool, should be the same instance but reset e2 := NewEvent() assert.Assert(t, !e2.IsSetID()) assert.DeepEqual(t, "", e2.ID) e2.Release() } ================================================ FILE: pkg/protocol/sse/example_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package sse import ( "context" "fmt" "net" "time" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/app/client" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route" ) // Example demonstrates a simple SSE server and client interaction. func Example() { // --- SSE Server --- ln, _ := net.Listen("tcp", "127.0.0.1:0") defer ln.Close() opt := config.NewOptions([]config.Option{}) opt.Listener = ln engine := route.NewEngine(opt) engine.GET("/", func(ctx context.Context, c *app.RequestContext) { println("Server Got LastEventID", GetLastEventID(&c.Request)) w := NewWriter(c) for i := 0; i < 5; i++ { w.WriteEvent(fmt.Sprintf("id-%d", i), "message", []byte("hello\n\nworld")) time.Sleep(10 * time.Millisecond) } // [optional] it writes 0\r\n\r\n to indicate the end of chunked response // hertz will do it after handler returns w.Close() }) go engine.Run() defer engine.Close() time.Sleep(20 * time.Millisecond) // wait for server to start opt.Addr = ln.Addr().String() // --- SSE Client --- c, _ := client.NewClient() req, resp := protocol.AcquireRequest(), protocol.AcquireResponse() req.SetRequestURI("http://" + opt.Addr + "/") req.SetMethod("GET") req.SetHeader(LastEventIDHeader, "id-0") // adds `text/event-stream` to http `Accept` header // may required for some Model Context Protocol(MCP) servers AddAcceptMIME(req) if err := c.Do(context.Background(), req, resp); err != nil { panic(err) } r, err := NewReader(resp) if err != nil { panic(err) } defer r.Close() ctx, cancel := context.WithCancel(context.Background()) go func() { time.Sleep(200 * time.Millisecond) // cancel can be used to force ForEach returns by closing the remote connection _ = cancel }() err = r.ForEach(ctx, func(e *Event) error { println("Event:", e.String()) return nil }) if err != nil { panic(err) } println("Client LastEventID", r.LastEventID()) // Output: // } ================================================ FILE: pkg/protocol/sse/reader.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package sse import ( "bufio" "bytes" "context" "errors" "io" "strconv" "strings" "time" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/pkg/protocol" ) // errNotSSEContentType is returned when the response's content type is not text/event-stream. var errNotSSEContentType = errors.New("Content-Type returned by server is NOT text/event-stream") // Reader represents a reader for Server-Sent Events (SSE). // // It is used to parse the response body and extract individual events. type Reader struct { resp *protocol.Response r io.Reader s *bufio.Scanner events int32 lastEventID string } // NewReader creates a new SSE reader from the given response. // // It returns an error if the response's content type is not text/event-stream. func NewReader(resp *protocol.Response) (*Reader, error) { if !bytes.HasPrefix(resp.Header.ContentType(), bytestr.MIMETextEventStream) { return nil, errNotSSEContentType } r := &Reader{resp: resp} if resp.IsBodyStream() { r.r = resp.BodyStream() } else { r.r = bytes.NewReader(resp.Body()) } r.s = bufio.NewScanner(r.r) r.s.Split(scanEOL) return r, nil } // SetMaxBufferSize sets the maximum buffer size for the scanner. // // The scanner will allocate its own buffer as needed, up to max bytes. // The default max size without calling this method is bufio.MaxScanTokenSize (64KB). // // It panics if it is called after reading event has started. func (r *Reader) SetMaxBufferSize(max int) { // NOTE: Consider using bytebufferpool if GC becomes an issue. // Currently using nil to let scanner manage its own buffer internally. r.s.Buffer(nil, max) } type forceCloseIf interface { ForceClose() error // implemented by *clientRespStream } // ForEach iterates over all SSE events in the response body, // invoking the provided handler function for each event. // // The handler MUST NOT keep the Event reference after returning. // Use (*Event).Clone to create a copy if needed. // // Iteration stops when: // - The handler returns an error // - Reading fails (e.g., bufio.ErrTooLong for events exceeding buffer size) // - Context is cancelled (if ctx.Done() != nil) // - All events are processed (returns nil) func (r *Reader) ForEach(ctx context.Context, f func(e *Event) error) error { if ctx.Done() != nil { ch := make(chan struct{}) defer close(ch) go func() { select { case <-ctx.Done(): // force close the underlying connection to release resource // or r.Read may block until remote server ends if s, ok := r.r.(forceCloseIf); ok { s.ForceClose() } case <-ch: return } }() } e := NewEvent() defer e.Release() for { if err := ctx.Err(); err != nil { return err } if err := r.Read(e); err != nil { if err == io.EOF { return nil } if er := ctx.Err(); er != nil { err = er } return err } if err := f(e); err != nil { return err } } } // LastEventID returns the last event ID read by the reader. func (r *Reader) LastEventID() string { return r.lastEventID } func (r *Reader) onEventRead(e *Event) { r.events++ if e.IsSetID() { r.lastEventID = e.ID } } // Read reads a single SSE event from the response body. // // It populates the provided Event struct with the parsed data. // Returns nil on success, io.EOF when no more events, or an error // (e.g., bufio.ErrTooLong if an event line exceeds the buffer size). // Use SetMaxBufferSize to handle larger events. func (r *Reader) Read(e *Event) error { e.Reset() for i := 0; r.s.Scan(); i++ { line := r.s.Bytes() // Trim UTF8 BOM if i == 0 && r.events == 0 && bytes.HasPrefix(line, []byte{0xEF, 0xBB, 0xBF}) { line = line[3:] } if len(line) == 0 { // Empty line marks the end of an event if e.bitset != 0 { r.onEventRead(e) return nil } continue // Skip empty lines at the beginning } if line[0] == ':' { // Comment which starts with colon continue } // Parse field var f, v []byte i := bytes.IndexByte(line, ':') if i < 0 { // No colon, the entire line is the field name with an empty value f = line } else { f = line[:i] // If the colon is followed by a space, remove it if i+1 < len(line) && line[i+1] == ' ' { v = line[i+2:] } else { v = line[i+1:] } } // Process the field switch string(f) { case "event": e.SetEvent(sseEventType(v)) case "data": if len(e.Data) > 0 { // If we already have data, append a newline before the new data e.Data = append(e.Data, '\n') } e.AppendData(v) case "id": id := string(v) // Ignore if it contains Null if !strings.Contains(id, "\u0000") { e.SetID(id) } case "retry": if retry, err := strconv.ParseInt(string(v), 10, 64); err == nil { e.SetRetry(time.Duration(retry) * time.Millisecond) } default: // As per spec, ignore if it's not defined. } } // Check if scanner encountered an error if err := r.s.Err(); err != nil { return err } if e.bitset == 0 { return io.EOF } r.onEventRead(e) return nil } // Close closes the underlying response body. // // NOTE: // * MUST NOT call Close() and Read() / ForEach() concurrently to avoid race issue. func (r *Reader) Close() error { return r.resp.CloseBodyStream() } ================================================ FILE: pkg/protocol/sse/reader_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package sse import ( "bufio" "bytes" "context" "errors" "io" "strings" "testing" "time" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/protocol" ) type mockBodyStream struct { reader io.Reader closed bool } func (m *mockBodyStream) Read(p []byte) (n int, err error) { return m.reader.Read(p) } func (m *mockBodyStream) Close() error { m.closed = true return nil } func TestNewReader(t *testing.T) { tests := []struct { name string contentType []byte body []byte wantErr bool }{ { name: "Valid content type", contentType: bytestr.MIMETextEventStream, body: []byte("event: message\ndata: test\n\n"), wantErr: false, }, { name: "Invalid content type", contentType: []byte("text/plain"), body: []byte("event: message\ndata: test\n\n"), wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { resp := &protocol.Response{} resp.Header.SetContentType(string(tt.contentType)) resp.SetBody(tt.body) r, err := NewReader(resp) if tt.wantErr { assert.Assert(t, err != nil) assert.Assert(t, r == nil) } else { assert.Assert(t, err == nil) assert.Assert(t, r != nil) } }) } } func TestReader_ReadEvent(t *testing.T) { tests := []struct { name string input string expected *Event wantErr bool }{ { name: "Basic event", input: "id: 123\nevent: update\ndata: test data\n\n", expected: func() *Event { e := NewEvent() e.SetID("123") e.SetEvent("update") e.SetData([]byte("test data")) return e }(), wantErr: false, }, { name: "Event with retry", input: "id: 123\nevent: update\nretry: 3000\ndata: test data\n\n", expected: func() *Event { e := NewEvent() e.SetID("123") e.SetEvent("update") e.SetRetry(3000 * time.Millisecond) e.SetData([]byte("test data")) return e }(), wantErr: false, }, { name: "Event with multiline data", input: "id: 123\revent: update\r\ndata: line1\rdata: line2\r\ndata: line3\n\n", expected: func() *Event { e := NewEvent() e.SetID("123") e.SetEvent("update") e.SetData([]byte("line1\nline2\nline3")) return e }(), wantErr: false, }, { name: "Event with BOM", input: "\xEF\xBB\xBFid: 123\nevent: update\ndata: test data\n\n", expected: func() *Event { e := NewEvent() e.SetID("123") e.SetEvent("update") e.SetData([]byte("test data")) return e }(), wantErr: false, }, { name: "Event with comments", input: ": this is a comment\nid: 123\n: another comment\nevent: update\ndata: test data\n\n", expected: func() *Event { e := NewEvent() e.SetID("123") e.SetEvent("update") e.SetData([]byte("test data")) return e }(), wantErr: false, }, { name: "Event with no colon in field", input: "id\nevent: update\ndata: test data\n\n", expected: func() *Event { e := NewEvent() e.SetID("") e.SetEvent("update") e.SetData([]byte("test data")) return e }(), wantErr: false, }, { name: "Event with no space after colon", input: "id:123\nevent:update\ndata:test data\n\n", expected: func() *Event { e := NewEvent() e.SetID("123") e.SetEvent("update") e.SetData([]byte("test data")) return e }(), wantErr: false, }, { name: "Event with ID containing null character (should be ignored)", input: "id: test\u0000id\nevent: update\ndata: test data\n\n", expected: func() *Event { e := NewEvent() e.SetEvent("update") e.SetData([]byte("test data")) return e }(), wantErr: false, }, { name: "Event with invalid retry value", input: "id: 123\nevent: update\nretry: invalid\ndata: test data\n\n", expected: func() *Event { e := NewEvent() e.SetID("123") e.SetEvent("update") e.SetData([]byte("test data")) return e }(), wantErr: false, }, { name: "Empty event", input: "\n\n", expected: func() *Event { e := NewEvent() // Empty event doesn't set any fields, so bitset remains 0 return e }(), wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { resp := &protocol.Response{} resp.Header.SetContentType(string(bytestr.MIMETextEventStream)) resp.SetBody([]byte(tt.input)) r, err := NewReader(resp) assert.Assert(t, err == nil) e := NewEvent() err = r.Read(e) if tt.wantErr { assert.Assert(t, err != nil) } else { assert.Assert(t, err == nil) assert.DeepEqual(t, tt.expected.ID, e.ID) assert.DeepEqual(t, tt.expected.Type, e.Type) assert.DeepEqual(t, tt.expected.Retry, e.Retry) assert.DeepEqual(t, tt.expected.Data, e.Data) // LastEventID check if e.ID != "" { assert.DeepEqual(t, r.LastEventID(), e.ID) } } e.Release() }) } } func TestReader_ReadEvent_WithBodyStream(t *testing.T) { input := "id: 123\nevent: update\ndata: test data\n\n" resp := &protocol.Response{} resp.Header.SetContentType(string(bytestr.MIMETextEventStream)) // Create a mock body stream ms := &mockBodyStream{ reader: strings.NewReader(input), } resp.SetBodyStream(ms, -1) r, err := NewReader(resp) assert.Assert(t, err == nil) e := NewEvent() err = r.Read(e) assert.Assert(t, err == nil) // Verify event data assert.DeepEqual(t, "123", e.ID) assert.DeepEqual(t, "update", e.Type) assert.DeepEqual(t, []byte("test data"), e.Data) // LastEventID check if e.ID != "" { assert.DeepEqual(t, r.LastEventID(), e.ID) } // Test Close err = r.Close() assert.Assert(t, err == nil) assert.Assert(t, ms.closed) e.Release() } type mockReadForceClose struct { readFunc func(b []byte) (int, error) closeFunc func() error } func (m *mockReadForceClose) Read(b []byte) (int, error) { return m.readFunc(b) } func (m *mockReadForceClose) ForceClose() error { return m.closeFunc() } func TestReader_ReadEvent_Error(t *testing.T) { // Create a reader that will return an error errReader := &bytes.Reader{} resp := &protocol.Response{} resp.Header.SetContentType(string(bytestr.MIMETextEventStream)) resp.SetBodyStream(errReader, -1) r, err := NewReader(resp) assert.Assert(t, err == nil) e := NewEvent() err = r.Read(e) // The error from bytes.Reader will be io.EOF assert.Assert(t, err == io.EOF) e.Release() } func TestReader_ForEach(t *testing.T) { // mock Read & ForceClose mr := &mockReadForceClose{} ch := make(chan error, 1) defer close(ch) mr.readFunc = func(b []byte) (int, error) { return 0, <-ch } mr.closeFunc = func() error { ch <- errors.New("closed") return nil } // create protocol.Response resp := &protocol.Response{} resp.Header.SetContentType(string(bytestr.MIMETextEventStream)) resp.SetBodyStream(mr, -1) r, err := NewReader(resp) assert.Assert(t, err == nil) // test ForEach with context ctx, cancel := context.WithCancel(context.Background()) go func() { // cancel after 50ms time.Sleep(50 * time.Millisecond) cancel() }() err = r.ForEach(ctx, func(e *Event) error { panic("must not called") }) assert.Assert(t, err == ctx.Err()) } func TestReader_SetMaxBufferSize(t *testing.T) { // Test that default buffer size fails for events > 64KB t.Run("default buffer size fails for large events", func(t *testing.T) { // Create a response with a large event (65KB) - just over default 64KB largeData := strings.Repeat("x", 65*1024) input := "event: large\ndata: " + largeData + "\n\n" resp := &protocol.Response{} resp.Header.SetContentType(string(bytestr.MIMETextEventStream)) resp.SetBody([]byte(input)) r, err := NewReader(resp) assert.Assert(t, err == nil) // Don't call SetMaxBufferSize, use default (64KB) // Reading should fail because the line is too long e := NewEvent() err = r.Read(e) assert.Assert(t, errors.Is(err, bufio.ErrTooLong)) e.Release() }) // Test with custom buffer size for large events t.Run("custom buffer size", func(t *testing.T) { // Create a response with a large event (65KB) - just over default 64KB largeData := strings.Repeat("x", 65*1024) input := "event: large\ndata: " + largeData + "\n\n" resp := &protocol.Response{} resp.Header.SetContentType(string(bytestr.MIMETextEventStream)) resp.SetBody([]byte(input)) r, err := NewReader(resp) assert.Assert(t, err == nil) // Set max buffer size to 70KB to handle the large event r.SetMaxBufferSize(70 * 1024) // Should be able to read the large event e := NewEvent() err = r.Read(e) assert.Assert(t, err == nil) assert.DeepEqual(t, "large", e.Type) assert.DeepEqual(t, largeData, string(e.Data)) e.Release() // Test panic when SetMaxBufferSize is called after reading defer func() { if r := recover(); r == nil { t.Error("SetMaxBufferSize should panic after reading has started") } }() r.SetMaxBufferSize(80 * 1024) }) } ================================================ FILE: pkg/protocol/sse/request.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package sse ================================================ FILE: pkg/protocol/sse/request_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package sse ================================================ FILE: pkg/protocol/sse/utils.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package sse import ( "bytes" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/pkg/protocol" ) const LastEventIDHeader = "Last-Event-ID" // GetLastEventID returns the value of the Last-Event-ID header. func GetLastEventID(req *protocol.Request) string { return string(req.Header.Peek(LastEventIDHeader)) } // SetLastEventID sets the Last-Event-ID header. func SetLastEventID(req *protocol.Request, id string) { req.Header.Set(LastEventIDHeader, id) } // AddAcceptMIME adds `text/event-stream` to http `Accept` header. // // This is NOT required as per spec: // * User agents MAY set (`Accept`, `text/event-stream`) in request's header list. func AddAcceptMIME(req *protocol.Request) { v := req.Header.Peek("Accept") if len(v) > 0 { if bytes.Contains(v, bytestr.MIMETextEventStream) { return } // for better compatibility, only use one Accept header value // append `text/event-stream` to the end of the value req.Header.Set("Accept", string(v)+", "+string(bytestr.MIMETextEventStream)) } else { req.Header.Set("Accept", string(bytestr.MIMETextEventStream)) } } func sseEventType(v []byte) string { switch string(v) { case "message": return "message" } return string(v) } func hasCRLF(s string) bool { for i := len(s) - 1; i >= 0; i-- { switch s[i] { case '\r', '\n': return true } } return false } // https://html.spec.whatwg.org/multipage/server-sent-events.html#parsing-an-event-stream // end-of-line = ( cr lf / cr / lf ) func scanEOL(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { return 0, nil, nil } i := bytes.IndexByte(data, '\r') j := bytes.IndexByte(data, '\n') if i >= 0 { if i+1 == j { // \r\n return i + 2, data[0:i], nil } if j >= 0 { // choose the nearer \r or \n as EOL if i < j { return i + 1, data[0:i], nil // \r } return j + 1, data[0:j], nil // \n } // if ends with '\r', we need to check the next char is NOT '\n' as per spec // this may cause unexpected blocks on reading more data. if i < len(data)-1 || atEOF { return i + 1, data[0:i], nil } } else if j >= 0 { return j + 1, data[0:j], nil } if atEOF { return len(data), data, nil } return 0, nil, nil // more data } ================================================ FILE: pkg/protocol/sse/utils_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package sse import ( "testing" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/protocol" ) func TestSetGetLastEventID(t *testing.T) { req := protocol.AcquireRequest() defer protocol.ReleaseRequest(req) SetLastEventID(req, "123") assert.DeepEqual(t, "123", GetLastEventID(req)) } func TestAddAcceptMIME(t *testing.T) { // Test case 1: Empty Accept header req := protocol.AcquireRequest() defer protocol.ReleaseRequest(req) AddAcceptMIME(req) acceptHeader := req.Header.Peek("Accept") assert.DeepEqual(t, string(bytestr.MIMETextEventStream), string(acceptHeader)) // Test case 2: Existing Accept header without text/event-stream req.Reset() req.Header.Set("Accept", "text/html, application/json") AddAcceptMIME(req) acceptHeader = req.Header.Peek("Accept") assert.DeepEqual(t, "text/html, application/json, text/event-stream", string(acceptHeader)) // Test case 3: Existing Accept header already containing text/event-stream req.Reset() req.Header.Set("Accept", "text/html, text/event-stream, application/json") AddAcceptMIME(req) acceptHeader = req.Header.Peek("Accept") assert.DeepEqual(t, "text/html, text/event-stream, application/json", string(acceptHeader)) } func TestHasCRLF(t *testing.T) { assert.Assert(t, hasCRLF("\nThis is a test string")) assert.Assert(t, hasCRLF("This is \na test string")) assert.Assert(t, hasCRLF("This is a test string\n")) assert.Assert(t, hasCRLF("\rThis is a test string")) assert.Assert(t, hasCRLF("This is \rna test string")) assert.Assert(t, hasCRLF("This is a test string\r")) assert.Assert(t, hasCRLF("This is a test string") == false) } func TestSSEEventType(t *testing.T) { assert.DeepEqual(t, "message", sseEventType([]byte("message"))) assert.DeepEqual(t, "custom", sseEventType([]byte("custom"))) } func TestScanEOL(t *testing.T) { tests := []struct { data string atEOF bool advance int token string }{ {"", true, 0, ""}, {"", false, 0, ""}, {"hello\r\nworld", false, 7, "hello"}, {"hello\rworld", false, 6, "hello"}, {"hello\nworld", false, 6, "hello"}, {"hello world", false, 0, ""}, {"hello world", true, 11, "hello world"}, {"\r", false, 0, ""}, {"hello\r", false, 0, ""}, {"hello\r", true, 6, "hello"}, {"\n", false, 1, ""}, {"\r\nhello", false, 2, ""}, {"\r\n", false, 2, ""}, } for _, tc := range tests { advance, token, _ := scanEOL([]byte(tc.data), tc.atEOF) if advance != tc.advance || string(token) != tc.token { t.Fatalf("scanLines(data=%q, atEOF=%v) returns (%d, %q) expect (%d, %q)", tc.data, tc.atEOF, advance, string(token), tc.advance, tc.token) } } } ================================================ FILE: pkg/protocol/sse/writer.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package sse import ( "errors" "strconv" "sync" "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/common/bytebufferpool" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/protocol/http1/resp" ) // Writer represents a writer for Server-Sent Events (SSE). // // It is used to write individual events to the response body. type Writer struct { w network.ExtWriter mu sync.Mutex } // NewWriter creates a new SSE writer. func NewWriter(c *app.RequestContext) *Writer { // make sure proxies won't cache the data c.Response.Header.Set("Cache-Control", "no-cache") // browsers may need charset=utf-8 for logging responses // even though it's unnecessary as per spec, coz chunks must be in utf8. c.Response.Header.SetContentType("text/event-stream; charset=utf-8") w := c.Response.GetHijackWriter() if w == nil { w = resp.NewChunkedBodyWriter(&c.Response, c.GetWriter()) c.Response.HijackWriter(w) } return &Writer{w: w} } var ( errIDContainsCRLR = errors.New(`id field contains '\r' or '\n'`) errTypeContainsCRLR = errors.New(`event field contains '\r' or '\n'`) ) // WriteEvent writes a single SSE event to the response body. // // If id, eventType, or data are zero-length, they will be ignored. // It returns an error if the event contains invalid characters or if the underlying writer fails. func (w *Writer) WriteEvent(id, eventType string, data []byte) error { return w.Write(&Event{ ID: id, Type: eventType, Data: data, }) } // WriteKeepAlive writes a comment line with "keep-alive" to the response body. // // It keeps the underlying connection alive, which is useful when using proxy servers. func (w *Writer) WriteKeepAlive() error { return w.WriteComment("keep-alive") } // WriteComment writes comment lines to the response body. // // Client-side will ignore lines starting with a U+003A COLON character (:) // see: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation func (w *Writer) WriteComment(s string) error { p := bytebufferpool.Get() defer bytebufferpool.Put(p) buf := p.B[:0] for data := bytesconv.S2b(s); len(data) > 0; { i, b, _ := scanEOL(data, true) buf = append(buf, ':') buf = append(buf, b...) buf = append(buf, '\n') data = data[i:] } if len(buf) == 0 { buf = append(buf, ':') } p.B = buf w.mu.Lock() defer w.mu.Unlock() if _, err := w.w.Write(p.B); err != nil { return err } return w.w.Flush() } // Write writes a single SSE event to the response body. // // It returns an error if the event contains invalid characters or underlying writer fails. func (w *Writer) Write(e *Event) error { p := bytebufferpool.Get() defer bytebufferpool.Put(p) buf := p.B[:0] if e.IsSetID() { if hasCRLF(e.ID) { return errIDContainsCRLR } buf = append(append(append(buf, "id: "...), e.ID...), '\n') } if e.IsSetType() { if e.Type == "message" { buf = append(buf, "event: message\n"...) // fast path for message } else { if hasCRLF(e.Type) { return errTypeContainsCRLR } buf = append(append(append(buf, "event: "...), e.Type...), '\n') } } if e.IsSetRetry() { buf = append(buf, "retry: "...) buf = strconv.AppendInt(buf, e.Retry.Milliseconds(), 10) buf = append(buf, '\n') } if e.IsSetData() { data := e.Data // replace EOLs with multiple "data: " lines for len(data) > 0 { i, b, _ := scanEOL(data, true) buf = append(buf, "data: "...) buf = append(buf, b...) buf = append(buf, '\n') data = data[i:] } } p.B = append(buf, '\n') // end of event w.mu.Lock() defer w.mu.Unlock() if _, err := w.w.Write(p.B); err != nil { return err } return w.w.Flush() } func (w *Writer) Close() error { w.mu.Lock() defer w.mu.Unlock() return w.w.Finalize() } ================================================ FILE: pkg/protocol/sse/writer_test.go ================================================ /* * Copyright 2025 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package sse import ( "bytes" "errors" "testing" "time" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/common/test/assert" ) // mockWriteFlusher implements the writeFlusher interface for testing type mw struct { buf bytes.Buffer flushCalled bool writeErr error flushErr error finalizeErr error } func (m *mw) Write(p []byte) (n int, err error) { if m.writeErr != nil { return 0, m.writeErr } return m.buf.Write(p) } func (m *mw) Flush() error { m.flushCalled = true return m.flushErr } func (m *mw) String() string { return m.buf.String() } func (m *mw) Finalize() error { return m.finalizeErr } func TestWriter_WriteEvent(t *testing.T) { tests := []struct { name string id string eventType string data []byte wantErr bool expected string }{ { name: "Basic event", id: "123", eventType: "message", data: []byte("test data"), wantErr: false, expected: "id: 123\nevent: message\ndata: test data\n\n", }, { name: "Empty fields", id: "", eventType: "", data: nil, wantErr: false, expected: "\n", }, { name: "ID with CRLF", id: "test\nid", eventType: "update", data: []byte("test data"), wantErr: true, expected: "", }, { name: "Event type with CRLF", id: "123", eventType: "up\ndate", data: []byte("test data"), wantErr: true, expected: "", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { m := &mw{} w := &Writer{w: m} err := w.WriteEvent(tt.id, tt.eventType, tt.data) if tt.wantErr { assert.Assert(t, err != nil) } else { assert.Assert(t, err == nil) assert.DeepEqual(t, tt.expected, m.String()) assert.Assert(t, m.flushCalled) } }) } } func TestWriter_Write(t *testing.T) { tests := []struct { name string event *Event writeErr error flushErr error wantErr bool expected string }{ { name: "Complete event", event: func() *Event { e := NewEvent() e.SetID("123") e.SetEvent("update") e.SetRetry(3 * time.Second) e.SetData([]byte("test data")) return e }(), wantErr: false, expected: "id: 123\nevent: update\nretry: 3000\ndata: test data\n\n", }, { name: "Multiline data", event: func() *Event { e := NewEvent() e.SetData([]byte("line1\rline2\nline3\r\nline4")) return e }(), wantErr: false, expected: "data: line1\ndata: line2\ndata: line3\ndata: line4\n\n", }, { name: "Write error", event: func() *Event { e := NewEvent() e.SetData([]byte("test data")) return e }(), writeErr: errors.New("write error"), wantErr: true, expected: "", }, { name: "Flush error", event: func() *Event { e := NewEvent() e.SetData([]byte("test data")) return e }(), flushErr: errors.New("flush error"), wantErr: true, expected: "data: test data\n\n", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { m := &mw{ writeErr: tt.writeErr, flushErr: tt.flushErr, } w := &Writer{w: m} err := w.Write(tt.event) if tt.wantErr { assert.Assert(t, err != nil) } else { assert.Assert(t, err == nil) assert.DeepEqual(t, tt.expected, m.String()) assert.Assert(t, m.flushCalled) } }) } } func TestNewWriter(t *testing.T) { c := app.NewContext(0) w := NewWriter(c) assert.Assert(t, w != nil) assert.DeepEqual(t, "no-cache", string(c.Response.Header.Peek("Cache-Control"))) assert.DeepEqual(t, "text/event-stream; charset=utf-8", string(c.Response.Header.Peek("Content-Type"))) } func TestWriter_WriteComment(t *testing.T) { m := &mw{} w := &Writer{w: m} err := w.WriteComment("test\ncomment") assert.Assert(t, err == nil) assert.DeepEqual(t, ":test\n:comment\n", m.String()) assert.Assert(t, m.flushCalled) // empty string m = &mw{} w = &Writer{w: m} err = w.WriteComment("") assert.Assert(t, err == nil) assert.DeepEqual(t, ":", m.String()) // keep-alive m = &mw{} w = &Writer{w: m} err = w.WriteKeepAlive() assert.Assert(t, err == nil) assert.DeepEqual(t, ":keep-alive\n", m.String()) } func TestWriter_Close(t *testing.T) { // Create a mock writeFlusher m := &mw{} // Create a Writer with the mock w := &Writer{w: m} // Test Close method err := w.Close() // Verify no error occurred assert.Nil(t, err) // Set an error to be returned by Finalize expectedErr := errors.New("finalize error") m.finalizeErr = expectedErr // Test Close method with error err = w.Close() // Verify the error is propagated assert.DeepEqual(t, expectedErr, err) } ================================================ FILE: pkg/protocol/suite/client.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package suite import "github.com/cloudwego/hertz/pkg/protocol/client" type ClientFactory interface { NewHostClient() (hc client.HostClient, err error) } ================================================ FILE: pkg/protocol/suite/server.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package suite import ( "context" "sync" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/hlog" "github.com/cloudwego/hertz/pkg/common/tracer" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" ) const ( // must be the same with the ALPN nextProto values HTTP1 = "http/1.1" HTTP2 = "h2" // HTTP3Draft29 is the ALPN protocol negotiated during the TLS handshake, for QUIC draft 29. HTTP3Draft29 = "h3-29" // HTTP3 is the ALPN protocol negotiated during the TLS handshake, for QUIC v1 and v2. HTTP3 = "h3" ) // Core is the core interface that promises to be provided for the protocol layer extensions type Core interface { // IsRunning Check whether engine is running or not IsRunning() bool // A RequestContext pool ready for protocol server impl GetCtxPool() *sync.Pool // Business logic entrance // After pre-read works, protocol server may call this method // to introduce the middlewares and handlers ServeHTTP(c context.Context, ctx *app.RequestContext) // GetTracer for tracing requirement GetTracer() tracer.Controller } type ServerFactory interface { New(core Core) (server protocol.Server, err error) } type StreamServerFactory interface { New(core Core) (server protocol.StreamServer, err error) } type Config struct { altServerConfig *altServerConfig configMap map[string]ServerFactory streamConfigMap map[string]StreamServerFactory } type ServerMap map[string]protocol.Server type StreamServerMap map[string]protocol.StreamServer type altServerConfig struct { targetProtocol string setAltHeaderFunc func(ctx context.Context, reqCtx *app.RequestContext) } type coreWrapper struct { Core beforeHandler func(c context.Context, ctx *app.RequestContext) } func (c *coreWrapper) ServeHTTP(ctx context.Context, reqCtx *app.RequestContext) { c.beforeHandler(ctx, reqCtx) c.Core.ServeHTTP(ctx, reqCtx) } // SetAltHeader will set response header "Alt-Svc" for the target protocol, altHeader will be the value of the header. // Protocols other than the target protocol will carry the altHeader in the request header. func (c *Config) SetAltHeader(target, altHeader string) { c.altServerConfig = &altServerConfig{ targetProtocol: target, setAltHeaderFunc: func(ctx context.Context, reqCtx *app.RequestContext) { reqCtx.Response.Header.Add(consts.HeaderAltSvc, altHeader) }, } } func (c *Config) Add(protocol string, factory interface{}) { switch factory := factory.(type) { case ServerFactory: if fac := c.configMap[protocol]; fac != nil { hlog.SystemLogger().Warnf("ServerFactory of protocol: %s will be overridden by customized function", protocol) } c.configMap[protocol] = factory case StreamServerFactory: if fac := c.streamConfigMap[protocol]; fac != nil { hlog.SystemLogger().Warnf("StreamServerFactory of protocol: %s will be overridden by customized function", protocol) } c.streamConfigMap[protocol] = factory default: hlog.SystemLogger().Fatalf("Unsupported factory type: %T", factory) } } func (c *Config) Get(name string) ServerFactory { return c.configMap[name] } func (c *Config) Delete(protocol string) { delete(c.configMap, protocol) } func (c *Config) Load(core Core, protocol string) (server protocol.Server, err error) { if c.configMap[protocol] == nil { return nil, errors.NewPrivate("HERTZ: Load server error, not support protocol: " + protocol) } if c.altServerConfig == nil || c.altServerConfig.targetProtocol == protocol { return c.configMap[protocol].New(core) } return c.configMap[protocol].New(&coreWrapper{Core: core, beforeHandler: c.altServerConfig.setAltHeaderFunc}) } func (c *Config) LoadAll(core Core) (serverMap ServerMap, streamServerMap StreamServerMap, err error) { serverMap = make(ServerMap) var wrappedCore *coreWrapper if c.altServerConfig != nil { wrappedCore = &coreWrapper{Core: core, beforeHandler: c.altServerConfig.setAltHeaderFunc} } var server protocol.Server for proto := range c.configMap { if c.altServerConfig != nil && c.altServerConfig.targetProtocol != proto { core = wrappedCore } if server, err = c.configMap[proto].New(core); err != nil { return nil, nil, err } else { serverMap[proto] = server } } streamServerMap = make(StreamServerMap) var streamServer protocol.StreamServer for proto := range c.streamConfigMap { if c.altServerConfig != nil && c.altServerConfig.targetProtocol != proto { core = wrappedCore } if streamServer, err = c.streamConfigMap[proto].New(core); err != nil { return nil, nil, err } else { streamServerMap[proto] = streamServer } } return serverMap, streamServerMap, nil } // New return an empty Config suite, use .Add() to add protocol impl func New() *Config { c := &Config{ configMap: make(map[string]ServerFactory), streamConfigMap: make(map[string]StreamServerFactory), } return c } ================================================ FILE: pkg/protocol/trailer.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package protocol import ( "bytes" "github.com/cloudwego/hertz/internal/bytestr" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol/consts" ) type Trailer struct { h []argsKV bufKV argsKV disableNormalizing bool } // Get returns trailer value for the given key. func (t *Trailer) Get(key string) string { return string(t.Peek(key)) } // Peek returns trailer value for the given key. // // Returned value is valid until the next call to Trailer. // Do not store references to returned value. Make copies instead. func (t *Trailer) Peek(key string) []byte { k := []byte(key) utils.NormalizeHeaderKey(k, t.disableNormalizing) return peekArgBytes(t.h, k) } // Del deletes trailer with the given key. func (t *Trailer) Del(key string) { k := []byte(key) utils.NormalizeHeaderKey(k, t.disableNormalizing) t.h = delAllArgsBytes(t.h, k) } // VisitAll calls f for each header. func (t *Trailer) VisitAll(f func(key, value []byte)) { visitArgs(t.h, f) } // Set sets the given 'key: value' trailer. // // If the key is forbidden by RFC 7230, section 4.1.2, Set will return error func (t *Trailer) Set(key, value string) error { initHeaderKV(&t.bufKV, key, value, t.disableNormalizing) return t.setArgBytes(t.bufKV.key, t.bufKV.value, ArgsHasValue) } // Add adds the given 'key: value' trailer. // // Multiple headers with the same key may be added with this function. // Use Set for setting a single header for the given key. // // If the key is forbidden by RFC 7230, section 4.1.2, Add will return error func (t *Trailer) Add(key, value string) error { initHeaderKV(&t.bufKV, key, value, t.disableNormalizing) return t.addArgBytes(t.bufKV.key, t.bufKV.value, ArgsHasValue) } func (t *Trailer) addArgBytes(key, value []byte, noValue bool) error { if IsBadTrailer(key) { return errs.NewPublicf("forbidden trailer key: %q", key) } t.h = appendArgBytes(t.h, key, value, noValue) return nil } func (t *Trailer) setArgBytes(key, value []byte, noValue bool) error { if IsBadTrailer(key) { return errs.NewPublicf("forbidden trailer key: %q", key) } t.h = setArgBytes(t.h, key, value, noValue) return nil } func (t *Trailer) UpdateArgBytes(key, value []byte) error { if IsBadTrailer(key) { return errs.NewPublicf("forbidden trailer key: %q", key) } t.h = updateArgBytes(t.h, key, value) return nil } func (t *Trailer) GetTrailers() []argsKV { return t.h } func (t *Trailer) Empty() bool { return len(t.h) == 0 } // GetBytes return the 'Trailer' Header which is composed by the Trailer key func (t *Trailer) GetBytes() []byte { var dst []byte for i, n := 0, len(t.h); i < n; i++ { kv := &t.h[i] dst = append(dst, kv.key...) if i+1 < n { dst = append(dst, bytestr.StrCommaSpace...) } } return dst } func (t *Trailer) ResetSkipNormalize() { t.h = t.h[:0] } func (t *Trailer) Reset() { t.disableNormalizing = false t.ResetSkipNormalize() } func (t *Trailer) DisableNormalizing() { t.disableNormalizing = true } func (t *Trailer) IsDisableNormalizing() bool { return t.disableNormalizing } // CopyTo copies all the trailer to dst. func (t *Trailer) CopyTo(dst *Trailer) { dst.Reset() dst.disableNormalizing = t.disableNormalizing dst.h = copyArgs(dst.h, t.h) } func (t *Trailer) SetTrailers(trailers []byte) (err error) { t.ResetSkipNormalize() for i := -1; i+1 < len(trailers); { trailers = trailers[i+1:] i = bytes.IndexByte(trailers, ',') if i < 0 { i = len(trailers) } trailerKey := trailers[:i] for len(trailerKey) > 0 && trailerKey[0] == ' ' { trailerKey = trailerKey[1:] } for len(trailerKey) > 0 && trailerKey[len(trailerKey)-1] == ' ' { trailerKey = trailerKey[:len(trailerKey)-1] } utils.NormalizeHeaderKey(trailerKey, t.disableNormalizing) err = t.addArgBytes(trailerKey, nil, argsNoValue) } return } func (t *Trailer) Header() []byte { t.bufKV.value = t.AppendBytes(t.bufKV.value[:0]) return t.bufKV.value } func (t *Trailer) AppendBytes(dst []byte) []byte { for i, n := 0, len(t.h); i < n; i++ { kv := &t.h[i] dst = appendHeaderLine(dst, kv.key, kv.value) } dst = append(dst, bytestr.StrCRLF...) return dst } func IsBadTrailer(key []byte) bool { switch key[0] | 0x20 { case 'a': return utils.CaseInsensitiveCompare(key, bytestr.StrAuthorization) case 'c': if len(key) >= len(consts.HeaderContentType) && utils.CaseInsensitiveCompare(key[:8], bytestr.StrContentType[:8]) { // skip compare prefix 'Content-' return utils.CaseInsensitiveCompare(key[8:], bytestr.StrContentEncoding[8:]) || utils.CaseInsensitiveCompare(key[8:], bytestr.StrContentLength[8:]) || utils.CaseInsensitiveCompare(key[8:], bytestr.StrContentType[8:]) || utils.CaseInsensitiveCompare(key[8:], bytestr.StrContentRange[8:]) } return utils.CaseInsensitiveCompare(key, bytestr.StrConnection) case 'e': return utils.CaseInsensitiveCompare(key, bytestr.StrExpect) case 'h': return utils.CaseInsensitiveCompare(key, bytestr.StrHost) case 'k': return utils.CaseInsensitiveCompare(key, bytestr.StrKeepAlive) case 'm': return utils.CaseInsensitiveCompare(key, bytestr.StrMaxForwards) case 'p': if len(key) >= len(consts.HeaderProxyConnection) && utils.CaseInsensitiveCompare(key[:6], bytestr.StrProxyConnection[:6]) { // skip compare prefix 'Proxy-' return utils.CaseInsensitiveCompare(key[6:], bytestr.StrProxyConnection[6:]) || utils.CaseInsensitiveCompare(key[6:], bytestr.StrProxyAuthenticate[6:]) || utils.CaseInsensitiveCompare(key[6:], bytestr.StrProxyAuthorization[6:]) } case 'r': return utils.CaseInsensitiveCompare(key, bytestr.StrRange) case 't': return utils.CaseInsensitiveCompare(key, bytestr.StrTE) || utils.CaseInsensitiveCompare(key, bytestr.StrTrailer) || utils.CaseInsensitiveCompare(key, bytestr.StrTransferEncoding) case 'w': return utils.CaseInsensitiveCompare(key, bytestr.StrWWWAuthenticate) } return false } ================================================ FILE: pkg/protocol/trailer_test.go ================================================ /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package protocol import ( "strings" "testing" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/protocol/consts" ) func TestTrailerAdd(t *testing.T) { var tr Trailer assert.Nil(t, tr.Add("foo", "value1")) assert.Nil(t, tr.Add("foo", "value2")) assert.Nil(t, tr.Add("bar", "value3")) assert.True(t, strings.Contains(string(tr.Header()), "Foo: value1")) assert.True(t, strings.Contains(string(tr.Header()), "Foo: value2")) assert.True(t, strings.Contains(string(tr.Header()), "Bar: value3")) } func TestHeaderTrailerSet(t *testing.T) { h := &RequestHeader{} // only one trailer h.Set("Trailer", "Foo") assert.True(t, strings.Contains(string(h.Trailer().Header()), "Foo:")) // multi trailer h.Set("Trailer", "Foo, bar, HERtz") assert.True(t, strings.Contains(string(h.Trailer().Header()), "Foo:")) assert.True(t, strings.Contains(string(h.Trailer().Header()), "Bar:")) assert.True(t, strings.Contains(string(h.Trailer().Header()), "Hertz:")) // all lowercase h.Set("Trailer", "foo,hertz,aaa") assert.True(t, strings.Contains(string(h.Trailer().Header()), "Foo:")) assert.True(t, strings.Contains(string(h.Trailer().Header()), "Hertz:")) assert.True(t, strings.Contains(string(h.Trailer().Header()), "Aaa:")) // all uppercase h.Set("Trailer", "FOO,HERTZ,AAA") assert.True(t, strings.Contains(string(h.Trailer().Header()), "Foo:")) assert.True(t, strings.Contains(string(h.Trailer().Header()), "Hertz:")) assert.True(t, strings.Contains(string(h.Trailer().Header()), "Aaa:")) // with '-' h.Set("Trailer", "FOO-HERTZ-AAA") assert.True(t, strings.Contains(string(h.Trailer().Header()), "Foo-Hertz-Aaa:")) // more space h.Set("Trailer", " foo, hertz , aaa ") assert.True(t, strings.Contains(string(h.Trailer().Header()), "Foo:")) assert.True(t, strings.Contains(string(h.Trailer().Header()), "Hertz:")) assert.True(t, strings.Contains(string(h.Trailer().Header()), "Aaa:")) } func TestTrailerAddError(t *testing.T) { var tr Trailer assert.NotNil(t, tr.Add(consts.HeaderContentType, "")) assert.NotNil(t, tr.Set(consts.HeaderProxyConnection, "")) } func TestTrailerDel(t *testing.T) { var tr Trailer assert.Nil(t, tr.Add("foo", "value1")) assert.Nil(t, tr.Add("foo", "value2")) assert.Nil(t, tr.Add("bar", "value3")) tr.Del("foo") assert.False(t, strings.Contains(string(tr.Header()), "Foo: value1")) assert.False(t, strings.Contains(string(tr.Header()), "Foo: value2")) assert.True(t, strings.Contains(string(tr.Header()), "Bar: value3")) } func TestTrailerSet(t *testing.T) { var tr Trailer assert.Nil(t, tr.Set("foo", "value1")) assert.Nil(t, tr.Set("foo", "value2")) assert.Nil(t, tr.Set("bar", "value3")) assert.False(t, strings.Contains(string(tr.Header()), "Foo: value1")) assert.True(t, strings.Contains(string(tr.Header()), "Foo: value2")) assert.True(t, strings.Contains(string(tr.Header()), "Bar: value3")) } func TestTrailerGet(t *testing.T) { var tr Trailer assert.Nil(t, tr.Add("foo", "value1")) assert.Nil(t, tr.Add("bar", "value3")) assert.DeepEqual(t, tr.Get("foo"), "value1") assert.DeepEqual(t, tr.Get("bar"), "value3") } func TestTrailerUpdateArgBytes(t *testing.T) { var tr Trailer assert.Nil(t, tr.addArgBytes([]byte("Foo"), []byte("value0"), argsNoValue)) assert.Nil(t, tr.UpdateArgBytes([]byte("Foo"), []byte("value1"))) assert.Nil(t, tr.UpdateArgBytes([]byte("Foo"), []byte("value2"))) assert.Nil(t, tr.UpdateArgBytes([]byte("Bar"), []byte("value3"))) assert.True(t, strings.Contains(string(tr.Header()), "Foo: value1")) assert.False(t, strings.Contains(string(tr.Header()), "Foo: value2")) assert.False(t, strings.Contains(string(tr.Header()), "Bar: value3")) } func TestTrailerEmpty(t *testing.T) { var tr Trailer assert.DeepEqual(t, tr.Empty(), true) assert.Nil(t, tr.Set("foo", "")) assert.DeepEqual(t, tr.Empty(), false) } func TestTrailerVisitAll(t *testing.T) { var tr Trailer assert.Nil(t, tr.Add("foo", "value1")) assert.Nil(t, tr.Add("bar", "value2")) tr.VisitAll( func(k, v []byte) { key := string(k) value := string(v) if (key != "Foo" || value != "value1") && (key != "Bar" || value != "value2") { t.Fatalf("Unexpected (%v, %v). Expected %v", key, value, "(foo, value1) or (bar, value2)") } }) } func TestIsBadTrailer(t *testing.T) { assert.True(t, IsBadTrailer(bytestr.StrAuthorization)) assert.True(t, IsBadTrailer(bytestr.StrContentEncoding)) assert.True(t, IsBadTrailer(bytestr.StrContentLength)) assert.True(t, IsBadTrailer(bytestr.StrContentType)) assert.True(t, IsBadTrailer(bytestr.StrContentRange)) assert.True(t, IsBadTrailer(bytestr.StrConnection)) assert.True(t, IsBadTrailer(bytestr.StrExpect)) assert.True(t, IsBadTrailer(bytestr.StrHost)) assert.True(t, IsBadTrailer(bytestr.StrKeepAlive)) assert.True(t, IsBadTrailer(bytestr.StrMaxForwards)) assert.True(t, IsBadTrailer(bytestr.StrProxyConnection)) assert.True(t, IsBadTrailer(bytestr.StrProxyAuthenticate)) assert.True(t, IsBadTrailer(bytestr.StrProxyAuthorization)) assert.True(t, IsBadTrailer(bytestr.StrRange)) assert.True(t, IsBadTrailer(bytestr.StrTE)) assert.True(t, IsBadTrailer(bytestr.StrTrailer)) assert.True(t, IsBadTrailer(bytestr.StrTransferEncoding)) assert.True(t, IsBadTrailer(bytestr.StrWWWAuthenticate)) } ================================================ FILE: pkg/protocol/uri.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package protocol import ( "bytes" "path/filepath" "sync" "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/internal/nocopy" ) // AcquireURI returns an empty URI instance from the pool. // // Release the URI with ReleaseURI after the URI is no longer needed. // This allows reducing GC load. func AcquireURI() *URI { return uriPool.Get().(*URI) } // ReleaseURI releases the URI acquired via AcquireURI. // // The released URI mustn't be used after releasing it, otherwise data races // may occur. func ReleaseURI(u *URI) { u.Reset() uriPool.Put(u) } var uriPool = &sync.Pool{ New: func() interface{} { return &URI{} }, } type URI struct { noCopy nocopy.NoCopy //lint:ignore U1000 until noCopy is used pathOriginal []byte scheme []byte path []byte queryString []byte hash []byte host []byte queryArgs Args parsedQueryArgs bool DisablePathNormalizing bool fullURI []byte requestURI []byte username []byte password []byte } type argsKV struct { key []byte value []byte noValue bool } func (kv *argsKV) GetKey() []byte { return kv.key } func (kv *argsKV) GetValue() []byte { return kv.value } // CopyTo copies uri contents to dst. func (u *URI) CopyTo(dst *URI) { dst.Reset() dst.pathOriginal = append(dst.pathOriginal[:0], u.pathOriginal...) dst.scheme = append(dst.scheme[:0], u.scheme...) dst.path = append(dst.path[:0], u.path...) dst.queryString = append(dst.queryString[:0], u.queryString...) dst.hash = append(dst.hash[:0], u.hash...) dst.host = append(dst.host[:0], u.host...) dst.username = append(dst.username[:0], u.username...) dst.password = append(dst.password[:0], u.password...) u.queryArgs.CopyTo(&dst.queryArgs) dst.parsedQueryArgs = u.parsedQueryArgs dst.DisablePathNormalizing = u.DisablePathNormalizing // fullURI and requestURI shouldn't be copied, since they are created // from scratch on each FullURI() and RequestURI() call. } // QueryArgs returns query args. func (u *URI) QueryArgs() *Args { u.parseQueryArgs() return &u.queryArgs } func (u *URI) parseQueryArgs() { if u.parsedQueryArgs { return } u.queryArgs.ParseBytes(u.queryString) u.parsedQueryArgs = true } // Hash returns URI hash, i.e. qwe of http://aaa.com/foo/bar?baz=123#qwe . // // The returned value is valid until the next URI method call. func (u *URI) Hash() []byte { return u.hash } // SetHash sets URI hash. func (u *URI) SetHash(hash string) { u.hash = append(u.hash[:0], hash...) } // SetHashBytes sets URI hash. func (u *URI) SetHashBytes(hash []byte) { u.hash = append(u.hash[:0], hash...) } // Username returns URI username func (u *URI) Username() []byte { return u.username } // SetUsername sets URI username. func (u *URI) SetUsername(username string) { u.username = append(u.username[:0], username...) } // SetUsernameBytes sets URI username. func (u *URI) SetUsernameBytes(username []byte) { u.username = append(u.username[:0], username...) } // Password returns URI password func (u *URI) Password() []byte { return u.password } // SetPassword sets URI password. func (u *URI) SetPassword(password string) { u.password = append(u.password[:0], password...) } // SetPasswordBytes sets URI password. func (u *URI) SetPasswordBytes(password []byte) { u.password = append(u.password[:0], password...) } // QueryString returns URI query string, // i.e. baz=123 of http://aaa.com/foo/bar?baz=123#qwe . // // The returned value is valid until the next URI method call. func (u *URI) QueryString() []byte { return u.queryString } // SetQueryString sets URI query string. func (u *URI) SetQueryString(queryString string) { u.queryString = append(u.queryString[:0], queryString...) u.parsedQueryArgs = false } // SetQueryStringBytes sets URI query string. func (u *URI) SetQueryStringBytes(queryString []byte) { u.queryString = append(u.queryString[:0], queryString...) u.parsedQueryArgs = false } // Path returns URI path, i.e. /foo/bar of http://aaa.com/foo/bar?baz=123#qwe . // // The returned path is always urldecoded and normalized, // i.e. '//f%20obar/baz/../zzz' becomes '/f obar/zzz'. // // The returned value is valid until the next URI method call. func (u *URI) Path() []byte { path := u.path if len(path) == 0 { path = bytestr.StrSlash } return path } // SetPath sets URI path. func (u *URI) SetPath(path string) { u.pathOriginal = append(u.pathOriginal[:0], path...) u.path = normalizePath(u.path, u.pathOriginal) } // String returns full uri. func (u *URI) String() string { return string(u.FullURI()) } // SetPathBytes sets URI path. func (u *URI) SetPathBytes(path []byte) { u.pathOriginal = append(u.pathOriginal[:0], path...) u.path = normalizePath(u.path, u.pathOriginal) } // PathOriginal returns the original path from requestURI passed to URI.Parse(). // // The returned value is valid until the next URI method call. func (u *URI) PathOriginal() []byte { return u.pathOriginal } // Scheme returns URI scheme, i.e. http of http://aaa.com/foo/bar?baz=123#qwe . // // Returned scheme is always lowercased. // // The returned value is valid until the next URI method call. func (u *URI) Scheme() []byte { scheme := u.scheme if len(scheme) == 0 { scheme = bytestr.StrHTTP } return scheme } // SetScheme sets URI scheme, i.e. http, https, ftp, etc. func (u *URI) SetScheme(scheme string) { u.scheme = append(u.scheme[:0], scheme...) bytesconv.LowercaseBytes(u.scheme) } // SetSchemeBytes sets URI scheme, i.e. http, https, ftp, etc. func (u *URI) SetSchemeBytes(scheme []byte) { u.scheme = append(u.scheme[:0], scheme...) bytesconv.LowercaseBytes(u.scheme) } // Reset clears uri. func (u *URI) Reset() { u.pathOriginal = u.pathOriginal[:0] u.scheme = u.scheme[:0] u.path = u.path[:0] u.queryString = u.queryString[:0] u.hash = u.hash[:0] u.username = u.username[:0] u.password = u.password[:0] u.host = u.host[:0] u.queryArgs.Reset() u.parsedQueryArgs = false u.DisablePathNormalizing = false // There is no need in u.fullURI = u.fullURI[:0], since full uri // is calculated on each call to FullURI(). // There is no need in u.requestURI = u.requestURI[:0], since requestURI // is calculated on each call to RequestURI(). } // Host returns host part, i.e. aaa.com of http://aaa.com/foo/bar?baz=123#qwe . // // Host is always lowercased. func (u *URI) Host() []byte { return u.host } // SetHost sets host for the uri. func (u *URI) SetHost(host string) { u.host = append(u.host[:0], host...) bytesconv.LowercaseBytes(u.host) } // SetHostBytes sets host for the uri. func (u *URI) SetHostBytes(host []byte) { u.host = append(u.host[:0], host...) bytesconv.LowercaseBytes(u.host) } // LastPathSegment returns the last part of uri path after '/'. // // Examples: // // - For /foo/bar/baz.html path returns baz.html. // - For /foo/bar/ returns empty byte slice. // - For /foobar.js returns foobar.js. func (u *URI) LastPathSegment() []byte { path := u.Path() n := bytes.LastIndexByte(path, '/') if n < 0 { return path } return path[n+1:] } // Update updates uri. // // The following newURI types are accepted: // // - Absolute, i.e. http://foobar.com/aaa/bb?cc . In this case the original // uri is replaced by newURI. // - Absolute without scheme, i.e. //foobar.com/aaa/bb?cc. In this case // the original scheme is preserved. // - Missing host, i.e. /aaa/bb?cc . In this case only RequestURI part // of the original uri is replaced. // - Relative path, i.e. xx?yy=abc . In this case the original RequestURI // is updated according to the new relative path. func (u *URI) Update(newURI string) { u.UpdateBytes(bytesconv.S2b(newURI)) } // UpdateBytes updates uri. // // The following newURI types are accepted: // // - Absolute, i.e. http://foobar.com/aaa/bb?cc . In this case the original // uri is replaced by newURI. // - Absolute without scheme, i.e. //foobar.com/aaa/bb?cc. In this case // the original scheme is preserved. // - Missing host, i.e. /aaa/bb?cc . In this case only RequestURI part // of the original uri is replaced. // - Relative path, i.e. xx?yy=abc . In this case the original RequestURI // is updated according to the new relative path. func (u *URI) UpdateBytes(newURI []byte) { u.requestURI = u.updateBytes(newURI, u.requestURI) } // Parse initializes URI from the given host and uri. // // host may be nil. In this case uri must contain fully qualified uri, // i.e. with scheme and host. http is assumed if scheme is omitted. // // uri may contain e.g. RequestURI without scheme and host if host is non-empty. func (u *URI) Parse(host, uri []byte) { u.parse(host, uri, false) } // Maybe rawURL is of the form scheme:path. // (Scheme must be [a-zA-Z][a-zA-Z0-9+-.]*) // If so, return scheme, path; else return nil, rawURL. func getScheme(rawURL []byte) (scheme, path []byte) { for i := 0; i < len(rawURL); i++ { c := rawURL[i] switch { case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z': // do nothing case '0' <= c && c <= '9' || c == '+' || c == '-' || c == '.': if i == 0 { return nil, rawURL } case c == ':': return checkSchemeWhenCharIsColon(i, rawURL) default: // we have encountered an invalid character, // so there is no valid scheme return nil, rawURL } } return nil, rawURL } func (u *URI) parse(host, uri []byte, isTLS bool) { u.Reset() if stringContainsCTLByte(uri) { return } if len(host) == 0 || bytes.Contains(uri, bytestr.StrColonSlashSlash) { scheme, newHost, newURI := splitHostURI(host, uri) u.scheme = append(u.scheme, scheme...) bytesconv.LowercaseBytes(u.scheme) host = newHost uri = newURI } if isTLS { u.scheme = append(u.scheme[:0], bytestr.StrHTTPS...) } if n := bytes.Index(host, bytestr.StrAt); n >= 0 { auth := host[:n] host = host[n+1:] if n := bytes.Index(auth, bytestr.StrColon); n >= 0 { u.username = append(u.username[:0], auth[:n]...) u.password = append(u.password[:0], auth[n+1:]...) } else { u.username = append(u.username[:0], auth...) u.password = u.password[:0] } } u.host = append(u.host, host...) bytesconv.LowercaseBytes(u.host) b := uri queryIndex := bytes.IndexByte(b, '?') fragmentIndex := bytes.IndexByte(b, '#') // Ignore query in fragment part if fragmentIndex >= 0 && queryIndex > fragmentIndex { queryIndex = -1 } if queryIndex < 0 && fragmentIndex < 0 { u.pathOriginal = append(u.pathOriginal, b...) u.path = normalizePath(u.path, u.pathOriginal) return } if queryIndex >= 0 { // Path is everything up to the start of the query u.pathOriginal = append(u.pathOriginal, b[:queryIndex]...) u.path = normalizePath(u.path, u.pathOriginal) if fragmentIndex < 0 { u.queryString = append(u.queryString, b[queryIndex+1:]...) } else { u.queryString = append(u.queryString, b[queryIndex+1:fragmentIndex]...) u.hash = append(u.hash, b[fragmentIndex+1:]...) } return } // fragmentIndex >= 0 && queryIndex < 0 // Path is up to the start of fragment u.pathOriginal = append(u.pathOriginal, b[:fragmentIndex]...) u.path = normalizePath(u.path, u.pathOriginal) u.hash = append(u.hash, b[fragmentIndex+1:]...) } // stringContainsCTLByte reports whether s contains any ASCII control character. func stringContainsCTLByte(s []byte) bool { for i := 0; i < len(s); i++ { b := s[i] if b < ' ' || b == 0x7f { return true } } return false } func splitHostURI(host, uri []byte) ([]byte, []byte, []byte) { scheme, path := getScheme(uri) if scheme == nil { return bytestr.StrHTTP, host, uri } uri = path[len(bytestr.StrSlashSlash):] n := bytes.IndexByte(uri, '/') if n < 0 { // A hack for bogus urls like foobar.com?a=b without // slash after host. if n = bytes.IndexByte(uri, '?'); n >= 0 { return scheme, uri[:n], uri[n:] } return scheme, uri, bytestr.StrSlash } return scheme, uri[:n], uri[n:] } func normalizePath(dst, src []byte) []byte { dst = dst[:0] dst = addLeadingSlash(dst, src) dst = decodeArgAppendNoPlus(dst, src) // Windows server need to replace all backslashes with // forward slashes to avoid path traversal attacks. if filepath.Separator == '\\' { for { n := bytes.IndexByte(dst, '\\') if n < 0 { break } dst[n] = '/' } } // remove duplicate slashes b := dst bSize := len(b) for { n := bytes.Index(b, bytestr.StrSlashSlash) if n < 0 { break } b = b[n:] copy(b, b[1:]) b = b[:len(b)-1] bSize-- } dst = dst[:bSize] // remove /./ parts b = dst for { n := bytes.Index(b, bytestr.StrSlashDotSlash) if n < 0 { break } nn := n + len(bytestr.StrSlashDotSlash) - 1 copy(b[n:], b[nn:]) b = b[:len(b)-nn+n] } // remove /foo/../ parts for { n := bytes.Index(b, bytestr.StrSlashDotDotSlash) if n < 0 { break } nn := bytes.LastIndexByte(b[:n], '/') if nn < 0 { nn = 0 } n += len(bytestr.StrSlashDotDotSlash) - 1 copy(b[nn:], b[n:]) b = b[:len(b)-n+nn] } // remove trailing /foo/.. n := bytes.LastIndex(b, bytestr.StrSlashDotDot) if n >= 0 && n+len(bytestr.StrSlashDotDot) == len(b) { nn := bytes.LastIndexByte(b[:n], '/') if nn < 0 { return bytestr.StrSlash } b = b[:nn+1] } return b } func copyArgs(dst, src []argsKV) []argsKV { if cap(dst) < len(src) { tmp := make([]argsKV, len(src)) copy(tmp, dst) dst = tmp } n := len(src) dst = dst[:n] for i := 0; i < n; i++ { dstKV := &dst[i] srcKV := &src[i] dstKV.key = append(dstKV.key[:0], srcKV.key...) if srcKV.noValue { dstKV.value = dstKV.value[:0] } else { dstKV.value = append(dstKV.value[:0], srcKV.value...) } dstKV.noValue = srcKV.noValue } return dst } func (u *URI) updateBytes(newURI, buf []byte) []byte { if len(newURI) == 0 { return buf } n := bytes.Index(newURI, bytestr.StrSlashSlash) if n >= 0 { // absolute uri var b [32]byte schemeOriginal := b[:0] if len(u.scheme) > 0 { schemeOriginal = append([]byte(nil), u.scheme...) } if n == 0 { newURI = bytes.Join([][]byte{u.scheme, bytestr.StrColon, newURI}, nil) } u.Parse(nil, newURI) if len(schemeOriginal) > 0 && len(u.scheme) == 0 { u.scheme = append(u.scheme[:0], schemeOriginal...) } return buf } if newURI[0] == '/' { // uri without host buf = u.appendSchemeHost(buf[:0]) buf = append(buf, newURI...) u.Parse(nil, buf) return buf } // relative path switch newURI[0] { case '?': // query string only update u.SetQueryStringBytes(newURI[1:]) return append(buf[:0], u.FullURI()...) case '#': // update only hash u.SetHashBytes(newURI[1:]) return append(buf[:0], u.FullURI()...) default: // update the last path part after the slash path := u.Path() n = bytes.LastIndexByte(path, '/') if n < 0 { panic("BUG: path must contain at least one slash") } buf = u.appendSchemeHost(buf[:0]) buf = bytesconv.AppendQuotedPath(buf, path[:n+1]) buf = append(buf, newURI...) u.Parse(nil, buf) return buf } } // AppendBytes appends full uri to dst and returns the extended dst. func (u *URI) AppendBytes(dst []byte) []byte { dst = u.appendSchemeHost(dst) dst = append(dst, u.RequestURI()...) if len(u.hash) > 0 { dst = append(dst, '#') dst = append(dst, u.hash...) } return dst } // RequestURI returns RequestURI - i.e. URI without Scheme and Host. func (u *URI) RequestURI() []byte { var dst []byte if u.DisablePathNormalizing { dst = append(u.requestURI[:0], u.PathOriginal()...) } else { dst = bytesconv.AppendQuotedPath(u.requestURI[:0], u.Path()) } if u.queryArgs.Len() > 0 { dst = append(dst, '?') dst = u.queryArgs.AppendBytes(dst) } else if len(u.queryString) > 0 { dst = append(dst, '?') dst = append(dst, u.queryString...) } u.requestURI = dst return u.requestURI } func (u *URI) appendSchemeHost(dst []byte) []byte { dst = append(dst, u.Scheme()...) dst = append(dst, bytestr.StrColonSlashSlash...) return append(dst, u.Host()...) } // FullURI returns full uri in the form {Scheme}://{Host}{RequestURI}#{Hash}. func (u *URI) FullURI() []byte { u.fullURI = u.AppendBytes(u.fullURI[:0]) return u.fullURI } func ParseURI(uriStr string) *URI { uri := &URI{} uri.Parse(nil, []byte(uriStr)) return uri } type Proxy func(*Request) (*URI, error) func ProxyURI(fixedURI *URI) Proxy { return func(*Request) (*URI, error) { return fixedURI, nil } } ================================================ FILE: pkg/protocol/uri_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package protocol import ( "bytes" "path/filepath" "reflect" "runtime" "testing" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestURI_Username(t *testing.T) { var req Request req.SetRequestURI("http://user:pass@example.com/foo/bar") u := req.URI() user1 := string(u.Username()) req.Header.SetRequestURIBytes([]byte("/foo/bar")) u = req.URI() user2 := string(u.Username()) assert.DeepEqual(t, user1, user2) expectUser3 := "user3" expectUser4 := "user4" u.SetUsername(expectUser3) user3 := string(u.Username()) assert.DeepEqual(t, expectUser3, user3) u.SetUsername(expectUser4) user4 := string(u.Username()) assert.DeepEqual(t, expectUser4, user4) u.SetUsernameBytes([]byte(user3)) assert.DeepEqual(t, expectUser3, user3) u.SetUsernameBytes([]byte(user4)) assert.DeepEqual(t, expectUser4, user4) } func TestURI_Password(t *testing.T) { u := AcquireURI() defer ReleaseURI(u) expectPassword1 := "password1" expectPassword2 := "password2" u.SetPassword(expectPassword1) password1 := string(u.Password()) assert.DeepEqual(t, expectPassword1, password1) u.SetPassword(expectPassword2) password2 := string(u.Password()) assert.DeepEqual(t, expectPassword2, password2) u.SetPasswordBytes([]byte(password1)) assert.DeepEqual(t, expectPassword1, password1) u.SetPasswordBytes([]byte(password2)) assert.DeepEqual(t, expectPassword2, password2) } func TestURI_Hash(t *testing.T) { u := AcquireURI() defer ReleaseURI(u) expectHash1 := "hash1" expectHash2 := "hash2" u.SetHash(expectHash1) hash1 := string(u.Hash()) assert.DeepEqual(t, expectHash1, hash1) u.SetHash(expectHash2) hash2 := string(u.Hash()) assert.DeepEqual(t, expectHash2, hash2) } func TestURI_QueryString(t *testing.T) { u := AcquireURI() defer ReleaseURI(u) expectQueryString1 := "key1=value1&key2=value2" expectQueryString2 := "key3=value3&key4=value4" u.SetQueryString(expectQueryString1) queryString1 := string(u.QueryString()) assert.DeepEqual(t, expectQueryString1, queryString1) u.SetQueryString(expectQueryString2) queryString2 := string(u.QueryString()) assert.DeepEqual(t, expectQueryString2, queryString2) } func TestURI_Path(t *testing.T) { u := AcquireURI() defer ReleaseURI(u) expectPath1 := "/" expectPath2 := "/path1" expectPath3 := "/path3" // When Path is not set, Path defaults to "/" path1 := string(u.Path()) assert.DeepEqual(t, expectPath1, path1) u.SetPath(expectPath2) path2 := string(u.Path()) assert.DeepEqual(t, expectPath2, path2) u.SetPath(expectPath3) path3 := string(u.Path()) assert.DeepEqual(t, expectPath3, path3) u.SetPathBytes([]byte(path2)) assert.DeepEqual(t, expectPath2, path2) u.SetPathBytes([]byte(path3)) assert.DeepEqual(t, expectPath3, path3) } func TestURI_Scheme(t *testing.T) { u := AcquireURI() defer ReleaseURI(u) expectScheme1 := "scheme1" expectScheme2 := "scheme2" u.SetScheme(expectScheme1) scheme1 := string(u.Scheme()) assert.DeepEqual(t, expectScheme1, scheme1) u.SetScheme(expectScheme2) scheme2 := string(u.Scheme()) assert.DeepEqual(t, expectScheme2, scheme2) u.SetSchemeBytes([]byte(scheme1)) assert.DeepEqual(t, expectScheme1, scheme1) u.SetSchemeBytes([]byte(scheme2)) assert.DeepEqual(t, expectScheme2, scheme2) } func TestURI_Host(t *testing.T) { u := AcquireURI() defer ReleaseURI(u) expectHost1 := "host1" expectHost2 := "host2" u.SetHost(expectHost1) host1 := string(u.Host()) assert.DeepEqual(t, expectHost1, host1) u.SetHost(expectHost2) host2 := string(u.Host()) assert.DeepEqual(t, expectHost2, host2) u.SetHostBytes([]byte(host1)) assert.DeepEqual(t, expectHost1, host1) u.SetHostBytes([]byte(host2)) assert.DeepEqual(t, expectHost2, host2) } func TestURI_PathOriginal(t *testing.T) { var u URI expectPath := "/path" u.Parse(nil, []byte(expectPath)) uri := string(u.PathOriginal()) assert.DeepEqual(t, expectPath, uri) } func TestArgsKV_Get(t *testing.T) { var argsKV argsKV expectKey := "key" expectValue := "value" argsKV.key = []byte(expectKey) argsKV.value = []byte(expectValue) key := string(argsKV.GetKey()) value := string(argsKV.GetValue()) assert.DeepEqual(t, expectKey, key) assert.DeepEqual(t, expectValue, value) } func TestURICopyToQueryArgs(t *testing.T) { t.Parallel() var u URI a := u.QueryArgs() a.Set("foo", "bar") var u1 URI u.CopyTo(&u1) a1 := u1.QueryArgs() if string(a1.Peek("foo")) != "bar" { t.Fatalf("unexpected query args value %q. Expecting %q", a1.Peek("foo"), "bar") } assert.DeepEqual(t, "bar", string(a1.Peek("foo"))) } func TestURICopyTo(t *testing.T) { t.Parallel() var u URI var copyU URI u.CopyTo(©U) if !reflect.DeepEqual(&u, ©U) { //nolint:govet t.Fatalf("URICopyTo fail, u: \n%+v\ncopyu: \n%+v\n", &u, ©U) //nolint:govet } u.UpdateBytes([]byte("https://google.com/foo?bar=baz&baraz#qqqq")) u.CopyTo(©U) if !reflect.DeepEqual(&u, ©U) { //nolint:govet t.Fatalf("URICopyTo fail, u: \n%+v\ncopyu: \n%+v\n", &u, ©U) //nolint:govet } } func TestURILastPathSegment(t *testing.T) { t.Parallel() testURILastPathSegment(t, "", "") testURILastPathSegment(t, "/", "") testURILastPathSegment(t, "/foo/bar/", "") testURILastPathSegment(t, "/foobar.js", "foobar.js") testURILastPathSegment(t, "/foo/bar/baz.html", "baz.html") } func testURILastPathSegment(t *testing.T, path, expectedSegment string) { var u URI u.SetPath(path) segment := u.LastPathSegment() assert.DeepEqual(t, expectedSegment, string(segment)) } func TestURIPathEscape(t *testing.T) { t.Parallel() testURIPathEscape(t, "/foo/bar", "/foo/bar") testURIPathEscape(t, "/f_o-o=b:ar,b.c&q", "/f_o-o=b:ar,b.c&q") testURIPathEscape(t, "/aa?bb.тест~qq", "/aa%3Fbb.%D1%82%D0%B5%D1%81%D1%82~qq") } func TestURIUpdate(t *testing.T) { t.Parallel() // full uri testURIUpdate(t, "http://foo.bar/baz?aaa=22#aaa", "https://aaa.com/bb", "https://aaa.com/bb") // empty uri testURIUpdate(t, "http://aaa.com/aaa.html?234=234#add", "", "http://aaa.com/aaa.html?234=234#add") // request uri testURIUpdate(t, "ftp://aaa/xxx/yyy?aaa=bb#aa", "/boo/bar?xx", "ftp://aaa/boo/bar?xx") // relative uri testURIUpdate(t, "http://foo.bar/baz/xxx.html?aaa=22#aaa", "bb.html?xx=12#pp", "http://foo.bar/baz/bb.html?xx=12#pp") testURIUpdate(t, "http://xx/a/b/c/d", "../qwe/p?zx=34", "http://xx/a/b/qwe/p?zx=34") testURIUpdate(t, "https://qqq/aaa.html?foo=bar", "?baz=434&aaa#xcv", "https://qqq/aaa.html?baz=434&aaa#xcv") testURIUpdate(t, "http://foo.bar/baz", "~a/%20b=c,тест?йцу=ке", "http://foo.bar/~a/%20b=c,%D1%82%D0%B5%D1%81%D1%82?йцу=ке") testURIUpdate(t, "http://foo.bar/baz", "/qwe#fragment", "http://foo.bar/qwe#fragment") testURIUpdate(t, "http://foobar/baz/xxx", "aaa.html#bb?cc=dd&ee=dfd", "http://foobar/baz/aaa.html#bb?cc=dd&ee=dfd") // hash testURIUpdate(t, "http://foo.bar/baz#aaa", "#fragment", "http://foo.bar/baz#fragment") // uri without scheme testURIUpdate(t, "https://foo.bar/baz", "//aaa.bbb/cc?dd", "https://aaa.bbb/cc?dd") testURIUpdate(t, "http://foo.bar/baz", "//aaa.bbb/cc?dd", "http://aaa.bbb/cc?dd") } func testURIUpdate(t *testing.T, base, update, result string) { var u URI u.Parse(nil, []byte(base)) u.Update(update) s := u.String() assert.DeepEqual(t, result, s) } func testURIPathEscape(t *testing.T, path, expectedRequestURI string) { var u URI u.SetPath(path) requestURI := u.RequestURI() assert.DeepEqual(t, expectedRequestURI, string(requestURI)) } func TestDelArgs(t *testing.T) { var args Args args.Set("foo", "bar") assert.DeepEqual(t, string(args.Peek("foo")), "bar") args.Del("foo") assert.DeepEqual(t, string(args.Peek("foo")), "") args.Set("foo2", "bar2") assert.DeepEqual(t, string(args.Peek("foo2")), "bar2") args.DelBytes([]byte("foo2")) assert.DeepEqual(t, string(args.Peek("foo2")), "") } func TestURIFullURI(t *testing.T) { t.Parallel() var args Args // empty scheme, path and hash testURIFullURI(t, "", "foobar.com", "", "", &args, "http://foobar.com/") // empty scheme and hash testURIFullURI(t, "", "aaa.com", "/foo/bar", "", &args, "http://aaa.com/foo/bar") // empty hash testURIFullURI(t, "fTP", "XXx.com", "/foo", "", &args, "ftp://xxx.com/foo") // empty args testURIFullURI(t, "https", "xx.com", "/", "aaa", &args, "https://xx.com/#aaa") // non-empty args and non-ASCII path args.Set("foo", "bar") args.Set("xxx", "йух") testURIFullURI(t, "", "xxx.com", "/тест123", "2er", &args, "http://xxx.com/%D1%82%D0%B5%D1%81%D1%82123?foo=bar&xxx=%D0%B9%D1%83%D1%85#2er") // test with empty args and non-empty query string var u URI u.Parse([]byte("google.com"), []byte("/foo?bar=baz&baraz#qqqq")) uri := u.FullURI() expectedURI := "http://google.com/foo?bar=baz&baraz#qqqq" assert.DeepEqual(t, expectedURI, string(uri)) } func testURIFullURI(t *testing.T, scheme, host, path, hash string, args *Args, expectedURI string) { var u URI u.SetScheme(scheme) u.SetHost(host) u.SetPath(path) u.SetHash(hash) args.CopyTo(u.QueryArgs()) uri := u.FullURI() assert.DeepEqual(t, expectedURI, string(uri)) } func TestParsePathWindows(t *testing.T) { t.Parallel() testParsePathWindows(t, "/../../../../../foo", "/foo") testParsePathWindows(t, "/..\\..\\..\\..\\..\\foo", "/foo") testParsePathWindows(t, "/..%5c..%5cfoo", "/foo") } func TestURIPathNormalize(t *testing.T) { if runtime.GOOS == "windows" { t.SkipNow() } t.Parallel() var u URI // double slash testURIPathNormalize(t, &u, "/aa//bb", "/aa/bb") // triple slash testURIPathNormalize(t, &u, "/x///y/", "/x/y/") // multi slashes testURIPathNormalize(t, &u, "/abc//de///fg////", "/abc/de/fg/") // encoded slashes testURIPathNormalize(t, &u, "/xxxx%2fyyy%2f%2F%2F", "/xxxx/yyy/") // dotdot testURIPathNormalize(t, &u, "/aaa/..", "/") // dotdot with trailing slash testURIPathNormalize(t, &u, "/xxx/yyy/../", "/xxx/") // multi dotdots testURIPathNormalize(t, &u, "/aaa/bbb/ccc/../../ddd", "/aaa/ddd") // dotdots separated by other data testURIPathNormalize(t, &u, "/a/b/../c/d/../e/..", "/a/c/") // too many dotdots testURIPathNormalize(t, &u, "/aaa/../../../../xxx", "/xxx") testURIPathNormalize(t, &u, "/../../../../../..", "/") testURIPathNormalize(t, &u, "/../../../../../../", "/") // encoded dotdots testURIPathNormalize(t, &u, "/aaa%2Fbbb%2F%2E.%2Fxxx", "/aaa/xxx") // double slash with dotdots testURIPathNormalize(t, &u, "/aaa////..//b", "/b") // fake dotdot testURIPathNormalize(t, &u, "/aaa/..bbb/ccc/..", "/aaa/..bbb/") // single dot testURIPathNormalize(t, &u, "/a/./b/././c/./d.html", "/a/b/c/d.html") testURIPathNormalize(t, &u, "./foo/", "/foo/") testURIPathNormalize(t, &u, "./../.././../../aaa/bbb/../../../././../", "/") testURIPathNormalize(t, &u, "./a/./.././../b/./foo.html", "/b/foo.html") } func testURIPathNormalize(t *testing.T, u *URI, requestURI, expectedPath string) { u.Parse(nil, []byte(requestURI)) //nolint:errcheck if string(u.Path()) != expectedPath { t.Fatalf("Unexpected path %q. Expected %q. requestURI=%q", u.Path(), expectedPath, requestURI) } } func testParsePathWindows(t *testing.T, path, expectedPath string) { var u URI u.Parse(nil, []byte(path)) parsedPath := u.Path() if filepath.Separator == '\\' && string(parsedPath) != expectedPath { t.Fatalf("Unexpected Path: %q. Expected %q", parsedPath, expectedPath) } } func TestParseHostWithStr(t *testing.T) { expectUsername := "username" expectPassword := "password" testParseHostWithStr(t, "username", "", "") testParseHostWithStr(t, "username@", expectUsername, "") testParseHostWithStr(t, "username:password@", expectUsername, expectPassword) testParseHostWithStr(t, ":password@", "", expectPassword) testParseHostWithStr(t, ":password", "", "") } func testParseHostWithStr(t *testing.T, host, expectUsername, expectPassword string) { var u URI u.Parse([]byte(host), nil) assert.DeepEqual(t, expectUsername, string(u.Username())) assert.DeepEqual(t, expectPassword, string(u.Password())) } func TestParseURI(t *testing.T) { expectURI := "http://google.com/foo?bar=baz&baraz#qqqq" uri := string(ParseURI(expectURI).FullURI()) assert.DeepEqual(t, expectURI, uri) } func TestSplitHostURI(t *testing.T) { cases := []struct { host, uri []byte wantScheme, wantHost, wantPath []byte }{ { []byte("example.com"), []byte("/foobar"), []byte("http"), []byte("example.com"), []byte("/foobar"), }, { []byte("example2.com"), []byte("http://example2.com"), []byte("http"), []byte("example2.com"), []byte("/"), }, { []byte("example2.com"), []byte("http://example3.com"), []byte("http"), []byte("example3.com"), []byte("/"), }, { []byte("example3.com"), []byte("https://foobar.com?a=b"), []byte("https"), []byte("foobar.com"), []byte("?a=b"), }, } for _, c := range cases { gotScheme, gotHost, gotPath := splitHostURI(c.host, c.uri) if !bytes.Equal(gotScheme, c.wantScheme) || !bytes.Equal(gotHost, c.wantHost) || !bytes.Equal(gotPath, c.wantPath) { t.Errorf("splitHostURI(%q, %q) == (%q, %q, %q), want (%q, %q, %q)", c.host, c.uri, gotScheme, gotHost, gotPath, c.wantScheme, c.wantHost, c.wantPath) } } } ================================================ FILE: pkg/protocol/uri_timing_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package protocol import ( "testing" ) func BenchmarkURIParsePath(b *testing.B) { benchmarkURIParse(b, "google.com", "/foo/bar") } func BenchmarkURIParsePathQueryString(b *testing.B) { benchmarkURIParse(b, "google.com", "/foo/bar?query=string&other=value") } func BenchmarkURIParsePathQueryStringHash(b *testing.B) { benchmarkURIParse(b, "google.com", "/foo/bar?query=string&other=value#hashstring") } func BenchmarkURIParseHostname(b *testing.B) { benchmarkURIParse(b, "google.com", "http://foobar.com/foo/bar?query=string&other=value#hashstring") } func BenchmarkURIFullURI(b *testing.B) { host := []byte("foobar.com") requestURI := []byte("/foobar/baz?aaa=bbb&ccc=ddd") uriLen := len(host) + len(requestURI) + 7 b.RunParallel(func(pb *testing.PB) { var u URI u.Parse(host, requestURI) for pb.Next() { uri := u.FullURI() if len(uri) != uriLen { b.Fatalf("unexpected uri len %d. Expecting %d", len(uri), uriLen) } } }) } func benchmarkURIParse(b *testing.B, host, uri string) { strHost, strURI := []byte(host), []byte(uri) b.RunParallel(func(pb *testing.PB) { var u URI for pb.Next() { u.Parse(strHost, strURI) } }) } ================================================ FILE: pkg/protocol/uri_unix.go ================================================ //go:build !windows /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package protocol import "github.com/cloudwego/hertz/pkg/common/hlog" func addLeadingSlash(dst, src []byte) []byte { // add leading slash for unix paths if len(src) == 0 || src[0] != '/' { dst = append(dst, '/') } return dst } // checkSchemeWhenCharIsColon check url begin with : // Scenarios that handle protocols like "http:" func checkSchemeWhenCharIsColon(i int, rawURL []byte) (scheme, path []byte) { if i == 0 { hlog.Errorf("error happened when try to parse the rawURL(%s): missing protocol scheme", rawURL) return } return rawURL[:i], rawURL[i+1:] } ================================================ FILE: pkg/protocol/uri_unix_test.go ================================================ //go:build !windows /* * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package protocol import ( "testing" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestGetScheme(t *testing.T) { scheme, path := getScheme([]byte("https://foo.com")) assert.DeepEqual(t, "https", string(scheme)) assert.DeepEqual(t, "//foo.com", string(path)) scheme, path = getScheme([]byte(":")) assert.DeepEqual(t, "", string(scheme)) assert.DeepEqual(t, "", string(path)) scheme, path = getScheme([]byte("ws://127.0.0.1")) assert.DeepEqual(t, "ws", string(scheme)) assert.DeepEqual(t, "//127.0.0.1", string(path)) scheme, path = getScheme([]byte("/hertz/demo")) assert.DeepEqual(t, "", string(scheme)) assert.DeepEqual(t, "/hertz/demo", string(path)) } ================================================ FILE: pkg/protocol/uri_windows.go ================================================ //go:build windows /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package protocol import "github.com/cloudwego/hertz/pkg/common/hlog" func addLeadingSlash(dst, src []byte) []byte { // zero length and "C:/" case isDisk := len(src) > 2 && src[1] == ':' if len(src) == 0 || (!isDisk && src[0] != '/') { dst = append(dst, '/') } return dst } // checkSchemeWhenCharIsColon check url begin with : // Scenarios that handle protocols like "http:" // Add the path to the win file, e.g. "E:\gopath", "E:\". func checkSchemeWhenCharIsColon(i int, rawURL []byte) (scheme, path []byte) { if i == 0 { hlog.Errorf("error happened when trying to parse the rawURL(%s): missing protocol scheme", rawURL) return } // case :\ if i+1 < len(rawURL) && rawURL[i+1] == '\\' { return nil, rawURL } return rawURL[:i], rawURL[i+1:] } ================================================ FILE: pkg/protocol/uri_windows_test.go ================================================ // Copyright 2023 CloudWeGo Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package protocol import ( "testing" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestURIPathNormalizeIssue86(t *testing.T) { t.Parallel() var u URI testURIPathNormalize(t, &u, `a`, `/a`) testURIPathNormalize(t, &u, "/../../../../../foo", "/foo") testURIPathNormalize(t, &u, "/..\\..\\..\\..\\..\\", "/") testURIPathNormalize(t, &u, "/..%5c..%5cfoo", "/foo") } func TestGetScheme(t *testing.T) { scheme, path := getScheme([]byte("E:\\file.go")) assert.DeepEqual(t, "", string(scheme)) assert.DeepEqual(t, "E:\\file.go", string(path)) scheme, path = getScheme([]byte("E:\\")) assert.DeepEqual(t, "", string(scheme)) assert.DeepEqual(t, "E:\\", string(path)) scheme, path = getScheme([]byte("https://foo.com")) assert.DeepEqual(t, "https", string(scheme)) assert.DeepEqual(t, "//foo.com", string(path)) scheme, path = getScheme([]byte("://")) assert.DeepEqual(t, "", string(scheme)) assert.DeepEqual(t, "", string(path)) scheme, path = getScheme([]byte("ws://127.0.0.1")) assert.DeepEqual(t, "ws", string(scheme)) assert.DeepEqual(t, "//127.0.0.1", string(path)) } ================================================ FILE: pkg/route/consts/const.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package consts import "math" const AbortIndex int8 = math.MaxInt8 / 2 ================================================ FILE: pkg/route/engine.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * The MIT License (MIT) * Copyright (c) 2014 Manuel Martínez-Almeida * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors */ package route import ( "bytes" "context" "crypto/tls" "errors" "fmt" "html/template" "io" "net" "path/filepath" "reflect" "runtime" "strings" "sync" "sync/atomic" "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/internal/nocopy" internalStats "github.com/cloudwego/hertz/internal/stats" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/app/server/binding" "github.com/cloudwego/hertz/pkg/app/server/render" "github.com/cloudwego/hertz/pkg/common/config" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/hlog" "github.com/cloudwego/hertz/pkg/common/tracer" "github.com/cloudwego/hertz/pkg/common/tracer/stats" "github.com/cloudwego/hertz/pkg/common/tracer/traceinfo" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/network/standard" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/protocol/http1" "github.com/cloudwego/hertz/pkg/protocol/http1/factory" "github.com/cloudwego/hertz/pkg/protocol/suite" "github.com/cloudwego/hertz/pkg/route/param" ) const unknownTransporterName = "unknown" var ( // will be netpoll.NewTransporter if available, see: netpoll.go defaultTransporter = standard.NewTransporter errInitFailed = errs.NewPrivate("engine has been init already") errAlreadyRunning = errs.NewPrivate("engine is already running") errStatusNotRunning = errs.NewPrivate("engine is not running") default404Body = []byte("Not Found") default405Body = []byte("Method Not Allowed") default400Body = []byte("Bad Request") requiredHostBody = []byte("missing required Host header") ) type hijackConn struct { network.Conn e *Engine } type CtxCallback func(ctx context.Context) type CtxErrCallback func(ctx context.Context) error // RouteInfo represents a request route's specification which contains method and path and its handler. type RouteInfo struct { Method string Path string Handler string HandlerFunc app.HandlerFunc } // RoutesInfo defines a RouteInfo array. type RoutesInfo []RouteInfo type Engine struct { noCopy nocopy.NoCopy //lint:ignore U1000 until noCopy is used // engine name Name string serverName atomic.Value // Options for route and protocol server options *config.Options // route RouterGroup trees MethodTrees maxParams uint16 allNoMethod app.HandlersChain allNoRoute app.HandlersChain noRoute app.HandlersChain noMethod app.HandlersChain // For render HTML delims render.Delims funcMap template.FuncMap htmlRender render.HTMLRender // NoHijackConnPool will control whether invite pool to acquire/release the hijackConn or not. // If it is difficult to guarantee that hijackConn will not be closed repeatedly, set it to true. NoHijackConnPool bool hijackConnPool sync.Pool // KeepHijackedConns is an opt-in disable of connection // close by hertz after connections' HijackHandler returns. // This allows to save goroutines, e.g. when hertz used to upgrade // http connections to WS and connection goes to another handler, // which will close it when needed. KeepHijackedConns bool // underlying transport transport network.Transporter // trace tracerCtl tracer.Controller enableTrace bool // protocol layer management protocolSuite *suite.Config protocolServers map[string]protocol.Server protocolStreamServers map[string]protocol.StreamServer // RequestContext pool ctxPool sync.Pool // Function to handle panics recovered from http handlers. // It should be used to generate an error page and return the http error code // 500 (Internal Server Error). // The handler can be used to keep your server from crashing because of // unrecovered panics. PanicHandler app.HandlerFunc // ContinueHandler is called after receiving the Expect 100 Continue Header // // https://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html#sec8.2.3 // https://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.1.1 // Using ContinueHandler a server can make decisioning on whether or not // to read a potentially large request body based on the headers // // The default is to automatically read request bodies of Expect 100 Continue requests // like they are normal requests ContinueHandler func(header *protocol.RequestHeader) bool // Indicates the engine status (Init/Running/Shutdown/Closed). status uint32 // Hook functions get triggered sequentially when engine start OnRun []CtxErrCallback // Hook functions get triggered simultaneously when engine shutdown OnShutdown []CtxCallback // Custom Functions clientIPFunc app.ClientIP formValueFunc app.FormValueFunc // Custom Binder binder binding.Binder } func (engine *Engine) IsTraceEnable() bool { return engine.enableTrace } func (engine *Engine) GetCtxPool() *sync.Pool { return &engine.ctxPool } func (engine *Engine) GetOptions() *config.Options { return engine.options } // SetTransporter only sets the global default value for the transporter. // Use WithTransporter during engine creation to set the transporter for the engine. func SetTransporter(transporter func(options *config.Options) network.Transporter) { defaultTransporter = transporter } func (engine *Engine) GetTransporterName() (tName string) { return getTransporterName(engine.transport) } func getTransporterName(transporter network.Transporter) (tName string) { defer func() { err := recover() if err != nil || tName == "" { tName = unknownTransporterName } }() t := reflect.ValueOf(transporter).Type().String() tName = strings.Split(strings.TrimPrefix(t, "*"), ".")[0] return tName } // Deprecated: This only get the global default transporter - may not be the real one used by the engine. // Use engine.GetTransporterName for the real transporter used. func GetTransporterName() (tName string) { defer func() { err := recover() if err != nil || tName == "" { tName = unknownTransporterName } }() fName := runtime.FuncForPC(reflect.ValueOf(defaultTransporter).Pointer()).Name() fSlice := strings.Split(fName, "/") name := fSlice[len(fSlice)-1] fSlice = strings.Split(name, ".") tName = fSlice[0] return } func (engine *Engine) IsStreamRequestBody() bool { return engine.options.StreamRequestBody } func (engine *Engine) IsRunning() bool { if atomic.LoadUint32(&engine.status) != statusRunning { return false } // double check listener type ListenerIface interface { Listener() net.Listener } v, ok := engine.transport.(ListenerIface) if ok { return v.Listener() != nil } return true // default behavior if no ListenerIface } func (engine *Engine) HijackConnHandle(c network.Conn, h app.HijackHandler) { engine.hijackConnHandler(c, h) } func (engine *Engine) GetTracer() tracer.Controller { return engine.tracerCtl } const ( _ uint32 = iota statusInitialized statusRunning statusShutdown statusClosed ) // NewContext make a pure RequestContext without any http request/response information // // Set the Request filed before use it for handlers func (engine *Engine) NewContext() *app.RequestContext { return app.NewContext(engine.maxParams) } // Shutdown starts the server's graceful exit by next steps: // // 1. Trigger OnShutdown hooks concurrently and wait them until wait timeout or finish // 2. Close the net listener, which means new connection won't be accepted // 3. Wait all connections get closed: // One connection gets closed after reaching out the shorter time of processing // one request (in hand or next incoming), idleTimeout or ExitWaitTime // 4. Exit func (engine *Engine) Shutdown(ctx context.Context) (err error) { if atomic.LoadUint32(&engine.status) != statusRunning { return errStatusNotRunning } if !atomic.CompareAndSwapUint32(&engine.status, statusRunning, statusShutdown) { return } opt := engine.GetOptions() hlog.SystemLogger().Infof("Begin graceful shutdown, wait at most %s ...", opt.ExitWaitTimeout) ctx, cancel := context.WithTimeout(ctx, opt.ExitWaitTimeout) defer cancel() ch := make(chan struct{}) go func() { defer close(ch) engine.executeOnShutdownHooks(ctx) }() defer func() { // ensure that the hook is executed until wait timeout or finish select { case <-ctx.Done(): hlog.SystemLogger().Infof("Execute OnShutdownHooks timeout: error=%v", ctx.Err()) return case <-ch: hlog.SystemLogger().Info("Execute OnShutdownHooks finish") return } }() if opt.Registry != nil { if err = opt.Registry.Deregister(opt.RegistryInfo); err != nil { hlog.SystemLogger().Errorf("Deregister error=%v", err) return err } } // call transport shutdown if err := engine.transport.Shutdown(ctx); err != ctx.Err() { return err } return } func (engine *Engine) executeOnShutdownHooks(ctx context.Context) { wg := sync.WaitGroup{} for i := range engine.OnShutdown { wg.Add(1) go func(index int) { defer wg.Done() engine.OnShutdown[index](ctx) }(i) } wg.Wait() } func (engine *Engine) Run() (err error) { if err = engine.Init(); err != nil { return err } // trigger hooks if any ctx := context.Background() for i := range engine.OnRun { if err = engine.OnRun[i](ctx); err != nil { return err } } if err = engine.MarkAsRunning(); err != nil { return err } defer atomic.StoreUint32(&engine.status, statusClosed) return engine.listenAndServe() } func (engine *Engine) Init() error { // add built-in http1 server by default if !engine.HasServer(suite.HTTP1) { engine.AddProtocol(suite.HTTP1, factory.NewServerFactory(newHttp1OptionFromEngine(engine))) } serverMap, streamServerMap, err := engine.protocolSuite.LoadAll(engine) if err != nil { return errs.New(err, errs.ErrorTypePrivate, "LoadAll protocol suite error") } engine.protocolServers = serverMap engine.protocolStreamServers = streamServerMap if engine.alpnEnable() { engine.options.TLS.NextProtos = append(engine.options.TLS.NextProtos, suite.HTTP1) } if !atomic.CompareAndSwapUint32(&engine.status, 0, statusInitialized) { return errInitFailed } return nil } func (engine *Engine) alpnEnable() bool { return engine.options.TLS != nil && engine.options.ALPN } func (engine *Engine) listenAndServe() error { hlog.SystemLogger().Infof("Using network library=%s", engine.GetTransporterName()) return engine.transport.ListenAndServe(engine.onData) } func (c *hijackConn) Close() error { if !c.e.KeepHijackedConns { // when we do not keep hijacked connections, // it is closed in hijackConnHandler. return nil } conn := c.Conn c.e.releaseHijackConn(c) return conn.Close() } func (engine *Engine) getNextProto(conn network.Conn) (proto string, err error) { if tlsConn, ok := conn.(network.ConnTLSer); ok { if engine.options.ReadTimeout > 0 { if err := conn.SetReadTimeout(engine.options.ReadTimeout); err != nil { hlog.SystemLogger().Errorf("BUG: error in SetReadDeadline=%s: error=%s", engine.options.ReadTimeout, err) } } err = tlsConn.Handshake() if err == nil { proto = tlsConn.ConnectionState().NegotiatedProtocol } } return } func (engine *Engine) onData(c context.Context, conn interface{}) (err error) { switch conn := conn.(type) { case network.Conn: err = engine.Serve(c, conn) case network.StreamConn: err = engine.ServeStream(c, conn) } return } func logError(conn network.Conn, err error) { // Quiet close the connection if errors.Is(err, errs.ErrShortConnection) || errors.Is(err, errs.ErrIdleTimeout) { return } // Do not process the hijack connection error if errors.Is(err, errs.ErrHijacked) { return } // Get remote address rip := "" if addr := conn.RemoteAddr(); addr != nil { rip = addr.String() } // Handle Specific error if hsp, ok := conn.(network.HandleSpecificError); ok { if hsp.HandleSpecificError(err, rip) { return } } // other errors hlog.SystemLogger().Errorf(hlog.EngineErrorFormat, err.Error(), rip) } func (engine *Engine) Close() error { if engine.htmlRender != nil { engine.htmlRender.Close() //nolint:errcheck } return engine.transport.Close() } func (engine *Engine) GetServerName() []byte { v := engine.serverName.Load() var serverName []byte if v == nil { serverName = []byte(engine.Name) if len(serverName) == 0 { serverName = bytestr.DefaultServerName } engine.serverName.Store(serverName) } else { serverName = v.([]byte) } return serverName } func (engine *Engine) Serve(c context.Context, conn network.Conn) (err error) { defer func() { if err != nil { logError(conn, err) } // always close conn before Serve returns, // some implementations (e.g., netpoll) may reuse conn if not closed _ = conn.Close() }() // H2C path if engine.options.H2C { // protocol sniffer buf, _ := conn.Peek(len(bytestr.StrClientPreface)) if bytes.Equal(buf, bytestr.StrClientPreface) && engine.protocolServers[suite.HTTP2] != nil { return engine.protocolServers[suite.HTTP2].Serve(c, conn) } hlog.SystemLogger().Warn("HTTP2 server is not loaded, request is going to fallback to HTTP1 server") } // ALPN path if engine.options.ALPN && engine.options.TLS != nil { proto, err1 := engine.getNextProto(conn) if err1 != nil { // The client closes the connection when handshake. So just ignore it. if err1 == io.EOF { return nil } if re, ok := err1.(tls.RecordHeaderError); ok && re.Conn != nil && utils.TLSRecordHeaderLooksLikeHTTP(re.RecordHeader) { io.WriteString(re.Conn, "HTTP/1.0 400 Bad Request\r\n\r\nClient sent an HTTP request to an HTTPS server.\n") re.Conn.Close() return re } return err1 } if server, ok := engine.protocolServers[proto]; ok { return server.Serve(c, conn) } } // HTTP1 path err = engine.protocolServers[suite.HTTP1].Serve(c, conn) return } func (engine *Engine) ServeStream(ctx context.Context, conn network.StreamConn) error { // ALPN path if engine.options.ALPN && engine.options.TLS != nil { version := conn.GetVersion() nextProtocol := versionToALNP(version) if server, ok := engine.protocolStreamServers[nextProtocol]; ok { return server.Serve(ctx, conn) } } // default path if server, ok := engine.protocolStreamServers[suite.HTTP3]; ok { return server.Serve(ctx, conn) } return errs.ErrNotSupportProtocol } func (engine *Engine) initBinderAndValidator(opt *config.Options) { if opt.CustomBinder != nil { engine.initCustomBinder(opt.CustomBinder) return } vf := engine.initValidatorFunc(opt) if opt.BindConfig == nil { c := binding.NewBindConfig() c.ValidatorFunc = vf engine.binder = binding.NewDefaultBinder(c) return } engine.initDefaultBinder(opt.BindConfig, vf) } // initValidator initializes the validator function and returns whether a custom validator was used func (engine *Engine) initValidatorFunc(opt *config.Options) binding.ValidatorFunc { customValidator := opt.CustomValidator if customValidator == nil { conf := opt.ValidateConfig //nolint:staticcheck // Deprecated vc, ok := conf.(*binding.ValidateConfig) //nolint:staticcheck // Deprecated if !ok && conf != nil { panic(fmt.Errorf("ValidateConfig type err: %T", conf)) } if vc == nil { return nil } customValidator = binding.NewValidator(vc) //nolint:staticcheck // Deprecated } switch v := customValidator.(type) { case binding.ValidatorFunc: // ValidatorFunc (preferred approach) return v case func(*protocol.Request, interface{}) error: // Function with correct signature, convert to ValidatorFunc return binding.ValidatorFunc(v) case binding.StructValidator: //nolint:staticcheck // Deprecated // StructValidator (backwards compatibility) return binding.MakeValidatorFunc(v) default: panic(fmt.Errorf("customized validator type err: %T", v)) } } // initCustomBinder handles custom binder initialization func (engine *Engine) initCustomBinder(customBinder interface{}) { binder, ok := customBinder.(binding.Binder) if !ok { panic("customized binder does not implement binding.Binder") } engine.binder = binder } // initDefaultBinder initializes the default binder with optional custom config func (engine *Engine) initDefaultBinder(bindConfig interface{}, vf binding.ValidatorFunc) { bConf, ok := bindConfig.(*binding.BindConfig) if !ok { panic("opt.BindConfig is not the '*binding.BindConfig' type") } // User customized validator has the highest priority if vf != nil { bConf.ValidatorFunc = vf } engine.binder = binding.NewDefaultBinder(bConf) } func NewEngine(opt *config.Options) *Engine { engine := &Engine{ trees: make(MethodTrees, 0, 9), RouterGroup: RouterGroup{ Handlers: nil, basePath: opt.BasePath, root: true, }, transport: defaultTransporter(opt), tracerCtl: &internalStats.Controller{}, protocolServers: make(map[string]protocol.Server), protocolStreamServers: make(map[string]protocol.StreamServer), enableTrace: true, options: opt, } engine.initBinderAndValidator(opt) if opt.TransporterNewer != nil { engine.transport = opt.TransporterNewer(opt) } engine.RouterGroup.engine = engine traceLevel := initTrace(engine) // prepare RequestContext pool engine.ctxPool.New = func() interface{} { ctx := engine.allocateContext() if engine.enableTrace { ti := traceinfo.NewTraceInfo() ti.Stats().SetLevel(traceLevel) ctx.SetTraceInfo(ti) } return ctx } // Init protocolSuite engine.protocolSuite = suite.New() return engine } func initTrace(engine *Engine) stats.Level { for _, ti := range engine.options.Tracers { if tracer, ok := ti.(tracer.Tracer); ok { engine.tracerCtl.Append(tracer) } } if !engine.tracerCtl.HasTracer() { engine.enableTrace = false } traceLevel := stats.LevelDetailed if tl, ok := engine.options.TraceLevel.(stats.Level); ok { traceLevel = tl } return traceLevel } func debugPrintRoute(httpMethod, absolutePath string, handlers app.HandlersChain) { nuHandlers := len(handlers) handlerName := app.GetHandlerName(handlers.Last()) if handlerName == "" { handlerName = utils.NameOfFunction(handlers.Last()) } hlog.SystemLogger().Debugf("Method=%-6s absolutePath=%-25s --> handlerName=%s (num=%d handlers)", httpMethod, absolutePath, handlerName, nuHandlers) } func (engine *Engine) addRoute(method, path string, handlers app.HandlersChain) { if len(path) == 0 { panic("path should not be ''") } utils.Assert(path[0] == '/', "path must begin with '/'") utils.Assert(method != "", "HTTP method can not be empty") utils.Assert(len(handlers) > 0, "there must be at least one handler") if !engine.options.DisablePrintRoute { debugPrintRoute(method, path, handlers) } methodRouter := engine.trees.get(method) if methodRouter == nil { methodRouter = &router{method: method, root: &node{}} engine.trees = append(engine.trees, methodRouter) } methodRouter.addRoute(path, handlers) // Update maxParams if paramsCount := countParams(path); paramsCount > engine.maxParams { engine.maxParams = paramsCount } } func (engine *Engine) PrintRoute(method string) { root := engine.trees.get(method) printNode(root.root, 0) } // debug use func printNode(node *node, level int) { fmt.Println("node.prefix: " + node.prefix) fmt.Println("node.ppath: " + node.ppath) fmt.Printf("level: %#v\n\n", level) for i := 0; i < len(node.children); i++ { printNode(node.children[i], level+1) } } func (engine *Engine) recv(ctx *app.RequestContext) { if rcv := recover(); rcv != nil { engine.PanicHandler(context.Background(), ctx) } } // ServeHTTP makes the router implement the Handler interface. func (engine *Engine) ServeHTTP(c context.Context, ctx *app.RequestContext) { ctx.SetBinder(engine.binder) if engine.PanicHandler != nil { defer engine.recv(ctx) } rPath := string(ctx.Request.URI().Path()) // align with https://datatracker.ietf.org/doc/html/rfc2616#section-5.2 if len(ctx.Request.Host()) == 0 && ctx.Request.Header.IsHTTP11() && bytesconv.B2s(ctx.Request.Method()) != consts.MethodConnect { ctx.SetHandlers(engine.Handlers) serveError(c, ctx, consts.StatusBadRequest, requiredHostBody) return } httpMethod := bytesconv.B2s(ctx.Request.Header.Method()) unescape := false if engine.options.UseRawPath { rPath = string(ctx.Request.URI().PathOriginal()) unescape = engine.options.UnescapePathValues } if engine.options.RemoveExtraSlash { rPath = utils.CleanPath(rPath) } // Follow RFC7230#section-5.3 if rPath == "" || rPath[0] != '/' { ctx.SetHandlers(engine.Handlers) serveError(c, ctx, consts.StatusBadRequest, default400Body) return } // if Params is re-assigned in HandlerFunc and the capacity is not enough we need to realloc maxParams := int(engine.maxParams) if cap(ctx.Params) < maxParams { ctx.Params = make(param.Params, 0, maxParams) } // Find root of the tree for the given HTTP method t := engine.trees paramsPointer := &ctx.Params for i, tl := 0, len(t); i < tl; i++ { if t[i].method != httpMethod { continue } // Find route in tree value := t[i].find(rPath, paramsPointer, unescape) if value.handlers != nil { ctx.SetHandlers(value.handlers) ctx.SetFullPath(value.fullPath) ctx.Next(c) return } if httpMethod != consts.MethodConnect && rPath != "/" { if value.tsr && engine.options.RedirectTrailingSlash { redirectTrailingSlash(ctx) return } if engine.options.RedirectFixedPath && redirectFixedPath(ctx, t[i].root, engine.options.RedirectFixedPath) { return } } break } if engine.options.HandleMethodNotAllowed { for _, tree := range engine.trees { if tree.method == httpMethod { continue } if value := tree.find(rPath, paramsPointer, unescape); value.handlers != nil { ctx.SetHandlers(engine.allNoMethod) serveError(c, ctx, consts.StatusMethodNotAllowed, default405Body) return } } } ctx.SetHandlers(engine.allNoRoute) serveError(c, ctx, consts.StatusNotFound, default404Body) } func (engine *Engine) allocateContext() *app.RequestContext { ctx := engine.NewContext() ctx.Request.SetMaxKeepBodySize(engine.options.MaxKeepBodySize) ctx.Response.SetMaxKeepBodySize(engine.options.MaxKeepBodySize) ctx.SetClientIPFunc(engine.clientIPFunc) ctx.SetFormValueFunc(engine.formValueFunc) return ctx } func serveError(c context.Context, ctx *app.RequestContext, code int, defaultMessage []byte) { ctx.SetStatusCode(code) ctx.Next(c) if ctx.Response.StatusCode() == code { // if body exists(maybe customized by users), leave it alone. if ctx.Response.HasBodyBytes() || ctx.Response.IsBodyStream() { return } ctx.Response.Header.Set("Content-Type", "text/plain") ctx.Response.SetBody(defaultMessage) } } func trailingSlashURL(ts string) string { tmpURI := ts + "/" if length := len(ts); length > 1 && ts[length-1] == '/' { tmpURI = ts[:length-1] } return tmpURI } func redirectTrailingSlash(c *app.RequestContext) { p := bytesconv.B2s(c.Request.URI().Path()) if prefix := utils.CleanPath(bytesconv.B2s(c.Request.Header.Peek("X-Forwarded-Prefix"))); prefix != "." { p = prefix + "/" + p } tmpURI := trailingSlashURL(p) query := c.Request.URI().QueryString() if len(query) > 0 { tmpURI = tmpURI + "?" + bytesconv.B2s(query) } c.Request.SetRequestURI(tmpURI) redirectRequest(c) } func redirectRequest(c *app.RequestContext) { code := consts.StatusMovedPermanently // Permanent redirect, request with GET method if bytesconv.B2s(c.Request.Header.Method()) != consts.MethodGet { code = consts.StatusTemporaryRedirect } c.Redirect(code, c.Request.URI().RequestURI()) } func redirectFixedPath(c *app.RequestContext, root *node, trailingSlash bool) bool { rPath := bytesconv.B2s(c.Request.URI().Path()) if fixedPath, ok := root.findCaseInsensitivePath(utils.CleanPath(rPath), trailingSlash); ok { c.Request.SetRequestURI(bytesconv.B2s(fixedPath)) redirectRequest(c) return true } return false } // NoRoute adds handlers for NoRoute. It returns a 404 code by default. func (engine *Engine) NoRoute(handlers ...app.HandlerFunc) { engine.noRoute = handlers engine.rebuild404Handlers() } // NoMethod sets the handlers called when the HTTP method does not match. func (engine *Engine) NoMethod(handlers ...app.HandlerFunc) { engine.noMethod = handlers engine.rebuild405Handlers() } func (engine *Engine) rebuild404Handlers() { engine.allNoRoute = engine.combineHandlers(engine.noRoute) } func (engine *Engine) rebuild405Handlers() { engine.allNoMethod = engine.combineHandlers(engine.noMethod) } // Use attaches a global middleware to the router. ie. the middleware attached though Use() will be // included in the handlers chain for every single request. Even 404, 405, static files... // // For example, this is the right place for a logger or error management middleware. func (engine *Engine) Use(middleware ...app.HandlerFunc) IRoutes { engine.RouterGroup.Use(middleware...) engine.rebuild404Handlers() engine.rebuild405Handlers() return engine } // LoadHTMLGlob loads HTML files identified by glob pattern // and associates the result with HTML renderer. func (engine *Engine) LoadHTMLGlob(pattern string) { tmpl := template.Must(template.New(""). Delims(engine.delims.Left, engine.delims.Right). Funcs(engine.funcMap). ParseGlob(pattern)) if engine.options.AutoReloadRender { files, err := filepath.Glob(pattern) if err != nil { hlog.SystemLogger().Errorf("LoadHTMLGlob: %v", err) return } engine.SetAutoReloadHTMLTemplate(tmpl, files) return } engine.SetHTMLTemplate(tmpl) } // LoadHTMLFiles loads a slice of HTML files // and associates the result with HTML renderer. func (engine *Engine) LoadHTMLFiles(files ...string) { tmpl := template.Must(template.New(""). Delims(engine.delims.Left, engine.delims.Right). Funcs(engine.funcMap). ParseFiles(files...)) if engine.options.AutoReloadRender { engine.SetAutoReloadHTMLTemplate(tmpl, files) return } engine.SetHTMLTemplate(tmpl) } // SetHTMLTemplate associate a template with HTML renderer. func (engine *Engine) SetHTMLTemplate(tmpl *template.Template) { engine.htmlRender = render.HTMLProduction{Template: tmpl.Funcs(engine.funcMap)} } // SetAutoReloadHTMLTemplate associate a template with HTML renderer. func (engine *Engine) SetAutoReloadHTMLTemplate(tmpl *template.Template, files []string) { engine.htmlRender = &render.HTMLDebug{ Template: tmpl, Files: files, FuncMap: engine.funcMap, Delims: engine.delims, RefreshInterval: engine.options.AutoReloadInterval, } } // SetFuncMap sets the funcMap used for template.funcMap. func (engine *Engine) SetFuncMap(funcMap template.FuncMap) { engine.funcMap = funcMap } func (engine *Engine) SetClientIPFunc(f app.ClientIP) { engine.clientIPFunc = f } func (engine *Engine) SetFormValueFunc(f app.FormValueFunc) { engine.formValueFunc = f } // Delims sets template left and right delims and returns an Engine instance. func (engine *Engine) Delims(left, right string) *Engine { engine.delims = render.Delims{Left: left, Right: right} return engine } func (engine *Engine) acquireHijackConn(c network.Conn) *hijackConn { if engine.NoHijackConnPool { return &hijackConn{ Conn: c, e: engine, } } v := engine.hijackConnPool.Get() if v == nil { return &hijackConn{ Conn: c, e: engine, } } hjc := v.(*hijackConn) hjc.Conn = c return hjc } func (engine *Engine) releaseHijackConn(hjc *hijackConn) { if engine.NoHijackConnPool { return } hjc.Conn = nil engine.hijackConnPool.Put(hjc) } func (engine *Engine) hijackConnHandler(c network.Conn, h app.HijackHandler) { hjc := engine.acquireHijackConn(c) h(hjc) if !engine.KeepHijackedConns { c.Close() engine.releaseHijackConn(hjc) } } // Routes returns a slice of registered routes, including some useful information, such as: // the http method, path and the handler name. func (engine *Engine) Routes() (routes RoutesInfo) { for _, tree := range engine.trees { routes = iterate(tree.method, routes, tree.root) } return routes } func (engine *Engine) AddProtocol(protocol string, factory interface{}) { engine.protocolSuite.Add(protocol, factory) } // SetAltHeader sets the value of "Alt-Svc" header for protocols other than targetProtocol. func (engine *Engine) SetAltHeader(targetProtocol, altHeaderValue string) { engine.protocolSuite.SetAltHeader(targetProtocol, altHeaderValue) } func (engine *Engine) HasServer(name string) bool { return engine.protocolSuite.Get(name) != nil } // iterate iterates the method tree by depth firstly. func iterate(method string, routes RoutesInfo, root *node) RoutesInfo { if len(root.handlers) > 0 { handlerFunc := root.handlers.Last() routes = append(routes, RouteInfo{ Method: method, Path: root.ppath, Handler: utils.NameOfFunction(handlerFunc), HandlerFunc: handlerFunc, }) } for _, child := range root.children { routes = iterate(method, routes, child) } if root.paramChild != nil { routes = iterate(method, routes, root.paramChild) } if root.anyChild != nil { routes = iterate(method, routes, root.anyChild) } return routes } // for built-in http1 impl only. func newHttp1OptionFromEngine(engine *Engine) *http1.Option { opt := &http1.Option{ StreamRequestBody: engine.options.StreamRequestBody, GetOnly: engine.options.GetOnly, DisablePreParseMultipartForm: engine.options.DisablePreParseMultipartForm, DisableKeepalive: engine.options.DisableKeepalive, NoDefaultServerHeader: engine.options.NoDefaultServerHeader, MaxRequestBodySize: engine.options.MaxRequestBodySize, MaxHeaderBytes: engine.options.MaxHeaderBytes, IdleTimeout: engine.options.IdleTimeout, ReadTimeout: engine.options.ReadTimeout, ServerName: engine.GetServerName(), ContinueHandler: engine.ContinueHandler, TLS: engine.options.TLS, HTMLRender: engine.htmlRender, EnableTrace: engine.IsTraceEnable(), HijackConnHandle: engine.HijackConnHandle, DisableHeaderNamesNormalizing: engine.options.DisableHeaderNamesNormalizing, NoDefaultDate: engine.options.NoDefaultDate, NoDefaultContentType: engine.options.NoDefaultContentType, } // Idle timeout of standard network must not be zero. Set it to -1 seconds if it is zero. // Due to the different triggering ways of the network library, see the actual use of this value for the detailed reasons. if opt.IdleTimeout == 0 && engine.GetTransporterName() == "standard" { opt.IdleTimeout = -1 } return opt } func versionToALNP(v uint32) string { if v == network.Version1 || v == network.Version2 { return suite.HTTP3 } if v == network.VersionTLS || v == network.VersionDraft29 { return suite.HTTP3Draft29 } return "" } // MarkAsRunning will mark the status of the hertz engine as "running". // Warning: do not call this method by yourself, unless you know what you are doing. func (engine *Engine) MarkAsRunning() (err error) { if !atomic.CompareAndSwapUint32(&engine.status, statusInitialized, statusRunning) { return errAlreadyRunning } return nil } ================================================ FILE: pkg/route/engine_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * The MIT License (MIT) * * Copyright (c) 2014 Manuel Martínez-Almeida * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors */ package route import ( "context" "crypto/tls" "errors" "fmt" "html/template" "io/ioutil" "net" "net/http" "sync" "sync/atomic" "testing" "time" "github.com/cloudwego/hertz/internal/test/mock/binder" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/app/server/binding" "github.com/cloudwego/hertz/pkg/app/server/registry" "github.com/cloudwego/hertz/pkg/common/config" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/test/mock" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/network/standard" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/protocol/suite" "github.com/cloudwego/hertz/pkg/route/param" ) func TestNew_Engine(t *testing.T) { defaultTransporter = standard.NewTransporter opt := config.NewOptions([]config.Option{}) router := NewEngine(opt) assert.DeepEqual(t, "standard", router.GetTransporterName()) assert.DeepEqual(t, "/", router.basePath) assert.DeepEqual(t, router.engine, router) assert.DeepEqual(t, 0, len(router.Handlers)) } func TestNew_Engine_WithTransporter(t *testing.T) { defaultTransporter = newMockTransporter opt := config.NewOptions([]config.Option{}) router := NewEngine(opt) assert.DeepEqual(t, "route", router.GetTransporterName()) defaultTransporter = newMockTransporter opt.TransporterNewer = standard.NewTransporter router = NewEngine(opt) assert.DeepEqual(t, "standard", router.GetTransporterName()) assert.DeepEqual(t, "route", GetTransporterName()) } func TestGetTransporterName(t *testing.T) { name := getTransporterName(&fakeTransporter{}) assert.DeepEqual(t, "route", name) } func TestEngineUnescape(t *testing.T) { e := NewEngine(config.NewOptions(nil)) routes := []string{ "/*all", "/cmd/:tool/", "/src/*filepath", "/search/:query", "/info/:user/project/:project", "/info/:user", } for _, r := range routes { e.GET(r, func(c context.Context, ctx *app.RequestContext) { ctx.String(consts.StatusOK, ctx.Param(ctx.Query("key"))) }) } testRoutes := []struct { route string key string want string }{ {"/", "", ""}, {"/cmd/%E4%BD%A0%E5%A5%BD/", "tool", "你好"}, {"/src/some/%E4%B8%96%E7%95%8C.png", "filepath", "some/世界.png"}, {"/info/%E4%BD%A0%E5%A5%BD/project/%E4%B8%96%E7%95%8C", "user", "你好"}, {"/info/%E4%BD%A0%E5%A5%BD/project/%E4%B8%96%E7%95%8C", "project", "世界"}, } for _, tr := range testRoutes { w := performRequest(e, http.MethodGet, tr.route+"?key="+tr.key) assert.DeepEqual(t, consts.StatusOK, w.Code) assert.DeepEqual(t, tr.want, w.Body.String()) } } func TestEngineUnescapeRaw(t *testing.T) { e := NewEngine(config.NewOptions(nil)) e.options.UseRawPath = true routes := []string{ "/*all", "/cmd/:tool/", "/src/*filepath", "/search/:query", "/info/:user/project/:project", "/info/:user", } for _, r := range routes { e.GET(r, func(c context.Context, ctx *app.RequestContext) { ctx.String(consts.StatusOK, ctx.Param(ctx.Query("key"))) }) } testRoutes := []struct { route string key string want string }{ {"/", "", ""}, {"/cmd/test/", "tool", "test"}, {"/src/some/file.png", "filepath", "some/file.png"}, {"/src/some/file+test.png", "filepath", "some/file test.png"}, {"/src/some/file++++%%%%test.png", "filepath", "some/file++++%%%%test.png"}, {"/src/some/file%2Ftest.png", "filepath", "some/file/test.png"}, {"/search/someth!ng+in+ünìcodé", "query", "someth!ng in ünìcodé"}, {"/info/gordon/project/go", "user", "gordon"}, {"/info/gordon/project/go", "project", "go"}, {"/info/slash%2Fgordon", "user", "slash/gordon"}, {"/info/slash%2Fgordon/project/Project%20%231", "user", "slash/gordon"}, {"/info/slash%2Fgordon/project/Project%20%231", "project", "Project #1"}, {"/info/slash%%%%", "user", "slash%%%%"}, {"/info/slash%%%%2Fgordon/project/Project%%%%20%231", "user", "slash%%%%2Fgordon"}, {"/info/slash%%%%2Fgordon/project/Project%%%%20%231", "project", "Project%%%%20%231"}, } for _, tr := range testRoutes { w := performRequest(e, http.MethodGet, tr.route+"?key="+tr.key) assert.DeepEqual(t, consts.StatusOK, w.Code) assert.DeepEqual(t, tr.want, w.Body.String()) } } func TestConnectionClose(t *testing.T) { engine := NewEngine(config.NewOptions(nil)) atomic.StoreUint32(&engine.status, statusRunning) engine.Init() engine.GET("/foo", func(c context.Context, ctx *app.RequestContext) { ctx.String(consts.StatusOK, "ok") }) conn := mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\nConnection: close\r\n\r\n") err := engine.Serve(context.Background(), conn) assert.True(t, errors.Is(err, errs.ErrShortConnection)) } func TestConnectionClose01(t *testing.T) { engine := NewEngine(config.NewOptions(nil)) atomic.StoreUint32(&engine.status, statusRunning) engine.Init() engine.GET("/foo", func(c context.Context, ctx *app.RequestContext) { ctx.SetConnectionClose() ctx.String(consts.StatusOK, "ok") }) conn := mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") err := engine.Serve(context.Background(), conn) assert.True(t, errors.Is(err, errs.ErrShortConnection)) } func TestIdleTimeout(t *testing.T) { engine := NewEngine(config.NewOptions(nil)) engine.options.IdleTimeout = 0 atomic.StoreUint32(&engine.status, statusRunning) engine.Init() engine.GET("/foo", func(c context.Context, ctx *app.RequestContext) { time.Sleep(100 * time.Millisecond) ctx.String(consts.StatusOK, "ok") }) conn := mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") ch := make(chan error) startCh := make(chan error) go func() { <-startCh ch <- engine.Serve(context.Background(), conn) }() close(startCh) select { case err := <-ch: if err != nil { t.Errorf("err happened: %s", err) } return case <-time.Tick(120 * time.Millisecond): t.Errorf("timeout! should have been finished in 120ms...") } } func TestIdleTimeout01(t *testing.T) { engine := NewEngine(config.NewOptions(nil)) engine.options.IdleTimeout = 1 * time.Second atomic.StoreUint32(&engine.status, statusRunning) engine.Init() atomic.StoreUint32(&engine.status, statusRunning) engine.GET("/foo", func(c context.Context, ctx *app.RequestContext) { time.Sleep(10 * time.Millisecond) ctx.String(consts.StatusOK, "ok") }) conn := mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") ch := make(chan error) startCh := make(chan error) go func() { <-startCh ch <- engine.Serve(context.Background(), conn) }() close(startCh) select { case <-ch: t.Errorf("cannot return this early! should wait for at least 1s...") case <-time.Tick(1 * time.Second): return } } func TestIdleTimeout03(t *testing.T) { engine := NewEngine(config.NewOptions(nil)) engine.options.IdleTimeout = 0 engine.transport = standard.NewTransporter(engine.options) atomic.StoreUint32(&engine.status, statusRunning) engine.Init() atomic.StoreUint32(&engine.status, statusRunning) engine.GET("/foo", func(c context.Context, ctx *app.RequestContext) { time.Sleep(50 * time.Millisecond) ctx.String(consts.StatusOK, "ok") }) conn := mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n" + "GET /foo HTTP/1.1\r\nHost: google.com\r\nConnection: close\r\n\r\n") ch := make(chan error) startCh := make(chan error) go func() { <-startCh ch <- engine.Serve(context.Background(), conn) }() close(startCh) select { case err := <-ch: if !errors.Is(err, errs.ErrShortConnection) { t.Errorf("err should be ErrShortConnection, but got %s", err) } return case <-time.Tick(200 * time.Millisecond): t.Errorf("timeout! should have been finished in 200ms...") } } func TestEngine_Routes(t *testing.T) { engine := NewEngine(config.NewOptions(nil)) engine.GET("/", handlerTest1) engine.GET("/user", handlerTest2) engine.GET("/user/:name/*action", handlerTest1) engine.GET("/anonymous1", func(c context.Context, ctx *app.RequestContext) {}) // TestEngine_Routes.func1 engine.POST("/user", handlerTest2) engine.POST("/user/:name/*action", handlerTest2) engine.POST("/anonymous2", func(c context.Context, ctx *app.RequestContext) {}) // TestEngine_Routes.func2 group := engine.Group("/v1") { group.GET("/user", handlerTest1) group.POST("/login", handlerTest2) } engine.Static("/static", ".") list := engine.Routes() assert.DeepEqual(t, 11, len(list)) assertRoutePresent(t, list, RouteInfo{ Method: "GET", Path: "/", Handler: "github.com/cloudwego/hertz/pkg/route.handlerTest1", }) assertRoutePresent(t, list, RouteInfo{ Method: "GET", Path: "/user", Handler: "github.com/cloudwego/hertz/pkg/route.handlerTest2", }) assertRoutePresent(t, list, RouteInfo{ Method: "GET", Path: "/user/:name/*action", Handler: "github.com/cloudwego/hertz/pkg/route.handlerTest1", }) assertRoutePresent(t, list, RouteInfo{ Method: "GET", Path: "/v1/user", Handler: "github.com/cloudwego/hertz/pkg/route.handlerTest1", }) assertRoutePresent(t, list, RouteInfo{ Method: "GET", Path: "/static/*filepath", Handler: "github.com/cloudwego/hertz/pkg/app.(*fsHandler).handleRequest-fm", }) assertRoutePresent(t, list, RouteInfo{ Method: "GET", Path: "/anonymous1", Handler: "github.com/cloudwego/hertz/pkg/route.TestEngine_Routes.func1", }) assertRoutePresent(t, list, RouteInfo{ Method: "POST", Path: "/user", Handler: "github.com/cloudwego/hertz/pkg/route.handlerTest2", }) assertRoutePresent(t, list, RouteInfo{ Method: "POST", Path: "/user/:name/*action", Handler: "github.com/cloudwego/hertz/pkg/route.handlerTest2", }) assertRoutePresent(t, list, RouteInfo{ Method: "POST", Path: "/anonymous2", Handler: "github.com/cloudwego/hertz/pkg/route.TestEngine_Routes.func2", }) assertRoutePresent(t, list, RouteInfo{ Method: "POST", Path: "/v1/login", Handler: "github.com/cloudwego/hertz/pkg/route.handlerTest2", }) assertRoutePresent(t, list, RouteInfo{ Method: "HEAD", Path: "/static/*filepath", Handler: "github.com/cloudwego/hertz/pkg/app.(*fsHandler).handleRequest-fm", }) } func handlerTest1(c context.Context, ctx *app.RequestContext) {} func handlerTest2(c context.Context, ctx *app.RequestContext) {} func assertRoutePresent(t *testing.T, gets RoutesInfo, want RouteInfo) { for _, get := range gets { if get.Path == want.Path && get.Method == want.Method && get.Handler == want.Handler { return } } t.Errorf("route not found: %v", want) } func TestGetNextProto(t *testing.T) { e := NewEngine(config.NewOptions(nil)) conn := &mockConn{} proto, err := e.getNextProto(conn) if proto != "h2" { t.Errorf("unexpected proto: %#v, expected: %#v", proto, "h2") } if err != nil { t.Errorf("unexpected error: %s", err.Error()) } } func formatAsDate(t time.Time) string { year, month, day := t.Date() return fmt.Sprintf("%d/%02d/%02d", year, month, day) } func TestRenderHtml(t *testing.T) { e := NewEngine(config.NewOptions(nil)) e.Delims("{[{", "}]}") e.SetFuncMap(template.FuncMap{ "formatAsDate": formatAsDate, }) e.LoadHTMLGlob("../common/testdata/template/htmltemplate.html") e.GET("/templateName", func(c context.Context, ctx *app.RequestContext) { ctx.HTML(http.StatusOK, "htmltemplate.html", map[string]interface{}{ "now": time.Date(2017, 0o7, 0o1, 0, 0, 0, 0, time.UTC), }) }) rr := performRequest(e, "GET", "/templateName") b, _ := ioutil.ReadAll(rr.Body) assert.DeepEqual(t, consts.StatusOK, rr.Code) assert.DeepEqual(t, []byte("

Date: 2017/07/01

"), b) assert.DeepEqual(t, "text/html; charset=utf-8", rr.Header().Get("Content-Type")) } func TestTransporterName(t *testing.T) { SetTransporter(standard.NewTransporter) assert.DeepEqual(t, "standard", GetTransporterName()) SetTransporter(newMockTransporter) assert.DeepEqual(t, "route", GetTransporterName()) } func newMockTransporter(options *config.Options) network.Transporter { return &mockTransporter{} } type mockTransporter struct{} func (m *mockTransporter) ListenAndServe(onData network.OnData) (err error) { panic("implement me") } func (m *mockTransporter) Close() error { panic("implement me") } func (m *mockTransporter) Shutdown(ctx context.Context) error { panic("implement me") } func TestRenderHtmlOfGlobWithAutoRender(t *testing.T) { opt := config.NewOptions([]config.Option{}) opt.AutoReloadRender = true e := NewEngine(opt) e.Delims("{[{", "}]}") e.SetFuncMap(template.FuncMap{ "formatAsDate": formatAsDate, }) e.LoadHTMLGlob("../common/testdata/template/htmltemplate.html") e.GET("/templateName", func(c context.Context, ctx *app.RequestContext) { ctx.HTML(http.StatusOK, "htmltemplate.html", map[string]interface{}{ "now": time.Date(2017, 0o7, 0o1, 0, 0, 0, 0, time.UTC), }) }) rr := performRequest(e, "GET", "/templateName") b, _ := ioutil.ReadAll(rr.Body) assert.DeepEqual(t, consts.StatusOK, rr.Code) assert.DeepEqual(t, []byte("

Date: 2017/07/01

"), b) assert.DeepEqual(t, "text/html; charset=utf-8", rr.Header().Get("Content-Type")) } func TestSetClientIPAndSetFormValue(t *testing.T) { opt := config.NewOptions([]config.Option{}) e := NewEngine(opt) e.SetClientIPFunc(func(ctx *app.RequestContext) string { return "1.1.1.1" }) e.SetFormValueFunc(func(requestContext *app.RequestContext, s string) []byte { return []byte(s) }) e.GET("/ping", func(c context.Context, ctx *app.RequestContext) { assert.DeepEqual(t, ctx.ClientIP(), "1.1.1.1") assert.DeepEqual(t, string(ctx.FormValue("key")), "key") }) _ = performRequest(e, "GET", "/ping") } func TestRenderHtmlOfFilesWithAutoRender(t *testing.T) { opt := config.NewOptions([]config.Option{}) opt.AutoReloadRender = true e := NewEngine(opt) e.Delims("{[{", "}]}") e.SetFuncMap(template.FuncMap{ "formatAsDate": formatAsDate, }) e.LoadHTMLFiles("../common/testdata/template/htmltemplate.html") e.GET("/templateName", func(c context.Context, ctx *app.RequestContext) { ctx.HTML(http.StatusOK, "htmltemplate.html", map[string]interface{}{ "now": time.Date(2017, 0o7, 0o1, 0, 0, 0, 0, time.UTC), }) }) rr := performRequest(e, "GET", "/templateName") b, _ := ioutil.ReadAll(rr.Body) assert.DeepEqual(t, consts.StatusOK, rr.Code) assert.DeepEqual(t, []byte("

Date: 2017/07/01

"), b) assert.DeepEqual(t, "text/html; charset=utf-8", rr.Header().Get("Content-Type")) } func TestSetEngineRun(t *testing.T) { e := NewEngine(config.NewOptions(nil)) e.Init() assert.True(t, !e.IsRunning()) e.MarkAsRunning() assert.True(t, e.IsRunning()) } type mockConn struct{} func (m *mockConn) SetWriteTimeout(t time.Duration) error { // TODO implement me panic("implement me") } func (m *mockConn) ReadBinary(n int) (p []byte, err error) { panic("implement me") } func (m *mockConn) Handshake() error { return nil } func (m *mockConn) ConnectionState() tls.ConnectionState { return tls.ConnectionState{ NegotiatedProtocol: "h2", } } func (m *mockConn) SetReadTimeout(t time.Duration) error { return nil } func (m *mockConn) Read(b []byte) (n int, err error) { panic("implement me") } func (m *mockConn) Write(b []byte) (n int, err error) { panic("implement me") } func (m *mockConn) Close() error { panic("implement me") } func (m *mockConn) LocalAddr() net.Addr { panic("implement me") } func (m *mockConn) RemoteAddr() net.Addr { return &net.TCPAddr{ IP: net.ParseIP("126.0.0.5"), Port: 8888, Zone: "", } } func (m *mockConn) SetDeadline(t time.Time) error { panic("implement me") } func (m *mockConn) SetReadDeadline(t time.Time) error { panic("implement me") } func (m *mockConn) SetWriteDeadline(t time.Time) error { panic("implement me") } func (m *mockConn) Release() error { panic("implement me") } func (m *mockConn) Peek(i int) ([]byte, error) { panic("implement me") } func (m *mockConn) Skip(n int) error { panic("implement me") } func (m *mockConn) ReadByte() (byte, error) { panic("implement me") } func (m *mockConn) Next(i int) ([]byte, error) { panic("implement me") } func (m *mockConn) Len() int { panic("implement me") } func (m *mockConn) Malloc(n int) (buf []byte, err error) { panic("implement me") } func (m *mockConn) WriteBinary(b []byte) (n int, err error) { panic("implement me") } func (m *mockConn) Flush() error { panic("implement me") } type fakeTransporter struct{} func (f *fakeTransporter) Close() error { // TODO implement me panic("implement me") } func (f *fakeTransporter) Shutdown(ctx context.Context) error { // TODO implement me panic("implement me") } func (f *fakeTransporter) ListenAndServe(onData network.OnData) error { // TODO implement me panic("implement me") } type mockNonValidator struct{} func (m *mockNonValidator) ValidateStruct(interface{}) error { return fmt.Errorf("test mock") } func TestInitBinderAndValidator(t *testing.T) { defer func() { if r := recover(); r != nil { t.Errorf("unexpected panic, %v", r) } }() opt := config.NewOptions([]config.Option{}) bindConfig := binding.NewBindConfig() bindConfig.LooseZeroMode = true opt.BindConfig = bindConfig opt.CustomBinder = binder.NewBinder() mockValidatorFunc := binding.ValidatorFunc(func(_ *protocol.Request, _ interface{}) error { return errors.New("test mock") }) opt.CustomValidator = mockValidatorFunc NewEngine(opt) validateConfig := binding.NewValidateConfig() opt.ValidateConfig = validateConfig opt.CustomValidator = nil NewEngine(opt) } func TestInitValidatorPanic(t *testing.T) { defer func() { if r := recover(); r == nil { t.Errorf("expect a panic, but get nil") } }() opt := config.NewOptions([]config.Option{}) bindConfig := binding.NewBindConfig() bindConfig.LooseZeroMode = true opt.BindConfig = bindConfig opt.CustomValidator = &mockNonValidator{} NewEngine(opt) } func TestBindConfig(t *testing.T) { type Req struct { A int `query:"a"` } opt := config.NewOptions([]config.Option{}) bindConfig := binding.NewBindConfig() bindConfig.LooseZeroMode = false opt.BindConfig = bindConfig e := NewEngine(opt) e.GET("/bind", func(c context.Context, ctx *app.RequestContext) { var req Req err := ctx.BindAndValidate(&req) if err == nil { t.Fatal("expect an error") } }) performRequest(e, "GET", "/bind?a=") bindConfig = binding.NewBindConfig() bindConfig.LooseZeroMode = true opt.BindConfig = bindConfig e = NewEngine(opt) e.GET("/bind", func(c context.Context, ctx *app.RequestContext) { var req Req err := ctx.BindAndValidate(&req) if err != nil { t.Fatal("unexpected error") } assert.DeepEqual(t, 0, req.A) }) performRequest(e, "GET", "/bind?a=") } type ValidateError struct { ErrType, FailField, Msg string } // Error implements error interface. func (e *ValidateError) Error() string { if e.Msg != "" { return e.ErrType + ": expr_path=" + e.FailField + ", cause=" + e.Msg } return e.ErrType + ": expr_path=" + e.FailField + ", cause=invalid" } func TestValidateConfigSetErrorFactory(t *testing.T) { type TestValidate struct { B int `query:"b" vd:"$>100"` } opt := config.NewOptions([]config.Option{}) CustomValidateErrFunc := func(failField, msg string) error { err := ValidateError{ ErrType: "validateErr", FailField: "[validateFailField]: " + failField, Msg: "[validateErrMsg]: " + msg, } return &err } validateConfig := binding.NewValidateConfig() validateConfig.SetValidatorErrorFactory(CustomValidateErrFunc) opt.ValidateConfig = validateConfig e := NewEngine(opt) e.GET("/bind", func(c context.Context, ctx *app.RequestContext) { var req TestValidate err := ctx.BindAndValidate(&req) if err == nil { t.Fatal("expect an error") } assert.DeepEqual(t, "validateErr: expr_path=[validateFailField]: B, cause=[validateErrMsg]: ", err.Error()) }) performRequest(e, "GET", "/bind?b=1") } func TestValidateConfigSetErrorFactoryWithBindConfig(t *testing.T) { type TestValidate struct { B int `query:"b" vd:"$>100"` } opt := config.NewOptions([]config.Option{}) CustomValidateErrFunc := func(failField, msg string) error { err := ValidateError{ ErrType: "validateErr", FailField: "[validateFailField]: " + failField, Msg: "[validateErrMsg]: " + msg, } return &err } validateConfig := binding.NewValidateConfig() validateConfig.SetValidatorErrorFactory(CustomValidateErrFunc) opt.ValidateConfig = validateConfig opt.BindConfig = binding.NewBindConfig() e := NewEngine(opt) e.GET("/bind", func(c context.Context, ctx *app.RequestContext) { var req TestValidate err := ctx.BindAndValidate(&req) if err == nil { t.Fatal("expect an error") } assert.DeepEqual(t, "validateErr: expr_path=[validateFailField]: B, cause=[validateErrMsg]: ", err.Error()) }) performRequest(e, "GET", "/bind?b=1") } func TestCustomBinder(t *testing.T) { type Req struct { A int `query:"a"` } opt := config.NewOptions([]config.Option{}) opt.CustomBinder = binder.NewBinder() e := NewEngine(opt) e.GET("/bind", func(c context.Context, ctx *app.RequestContext) { var req Req err := ctx.BindAndValidate(&req) if err != nil { t.Fatal("unexpected error") } assert.NotEqual(t, 2, req.A) }) performRequest(e, "GET", "/bind?a=2") } func TestValidateRegValidateFunc(t *testing.T) { type Req struct { A int `query:"a" vd:"f($)"` } opt := config.NewOptions([]config.Option{}) validateConfig := &binding.ValidateConfig{} validateConfig.MustRegValidateFunc("f", func(args ...interface{}) error { return fmt.Errorf("test error") }) e := NewEngine(opt) e.GET("/validate", func(c context.Context, ctx *app.RequestContext) { var req Req err := ctx.BindAndValidate(&req) assert.NotNil(t, err) assert.DeepEqual(t, "test error", err.Error()) }) performRequest(e, "GET", "/validate?a=2") } func TestCustomValidator(t *testing.T) { type Req struct { A int `query:"a" vd:"d($)"` } opt := config.NewOptions([]config.Option{}) validateConfig := &binding.ValidateConfig{} validateConfig.MustRegValidateFunc("d", func(args ...interface{}) error { return fmt.Errorf("test error") }) opt.CustomValidator = binding.ValidatorFunc(func(_ *protocol.Request, _ interface{}) error { return errors.New("test mock") }) e := NewEngine(opt) e.GET("/validate", func(c context.Context, ctx *app.RequestContext) { var req Req err := ctx.BindAndValidate(&req) assert.NotNil(t, err) assert.DeepEqual(t, "test mock", err.Error()) }) performRequest(e, "GET", "/validate?a=2") } func TestCustomValidatorFunc(t *testing.T) { type Req struct { A int `query:"a" vd:"$>10"` } validatorFunc := func(req *protocol.Request, v any) error { return fmt.Errorf("test validator func") } opt := config.NewOptions([]config.Option{}) opt.CustomValidator = validatorFunc e := NewEngine(opt) e.GET("/validate", func(c context.Context, ctx *app.RequestContext) { var req Req err := ctx.BindAndValidate(&req) assert.NotNil(t, err) assert.DeepEqual(t, "test validator func", err.Error()) }) performRequest(e, "GET", "/validate?a=2") } func TestWithCustomValidatorFunc(t *testing.T) { type Req struct { A int `query:"a" vd:"$>10"` } validatorFunc := func(req *protocol.Request, v any) error { return fmt.Errorf("test with custom validator func") } opt := config.NewOptions([]config.Option{}) opt.CustomValidator = validatorFunc e := NewEngine(opt) e.GET("/validate", func(c context.Context, ctx *app.RequestContext) { var req Req err := ctx.BindAndValidate(&req) assert.NotNil(t, err) assert.DeepEqual(t, "test with custom validator func", err.Error()) }) performRequest(e, "GET", "/validate?a=2") } func TestCustomValidatorInvalidType(t *testing.T) { defer func() { if r := recover(); r == nil { t.Fatal("expected panic for invalid validator type") } }() opt := config.NewOptions([]config.Option{}) opt.CustomValidator = "invalid validator type" NewEngine(opt) } func TestWithCustomValidatorConversion(t *testing.T) { type Req struct { A int `query:"a" vd:"$>10"` } // Create a config using the deprecated WithCustomValidator function opt := config.NewOptions([]config.Option{}) // Simulate using the WithCustomValidator function withValidatorOpt := config.Option{F: func(o *config.Options) { o.CustomValidator = binding.ValidatorFunc(func(_ *protocol.Request, _ interface{}) error { return errors.New("test mock") }) }} withValidatorOpt.F(opt) // Verify it was converted to ValidatorFunc _, isValidatorFunc := opt.CustomValidator.(binding.ValidatorFunc) assert.True(t, isValidatorFunc) e := NewEngine(opt) e.GET("/validate", func(c context.Context, ctx *app.RequestContext) { var req Req err := ctx.BindAndValidate(&req) assert.NotNil(t, err) assert.DeepEqual(t, "test mock", err.Error()) }) performRequest(e, "GET", "/validate?a=2") } var errTestDeregsitry = fmt.Errorf("test deregsitry error") type mockDeregsitryErr struct{} var _ registry.Registry = &mockDeregsitryErr{} func (e mockDeregsitryErr) Register(*registry.Info) error { return nil } func (e mockDeregsitryErr) Deregister(*registry.Info) error { return errTestDeregsitry } type mockStandardTransporter struct { network.Transporter } func (m *mockStandardTransporter) Shutdown(ctx context.Context) error { // FIXME: standard.Transporter mindlessly blocks on ctx.Done() // This change help tests run faster newctx, cancel := context.WithCancel(ctx) cancel() return m.Transporter.Shutdown(newctx) } func newMockStandardTransporter(opt *config.Options) network.Transporter { return &mockStandardTransporter{standard.NewTransporter(opt)} } func TestEngineShutdown(t *testing.T) { opt := config.NewOptions(nil) opt.Addr = "127.0.0.1:0" opt.TransporterNewer = newMockStandardTransporter mockCtxCallback := func(ctx context.Context) { // Shutdown adds `ExitWaitTimeout` to the given context dl, ok := ctx.Deadline() assert.Assert(t, ok) assert.Assert(t, opt.ExitWaitTimeout-time.Until(dl) < 50*time.Millisecond, // runtime schedule latency opt.ExitWaitTimeout, time.Until(dl)) } var wg sync.WaitGroup var engine *Engine runEngine := func() { wg.Add(1) go func() { defer wg.Done() engine.Run() }() // wait for engine to start time.Sleep(100 * time.Millisecond) } shutdownEngine := func(ctx context.Context, expectErr error, expectStatus uint32) { t.Helper() err := engine.Shutdown(ctx) if expectErr == nil { assert.Nil(t, err) } else { assert.DeepEqual(t, expectErr, err) } if expectStatus != 0 { if expectStatus == statusShutdown { assert.DeepEqual(t, expectStatus, atomic.LoadUint32(&engine.status)) // make sure engine.Run() returns // in case registry fails, it blocks engine.transport.Shutdown(ctx) expectStatus = statusClosed } wg.Wait() // wait engine.Run() returns } assert.DeepEqual(t, expectStatus, atomic.LoadUint32(&engine.status)) } // case: serve not running error engine = NewEngine(opt) shutdownEngine(context.Background(), errStatusNotRunning, 0) // case: serve successfully running and shutdown engine = NewEngine(opt) engine.OnShutdown = []CtxCallback{mockCtxCallback} runEngine() shutdownEngine(context.Background(), nil, statusClosed) // case: serve successfully running and shutdown with deregistry error engine = NewEngine(opt) engine.OnShutdown = []CtxCallback{mockCtxCallback} engine.options.Registry = &mockDeregsitryErr{} runEngine() shutdownEngine(context.Background(), errTestDeregsitry, statusShutdown) engine.options.Registry = nil // case: ctx cancelled when Shutdown engine = NewEngine(opt) // make sure callback is in progress but ctx cancelled engine.OnShutdown = []CtxCallback{func(ctx context.Context) { time.Sleep(50 * time.Millisecond) }} runEngine() ctx, cancel := context.WithCancel(context.Background()) cancel() shutdownEngine(ctx, nil, statusClosed) } type mockStreamer struct{} type mockProtocolServer struct{} func (s *mockStreamer) Serve(c context.Context, conn network.StreamConn) error { return nil } func (s *mockProtocolServer) Serve(c context.Context, conn network.Conn) error { return nil } type mockStreamConn struct { network.StreamConn version string } var _ network.StreamConn = &mockStreamConn{} func (m *mockStreamConn) GetVersion() uint32 { return network.Version1 } func TestEngineServeStream(t *testing.T) { engine := &Engine{ options: &config.Options{ ALPN: true, TLS: &tls.Config{}, }, protocolStreamServers: map[string]protocol.StreamServer{ suite.HTTP3: &mockStreamer{}, }, } // Test ALPN path conn := &mockStreamConn{version: suite.HTTP3} err := engine.ServeStream(context.Background(), conn) assert.Nil(t, err) // Test default path engine.options.ALPN = false conn = &mockStreamConn{} err = engine.ServeStream(context.Background(), conn) assert.Nil(t, err) // Test unsupported protocol engine.protocolStreamServers = map[string]protocol.StreamServer{} conn = &mockStreamConn{} err = engine.ServeStream(context.Background(), conn) assert.DeepEqual(t, errs.ErrNotSupportProtocol, err) } func TestEngineServe(t *testing.T) { engine := NewEngine(config.NewOptions(nil)) engine.protocolServers[suite.HTTP1] = &mockProtocolServer{} engine.protocolServers[suite.HTTP2] = &mockProtocolServer{} // test H2C path ctx := context.Background() conn := mock.NewConn("PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n") engine.options.H2C = true err := engine.Serve(ctx, conn) assert.Nil(t, err) // test ALPN path ctx = context.Background() conn = mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") engine.options.H2C = false engine.options.ALPN = true engine.options.TLS = &tls.Config{} err = engine.Serve(ctx, conn) assert.Nil(t, err) // test HTTP1 path engine.options.ALPN = false err = engine.Serve(ctx, conn) assert.Nil(t, err) } func TestOndata(t *testing.T) { ctx := context.Background() engine := NewEngine(config.NewOptions(nil)) // test stream conn streamConn := &mockStreamConn{version: suite.HTTP3} engine.protocolStreamServers[suite.HTTP3] = &mockStreamer{} err := engine.onData(ctx, streamConn) assert.Nil(t, err) // test conn conn := mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") engine.protocolServers[suite.HTTP1] = &mockProtocolServer{} err = engine.onData(ctx, conn) assert.Nil(t, err) } func TestAcquireHijackConn(t *testing.T) { engine := &Engine{ NoHijackConnPool: false, } // test conn pool conn := mock.NewConn("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") hijackConn := engine.acquireHijackConn(conn) assert.NotNil(t, hijackConn) assert.NotNil(t, hijackConn.Conn) assert.DeepEqual(t, engine, hijackConn.e) assert.DeepEqual(t, conn, hijackConn.Conn) // test no conn pool engine.NoHijackConnPool = true hijackConn = engine.acquireHijackConn(conn) assert.NotNil(t, hijackConn) assert.NotNil(t, hijackConn.Conn) assert.DeepEqual(t, engine, hijackConn.e) assert.DeepEqual(t, conn, hijackConn.Conn) } func TestHandleParamsReassignInHandleFunc(t *testing.T) { e := NewEngine(config.NewOptions(nil)) routes := []string{ "/:a/:b/:c", } for _, r := range routes { e.GET(r, func(c context.Context, ctx *app.RequestContext) { ctx.Params = make([]param.Param, 1) ctx.String(consts.StatusOK, "") }) } testRoutes := []string{ "/aaa/bbb/ccc", "/asd/alskja/alkdjad", "/asd/alskja/alkdjad", "/asd/alskja/alkdjad", "/asd/alskja/alkdjad", "/alksjdlakjd/ooo/askda", "/alksjdlakjd/ooo/askda", "/alksjdlakjd/ooo/askda", } ctx := e.ctxPool.Get().(*app.RequestContext) for _, tr := range testRoutes { r := protocol.NewRequest(http.MethodGet, tr, nil) r.CopyTo(&ctx.Request) e.ServeHTTP(context.Background(), ctx) ctx.ResetWithoutConn() } } ================================================ FILE: pkg/route/netpoll.go ================================================ // Copyright 2022 CloudWeGo Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // //go:build (amd64 || arm64) && (linux || darwin) package route import ( "os" "strconv" "github.com/cloudwego/hertz/pkg/network/netpoll" ) func init() { if v, _ := strconv.ParseBool(os.Getenv("HERTZ_NO_NETPOLL")); !v { defaultTransporter = netpoll.NewTransporter } } ================================================ FILE: pkg/route/param/param.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * The MIT License (MIT) * * Copyright (c) 2014 Manuel Martínez-Almeida * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors */ package param // Param is a single URL parameter, consisting of a key and a value. type Param struct { Key string Value string } // Params is a Param-slice, as returned by the router. // The slice is ordered, the first URL parameter is also the first slice value. // It is therefore safe to read values by the index. type Params []Param // Get returns the value of the first Param which key matches the given name. // If no matching Param is found, an empty string is returned. func (ps Params) Get(name string) (string, bool) { for _, entry := range ps { if entry.Key == name { return entry.Value, true } } return "", false } // ByName returns the value of the first Param which key matches the given name. // If no matching Param is found, an empty string is returned. func (ps Params) ByName(name string) (va string) { va, _ = ps.Get(name) return } ================================================ FILE: pkg/route/routergroup.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * The MIT License (MIT) * * Copyright (c) 2014 Manuel Martínez-Almeida * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors */ package route import ( "context" "path" "regexp" "strings" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/protocol/consts" rConsts "github.com/cloudwego/hertz/pkg/route/consts" ) // IRouter defines all router handle interface includes single and group router. type IRouter interface { IRoutes Group(string, ...app.HandlerFunc) *RouterGroup } // IRoutes defines all router handle interface. type IRoutes interface { Use(...app.HandlerFunc) IRoutes Handle(string, string, ...app.HandlerFunc) IRoutes Any(string, ...app.HandlerFunc) IRoutes GET(string, ...app.HandlerFunc) IRoutes POST(string, ...app.HandlerFunc) IRoutes DELETE(string, ...app.HandlerFunc) IRoutes PATCH(string, ...app.HandlerFunc) IRoutes PUT(string, ...app.HandlerFunc) IRoutes OPTIONS(string, ...app.HandlerFunc) IRoutes HEAD(string, ...app.HandlerFunc) IRoutes StaticFile(string, string) IRoutes Static(string, string) IRoutes StaticFS(string, *app.FS) IRoutes } // RouterGroup is used internally to configure router, a RouterGroup is associated with // a prefix and an array of handlers (middleware). type RouterGroup struct { Handlers app.HandlersChain basePath string engine *Engine root bool } var _ IRouter = (*RouterGroup)(nil) // Use adds middleware to the group, see example code in GitHub. func (group *RouterGroup) Use(middleware ...app.HandlerFunc) IRoutes { group.Handlers = append(group.Handlers, middleware...) return group.returnObj() } // Group creates a new router group. You should add all the routes that have common middlewares or the same path prefix. // For example, all the routes that use a common middleware for authorization could be grouped. func (group *RouterGroup) Group(relativePath string, handlers ...app.HandlerFunc) *RouterGroup { return &RouterGroup{ Handlers: group.combineHandlers(handlers), basePath: group.calculateAbsolutePath(relativePath), engine: group.engine, } } // BasePath returns the base path of router group. // For example, if v := router.Group("/rest/n/v1/api"), v.BasePath() is "/rest/n/v1/api". func (group *RouterGroup) BasePath() string { return group.basePath } func (group *RouterGroup) handle(httpMethod, relativePath string, handlers app.HandlersChain) IRoutes { absolutePath := group.calculateAbsolutePath(relativePath) handlers = group.combineHandlers(handlers) group.engine.addRoute(httpMethod, absolutePath, handlers) return group.returnObj() } var upperLetterReg = regexp.MustCompile("^[A-Z]+$") // Handle registers a new request handle and middleware with the given path and method. // The last handler should be the real handler, the other ones should be middleware that can and should be shared among different routes. // See the example code in GitHub. // // For GET, POST, PUT, PATCH and DELETE requests the respective shortcut // functions can be used. // // This function is intended for bulk loading and to allow the usage of less // frequently used, non-standardized or custom methods (e.g. for internal // communication with a proxy). func (group *RouterGroup) Handle(httpMethod, relativePath string, handlers ...app.HandlerFunc) IRoutes { if matches := upperLetterReg.MatchString(httpMethod); !matches { panic("http method " + httpMethod + " is not valid") } return group.handle(httpMethod, relativePath, handlers) } // POST is a shortcut for router.Handle("POST", path, handle). func (group *RouterGroup) POST(relativePath string, handlers ...app.HandlerFunc) IRoutes { return group.handle(consts.MethodPost, relativePath, handlers) } // GET is a shortcut for router.Handle("GET", path, handle). func (group *RouterGroup) GET(relativePath string, handlers ...app.HandlerFunc) IRoutes { return group.handle(consts.MethodGet, relativePath, handlers) } // DELETE is a shortcut for router.Handle("DELETE", path, handle). func (group *RouterGroup) DELETE(relativePath string, handlers ...app.HandlerFunc) IRoutes { return group.handle(consts.MethodDelete, relativePath, handlers) } // PATCH is a shortcut for router.Handle("PATCH", path, handle). func (group *RouterGroup) PATCH(relativePath string, handlers ...app.HandlerFunc) IRoutes { return group.handle(consts.MethodPatch, relativePath, handlers) } // PUT is a shortcut for router.Handle("PUT", path, handle). func (group *RouterGroup) PUT(relativePath string, handlers ...app.HandlerFunc) IRoutes { return group.handle(consts.MethodPut, relativePath, handlers) } // OPTIONS is a shortcut for router.Handle("OPTIONS", path, handle). func (group *RouterGroup) OPTIONS(relativePath string, handlers ...app.HandlerFunc) IRoutes { return group.handle(consts.MethodOptions, relativePath, handlers) } // HEAD is a shortcut for router.Handle("HEAD", path, handle). func (group *RouterGroup) HEAD(relativePath string, handlers ...app.HandlerFunc) IRoutes { return group.handle(consts.MethodHead, relativePath, handlers) } // Any registers a route that matches all the HTTP methods. // GET, POST, PUT, PATCH, HEAD, OPTIONS, DELETE, CONNECT, TRACE. func (group *RouterGroup) Any(relativePath string, handlers ...app.HandlerFunc) IRoutes { group.handle(consts.MethodGet, relativePath, handlers) group.handle(consts.MethodPost, relativePath, handlers) group.handle(consts.MethodPut, relativePath, handlers) group.handle(consts.MethodPatch, relativePath, handlers) group.handle(consts.MethodHead, relativePath, handlers) group.handle(consts.MethodOptions, relativePath, handlers) group.handle(consts.MethodDelete, relativePath, handlers) group.handle(consts.MethodConnect, relativePath, handlers) group.handle(consts.MethodTrace, relativePath, handlers) return group.returnObj() } // StaticFile registers a single route in order to Serve a single file of the local filesystem. // router.StaticFile("favicon.ico", "./resources/favicon.ico") func (group *RouterGroup) StaticFile(relativePath, filepath string) IRoutes { if strings.Contains(relativePath, ":") || strings.Contains(relativePath, "*") { panic("URL parameters can not be used when serving a static file") } handler := func(c context.Context, ctx *app.RequestContext) { ctx.File(filepath) } group.GET(relativePath, handler) group.HEAD(relativePath, handler) return group.returnObj() } // Static serves files from the given file system root. // To use the operating system's file system implementation, // use : // // router.Static("/static", "/var/www") func (group *RouterGroup) Static(relativePath, root string) IRoutes { return group.StaticFS(relativePath, &app.FS{Root: root}) } // StaticFS works just like `Static()` but a custom `FS` can be used instead. func (group *RouterGroup) StaticFS(relativePath string, fs *app.FS) IRoutes { if strings.Contains(relativePath, ":") || strings.Contains(relativePath, "*") { panic("URL parameters can not be used when serving a static folder") } handler := fs.NewRequestHandler() urlPattern := path.Join(relativePath, "/*filepath") // Register GET and HEAD handlers group.GET(urlPattern, handler) group.HEAD(urlPattern, handler) return group.returnObj() } func (group *RouterGroup) combineHandlers(handlers app.HandlersChain) app.HandlersChain { finalSize := len(group.Handlers) + len(handlers) if finalSize >= int(rConsts.AbortIndex) { panic("too many handlers") } mergedHandlers := make(app.HandlersChain, finalSize) copy(mergedHandlers, group.Handlers) copy(mergedHandlers[len(group.Handlers):], handlers) return mergedHandlers } func (group *RouterGroup) calculateAbsolutePath(relativePath string) string { return joinPaths(group.basePath, relativePath) } func (group *RouterGroup) returnObj() IRoutes { if group.root { return group.engine } return group } // GETEX adds a handlerName param. When handler is decorated or handler is an anonymous function, // Hertz cannot get handler name directly. In this case, pass handlerName explicitly. func (group *RouterGroup) GETEX(relativePath string, handler app.HandlerFunc, handlerName string) IRoutes { app.SetHandlerName(handler, handlerName) return group.GET(relativePath, handler) } // POSTEX adds a handlerName param. When handler is decorated or handler is an anonymous function, // Hertz cannot get handler name directly. In this case, pass handlerName explicitly. func (group *RouterGroup) POSTEX(relativePath string, handler app.HandlerFunc, handlerName string) IRoutes { app.SetHandlerName(handler, handlerName) return group.POST(relativePath, handler) } // PUTEX adds a handlerName param. When handler is decorated or handler is an anonymous function, // Hertz cannot get handler name directly. In this case, pass handlerName explicitly. func (group *RouterGroup) PUTEX(relativePath string, handler app.HandlerFunc, handlerName string) IRoutes { app.SetHandlerName(handler, handlerName) return group.PUT(relativePath, handler) } // DELETEEX adds a handlerName param. When handler is decorated or handler is an anonymous function, // Hertz cannot get handler name directly. In this case, pass handlerName explicitly. func (group *RouterGroup) DELETEEX(relativePath string, handler app.HandlerFunc, handlerName string) IRoutes { app.SetHandlerName(handler, handlerName) return group.DELETE(relativePath, handler) } // HEADEX adds a handlerName param. When handler is decorated or handler is an anonymous function, // Hertz cannot get handler name directly. In this case, pass handlerName explicitly. func (group *RouterGroup) HEADEX(relativePath string, handler app.HandlerFunc, handlerName string) IRoutes { app.SetHandlerName(handler, handlerName) return group.HEAD(relativePath, handler) } // AnyEX adds a handlerName param. When handler is decorated or handler is an anonymous function, // Hertz cannot get handler name directly. In this case, pass handlerName explicitly. func (group *RouterGroup) AnyEX(relativePath string, handler app.HandlerFunc, handlerName string) IRoutes { app.SetHandlerName(handler, handlerName) return group.Any(relativePath, handler) } // HandleEX adds a handlerName param. When handler is decorated or handler is an anonymous function, // Hertz cannot get handler name directly. In this case, pass handlerName explicitly. func (group *RouterGroup) HandleEX(httpMethod, relativePath string, handler app.HandlerFunc, handlerName string) IRoutes { app.SetHandlerName(handler, handlerName) return group.Handle(httpMethod, relativePath, handler) } func joinPaths(absolutePath, relativePath string) string { if relativePath == "" { return absolutePath } finalPath := path.Join(absolutePath, relativePath) appendSlash := lastChar(relativePath) == '/' && lastChar(finalPath) != '/' if appendSlash { return finalPath + "/" } return finalPath } func lastChar(str string) uint8 { if str == "" { panic("The length of the string can't be 0") } return str[len(str)-1] } ================================================ FILE: pkg/route/routergroup_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * The MIT License (MIT) * * Copyright (c) 2014 Manuel Martínez-Almeida * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors */ package route import ( "context" "io/ioutil" "net/http" "os" "testing" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/common/test/assert" ) func TestRouterGroupBasic(t *testing.T) { cfg := config.NewOptions(nil) router := NewEngine(cfg) group := router.Group("/hola", func(c context.Context, ctx *app.RequestContext) {}) group.Use(func(c context.Context, ctx *app.RequestContext) {}) assert.DeepEqual(t, len(group.Handlers), 2) assert.DeepEqual(t, "/hola", group.BasePath()) assert.DeepEqual(t, router, group.engine) group2 := group.Group("manu") group2.Use(func(c context.Context, ctx *app.RequestContext) {}, func(c context.Context, ctx *app.RequestContext) {}) assert.DeepEqual(t, len(group2.Handlers), 4) assert.DeepEqual(t, "/hola/manu", group2.BasePath()) assert.DeepEqual(t, router, group2.engine) } func TestRouterGroupBasicHandle(t *testing.T) { performRequestInGroup(t, http.MethodGet) performRequestInGroup(t, http.MethodPost) performRequestInGroup(t, http.MethodPut) performRequestInGroup(t, http.MethodPatch) performRequestInGroup(t, http.MethodDelete) performRequestInGroup(t, http.MethodHead) performRequestInGroup(t, http.MethodOptions) } func performRequestInGroup(t *testing.T, method string) { router := NewEngine(config.NewOptions(nil)) v1 := router.Group("v1", func(c context.Context, ctx *app.RequestContext) {}) assert.DeepEqual(t, "/v1", v1.BasePath()) login := v1.Group("/login/", func(c context.Context, ctx *app.RequestContext) {}, func(c context.Context, ctx *app.RequestContext) {}) assert.DeepEqual(t, "/v1/login/", login.BasePath()) handler := func(c context.Context, ctx *app.RequestContext) { ctx.String(http.StatusBadRequest, "the method was %s and index %d", string(ctx.Request.Header.Method()), ctx.GetIndex()) } switch method { case http.MethodGet: v1.GET("/test", handler) login.GET("/test", handler) case http.MethodPost: v1.POST("/test", handler) login.POST("/test", handler) case http.MethodPut: v1.PUT("/test", handler) login.PUT("/test", handler) case http.MethodPatch: v1.PATCH("/test", handler) login.PATCH("/test", handler) case http.MethodDelete: v1.DELETE("/test", handler) login.DELETE("/test", handler) case http.MethodHead: v1.HEAD("/test", handler) login.HEAD("/test", handler) case http.MethodOptions: v1.OPTIONS("/test", handler) login.OPTIONS("/test", handler) default: panic("unknown method") } w := performRequest(router, method, "/v1/login/test") assert.DeepEqual(t, http.StatusBadRequest, w.Code) assert.DeepEqual(t, "the method was "+method+" and index 3", w.Body.String()) w = performRequest(router, method, "/v1/test") assert.DeepEqual(t, http.StatusBadRequest, w.Code) assert.DeepEqual(t, "the method was "+method+" and index 1", w.Body.String()) } func TestRouterGroupStatic(t *testing.T) { router := NewEngine(config.NewOptions(nil)) router.Static("/", ".") w := performRequest(router, "GET", "/engine.go") fd, err := os.Open("./engine.go") if err != nil { panic(err) } assert.DeepEqual(t, http.StatusOK, w.Code) defer fd.Close() content, err := ioutil.ReadAll(fd) if err != nil { panic(err) } assert.DeepEqual(t, string(content), w.Body.String()) } func TestRouterGroupStaticFile(t *testing.T) { router := NewEngine(config.NewOptions(nil)) router.StaticFile("file", "./engine.go") w := performRequest(router, "GET", "/file") assert.DeepEqual(t, http.StatusOK, w.Code) fd, err := os.Open("./engine.go") if err != nil { panic(err) } defer fd.Close() content, err := ioutil.ReadAll(fd) if err != nil { panic(err) } assert.DeepEqual(t, string(content), w.Body.String()) } func TestRouterGroupInvalidStatic(t *testing.T) { router := &RouterGroup{ Handlers: nil, basePath: "/", root: true, } assert.Panic(t, func() { router.Static("/path/:param", "/") }) assert.Panic(t, func() { router.Static("/path/*param", "/") }) } func TestRouterGroupInvalidStaticFile(t *testing.T) { router := &RouterGroup{ Handlers: nil, basePath: "/", root: true, } assert.Panic(t, func() { router.StaticFile("/path/:param", "favicon.ico") }) assert.Panic(t, func() { router.StaticFile("/path/*param", "favicon.ico") }) } func TestRouterGroupTooManyHandlers(t *testing.T) { engine := NewEngine(config.NewOptions(nil)) handlers1 := make([]app.HandlerFunc, 40) engine.Use(handlers1...) handlers2 := make([]app.HandlerFunc, 26) assert.Panic(t, func() { engine.Use(handlers2...) }) assert.Panic(t, func() { engine.GET("/", handlers2...) }) } func TestRouterGroupBadMethod(t *testing.T) { router := &RouterGroup{ Handlers: nil, basePath: "/", root: true, } assert.Panic(t, func() { router.Handle(http.MethodGet, "/") }) assert.Panic(t, func() { router.Handle(" GET", "/") }) assert.Panic(t, func() { router.Handle("GET ", "/") }) assert.Panic(t, func() { router.Handle("", "/") }) assert.Panic(t, func() { router.Handle("PO ST", "/") }) assert.Panic(t, func() { router.Handle("1GET", "/") }) assert.Panic(t, func() { router.Handle("PATCh", "/") }) } func TestRouterGroupPipeline(t *testing.T) { opt := config.NewOptions([]config.Option{}) router := NewEngine(opt) testRoutesInterface(t, router) v1 := router.Group("/v1") testRoutesInterface(t, v1) } func testRoutesInterface(t *testing.T, r IRoutes) { handler := func(c context.Context, ctx *app.RequestContext) {} assert.DeepEqual(t, r, r.Use(handler)) assert.DeepEqual(t, r, r.Handle(http.MethodGet, "/handler", handler)) assert.DeepEqual(t, r, r.Any("/any", handler)) assert.DeepEqual(t, r, r.GET("/", handler)) assert.DeepEqual(t, r, r.POST("/", handler)) assert.DeepEqual(t, r, r.DELETE("/", handler)) assert.DeepEqual(t, r, r.PATCH("/", handler)) assert.DeepEqual(t, r, r.PUT("/", handler)) assert.DeepEqual(t, r, r.OPTIONS("/", handler)) assert.DeepEqual(t, r, r.HEAD("/", handler)) assert.DeepEqual(t, r, r.StaticFile("/file", ".")) assert.DeepEqual(t, r, r.Static("/static", ".")) assert.DeepEqual(t, r, r.StaticFS("/static2", &app.FS{})) } ================================================ FILE: pkg/route/routes_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * The MIT License (MIT) * * Copyright (c) 2014 Manuel Martínez-Almeida * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors */ package route import ( "context" "fmt" "io/ioutil" "net/http" "net/http/httptest" "os" "path/filepath" "strings" "testing" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" ) type header struct { Key string Value string } func performRequest(e *Engine, method, path string, headers ...header) *httptest.ResponseRecorder { ctx := e.ctxPool.Get().(*app.RequestContext) ctx.HTMLRender = e.htmlRender r := protocol.NewRequest(method, path, nil) r.CopyTo(&ctx.Request) for _, v := range headers { ctx.Request.Header.Add(v.Key, v.Value) } e.ServeHTTP(context.Background(), ctx) w := httptest.NewRecorder() h := w.Header() ctx.Response.Header.VisitAll(func(key, value []byte) { h.Add(string(key), string(value)) }) w.WriteHeader(ctx.Response.StatusCode()) if _, err := w.Write(ctx.Response.Body()); err != nil { panic(err.Error()) } ctx.Reset() e.ctxPool.Put(ctx) return w } func testRouteOK(method string, t *testing.T) { passed := false passedAny := false r := NewEngine(config.NewOptions(nil)) r.Any("/test2", func(c context.Context, ctx *app.RequestContext) { passedAny = true }) r.Handle(method, "/test", func(c context.Context, ctx *app.RequestContext) { passed = true }) w := performRequest(r, method, "/test") assert.DeepEqual(t, true, passed) assert.DeepEqual(t, consts.StatusOK, w.Code) performRequest(r, method, "/test2") assert.DeepEqual(t, true, passedAny) } // TestSingleRouteOK tests that POST route is correctly invoked. func testRouteNotOK(method string, t *testing.T) { passed := false router := NewEngine(config.NewOptions(nil)) router.Handle(method, "/test_2", func(c context.Context, ctx *app.RequestContext) { passed = true }) w := performRequest(router, method, "/test") assert.DeepEqual(t, false, passed) assert.DeepEqual(t, consts.StatusNotFound, w.Code) } // TestSingleRouteOK tests that POST route is correctly invoked. func testRouteNotOK2(method string, t *testing.T) { passed := false router := NewEngine(config.NewOptions(nil)) router.options.HandleMethodNotAllowed = true var methodRoute string if method == consts.MethodPost { methodRoute = consts.MethodGet } else { methodRoute = consts.MethodPost } router.Handle(methodRoute, "/test", func(c context.Context, ctx *app.RequestContext) { passed = true }) w := performRequest(router, method, "/test") assert.DeepEqual(t, false, passed) assert.DeepEqual(t, consts.StatusMethodNotAllowed, w.Code) } func testRouteNotOK3(method string, t *testing.T) { passed := false router := NewEngine(config.NewOptions(nil)) router.Handle("GET", "/api/v:version/product/local/products/list", func(c context.Context, ctx *app.RequestContext) { passed = true }) router.Handle("GET", "/api/v:version/product/product_creation/preload_all_categories", func(c context.Context, ctx *app.RequestContext) { passed = true }) w := performRequest(router, method, "/api/v1/product/products/activate") assert.DeepEqual(t, false, passed) assert.DeepEqual(t, consts.StatusNotFound, w.Code) } func TestRouterMethod(t *testing.T) { router := NewEngine(config.NewOptions(nil)) router.PUT("/hey2", func(c context.Context, ctx *app.RequestContext) { ctx.String(consts.StatusOK, "sup2") }) router.PUT("/hey", func(c context.Context, ctx *app.RequestContext) { ctx.String(consts.StatusOK, "called") }) router.PUT("/hey3", func(c context.Context, ctx *app.RequestContext) { ctx.String(consts.StatusOK, "sup3") }) w := performRequest(router, consts.MethodPut, "/hey") assert.DeepEqual(t, consts.StatusOK, w.Code) assert.DeepEqual(t, "called", w.Body.String()) } func TestRouterGroupRouteOK(t *testing.T) { testRouteOK(consts.MethodGet, t) testRouteOK(consts.MethodPost, t) testRouteOK(consts.MethodPut, t) testRouteOK(consts.MethodPatch, t) testRouteOK(consts.MethodHead, t) testRouteOK(consts.MethodOptions, t) testRouteOK(consts.MethodDelete, t) testRouteOK(consts.MethodConnect, t) testRouteOK(consts.MethodTrace, t) } func TestRouteNotOK1(t *testing.T) { testRouteNotOK(consts.MethodGet, t) testRouteNotOK(consts.MethodPost, t) testRouteNotOK(consts.MethodPut, t) testRouteNotOK(consts.MethodPatch, t) testRouteNotOK(consts.MethodHead, t) testRouteNotOK(consts.MethodOptions, t) testRouteNotOK(consts.MethodDelete, t) testRouteNotOK(consts.MethodConnect, t) testRouteNotOK(consts.MethodTrace, t) } func TestRouteNotOK2(t *testing.T) { testRouteNotOK2(consts.MethodGet, t) testRouteNotOK2(consts.MethodPost, t) testRouteNotOK2(consts.MethodPut, t) testRouteNotOK2(consts.MethodPatch, t) testRouteNotOK2(consts.MethodHead, t) testRouteNotOK2(consts.MethodOptions, t) testRouteNotOK2(consts.MethodDelete, t) testRouteNotOK2(consts.MethodConnect, t) testRouteNotOK2(consts.MethodTrace, t) } func TestRouteNotOK3(t *testing.T) { testRouteNotOK3(consts.MethodGet, t) testRouteNotOK3(consts.MethodPost, t) testRouteNotOK3(consts.MethodPut, t) testRouteNotOK3(consts.MethodPatch, t) testRouteNotOK3(consts.MethodHead, t) testRouteNotOK3(consts.MethodOptions, t) testRouteNotOK3(consts.MethodDelete, t) testRouteNotOK3(consts.MethodConnect, t) testRouteNotOK3(consts.MethodTrace, t) } func TestRouteRedirectTrailingSlash(t *testing.T) { router := NewEngine(config.NewOptions(nil)) router.options.RedirectFixedPath = false router.options.RedirectTrailingSlash = true router.GET("/path", func(c context.Context, ctx *app.RequestContext) {}) router.GET("/path2/", func(c context.Context, ctx *app.RequestContext) {}) router.POST("/path3", func(c context.Context, ctx *app.RequestContext) {}) router.PUT("/path4/", func(c context.Context, ctx *app.RequestContext) {}) w := performRequest(router, consts.MethodGet, "/path/") assert.DeepEqual(t, "/path", w.Header().Get("Location")) assert.DeepEqual(t, consts.StatusMovedPermanently, w.Code) w = performRequest(router, consts.MethodGet, "/path2") assert.DeepEqual(t, "/path2/", w.Header().Get("Location")) assert.DeepEqual(t, consts.StatusMovedPermanently, w.Code) w = performRequest(router, consts.MethodPost, "/path3/") assert.DeepEqual(t, "/path3", w.Header().Get("Location")) assert.DeepEqual(t, consts.StatusTemporaryRedirect, w.Code) w = performRequest(router, consts.MethodPut, "/path4") assert.DeepEqual(t, "/path4/", w.Header().Get("Location")) assert.DeepEqual(t, consts.StatusTemporaryRedirect, w.Code) w = performRequest(router, consts.MethodGet, "/path") assert.DeepEqual(t, consts.StatusOK, w.Code) w = performRequest(router, consts.MethodGet, "/path2/") assert.DeepEqual(t, consts.StatusOK, w.Code) w = performRequest(router, consts.MethodPost, "/path3") assert.DeepEqual(t, consts.StatusOK, w.Code) w = performRequest(router, consts.MethodPut, "/path4/") assert.DeepEqual(t, consts.StatusOK, w.Code) w = performRequest(router, consts.MethodGet, "/path2", header{Key: "X-Forwarded-Prefix", Value: "/api"}) assert.DeepEqual(t, "/api/path2/", w.Header().Get("Location")) assert.DeepEqual(t, consts.StatusMovedPermanently, w.Code) w = performRequest(router, consts.MethodGet, "/path2/", header{Key: "X-Forwarded-Prefix", Value: "/api/"}) assert.DeepEqual(t, consts.StatusOK, w.Code) router.options.RedirectTrailingSlash = false w = performRequest(router, consts.MethodGet, "/path/") assert.DeepEqual(t, consts.StatusNotFound, w.Code) w = performRequest(router, consts.MethodGet, "/path2") assert.DeepEqual(t, consts.StatusNotFound, w.Code) w = performRequest(router, consts.MethodPost, "/path3/") assert.DeepEqual(t, consts.StatusNotFound, w.Code) w = performRequest(router, consts.MethodPut, "/path4") assert.DeepEqual(t, consts.StatusNotFound, w.Code) } func TestRouteRedirectTrailingSlashWithQuery(t *testing.T) { router := NewEngine(config.NewOptions(nil)) router.options.RedirectFixedPath = false router.options.RedirectTrailingSlash = true router.GET("/path", func(c context.Context, ctx *app.RequestContext) {}) router.GET("/path2/", func(c context.Context, ctx *app.RequestContext) {}) w := performRequest(router, consts.MethodGet, "/path/?offset=2") assert.DeepEqual(t, "/path?offset=2", w.Header().Get("Location")) assert.DeepEqual(t, consts.StatusMovedPermanently, w.Code) w = performRequest(router, consts.MethodGet, "/path2?offset=2") assert.DeepEqual(t, "/path2/?offset=2", w.Header().Get("Location")) assert.DeepEqual(t, consts.StatusMovedPermanently, w.Code) } func TestRouteRedirectFixedPath(t *testing.T) { router := NewEngine(config.NewOptions(nil)) router.options.RedirectFixedPath = true router.options.RedirectTrailingSlash = false router.GET("/path", func(c context.Context, ctx *app.RequestContext) {}) router.GET("/Path2", func(c context.Context, ctx *app.RequestContext) {}) router.POST("/PATH3", func(c context.Context, ctx *app.RequestContext) {}) router.POST("/Path4/", func(c context.Context, ctx *app.RequestContext) {}) w := performRequest(router, consts.MethodGet, "/PATH") assert.DeepEqual(t, "/path", w.Header().Get("Location")) assert.DeepEqual(t, consts.StatusMovedPermanently, w.Code) w = performRequest(router, consts.MethodGet, "/path2") assert.DeepEqual(t, "/Path2", w.Header().Get("Location")) assert.DeepEqual(t, consts.StatusMovedPermanently, w.Code) w = performRequest(router, consts.MethodPost, "/path3") assert.DeepEqual(t, "/PATH3", w.Header().Get("Location")) assert.DeepEqual(t, consts.StatusTemporaryRedirect, w.Code) w = performRequest(router, consts.MethodPost, "/path4") assert.DeepEqual(t, "/Path4/", w.Header().Get("Location")) assert.DeepEqual(t, consts.StatusTemporaryRedirect, w.Code) } // TestContextParamsGet tests that a parameter can be parsed from the URL. func TestRouteParamsByName(t *testing.T) { name := "" lastName := "" wild := "" router := NewEngine(config.NewOptions(nil)) router.GET("/test/:name/:last_name/*wild", func(c context.Context, ctx *app.RequestContext) { name = ctx.Params.ByName("name") lastName = ctx.Params.ByName("last_name") var ok bool wild, ok = ctx.Params.Get("wild") assert.DeepEqual(t, true, ok) assert.DeepEqual(t, name, ctx.Param("name")) assert.DeepEqual(t, name, ctx.Param("name")) assert.DeepEqual(t, lastName, ctx.Param("last_name")) assert.DeepEqual(t, "", ctx.Param("wtf")) assert.DeepEqual(t, "", ctx.Params.ByName("wtf")) wtf, ok := ctx.Params.Get("wtf") assert.DeepEqual(t, "", wtf) assert.False(t, ok) }) w := performRequest(router, consts.MethodGet, "/test/john/smith/is/super/great") assert.DeepEqual(t, consts.StatusOK, w.Code) assert.DeepEqual(t, "john", name) assert.DeepEqual(t, "smith", lastName) assert.DeepEqual(t, "is/super/great", wild) } // TestContextParamsGet tests that a parameter can be parsed from the URL even with extra slashes. func TestRouteParamsByNameWithExtraSlash(t *testing.T) { name := "" lastName := "" wild := "" router := NewEngine(config.NewOptions(nil)) router.options.RemoveExtraSlash = true router.GET("/test/:name/:last_name/*wild", func(c context.Context, ctx *app.RequestContext) { name = ctx.Params.ByName("name") lastName = ctx.Params.ByName("last_name") var ok bool wild, ok = ctx.Params.Get("wild") assert.True(t, ok) assert.DeepEqual(t, name, ctx.Param("name")) assert.DeepEqual(t, name, ctx.Param("name")) assert.DeepEqual(t, lastName, ctx.Param("last_name")) assert.DeepEqual(t, "", ctx.Param("wtf")) assert.DeepEqual(t, "", ctx.Params.ByName("wtf")) wtf, ok := ctx.Params.Get("wtf") assert.DeepEqual(t, "", wtf) assert.False(t, ok) }) w := performRequest(router, consts.MethodGet, "/test//john//smith//is//super//great") assert.DeepEqual(t, consts.StatusOK, w.Code) assert.DeepEqual(t, "john", name) assert.DeepEqual(t, "smith", lastName) assert.DeepEqual(t, "is/super/great", wild) } // TestHandleStaticFile - ensure the static file handles properly func TestRouteStaticFile(t *testing.T) { // SETUP file testRoot, _ := os.Getwd() f, err := ioutil.TempFile(testRoot, "") if err != nil { t.Error(err) } defer os.Remove(f.Name()) _, err = f.WriteString("Hertz Web Framework") assert.Nil(t, err) f.Close() dir, filename := filepath.Split(f.Name()) // SETUP engine router := NewEngine(config.NewOptions(nil)) router.StaticFS("/using_static", &app.FS{Root: dir, AcceptByteRange: true, PathRewrite: app.NewPathSlashesStripper(1)}) router.StaticFile("/result", f.Name()) w := performRequest(router, consts.MethodGet, "/using_static/"+filename) w2 := performRequest(router, consts.MethodGet, "/result") assert.DeepEqual(t, w, w2) assert.DeepEqual(t, consts.StatusOK, w.Code) assert.DeepEqual(t, "Hertz Web Framework", w.Body.String()) assert.DeepEqual(t, "text/plain; charset=utf-8", w.Header().Get("Content-Type")) w3 := performRequest(router, consts.MethodHead, "/using_static/"+filename) w4 := performRequest(router, consts.MethodHead, "/result") assert.DeepEqual(t, w3, w4) assert.DeepEqual(t, consts.StatusOK, w3.Code) } // TestHandleStaticDir - ensure the root/sub dir handles properly func TestRouteStaticListingDir(t *testing.T) { router := NewEngine(config.NewOptions(nil)) router.StaticFS("/", &app.FS{Root: "./", GenerateIndexPages: true}) w := performRequest(router, consts.MethodGet, "/") assert.DeepEqual(t, consts.StatusOK, w.Code) assert.True(t, strings.Contains(w.Body.String(), "engine.go")) assert.DeepEqual(t, "text/html; charset=utf-8", w.Header().Get("Content-Type")) } // TestHandleHeadToDir - ensure the root/sub dir handles properly func TestRouteStaticNoListing(t *testing.T) { router := NewEngine(config.NewOptions(nil)) router.Static("/", "./") w := performRequest(router, consts.MethodGet, "/") assert.DeepEqual(t, http.StatusForbidden, w.Code) assert.False(t, strings.Contains(w.Body.String(), "engine.go")) } func TestRouterMiddlewareAndStatic(t *testing.T) { router := NewEngine(config.NewOptions(nil)) static := router.Group("/", func(c context.Context, ctx *app.RequestContext) { ctx.Response.Header.Set("Last-Modified", "Mon, 02 Jan 2006 15:04:05 MST") ctx.Response.Header.Set("Last-Modified", "Mon, 02 Jan 2006 15:04:05 MST") ctx.Response.Header.Set("Expires", "Mon, 02 Jan 2006 15:04:05 MST") ctx.Response.Header.Set("X-Hertz", "Hertz Framework") }) static.Static("/", "./") w := performRequest(router, consts.MethodGet, "/engine.go") assert.DeepEqual(t, consts.StatusOK, w.Code) assert.True(t, strings.Contains(w.Body.String(), "package route")) // when Go version <= 1.16, mime.TypeByExtension will return Content-Type='text/plain; charset=utf-8', // otherwise it will return Content-Type='text/x-go; charset=utf-8' assert.NotEqual(t, "", w.Header().Get("Content-Type")) assert.NotEqual(t, w.Header().Get("Last-Modified"), "Mon, 02 Jan 2006 15:04:05 MST") assert.DeepEqual(t, "Mon, 02 Jan 2006 15:04:05 MST", w.Header().Get("Expires")) assert.DeepEqual(t, "Hertz Framework", w.Header().Get("x-Hertz")) } func TestRouteNotAllowedEnabled(t *testing.T) { router := NewEngine(config.NewOptions(nil)) router.options.HandleMethodNotAllowed = true router.POST("/path", func(c context.Context, ctx *app.RequestContext) {}) w := performRequest(router, consts.MethodGet, "/path") assert.DeepEqual(t, consts.StatusMethodNotAllowed, w.Code) router.NoMethod(func(c context.Context, ctx *app.RequestContext) { ctx.String(http.StatusTeapot, "responseText") }) w = performRequest(router, consts.MethodGet, "/path") assert.DeepEqual(t, "responseText", w.Body.String()) assert.DeepEqual(t, http.StatusTeapot, w.Code) } func TestRouteNotAllowedEnabled2(t *testing.T) { router := NewEngine(config.NewOptions(nil)) router.options.HandleMethodNotAllowed = true // add one methodTree to trees router.addRoute(consts.MethodPost, "/", app.HandlersChain{func(_ context.Context, _ *app.RequestContext) {}}) router.GET("/path2", func(c context.Context, ctx *app.RequestContext) {}) w := performRequest(router, consts.MethodPost, "/path2") assert.DeepEqual(t, consts.StatusMethodNotAllowed, w.Code) } func TestRouteNotAllowedDisabled(t *testing.T) { router := NewEngine(config.NewOptions(nil)) router.options.HandleMethodNotAllowed = false router.POST("/path", func(c context.Context, ctx *app.RequestContext) {}) w := performRequest(router, consts.MethodGet, "/path") assert.DeepEqual(t, consts.StatusNotFound, w.Code) router.NoMethod(func(c context.Context, ctx *app.RequestContext) { ctx.String(http.StatusTeapot, "responseText") }) w = performRequest(router, consts.MethodGet, "/path") assert.DeepEqual(t, "Not Found", w.Body.String()) assert.DeepEqual(t, consts.StatusNotFound, w.Code) } func TestRouterNotFoundWithRemoveExtraSlash(t *testing.T) { router := NewEngine(config.NewOptions(nil)) router.options.RemoveExtraSlash = true router.GET("/path", func(c context.Context, ctx *app.RequestContext) {}) router.GET("/", func(c context.Context, ctx *app.RequestContext) {}) testRoutes := []struct { route string code int location string }{ {"/../path", consts.StatusOK, ""}, // CleanPath {"/nope", consts.StatusNotFound, ""}, // NotFound } for _, tr := range testRoutes { w := performRequest(router, "GET", tr.route) assert.DeepEqual(t, tr.code, w.Code) if w.Code != consts.StatusNotFound { assert.DeepEqual(t, tr.location, fmt.Sprint(w.Header().Get("Location"))) } } } func TestRouterNotFound(t *testing.T) { router := NewEngine(config.NewOptions(nil)) router.options.RedirectFixedPath = true router.GET("/path", func(c context.Context, ctx *app.RequestContext) {}) router.GET("/dir/", func(c context.Context, ctx *app.RequestContext) {}) router.GET("/", func(c context.Context, ctx *app.RequestContext) {}) testRoutes := []struct { route string code int location string }{ {"/path/", consts.StatusMovedPermanently, "/path"}, // TSR -/ {"/dir", consts.StatusMovedPermanently, "/dir/"}, // TSR +/ {"/PATH", consts.StatusMovedPermanently, "/path"}, // Fixed Case {"/DIR/", consts.StatusMovedPermanently, "/dir/"}, // Fixed Case {"/PATH/", consts.StatusMovedPermanently, "/path"}, // Fixed Case -/ {"/DIR", consts.StatusMovedPermanently, "/dir/"}, // Fixed Case +/ {"/../path/", consts.StatusMovedPermanently, "/path"}, // Without CleanPath {"/nope", consts.StatusNotFound, ""}, // NotFound } for _, tr := range testRoutes { w := performRequest(router, consts.MethodGet, tr.route) assert.DeepEqual(t, tr.code, w.Code) if w.Code != consts.StatusNotFound { assert.DeepEqual(t, tr.location, fmt.Sprint(w.Header().Get("Location"))) } } // Test custom not found handler var notFound bool router.NoRoute(func(c context.Context, ctx *app.RequestContext) { ctx.AbortWithStatus(consts.StatusNotFound) notFound = true }) w := performRequest(router, consts.MethodGet, "/nope") assert.DeepEqual(t, consts.StatusNotFound, w.Code) assert.True(t, notFound) // Test other method than GET (want 307 instead of 301) router.PATCH("/path", func(c context.Context, ctx *app.RequestContext) {}) w = performRequest(router, consts.MethodPatch, "/path/") assert.DeepEqual(t, consts.StatusTemporaryRedirect, w.Code) assert.DeepEqual(t, "map[Content-Type:[text/plain; charset=utf-8] Location:[/path]]", fmt.Sprint(w.Header())) // Test special case where no node for the prefix "/" exists router = NewEngine(config.NewOptions(nil)) router.GET("/a", func(c context.Context, ctx *app.RequestContext) {}) w = performRequest(router, consts.MethodGet, "/") assert.DeepEqual(t, consts.StatusNotFound, w.Code) } func TestRouterStaticFSNotFound(t *testing.T) { router := NewEngine(config.NewOptions(nil)) router.StaticFS("/", &app.FS{Root: "/thisreallydoesntexist/"}) router.NoRoute(func(c context.Context, ctx *app.RequestContext) { ctx.String(consts.StatusNotFound, "non existent") }) w := performRequest(router, consts.MethodGet, "/nonexistent") assert.DeepEqual(t, "Cannot open requested path", w.Body.String()) w = performRequest(router, consts.MethodHead, "/nonexistent") assert.DeepEqual(t, "Cannot open requested path", w.Body.String()) } func TestRouterStaticFSFileNotFound(t *testing.T) { router := NewEngine(config.NewOptions(nil)) router.StaticFS("/", &app.FS{Root: "."}) assert.NotPanic(t, func() { performRequest(router, consts.MethodGet, "/nonexistent") }) } func TestMiddlewareCalledOnceByRouterStaticFSNotFound(t *testing.T) { router := NewEngine(config.NewOptions(nil)) // Middleware must be called just only once by per request. middlewareCalledNum := 0 router.Use(func(c context.Context, ctx *app.RequestContext) { middlewareCalledNum++ }) router.StaticFS("/", &app.FS{Root: "/thisreallydoesntexist/"}) // First access performRequest(router, consts.MethodGet, "/nonexistent") assert.DeepEqual(t, 1, middlewareCalledNum) // Second access performRequest(router, consts.MethodHead, "/nonexistent") assert.DeepEqual(t, 2, middlewareCalledNum) } func TestRouteRawPath(t *testing.T) { route := NewEngine(config.NewOptions(nil)) route.options.UseRawPath = true route.POST("/project/:name/build/:num", func(c context.Context, ctx *app.RequestContext) { name := ctx.Params.ByName("name") num := ctx.Params.ByName("num") assert.DeepEqual(t, name, ctx.Param("name")) assert.DeepEqual(t, num, ctx.Param("num")) assert.DeepEqual(t, "Some/Other/Project", name) assert.DeepEqual(t, "222", num) }) w := performRequest(route, consts.MethodPost, "/project/Some%2FOther%2FProject/build/222") assert.DeepEqual(t, consts.StatusOK, w.Code) } func TestRouteRawPathNoUnescape(t *testing.T) { route := NewEngine(config.NewOptions(nil)) route.options.UseRawPath = true route.options.UnescapePathValues = false route.POST("/project/:name/build/:num", func(c context.Context, ctx *app.RequestContext) { name := ctx.Params.ByName("name") num := ctx.Params.ByName("num") assert.DeepEqual(t, name, ctx.Param("name")) assert.DeepEqual(t, num, ctx.Param("num")) assert.DeepEqual(t, "Some%2FOther%2FProject", name) assert.DeepEqual(t, "333", num) }) w := performRequest(route, consts.MethodPost, "/project/Some%2FOther%2FProject/build/333") assert.DeepEqual(t, consts.StatusOK, w.Code) } func TestRouteServeErrorWithWriteHeader(t *testing.T) { route := NewEngine(config.NewOptions(nil)) route.Use(func(c context.Context, ctx *app.RequestContext) { ctx.SetStatusCode(421) ctx.Next(c) }) w := performRequest(route, consts.MethodGet, "/NotFound") assert.DeepEqual(t, 421, w.Code) assert.DeepEqual(t, 0, w.Body.Len()) } func TestRouteContextHoldsFullPath(t *testing.T) { router := NewEngine(config.NewOptions(nil)) // Test routes routes := []string{ "/simple", "/project/:name", "/", "/news/home", "/news", "/simple-two/one", "/simple-two/one-two", "/project/:name/build/*params", "/project/:name/bui", "/user/:id/status", "/user/:id", "/user/:id/profile", } for _, route := range routes { actualRoute := route router.GET(route, func(c context.Context, ctx *app.RequestContext) { // For each defined route context should contain its full path assert.DeepEqual(t, actualRoute, ctx.FullPath()) ctx.AbortWithStatus(consts.StatusOK) }) } for _, route := range routes { w := performRequest(router, consts.MethodGet, route) assert.DeepEqual(t, consts.StatusOK, w.Code) } // Test not found router.Use(func(c context.Context, ctx *app.RequestContext) { // For not found routes full path is empty assert.DeepEqual(t, "", ctx.FullPath()) }) w := performRequest(router, consts.MethodGet, "/not-found") assert.DeepEqual(t, consts.StatusNotFound, w.Code) } func checkUnusedParamValues(t *testing.T, ctx *app.RequestContext, expectParam map[string]string) { for _, p := range ctx.Params { if expectParam == nil { t.Errorf("pValue '%+v' is set for param name '%v' but we are not expecting it", p.Value, p.Key) } else if val, ok := expectParam[p.Key]; !ok || val != p.Value { t.Errorf("'%+v' is set for param name '%v' but we are expecting it with expectParam '%+v'", p.Value, p.Key, val) } } } var handlerFunc = func(route string) app.HandlersChain { return app.HandlersChain{func(c context.Context, ctx *app.RequestContext) { ctx.Set("path", route) }} } var handlerHelper = func(route, key string, value int) app.HandlersChain { return app.HandlersChain{func(c context.Context, ctx *app.RequestContext) { ctx.Set(key, value) ctx.Set("path", route) }} } func getHelper(c *app.RequestContext, key string) interface{} { p, _ := c.Get(key) return p } func TestRouterStatic(t *testing.T) { e := NewEngine(config.NewOptions(nil)) path := "/folders/a/files/hertz.gif" e.addRoute(consts.MethodGet, path, handlerFunc(path)) c := e.NewContext() c.Request.SetRequestURI(path) c.Request.Header.SetMethod(consts.MethodGet) e.ServeHTTP(context.Background(), c) assert.DeepEqual(t, path, getHelper(c, "path")) } func TestRouterParam(t *testing.T) { e := NewEngine(config.NewOptions(nil)) e.addRoute(consts.MethodGet, "/users/:id", handlerFunc("/users/:id")) testCases := []struct { name string whenURL string expectRoute interface{} expectParam map[string]string }{ { name: "route /users/1 to /users/:id", whenURL: "/users/1", expectRoute: "/users/:id", expectParam: map[string]string{"id": "1"}, }, { name: "route /users/1/ to /users/:id", whenURL: "/users/1/", expectRoute: nil, expectParam: nil, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { c := e.NewContext() c.Request.SetRequestURI(tc.whenURL) c.Request.Header.SetMethod(consts.MethodGet) e.ServeHTTP(context.Background(), c) assert.DeepEqual(t, tc.expectRoute, getHelper(c, "path")) checkUnusedParamValues(t, c, tc.expectParam) }) } } func TestRouterTwoParam(t *testing.T) { e := NewEngine(config.NewOptions(nil)) e.addRoute(consts.MethodGet, "/users/:uid/files/:fid", handlerFunc("/users/:uid/files/:fid")) ctx := e.NewContext() ctx.Request.SetRequestURI("/users/1/files/1") ctx.Request.Header.SetMethod(consts.MethodGet) e.ServeHTTP(context.Background(), ctx) assert.DeepEqual(t, "1", ctx.Param("uid")) assert.DeepEqual(t, "1", ctx.Param("fid")) } func TestRouterParamWithSlash(t *testing.T) { e := NewEngine(config.NewOptions(nil)) e.addRoute(consts.MethodGet, "/a/:b/c/d/:e", handlerFunc("/a/:b/c/d/:e")) e.addRoute(consts.MethodGet, "/a/:b/c/:d/:f", handlerFunc("/a/:b/c/:d/:f")) ctx := e.NewContext() ctx.Request.SetRequestURI("/a/1/c/d/2/3") ctx.Request.Header.SetMethod(consts.MethodGet) e.ServeHTTP(context.Background(), ctx) assert.Nil(t, getHelper(ctx, "path")) assert.DeepEqual(t, consts.StatusNotFound, ctx.Response.StatusCode()) } func TestRouteMultiLevelBacktracking(t *testing.T) { testCases := []struct { name string whenURL string expectRoute interface{} expectParam map[string]string }{ { name: "route /a/c/df to /a/c/df", whenURL: "/a/c/df", expectRoute: "/a/c/df", }, { name: "route /a/x/df to /a/:b/c", whenURL: "/a/x/c", expectRoute: "/a/:b/c", expectParam: map[string]string{"b": "x"}, }, // { // name: "route /a/x/f to /a/*/f", // whenURL: "/a/x/f", // expectRoute: "/a/*/f", // expectParam: map[string]string{"x": "x/f"}, // NOTE: `x` would be probably more suitable // }, { name: "route /b/c/f to /:e/c/f", whenURL: "/b/c/f", expectRoute: "/:e/c/f", expectParam: map[string]string{"e": "b"}, }, { name: "route /b/c/c to /*x", whenURL: "/b/c/c", expectRoute: "/*x", expectParam: map[string]string{"x": "b/c/c"}, }, } e := NewEngine(config.NewOptions(nil)) e.addRoute(consts.MethodGet, "/a/:b/c", handlerHelper("/a/:b/c", "case", 1)) e.addRoute(consts.MethodGet, "/a/c/d", handlerHelper("/a/c/d", "case", 2)) e.addRoute(consts.MethodGet, "/a/c/df", handlerHelper("/a/c/df", "case", 3)) // e.addRoute(consts.MethodGet, "/a/*/f", handlerHelper("case", 4)) e.addRoute(consts.MethodGet, "/:e/c/f", handlerHelper("/:e/c/f", "case", 5)) e.addRoute(consts.MethodGet, "/*x", handlerHelper("/*x", "case", 6)) for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { ctx := e.NewContext() ctx.Request.SetRequestURI(tc.whenURL) ctx.Request.Header.SetMethod(consts.MethodGet) e.ServeHTTP(context.Background(), ctx) assert.DeepEqual(t, tc.expectRoute, getHelper(ctx, "path")) for param, expectedValue := range tc.expectParam { assert.DeepEqual(t, expectedValue, ctx.Param(param)) } checkUnusedParamValues(t, ctx, tc.expectParam) }) } } func TestRouteMultiLevelBacktracking2(t *testing.T) { e := NewEngine(config.NewOptions(nil)) e.addRoute(consts.MethodGet, "/a/:b/c", handlerFunc("/a/:b/c")) e.addRoute(consts.MethodGet, "/a/c/d", handlerFunc("/a/c/d")) e.addRoute(consts.MethodGet, "/a/c/df", handlerFunc("/a/c/df")) e.addRoute(consts.MethodGet, "/:e/c/f", handlerFunc("/:e/c/f")) e.addRoute(consts.MethodGet, "/*x", handlerFunc("/*x")) testCases := []struct { name string whenURL string expectRoute string expectParam map[string]string }{ { name: "route /a/c/df to /a/c/df", whenURL: "/a/c/df", expectRoute: "/a/c/df", }, { name: "route /a/x/df to /a/:b/c", whenURL: "/a/x/c", expectRoute: "/a/:b/c", expectParam: map[string]string{"b": "x"}, }, { name: "route /a/c/f to /:e/c/f", whenURL: "/a/c/f", expectRoute: "/:e/c/f", expectParam: map[string]string{"e": "a"}, }, { name: "route /b/c/f to /:e/c/f", whenURL: "/b/c/f", expectRoute: "/:e/c/f", expectParam: map[string]string{"e": "b"}, }, { name: "route /b/c/c to /*", whenURL: "/b/c/c", expectRoute: "/*x", expectParam: map[string]string{"x": "b/c/c"}, }, { // this traverses `/a/:b/c` and `/:e/c/f` branches and eventually backtracks to `/*` name: "route /a/c/cf to /*", whenURL: "/a/c/cf", expectRoute: "/*x", expectParam: map[string]string{"x": "a/c/cf"}, }, { name: "route /anyMatch to /*", whenURL: "/anyMatch", expectRoute: "/*x", expectParam: map[string]string{"x": "anyMatch"}, }, { name: "route /anyMatch/withSlash to /*", whenURL: "/anyMatch/withSlash", expectRoute: "/*x", expectParam: map[string]string{"x": "anyMatch/withSlash"}, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { ctx := e.NewContext() ctx.Request.SetRequestURI(tc.whenURL) ctx.Request.Header.SetMethod(consts.MethodGet) e.ServeHTTP(context.Background(), ctx) assert.DeepEqual(t, tc.expectRoute, getHelper(ctx, "path")) for param, expectedValue := range tc.expectParam { assert.DeepEqual(t, expectedValue, ctx.Param(param)) } checkUnusedParamValues(t, ctx, tc.expectParam) }) } } func TestRouterBacktrackingFromMultipleParamKinds(t *testing.T) { e := NewEngine(config.NewOptions(nil)) e.addRoute(consts.MethodGet, "/*x", handlerFunc("/*x")) // this can match only path that does not have slash in it e.addRoute(consts.MethodGet, "/:1/second", handlerFunc("/:1/second")) e.addRoute(consts.MethodGet, "/:1/:2", handlerFunc("/:1/:2")) // this acts as match ANY for all routes that have at least one slash e.addRoute(consts.MethodGet, "/:1/:2/third", handlerFunc("/:1/:2/third")) e.addRoute(consts.MethodGet, "/:1/:2/:3/fourth", handlerFunc("/:1/:2/:3/fourth")) e.addRoute(consts.MethodGet, "/:1/:2/:3/:4/fifth", handlerFunc("/:1/:2/:3/:4/fifth")) testCases := []struct { name string whenURL string expectRoute string expectParam map[string]string }{ { name: "route /first to /*", whenURL: "/first", expectRoute: "/*x", expectParam: map[string]string{"x": "first"}, }, { name: "route /first/second to /:1/second", whenURL: "/first/second", expectRoute: "/:1/second", expectParam: map[string]string{"1": "first"}, }, { name: "route /first/second-new to /:1/:2", whenURL: "/first/second-new", expectRoute: "/:1/:2", expectParam: map[string]string{ "1": "first", "2": "second-new", }, }, { name: "route /first/second/ to /:1/:2", whenURL: "/first/second/", expectRoute: "/*x", // "/:1/:2", expectParam: map[string]string{"x": "first/second/"}, }, { name: "route /first/second/third/fourth/fifth/nope to /:1/:2", whenURL: "/first/second/third/fourth/fifth/nope", expectRoute: "/*x", expectParam: map[string]string{"x": "first/second/third/fourth/fifth/nope"}, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { ctx := e.NewContext() ctx.Request.SetRequestURI(tc.whenURL) ctx.Request.Header.SetMethod(consts.MethodGet) e.ServeHTTP(context.Background(), ctx) assert.DeepEqual(t, tc.expectRoute, getHelper(ctx, "path")) for param, expectedValue := range tc.expectParam { assert.DeepEqual(t, expectedValue, ctx.Param(param)) } checkUnusedParamValues(t, ctx, tc.expectParam) }) } } func TestRouterParamStaticConflict(t *testing.T) { e := NewEngine(config.NewOptions(nil)) g := e.Group("/g") g.GET("/skills", handlerFunc("/g/skills")...) g.GET("/status", handlerFunc("/g/status")...) g.GET("/:name", handlerFunc("/g/:name")...) testCases := []struct { whenURL string expectRoute interface{} expectParam map[string]string }{ { whenURL: "/g/s", expectRoute: "/g/:name", expectParam: map[string]string{"name": "s"}, }, { whenURL: "/g/status", expectRoute: "/g/status", expectParam: map[string]string{"name": ""}, }, } for _, tc := range testCases { t.Run(tc.whenURL, func(t *testing.T) { ctx := e.NewContext() ctx.Request.SetRequestURI(tc.whenURL) ctx.Request.Header.SetMethod(consts.MethodGet) e.ServeHTTP(context.Background(), ctx) assert.DeepEqual(t, tc.expectRoute, getHelper(ctx, "path")) for param, expectedValue := range tc.expectParam { assert.DeepEqual(t, expectedValue, ctx.Param(param)) } checkUnusedParamValues(t, ctx, tc.expectParam) }) } } func TestRouterMatchAny(t *testing.T) { e := NewEngine(config.NewOptions(nil)) // Routes e.addRoute(consts.MethodGet, "/", handlerFunc("/")) e.addRoute(consts.MethodGet, "/*x", handlerFunc("/*x")) e.addRoute(consts.MethodGet, "/users/*x", handlerFunc("/users/*x")) testCases := []struct { whenURL string expectRoute interface{} expectParam map[string]string }{ { whenURL: "/", expectRoute: "/", expectParam: map[string]string{"x": ""}, }, { whenURL: "/download", expectRoute: "/*x", expectParam: map[string]string{"x": "download"}, }, { whenURL: "/users/joe", expectRoute: "/users/*x", expectParam: map[string]string{"x": "joe"}, }, } for _, tc := range testCases { t.Run(tc.whenURL, func(t *testing.T) { ctx := e.NewContext() ctx.Request.SetRequestURI(tc.whenURL) ctx.Request.Header.SetMethod(consts.MethodGet) e.ServeHTTP(context.Background(), ctx) assert.DeepEqual(t, tc.expectRoute, getHelper(ctx, "path")) for param, expectedValue := range tc.expectParam { assert.DeepEqual(t, expectedValue, ctx.Param(param)) } checkUnusedParamValues(t, ctx, tc.expectParam) }) } } // NOTE: This is to document current implementation. Last added route with `*` asterisk is always the match and no // backtracking or more precise matching is done to find more suitable match. // // Current behaviour might not be correct or expected. // But this is where we are without well defined requirements/rules how (multiple) asterisks work in route func TestRouterAnyMatchesLastAddedAnyRoute(t *testing.T) { e := NewEngine(config.NewOptions(nil)) e.addRoute(consts.MethodGet, "/users/*x", handlerHelper("/users/*x", "case", 1)) // e.addRoute(consts.MethodGet, "/users/*x/action*y", handlerHelper("/users/*x/action*y", "case", 2)) ctx := e.NewContext() ctx.Request.SetRequestURI("/users/xxx/action/sea") ctx.Request.Header.SetMethod(consts.MethodGet) e.ServeHTTP(context.Background(), ctx) assert.DeepEqual(t, "/users/*x", getHelper(ctx, "path")) assert.DeepEqual(t, "xxx/action/sea", ctx.Param("x")) // if we add another route then it is the last added and so it is matched // e.addRoute(consts.MethodGet, "/users/*x/action/search", handlerHelper("/users/*x/action/search", "case", 3)) // c.Request.SetRequestURI("/users/xxx/action/sea") // c.Request.Header.SetMethod(consts.MethodGet) // e.ServeHTTP(context.Background(), c) // test.DeepEqual(t, "/users/*x/action/search", getHelper(c, "path")) // test.DeepEqual(t, "xxx/action/sea", ctx.Param("x")) } func TestRouterMatchAnyPrefixIssue(t *testing.T) { e := NewEngine(config.NewOptions(nil)) // Routes e.addRoute(consts.MethodGet, "/*x", handlerFunc("/*x")) e.addRoute(consts.MethodGet, "/users/*x", handlerFunc("/users/*x")) testCases := []struct { whenURL string expectRoute interface{} expectParam map[string]string }{ { whenURL: "/", expectRoute: "/*x", expectParam: map[string]string{"x": ""}, }, { whenURL: "/users", expectRoute: "/*x", expectParam: map[string]string{"x": "users"}, }, { whenURL: "/users/", expectRoute: "/users/*x", expectParam: map[string]string{"x": ""}, }, { whenURL: "/users_prefix", expectRoute: "/*x", expectParam: map[string]string{"x": "users_prefix"}, }, { whenURL: "/users_prefix/", expectRoute: "/*x", expectParam: map[string]string{"x": "users_prefix/"}, }, } for _, tc := range testCases { t.Run(tc.whenURL, func(t *testing.T) { ctx := e.NewContext() ctx.Request.SetRequestURI(tc.whenURL) ctx.Request.Header.SetMethod(consts.MethodGet) e.ServeHTTP(context.Background(), ctx) assert.DeepEqual(t, tc.expectRoute, getHelper(ctx, "path")) for param, expectedValue := range tc.expectParam { assert.DeepEqual(t, expectedValue, ctx.Param(param)) } checkUnusedParamValues(t, ctx, tc.expectParam) }) } } // TestRouterMatchAnySlash shall verify finding the best route // for any routes with trailing slash requests func TestRouterMatchAnySlash(t *testing.T) { e := NewEngine(config.NewOptions(nil)) // Routes e.addRoute(consts.MethodGet, "/users", handlerFunc("/users")) e.addRoute(consts.MethodGet, "/users/*x", handlerFunc("/users/*x")) e.addRoute(consts.MethodGet, "/img/*x", handlerFunc("/img/*x")) e.addRoute(consts.MethodGet, "/img/load", handlerFunc("/img/load")) e.addRoute(consts.MethodGet, "/img/load/*x", handlerFunc("/img/load/*x")) e.addRoute(consts.MethodGet, "/assets/*x", handlerFunc("/assets/*x")) testCases := []struct { whenURL string expectRoute interface{} expectParam map[string]string expectError error }{ { whenURL: "/", expectRoute: nil, expectParam: map[string]string{"x": ""}, }, { // Test trailing slash request for simple any route (see #1526) whenURL: "/users/", expectRoute: "/users/*x", expectParam: map[string]string{"x": ""}, }, { whenURL: "/users/joe", expectRoute: "/users/*x", expectParam: map[string]string{"x": "joe"}, }, // Test trailing slash request for nested any route (see #1526) { whenURL: "/img/load", expectRoute: "/img/load", expectParam: map[string]string{"x": ""}, }, { whenURL: "/img/load/", expectRoute: "/img/load/*x", expectParam: map[string]string{"x": ""}, }, { whenURL: "/img/load/ben", expectRoute: "/img/load/*x", expectParam: map[string]string{"x": "ben"}, }, // Test /assets/*x any route { // ... without trailing slash must not match whenURL: "/assets", expectRoute: nil, expectParam: map[string]string{"x": ""}, }, { // ... with trailing slash must match whenURL: "/assets/", expectRoute: "/assets/*x", expectParam: map[string]string{"x": ""}, }, } for _, tc := range testCases { t.Run(tc.whenURL, func(t *testing.T) { ctx := e.NewContext() ctx.Request.SetRequestURI(tc.whenURL) ctx.Request.Header.SetMethod(consts.MethodGet) e.ServeHTTP(context.Background(), ctx) assert.DeepEqual(t, tc.expectRoute, getHelper(ctx, "path")) for param, expectedValue := range tc.expectParam { assert.DeepEqual(t, expectedValue, ctx.Param(param)) } checkUnusedParamValues(t, ctx, tc.expectParam) }) } } func TestRouterMatchAnyMultiLevel(t *testing.T) { e := NewEngine(config.NewOptions(nil)) // Routes e.addRoute(consts.MethodGet, "/api/users/jack", handlerFunc("/api/users/jack")) e.addRoute(consts.MethodGet, "/api/users/jill", handlerFunc("/api/users/jill")) e.addRoute(consts.MethodGet, "/api/users/*x", handlerFunc("/api/users/*x")) e.addRoute(consts.MethodGet, "/api/*x", handlerFunc("/api/*x")) e.addRoute(consts.MethodGet, "/other/*x", handlerFunc("/other/*x")) e.addRoute(consts.MethodGet, "/*x", handlerFunc("/*x")) testCases := []struct { whenURL string expectRoute interface{} expectParam map[string]string expectError error }{ { whenURL: "/api/users/jack", expectRoute: "/api/users/jack", expectParam: map[string]string{"x": ""}, }, { whenURL: "/api/users/jill", expectRoute: "/api/users/jill", expectParam: map[string]string{"x": ""}, }, { whenURL: "/api/users/joe", expectRoute: "/api/users/*x", expectParam: map[string]string{"x": "joe"}, }, { whenURL: "/api/nousers/joe", expectRoute: "/api/*x", expectParam: map[string]string{"x": "nousers/joe"}, }, { whenURL: "/api/none", expectRoute: "/api/*x", expectParam: map[string]string{"x": "none"}, }, { whenURL: "/api/none", expectRoute: "/api/*x", expectParam: map[string]string{"x": "none"}, }, { whenURL: "/noapi/users/jim", expectRoute: "/*x", expectParam: map[string]string{"x": "noapi/users/jim"}, }, } for _, tc := range testCases { t.Run(tc.whenURL, func(t *testing.T) { ctx := e.NewContext() ctx.Request.SetRequestURI(tc.whenURL) ctx.Request.Header.SetMethod(consts.MethodGet) e.ServeHTTP(context.Background(), ctx) assert.DeepEqual(t, tc.expectRoute, getHelper(ctx, "path")) for param, expectedValue := range tc.expectParam { assert.DeepEqual(t, expectedValue, ctx.Param(param)) } checkUnusedParamValues(t, ctx, tc.expectParam) }) } } func TestRouterMatchAnyMultiLevelWithPost(t *testing.T) { e := NewEngine(config.NewOptions(nil)) // Routes e.POST("/api/auth/login", handlerFunc("/api/auth/login")...) e.POST("/api/auth/forgotPassword", handlerFunc("/api/auth/forgotPassword")...) e.Any("/api/*x", handlerFunc("/api/*x")...) e.Any("/*x", handlerFunc("/*x")...) testCases := []struct { whenMethod string whenURL string expectRoute interface{} expectParam map[string]string expectError error }{ { // POST /api/auth/login shall choose login method whenURL: "/api/auth/login", whenMethod: consts.MethodPost, expectRoute: "/api/auth/login", expectParam: map[string]string{"x": ""}, }, { // POST /api/auth/logout shall choose nearest any route whenURL: "/api/auth/logout", whenMethod: consts.MethodPost, expectRoute: "/api/*x", expectParam: map[string]string{"x": "auth/logout"}, }, { // POST to /api/other/test shall choose nearest any route whenURL: "/api/other/test", whenMethod: consts.MethodPost, expectRoute: "/api/*x", expectParam: map[string]string{"x": "other/test"}, }, { // GET to /api/other/test shall choose nearest any route whenURL: "/api/other/test", whenMethod: consts.MethodGet, expectRoute: "/api/*x", expectParam: map[string]string{"x": "other/test"}, }, } for _, tc := range testCases { t.Run(tc.whenURL, func(t *testing.T) { ctx := e.NewContext() ctx.Request.SetRequestURI(tc.whenURL) ctx.Request.Header.SetMethod(tc.whenMethod) e.ServeHTTP(context.Background(), ctx) assert.DeepEqual(t, tc.expectRoute, getHelper(ctx, "path")) for param, expectedValue := range tc.expectParam { assert.DeepEqual(t, expectedValue, ctx.Param(param)) } checkUnusedParamValues(t, ctx, tc.expectParam) }) } } func TestRouterMicroParam(t *testing.T) { e := NewEngine(config.NewOptions(nil)) e.addRoute(consts.MethodGet, "/:a/:b/:c", handlerFunc("/:a/:b/:c")) ctx := e.NewContext() ctx.Request.SetRequestURI("/1/2/3") ctx.Request.Header.SetMethod(consts.MethodGet) e.ServeHTTP(context.Background(), ctx) assert.DeepEqual(t, "1", ctx.Param("a")) assert.DeepEqual(t, "2", ctx.Param("b")) assert.DeepEqual(t, "3", ctx.Param("c")) } func TestRouterMixParamMatchAny(t *testing.T) { e := NewEngine(config.NewOptions(nil)) // Route e.addRoute(consts.MethodGet, "/users/:id/*x", handlerFunc("/users/:id/*x")) ctx := e.NewContext() ctx.Request.SetRequestURI("/users/joe/comments") ctx.Request.Header.SetMethod(consts.MethodGet) e.ServeHTTP(context.Background(), ctx) assert.DeepEqual(t, "joe", ctx.Param("id")) } func TestRouterMultiRoute(t *testing.T) { e := NewEngine(config.NewOptions(nil)) // Routes e.addRoute(consts.MethodGet, "/users", handlerFunc("/users")) e.addRoute(consts.MethodGet, "/users/:id", handlerFunc("/users/:id")) testCases := []struct { whenMethod string whenURL string expectRoute interface{} expectParam map[string]string expectError error }{ { whenURL: "/users", expectRoute: "/users", expectParam: map[string]string{"x": ""}, }, { whenURL: "/users/1", expectRoute: "/users/:id", expectParam: map[string]string{"id": "1"}, }, { whenURL: "/user", expectRoute: nil, expectParam: map[string]string{"x": ""}, }, } for _, tc := range testCases { t.Run(tc.whenURL, func(t *testing.T) { ctx := e.NewContext() ctx.Request.SetRequestURI(tc.whenURL) ctx.Request.Header.SetMethod(consts.MethodGet) e.ServeHTTP(context.Background(), ctx) assert.DeepEqual(t, tc.expectRoute, getHelper(ctx, "path")) for param, expectedValue := range tc.expectParam { assert.DeepEqual(t, expectedValue, ctx.Param(param)) } checkUnusedParamValues(t, ctx, tc.expectParam) }) } } func TestRouterPriority(t *testing.T) { e := NewEngine(config.NewOptions(nil)) // Routes e.addRoute(consts.MethodGet, "/users", handlerFunc("/users")) e.addRoute(consts.MethodGet, "/users/new", handlerFunc("/users/new")) e.addRoute(consts.MethodGet, "/users/:id", handlerFunc("/users/:id")) e.addRoute(consts.MethodGet, "/users/dew", handlerFunc("/users/dew")) e.addRoute(consts.MethodGet, "/users/:id/files", handlerFunc("/users/:id/files")) e.addRoute(consts.MethodGet, "/users/newsee", handlerFunc("/users/newsee")) e.addRoute(consts.MethodGet, "/users/*x", handlerFunc("/users/*x")) e.addRoute(consts.MethodGet, "/users/new/*x", handlerFunc("/users/new/*x")) e.addRoute(consts.MethodGet, "/*x", handlerFunc("/*x")) testCases := []struct { whenMethod string whenURL string expectRoute interface{} expectParam map[string]string expectError error }{ { whenURL: "/users", expectRoute: "/users", }, { whenURL: "/users/new", expectRoute: "/users/new", }, { whenURL: "/users/1", expectRoute: "/users/:id", expectParam: map[string]string{"id": "1"}, }, { whenURL: "/users/dew", expectRoute: "/users/dew", }, { whenURL: "/users/1/files", expectRoute: "/users/:id/files", expectParam: map[string]string{"id": "1"}, }, { whenURL: "/users/new", expectRoute: "/users/new", }, { whenURL: "/users/news", expectRoute: "/users/:id", expectParam: map[string]string{"id": "news"}, }, { whenURL: "/users/newsee", expectRoute: "/users/newsee", }, { whenURL: "/users/joe/books", expectRoute: "/users/*x", expectParam: map[string]string{"x": "joe/books"}, }, { whenURL: "/users/new/someone", expectRoute: "/users/new/*x", expectParam: map[string]string{"x": "someone"}, }, { whenURL: "/users/dew/someone", expectRoute: "/users/*x", expectParam: map[string]string{"x": "dew/someone"}, }, { // Route > /users/*x should be matched although /users/dew exists whenURL: "/users/notexists/someone", expectRoute: "/users/*x", expectParam: map[string]string{"x": "notexists/someone"}, }, { whenURL: "/nousers", expectRoute: "/*x", expectParam: map[string]string{"x": "nousers"}, }, { whenURL: "/nousers/new", expectRoute: "/*x", expectParam: map[string]string{"x": "nousers/new"}, }, } for _, tc := range testCases { t.Run(tc.whenURL, func(t *testing.T) { ctx := e.NewContext() ctx.Request.SetRequestURI(tc.whenURL) ctx.Request.Header.SetMethod(consts.MethodGet) e.ServeHTTP(context.Background(), ctx) assert.DeepEqual(t, tc.expectRoute, getHelper(ctx, "path")) for param, expectedValue := range tc.expectParam { assert.DeepEqual(t, expectedValue, ctx.Param(param)) } checkUnusedParamValues(t, ctx, tc.expectParam) }) } } func TestRouterIssue1348(t *testing.T) { e := NewEngine(config.NewOptions(nil)) e.addRoute(consts.MethodGet, "/:lang/", handlerFunc("/:lang/")) e.addRoute(consts.MethodGet, "/:lang/dupa", handlerFunc("/:lang/dupa")) } func TestRouterPriorityNotFound(t *testing.T) { e := NewEngine(config.NewOptions(nil)) // Add e.addRoute(consts.MethodGet, "/a/foo", handlerFunc("/a/foo")) e.addRoute(consts.MethodGet, "/a/bar", handlerFunc("/a/bar")) testCases := []struct { whenMethod string whenURL string expectRoute interface{} expectParam map[string]string expectError error }{ { whenURL: "/a/foo", expectRoute: "/a/foo", }, { whenURL: "/a/bar", expectRoute: "/a/bar", }, { whenURL: "/abc/def", expectRoute: nil, }, } for _, tc := range testCases { t.Run(tc.whenURL, func(t *testing.T) { ctx := e.NewContext() ctx.Request.SetRequestURI(tc.whenURL) ctx.Request.Header.SetMethod(consts.MethodGet) e.ServeHTTP(context.Background(), ctx) assert.DeepEqual(t, tc.expectRoute, getHelper(ctx, "path")) for param, expectedValue := range tc.expectParam { assert.DeepEqual(t, expectedValue, ctx.Param(param)) } checkUnusedParamValues(t, ctx, tc.expectParam) }) } } func TestRouterParamNames(t *testing.T) { e := NewEngine(config.NewOptions(nil)) // Routes e.addRoute(consts.MethodGet, "/users", handlerFunc("/users")) e.addRoute(consts.MethodGet, "/users/:id", handlerFunc("/users/:id")) e.addRoute(consts.MethodGet, "/users/:uid/files/:fid", handlerFunc("/users/:uid/files/:fid")) testCases := []struct { whenMethod string whenURL string expectRoute interface{} expectParam map[string]string expectError error }{ { whenURL: "/users", expectRoute: "/users", }, { whenURL: "/users/1", expectRoute: "/users/:id", expectParam: map[string]string{"id": "1"}, }, { whenURL: "/users/1/files/1", expectRoute: "/users/:uid/files/:fid", expectParam: map[string]string{ "uid": "1", "fid": "1", }, }, } for _, tc := range testCases { t.Run(tc.whenURL, func(t *testing.T) { ctx := e.NewContext() ctx.Request.SetRequestURI(tc.whenURL) ctx.Request.Header.SetMethod(consts.MethodGet) e.ServeHTTP(context.Background(), ctx) assert.DeepEqual(t, tc.expectRoute, getHelper(ctx, "path")) for param, expectedValue := range tc.expectParam { assert.DeepEqual(t, expectedValue, ctx.Param(param)) } checkUnusedParamValues(t, ctx, tc.expectParam) }) } } func TestRouterStaticDynamicConflict(t *testing.T) { e := NewEngine(config.NewOptions(nil)) e.addRoute(consts.MethodGet, "/dictionary/skills", handlerHelper("/dictionary/skills", "a", 1)) e.addRoute(consts.MethodGet, "/dictionary/:name", handlerHelper("/dictionary/:name", "b", 2)) e.addRoute(consts.MethodGet, "/users/new", handlerHelper("/users/new", "d", 4)) e.addRoute(consts.MethodGet, "/users/:name", handlerHelper("/users/:name", "e", 5)) e.addRoute(consts.MethodGet, "/server", handlerHelper("/server", "c", 3)) e.addRoute(consts.MethodGet, "/", handlerHelper("/", "f", 6)) testCases := []struct { whenMethod string whenURL string expectRoute interface{} expectParam map[string]string expectError error }{ { whenURL: "/dictionary/skills", expectRoute: "/dictionary/skills", expectParam: map[string]string{"x": ""}, }, { whenURL: "/dictionary/skillsnot", expectRoute: "/dictionary/:name", expectParam: map[string]string{"name": "skillsnot"}, }, { whenURL: "/dictionary/type", expectRoute: "/dictionary/:name", expectParam: map[string]string{"name": "type"}, }, { whenURL: "/server", expectRoute: "/server", }, { whenURL: "/users/new", expectRoute: "/users/new", }, { whenURL: "/users/new2", expectRoute: "/users/:name", expectParam: map[string]string{"name": "new2"}, }, { whenURL: "/", expectRoute: "/", }, } for _, tc := range testCases { t.Run(tc.whenURL, func(t *testing.T) { ctx := e.NewContext() ctx.Request.SetRequestURI(tc.whenURL) ctx.Request.Header.SetMethod(consts.MethodGet) e.ServeHTTP(context.Background(), ctx) assert.DeepEqual(t, tc.expectRoute, getHelper(ctx, "path")) for param, expectedValue := range tc.expectParam { assert.DeepEqual(t, expectedValue, ctx.Param(param)) } checkUnusedParamValues(t, ctx, tc.expectParam) }) } } func TestRouterParamBacktraceNotFound(t *testing.T) { e := NewEngine(config.NewOptions(nil)) // Add e.addRoute(consts.MethodGet, "/:param1", handlerFunc("/:param1")) e.addRoute(consts.MethodGet, "/:param1/foo", handlerFunc("/:param1/foo")) e.addRoute(consts.MethodGet, "/:param1/bar", handlerFunc("/:param1/bar")) e.addRoute(consts.MethodGet, "/:param1/bar/:param2", handlerFunc("/:param1/bar/:param2")) testCases := []struct { name string whenMethod string whenURL string expectRoute interface{} expectParam map[string]string expectError error }{ { name: "route /a to /:param1", whenURL: "/a", expectRoute: "/:param1", expectParam: map[string]string{"param1": "a"}, }, { name: "route /a/foo to /:param1/foo", whenURL: "/a/foo", expectRoute: "/:param1/foo", expectParam: map[string]string{"param1": "a"}, }, { name: "route /a/bar to /:param1/bar", whenURL: "/a/bar", expectRoute: "/:param1/bar", expectParam: map[string]string{"param1": "a"}, }, { name: "route /a/bar/b to /:param1/bar/:param2", whenURL: "/a/bar/b", expectRoute: "/:param1/bar/:param2", expectParam: map[string]string{ "param1": "a", "param2": "b", }, }, { name: "route /a/bbbbb should return 404", whenURL: "/a/bbbbb", expectRoute: nil, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { ctx := e.NewContext() ctx.Request.SetRequestURI(tc.whenURL) ctx.Request.Header.SetMethod(consts.MethodGet) e.ServeHTTP(context.Background(), ctx) assert.DeepEqual(t, tc.expectRoute, getHelper(ctx, "path")) for param, expectedValue := range tc.expectParam { assert.DeepEqual(t, expectedValue, ctx.Param(param)) } checkUnusedParamValues(t, ctx, tc.expectParam) }) } } ================================================ FILE: pkg/route/routes_timing_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * The MIT License (MIT) * * Copyright (c) 2014 Manuel Martínez-Almeida * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors */ package route import ( "context" "testing" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/protocol" ) type Route struct { path string } func BenchmarkTree_FindStatic(b *testing.B) { tree := &router{method: "GET", root: &node{}} static := []*Route{ {"/"}, {"/cmd.html"}, {"/code.html"}, {"/contrib.html"}, {"/contribute.html"}, {"/debugging_with_gdb.html"}, {"/docs.html"}, {"/effective_go.html"}, {"/files.log"}, {"/gccgo_contribute.html"}, {"/gccgo_install.html"}, {"/go-logo-black.png"}, {"/go-logo-blue.png"}, {"/go-logo-white.png"}, {"/go1.1.html"}, {"/go1.2.html"}, {"/go1.html"}, {"/go1compat.html"}, {"/go_faq.html"}, {"/go_mem.html"}, {"/go_spec.html"}, {"/help.html"}, {"/ie.css"}, {"/install-source.html"}, {"/install.html"}, {"/logo-153x55.png"}, {"/Makefile"}, {"/root.html"}, {"/share.png"}, {"/sieve.gif"}, {"/tos.html"}, {"/articles/"}, {"/articles/go_command.html"}, {"/articles/index.html"}, {"/articles/wiki/"}, {"/articles/wiki/edit.html"}, {"/articles/wiki/final-noclosure.go"}, {"/articles/wiki/final-noerror.go"}, {"/articles/wiki/final-parsetemplate.go"}, {"/articles/wiki/final-template.go"}, {"/articles/wiki/final.go"}, {"/articles/wiki/get.go"}, {"/articles/wiki/http-sample.go"}, {"/articles/wiki/index.html"}, {"/articles/wiki/Makefile"}, {"/articles/wiki/notemplate.go"}, {"/articles/wiki/part1-noerror.go"}, {"/articles/wiki/part1.go"}, {"/articles/wiki/part2.go"}, {"/articles/wiki/part3-errorhandling.go"}, {"/articles/wiki/part3.go"}, {"/articles/wiki/test.bash"}, {"/articles/wiki/test_edit.good"}, {"/articles/wiki/test_Test.txt.good"}, {"/articles/wiki/test_view.good"}, {"/articles/wiki/view.html"}, {"/codewalk/"}, {"/codewalk/codewalk.css"}, {"/codewalk/codewalk.js"}, {"/codewalk/codewalk.xml"}, {"/codewalk/functions.xml"}, {"/codewalk/markov.go"}, {"/codewalk/markov.xml"}, {"/codewalk/pig.go"}, {"/codewalk/popout.png"}, {"/codewalk/run"}, {"/codewalk/sharemem.xml"}, {"/codewalk/urlpoll.go"}, {"/devel/"}, {"/devel/release.html"}, {"/devel/weekly.html"}, {"/gopher/"}, {"/gopher/appenginegopher.jpg"}, {"/gopher/appenginegophercolor.jpg"}, {"/gopher/appenginelogo.gif"}, {"/gopher/bumper.png"}, {"/gopher/bumper192x108.png"}, {"/gopher/bumper320x180.png"}, {"/gopher/bumper480x270.png"}, {"/gopher/bumper640x360.png"}, {"/gopher/doc.png"}, {"/gopher/frontpage.png"}, {"/gopher/gopherbw.png"}, {"/gopher/gophercolor.png"}, {"/gopher/gophercolor16x16.png"}, {"/gopher/help.png"}, {"/gopher/pkg.png"}, {"/gopher/project.png"}, {"/gopher/ref.png"}, {"/gopher/run.png"}, {"/gopher/talks.png"}, {"/gopher/pencil/"}, {"/gopher/pencil/gopherhat.jpg"}, {"/gopher/pencil/gopherhelmet.jpg"}, {"/gopher/pencil/gophermega.jpg"}, {"/gopher/pencil/gopherrunning.jpg"}, {"/gopher/pencil/gopherswim.jpg"}, {"/gopher/pencil/gopherswrench.jpg"}, {"/play/"}, {"/play/fib.go"}, {"/play/hello.go"}, {"/play/life.go"}, {"/play/peano.go"}, {"/play/pi.go"}, {"/play/sieve.go"}, {"/play/solitaire.go"}, {"/play/tree.go"}, {"/progs/"}, {"/progs/cgo1.go"}, {"/progs/cgo2.go"}, {"/progs/cgo3.go"}, {"/progs/cgo4.go"}, {"/progs/defer.go"}, {"/progs/defer.out"}, {"/progs/defer2.go"}, {"/progs/defer2.out"}, {"/progs/eff_bytesize.go"}, {"/progs/eff_bytesize.out"}, {"/progs/eff_qr.go"}, {"/progs/eff_sequence.go"}, {"/progs/eff_sequence.out"}, {"/progs/eff_unused1.go"}, {"/progs/eff_unused2.go"}, {"/progs/error.go"}, {"/progs/error2.go"}, {"/progs/error3.go"}, {"/progs/error4.go"}, {"/progs/go1.go"}, {"/progs/gobs1.go"}, {"/progs/gobs2.go"}, {"/progs/image_draw.go"}, {"/progs/image_package1.go"}, {"/progs/image_package1.out"}, {"/progs/image_package2.go"}, {"/progs/image_package2.out"}, {"/progs/image_package3.go"}, {"/progs/image_package3.out"}, {"/progs/image_package4.go"}, {"/progs/image_package4.out"}, {"/progs/image_package5.go"}, {"/progs/image_package5.out"}, {"/progs/image_package6.go"}, {"/progs/image_package6.out"}, {"/progs/interface.go"}, {"/progs/interface2.go"}, {"/progs/interface2.out"}, {"/progs/json1.go"}, {"/progs/json2.go"}, {"/progs/json2.out"}, {"/progs/json3.go"}, {"/progs/json4.go"}, {"/progs/json5.go"}, {"/progs/run"}, {"/progs/slices.go"}, {"/progs/timeout1.go"}, {"/progs/timeout2.go"}, {"/progs/update.bash"}, } for _, route := range static { tree.addRoute(route.path, fakeHandler(route.path)) } ps := getParams() b.ResetTimer() for i := 0; i < b.N; i++ { for _, request := range static { tree.find(request.path, ps, false) } } } func BenchmarkTree_FindGithub(b *testing.B) { tree := &router{method: "GET", root: &node{}} static := []*Route{ // OAuth Authorizations {"/authorizations"}, {"/authorizations/:id"}, //{"/authorizations"}, //{"/authorizations/clients/:client_id"}, //{"/authorizations/:id"}, //{"/authorizations/:id"}, {"/applications/:client_id/tokens/:access_token"}, {"/applications/:client_id/tokens"}, //{"/applications/:client_id/tokens/:access_token"}, // Activity {"/events"}, {"/repos/:owner/:repo/events"}, {"/networks/:owner/:repo/events"}, {"/orgs/:org/events"}, {"/users/:user/received_events"}, {"/users/:user/received_events/public"}, {"/users/:user/events"}, {"/users/:user/events/public"}, {"/users/:user/events/orgs/:org"}, {"/feeds"}, //{"/notifications"}, {"/repos/:owner/:repo/notifications"}, {"/notifications"}, //{"/repos/:owner/:repo/notifications"}, {"/notifications/threads/:id"}, //{"/notifications/threads/:id"}, {"/notifications/threads/:id/subscription"}, //{"/notifications/threads/:id/subscription"}, //{"/notifications/threads/:id/subscription"}, {"/repos/:owner/:repo/stargazers"}, {"/users/:user/starred"}, {"/user/starred"}, {"/user/starred/:owner/:repo"}, //{"/user/starred/:owner/:repo"}, //{"/user/starred/:owner/:repo"}, {"/repos/:owner/:repo/subscribers"}, {"/users/:user/subscriptions"}, {"/user/subscriptions"}, {"/repos/:owner/:repo/subscription"}, //{"/repos/:owner/:repo/subscription"}, //{"/repos/:owner/:repo/subscription"}, {"/user/subscriptions/:owner/:repo"}, //{"PUT", "/user/subscriptions/:owner/:repo"}, //{"DELETE", "/user/subscriptions/:owner/:repo"}, // Gists {"/users/:user/gists"}, {"/gists"}, //{"GET", "/gists/public"}, //{"GET", "/gists/starred"}, {"/gists/:id"}, //{"POST", "/gists"}, //{"PATCH", "/gists/:id"}, {"/gists/:id/star"}, //{"DELETE", "/gists/:id/star"}, //{"GET", "/gists/:id/star"}, {"/gists/:id/forks"}, //{"DELETE", "/gists/:id"}, // Git Data {"/repos/:owner/:repo/git/blobs/:sha"}, {"/repos/:owner/:repo/git/blobs"}, {"/repos/:owner/:repo/git/commits/:sha"}, {"/repos/:owner/:repo/git/commits"}, //{"GET", "/repos/:owner/:repo/git/refs/*ref"}, {"/repos/:owner/:repo/git/refs"}, //{"POST", "/repos/:owner/:repo/git/refs"}, //{"PATCH", "/repos/:owner/:repo/git/refs/*ref"}, //{"DELETE", "/repos/:owner/:repo/git/refs/*ref"}, {"/repos/:owner/:repo/git/tags/:sha"}, {"/repos/:owner/:repo/git/tags"}, {"/repos/:owner/:repo/git/trees/:sha"}, {"/repos/:owner/:repo/git/trees"}, {"/issues"}, {"/user/issues"}, {"/orgs/:org/issues"}, {"/repos/:owner/:repo/issues"}, {"/repos/:owner/:repo/issues/:number"}, //{"POST", "/repos/:owner/:repo/issues"}, //{"PATCH", "/repos/:owner/:repo/issues/:number"}, {"/repos/:owner/:repo/assignees"}, {"/repos/:owner/:repo/assignees/:assignee"}, {"/repos/:owner/:repo/issues/:number/comments"}, //{"GET", "/repos/:owner/:repo/issues/comments"}, //{"GET", "/repos/:owner/:repo/issues/comments/:id"}, //{"POST", "/repos/:owner/:repo/issues/:number/comments"}, //{"PATCH", "/repos/:owner/:repo/issues/comments/:id"}, //{"DELETE", "/repos/:owner/:repo/issues/comments/:id"}, {"/repos/:owner/:repo/issues/:number/events"}, //{"GET", "/repos/:owner/:repo/issues/events"}, //{"GET", "/repos/:owner/:repo/issues/events/:id"}, {"/repos/:owner/:repo/labels"}, {"/repos/:owner/:repo/labels/:name"}, //{"POST", "/repos/:owner/:repo/labels"}, //{"PATCH", "/repos/:owner/:repo/labels/:name"}, //{"DELETE", "/repos/:owner/:repo/labels/:name"}, {"/repos/:owner/:repo/issues/:number/labels"}, //{"POST", "/repos/:owner/:repo/issues/:number/labels"}, //{"DELETE", "/repos/:owner/:repo/issues/:number/labels/:name"}, //{"PUT", "/repos/:owner/:repo/issues/:number/labels"}, //{"DELETE", "/repos/:owner/:repo/issues/:number/labels"}, {"/repos/:owner/:repo/milestones/:number/labels"}, {"/repos/:owner/:repo/milestones"}, {"/repos/:owner/:repo/milestones/:number"}, //{"POST", "/repos/:owner/:repo/milestones"}, //{"PATCH", "/repos/:owner/:repo/milestones/:number"}, //{"DELETE", "/repos/:owner/:repo/milestones/:number"}, // Miscellaneous {"/emojis"}, {"/gitignore/templates"}, {"/gitignore/templates/:name"}, {"/markdown"}, {"/markdown/raw"}, {"/meta"}, {"/rate_limit"}, // Organizations {"/users/:user/orgs"}, {"/user/orgs"}, {"/orgs/:org"}, //{"PATCH", "/orgs/:org"}, {"/orgs/:org/members"}, {"/orgs/:org/members/:user"}, //{"DELETE", "/orgs/:org/members/:user"}, {"/orgs/:org/public_members"}, {"/orgs/:org/public_members/:user"}, //{"PUT", "/orgs/:org/public_members/:user"}, //{"DELETE", "/orgs/:org/public_members/:user"}, {"/orgs/:org/teams"}, {"/teams/:id"}, //{"POST", "/orgs/:org/teams"}, //{"PATCH", "/teams/:id"}, //{"DELETE", "/teams/:id"}, {"/teams/:id/members"}, {"/teams/:id/members/:user"}, //{"PUT", "/teams/:id/members/:user"}, //{"DELETE", "/teams/:id/members/:user"}, {"/teams/:id/repos"}, {"/teams/:id/repos/:owner/:repo"}, //{"PUT", "/teams/:id/repos/:owner/:repo"}, //{"DELETE", "/teams/:id/repos/:owner/:repo"}, {"/user/teams"}, // Pull Requests {"/repos/:owner/:repo/pulls"}, {"/repos/:owner/:repo/pulls/:number"}, //{"POST", "/repos/:owner/:repo/pulls"}, //{"PATCH", "/repos/:owner/:repo/pulls/:number"}, {"/repos/:owner/:repo/pulls/:number/commits"}, {"/repos/:owner/:repo/pulls/:number/files"}, {"/repos/:owner/:repo/pulls/:number/merge"}, //{"PUT", "/repos/:owner/:repo/pulls/:number/merge"}, {"/repos/:owner/:repo/pulls/:number/comments"}, //{"GET", "/repos/:owner/:repo/pulls/comments"}, //{"GET", "/repos/:owner/:repo/pulls/comments/:number"}, //{"PUT", "/repos/:owner/:repo/pulls/:number/comments"}, //{"PATCH", "/repos/:owner/:repo/pulls/comments/:number"}, //{"DELETE", "/repos/:owner/:repo/pulls/comments/:number"}, // Repositories {"/user/repos"}, {"/users/:user/repos"}, {"/orgs/:org/repos"}, {"/repositories"}, //{"POST", "/user/repos"}, //{"POST", "/orgs/:org/repos"}, {"/repos/:owner/:repo"}, //{"PATCH", "/repos/:owner/:repo"}, {"/repos/:owner/:repo/contributors"}, {"/repos/:owner/:repo/languages"}, {"/repos/:owner/:repo/teams"}, {"/repos/:owner/:repo/tags"}, {"/repos/:owner/:repo/branches"}, {"/repos/:owner/:repo/branches/:branch"}, //{"DELETE", "/repos/:owner/:repo"}, {"/repos/:owner/:repo/collaborators"}, {"/repos/:owner/:repo/collaborators/:user"}, //{"PUT", "/repos/:owner/:repo/collaborators/:user"}, //{"DELETE", "/repos/:owner/:repo/collaborators/:user"}, {"/repos/:owner/:repo/comments"}, {"/repos/:owner/:repo/commits/:sha/comments"}, //{"POST", "/repos/:owner/:repo/commits/:sha/comments"}, {"/repos/:owner/:repo/comments/:id"}, //{"PATCH", "/repos/:owner/:repo/comments/:id"}, //{"DELETE", "/repos/:owner/:repo/comments/:id"}, {"/repos/:owner/:repo/commits"}, {"/repos/:owner/:repo/commits/:sha"}, {"/repos/:owner/:repo/readme"}, //{"GET", "/repos/:owner/:repo/contents/*path"}, //{"PUT", "/repos/:owner/:repo/contents/*path"}, //{"DELETE", "/repos/:owner/:repo/contents/*path"}, //{"GET", "/repos/:owner/:repo/:archive_format/:ref"}, {"/repos/:owner/:repo/keys"}, {"/repos/:owner/:repo/keys/:id"}, //{"POST", "/repos/:owner/:repo/keys"}, //{"PATCH", "/repos/:owner/:repo/keys/:id"}, //{"DELETE", "/repos/:owner/:repo/keys/:id"}, {"/repos/:owner/:repo/downloads"}, {"/repos/:owner/:repo/downloads/:id"}, //{"DELETE", "/repos/:owner/:repo/downloads/:id"}, {"/repos/:owner/:repo/forks"}, //{"POST", "/repos/:owner/:repo/forks"}, {"/repos/:owner/:repo/hooks"}, {"/repos/:owner/:repo/hooks/:id"}, //{"POST", "/repos/:owner/:repo/hooks"}, //{"PATCH", "/repos/:owner/:repo/hooks/:id"}, //{"POST", "/repos/:owner/:repo/hooks/:id/tests"}, //{"DELETE", "/repos/:owner/:repo/hooks/:id"}, //{"POST", "/repos/:owner/:repo/merges"}, {"/repos/:owner/:repo/releases"}, {"/repos/:owner/:repo/releases/:id"}, //{"POST", "/repos/:owner/:repo/releases"}, //{"PATCH", "/repos/:owner/:repo/releases/:id"}, //{"DELETE", "/repos/:owner/:repo/releases/:id"}, {"/repos/:owner/:repo/releases/:id/assets"}, {"/repos/:owner/:repo/stats/contributors"}, {"/repos/:owner/:repo/stats/commit_activity"}, {"/repos/:owner/:repo/stats/code_frequency"}, {"/repos/:owner/:repo/stats/participation"}, {"/repos/:owner/:repo/stats/punch_card"}, {"/repos/:owner/:repo/statuses/:ref"}, //{"POST", "/repos/:owner/:repo/statuses/:ref"}, // Search {"/search/repositories"}, {"/search/code"}, {"/search/issues"}, {"/search/users"}, {"/legacy/issues/search/:owner/:repository/:state/:keyword"}, {"/legacy/repos/search/:keyword"}, {"/legacy/user/search/:keyword"}, {"/legacy/user/email/:email"}, // Users {"/users/:user"}, {"/user"}, //{"PATCH", "/user"}, {"/users"}, {"/user/emails"}, //{"POST", "/user/emails"}, //{"DELETE", "/user/emails"}, {"/users/:user/followers"}, {"/user/followers"}, {"/users/:user/following"}, {"/user/following"}, {"/user/following/:user"}, {"/users/:user/following/:target_user"}, //{"PUT", "/user/following/:user"}, //{"DELETE", "/user/following/:user"}, {"/users/:user/keys"}, {"/user/keys"}, {"/user/keys/:id"}, //{"POST", "/user/keys"}, //{"PATCH", "/user/keys/:id"}, //{"DELETE", "/user/keys/:id"}, } for _, route := range static { tree.addRoute(route.path, fakeHandler(route.path)) } ps := getParams() b.ResetTimer() for i := 0; i < b.N; i++ { for _, request := range static { tree.find(request.path, ps, false) } } } func BenchmarkTree_FindStaticTsr(b *testing.B) { tree := &router{method: "GET", root: &node{}} routes := [...]string{ "/doc/foo/go_faq.html/", } for _, route := range routes { tree.addRoute(route, fakeHandler(route)) } tr := testRequests{ {"/doc/foo/go_faq.html", false, "/doc/foo/go_faq.html", nil}, } ps := getParams() b.ResetTimer() for i := 0; i < b.N; i++ { for _, request := range tr { tree.find(request.path, ps, false) } } } func BenchmarkTree_FindParam(b *testing.B) { tree := &router{method: "GET", root: &node{}} routes := [...]string{ "/hi/:key1/foo/:key2", } for _, route := range routes { tree.addRoute(route, fakeHandler(route)) } tr := testRequests{ {"/hi/1/foo/2", false, "/hi", nil}, } ps := getParams() b.ResetTimer() for i := 0; i < b.N; i++ { for _, request := range tr { tree.find(request.path, ps, false) } } } func BenchmarkTree_FindParamTsr(b *testing.B) { tree := &router{method: "GET", root: &node{}} routes := [...]string{ "/hi/:key1/foo/:key2/", } for _, route := range routes { tree.addRoute(route, fakeHandler(route)) } tr := testRequests{ {"/hi/1/foo/2", false, "/hi", nil}, } ps := getParams() b.ResetTimer() for i := 0; i < b.N; i++ { for _, request := range tr { tree.find(request.path, ps, false) } } } func BenchmarkTree_FindAny(b *testing.B) { tree := &router{method: "GET", root: &node{}} routes := [...]string{ "/hi/*key1", } for _, route := range routes { tree.addRoute(route, fakeHandler(route)) } tr := testRequests{ {"/hi/foo", false, "/hi", nil}, } ps := getParams() b.ResetTimer() for i := 0; i < b.N; i++ { for _, request := range tr { tree.find(request.path, ps, false) } } } func BenchmarkTree_FindAnyFallback(b *testing.B) { tree := &router{method: "GET", root: &node{}} routes := [...]string{ "/hi/a/b/c/d/e/*key1", "/*key2", } for _, route := range routes { tree.addRoute(route, fakeHandler(route)) } tr := testRequests{ {"/hi/a/b/c/d/f", false, "/*key2", nil}, } ps := getParams() b.ResetTimer() for i := 0; i < b.N; i++ { for _, request := range tr { tree.find(request.path, ps, false) } } } func BenchmarkRouteStatic(b *testing.B) { r := NewEngine(config.NewOptions(nil)) r.GET("/hi/foo", func(c context.Context, ctx *app.RequestContext) {}) ctx := r.NewContext() req := protocol.NewRequest("GET", "/hi/foo", nil) req.CopyTo(&ctx.Request) b.ResetTimer() for i := 0; i < b.N; i++ { r.ServeHTTP(context.Background(), ctx) // ctx.index = -1 } } func BenchmarkRouteParam(b *testing.B) { r := NewEngine(config.NewOptions(nil)) r.GET("/hi/:user", func(c context.Context, ctx *app.RequestContext) {}) ctx := r.NewContext() req := protocol.NewRequest("GET", "/hi/foo", nil) req.CopyTo(&ctx.Request) b.ResetTimer() for i := 0; i < b.N; i++ { r.ServeHTTP(context.Background(), ctx) // ctx.index = -1 } } func BenchmarkRouteAny(b *testing.B) { r := NewEngine(config.NewOptions(nil)) r.GET("/hi/*user", func(c context.Context, ctx *app.RequestContext) {}) ctx := r.NewContext() req := protocol.NewRequest("GET", "/hi/foo/dy", nil) req.CopyTo(&ctx.Request) b.ResetTimer() for i := 0; i < b.N; i++ { r.ServeHTTP(context.Background(), ctx) // ctx.index = -1 } } ================================================ FILE: pkg/route/tree.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * The MIT License (MIT) * * Copyright (c) 2021 LabStack * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell * copies of the Software, and to permit persons to whom the Software is * furnished to do so, subject to the following conditions: * * The above copyright notice and this permission notice shall be included in * all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package route import ( "bytes" "fmt" "net/url" "strings" "unicode" "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/route/param" ) type router struct { method string root *node } type MethodTrees []*router func (trees MethodTrees) get(method string) *router { for _, tree := range trees { if tree.method == method { return tree } } return nil } func countParams(path string) uint16 { var n uint16 s := bytesconv.S2b(path) n += uint16(bytes.Count(s, bytestr.StrColon)) n += uint16(bytes.Count(s, bytestr.StrStar)) return n } type ( node struct { kind kind label byte prefix string parent *node children children // original path ppath string // param names pnames []string handlers app.HandlersChain paramChild *node anyChild *node // isLeaf indicates that node does not have child routes isLeaf bool } kind uint8 children []*node ) const ( // static kind skind kind = iota // param kind pkind // all kind akind paramLabel = byte(':') anyLabel = byte('*') slash = "/" nilString = "" ) func checkPathValid(path string) { if path == nilString { panic("empty path") } if path[0] != '/' { panic("path must begin with '/'") } for i, c := range []byte(path) { switch c { case ':': if (i < len(path)-1 && path[i+1] == '/') || i == (len(path)-1) { panic("wildcards must be named with a non-empty name in path '" + path + "'") } i++ for ; i < len(path) && path[i] != '/'; i++ { if path[i] == ':' || path[i] == '*' { panic("only one wildcard per path segment is allowed, find multi in path '" + path + "'") } } case '*': if i == len(path)-1 { panic("wildcards must be named with a non-empty name in path '" + path + "'") } if i > 0 && path[i-1] != '/' { panic(" no / before wildcards in path " + path) } for ; i < len(path); i++ { if path[i] == '/' { panic("catch-all routes are only allowed at the end of the path in path '" + path + "'") } } } } } // addRoute adds a node with the given handle to the path. func (r *router) addRoute(path string, h app.HandlersChain) { checkPathValid(path) var ( pnames []string // Param names ppath = path // Pristine path ) if h == nil { panic(fmt.Sprintf("Adding route without handler function: %v", path)) } // Add the front static route part of a non-static route for i, lcpIndex := 0, len(path); i < lcpIndex; i++ { // param route if path[i] == paramLabel { j := i + 1 r.insert(path[:i], nil, skind, nilString, nil) for ; i < lcpIndex && path[i] != '/'; i++ { } pnames = append(pnames, path[j:i]) path = path[:j] + path[i:] i, lcpIndex = j, len(path) if i == lcpIndex { // path node is last fragment of route path. ie. `/users/:id` r.insert(path[:i], h, pkind, ppath, pnames) return } else { r.insert(path[:i], nil, pkind, nilString, pnames) } } else if path[i] == anyLabel { r.insert(path[:i], nil, skind, nilString, nil) pnames = append(pnames, path[i+1:]) r.insert(path[:i+1], h, akind, ppath, pnames) return } } r.insert(path, h, skind, ppath, pnames) } func (r *router) insert(path string, h app.HandlersChain, t kind, ppath string, pnames []string) { currentNode := r.root if currentNode == nil { panic("hertz: invalid node") } search := path for { searchLen := len(search) prefixLen := len(currentNode.prefix) lcpLen := 0 max := prefixLen if searchLen < max { max = searchLen } for ; lcpLen < max && search[lcpLen] == currentNode.prefix[lcpLen]; lcpLen++ { } if lcpLen == 0 { // At root node currentNode.label = search[0] currentNode.prefix = search if h != nil { currentNode.kind = t currentNode.handlers = h currentNode.ppath = ppath currentNode.pnames = pnames } currentNode.isLeaf = currentNode.children == nil && currentNode.paramChild == nil && currentNode.anyChild == nil } else if lcpLen < prefixLen { // Split node n := newNode( currentNode.kind, currentNode.prefix[lcpLen:], currentNode, currentNode.children, currentNode.handlers, currentNode.ppath, currentNode.pnames, currentNode.paramChild, currentNode.anyChild, ) // Update parent path for all children to new node for _, child := range currentNode.children { child.parent = n } if currentNode.paramChild != nil { currentNode.paramChild.parent = n } if currentNode.anyChild != nil { currentNode.anyChild.parent = n } // Reset parent node currentNode.kind = skind currentNode.label = currentNode.prefix[0] currentNode.prefix = currentNode.prefix[:lcpLen] currentNode.children = nil currentNode.handlers = nil currentNode.ppath = nilString currentNode.pnames = nil currentNode.paramChild = nil currentNode.anyChild = nil currentNode.isLeaf = false // Only Static children could reach here currentNode.children = append(currentNode.children, n) if lcpLen == searchLen { // At parent node currentNode.kind = t currentNode.handlers = h currentNode.ppath = ppath currentNode.pnames = pnames } else { // Create child node n = newNode(t, search[lcpLen:], currentNode, nil, h, ppath, pnames, nil, nil) // Only Static children could reach here currentNode.children = append(currentNode.children, n) } currentNode.isLeaf = currentNode.children == nil && currentNode.paramChild == nil && currentNode.anyChild == nil } else if lcpLen < searchLen { search = search[lcpLen:] c := currentNode.findChildWithLabel(search[0]) if c != nil { // Go deeper currentNode = c continue } // Create child node n := newNode(t, search, currentNode, nil, h, ppath, pnames, nil, nil) switch t { case skind: currentNode.children = append(currentNode.children, n) case pkind: currentNode.paramChild = n case akind: currentNode.anyChild = n } currentNode.isLeaf = currentNode.children == nil && currentNode.paramChild == nil && currentNode.anyChild == nil } else { // Node already exists if currentNode.handlers != nil && h != nil { panic("handlers are already registered for path '" + ppath + "'") } if h != nil { currentNode.handlers = h currentNode.ppath = ppath currentNode.pnames = pnames } } return } } // find finds registered handler by method and path, parses URL params and puts params to context func (r *router) find(path string, paramsPointer *param.Params, unescape bool) (res nodeValue) { var ( cn = r.root // current node search = path // current path searchIndex = 0 buf []byte paramIndex int ) backtrackToNextNodeKind := func(fromKind kind) (nextNodeKind kind, valid bool) { previous := cn cn = previous.parent valid = cn != nil // Next node type by priority if previous.kind == akind { nextNodeKind = skind } else { nextNodeKind = previous.kind + 1 } if fromKind == skind { // when backtracking is done from static kind block we did not change search so nothing to restore return } // restore search to value it was before we move to current node we are backtracking from. if previous.kind == skind { searchIndex -= len(previous.prefix) } else { paramIndex-- // for param/any node.prefix value is always `:` so we can not deduce searchIndex from that and must use pValue // for that index as it would also contain part of path we cut off before moving into node we are backtracking from searchIndex -= len((*paramsPointer)[paramIndex].Value) (*paramsPointer) = (*paramsPointer)[:paramIndex] } search = path[searchIndex:] return } // search order: static > param > any for { if cn.kind == skind { if len(search) >= len(cn.prefix) && cn.prefix == search[:len(cn.prefix)] { // Continue search search = search[len(cn.prefix):] searchIndex = searchIndex + len(cn.prefix) } else { // not equal if (len(cn.prefix) == len(search)+1) && (cn.prefix[len(search)]) == '/' && cn.prefix[:len(search)] == search && (cn.handlers != nil || cn.anyChild != nil) { res.tsr = true } // No matching prefix, let's backtrack to the first possible alternative node of the decision path nk, ok := backtrackToNextNodeKind(skind) if !ok { return // No other possibilities on the decision path } else if nk == pkind { goto Param } else { // Not found (this should never be possible for static node we are looking currently) break } } } if search == nilString && len(cn.handlers) != 0 { res.handlers = cn.handlers break } // Static node if search != nilString { // If it can execute that logic, there is handler registered on the current node and search is `/`. if search == "/" && cn.handlers != nil { res.tsr = true } if child := cn.findChild(search[0]); child != nil { cn = child continue } } if search == nilString { if cd := cn.findChild('/'); cd != nil && (cd.handlers != nil || cd.anyChild != nil) { res.tsr = true } } Param: // Param node if child := cn.paramChild; search != nilString && child != nil { cn = child i := strings.Index(search, slash) if i == -1 { i = len(search) } (*paramsPointer) = (*paramsPointer)[:(paramIndex + 1)] val := search[:i] if unescape { if v, err := url.QueryUnescape(search[:i]); err == nil { val = v } } (*paramsPointer)[paramIndex].Value = val paramIndex++ search = search[i:] searchIndex = searchIndex + i if search == nilString { if cd := cn.findChild('/'); cd != nil && (cd.handlers != nil || cd.anyChild != nil) { res.tsr = true } } continue } Any: // Any node if child := cn.anyChild; child != nil { // If any node is found, use remaining path for paramValues cn = child (*paramsPointer) = (*paramsPointer)[:(paramIndex + 1)] index := len(cn.pnames) - 1 val := search if unescape { if v, err := url.QueryUnescape(search); err == nil { val = v } } (*paramsPointer)[index].Value = bytesconv.B2s(append(buf, val...)) // update indexes/search in case we need to backtrack when no handler match is found paramIndex++ searchIndex += len(search) search = nilString res.handlers = cn.handlers break } // Let's backtrack to the first possible alternative node of the decision path nk, ok := backtrackToNextNodeKind(akind) if !ok { break // No other possibilities on the decision path } else if nk == pkind { goto Param } else if nk == akind { goto Any } else { // Not found break } } if cn != nil { res.fullPath = cn.ppath for i, name := range cn.pnames { (*paramsPointer)[i].Key = name } } return } func (n *node) findChild(l byte) *node { for _, c := range n.children { if c.label == l { return c } } return nil } func (n *node) findChildWithLabel(l byte) *node { for _, c := range n.children { if c.label == l { return c } } if l == paramLabel { return n.paramChild } if l == anyLabel { return n.anyChild } return nil } func newNode(t kind, pre string, p *node, child children, mh app.HandlersChain, ppath string, pnames []string, paramChildren, anyChildren *node) *node { return &node{ kind: t, label: pre[0], prefix: pre, parent: p, children: child, ppath: ppath, pnames: pnames, handlers: mh, paramChild: paramChildren, anyChild: anyChildren, isLeaf: child == nil && paramChildren == nil && anyChildren == nil, } } // nodeValue holds return values of (*Node).getValue method type nodeValue struct { handlers app.HandlersChain tsr bool fullPath string } // Makes a case-insensitive lookup of the given path and tries to find a handler. // It returns the case-corrected path and a bool indicating whether the lookup // was successful. func (n *node) findCaseInsensitivePath(path string, fixTrailingSlash bool) (ciPath []byte, found bool) { ciPath = make([]byte, 0, len(path)+1) // preallocate enough memory // Match paramKind. if n.label == paramLabel { end := 0 for end < len(path) && path[end] != '/' { end++ } ciPath = append(ciPath, path[:end]...) if end < len(path) { if len(n.children) > 0 { path = path[end:] goto loop } if fixTrailingSlash && len(path) == end+1 { return ciPath, true } return } if n.handlers != nil { return ciPath, true } if fixTrailingSlash && len(n.children) == 1 { // No handle found. Check if a handle for this path with(without) a trailing slash exists n = n.children[0] if n.prefix == "/" && n.handlers != nil { return append(ciPath, '/'), true } } return } // Match allKind. if n.label == anyLabel { return append(ciPath, path...), true } // Match static kind. if len(path) >= len(n.prefix) && strings.EqualFold(path[:len(n.prefix)], n.prefix) { path = path[len(n.prefix):] ciPath = append(ciPath, n.prefix...) if len(path) == 0 { if n.handlers != nil { return ciPath, true } // No handle found. // Try to fix the path by adding a trailing slash. if fixTrailingSlash { for i := 0; i < len(n.children); i++ { if n.children[i].label == '/' { n = n.children[i] if (len(n.prefix) == 1 && n.handlers != nil) || (n.prefix == "*" && n.children[0].handlers != nil) { return append(ciPath, '/'), true } return } } } return } } else if fixTrailingSlash { // Nothing found. // Try to fix the path by adding / removing a trailing slash. if path == "/" { return ciPath, true } if len(path)+1 == len(n.prefix) && n.prefix[len(path)] == '/' && strings.EqualFold(path, n.prefix[:len(path)]) && n.handlers != nil { return append(ciPath, n.prefix...), true } } loop: // First match static kind. for _, node := range n.children { if unicode.ToLower(rune(path[0])) == unicode.ToLower(rune(node.label)) { out, found := node.findCaseInsensitivePath(path, fixTrailingSlash) if found { return append(ciPath, out...), true } } } if n.paramChild != nil { out, found := n.paramChild.findCaseInsensitivePath(path, fixTrailingSlash) if found { return append(ciPath, out...), true } } if n.anyChild != nil { out, found := n.anyChild.findCaseInsensitivePath(path, fixTrailingSlash) if found { return append(ciPath, out...), true } } // Nothing found. We can recommend to redirect to the same URL // without a trailing slash if a leaf exists for that path found = fixTrailingSlash && path == "/" && n.handlers != nil return } ================================================ FILE: pkg/route/tree_test.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * Copyright 2013 Julien Schmidt. All rights reserved. * Use of this source code is governed by a BSD-style license that can be found * at https://github.com/julienschmidt/httprouter/blob/master/LICENSE * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors. */ package route import ( "context" "fmt" "strings" "testing" "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/route/param" ) // Used as a workaround since we can't compare functions or their addresses var fakeHandlerValue string func fakeHandler(val string) app.HandlersChain { return app.HandlersChain{func(c context.Context, ctx *app.RequestContext) { fakeHandlerValue = val }} } type testRequests []struct { path string nilHandler bool route string ps param.Params } func getParams() *param.Params { ps := make(param.Params, 0, 20) return &ps } func checkRequests(t *testing.T, tree *router, requests testRequests, unescapes ...bool) { unescape := false if len(unescapes) >= 1 { unescape = unescapes[0] } for _, request := range requests { params := getParams() value := tree.find(request.path, params, unescape) if value.handlers == nil { if !request.nilHandler { t.Errorf("handle mismatch for route '%s': Expected non-nil handle", request.path) } } else if request.nilHandler { t.Errorf("handle mismatch for route '%s': Expected nil handle", request.path) } else { value.handlers[0](context.Background(), nil) if fakeHandlerValue != request.route { t.Errorf("handle mismatch for route '%s': Wrong handle (%s != %s)", request.path, fakeHandlerValue, request.route) } } for _, item := range request.ps { if item.Value != (*params).ByName(item.Key) { t.Errorf("mismatch params. path: %s, key: %s, expected value: %s, actual value: %s", request.path, item.Key, item.Value, (*params).ByName(item.Key)) } } } } func TestCountParams(t *testing.T) { if countParams("/path/:param1/static/*catch-all") != 2 { t.Fail() } if countParams(strings.Repeat("/:param", 256)) != 256 { t.Fail() } } func TestEmptyPath(t *testing.T) { tree := &router{method: "GET", root: &node{}} routes := [...]string{ "", "user", ":user", "*user", } for _, route := range routes { recv := catchPanic(func() { tree.addRoute(route, nil) }) if recv == nil { t.Fatalf("no panic while inserting route with empty path '%s", route) } } } func TestTreeAddAndGet(t *testing.T) { tree := &router{method: "GET", root: &node{}} routes := [...]string{ "/hi", "/contact", "/co", "/c", "/a", "/ab", "/doc/", "/doc/go_faq.html", "/doc/go1.html", "/α", "/β", } for _, route := range routes { tree.addRoute(route, fakeHandler(route)) } checkRequests(t, tree, testRequests{ {"", true, "", nil}, {"a", true, "", nil}, {"/a", false, "/a", nil}, {"/", true, "", nil}, {"/hi", false, "/hi", nil}, {"/contact", false, "/contact", nil}, {"/co", false, "/co", nil}, {"/con", true, "", nil}, // key mismatch {"/cona", true, "", nil}, // key mismatch {"/no", true, "", nil}, // no matching child {"/ab", false, "/ab", nil}, {"/α", false, "/α", nil}, {"/β", false, "/β", nil}, }) } func TestTreeWildcard(t *testing.T) { tree := &router{method: "GET", root: &node{}} routes := [...]string{ "/", "/cmd/:tool/:sub", "/cmd/:tool/", "/cmd/xxx/", "/src/*filepath", "/search/", "/search/:query", "/user_:name", "/user_:name/about", "/files/:dir/*filepath", "/doc/", "/doc/go_faq.html", "/doc/go1.html", "/info/:user/public", "/info/:user/project/:project", "/a/b/:c", "/a/:b/c/d", "/a/*b", } for _, route := range routes { tree.addRoute(route, fakeHandler(route)) } checkRequests(t, tree, testRequests{ {"/", false, "/", nil}, {"/cmd/test/", false, "/cmd/:tool/", param.Params{param.Param{Key: "tool", Value: "test"}}}, {"/cmd/test", true, "", nil}, {"/cmd/test/3", false, "/cmd/:tool/:sub", param.Params{param.Param{Key: "tool", Value: "test"}, param.Param{Key: "sub", Value: "3"}}}, {"/src/", false, "/src/*filepath", param.Params{param.Param{Key: "filepath", Value: ""}}}, {"/src/some/file.png", false, "/src/*filepath", param.Params{param.Param{Key: "filepath", Value: "some/file.png"}}}, {"/search/", false, "/search/", nil}, {"/search/someth!ng+in+ünìcodé", false, "/search/:query", param.Params{param.Param{Key: "query", Value: "someth!ng+in+ünìcodé"}}}, {"/search/someth!ng+in+ünìcodé/", true, "", nil}, {"/user_gopher", false, "/user_:name", param.Params{param.Param{Key: "name", Value: "gopher"}}}, {"/user_gopher/about", false, "/user_:name/about", param.Params{param.Param{Key: "name", Value: "gopher"}}}, {"/files/js/inc/framework.js", false, "/files/:dir/*filepath", param.Params{param.Param{Key: "dir", Value: "js"}, param.Param{Key: "filepath", Value: "inc/framework.js"}}}, {"/info/gordon/public", false, "/info/:user/public", param.Params{param.Param{Key: "user", Value: "gordon"}}}, {"/info/gordon/project/go", false, "/info/:user/project/:project", param.Params{param.Param{Key: "user", Value: "gordon"}, param.Param{Key: "project", Value: "go"}}}, {"/a/b/c", false, "/a/b/:c", param.Params{param.Param{Key: "c", Value: "c"}}}, {"/a/b/c/d", false, "/a/:b/c/d", param.Params{param.Param{Key: "b", Value: "b"}}}, {"/a/b", false, "/a/*b", param.Params{param.Param{Key: "b", Value: "b"}}}, }) } func TestUnescapeParameters(t *testing.T) { tree := &router{method: "GET", root: &node{}} routes := [...]string{ "/", "/cmd/:tool/:sub", "/cmd/:tool/", "/src/*filepath", "/search/:query", "/files/:dir/*filepath", "/info/:user/project/:project", "/info/:user", } for _, route := range routes { tree.addRoute(route, fakeHandler(route)) } unescape := true checkRequests(t, tree, testRequests{ {"/", false, "/", nil}, {"/cmd/test/", false, "/cmd/:tool/", param.Params{param.Param{Key: "tool", Value: "test"}}}, {"/cmd/test", true, "", nil}, {"/src/some/file.png", false, "/src/*filepath", param.Params{param.Param{Key: "filepath", Value: "some/file.png"}}}, {"/src/some/file+test.png", false, "/src/*filepath", param.Params{param.Param{Key: "filepath", Value: "some/file test.png"}}}, {"/src/some/file++++%%%%test.png", false, "/src/*filepath", param.Params{param.Param{Key: "filepath", Value: "some/file++++%%%%test.png"}}}, {"/src/some/file%2Ftest.png", false, "/src/*filepath", param.Params{param.Param{Key: "filepath", Value: "some/file/test.png"}}}, {"/search/someth!ng+in+ünìcodé", false, "/search/:query", param.Params{param.Param{Key: "query", Value: "someth!ng in ünìcodé"}}}, {"/info/gordon/project/go", false, "/info/:user/project/:project", param.Params{param.Param{Key: "user", Value: "gordon"}, param.Param{Key: "project", Value: "go"}}}, {"/info/slash%2Fgordon", false, "/info/:user", param.Params{param.Param{Key: "user", Value: "slash/gordon"}}}, {"/info/slash%2Fgordon/project/Project%20%231", false, "/info/:user/project/:project", param.Params{param.Param{Key: "user", Value: "slash/gordon"}, param.Param{Key: "project", Value: "Project #1"}}}, {"/info/slash%%%%", false, "/info/:user", param.Params{param.Param{Key: "user", Value: "slash%%%%"}}}, {"/info/slash%%%%2Fgordon/project/Project%%%%20%231", false, "/info/:user/project/:project", param.Params{param.Param{Key: "user", Value: "slash%%%%2Fgordon"}, param.Param{Key: "project", Value: "Project%%%%20%231"}}}, }, unescape) } func catchPanic(testFunc func()) (recv interface{}) { defer func() { recv = recover() }() testFunc() return } type testRoute struct { path string conflict bool } func testRoutes(t *testing.T, routes []testRoute) { tree := &router{method: "GET", root: &node{}} for _, route := range routes { recv := catchPanic(func() { tree.addRoute(route.path, []app.HandlerFunc{ func(c context.Context, ctx *app.RequestContext) { fmt.Println("test") }, }) }) if route.conflict { if recv == nil { t.Errorf("no panic for conflicting route '%s'", route.path) } } else if recv != nil { t.Errorf("unexpected panic for route '%s': %v", route.path, recv) } } } func TestTreeWildcardConflict(t *testing.T) { routes := []testRoute{ {"/cmd/vet", false}, {"/cmd/:tool/:sub", false}, {"/src/*filepath", false}, {"/src/*filepathx", true}, {"/src/", false}, {"/src1/", false}, {"/src1/*filepath", false}, {"/src2*filepath", true}, {"/search/:query", false}, {"/search/invalid", false}, {"/user_:name", false}, {"/user_x", false}, {"/user_:name", true}, {"/id:id", false}, {"/id/:id", false}, } testRoutes(t, routes) } func TestTreeChildConflict(t *testing.T) { routes := []testRoute{ {"/cmd/vet", false}, {"/cmd/:tool/:sub", false}, {"/src/AUTHORS", false}, {"/src/*filepath", false}, {"/user_x", false}, {"/user_:name", false}, {"/id/:id", false}, {"/id:id", false}, {"/:id", false}, {"/*filepath", false}, } testRoutes(t, routes) } func TestTreeDuplicatePath(t *testing.T) { tree := &router{method: "GET", root: &node{}} routes := [...]string{ "/", "/doc/", "/src/*filepath", "/search/:query", "/user_:name", } for _, route := range routes { recv := catchPanic(func() { tree.addRoute(route, fakeHandler(route)) }) if recv != nil { t.Fatalf("panic inserting route '%s': %v", route, recv) } // Add again recv = catchPanic(func() { tree.addRoute(route, fakeHandler(route)) }) if recv == nil { t.Fatalf("no panic while inserting duplicate route '%s", route) } } checkRequests(t, tree, testRequests{ {"/", false, "/", nil}, {"/doc/", false, "/doc/", nil}, {"/src/some/file.png", false, "/src/*filepath", param.Params{param.Param{Key: "filepath", Value: "some/file.png"}}}, {"/search/someth!ng+in+ünìcodé", false, "/search/:query", param.Params{param.Param{Key: "query", Value: "someth!ng+in+ünìcodé"}}}, {"/user_gopher", false, "/user_:name", param.Params{param.Param{Key: "name", Value: "gopher"}}}, }) } func TestEmptyWildcardName(t *testing.T) { tree := &router{method: "GET", root: &node{}} routes := [...]string{ "/user:", "/user:/", "/cmd/:/", "/src/*", } for _, route := range routes { recv := catchPanic(func() { tree.addRoute(route, nil) }) if recv == nil { t.Fatalf("no panic while inserting route with empty wildcard name '%s", route) } } } func TestTreeCatchAllConflict(t *testing.T) { routes := []testRoute{ {"/src/*filepath/x", true}, {"/src2/", false}, {"/src2/*filepath/x", true}, {"/src3/*filepath", false}, {"/src3/*filepath/x", true}, } testRoutes(t, routes) } func TestTreeCatchMaxParams(t *testing.T) { tree := &router{method: "GET", root: &node{}} route := "/cmd/*filepath" tree.addRoute(route, fakeHandler(route)) } func TestTreeDoubleWildcard(t *testing.T) { const panicMsg = "only one wildcard per path segment is allowed" routes := [...]string{ "/:foo:bar", "/:foo:bar/", "/:foo*bar", } for _, route := range routes { tree := &router{method: "GET", root: &node{}} recv := catchPanic(func() { tree.addRoute(route, nil) }) if rs, ok := recv.(string); !ok || !strings.HasPrefix(rs, panicMsg) { t.Fatalf(`"Expected panic "%s" for route '%s', got "%v"`, panicMsg, route, recv) } } } func TestTreeTrailingSlashRedirect2(t *testing.T) { tree := &router{method: "GET", root: &node{}} routes := [...]string{ "/api/:version/seller/locales/get", "/api/v:version/seller/permissions/get", "/api/v:version/seller/university/entrance_knowledge_list/get", } for _, route := range routes { recv := catchPanic(func() { tree.addRoute(route, fakeHandler(route)) }) if recv != nil { t.Fatalf("panic inserting route '%s': %v", route, recv) } } v := make(param.Params, 0, 1) tsrRoutes := [...]string{ "/api/v:version/seller/permissions/get/", "/api/version/seller/permissions/get/", } for _, route := range tsrRoutes { value := tree.find(route, &v, false) if value.handlers != nil { t.Fatalf("non-nil handler for TSR route '%s", route) } else if !value.tsr { t.Errorf("expected TSR recommendation for route '%s'", route) } } noTsrRoutes := [...]string{ "/api/v:version/seller/permissions/get/a", } for _, route := range noTsrRoutes { value := tree.find(route, &v, false) if value.handlers != nil { t.Fatalf("non-nil handler for No-TSR route '%s", route) } else if value.tsr { t.Errorf("expected no TSR recommendation for route '%s'", route) } } } func TestTreeTrailingSlashRedirect(t *testing.T) { tree := &router{method: "GET", root: &node{}} routes := [...]string{ "/hi", "/b/", "/search/:query", "/cmd/:tool/", "/src/*filepath", "/x", "/x/y", "/y/", "/y/z", "/0/:id", "/0/:id/1", "/1/:id/", "/1/:id/2", "/aa", "/a/", "/admin", "/admin/:category", "/admin/:category/:page", "/doc", "/doc/go_faq.html", "/doc/go1.html", "/no/a", "/no/b", "/api/hello/:name", "/user/:name/*id", "/resource", "/r/*id", "/book/biz/:name", "/book/biz/abc", "/book/biz/abc/bar", "/book/:page/:name", "/book/hello/:name/biz/", } for _, route := range routes { recv := catchPanic(func() { tree.addRoute(route, fakeHandler(route)) }) if recv != nil { t.Fatalf("panic inserting route '%s': %v", route, recv) } } tsrRoutes := [...]string{ "/hi/", "/b", "/search/gopher/", "/cmd/vet", "/src", "/x/", "/y", "/0/go/", "/1/go", "/a", "/admin/", "/admin/config/", "/admin/config/permissions/", "/doc/", "/user/name", "/r", "/book/hello/a/biz", "/book/biz/foo/", "/book/biz/abc/bar/", } v := make(param.Params, 0, 10) for _, route := range tsrRoutes { value := tree.find(route, &v, false) if value.handlers != nil { t.Fatalf("non-nil handler for TSR route '%s", route) } else if !value.tsr { t.Errorf("expected TSR recommendation for route '%s'", route) } } noTsrRoutes := [...]string{ "/", "/no", "/no/", "/_", "/_/", "/api/world/abc", "/book", "/book/", "/book/hello/a/abc", "/book/biz/abc/biz", } for _, route := range noTsrRoutes { value := tree.find(route, &v, false) if value.handlers != nil { t.Fatalf("non-nil handler for No-TSR route '%s", route) } else if value.tsr { t.Errorf("expected no TSR recommendation for route '%s'", route) } } } func TestTreeRootTrailingSlashRedirect(t *testing.T) { tree := &router{method: "GET", root: &node{}} recv := catchPanic(func() { tree.addRoute("/:test", fakeHandler("/:test")) }) if recv != nil { t.Fatalf("panic inserting test route: %v", recv) } value := tree.find("/", nil, false) if value.handlers != nil { t.Fatalf("non-nil handler") } else if value.tsr { t.Errorf("expected no TSR recommendation") } } func TestTreeFindCaseInsensitivePath(t *testing.T) { tree := &router{method: "GET", root: &node{}} longPath := "/l" + strings.Repeat("o", 128) + "ng" lOngPath := "/l" + strings.Repeat("O", 128) + "ng/" routes := [...]string{ "/hi", "/b/", "/ABC/", "/search/:query", "/cmd/:tool/", "/src/*filepath", "/x", "/x/y", "/y/", "/y/z", "/0/:id", "/0/:id/1", "/1/:id/", "/1/:id/2", "/aa", "/a/", "/doc", "/doc/go_faq.html", "/doc/go1.html", "/doc/go/away", "/no/a", "/no/b", "/z/:id/2", "/z/:id/:age", "/x/:id/3/", "/x/:id/3/4", "/x/:id/:age/5", longPath, } for _, route := range routes { recv := catchPanic(func() { tree.addRoute(route, fakeHandler(route)) }) if recv != nil { t.Fatalf("panic inserting route '%s': %v", route, recv) } } // Check out == in for all registered routes // With fixTrailingSlash = true for _, route := range routes { out, found := tree.root.findCaseInsensitivePath(route, true) if !found { t.Errorf("Route '%s' not found!", route) } else if string(out) != route { t.Errorf("Wrong result for route '%s': %s", route, string(out)) } } // With fixTrailingSlash = false for _, route := range routes { out, found := tree.root.findCaseInsensitivePath(route, false) if !found { t.Errorf("Route '%s' not found!", route) } else if string(out) != route { t.Errorf("Wrong result for route '%s': %s", route, string(out)) } } tests := []struct { in string out string found bool slash bool }{ {"/HI", "/hi", true, false}, {"/HI/", "/hi", true, true}, {"/B", "/b/", true, true}, {"/B/", "/b/", true, false}, {"/abc", "/ABC/", true, true}, {"/abc/", "/ABC/", true, false}, {"/aBc", "/ABC/", true, true}, {"/aBc/", "/ABC/", true, false}, {"/abC", "/ABC/", true, true}, {"/abC/", "/ABC/", true, false}, {"/SEARCH/QUERY", "/search/QUERY", true, false}, {"/SEARCH/QUERY/", "/search/QUERY", true, true}, {"/CMD/TOOL/", "/cmd/TOOL/", true, false}, {"/CMD/TOOL", "/cmd/TOOL/", true, true}, {"/SRC/FILE/PATH", "/src/FILE/PATH", true, false}, {"/x/Y", "/x/y", true, false}, {"/x/Y/", "/x/y", true, true}, {"/X/y", "/x/y", true, false}, {"/X/y/", "/x/y", true, true}, {"/X/Y", "/x/y", true, false}, {"/X/Y/", "/x/y", true, true}, {"/Y/", "/y/", true, false}, {"/Y", "/y/", true, true}, {"/Y/z", "/y/z", true, false}, {"/Y/z/", "/y/z", true, true}, {"/Y/Z", "/y/z", true, false}, {"/Y/Z/", "/y/z", true, true}, {"/y/Z", "/y/z", true, false}, {"/y/Z/", "/y/z", true, true}, {"/Aa", "/aa", true, false}, {"/Aa/", "/aa", true, true}, {"/AA", "/aa", true, false}, {"/AA/", "/aa", true, true}, {"/aA", "/aa", true, false}, {"/aA/", "/aa", true, true}, {"/A/", "/a/", true, false}, {"/A", "/a/", true, true}, {"/DOC", "/doc", true, false}, {"/DOC/", "/doc", true, true}, {"/NO", "", false, true}, {"/DOC/GO", "", false, true}, {"/Z/1/2", "/z/1/2", true, false}, {"/Z/1/3", "/z/1/3", true, false}, {"/Z/1/2/", "/z/1/2", true, true}, {"/Z/1/3/", "/z/1/3", true, true}, {"/X/1/3", "/x/1/3/", true, true}, {"/X/1/3/5", "/x/1/3/5", true, false}, {lOngPath, longPath, true, true}, } // With fixTrailingSlash = true for _, test := range tests { out, found := tree.root.findCaseInsensitivePath(test.in, true) if found != test.found || (found && (string(out) != test.out)) { t.Errorf("Wrong result for '%s': got %s, %t; want %s, %t", test.in, string(out), found, test.out, test.found) return } } // With fixTrailingSlash = false for _, test := range tests { out, found := tree.root.findCaseInsensitivePath(test.in, false) if test.slash { if found { // test needs a trailingSlash fix. It must not be found! t.Errorf("Found without fixTrailingSlash: %s; got %s", test.in, string(out)) } } else { if found != test.found || (found && (string(out) != test.out)) { t.Errorf("Wrong result for '%s': got %s, %t; want %s, %t", test.in, string(out), found, test.out, test.found) return } } } } func TestTreeParamNotOptimize(t *testing.T) { tree := &router{method: "GET", root: &node{}} routes := [...]string{ "/:parama/start", "/:paramb", } for _, route := range routes { tree.addRoute(route, fakeHandler(route)) } checkRequests(t, tree, testRequests{ {"/1", false, "/:paramb", param.Params{param.Param{Key: "paramb", Value: "1"}}}, {"/1/start", false, "/:parama/start", param.Params{param.Param{Key: "parama", Value: "1"}}}, }) // other sequence tree = &router{method: "GET", root: &node{}} routes = [...]string{ "/:paramb", "/:parama/start", } for _, route := range routes { tree.addRoute(route, fakeHandler(route)) } checkRequests(t, tree, testRequests{ {"/1/start", false, "/:parama/start", param.Params{param.Param{Key: "parama", Value: "1"}}}, {"/1", false, "/:paramb", param.Params{param.Param{Key: "paramb", Value: "1"}}}, }) } ================================================ FILE: scripts/.utils/check_go_mod.sh ================================================ #!/bin/bash # Copyright 2025 CloudWeGo Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Go Module Dependency Checker # # This script validates that all CloudWeGo dependencies in go.mod use proper # semantic version tags (vA.B.C format) instead of commit hashes or branch names. # # Purpose: # - Ensures CloudWeGo dependencies follow proper versioning practices # - Validates go.mod contains only formal release versions # - Prevents release with non-standard dependency versions # # Usage: # ./scripts/.utils/check_go_mod.sh # # The script: # 1. Searches go.mod for github.com/cloudwego/* dependencies # 2. Checks that each dependency uses semantic versioning (vX.Y.Z) # 3. Reports any dependencies using non-standard versions # 4. Exits with error if invalid dependencies are found # # Called by: # - release.sh (before creating release tags) # - release-hotfix.sh (before creating hotfix tags) # GO_MOD_FILE="go.mod" cd $(git rev-parse --show-toplevel) # Check go.mod for cloudwego dependencies using proper tag versions echo "🔍 Checking go.mod for cloudwego dependencies..." if [ -f "$GO_MOD_FILE" ]; then # Find cloudwego dependencies that don't use formal tag versions (vA.B.C format) invalid_deps=$(grep "github.com/cloudwego/" "$GO_MOD_FILE" | \ grep -v "// indirect" | grep -v "module " | \ grep -v -E "v[0-9]+\.[0-9]+\.[0-9]+$" | awk '{print $1 " " $2}') if [ -n "$invalid_deps" ]; then echo "❌ Error: Found cloudwego dependencies not using formal tag versions:" echo echo " $invalid_deps" echo echo " All github.com/cloudwego/* dependencies must use formal tag versions like v0.1.2 format" echo " Please update go.mod to use proper tagged versions for these dependencies" exit 1 else echo "✅ All cloudwego dependencies use formal tag versions" fi else echo "⚠️ WARNING: go.mod not found" fi ================================================ FILE: scripts/.utils/check_version.sh ================================================ #!/bin/bash # Copyright 2025 CloudWeGo Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Version Consistency Checker # # This script validates that the version in version.go matches the intended # release version to ensure consistency across the codebase. # # Purpose: # - Prevents releasing with mismatched version numbers # - Ensures version.go is updated before creating release tags # - Provides interactive confirmation for version mismatches # # Usage: # ./scripts/.utils/check_version.sh # # Parameters: # target_version - The version being released (e.g., v1.2.3) # # The script: # 1. Reads the Version constant from version.go # 2. Compares it with the target release version # 3. Warns if versions don't match and prompts for confirmation # 4. Exits with error if user chooses not to continue # # Called by: # - release.sh (before creating release tags) # - release-hotfix.sh (before creating hotfix tags) # VERSION_FILE="version.go" new_version=$1 if [ -z "$new_version" ]; then echo "❌ Error: Version parameter is required" echo "Usage: $0 " exit 1 fi cd $(git rev-parse --show-toplevel) if [ -f "$VERSION_FILE" ]; then version_go_version=$(grep -o 'Version = "v[^"]*"' $VERSION_FILE | sed 's/Version = "\([^"]*\)"/\1/') if [ "$version_go_version" != "$new_version" ]; then echo "⚠️ WARNING: $VERSION_FILE contains $version_go_version but you're releasing $new_version" echo " Please update $VERSION_FILE to match the release version." read -p " Continue anyway? (y/N): " continue_anyway if [ "$continue_anyway" != "y" ] && [ "$continue_anyway" != "Y" ]; then echo "❌ Release cancelled" exit 1 fi else echo "✅ $VERSION_FILE matches release version: $new_version" fi else echo "⚠️ WARNING: $VERSION_FILE not found" fi ================================================ FILE: scripts/.utils/funcs.sh ================================================ #!/bin/bash # Copyright 2025 CloudWeGo Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Shared utility functions for release scripts # # Function to show changes between two commits/tags # Usage: show_changes [title] show_changes() { local from_ref=$1 local to_ref=$2 local title=${3:-"Changes"} echo echo "📋 $title:" echo "$(printf '%*s' ${#title} | tr ' ' '=')" # Check if there are any commits between the references local commit_count=$(git rev-list --count "$from_ref..$to_ref") if [ "$commit_count" -eq 0 ]; then echo "✅ No new commits. Nothing to release." return 1 fi # Show changes grouped by author email git log --pretty=format:"%ae|%h %s" "$from_ref..$to_ref" | sort | awk -F'|' '{ if ($1 != prev_email) { if (prev_email != "") print "" print "👤 " $1 ":" prev_email = $1 } print " " $2 }' return 0 } # Function to show changes for first release (no previous version) # Usage: show_first_release_changes [limit] show_first_release_changes() { local commit=$1 local limit=${2:-10} echo echo "📋 This is the first release - showing recent commits:" echo "====================================================" git log --pretty=format:"%ae|%h %s" -"$limit" "$commit" | sort | awk -F'|' '{ if ($1 != prev_email) { if (prev_email != "") print "" print "👤 " $1 ":" prev_email = $1 } print " " $2 }' } # Function to ask user for confirmation of changes # Usage: confirm_changes confirm_changes() { echo echo "📝 Please review the changes above." read -p "Do you want to proceed with these changes? (y/N): " confirm_changes if [ "$confirm_changes" != "y" ] && [ "$confirm_changes" != "Y" ]; then echo "❌ Release cancelled by user" return 1 fi return 0 } # Function to check if commit exists on origin # Usage: check_commit_exists check_commit_exists() { local commit=$1 local branch=$2 if ! git cat-file -e "$commit" 2>/dev/null; then echo "❌ Error: Commit $commit does not exist locally" exit 1 fi if ! git branch -r --contains "$commit" | grep -q "origin/$branch$"; then echo "❌ Error: Commit $commit does not exist on origin/$branch" exit 1 fi } # Function to get latest version tag # Usage: get_latest_version get_latest_version() { local latest_tag=$(git tag -l "v*.*.*" | sort -V | tail -n1) if [ -z "$latest_tag" ]; then echo "v0.0.0" else echo "$latest_tag" fi } # Function to get latest patch version for a given minor version # Usage: get_latest_patch_version get_latest_patch_version() { local minor_version=$1 # Remove 'v' prefix and '.x' suffix if present minor_version=${minor_version#v} minor_version=${minor_version%.x} local latest_patch=$(git tag -l "v${minor_version}.*" | sort -V | tail -n1) if [ -z "$latest_patch" ]; then echo "No tags found for minor version v${minor_version}.x" return 1 else echo "$latest_patch" fi } # Function to increment version # Usage: increment_version # where type is one of: major, minor, patch increment_version() { local version=$1 local type=$2 # Remove 'v' prefix version=${version#v} IFS='.' read -r major minor patch <<< "$version" case $type in "major") major=$((major + 1)) minor=0 patch=0 ;; "minor") minor=$((minor + 1)) patch=0 ;; "patch") patch=$((patch + 1)) ;; esac echo "v$major.$minor.$patch" } # Function to increment patch version (convenience wrapper) # Usage: increment_patch_version increment_patch_version() { local version=$1 increment_version "$version" "patch" } ================================================ FILE: scripts/release-hotfix.sh ================================================ #!/bin/bash # Copyright 2025 CloudWeGo Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Hertz Hotfix Release Script # # This script creates hotfix releases for specific minor versions. # It performs the following steps: # 1. Prompts for the minor version to create a hotfix for (e.g., v0.14.x) # 2. Finds the latest patch version for that minor version # 3. Creates or uses existing hotfix branch from the latest patch # 4. Shows changes in the hotfix branch # 5. Creates and pushes a new patch version tag # # Usage: # ./release-hotfix.sh # # Prerequisites: # - Git repository must be fetched with latest changes # - Hotfix branch must exist (script can create it if needed) # - Hotfix commits must be on origin # # The script will interactively prompt for: # - Minor version to hotfix (e.g., v0.14.x) # - Confirmation to create hotfix branch (if it doesn't exist) # - Confirmation before creating the hotfix tag # # Example workflow: # 1. Run script and specify v0.14.x # 2. Script creates v0.14.x-hotfix branch from latest v0.14.* tag # 3. Create PR with hotfix changes to the hotfix branch # 4. Run script again to create the hotfix release # set -e # Parse arguments DRY_RUN=false while [[ $# -gt 0 ]]; do case $1 in --dry-run) DRY_RUN=true shift ;; -h|--help) echo "Usage: $0 [--dry-run]" echo " --dry-run Show what would be done without making changes" exit 0 ;; *) echo "Unknown option: $1" echo "Use --help for usage information" exit 1 ;; esac done SCRIPTS_ROOT=$(git rev-parse --show-toplevel)/scripts/ CHECK_VERSION_CMD=$SCRIPTS_ROOT/.utils/check_version.sh CHECK_GO_MOD_CMD=$SCRIPTS_ROOT/.utils/check_go_mod.sh # Source shared utility functions source "$SCRIPTS_ROOT/.utils/funcs.sh" echo "🔧 Hertz Hotfix Release Script" echo "===============================" if [ "$DRY_RUN" = true ]; then echo "🧪 DRY RUN MODE - No changes will be made" echo "=========================================" fi # Fetch latest changes from origin echo "📥 Fetching latest changes from origin..." git fetch -p --force --tags origin # 1. Ask which minor version to fix echo read -p "⌨️ Enter the minor version to create hotfix for (e.g., v0.14.x): " minor_version if [ -z "$minor_version" ]; then echo "❌ Error: Minor version is required" exit 1 fi # Normalize minor version format if [[ ! "$minor_version" =~ ^v[0-9]+\.[0-9]+\.x$ ]]; then # Try to fix common formats if [[ "$minor_version" =~ ^v?([0-9]+)\.([0-9]+)$ ]]; then minor_version="v${BASH_REMATCH[1]}.${BASH_REMATCH[2]}.x" elif [[ "$minor_version" =~ ^v?([0-9]+)\.([0-9]+)\.x$ ]]; then minor_version="v${BASH_REMATCH[1]}.${BASH_REMATCH[2]}.x" else echo "❌ Error: Invalid minor version format. Please use format like 'v0.14.x'" exit 1 fi fi echo "🎯 Target minor version: $minor_version" # 2. Get the latest version of the given minor version latest_patch=$(get_latest_patch_version "$minor_version") if [ $? -ne 0 ]; then echo "❌ Error: $latest_patch" exit 1 fi echo "📌 Latest patch version: $latest_patch" # 3. Check if hotfix branch exists hotfix_branch="hotfix/${minor_version}" if ! git show-ref --verify --quiet "refs/remotes/origin/$hotfix_branch"; then echo "❌ Hotfix branch 'origin/$hotfix_branch' does not exist" echo if [ "$DRY_RUN" = true ]; then read -p "🧪 DRY RUN: Would create hotfix branch '$hotfix_branch' from $latest_patch. Continue? (y/N): " create_branch else read -p "🔧 Create hotfix branch '$hotfix_branch' from $latest_patch? (y/N): " create_branch fi if [ "$create_branch" = "y" ] || [ "$create_branch" = "Y" ]; then if [ "$DRY_RUN" = true ]; then echo "🧪 DRY RUN: Would execute:" echo " git push origin \"$latest_patch:refs/heads/$hotfix_branch\"" echo echo "✅ DRY RUN: Hotfix branch '$hotfix_branch' would be created" echo "🔗 In real mode, you would create a PR for your hotfix changes to this branch" exit 0 else echo "🌿 Creating hotfix branch $hotfix_branch..." git push origin "$latest_patch:refs/heads/$hotfix_branch" echo echo "✅ Hotfix branch '$hotfix_branch' created successfully!" echo "🔗 Please create a PR for your hotfix changes to this branch" echo " Then run this script again to create the hotfix release" exit 0 fi else echo "❌ Cannot proceed without hotfix branch" exit 1 fi fi echo "✅ Found hotfix branch: $hotfix_branch" # 4. Show diff between hotfix branch and latest patch version if ! show_changes "$latest_patch" "origin/$hotfix_branch" "Changes in hotfix branch since $latest_patch"; then exit 0 fi # Ask user to confirm the changes before proceeding if ! confirm_changes; then exit 1 fi # 5. Ask user if they want to release the new patch version new_patch_version=$(increment_patch_version "$latest_patch") echo echo "🔖 New patch version will be: $new_patch_version" # Check version.go file echo $CHECK_VERSION_CMD "$new_patch_version" # Check go.mod echo $CHECK_GO_MOD_CMD # Final confirmation echo if [ "$DRY_RUN" = true ]; then read -p "DRY RUN: Would create hotfix release tag $new_patch_version from hotfix branch $hotfix_branch. Continue? (y/N): " confirm else read -p "Create hotfix release tag $new_patch_version from hotfix branch $hotfix_branch? (y/N): " confirm fi if [ "$confirm" != "y" ] && [ "$confirm" != "Y" ]; then echo "❌ Release cancelled" exit 1 fi # Get the latest commit from hotfix branch hotfix_commit=$(git rev-parse "origin/$hotfix_branch") check_commit_exists "$hotfix_commit" "$hotfix_branch" # Create and push tag if [ "$DRY_RUN" = true ]; then echo echo "🧪 DRY RUN: Would execute:" echo " git tag -a \"$new_patch_version\" \"$hotfix_commit\" -m \"Hotfix release $new_patch_version\"" echo " git push origin \"$new_patch_version\"" else echo "🏷️ Creating tag $new_patch_version..." git tag -a "$new_patch_version" "$hotfix_commit" -m "Hotfix release $new_patch_version" echo "📤 Pushing tag to origin..." git push origin "$new_patch_version" fi echo if [ "$DRY_RUN" = true ]; then echo "🧪 DRY RUN COMPLETE - No changes were made" echo "Tag: $new_patch_version" echo "Commit: $hotfix_commit" echo "Based on hotfix branch: $hotfix_branch" else echo "🎉 Hotfix release $new_patch_version created successfully!" echo "Tag: $new_patch_version" echo "Commit: $hotfix_commit" echo "Based on hotfix branch: $hotfix_branch" fi ================================================ FILE: scripts/release.sh ================================================ #!/bin/bash # Copyright 2025 CloudWeGo Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Hertz Release Script # # This script creates a new release tag for the Hertz project. # It performs the following steps: # 1. Validates you're on the main branch and up-to-date with origin # 2. Prompts for the target commit to release # 3. Shows changes since the last version # 4. Allows you to choose version bump type (major/minor/patch) # 5. Validates version.go matches the new version # 6. Creates and pushes the release tag # # Usage: # ./release.sh # # Prerequisites: # - Must be on main branch # - Local repo must be up-to-date with origin # - Target commit must exist on origin # # The script will interactively prompt for: # - Target commit hash to release # - Version bump type (patch/minor/major) # - Confirmation before creating the tag # set -e # Parse arguments DRY_RUN=false while [[ $# -gt 0 ]]; do case $1 in --dry-run) DRY_RUN=true shift ;; -h|--help) echo "Usage: $0 [--dry-run]" echo " --dry-run Show what would be done without making changes" exit 0 ;; *) echo "Unknown option: $1" echo "Use --help for usage information" exit 1 ;; esac done DEFAULT_BRANCH=main GIT_ROOT=$(git rev-parse --show-toplevel) SCRIPTS_ROOT=$GIT_ROOT/scripts CHECK_VERSION_CMD=$SCRIPTS_ROOT/.utils/check_version.sh CHECK_GO_MOD_CMD=$SCRIPTS_ROOT/.utils/check_go_mod.sh # Source shared utility functions source $SCRIPTS_ROOT/.utils/funcs.sh echo "🚀 Hertz Release Script" echo "=======================" if [ "$DRY_RUN" = true ]; then echo "🧪 DRY RUN MODE - No changes will be made" echo "=========================================" fi # Fetch latest changes from origin echo "📥 Fetching latest changes from origin..." git fetch -p --force --tags origin # 1. Ask user which commit to release echo read -p "⌨️ Enter the commit hash you want to release: " target_commit if [ -z "$target_commit" ]; then echo "❌ Error: Commit hash is required" exit 1 fi # Resolve to full commit hash target_commit=$(git rev-parse "$target_commit") echo "📌 Target commit: $target_commit" echo # 2. Check if commit exists on origin check_commit_exists "$target_commit" "$DEFAULT_BRANCH" # 3. Get latest version tag latest_version=$(get_latest_version) # 4. Show diff between latest version and target commit if [ "$latest_version" = "v0.0.0" ]; then show_first_release_changes "$target_commit" 10 else if ! show_changes "$latest_version" "$target_commit" "Changes since latest version $latest_version"; then exit 0 fi fi # Ask user to confirm the changes before proceeding if ! confirm_changes; then exit 1 fi # 5. Ask user for version bump type # Calculate what each version would be major_version=$(increment_version "$latest_version" "major") minor_version=$(increment_version "$latest_version" "minor") patch_version=$(increment_version "$latest_version" "patch") echo echo "⚙️ Available version bump types:" echo " 1) patch - for bug fixes ($patch_version)" echo " 2) minor - for new features ($minor_version)" echo " 3) major - for breaking changes ($major_version)" echo while true; do read -p "Select version bump type (1/2/3): " bump_choice case $bump_choice in 1) bump_type="patch"; break;; 2) bump_type="minor"; break;; 3) bump_type="major"; break;; *) echo "Please enter 1, 2, or 3";; esac done new_version=$(increment_version "$latest_version" "$bump_type") echo echo "New version will be: $new_version" # Check version.go file echo $CHECK_VERSION_CMD "$new_version" # Check go.mod echo $CHECK_GO_MOD_CMD # Final confirmation echo if [ "$DRY_RUN" = true ]; then read -p "DRY RUN: Would create release tag $new_version for commit $target_commit. Continue? (y/N): " confirm else read -p "Create release tag $new_version for commit $target_commit? (y/N): " confirm fi if [ "$confirm" != "y" ] && [ "$confirm" != "Y" ]; then echo "❌ Release cancelled" exit 1 fi # Create and push tag if [ "$DRY_RUN" = true ]; then echo echo "🧪 DRY RUN: Would execute:" echo " git tag -a \"$new_version\" \"$target_commit\" -m \"Release $new_version\"" echo " git push origin \"$new_version\"" else echo "🏷️ Creating tag $new_version..." git tag -a "$new_version" "$target_commit" -m "Release $new_version" echo "📤 Pushing tag to origin..." git push origin "$new_version" fi echo if [ "$DRY_RUN" = true ]; then echo "🧪 DRY RUN COMPLETE - No changes were made" echo "Tag: $new_version" echo "Commit: $target_commit" else echo "🎉 Release $new_version created successfully!" echo "Tag: $new_version" echo "Commit: $target_commit" fi ================================================ FILE: version.go ================================================ /* * Copyright 2022 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package hertz // Name and Version info of this framework, used for statistics and debug const ( Name = "Hertz" Version = "v0.10.4" )