Repository: go-resty/resty Branch: v3 Commit: 0fadf8088bd3 Files: 51 Total size: 634.8 KB Directory structure: gitextract_43qx_89h/ ├── .github/ │ ├── FUNDING.yml │ └── workflows/ │ ├── ci.yml │ └── label-actions.yml ├── .gitignore ├── .testdata/ │ ├── cert.pem │ ├── key.pem │ ├── sample-root.pem │ └── text-file.txt ├── BUILD.bazel ├── LICENSE ├── README.md ├── WORKSPACE ├── benchmark_test.go ├── cert_watcher_test.go ├── circuit_breaker.go ├── circuit_breaker_test.go ├── client.go ├── client_test.go ├── context_test.go ├── curl.go ├── curl_test.go ├── debug.go ├── digest.go ├── digest_test.go ├── go.mod ├── go.sum ├── hedging.go ├── hedging_test.go ├── load_balancer.go ├── load_balancer_test.go ├── middleware.go ├── middleware_test.go ├── multipart.go ├── multipart_test.go ├── redirect.go ├── request.go ├── request_test.go ├── response.go ├── resty.go ├── resty_test.go ├── retry.go ├── retry_test.go ├── sse.go ├── sse_test.go ├── stream.go ├── stream_test.go ├── trace.go ├── transport_dial.go ├── transport_dial_wasm.go ├── util.go └── util_test.go ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/FUNDING.yml ================================================ github: [jeevatkm] custom: ["https://www.paypal.com/donate/?cmd=_donations&business=QWMZG74FW4QYC&lc=US&item_name=Resty+Library+for+Go¤cy_code=USD"] ================================================ FILE: .github/workflows/ci.yml ================================================ name: CI on: push: branches: - v3 - v2 paths-ignore: - '**.md' - '**.bazel' - 'WORKSPACE' pull_request: branches: - main - v3 - v2 paths-ignore: - '**.md' - '**.bazel' - 'WORKSPACE' # Allows you to run this workflow manually from the Actions tab workflow_dispatch: jobs: build: name: Build strategy: matrix: go: [ 'stable', '1.23.x' ] os: [ ubuntu-latest ] runs-on: ${{ matrix.os }} steps: - name: Checkout uses: actions/checkout@v4 with: fetch-depth: 0 - name: Setup Go uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} cache: true cache-dependency-path: go.sum - name: Format run: diff -u <(echo -n) <(go fmt $(go list ./...)) - name: Test run: go run gotest.tools/gotestsum@latest -f testname -- ./... -race -count=1 -coverprofile=coverage.txt -covermode=atomic -coverpkg=./... -shuffle=on - name: Upload coverage to Codecov if: ${{ matrix.os == 'ubuntu-latest' && matrix.go == 'stable' }} uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} file: ./coverage.txt flags: unittests ================================================ FILE: .github/workflows/label-actions.yml ================================================ name: 'Label' on: pull_request: types: [labeled] paths-ignore: - '**.md' - '**.bazel' - 'WORKSPACE' jobs: build: strategy: matrix: go: [ 'stable', '1.23.x' ] os: [ ubuntu-latest ] name: Run Build if: ${{ github.event.label.name == 'run-build' }} runs-on: ${{ matrix.os }} steps: - name: Checkout uses: actions/checkout@v4 with: fetch-depth: 0 - name: Setup Go uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} cache: true cache-dependency-path: go.sum - name: Format run: diff -u <(echo -n) <(go fmt $(go list ./...)) - name: Test run: go run gotest.tools/gotestsum@latest -f testname -- ./... -race -count=1 -coverprofile=coverage.txt -covermode=atomic -coverpkg=./... -shuffle=on - name: Upload coverage to Codecov if: ${{ matrix.os == 'ubuntu-latest' && matrix.go == 'stable' }} uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} file: ./coverage.txt flags: unittests ================================================ FILE: .gitignore ================================================ # Compiled Object files, Static and Dynamic libs (Shared Objects) *.o *.a *.so # Folders _obj _test # Architecture specific extensions/prefixes *.[568vq] [568vq].out *.cgo1.go *.cgo2.c _cgo_defun.c _cgo_gotypes.go _cgo_export.* _testmain.go *.exe *.test *.prof coverage.out coverage.txt # Exclude IDE folders .idea/* .vscode/* ================================================ FILE: .testdata/cert.pem ================================================ -----BEGIN CERTIFICATE----- MIIC+jCCAeKgAwIBAgIRAJce5ewsoW44j0qvSABmq7owDQYJKoZIhvcNAQELBQAw EjEQMA4GA1UEChMHQWNtZSBDbzAeFw0yNTAxMDQwNzA3MTNaFw0yNjAxMDQwNzA3 MTNaMBIxEDAOBgNVBAoTB0FjbWUgQ28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw ggEKAoIBAQCYkTN1g/0Z3KkS3w0lX9yhZkwiA0obXCeFs7hpRP0p4WlW3uADyXQ5 h2MaYx8OCA7oGU7/dWOPhtE3rgFEz7IwLxcP5d02ukLGlFD69D6KLyTXwCFmvOWQ 5fbOq4s73WTNDfYSTYNzeujDCjeu/Bk0OVhdxbyZdyrpdm+UBfH8uIDoGeCRXnji nqG9HNOQx6r/S6FqC5j/7PrVl1i66WlqRzKEJB94uejfujrHq8RjQm/wzEutU5df C39zEEEx75qQt7Jc0asm1AqAKSq34xn4rVajWrBZ/WudUUizHfaBDP61uPFvPyKW JDvTSdeoM9TPX0y0cjo6AwSrdLl7flrRAgMBAAGjSzBJMA4GA1UdDwEB/wQEAwIF oDATBgNVHSUEDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMBQGA1UdEQQNMAuC CWxvY2FsaG9zdDANBgkqhkiG9w0BAQsFAAOCAQEAdHvPQe3EJ4/X6K/bklJUhIfM KBauH8VMBfri7xLawleKssm7GdiFivSA0g1pArkl8SALBlPqhrx7rwlyyivLTZaR VFvXaQ9eU0zGnSnDnKVz6CX/zn3TKfcgZPEBclayh0ldm7A8xSJWaWbRZ+s9e9x1 XcQTn2KkMZfBDMnGEWQ3KZrClvO5ZfkqSiyzEm9+eF0m0E7ujTyfSVMsPdyldA6U pHG8omQTyOzJl2I4z7DlS0AEsL0TJHV4iKr9rDei2xQz/wtful5qU/taYp2Y6zMH 8ytnDldJhmcCwmvtqvK5p6CbkatE7TFyw2CxQJHnQef+Y4W94sSZWg9CGRKDIQ== -----END CERTIFICATE----- ================================================ FILE: .testdata/key.pem ================================================ -----BEGIN PRIVATE KEY----- MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCYkTN1g/0Z3KkS 3w0lX9yhZkwiA0obXCeFs7hpRP0p4WlW3uADyXQ5h2MaYx8OCA7oGU7/dWOPhtE3 rgFEz7IwLxcP5d02ukLGlFD69D6KLyTXwCFmvOWQ5fbOq4s73WTNDfYSTYNzeujD Cjeu/Bk0OVhdxbyZdyrpdm+UBfH8uIDoGeCRXnjinqG9HNOQx6r/S6FqC5j/7PrV l1i66WlqRzKEJB94uejfujrHq8RjQm/wzEutU5dfC39zEEEx75qQt7Jc0asm1AqA KSq34xn4rVajWrBZ/WudUUizHfaBDP61uPFvPyKWJDvTSdeoM9TPX0y0cjo6AwSr dLl7flrRAgMBAAECggEAJPTPNUEilxgncGXNZmdBJ2uDN536XoRFIpL1MbK/bFyo yp00QFaVK7ZK4EJwbFKxYbF3vFOwKT0sAsPIlOWGsTtG59fzbOVTdYzJzPBLEef3 kbd9n8hUB3RdA5T0Ji0r1Kv0FlzmYZu9NDmOYXm5lTfq2tQiKj5+i4zf3EhQZLng 4wVxBT7yQUQcstJv5K1L6HVzunSYtbHx8ZVxmw+tJ4lMCK23KPlvncZZTT8chWdT 3GOp5nYIHk9E5jQnBnj7p73sxZUCZlb8uhLtdcgAXc4scptEVO+7n5zOaXIv40Oz yfkESgHcZWAMDvnkxdySHlD38Z2LIKDGbqR6O9wcwQKBgQDBO6fFPXO41nsxdVCB nhCgL2hsGjaxJzGBLaOJNVMMFRASN3Yqvs4N1Hn7lawRI/FRRffxjLkZfNGEBSF2 OipdvX19Oe2hCZxvwHPoe5sb/Dh6KE7If1hRLOCXg/8E7ADBtAp94dam1WF4Kh6N Va6+n2YKif2rqye1YtRoUU46iQKBgQDKH/eMcMRUe9IySxHLogidOUwa0X7WrxF/ PkXGpPbHQtMOJF5cVzh+L+foUKXNM60lgmCH0438GKU7kirC/dVtD/bwE598/XFZ vnjPV7Adf9vBz9NN8cS/4uEfQYbvTRmrnrQK+ZhOe8hmwjapxqdWrVHNUtvx18vL qBwR4YjsCQKBgCycMx1MFJ1FludSKCXkcf4pM7hRTPMVE065VJnmn6eYbT9nYnZ3 2mZC+W5lnXXPkHSs7JLtZAZIVK5f6Nu8je9aQdBZQUz+RQlfquKvNp39WqSJDbcn /yGudKNGK+fc/Ee74vgw3Tdi57+wKaGDeHY1on8oYFHzj5VGnbb/nknRAoGBAK2Z hyQ4NmfZcU+A6mfbY0qmS5c9F5OMCZsgAQ374XiDDIK4+dKVlw/KVYRSwBTerXfp 4r7GFMzQ3hmsEM4o9YYWkCDiubjAdPp/fYOX7MtpZXWw6euoGzQzyObvgNVHgyTD yh8jAI1oA1c+t3RaCp+HfRq8b+vnTEI+wN0auF8BAoGBAJmw+GgHCZGpw2XPNu+X 8kuVGbQYAjTOXhBM4WzZyhfH1TWKLGn7C9YixhE2AW0UWKDvy+6OqPhe8q3KVms3 8YZ1W+vbUNEZNGE0XrB5ZMXfePiqisCz0jgP9OAuT+ii4aI3MAm3zgCEC6UTMvLq gNBu3Tcy6udxnUf7czzJDRtE -----END PRIVATE KEY----- ================================================ FILE: .testdata/sample-root.pem ================================================ -----BEGIN CERTIFICATE----- MIIEBDCCAuygAwIBAgIDAjppMA0GCSqGSIb3DQEBBQUAMEIxCzAJBgNVBAYTAlVT MRYwFAYDVQQKEw1HZW9UcnVzdCBJbmMuMRswGQYDVQQDExJHZW9UcnVzdCBHbG9i YWwgQ0EwHhcNMTMwNDA1MTUxNTU1WhcNMTUwNDA0MTUxNTU1WjBJMQswCQYDVQQG EwJVUzETMBEGA1UEChMKR29vZ2xlIEluYzElMCMGA1UEAxMcR29vZ2xlIEludGVy bmV0IEF1dGhvcml0eSBHMjCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB AJwqBHdc2FCROgajguDYUEi8iT/xGXAaiEZ+4I/F8YnOIe5a/mENtzJEiaB0C1NP VaTOgmKV7utZX8bhBYASxF6UP7xbSDj0U/ck5vuR6RXEz/RTDfRK/J9U3n2+oGtv h8DQUB8oMANA2ghzUWx//zo8pzcGjr1LEQTrfSTe5vn8MXH7lNVg8y5Kr0LSy+rE ahqyzFPdFUuLH8gZYR/Nnag+YyuENWllhMgZxUYi+FOVvuOAShDGKuy6lyARxzmZ EASg8GF6lSWMTlJ14rbtCMoU/M4iarNOz0YDl5cDfsCx3nuvRTPPuj5xt970JSXC DTWJnZ37DhF5iR43xa+OcmkCAwEAAaOB+zCB+DAfBgNVHSMEGDAWgBTAephojYn7 qwVkDBF9qn1luMrMTjAdBgNVHQ4EFgQUSt0GFhu89mi1dvWBtrtiGrpagS8wEgYD VR0TAQH/BAgwBgEB/wIBADAOBgNVHQ8BAf8EBAMCAQYwOgYDVR0fBDMwMTAvoC2g K4YpaHR0cDovL2NybC5nZW90cnVzdC5jb20vY3Jscy9ndGdsb2JhbC5jcmwwPQYI KwYBBQUHAQEEMTAvMC0GCCsGAQUFBzABhiFodHRwOi8vZ3RnbG9iYWwtb2NzcC5n ZW90cnVzdC5jb20wFwYDVR0gBBAwDjAMBgorBgEEAdZ5AgUBMA0GCSqGSIb3DQEB BQUAA4IBAQA21waAESetKhSbOHezI6B1WLuxfoNCunLaHtiONgaX4PCVOzf9G0JY /iLIa704XtE7JW4S615ndkZAkNoUyHgN7ZVm2o6Gb4ChulYylYbc3GrKBIxbf/a/ zG+FA1jDaFETzf3I93k9mTXwVqO94FntT0QJo544evZG0R0SnU++0ED8Vf4GXjza HFa9llF7b1cq26KqltyMdMKVvvBulRP/F/A8rLIQjcxz++iPAsbw+zOzlTvjwsto WHPbqCRiOwY1nQ2pM714A5AuTHhdUDqB1O6gyHA43LL5Z/qHQF1hwFGPa4NrzQU6 yuGnBXj8ytqU0CwIPX4WecigUCAkVDNx -----END CERTIFICATE----- ================================================ FILE: .testdata/text-file.txt ================================================ THIS IS TEXT FILE FOR MULTIPART UPLOAD TEST :) - go-resty ================================================ FILE: BUILD.bazel ================================================ load("@bazel_gazelle//:def.bzl", "gazelle") load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") # gazelle:prefix resty.dev/v3 # gazelle:go_naming_convention import_alias gazelle(name = "gazelle") go_library( name = "resty", srcs = [ "circuit_breaker.go", "client.go", "curl.go", "debug.go", "digest.go", "hedging.go", "load_balancer.go", "middleware.go", "multipart.go", "redirect.go", "request.go", "response.go", "resty.go", "retry.go", "sse.go", "stream.go", "trace.go", "transport_dial.go", "transport_dial_wasm.go", "util.go", ], importpath = "resty.dev/v3", visibility = ["//visibility:public"], deps = ["@org_golang_x_net//publicsuffix:go_default_library"], ) go_test( name = "resty_test", srcs = [ "benchmark_test.go", "cert_watcher_test.go", "client_test.go", "context_test.go", "curl_test.go", "digest_test.go", "hedging_test.go", "load_balancer_test.go", "middleware_test.go", "multipart_test.go", "request_test.go", "resty_test.go", "retry_test.go", "sse_test.go", "util_test.go", ], data = glob([".testdata/*"]), embed = [":resty"], ) alias( name = "go_default_library", actual = ":resty", visibility = ["//visibility:public"], ) ================================================ FILE: LICENSE ================================================ The MIT License (MIT) Copyright (c) 2015-present Jeevanandam M., https://myjeeva.com Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================

Resty Logo

Simple HTTP, REST, and SSE client library for Go

Resty Build Status Resty Code Coverage Go Report Card Resty GoDoc License Mentioned in Awesome Go

## Documentation Go to https://resty.dev and refer to godoc. ## Minimum Go Version Use `go1.23` and above. ## Support & Donate * Sponsor via [GitHub](https://github.com/sponsors/jeevatkm) * Donate via [PayPal](https://www.paypal.com/donate/?cmd=_donations&business=QWMZG74FW4QYC&lc=US&item_name=Resty+Library+for+Go¤cy_code=USD) ## Versioning Resty releases versions according to [Semantic Versioning](http://semver.org) * Resty v3 provides Go Vanity URL `resty.dev/v3`. * Resty v2 migrated away from `gopkg.in` service, `github.com/go-resty/resty/v2`. * Resty fully adapted to `go mod` capabilities since `v1.10.0` release. * Resty v1 series was using `gopkg.in` to provide versioning. `gopkg.in/resty.vX` points to appropriate tagged versions; `X` denotes version series number and it's a stable release for production use. For e.g. `gopkg.in/resty.v0`. ## Contribution I would welcome your contribution! * If you find any improvement or issue you want to fix, feel free to send a pull request. * The pull requests must include test cases for feature/fix/enhancement with patch coverage of 100%. * I have done my best to bring pretty good coverage. I would request contributors to do the same for their contribution. I always look forward to hearing feedback, appreciation, and real-world usage stories from Resty users on [GitHub Discussions](https://github.com/go-resty/resty/discussions). It means a lot to me. ## Creator [Jeevanandam M.](https://github.com/jeevatkm) (jeeva@myjeeva.com) ## Contributors Have a look on [Contributors](https://github.com/go-resty/resty/graphs/contributors) page. ## License Info Resty released under MIT [LICENSE](LICENSE). Resty [Documentation](https://github.com/go-resty/docs) and website released under Apache-2.0 [LICENSE](https://github.com/go-resty/docs/blob/main/LICENSE). ================================================ FILE: WORKSPACE ================================================ workspace(name = "resty") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") http_archive( name = "io_bazel_rules_go", sha256 = "80a98277ad1311dacd837f9b16db62887702e9f1d1c4c9f796d0121a46c8e184", urls = [ "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.46.0/rules_go-v0.46.0.zip", "https://github.com/bazelbuild/rules_go/releases/download/v0.46.0/rules_go-v0.46.0.zip", ], ) http_archive( name = "bazel_gazelle", sha256 = "62ca106be173579c0a167deb23358fdfe71ffa1e4cfdddf5582af26520f1c66f", urls = [ "https://mirror.bazel.build/github.com/bazelbuild/bazel-gazelle/releases/download/v0.23.0/bazel-gazelle-v0.23.0.tar.gz", "https://github.com/bazelbuild/bazel-gazelle/releases/download/v0.23.0/bazel-gazelle-v0.23.0.tar.gz", ], ) load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_dependencies") go_rules_dependencies() go_register_toolchains(version = "1.21") load("@bazel_gazelle//:deps.bzl", "gazelle_dependencies") gazelle_dependencies() ================================================ FILE: benchmark_test.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty import ( "bytes" "strings" "testing" "time" ) func Benchmark_parseRequestURL_PathParams(b *testing.B) { c := New().SetPathParams(map[string]string{ "foo": "1", "bar": "2", }).SetPathRawParams(map[string]string{ "foo": "3", "xyz": "4", }) r := c.R().SetPathParams(map[string]string{ "foo": "5", "qwe": "6", }).SetPathRawParams(map[string]string{ "foo": "7", "asd": "8", }) b.ResetTimer() for i := 0; i < b.N; i++ { r.URL = "https://example.com/{foo}/{bar}/{xyz}/{qwe}/{asd}" if err := parseRequestURL(c, r); err != nil { b.Errorf("parseRequestURL() error = %v", err) } } } func Benchmark_parseRequestURL_QueryParams(b *testing.B) { c := New().SetQueryParams(map[string]string{ "foo": "1", "bar": "2", }) r := c.R().SetQueryParams(map[string]string{ "foo": "5", "qwe": "6", }) b.ResetTimer() for i := 0; i < b.N; i++ { r.URL = "https://example.com/" if err := parseRequestURL(c, r); err != nil { b.Errorf("parseRequestURL() error = %v", err) } } } func Benchmark_parseRequestHeader(b *testing.B) { c := New() r := c.R() c.SetHeaders(map[string]string{ "foo": "1", // ignored, because of the same header in the request "bar": "2", }) r.SetHeaders(map[string]string{ "foo": "3", "xyz": "4", }) b.ResetTimer() for i := 0; i < b.N; i++ { if err := parseRequestHeader(c, r); err != nil { b.Errorf("parseRequestHeader() error = %v", err) } } } func Benchmark_parseRequestBody_string(b *testing.B) { c := New() r := c.R() r.SetBody("foo") b.ResetTimer() for i := 0; i < b.N; i++ { if err := parseRequestBody(c, r); err != nil { b.Errorf("parseRequestBody() error = %v", err) } } } func Benchmark_parseRequestBody_byte(b *testing.B) { c := New() r := c.R() r.SetBody([]byte("foo")) b.ResetTimer() for i := 0; i < b.N; i++ { if err := parseRequestBody(c, r); err != nil { b.Errorf("parseRequestBody() error = %v", err) } } } func Benchmark_parseRequestBody_reader(b *testing.B) { c := New() r := c.R() r.SetBody(bytes.NewBufferString("foo")) b.ResetTimer() for i := 0; i < b.N; i++ { if err := parseRequestBody(c, r); err != nil { b.Errorf("parseRequestBody() error = %v", err) } } } func Benchmark_parseRequestBody_struct(b *testing.B) { type FooBar struct { Foo string `json:"foo"` Bar string `json:"bar"` } c := New() r := c.R() r.SetBody(FooBar{Foo: "1", Bar: "2"}).SetHeader(hdrContentTypeKey, jsonContentType) b.ResetTimer() for i := 0; i < b.N; i++ { if err := parseRequestBody(c, r); err != nil { b.Errorf("parseRequestBody() error = %v", err) } } } func Benchmark_parseRequestBody_struct_xml(b *testing.B) { type FooBar struct { Foo string `xml:"foo"` Bar string `xml:"bar"` } c := New() r := c.R() r.SetBody(FooBar{Foo: "1", Bar: "2"}).SetHeader(hdrContentTypeKey, "text/xml") b.ResetTimer() for i := 0; i < b.N; i++ { if err := parseRequestBody(c, r); err != nil { b.Errorf("parseRequestBody() error = %v", err) } } } func Benchmark_parseRequestBody_map(b *testing.B) { c := New() r := c.R() r.SetBody(map[string]string{ "foo": "1", "bar": "2", }).SetHeader(hdrContentTypeKey, jsonContentType) b.ResetTimer() for i := 0; i < b.N; i++ { if err := parseRequestBody(c, r); err != nil { b.Errorf("parseRequestBody() error = %v", err) } } } func Benchmark_parseRequestBody_slice(b *testing.B) { c := New() r := c.R() r.SetBody([]string{"1", "2"}).SetHeader(hdrContentTypeKey, jsonContentType) b.ResetTimer() for i := 0; i < b.N; i++ { if err := parseRequestBody(c, r); err != nil { b.Errorf("parseRequestBody() error = %v", err) } } } func Benchmark_parseRequestBody_FormData(b *testing.B) { c := New() r := c.R() c.SetFormData(map[string]string{"foo": "1", "bar": "2"}) r.SetFormData(map[string]string{"foo": "3", "baz": "4"}) b.ResetTimer() for i := 0; i < b.N; i++ { if err := parseRequestBody(c, r); err != nil { b.Errorf("parseRequestBody() error = %v", err) } } } func Benchmark_parseRequestBody_MultiPart(b *testing.B) { c := New() r := c.R() c.SetFormData(map[string]string{"foo": "1", "bar": "2"}) r.SetFormData(map[string]string{"foo": "3", "baz": "4"}). SetMultipartFormData(map[string]string{"foo": "5", "xyz": "6"}). SetFileReader("qwe", "qwe.txt", strings.NewReader("7")). SetMultipartFields( &MultipartField{ Name: "sdj", ContentType: "text/plain", Reader: strings.NewReader("8"), }, ). SetMethod(MethodPost) b.ResetTimer() for i := 0; i < b.N; i++ { if err := parseRequestBody(c, r); err != nil { b.Errorf("parseRequestBody() error = %v", err) } } } // benchmarkStringer implements fmt.Stringer for benchmarking type benchmarkStringer struct { value string } func (s benchmarkStringer) String() string { return s.value } // Tier 1: most common URL types func Benchmark_formatAnyToString_string(b *testing.B) { v := "hello world" for i := 0; i < b.N; i++ { _ = formatAnyToString(v) } } func Benchmark_formatAnyToString_int(b *testing.B) { v := 12345 for i := 0; i < b.N; i++ { _ = formatAnyToString(v) } } func Benchmark_formatAnyToString_bool(b *testing.B) { v := true for i := 0; i < b.N; i++ { _ = formatAnyToString(v) } } func Benchmark_formatAnyToString_int64(b *testing.B) { v := int64(9223372036854775807) for i := 0; i < b.N; i++ { _ = formatAnyToString(v) } } func Benchmark_formatAnyToString_stringSlice(b *testing.B) { v := []string{"a", "b", "c"} for i := 0; i < b.N; i++ { _ = formatAnyToString(v) } } // Tier 2: common stdlib types func Benchmark_formatAnyToString_time(b *testing.B) { v := time.Date(2024, 6, 15, 10, 30, 0, 0, time.UTC) for i := 0; i < b.N; i++ { _ = formatAnyToString(v) } } func Benchmark_formatAnyToString_byteSlice(b *testing.B) { v := []byte("binary data") for i := 0; i < b.N; i++ { _ = formatAnyToString(v) } } func Benchmark_formatAnyToString_float64(b *testing.B) { v := 3.14159265359 for i := 0; i < b.N; i++ { _ = formatAnyToString(v) } } // Tier 3: less common integers (signed) func Benchmark_formatAnyToString_int32(b *testing.B) { v := int32(2147483647) for i := 0; i < b.N; i++ { _ = formatAnyToString(v) } } func Benchmark_formatAnyToString_int16(b *testing.B) { v := int16(32767) for i := 0; i < b.N; i++ { _ = formatAnyToString(v) } } func Benchmark_formatAnyToString_int8(b *testing.B) { v := int8(127) for i := 0; i < b.N; i++ { _ = formatAnyToString(v) } } // Tier 4: less common integers (unsigned) func Benchmark_formatAnyToString_uint64(b *testing.B) { v := uint64(18446744073709551615) for i := 0; i < b.N; i++ { _ = formatAnyToString(v) } } func Benchmark_formatAnyToString_uint32(b *testing.B) { v := uint32(4294967295) for i := 0; i < b.N; i++ { _ = formatAnyToString(v) } } func Benchmark_formatAnyToString_uint16(b *testing.B) { v := uint16(65535) for i := 0; i < b.N; i++ { _ = formatAnyToString(v) } } func Benchmark_formatAnyToString_uint8(b *testing.B) { v := uint8(255) for i := 0; i < b.N; i++ { _ = formatAnyToString(v) } } func Benchmark_formatAnyToString_uint(b *testing.B) { v := uint(12345) for i := 0; i < b.N; i++ { _ = formatAnyToString(v) } } // Tier 5: rare types and fallbacks func Benchmark_formatAnyToString_float32(b *testing.B) { v := float32(3.14) for i := 0; i < b.N; i++ { _ = formatAnyToString(v) } } func Benchmark_formatAnyToString_stringer(b *testing.B) { v := benchmarkStringer{value: "custom value"} for i := 0; i < b.N; i++ { _ = formatAnyToString(v) } } func Benchmark_formatAnyToString_default(b *testing.B) { v := struct{ Name string }{Name: "test"} for i := 0; i < b.N; i++ { _ = formatAnyToString(v) } } ================================================ FILE: cert_watcher_test.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty import ( "crypto/rand" "crypto/rsa" "crypto/x509" "crypto/x509/pkix" "encoding/pem" "math/big" "net" "net/http" "os" "path/filepath" "strings" "testing" "time" ) type certPaths struct { RootCAKey string RootCACert string TLSKey string TLSCert string } func TestClient_SetRootCertificateWatcher(t *testing.T) { // For this test, we want to: // - Generate root CA // - Generate TLS cert signed with root CA // - Start a Test HTTPS server // - Create a Resty client with SetRootCertificateWatcher and SetClientRootCertificateWatcher // - Send multiple requests and re-generate the certs periodically to reproduce renewal certDir := t.TempDir() paths := certPaths{ RootCAKey: filepath.Join(certDir, "root-ca.key"), RootCACert: filepath.Join(certDir, "root-ca.crt"), TLSKey: filepath.Join(certDir, "tls.key"), TLSCert: filepath.Join(certDir, "tls.crt"), } generateCerts(t, paths) ts := createTestTLSServer(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }, paths.TLSCert, paths.TLSKey) defer ts.Close() poolingInterval := 100 * time.Millisecond client := NewWithTransportSettings(&TransportSettings{ // Make sure that TLS handshake happens for all request // (otherwise, test may succeed because 1st TLS session is re-used) DisableKeepAlives: true, }).SetRootCertificatesWatcher( &CertWatcherOptions{PoolInterval: poolingInterval}, paths.RootCACert, ).SetClientRootCertificatesWatcher( &CertWatcherOptions{PoolInterval: poolingInterval}, paths.RootCACert, ).SetDebug(false) url := strings.Replace(ts.URL, "127.0.0.1", "localhost", 1) t.Log("Test URL:", url) t.Run("Cert Watcher should handle certs rotation", func(t *testing.T) { for i := 0; i < 5; i++ { res, err := client.R().Get(url) if err != nil { t.Fatal(err) } assertEqual(t, res.StatusCode(), http.StatusOK) if i%2 == 1 { // Re-generate certs to simulate renewal scenario generateCerts(t, paths) time.Sleep(50 * time.Millisecond) } } }) t.Run("Cert Watcher should recover on failure", func(t *testing.T) { // Delete root cert and re-create it to ensure that cert watcher is able to recover // Re-generate certs to invalidate existing cert generateCerts(t, paths) // Delete root cert so that Cert Watcher will fail err := os.RemoveAll(paths.RootCACert) assertNil(t, err) // Reset TLS config to ensure that previous root cert is not re-used tr, err := client.HTTPTransport() assertNil(t, err) tr.TLSClientConfig = nil client.SetTransport(tr) time.Sleep(50 * time.Millisecond) _, err = client.R().Get(url) // We expect an error since root cert has been deleted assertNotNil(t, err) // Re-generate certs. We except cert watcher to reload the new root cert. generateCerts(t, paths) time.Sleep(50 * time.Millisecond) _, err = client.R().Get(url) assertNil(t, err) }) err := client.Close() assertNil(t, err) } func generateCerts(t *testing.T, paths certPaths) { rootKey, rootCert, err := generateRootCA(paths.RootCAKey, paths.RootCACert) if err != nil { t.Fatal(err) } if err := generateTLSCert(paths.TLSKey, paths.TLSCert, rootKey, rootCert); err != nil { t.Fatal(err) } } // Generate a Root Certificate Authority (CA) func generateRootCA(keyPath, certPath string) (*rsa.PrivateKey, []byte, error) { // Generate the key for the Root CA rootKey, err := generateKey() if err != nil { return nil, nil, err } // Define the maximum value you want for the random big integer max := new(big.Int).Lsh(big.NewInt(1), 256) // Example: 256 bits // Generate a random big.Int randomBigInt, err := rand.Int(rand.Reader, max) if err != nil { return nil, nil, err } // Create the root certificate template rootCertTemplate := &x509.Certificate{ SerialNumber: randomBigInt, Subject: pkix.Name{ Organization: []string{"YourOrg"}, Country: []string{"US"}, Province: []string{"State"}, Locality: []string{"City"}, CommonName: "YourRootCA", }, NotBefore: time.Now(), NotAfter: time.Now().Add(time.Hour * 10), KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, IsCA: true, BasicConstraintsValid: true, } // Self-sign the root certificate rootCert, err := x509.CreateCertificate(rand.Reader, rootCertTemplate, rootCertTemplate, &rootKey.PublicKey, rootKey) if err != nil { return nil, nil, err } // Save the Root CA key and certificate if err := savePEMKey(keyPath, rootKey); err != nil { return nil, nil, err } if err := savePEMCert(certPath, rootCert); err != nil { return nil, nil, err } return rootKey, rootCert, nil } // Generate a TLS Certificate signed by the Root CA func generateTLSCert(keyPath, certPath string, rootKey *rsa.PrivateKey, rootCert []byte) error { // Generate a key for the server serverKey, err := generateKey() if err != nil { return err } // Parse the Root CA certificate parsedRootCert, err := x509.ParseCertificate(rootCert) if err != nil { return err } // Create the server certificate template serverCertTemplate := &x509.Certificate{ SerialNumber: big.NewInt(2), Subject: pkix.Name{ Organization: []string{"YourOrg"}, CommonName: "localhost", }, NotBefore: time.Now(), NotAfter: time.Now().Add(time.Hour * 10), KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, DNSNames: []string{"localhost"}, } // Sign the server certificate with the Root CA serverCert, err := x509.CreateCertificate(rand.Reader, serverCertTemplate, parsedRootCert, &serverKey.PublicKey, rootKey) if err != nil { return err } // Save the server key and certificate if err := savePEMKey(keyPath, serverKey); err != nil { return err } if err := savePEMCert(certPath, serverCert); err != nil { return err } return nil } func generateKey() (*rsa.PrivateKey, error) { return rsa.GenerateKey(rand.Reader, 2048) } func savePEMKey(fileName string, key *rsa.PrivateKey) error { keyFile, err := os.Create(fileName) if err != nil { return err } defer keyFile.Close() privateKeyPEM := &pem.Block{ Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key), } return pem.Encode(keyFile, privateKeyPEM) } func savePEMCert(fileName string, cert []byte) error { certFile, err := os.Create(fileName) if err != nil { return err } defer certFile.Close() certPEM := &pem.Block{ Type: "CERTIFICATE", Bytes: cert, } return pem.Encode(certFile, certPEM) } ================================================ FILE: circuit_breaker.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty import ( "errors" "net/http" "sync" "sync/atomic" "time" ) // ErrCircuitBreakerOpen is returned when the circuit breaker is open. var ErrCircuitBreakerOpen = errors.New("resty: circuit breaker open") type ( // CircuitBreakerTriggerHook type is for reacting to circuit breaker trigger hooks. CircuitBreakerTriggerHook func(*Request, error) // CircuitBreakerStateChangeHook type is for reacting to circuit breaker state change hooks. CircuitBreakerStateChangeHook func(oldState, newState CircuitBreakerState) // CircuitBreakerState type represents the state of the circuit breaker. CircuitBreakerState uint32 ) // group is an interface for types that can be combined and inverted type group[T any] interface { op(T) T empty() T inverse() T } // totalAndFailures tracks total requests and failures type totalAndFailures struct { total int failures int } func (tf totalAndFailures) op(g totalAndFailures) totalAndFailures { tf.total += g.total tf.failures += g.failures return tf } func (tf totalAndFailures) empty() totalAndFailures { return totalAndFailures{} } func (tf totalAndFailures) inverse() totalAndFailures { tf.total = -tf.total tf.failures = -tf.failures return tf } // slidingWindow implements a time-based sliding window for tracking values type slidingWindow[G group[G]] struct { mutex sync.RWMutex total G values []G idx int lastStart time.Time interval time.Duration } func newSlidingWindow[G group[G]](empty func() G, interval time.Duration, buckets int) *slidingWindow[G] { return &slidingWindow[G]{ total: empty(), values: make([]G, buckets), idx: 0, lastStart: time.Now(), interval: interval, } } func (sw *slidingWindow[G]) Add(val G) { sw.mutex.Lock() defer sw.mutex.Unlock() now := time.Now() elapsed := now.Sub(sw.lastStart) bucketDuration := sw.interval / time.Duration(len(sw.values)) // Advance window if needed if elapsed >= bucketDuration { bucketsToAdvance := int(elapsed / bucketDuration) if bucketsToAdvance >= len(sw.values) { // Reset all buckets for i := range sw.values { sw.values[i] = sw.total.empty() } sw.total = sw.total.empty() sw.idx = 0 } else { // Remove old buckets for i := 0; i < bucketsToAdvance; i++ { sw.idx = (sw.idx + 1) % len(sw.values) sw.total = sw.total.op(sw.values[sw.idx].inverse()) sw.values[sw.idx] = sw.total.empty() } } sw.lastStart = now } // Add to current bucket sw.values[sw.idx] = sw.values[sw.idx].op(val) sw.total = sw.total.op(val) } func (sw *slidingWindow[G]) Get() G { sw.mutex.RLock() defer sw.mutex.RUnlock() return sw.total } func (sw *slidingWindow[G]) SetInterval(interval time.Duration) { sw.mutex.Lock() defer sw.mutex.Unlock() sw.interval = interval } const ( // CircuitBreakerStateClosed represents the closed state of the circuit breaker. CircuitBreakerStateClosed CircuitBreakerState = iota // CircuitBreakerStateOpen represents the open state of the circuit breaker. CircuitBreakerStateOpen // CircuitBreakerStateHalfOpen represents the half-open state of the circuit breaker. CircuitBreakerStateHalfOpen ) // CircuitBreaker struct implements a state machine to monitor and manage the // states of circuit breakers. The three states are: // - Closed: requests are allowed // - Open: requests are blocked // - Half-Open: a single request is allowed to determine // // Transitions // - To Closed State: when the success count reaches the success threshold. // - To Open State: when the failure count reaches the failure threshold. // - Half-Open Check: when the specified timeout reaches, a single request is allowed // to determine the transition state; if failed, it goes back to the open state. // // Use [NewCircuitBreakerWithCount] or [NewCircuitBreakerWithRatio] to create a new [CircuitBreaker] // instance accordingly. type CircuitBreaker struct { lock *sync.RWMutex policies []CircuitBreakerPolicy resetTimeout time.Duration state atomic.Value // CircuitBreakerState sw *slidingWindow[totalAndFailures] // Hooks triggerHooks []CircuitBreakerTriggerHook stateChangeHooks []CircuitBreakerStateChangeHook // Count-based failureThreshold uint64 successThreshold uint64 // Ratio-based isRatioBased bool failureRatio float64 // Threshold, e.g., 0.5 for 50% failure minRequests uint64 // Minimum number of requests to consider failure ratio } // NewCircuitBreakerWithCount method creates a new [CircuitBreaker] instance with Count settings. // // The default settings are: // - Policies: CircuitBreaker5xxPolicy func NewCircuitBreakerWithCount(failureThreshold uint64, successThreshold uint64, resetTimeout time.Duration, policies ...CircuitBreakerPolicy) *CircuitBreaker { cb := newCircuitBreaker(resetTimeout, policies...) cb.failureThreshold = failureThreshold cb.successThreshold = successThreshold return cb } // NewCircuitBreakerWithRatio method creates a new [CircuitBreaker] instance with Ratio settings. // // The default settings are: // - Policies: CircuitBreaker5xxPolicy func NewCircuitBreakerWithRatio(failureRatio float64, minRequests uint64, resetTimeout time.Duration, policies ...CircuitBreakerPolicy) *CircuitBreaker { cb := newCircuitBreaker(resetTimeout, policies...) cb.failureRatio = failureRatio cb.minRequests = minRequests cb.isRatioBased = true return cb } func newCircuitBreaker(resetTimeout time.Duration, policies ...CircuitBreakerPolicy) *CircuitBreaker { cb := &CircuitBreaker{ lock: &sync.RWMutex{}, resetTimeout: resetTimeout, policies: []CircuitBreakerPolicy{CircuitBreaker5xxPolicy}, } cb.state.Store(CircuitBreakerStateClosed) cb.sw = newSlidingWindow( func() totalAndFailures { return totalAndFailures{} }, resetTimeout, 10, ) if len(policies) > 0 { cb.policies = policies } return cb } // OnTrigger method adds a [CircuitBreakerTriggerHook] to the [CircuitBreaker] instance. func (cb *CircuitBreaker) OnTrigger(hooks ...CircuitBreakerTriggerHook) *CircuitBreaker { cb.lock.Lock() defer cb.lock.Unlock() cb.triggerHooks = append(cb.triggerHooks, hooks...) return cb } // onTriggerHooks method executes all registered trigger hooks. func (cb *CircuitBreaker) onTriggerHooks(req *Request, err error) { cb.lock.RLock() defer cb.lock.RUnlock() for _, h := range cb.triggerHooks { h(req, err) } } // OnStateChange method adds a [CircuitBreakerStateChangeHook] to the [CircuitBreaker] instance. func (cb *CircuitBreaker) OnStateChange(hooks ...CircuitBreakerStateChangeHook) *CircuitBreaker { cb.lock.Lock() defer cb.lock.Unlock() cb.stateChangeHooks = append(cb.stateChangeHooks, hooks...) return cb } // onStateChangeHooks method executes all registered state change hooks. func (cb *CircuitBreaker) onStateChangeHooks(oldState, newState CircuitBreakerState) { cb.lock.RLock() defer cb.lock.RUnlock() for _, h := range cb.stateChangeHooks { h(oldState, newState) } } // CircuitBreakerPolicy is a function type that determines whether a response should // trip the [CircuitBreaker]. type CircuitBreakerPolicy func(resp *http.Response) bool // CircuitBreaker5xxPolicy is a [CircuitBreakerPolicy] that trips the [CircuitBreaker] if // the response status code is 500 or greater. func CircuitBreaker5xxPolicy(resp *http.Response) bool { return resp.StatusCode > 499 } func (cb *CircuitBreaker) getState() CircuitBreakerState { return cb.state.Load().(CircuitBreakerState) } func (cb *CircuitBreaker) allow() error { if cb.getState() == CircuitBreakerStateOpen { return ErrCircuitBreakerOpen } return nil } func (cb *CircuitBreaker) applyPolicies(resp *http.Response) { failed := false for _, policy := range cb.policies { if policy(resp) { failed = true break } } if failed { cb.sw.Add(totalAndFailures{total: 1, failures: 1}) switch cb.getState() { case CircuitBreakerStateClosed: tf := cb.sw.Get() if cb.isRatioBased { if tf.total >= int(cb.minRequests) { currentFailureRatio := float64(tf.failures) / float64(tf.total) if currentFailureRatio >= cb.failureRatio { cb.open() } } } else { if tf.failures >= int(cb.failureThreshold) { cb.open() } } case CircuitBreakerStateHalfOpen: cb.open() } return } cb.sw.Add(totalAndFailures{total: 1, failures: 0}) switch cb.getState() { case CircuitBreakerStateClosed: return case CircuitBreakerStateHalfOpen: tf := cb.sw.Get() if tf.total-tf.failures >= int(cb.successThreshold) { cb.changeState(CircuitBreakerStateClosed) } } } func (cb *CircuitBreaker) open() { cb.changeState(CircuitBreakerStateOpen) go func() { time.Sleep(cb.resetTimeout) cb.changeState(CircuitBreakerStateHalfOpen) }() } func (cb *CircuitBreaker) changeState(state CircuitBreakerState) { oldState := cb.getState() cb.sw = newSlidingWindow( func() totalAndFailures { return totalAndFailures{} }, cb.resetTimeout, 10, ) cb.state.Store(state) if oldState != state { cb.onStateChangeHooks(oldState, state) } } ================================================ FILE: circuit_breaker_test.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty import ( "net/http" "sync" "sync/atomic" "testing" "time" ) var _ CircuitBreakerPolicy = CircuitBreaker5xxPolicy func TestCircuitBreakerCountBased(t *testing.T) { ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { t.Logf("Method: %v", r.Method) t.Logf("Path: %v", r.URL.Path) switch r.URL.Path { case "/200": w.WriteHeader(http.StatusOK) return case "/500": w.WriteHeader(http.StatusInternalServerError) return } }) defer ts.Close() failThreshold := uint64(2) successThreshold := uint64(1) resetTimeout := 100 * time.Millisecond cb := NewCircuitBreakerWithCount(failThreshold, successThreshold, resetTimeout) c := dcnl().SetCircuitBreaker(cb) for i := uint64(0); i < failThreshold; i++ { _, err := c.R().Get(ts.URL + "/500") assertNil(t, err) } resp, err := c.R().Get(ts.URL + "/500") assertErrorIs(t, ErrCircuitBreakerOpen, err) assertNil(t, resp) assertEqual(t, CircuitBreakerStateOpen, c.circuitBreaker.getState(), "expected open state after reaching failure threshold") time.Sleep(resetTimeout + 50*time.Millisecond) assertEqual(t, CircuitBreakerStateHalfOpen, c.circuitBreaker.getState(), "expected half-open state") _, err = c.R().Get(ts.URL + "/500") assertError(t, err) assertEqual(t, CircuitBreakerStateOpen, c.circuitBreaker.getState(), "expected open state after failure in half-open") time.Sleep(resetTimeout + 50*time.Millisecond) assertEqual(t, CircuitBreakerStateHalfOpen, c.circuitBreaker.getState(), "expected half-open state") for i := uint64(0); i < successThreshold; i++ { _, err := c.R().Get(ts.URL + "/200") assertNil(t, err) } assertEqual(t, CircuitBreakerStateClosed, c.circuitBreaker.getState(), "expected closed state after success threshold") resp, err = c.R().Get(ts.URL + "/200") assertNil(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) _, err = c.R().Get(ts.URL + "/500") assertError(t, err) assertEqual(t, 1, c.circuitBreaker.sw.Get().failures, "expected failure count to be 1 after single failure in closed state") time.Sleep(resetTimeout) _, err = c.R().Get(ts.URL + "/500") assertError(t, err) assertEqual(t, 1, c.circuitBreaker.sw.Get().failures, "expected failure count to be 1 after single failure in closed state") } func TestCircuitBreaker5xxPolicy(t *testing.T) { res1 := CircuitBreaker5xxPolicy(&http.Response{StatusCode: 500}) assertTrue(t, res1, "expected true for 5xx status code") res2 := CircuitBreaker5xxPolicy(&http.Response{StatusCode: 200}) assertFalse(t, res2, "expected false for non-5xx status code") } func TestCircuitBreakerCountBasedOpensAndAllow(t *testing.T) { cb := NewCircuitBreakerWithCount(2, 1, 20*time.Millisecond) fail := &http.Response{StatusCode: 500} // expected allow when state is closed err1 := cb.allow() assertNil(t, err1) assertEqual(t, 0, cb.sw.Get().failures, "expected allow when no failures initially") // expected still closed after 1 failure cb.applyPolicies(fail) err2 := cb.allow() assertNil(t, err2) assertEqual(t, 1, cb.sw.Get().failures, "expected still closed after 1 failure") // expected open after reaching failure threshold cb.applyPolicies(fail) err3 := cb.allow() assertErrorIs(t, ErrCircuitBreakerOpen, err3, "expected open after reaching failure threshold") // time.Sleep to half-open state time.Sleep(25 * time.Millisecond) assertEqual(t, CircuitBreakerStateHalfOpen, cb.getState(), "expected half-open state after reset timeout") // expected still half-open after a failure cb.applyPolicies(fail) assertEqual(t, CircuitBreakerStateOpen, cb.getState(), "expected open state after failure in half-open") // expected open state on allow err4 := cb.allow() assertErrorIs(t, ErrCircuitBreakerOpen, err4, "expected open state on allow after failure in half-open") } func TestCircuitBreakerCountBasedHalfOpenToClosedOnSuccess(t *testing.T) { cb := NewCircuitBreakerWithCount(1, 1, 30*time.Millisecond) fail := &http.Response{StatusCode: 500} ok := &http.Response{StatusCode: 200} // expected open after failing threshold cb.applyPolicies(fail) err1 := cb.allow() assertErrorIs(t, ErrCircuitBreakerOpen, err1, "expected open after failing threshold") // wait for resetTimeout to transition to half-open deadline := time.Now().Add(200 * time.Millisecond) for time.Now().Before(deadline) { if cb.getState() == CircuitBreakerStateHalfOpen { break } time.Sleep(5 * time.Millisecond) } // expected half-open state after reset timeout assertEqual(t, CircuitBreakerStateHalfOpen, cb.getState(), "expected half-open state after reset timeout") // on success in half-open, should move to closed cb.applyPolicies(ok) assertEqual(t, CircuitBreakerStateClosed, cb.getState(), "expected closed state after success in half-open") // expected allow when closed err := cb.allow() assertNil(t, err) } func TestCircuitBreakerRatioBasedOpenToClosed(t *testing.T) { cb := NewCircuitBreakerWithRatio(0.5, 2, 20*time.Millisecond) fail := &http.Response{StatusCode: 500} ok := &http.Response{StatusCode: 200} // two failures should open (2/2 = 1.0 >= 0.5) cb.applyPolicies(fail) err1 := cb.allow() assertNil(t, err1) if err1 == ErrCircuitBreakerOpen { t.Errorf("expected still closed after 1 failure (minRequests not met)") } // expected open after failures exceed ratio threshold cb.applyPolicies(fail) err2 := cb.allow() assertErrorIs(t, ErrCircuitBreakerOpen, err2, "expected open after failures exceed ratio threshold") time.Sleep(25 * time.Millisecond) // expected half-open state after reset timeout assertEqual(t, CircuitBreakerStateHalfOpen, cb.getState(), "expected half-open state after reset timeout") // on success in half-open, should move to closed cb.applyPolicies(ok) assertEqual(t, CircuitBreakerStateClosed, cb.getState(), "expected closed state after success in half-open") } func TestCircuitBreakerNewStateAndPolicies(t *testing.T) { cb := NewCircuitBreakerWithCount(3, 2, 10*time.Millisecond, CircuitBreaker5xxPolicy) assertEqual(t, CircuitBreakerStateClosed, cb.getState()) assertEqual(t, uint64(3), cb.failureThreshold) assertEqual(t, uint64(2), cb.successThreshold) assertEqual(t, 10*time.Millisecond, cb.resetTimeout) assertEqual(t, 1, len(cb.policies)) } func TestCircuitBreakerChangeStateClearsCounts(t *testing.T) { cb := NewCircuitBreakerWithCount(2, 1, 10*time.Millisecond) fail := &http.Response{StatusCode: 500} cb.applyPolicies(fail) assertEqual(t, 1, cb.sw.Get().failures) cb.changeState(CircuitBreakerStateHalfOpen) assertEqual(t, CircuitBreakerStateHalfOpen, cb.getState()) assertEqual(t, 0, cb.sw.Get().failures) assertEqual(t, 0, cb.sw.Get().total) } func TestCircuitBreakerAllowDuringHalfOpen(t *testing.T) { cb := NewCircuitBreakerWithCount(1, 1, 20*time.Millisecond) fail := &http.Response{StatusCode: 500} cb.applyPolicies(fail) // opens assertErrorIs(t, ErrCircuitBreakerOpen, cb.allow(), "expected open state") time.Sleep(25 * time.Millisecond) // wait to transition to half-open assertEqual(t, CircuitBreakerStateHalfOpen, cb.getState(), "expected half-open state") assertNil(t, cb.allow()) } func TestCircuitBreakerOnTriggerHooks(t *testing.T) { cb := NewCircuitBreakerWithCount(1, 1, 10*time.Millisecond) called := false var gotErr error cb.OnTrigger(func(r *Request, e error) { called = true gotErr = e }) cb.onTriggerHooks(nil, ErrCircuitBreakerOpen) assertTrue(t, called, "expected onTrigger hook to be called") assertEqual(t, ErrCircuitBreakerOpen, gotErr, "expected error to be passed to onTrigger hook") } func TestCircuitBreakerOnStateChangeHooks(t *testing.T) { cb := NewCircuitBreakerWithCount(1, 1, 10*time.Millisecond) called := false var oldState, newState CircuitBreakerState cb.OnStateChange(func(o, n CircuitBreakerState) { called = true oldState = o newState = n }) cb.onStateChangeHooks(CircuitBreakerStateClosed, CircuitBreakerStateOpen) assertTrue(t, called) assertEqual(t, CircuitBreakerStateClosed, oldState, "expected old state to be passed to onStateChange hook") assertEqual(t, CircuitBreakerStateOpen, newState, "expected new state to be passed to onStateChange hook") } func TestCircuitBreakerMultipleHooksAreCalled(t *testing.T) { cb := NewCircuitBreakerWithCount(1, 1, 10*time.Millisecond) triggerCount := 0 cb.OnTrigger(func(_ *Request, _ error) { triggerCount++ }) cb.OnTrigger(func(_ *Request, _ error) { triggerCount++ }) cb.onTriggerHooks(nil, ErrCircuitBreakerOpen) assertEqual(t, 2, triggerCount, "expected both trigger hooks to be called") stateCount := 0 cb.OnStateChange(func(_, _ CircuitBreakerState) { stateCount++ }) cb.OnStateChange(func(_, _ CircuitBreakerState) { stateCount++ }) cb.onStateChangeHooks(CircuitBreakerStateClosed, CircuitBreakerStateHalfOpen) assertEqual(t, 2, stateCount, "expected both state change hooks to be called") } func TestCircuitBreakerConcurrentOnTriggerRegistration(t *testing.T) { cb := NewCircuitBreakerWithCount(1, 1, 10*time.Millisecond) var wg sync.WaitGroup var cnt int32 n := 100 wg.Add(n) for i := 0; i < n; i++ { go func() { cb.OnTrigger(func(_ *Request, _ error) { atomic.AddInt32(&cnt, 1) }) wg.Done() }() } wg.Wait() cb.onTriggerHooks(nil, ErrCircuitBreakerOpen) got := atomic.LoadInt32(&cnt) assertEqual(t, int32(n), got, "expected N hooks executed") } func TestCircuitBreakerConcurrentOnStateChangeRegistration(t *testing.T) { cb := NewCircuitBreakerWithCount(1, 1, 10*time.Millisecond) var wg sync.WaitGroup var cnt int32 n := 100 wg.Add(n) for i := 0; i < n; i++ { go func() { cb.OnStateChange(func(_, _ CircuitBreakerState) { atomic.AddInt32(&cnt, 1) }) wg.Done() }() } wg.Wait() cb.onStateChangeHooks(CircuitBreakerStateClosed, CircuitBreakerStateOpen) got := atomic.LoadInt32(&cnt) assertEqual(t, int32(n), got, "expected N state change hooks executed") } func TestCircuitBreakerSlidingWindow1SetInterval(t *testing.T) { cb := NewCircuitBreakerWithCount(2, 1, 100*time.Millisecond) // Verify initial interval assertEqual(t, 100*time.Millisecond, cb.sw.interval, "initial interval mismatch") // Change interval to a longer duration cb.sw.SetInterval(200 * time.Millisecond) // Verify interval was changed assertEqual(t, 200*time.Millisecond, cb.sw.interval, "interval not updated correctly") } func TestCircuitBreakerSlidingWindow2SetInterval(t *testing.T) { sw := newSlidingWindow(func() totalAndFailures { return totalAndFailures{} }, 100*time.Millisecond, 5) assertEqual(t, 100*time.Millisecond, sw.interval, "initial interval mismatch") sw.SetInterval(250 * time.Millisecond) assertEqual(t, 250*time.Millisecond, sw.interval, "interval not updated correctly") } func TestCircuitBreakerSlidingWindowConcurrentAddGet(t *testing.T) { sw := newSlidingWindow(func() totalAndFailures { return totalAndFailures{} }, 200*time.Millisecond, 10) var wg sync.WaitGroup n := 200 wg.Add(n) for i := 0; i < n; i++ { go func() { sw.Add(totalAndFailures{total: 1, failures: 0}) wg.Done() }() } wg.Wait() got := sw.Get() assertEqual(t, n, got.total, "concurrent adds: expected total count mismatch") } func TestCircuitBreakerTotalAndFailuresOperations(t *testing.T) { a := totalAndFailures{total: 2, failures: 1} b := totalAndFailures{total: 3, failures: 2} c := a.op(b) assertEqual(t, 5, c.total, "op result incorrect, want total 5") assertEqual(t, 3, c.failures, "op result incorrect, want failures 3") inv := c.inverse() assertEqual(t, -5, inv.total, "inverse result incorrect, want total -5") assertEqual(t, -3, inv.failures, "inverse result incorrect, want failures -3") empty := c.empty() assertEqual(t, 0, empty.total, "empty result incorrect, want total 0") assertEqual(t, 0, empty.failures, "empty result incorrect, want failures 0") } func TestCircuitBreakerSlidingWindowResetWhenElapsedExceedsBuckets(t *testing.T) { interval := 100 * time.Millisecond sw := newSlidingWindow(func() totalAndFailures { return totalAndFailures{} }, interval, 4) // Pre-populate total and buckets to non-zero values sw.values[0] = totalAndFailures{total: 5, failures: 2} sw.values[1] = totalAndFailures{total: 3, failures: 1} sw.total = sw.values[0].op(sw.values[1]).op(sw.total) // Force lastStart far in the past so bucketsToAdvance >= len(values) path is taken sw.lastStart = sw.lastStart.Add(-time.Duration(10) * interval) // Add a new value; should reset buckets and only this value remains sw.Add(totalAndFailures{total: 1, failures: 1}) got := sw.Get() assertEqual(t, 1, got.total, "after reset expected total=1") assertEqual(t, 1, got.failures, "after reset expected failures=1") assertEqual(t, 0, sw.idx, "expected idx reset to 0") } ================================================ FILE: client.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty import ( "bytes" "context" "crypto/tls" "crypto/x509" "errors" "io" "maps" "net/http" "net/url" "os" "reflect" "slices" "strings" "sync" "time" ) const ( // MethodGet HTTP method MethodGet = "GET" // MethodPost HTTP method MethodPost = "POST" // MethodPut HTTP method MethodPut = "PUT" // MethodDelete HTTP method MethodDelete = "DELETE" // MethodPatch HTTP method MethodPatch = "PATCH" // MethodHead HTTP method MethodHead = "HEAD" // MethodOptions HTTP method MethodOptions = "OPTIONS" // MethodTrace HTTP method MethodTrace = "TRACE" ) const ( defaultWatcherPoolingInterval = 24 * time.Hour ) var ( ErrNotHttpTransportType = errors.New("resty: not a http.Transport type") ErrUnsupportedRequestBodyKind = errors.New("resty: unsupported request body kind") hdrUserAgentKey = http.CanonicalHeaderKey("User-Agent") hdrAcceptKey = http.CanonicalHeaderKey("Accept") hdrAcceptEncodingKey = http.CanonicalHeaderKey("Accept-Encoding") hdrContentTypeKey = http.CanonicalHeaderKey("Content-Type") hdrContentLengthKey = http.CanonicalHeaderKey("Content-Length") hdrContentEncodingKey = http.CanonicalHeaderKey("Content-Encoding") hdrContentDisposition = http.CanonicalHeaderKey("Content-Disposition") hdrAuthorizationKey = http.CanonicalHeaderKey("Authorization") hdrWwwAuthenticateKey = http.CanonicalHeaderKey("WWW-Authenticate") hdrRetryAfterKey = http.CanonicalHeaderKey("Retry-After") hdrCookieKey = http.CanonicalHeaderKey("Cookie") plainTextType = "text/plain; charset=utf-8" jsonContentType = "application/json" formContentType = "application/x-www-form-urlencoded" jsonKey = "json" xmlKey = "xml" defaultAuthScheme = "Bearer" hdrUserAgentValue = "go-resty/" + Version + " (https://resty.dev)" bufPool = &sync.Pool{New: func() any { return &bytes.Buffer{} }} ) type ( // RequestMiddleware type is for request middleware, called before a request is sent RequestMiddleware func(*Client, *Request) error // ResponseMiddleware type is for response middleware, called after a response has been received ResponseMiddleware func(*Client, *Response) error // ErrorHook type is for reacting to request errors, called after all retries were attempted ErrorHook func(*Request, error) // SuccessHook type is for reacting to request success SuccessHook func(*Client, *Response) // CloseHook type is for reacting to client closing CloseHook func() // RequestFunc type is for extended manipulation of the Request instance RequestFunc func(*Request) *Request // TLSClientConfiger interface is to configure TLS Client configuration on custom transport // implemented using [http.RoundTripper] TLSClientConfiger interface { TLSClientConfig() *tls.Config SetTLSClientConfig(*tls.Config) error } ) // TransportSettings struct is used to define custom dialer and transport // values for the Resty client. Please refer to individual // struct fields to know the default values. // // Also, refer to https://pkg.go.dev/net/http#Transport for more details. type TransportSettings struct { // DialerTimeout, default value is `30` seconds. DialerTimeout time.Duration // DialerKeepAlive, default value is `30` seconds. DialerKeepAlive time.Duration // IdleConnTimeout, default value is `90` seconds. IdleConnTimeout time.Duration // TLSHandshakeTimeout, default value is `10` seconds. TLSHandshakeTimeout time.Duration // ExpectContinueTimeout, default value is `1` seconds. ExpectContinueTimeout time.Duration // ResponseHeaderTimeout, added to provide ability to // set value. No default value in Resty, the Go // HTTP client default value applies. ResponseHeaderTimeout time.Duration // MaxIdleConns, default value is `100`. MaxIdleConns int // MaxIdleConnsPerHost, default value is `runtime.GOMAXPROCS(0) + 1`. MaxIdleConnsPerHost int // MaxConnsPerHost, default value is no limit. MaxConnsPerHost int // DisableKeepAlives, default value is `false`. DisableKeepAlives bool // MaxResponseHeaderBytes, added to provide ability to // set value. No default value in Resty, the Go // HTTP client default value applies. MaxResponseHeaderBytes int64 // WriteBufferSize, added to provide ability to // set value. No default value in Resty, the Go // HTTP client default value applies. WriteBufferSize int // ReadBufferSize, added to provide ability to // set value. No default value in Resty, the Go // HTTP client default value applies. ReadBufferSize int } // Client struct is used to create a Resty client with client-level settings, // these settings apply to all the requests raised from the client. // // Resty also provides an option to override most of the client settings // at [Request] level. type Client struct { lock *sync.RWMutex baseURL string queryParams url.Values formData url.Values pathParams map[string]string header http.Header credentials *credentials authToken string authScheme string cookies []*http.Cookie errorType reflect.Type debug bool disableWarn bool isMethodGetAllowPayload bool isMethodDeleteAllowPayload bool timeout time.Duration retryCount int retryWaitTime time.Duration retryMaxWaitTime time.Duration retryConditions []RetryConditionFunc retryHooks []RetryHookFunc retryDelayStrategy RetryDelayStrategyFunc isRetryDefaultConditions bool isRetryAllowNonIdempotent bool headerAuthorizationKey string responseBodyLimit int64 resBodyUnlimitedReads bool jsonEscapeHTML bool closeConnection bool isResponseDoNotParse bool isTrace bool debugBodyLimit int responseSaveDirectory string isResponseSaveToFile bool scheme string log Logger ctx context.Context httpClient *http.Client proxyURL *url.URL debugLogFormatter DebugLogFormatterFunc debugLogCallback DebugLogCallbackFunc isCurlCmdGenerate bool isCurlCmdDebugLog bool unescapeQueryParams bool loadBalancer LoadBalancer beforeRequest []RequestMiddleware afterResponse []ResponseMiddleware errorHooks []ErrorHook invalidHooks []ErrorHook panicHooks []ErrorHook successHooks []SuccessHook closeHooks []CloseHook contentTypeEncoders map[string]ContentTypeEncoder contentTypeDecoders map[string]ContentTypeDecoder contentDecompresserKeys []string contentDecompressers map[string]ContentDecompresser certWatcherStopChan chan bool circuitBreaker *CircuitBreaker hedging *Hedging } // CertWatcherOptions allows configuring a watcher that reloads dynamically TLS certs. type CertWatcherOptions struct { // PoolInterval is the frequency at which resty will check if the PEM file needs to be reloaded. // Default is 24 hours. PoolInterval time.Duration } //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ // Client methods //___________________________________ // BaseURL method returns the Base URL value from the client instance. func (c *Client) BaseURL() string { c.lock.RLock() defer c.lock.RUnlock() return c.baseURL } // SetBaseURL method sets the Base URL in the client instance. It will be used with a request // raised from this client with a relative URL // // // Setting HTTP address // client.SetBaseURL("http://myjeeva.com") // // // Setting HTTPS address // client.SetBaseURL("https://myjeeva.com") func (c *Client) SetBaseURL(url string) *Client { c.lock.Lock() defer c.lock.Unlock() c.baseURL = strings.TrimRight(url, "/") return c } // LoadBalancer method returns the request load balancer instance from the client // instance. Otherwise returns nil. func (c *Client) LoadBalancer() LoadBalancer { c.lock.RLock() defer c.lock.RUnlock() return c.loadBalancer } // SetLoadBalancer method is used to set the new request load balancer into the client. func (c *Client) SetLoadBalancer(b LoadBalancer) *Client { c.lock.Lock() defer c.lock.Unlock() c.loadBalancer = b return c } // Header method returns the headers from the client instance. func (c *Client) Header() http.Header { c.lock.RLock() defer c.lock.RUnlock() return c.header } // SetHeader method sets a single header and its value in the client instance. // These headers will be applied to all requests raised from the client instance. // Also, it can be overridden by request-level header options. // // For Example: To set `Content-Type` and `Accept` as `application/json` // // client. // SetHeader("Content-Type", "application/json"). // SetHeader("Accept", "application/json") // // See [Request.SetHeader] or [Request.SetHeaders]. func (c *Client) SetHeader(header, value string) *Client { c.lock.Lock() defer c.lock.Unlock() c.header.Set(header, value) return c } // SetHeaderAny method sets a single header field and its value in the client instance // for all requests raised from the client. // // It is similar to [Client.SetHeader] but accepts any type as the value and converts // it to a string using predefined formatting rules (integers, bools, time.Time, etc.). // // For Example: To set `X-Request-Id` with an integer value // // client.SetHeaderAny("X-Request-Id", 12345) // // See [Request.SetHeaderAny] or [Client.SetHeader]. func (c *Client) SetHeaderAny(header string, value any) *Client { c.lock.Lock() defer c.lock.Unlock() strVal := formatAnyToString(value) c.header.Set(header, strVal) return c } // SetHeaders method sets multiple headers and their values at one go, and // these headers will be applied to all requests raised from the client instance. // Also, it can be overridden at request-level headers options. // // For Example: To set `Content-Type` and `Accept` as `application/json` // // client.SetHeaders(map[string]string{ // "Content-Type": "application/json", // "Accept": "application/json", // }) // // See [Request.SetHeaders] or [Request.SetHeader]. func (c *Client) SetHeaders(headers map[string]string) *Client { c.lock.Lock() defer c.lock.Unlock() for h, v := range headers { c.header.Set(h, v) } return c } // SetHeaderVerbatim method is used to set the HTTP header key and value verbatim in the current request. // It is typically helpful for legacy applications or servers that require HTTP headers in a certain way // // For Example: To set header key as `all_lowercase`, `UPPERCASE`, and `x-cloud-trace-id` // // client. // SetHeaderVerbatim("all_lowercase", "available"). // SetHeaderVerbatim("UPPERCASE", "available"). // SetHeaderVerbatim("x-cloud-trace-id", "798e94019e5fc4d57fbb8901eb4c6cae") // // See [Request.SetHeaderVerbatim]. func (c *Client) SetHeaderVerbatim(header, value string) *Client { c.lock.Lock() defer c.lock.Unlock() c.header[header] = []string{value} return c } // SetHeaderVerbatimAny method sets the HTTP header key and value verbatim in the client instance // for all requests raised from the client. // // It is similar to [Client.SetHeaderVerbatim] but accepts any type as the value and converts // it to a string using predefined formatting rules (integers, bools, time.Time, etc.). // // For Example: To set header key as `x-trace-id` with an integer value // // client.SetHeaderVerbatimAny("x-trace-id", 798940) // // See [Request.SetHeaderVerbatimAny] or [Client.SetHeaderVerbatim]. func (c *Client) SetHeaderVerbatimAny(header string, value any) *Client { c.lock.Lock() defer c.lock.Unlock() strVal := formatAnyToString(value) c.header[header] = []string{strVal} return c } // Context method returns the [context.Context] from the client instance. func (c *Client) Context() context.Context { c.lock.RLock() defer c.lock.RUnlock() return c.ctx } // SetContext method sets the given [context.Context] in the client instance and // it gets added to [Request] raised from this instance. func (c *Client) SetContext(ctx context.Context) *Client { c.lock.Lock() defer c.lock.Unlock() c.ctx = ctx return c } // CookieJar method returns the HTTP cookie jar instance from the underlying Go HTTP Client. func (c *Client) CookieJar() http.CookieJar { return c.Client().Jar } // SetCookieJar method sets custom [http.CookieJar] in the resty client. It's a way to override the default. // // For Example, sometimes we don't want to save cookies in API mode so that we can remove the default // CookieJar in resty client. // // client.SetCookieJar(nil) func (c *Client) SetCookieJar(jar http.CookieJar) *Client { c.lock.Lock() defer c.lock.Unlock() c.httpClient.Jar = jar return c } // Cookies method returns all cookies registered in the client instance. func (c *Client) Cookies() []*http.Cookie { c.lock.RLock() defer c.lock.RUnlock() return c.cookies } // SetCookie method appends a single cookie to the client instance. // These cookies will be added to all the requests from this client instance. // // client.SetCookie(&http.Cookie{ // Name:"go-resty", // Value:"This is cookie value", // }) func (c *Client) SetCookie(hc *http.Cookie) *Client { c.lock.Lock() defer c.lock.Unlock() c.cookies = append(c.cookies, hc) return c } // SetCookies method sets an array of cookies in the client instance. // These cookies will be added to all the requests from this client instance. // // cookies := []*http.Cookie{ // &http.Cookie{ // Name:"go-resty-1", // Value:"This is cookie 1 value", // }, // &http.Cookie{ // Name:"go-resty-2", // Value:"This is cookie 2 value", // }, // } // // // Setting a cookies into resty // client.SetCookies(cookies) func (c *Client) SetCookies(cs []*http.Cookie) *Client { c.lock.Lock() defer c.lock.Unlock() c.cookies = append(c.cookies, cs...) return c } // QueryParams method returns all query parameters and their values from the client instance. func (c *Client) QueryParams() url.Values { c.lock.RLock() defer c.lock.RUnlock() return c.queryParams } // SetQueryParam method sets a single parameter and its value in the client instance. // It will be formed as a query string for the request. // // For Example: `search=kitchen%20papers&size=large` // // In the URL after the `?` mark. These query params will be added to all the requests raised from // this client instance. Also, it can be overridden at the request level. // // See [Request.SetQueryParam] or [Request.SetQueryParams]. // // client. // SetQueryParam("search", "kitchen papers"). // SetQueryParam("size", "large") func (c *Client) SetQueryParam(param, value string) *Client { c.lock.Lock() defer c.lock.Unlock() c.queryParams.Set(param, value) return c } // SetQueryParamAny method sets a single query parameter and its value in the client instance. // It will be formed as a query string for the request. // // It is similar to [Client.SetQueryParam] but accepts any type as the value and converts // it to a string using predefined formatting rules (integers, bools, time.Time, etc.). // // For Example: To set `page` and `active` query parameters // // client. // SetQueryParamAny("page", 5). // SetQueryParamAny("active", true) // // See [Request.SetQueryParamAny] or [Client.SetQueryParam]. func (c *Client) SetQueryParamAny(param string, value any) *Client { c.lock.Lock() defer c.lock.Unlock() strVal := formatAnyToString(value) c.queryParams.Set(param, strVal) return c } // SetQueryParams method sets multiple parameters and their values at one go in the client instance. // It will be formed as a query string for the request. // // For Example: `search=kitchen%20papers&size=large` // // In the URL after the `?` mark. These query params will be added to all the requests raised from this // client instance. Also, it can be overridden at the request level. // // See [Request.SetQueryParams] or [Request.SetQueryParam]. // // client.SetQueryParams(map[string]string{ // "search": "kitchen papers", // "size": "large", // }) func (c *Client) SetQueryParams(params map[string]string) *Client { // Do not lock here since there is potential deadlock. for p, v := range params { c.SetQueryParam(p, v) } return c } // FormData method returns the form parameters and their values from the client instance. func (c *Client) FormData() url.Values { c.lock.RLock() defer c.lock.RUnlock() return c.formData } // SetFormData method sets Form parameters and their values in the client instance. // The request content type would be set as `application/x-www-form-urlencoded`. // The client-level form data gets added to all the requests. Also, it can be // overridden at the request level. // // See [Request.SetFormData]. // // client.SetFormData(map[string]string{ // "access_token": "BC594900-518B-4F7E-AC75-BD37F019E08F", // "user_id": "3455454545", // }) func (c *Client) SetFormData(data map[string]string) *Client { c.lock.Lock() defer c.lock.Unlock() for k, v := range data { c.formData.Set(k, v) } return c } // SetBasicAuth method sets the basic authentication header in the HTTP request. For Example: // // Authorization: Basic // // For Example: To set the header for username "go-resty" and password "welcome" // // client.SetBasicAuth("go-resty", "welcome") // // This basic auth information is added to all requests from this client instance. // It can also be overridden at the request level. // // See [Request.SetBasicAuth]. func (c *Client) SetBasicAuth(username, password string) *Client { c.lock.Lock() defer c.lock.Unlock() c.credentials = &credentials{Username: username, Password: password} return c } // AuthToken method returns the auth token value registered in the client instance. func (c *Client) AuthToken() string { c.lock.RLock() defer c.lock.RUnlock() return c.authToken } // HeaderAuthorizationKey method returns the HTTP header name for Authorization from the client instance. func (c *Client) HeaderAuthorizationKey() string { c.lock.RLock() defer c.lock.RUnlock() return c.headerAuthorizationKey } // SetHeaderAuthorizationKey method sets the given HTTP header name for Authorization in the client instance. // // It can be overridden at the request level; see [Request.SetHeaderAuthorizationKey]. // // client.SetHeaderAuthorizationKey("X-Custom-Authorization") func (c *Client) SetHeaderAuthorizationKey(k string) *Client { c.lock.Lock() defer c.lock.Unlock() c.headerAuthorizationKey = k return c } // SetAuthToken method sets the auth token of the `Authorization` header for all HTTP requests. // The default auth scheme is `Bearer`; it can be customized with the method [Client.SetAuthScheme]. For Example: // // Authorization: // // For Example: To set auth token BC594900518B4F7EAC75BD37F019E08FBC594900518B4F7EAC75BD37F019E08F // // client.SetAuthToken("BC594900518B4F7EAC75BD37F019E08FBC594900518B4F7EAC75BD37F019E08F") // // This auth token gets added to all the requests raised from this client instance. // Also, it can be overridden at the request level. // // See [Request.SetAuthToken]. func (c *Client) SetAuthToken(token string) *Client { c.lock.Lock() defer c.lock.Unlock() c.authToken = token return c } // AuthScheme method returns the auth scheme name set in the client instance. // // See [Client.SetAuthScheme], [Request.SetAuthScheme]. func (c *Client) AuthScheme() string { c.lock.RLock() defer c.lock.RUnlock() return c.authScheme } // SetAuthScheme method sets the auth scheme type in the HTTP request. For Example: // // Authorization: // // For Example: To set the scheme to use OAuth // // client.SetAuthScheme("OAuth") // // This auth scheme gets added to all the requests raised from this client instance. // Also, it can be overridden at the request level. // // Information about auth schemes can be found in [RFC 7235], IANA [HTTP Auth schemes]. // // See [Request.SetAuthScheme]. // // [RFC 7235]: https://tools.ietf.org/html/rfc7235 // [HTTP Auth schemes]: https://www.iana.org/assignments/http-authschemes/http-authschemes.xhtml#authschemes func (c *Client) SetAuthScheme(scheme string) *Client { c.lock.Lock() defer c.lock.Unlock() c.authScheme = scheme return c } // SetDigestAuth method sets the Digest Auth transport with provided credentials in the client. // If a server responds with 401 and sends a Digest challenge in the header `WWW-Authenticate`, // the request will be resent with the appropriate digest `Authorization` header. // // For Example: To set the Digest scheme with user "Mufasa" and password "Circle Of Life" // // client.SetDigestAuth("Mufasa", "Circle Of Life") // // Information about Digest Access Authentication can be found in [RFC 7616]. // // NOTE: // - On the QOP `auth-int` scenario, the request body is read into memory to // compute the body hash that increases memory usage. // - Create a dedicated client instance to use digest auth, // as it does digest auth for all the requests raised by the client. // // [RFC 7616]: https://datatracker.ietf.org/doc/html/rfc7616 func (c *Client) SetDigestAuth(username, password string) *Client { dt := &digestTransport{ credentials: &credentials{username, password}, transport: c.Transport(), } c.SetTransport(dt) return c } // R method creates a new request instance; it's used for Get, Post, Put, Delete, Patch, Head, Options, etc. func (c *Client) R() *Request { c.lock.RLock() defer c.lock.RUnlock() r := &Request{ QueryParams: url.Values{}, FormData: url.Values{}, Header: http.Header{}, Cookies: make([]*http.Cookie, 0), PathParams: make(map[string]string), Timeout: c.timeout, IsDebug: c.debug, IsTrace: c.isTrace, IsResponseSaveToFile: c.isResponseSaveToFile, AuthScheme: c.authScheme, AuthToken: c.authToken, RetryCount: c.retryCount, RetryWaitTime: c.retryWaitTime, RetryMaxWaitTime: c.retryMaxWaitTime, RetryDelayStrategy: c.retryDelayStrategy, IsRetryDefaultConditions: c.isRetryDefaultConditions, IsCloseConnection: c.closeConnection, IsResponseDoNotParse: c.isResponseDoNotParse, DebugBodyLimit: c.debugBodyLimit, ResponseBodyLimit: c.responseBodyLimit, IsResponseBodyUnlimitedReads: c.resBodyUnlimitedReads, IsMethodGetAllowPayload: c.isMethodGetAllowPayload, IsMethodDeleteAllowPayload: c.isMethodDeleteAllowPayload, IsRetryAllowNonIdempotent: c.isRetryAllowNonIdempotent, HeaderAuthorizationKey: c.headerAuthorizationKey, mu: new(sync.Mutex), client: c, baseURL: c.baseURL, multipartFields: make([]*MultipartField, 0), jsonEscapeHTML: c.jsonEscapeHTML, log: c.log, isCurlCmdGenerate: c.isCurlCmdGenerate, isCurlCmdDebugLog: c.isCurlCmdDebugLog, unescapeQueryParams: c.unescapeQueryParams, credentials: c.credentials, } if c.ctx != nil { r.ctx = context.WithoutCancel(c.ctx) // refer to godoc for more info about this function } return r } // NewRequest method is an alias for method `R()`. func (c *Client) NewRequest() *Request { return c.R() } // SetRequestMiddlewares method allows Resty users to override the default request // middlewares sequence // // client.SetRequestMiddlewares( // Custom1RequestMiddleware, // Custom2RequestMiddleware, // resty.PrepareRequestMiddleware, // after this, `Request.RawRequest` instance is available // Custom3RequestMiddleware, // Custom4RequestMiddleware, // ) // // See, [Client.AddRequestMiddleware] // // NOTE: // - It overwrites the existing request middleware list. // - Be sure to include Resty request middlewares in the request chain at the appropriate spot. func (c *Client) SetRequestMiddlewares(middlewares ...RequestMiddleware) *Client { c.lock.Lock() defer c.lock.Unlock() c.beforeRequest = middlewares return c } // SetResponseMiddlewares method allows Resty users to override the default response // middlewares sequence // // client.SetResponseMiddlewares( // Custom1ResponseMiddleware, // Custom2ResponseMiddleware, // resty.AutoParseResponseMiddleware, // before this, the body is not read except on the debug flow // Custom3ResponseMiddleware, // resty.SaveToFileResponseMiddleware, // See, Request.SetOutputFileName, Request.SetSaveResponse // Custom4ResponseMiddleware, // Custom5ResponseMiddleware, // ) // // See, [Client.AddResponseMiddleware] // // NOTE: // - It overwrites the existing response middleware list. // - Be sure to include Resty response middlewares in the response chain at the appropriate spot. func (c *Client) SetResponseMiddlewares(middlewares ...ResponseMiddleware) *Client { c.lock.Lock() defer c.lock.Unlock() c.afterResponse = middlewares return c } func (c *Client) requestMiddlewares() []RequestMiddleware { c.lock.RLock() defer c.lock.RUnlock() return c.beforeRequest } // AddRequestMiddleware method appends a request middleware to the before request chain. // After all requests, middlewares are applied, and the request is sent to the host server. // // client.AddRequestMiddleware(func(c *resty.Client, r *resty.Request) error { // // Now you have access to the Client and Request instance // // manipulate it as per your need // // return nil // if its successful otherwise return error // }) func (c *Client) AddRequestMiddleware(m RequestMiddleware) *Client { c.lock.Lock() defer c.lock.Unlock() idx := len(c.beforeRequest) - 1 c.beforeRequest = slices.Insert(c.beforeRequest, idx, m) return c } func (c *Client) responseMiddlewares() []ResponseMiddleware { c.lock.RLock() defer c.lock.RUnlock() return c.afterResponse } // AddResponseMiddleware method appends response middleware to the after-response chain. // All the response middlewares are applied; once we receive a response // from the host server. // // client.AddResponseMiddleware(func(c *resty.Client, r *resty.Response) error { // // Now you have access to the Client and Response instance // // Also, you could access request via Response.Request i.e., r.Request // // manipulate it as per your need // // return nil // if its successful otherwise return error // }) func (c *Client) AddResponseMiddleware(m ResponseMiddleware) *Client { c.lock.Lock() defer c.lock.Unlock() c.afterResponse = append(c.afterResponse, m) return c } // OnError method adds a callback that will be run whenever a request execution fails. // This is called after all retries have been attempted (if any). // If there was a response from the server, the error will be wrapped in [ResponseError] // which has the last response received from the server. // // client.OnError(func(req *resty.Request, err error) { // if v, ok := err.(*resty.ResponseError); ok { // // Do something with v.Response // } // // Log the error, increment a metric, etc... // }) // // Out of the [Client.OnSuccess], [Client.OnError], [Client.OnInvalid], [Client.OnPanic] // callbacks, exactly one set will be invoked for each call to [Request.Execute] that completes. // // NOTE: // - Do not use [Client] setter methods within OnError hooks; deadlock will happen. func (c *Client) OnError(hooks ...ErrorHook) *Client { c.lock.Lock() defer c.lock.Unlock() c.errorHooks = append(c.errorHooks, hooks...) return c } // OnSuccess method adds a callback that will be run whenever a request execution // succeeds. This is called after all retries have been attempted (if any). // // Out of the [Client.OnSuccess], [Client.OnError], [Client.OnInvalid], [Client.OnPanic] // callbacks, exactly one set will be invoked for each call to [Request.Execute] that completes. // // NOTE: // - Do not use [Client] setter methods within OnSuccess hooks; deadlock will happen. func (c *Client) OnSuccess(hooks ...SuccessHook) *Client { c.lock.Lock() defer c.lock.Unlock() c.successHooks = append(c.successHooks, hooks...) return c } // OnInvalid method adds a callback that will be run whenever a request execution // fails before it starts because the request is invalid. // // Out of the [Client.OnSuccess], [Client.OnError], [Client.OnInvalid], [Client.OnPanic] // callbacks, exactly one set will be invoked for each call to [Request.Execute] that completes. // // NOTE: // - Do not use [Client] setter methods within OnInvalid hooks; deadlock will happen. func (c *Client) OnInvalid(hooks ...ErrorHook) *Client { c.lock.Lock() defer c.lock.Unlock() c.invalidHooks = append(c.invalidHooks, hooks...) return c } // OnPanic method adds a callback that will be run whenever a request execution // panics. // // Out of the [Client.OnSuccess], [Client.OnError], [Client.OnInvalid], [Client.OnPanic] // callbacks, exactly one set will be invoked for each call to [Request.Execute] that completes. // // If an [Client.OnSuccess], [Client.OnError], or [Client.OnInvalid] callback panics, // then exactly one rule can be violated. // // NOTE: // - Do not use [Client] setter methods within OnPanic hooks; deadlock will happen. func (c *Client) OnPanic(hooks ...ErrorHook) *Client { c.lock.Lock() defer c.lock.Unlock() c.panicHooks = append(c.panicHooks, hooks...) return c } // OnClose method adds a callback that will be run whenever the client is closed. // The hooks are executed in the order they were registered. func (c *Client) OnClose(hooks ...CloseHook) *Client { c.lock.Lock() defer c.lock.Unlock() c.closeHooks = append(c.closeHooks, hooks...) return c } // ContentTypeEncoders method returns all the registered content type encoders. func (c *Client) ContentTypeEncoders() map[string]ContentTypeEncoder { c.lock.RLock() defer c.lock.RUnlock() return c.contentTypeEncoders } // AddContentTypeEncoder method adds the user-provided Content-Type encoder into a client. // // NOTE: It overwrites the encoder function if the given Content-Type key already exists. func (c *Client) AddContentTypeEncoder(ct string, e ContentTypeEncoder) *Client { c.lock.Lock() defer c.lock.Unlock() c.contentTypeEncoders[strings.ToLower(ct)] = e return c } func (c *Client) inferContentTypeEncoder(ct ...string) (ContentTypeEncoder, bool) { c.lock.RLock() defer c.lock.RUnlock() for _, v := range ct { if d, f := c.contentTypeEncoders[v]; f { return d, f } } return nil, false } // ContentTypeDecoders method returns all the registered content type decoders. func (c *Client) ContentTypeDecoders() map[string]ContentTypeDecoder { c.lock.RLock() defer c.lock.RUnlock() return c.contentTypeDecoders } // AddContentTypeDecoder method adds the user-provided Content-Type decoder into a client. // // NOTE: It overwrites the decoder function if the given Content-Type key already exists. func (c *Client) AddContentTypeDecoder(ct string, d ContentTypeDecoder) *Client { c.lock.Lock() defer c.lock.Unlock() c.contentTypeDecoders[strings.ToLower(ct)] = d return c } func (c *Client) inferContentTypeDecoder(ct ...string) (ContentTypeDecoder, bool) { c.lock.RLock() defer c.lock.RUnlock() for _, v := range ct { if d, f := c.contentTypeDecoders[v]; f { return d, f } } return nil, false } // ContentDecompressers method returns all the registered content-encoding Decompressers. func (c *Client) ContentDecompressers() map[string]ContentDecompresser { c.lock.RLock() defer c.lock.RUnlock() return c.contentDecompressers } // AddContentDecompresser method adds the user-provided Content-Encoding ([RFC 9110]) Decompresser // and directive into a client. // // NOTE: It overwrites the Decompresser function if the given Content-Encoding directive already exists. // // [RFC 9110]: https://datatracker.ietf.org/doc/html/rfc9110 func (c *Client) AddContentDecompresser(k string, d ContentDecompresser) *Client { c.lock.Lock() defer c.lock.Unlock() lk := strings.ToLower(k) if !slices.Contains(c.contentDecompresserKeys, lk) { c.contentDecompresserKeys = slices.Insert(c.contentDecompresserKeys, 0, lk) } c.contentDecompressers[lk] = d return c } // ContentDecompresserKeys method returns all the registered content-encoding Decompressers // keys as comma-separated string. func (c *Client) ContentDecompresserKeys() string { c.lock.RLock() defer c.lock.RUnlock() return strings.Join(c.contentDecompresserKeys, ", ") } // SetContentDecompresserKeys method sets given Content-Encoding ([RFC 9110]) directives into the client instance. // // It checks the given Content-Encoding exists in the [ContentDecompresser] list before assigning it, // if it does not exist, it will skip that directive. // // Use this method to overwrite the default order. If a new content Decompresser is added, // that directive will be the first. // // [RFC 9110]: https://datatracker.ietf.org/doc/html/rfc9110 func (c *Client) SetContentDecompresserKeys(keys []string) *Client { result := make([]string, 0) decoders := c.ContentDecompressers() for _, k := range keys { k = strings.ToLower(k) if _, f := decoders[k]; f { result = append(result, k) } } c.lock.Lock() defer c.lock.Unlock() c.contentDecompresserKeys = result return c } // SetCircuitBreaker method sets the Circuit Breaker instance into the client. // It is used to prevent the client from sending requests that are likely to fail. // For Example: To use the default Circuit Breaker: // // client.SetCircuitBreaker(NewCircuitBreaker()) func (c *Client) SetCircuitBreaker(b *CircuitBreaker) *Client { c.lock.Lock() defer c.lock.Unlock() c.circuitBreaker = b return c } // IsDebug method returns `true` if the client is in debug mode; otherwise, it is `false`. func (c *Client) IsDebug() bool { c.lock.RLock() defer c.lock.RUnlock() return c.debug } // SetDebug method is used to turn on/off the debug mode on the Resty client instance. It logs details // of every request and response when enabled. // // client.SetDebug(true) // // Also, it can be enabled at the request level for a particular request; see [Request.SetDebug]. // - For [Request], it logs information such as HTTP verb, Relative URL path, // Host, Headers, and Body if it has one. // - For [Response], it logs information such as Status, Response Time, Headers, // and Body if it has one. func (c *Client) SetDebug(d bool) *Client { c.lock.Lock() defer c.lock.Unlock() c.debug = d return c } // DebugBodyLimit method returns the debug body limit value set on the client instance func (c *Client) DebugBodyLimit() int { c.lock.RLock() defer c.lock.RUnlock() return c.debugBodyLimit } // SetDebugBodyLimit sets the maximum size in bytes for which the response and // request body will be logged in debug mode. // // client.SetDebugBodyLimit(1000000) func (c *Client) SetDebugBodyLimit(sl int) *Client { c.lock.Lock() defer c.lock.Unlock() c.debugBodyLimit = sl return c } func (c *Client) debugLogCallbackFunc() DebugLogCallbackFunc { c.lock.RLock() defer c.lock.RUnlock() return c.debugLogCallback } // OnDebugLog method sets the debug log callback function to the client instance. // Registered callback gets called before the Resty logs the information. func (c *Client) OnDebugLog(dlc DebugLogCallbackFunc) *Client { c.lock.Lock() defer c.lock.Unlock() if c.debugLogCallback != nil { c.log.Warnf("Overwriting an existing on-debug-log callback from=%s to=%s", functionName(c.debugLogCallback), functionName(dlc)) } c.debugLogCallback = dlc return c } func (c *Client) debugLogFormatterFunc() DebugLogFormatterFunc { c.lock.RLock() defer c.lock.RUnlock() return c.debugLogFormatter } // SetDebugLogFormatter method sets the Resty debug log formatter to the client instance. func (c *Client) SetDebugLogFormatter(df DebugLogFormatterFunc) *Client { c.lock.Lock() defer c.lock.Unlock() c.debugLogFormatter = df return c } // IsDisableWarn method returns `true` if the warning message is disabled; otherwise, it is `false`. func (c *Client) IsDisableWarn() bool { c.lock.RLock() defer c.lock.RUnlock() return c.disableWarn } // SetLoggerWarnLevel method disables the warning log message on the Resty client. // // For example, Resty warns users when BasicAuth is used in non-TLS mode. // // client.SetLoggerWarnLevel(true) func (c *Client) SetLoggerWarnLevel(d bool) *Client { c.lock.Lock() defer c.lock.Unlock() c.disableWarn = d return c } // IsMethodGetAllowPayload method returns `true` if the client is enabled to allow // payload with GET method; otherwise, it is `false`. func (c *Client) IsMethodGetAllowPayload() bool { c.lock.RLock() defer c.lock.RUnlock() return c.isMethodGetAllowPayload } // SetMethodGetAllowPayload method allows the GET method with payload on the Resty client. // By default, Resty does not allow. // // client.SetMethodGetAllowPayload(true) // // It can be overridden at the request level. See [Request.SetMethodGetAllowPayload] func (c *Client) SetMethodGetAllowPayload(allow bool) *Client { c.lock.Lock() defer c.lock.Unlock() c.isMethodGetAllowPayload = allow return c } // IsMethodDeleteAllowPayload method returns `true` if the client is enabled to allow // payload with DELETE method; otherwise, it is `false`. // // More info, refer to GH#881 func (c *Client) IsMethodDeleteAllowPayload() bool { c.lock.RLock() defer c.lock.RUnlock() return c.isMethodDeleteAllowPayload } // SetMethodDeleteAllowPayload method allows the DELETE method with payload on the Resty client. // By default, Resty does not allow. // // client.SetMethodDeleteAllowPayload(true) // // More info, refer to GH#881 // // It can be overridden at the request level. See [Request.SetMethodDeleteAllowPayload] func (c *Client) SetMethodDeleteAllowPayload(allow bool) *Client { c.lock.Lock() defer c.lock.Unlock() c.isMethodDeleteAllowPayload = allow return c } // Logger method returns the logger instance used by the client instance. func (c *Client) Logger() Logger { c.lock.RLock() defer c.lock.RUnlock() return c.log } // SetLogger method sets given writer for logging Resty request and response details. // // Compliant to interface [resty.Logger] func (c *Client) SetLogger(l Logger) *Client { c.lock.Lock() defer c.lock.Unlock() c.log = l return c } // Timeout method returns the timeout duration value from the client func (c *Client) Timeout() time.Duration { c.lock.RLock() defer c.lock.RUnlock() return c.timeout } // SetTimeout method is used to set a timeout for a request raised by the client. // // client.SetTimeout(1 * time.Minute) // // It can be overridden at the request level. See [Request.SetTimeout] // // NOTE: Resty uses [context.WithTimeout] on the request, it does not use [http.Client].Timeout func (c *Client) SetTimeout(timeout time.Duration) *Client { c.lock.Lock() defer c.lock.Unlock() c.timeout = timeout return c } // ResultError method returns the global or client common `ResultError` object // type registered in the client instance. func (c *Client) ResultError() reflect.Type { c.lock.RLock() defer c.lock.RUnlock() return c.errorType } // SetResultError method registers the global or client common `ResultError` // object type into the client instance. It is used for automatic unmarshalling if // the response status code is greater than 399 and the content type is JSON or XML. // It can be a pointer or a non-pointer. // // client.SetResultError(&LoginErrorResponse{}) // // OR // client.SetResultError(LoginErrorResponse{}) func (c *Client) SetResultError(v any) *Client { c.lock.Lock() defer c.lock.Unlock() c.errorType = inferType(v) return c } func (c *Client) newErrorInterface() any { e := c.ResultError() if e == nil { return e } return reflect.New(e).Interface() } // SetRedirectPolicy method sets the redirect policy for the client. Resty provides ready-to-use // redirect policies. Wanna create one for yourself, refer to `redirect.go`. // // client.SetRedirectPolicy(resty.FlexibleRedirectPolicy(20)) // // // Need multiple redirect policies together // client.SetRedirectPolicy(resty.FlexibleRedirectPolicy(20), resty.DomainCheckRedirectPolicy("host1.com", "host2.net")) // // NOTE: It overwrites the previous redirect policies in the client instance. func (c *Client) SetRedirectPolicy(policies ...RedirectPolicy) *Client { c.lock.Lock() defer c.lock.Unlock() c.httpClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { for _, p := range policies { if err := p.Apply(req, via); err != nil { return err } } return nil // looks good, go ahead } return c } // RetryCount method returns the retry count value from the client instance. func (c *Client) RetryCount() int { c.lock.RLock() defer c.lock.RUnlock() return c.retryCount } // SetRetryCount method enables retry on Resty client and allows you // to set no. of retry count. // // first attempt + retry count = total attempts // // See [Request.SetRetryDelayStrategy] // // NOTE: // - By default, Resty only does retry on idempotent HTTP verb, [RFC 9110 Section 9.2.2], [RFC 9110 Section 18.2] // // [RFC 9110 Section 9.2.2]: https://datatracker.ietf.org/doc/html/rfc9110.html#name-idempotent-methods // [RFC 9110 Section 18.2]: https://datatracker.ietf.org/doc/html/rfc9110.html#name-method-registration func (c *Client) SetRetryCount(count int) *Client { c.lock.Lock() defer c.lock.Unlock() c.retryCount = count return c } // RetryWaitTime method returns the retry wait time that is used to sleep before // retrying the request. func (c *Client) RetryWaitTime() time.Duration { c.lock.RLock() defer c.lock.RUnlock() return c.retryWaitTime } // SetRetryWaitTime method sets the default wait time for sleep before retrying // // Default is 100 milliseconds. func (c *Client) SetRetryWaitTime(waitTime time.Duration) *Client { c.lock.Lock() defer c.lock.Unlock() c.retryWaitTime = waitTime return c } // RetryMaxWaitTime method returns the retry max wait time that is used to sleep // before retrying the request. func (c *Client) RetryMaxWaitTime() time.Duration { c.lock.RLock() defer c.lock.RUnlock() return c.retryMaxWaitTime } // SetRetryMaxWaitTime method sets the max wait time for sleep before retrying // // Default is 2 seconds. func (c *Client) SetRetryMaxWaitTime(maxWaitTime time.Duration) *Client { c.lock.Lock() defer c.lock.Unlock() c.retryMaxWaitTime = maxWaitTime return c } // RetryDelayStrategy method returns the retry delay strategy function; // otherwise, it is nil. // // See [Client.SetRetryDelayStrategy] func (c *Client) RetryDelayStrategy() RetryDelayStrategyFunc { c.lock.RLock() defer c.lock.RUnlock() return c.retryDelayStrategy } // SetRetryDelayStrategy method used to set the custom Retry delay strategy // into Resty client, it is used to get wait time before each retry. // It can be overridden at request level, see [Request.SetRetryDelayStrategy] // // By default, Resty employs the capped exponential backoff with a jitter delay strategy. func (c *Client) SetRetryDelayStrategy(rs RetryDelayStrategyFunc) *Client { c.lock.Lock() defer c.lock.Unlock() c.retryDelayStrategy = rs return c } // IsRetryDefaultConditions method returns true if Resty's default retry conditions // are enabled otherwise false // // Default value is `true` func (c *Client) IsRetryDefaultConditions() bool { c.lock.RLock() defer c.lock.RUnlock() return c.isRetryDefaultConditions } // SetRetryDefaultConditions method is used to enable/disable the Resty's default // retry conditions // // It can be overridden at request level, see [Request.SetRetryDefaultConditions] func (c *Client) SetRetryDefaultConditions(b bool) *Client { c.lock.Lock() defer c.lock.Unlock() c.isRetryDefaultConditions = b return c } // IsRetryAllowNonIdempotent method returns true if the client is enabled to allow // non-idempotent HTTP methods retry; otherwise, it is `false` // // Default value is `false` func (c *Client) IsRetryAllowNonIdempotent() bool { c.lock.RLock() defer c.lock.RUnlock() return c.isRetryAllowNonIdempotent } // SetRetryAllowNonIdempotent method is used to enable/disable non-idempotent HTTP // methods retry. By default, Resty only allows idempotent HTTP methods, see // [RFC 9110 Section 9.2.2], [RFC 9110 Section 18.2] // // It can be overridden at request level, see [Request.SetRetryAllowNonIdempotent] // // [RFC 9110 Section 9.2.2]: https://datatracker.ietf.org/doc/html/rfc9110.html#name-idempotent-methods // [RFC 9110 Section 18.2]: https://datatracker.ietf.org/doc/html/rfc9110.html#name-method-registration func (c *Client) SetRetryAllowNonIdempotent(b bool) *Client { c.lock.Lock() defer c.lock.Unlock() c.isRetryAllowNonIdempotent = b return c } // RetryConditions method returns all the retry condition functions. func (c *Client) RetryConditions() []RetryConditionFunc { c.lock.RLock() defer c.lock.RUnlock() return c.retryConditions } // AddRetryConditions method adds one or more retry condition functions into the request. // These retry conditions are executed to determine if the request can be retried. // The request will retry if any functions return `true`, otherwise return `false`. // // NOTE: // - Retry conditions are executed on each retry attempt. // - Default retry conditions are executed first. // - Client-level retry conditions are applied to all requests. // - Request-level retry conditions are executed before client-level retry conditions. // See [Request.AddRetryConditions], [Request.SetRetryConditions] // - Once a retry condition returns true, the remaining retry conditions are not executed. // - Retry conditions are executed in the order in which they are added. func (c *Client) AddRetryConditions(conditions ...RetryConditionFunc) *Client { c.lock.Lock() defer c.lock.Unlock() c.retryConditions = append(c.retryConditions, conditions...) return c } // RetryHooks method returns all the retry hook functions. func (c *Client) RetryHooks() []RetryHookFunc { c.lock.RLock() defer c.lock.RUnlock() return c.retryHooks } // AddRetryHooks method adds one or more side-effecting retry hooks to an array // of hooks that will be executed on each retry. // // NOTE: // - Retry hooks are executed on each retry attempt. // - The request-level retry hooks are executed first before client-level hooks. // See [Request.AddRetryHooks], [Request.SetRetryHooks] // - Retry hooks are executed in the order in which they are added. func (c *Client) AddRetryHooks(hooks ...RetryHookFunc) *Client { c.lock.Lock() defer c.lock.Unlock() c.retryHooks = append(c.retryHooks, hooks...) return c } // isHedgingEnabled method returns true if hedging is enabled. func (c *Client) isHedgingEnabled() bool { c.lock.RLock() defer c.lock.RUnlock() return c.hedging != nil } // Hedging method returns the hedging configuration of the client. // If nil is returned, it means hedging is disabled. func (c *Client) Hedging() *Hedging { c.lock.RLock() defer c.lock.RUnlock() return c.hedging } // SetHedging method sets the hedging instance into client. If nil is passed, it disables hedging. // // See [NewHedging] for more details about the Hedging configuration. func (c *Client) SetHedging(h *Hedging) *Client { c.lock.Lock() defer c.lock.Unlock() // if nil is passed, we disable hedging // by reverting the transport instance if h == nil { if ht, ok := c.httpClient.Transport.(*Hedging); ok { c.httpClient.Transport = ht.transport c.hedging = h } return c } // enable hedging if its not already enabled currentTransport := c.httpClient.Transport if currentTransport == nil { currentTransport = createTransport(nil, nil) } // If current transport is already a Hedging instance, unwrap it // to avoid double-wrapping (e.g., when SetHedging is called multiple times) if hedging, ok := currentTransport.(*Hedging); ok { currentTransport = hedging.transport } // Disable retry by default when hedging is enabled. // Users can re-enable retry if they want it as a fallback mechanism. if c.retryCount > 0 { c.log.Warnf("Disabling retry (count: %d) as hedging is now enabled."+ " You can re-enable retry with SetRetryCount() if you really want it as a fallback."+ " otherwise, hedging and retry requests can overwhelm the server.", c.retryCount) c.retryCount = 0 } h.transport = currentTransport c.httpClient.Transport = h c.hedging = h return c } // TLSClientConfig method returns the [tls.Config] from underlying client transport // otherwise returns nil func (c *Client) TLSClientConfig() *tls.Config { cfg, err := c.tlsConfig() if err != nil { c.Logger().Errorf("%v", err) } return cfg } // SetTLSClientConfig method sets TLSClientConfig for underlying client Transport. // // Values supported by https://pkg.go.dev/crypto/tls#Config can be configured. // // // Disable SSL cert verification for local development // client.SetTLSClientConfig(&tls.Config{ // InsecureSkipVerify: true // }) // // NOTE: This method overwrites existing [http.Transport.TLSClientConfig] func (c *Client) SetTLSClientConfig(tlsConfig *tls.Config) *Client { c.lock.Lock() defer c.lock.Unlock() // TLSClientConfiger interface handling if tc, ok := c.httpClient.Transport.(TLSClientConfiger); ok { if err := tc.SetTLSClientConfig(tlsConfig); err != nil { c.log.Errorf("%v", err) } return c } // default standard transport handling transport, ok := c.httpClient.Transport.(*http.Transport) if !ok { c.log.Errorf("SetTLSClientConfig: %v", ErrNotHttpTransportType) return c } transport.TLSClientConfig = tlsConfig return c } // ProxyURL method returns the proxy URL if set otherwise nil. func (c *Client) ProxyURL() *url.URL { c.lock.RLock() defer c.lock.RUnlock() return c.proxyURL } // SetProxy method sets the Proxy URL and Port for the Resty client. // // // HTTP/HTTPS proxy // client.SetProxy("http://proxyserver:8888") // // // SOCKS5 Proxy // client.SetProxy("socks5://127.0.0.1:1080") // // OR you could also set Proxy via environment variable, refer to [http.ProxyFromEnvironment] func (c *Client) SetProxy(proxyURL string) *Client { transport, err := c.HTTPTransport() if err != nil { c.Logger().Errorf("%v", err) return c } pURL, err := url.Parse(proxyURL) if err != nil { c.Logger().Errorf("%v", err) return c } c.lock.Lock() c.proxyURL = pURL transport.Proxy = http.ProxyURL(c.proxyURL) c.lock.Unlock() return c } // RemoveProxy method removes the proxy configuration from the Resty client // // client.RemoveProxy() func (c *Client) RemoveProxy() *Client { transport, err := c.HTTPTransport() if err != nil { c.Logger().Errorf("%v", err) return c } c.lock.Lock() defer c.lock.Unlock() c.proxyURL = nil transport.Proxy = nil return c } // SetCertificateFromFile method helps to set client certificates into Resty // from cert and key files to perform SSL client authentication // // client.SetCertificateFromFile("certs/client.pem", "certs/client.key") func (c *Client) SetCertificateFromFile(certFilePath, certKeyFilePath string) *Client { cert, err := tls.LoadX509KeyPair(certFilePath, certKeyFilePath) if err != nil { c.Logger().Errorf("client certificate/key parsing error: %v", err) return c } c.SetCertificates(cert) return c } // SetCertificateFromString method helps to set client certificates into Resty // from string to perform SSL client authentication // // myClientCertStr := `-----BEGIN CERTIFICATE----- // ... cert content ... // -----END CERTIFICATE-----` // // myClientCertKeyStr := `-----BEGIN PRIVATE KEY----- // ... cert key content ... // -----END PRIVATE KEY-----` // // client.SetCertificateFromString(myClientCertStr, myClientCertKeyStr) func (c *Client) SetCertificateFromString(certStr, certKeyStr string) *Client { cert, err := tls.X509KeyPair([]byte(certStr), []byte(certKeyStr)) if err != nil { c.Logger().Errorf("client certificate/key parsing error: %v", err) return c } c.SetCertificates(cert) return c } // SetCertificates method helps to conveniently set a slice of client certificates // into Resty to perform SSL client authentication // // cert, err := tls.LoadX509KeyPair("certs/client.pem", "certs/client.key") // if err != nil { // log.Printf("ERROR client certificate/key parsing error: %v", err) // return // } // // client.SetCertificates(cert) func (c *Client) SetCertificates(certs ...tls.Certificate) *Client { config, err := c.tlsConfig() if err != nil { c.Logger().Errorf("%v", err) return c } c.lock.Lock() defer c.lock.Unlock() config.Certificates = append(config.Certificates, certs...) return c } // SetRootCertificates method helps to add one or more root certificate files // into the Resty client // // // one pem file path // client.SetRootCertificates("/path/to/root/pemFile.pem") // // // one or more pem file path(s) // client.SetRootCertificates( // "/path/to/root/pemFile1.pem", // "/path/to/root/pemFile2.pem" // "/path/to/root/pemFile3.pem" // ) // // // if you happen to have string slices // client.SetRootCertificates(certs...) func (c *Client) SetRootCertificates(pemFilePaths ...string) *Client { for _, fp := range pemFilePaths { rootPemData, err := os.ReadFile(fp) if err != nil { c.Logger().Errorf("%v", err) return c } c.handleCAs("root", rootPemData) } return c } // SetRootCertificatesWatcher method enables dynamic reloading of one or more root certificate files. // It is designed for scenarios involving long-running Resty clients where certificates may be renewed. // // client.SetRootCertificatesWatcher( // &resty.CertWatcherOptions{ // PoolInterval: 24 * time.Hour, // }, // "root-ca.pem", // ) func (c *Client) SetRootCertificatesWatcher(options *CertWatcherOptions, pemFilePaths ...string) *Client { c.SetRootCertificates(pemFilePaths...) for _, fp := range pemFilePaths { c.initCertWatcher(fp, "root", options) } return c } // SetRootCertificateFromString method helps to add root certificate from the string // into the Resty client // // myRootCertStr := `-----BEGIN CERTIFICATE----- // ... cert content ... // -----END CERTIFICATE-----` // // client.SetRootCertificateFromString(myRootCertStr) func (c *Client) SetRootCertificateFromString(pemCerts string) *Client { c.handleCAs("root", []byte(pemCerts)) return c } // SetClientRootCertificates method helps to add one or more client root // certificate files into the Resty client // // // one pem file path // client.SetClientRootCertificates("/path/to/client-root/pemFile.pem") // // // one or more pem file path(s) // client.SetClientRootCertificates( // "/path/to/client-root/pemFile1.pem", // "/path/to/client-root/pemFile2.pem" // "/path/to/client-root/pemFile3.pem" // ) // // // if you happen to have string slices // client.SetClientRootCertificates(certs...) func (c *Client) SetClientRootCertificates(pemFilePaths ...string) *Client { for _, fp := range pemFilePaths { pemData, err := os.ReadFile(fp) if err != nil { c.Logger().Errorf("%v", err) return c } c.handleCAs("client-root", pemData) } return c } // SetClientRootCertificatesWatcher method enables dynamic reloading of one or more client root certificate files. // It is designed for scenarios involving long-running Resty clients where certificates may be renewed. // // client.SetClientRootCertificatesWatcher( // &resty.CertWatcherOptions{ // PoolInterval: 24 * time.Hour, // }, // "client-root-ca.pem", // ) func (c *Client) SetClientRootCertificatesWatcher(options *CertWatcherOptions, pemFilePaths ...string) *Client { c.SetClientRootCertificates(pemFilePaths...) for _, fp := range pemFilePaths { c.initCertWatcher(fp, "client-root", options) } return c } // SetClientRootCertificateFromString method helps to add a client root certificate // from the string into the Resty client // // myClientRootCertStr := `-----BEGIN CERTIFICATE----- // ... cert content ... // -----END CERTIFICATE-----` // // client.SetClientRootCertificateFromString(myClientRootCertStr) func (c *Client) SetClientRootCertificateFromString(pemCerts string) *Client { c.handleCAs("client-root", []byte(pemCerts)) return c } func (c *Client) handleCAs(scope string, permCerts []byte) { config, err := c.tlsConfig() if err != nil { c.Logger().Errorf("%v", err) return } c.lock.Lock() defer c.lock.Unlock() switch scope { case "root": if config.RootCAs == nil { config.RootCAs = x509.NewCertPool() } config.RootCAs.AppendCertsFromPEM(permCerts) case "client-root": if config.ClientCAs == nil { config.ClientCAs = x509.NewCertPool() } config.ClientCAs.AppendCertsFromPEM(permCerts) } } func (c *Client) initCertWatcher(pemFilePath, scope string, options *CertWatcherOptions) { tickerDuration := defaultWatcherPoolingInterval if options != nil && options.PoolInterval > 0 { tickerDuration = options.PoolInterval } go func() { ticker := time.NewTicker(tickerDuration) st, err := os.Stat(pemFilePath) if err != nil { c.Logger().Errorf("%v", err) return } modTime := st.ModTime().UTC() for { select { case <-c.certWatcherStopChan: ticker.Stop() return case <-ticker.C: c.debugf("Checking if cert %s has changed...", pemFilePath) st, err = os.Stat(pemFilePath) if err != nil { c.Logger().Errorf("%v", err) continue } newModTime := st.ModTime().UTC() if modTime.Equal(newModTime) { c.debugf("Cert %s hasn't changed.", pemFilePath) continue } modTime = newModTime c.debugf("Reloading cert %s ...", pemFilePath) switch scope { case "root": c.SetRootCertificates(pemFilePath) case "client-root": c.SetClientRootCertificates(pemFilePath) } c.debugf("Cert %s reloaded.", pemFilePath) } } }() } // ResponseSaveDirectory method returns the output directory value from the client. func (c *Client) ResponseSaveDirectory() string { c.lock.RLock() defer c.lock.RUnlock() return c.responseSaveDirectory } // SetResponseSaveDirectory method sets the output directory for saving HTTP responses in a file. // Resty creates one if the output directory does not exist. This setting is optional, // if you plan to use the absolute path in [Request.SetResponseSaveFileName] and can used together. // // client.SetResponseSaveDirectory("/save/http/response/here") func (c *Client) SetResponseSaveDirectory(dirPath string) *Client { c.lock.Lock() defer c.lock.Unlock() c.responseSaveDirectory = dirPath return c } // IsResponseSaveToFile method returns true if the save response is set to true; otherwise, false func (c *Client) IsResponseSaveToFile() bool { c.lock.RLock() defer c.lock.RUnlock() return c.isResponseSaveToFile } // SetResponseSaveToFile method used to enable the save response option at the client level for // all requests // // client.SetResponseSaveToFile(true) // // Resty determines the save filename in the following order - // - [Request.SetResponseSaveFileName] // - Content-Disposition header // - Request URL using [path.Base] // - Request URL hostname if path is empty or "/" // // It can be overridden at request level, see [Request.SetResponseSaveToFile] func (c *Client) SetResponseSaveToFile(save bool) *Client { c.lock.Lock() defer c.lock.Unlock() c.isResponseSaveToFile = save return c } // HTTPTransport method does type assertion and returns [http.Transport] // from the client instance, if type assertion fails it returns an error func (c *Client) HTTPTransport() (*http.Transport, error) { c.lock.RLock() defer c.lock.RUnlock() if transport, ok := c.httpClient.Transport.(*http.Transport); ok { return transport, nil } return nil, ErrNotHttpTransportType } // Transport method returns underlying client transport referance as-is // i.e., [http.RoundTripper] func (c *Client) Transport() http.RoundTripper { c.lock.RLock() defer c.lock.RUnlock() return c.httpClient.Transport } // SetTransport method sets custom [http.Transport] or any [http.RoundTripper] // compatible interface implementation in the Resty client. // // transport := &http.Transport{ // // something like Proxying to httptest.Server, etc... // Proxy: func(req *http.Request) (*url.URL, error) { // return url.Parse(server.URL) // }, // } // client.SetTransport(transport) // // NOTE: // - If transport is not the type of [http.Transport], you may lose the // ability to set a few Resty client settings. However, if you implement // [TLSClientConfiger] interface, then TLS client config is possible to set. // - It overwrites the Resty client transport instance and its configurations. func (c *Client) SetTransport(transport http.RoundTripper) *Client { c.lock.Lock() defer c.lock.Unlock() if transport != nil { c.httpClient.Transport = transport } return c } // Scheme method returns custom scheme value from the client. // // scheme := client.Scheme() func (c *Client) Scheme() string { c.lock.RLock() defer c.lock.RUnlock() return c.scheme } // SetScheme method sets a custom scheme for the Resty client. It's a way to override the default. // // client.SetScheme("http") func (c *Client) SetScheme(scheme string) *Client { c.lock.Lock() defer c.lock.Unlock() if !isStringEmpty(scheme) { c.scheme = strings.TrimSpace(scheme) } return c } // SetCloseConnection method sets variable `Close` in HTTP request struct with the given // value. More info: https://golang.org/src/net/http/request.go // // It can be overridden at the request level, see [Request.SetCloseConnection] func (c *Client) SetCloseConnection(close bool) *Client { c.lock.Lock() defer c.lock.Unlock() c.closeConnection = close return c } // SetResponseDoNotParse method instructs Resty not to parse the response body automatically. // // Resty exposes the raw response body as [io.ReadCloser]. If you use it, do not // forget to close the body, otherwise, you might get into connection leaks, and connection // reuse may not happen. // // NOTE: The default [Response] middlewares are not executed when using this option. User // takes over the control of handling response body from Resty. func (c *Client) SetResponseDoNotParse(notParse bool) *Client { c.lock.Lock() defer c.lock.Unlock() c.isResponseDoNotParse = notParse return c } // PathParams method returns the path parameters from the client. // // pathParams := client.PathParams() func (c *Client) PathParams() map[string]string { c.lock.RLock() defer c.lock.RUnlock() return c.pathParams } // SetPathParam method sets a single URL path key-value pair in the // Resty client instance. // // client.SetPathParam("userId", "sample@sample.com") // // Result: // URL - /v1/users/{userId}/details // Composed URL - /v1/users/sample@sample.com/details // // It replaces the value of the key while composing the request URL. // The value will be escaped using [url.PathEscape] function. // // It can be overridden at the request level, // see [Request.SetPathParam] or [Request.SetPathParams] func (c *Client) SetPathParam(param, value string) *Client { c.lock.Lock() defer c.lock.Unlock() c.pathParams[param] = url.PathEscape(value) return c } // SetPathParamAny method sets a single URL path key-value pair in the // Resty client instance. // // It is similar to [Client.SetPathParam] but accepts any type as the value and converts // it to a string using predefined formatting rules (integers, bools, time.Time, etc.). // // client.SetPathParamAny("userId", 12345) // // Result: // URL - /v1/users/{userId}/details // Composed URL - /v1/users/12345/details // // It replaces the value of the key while composing the request URL. // The value will be escaped using [url.PathEscape] function. // // It can be overridden at the request level, // see [Request.SetPathParamAny] or [Request.SetPathParams] func (c *Client) SetPathParamAny(param string, value any) *Client { c.lock.Lock() defer c.lock.Unlock() strVal := formatAnyToString(value) c.pathParams[param] = url.PathEscape(strVal) return c } // SetPathParams method sets multiple URL path key-value pairs at one go in the // Resty client instance. // // client.SetPathParams(map[string]string{ // "userId": "sample@sample.com", // "subAccountId": "100002", // "path": "groups/developers", // }) // // Result: // URL - /v1/users/{userId}/{subAccountId}/{path}/details // Composed URL - /v1/users/sample@sample.com/100002/groups%2Fdevelopers/details // // It replaces the value of the key while composing the request URL. // The values will be escaped using [url.PathEscape] function. // // It can be overridden at the request level, // see [Request.SetPathParam] or [Request.SetPathParams] func (c *Client) SetPathParams(params map[string]string) *Client { for p, v := range params { c.SetPathParam(p, v) } return c } // SetPathRawParam method sets a single URL path key-value pair in the // Resty client instance without path escape. // // client.SetPathRawParam("path", "groups/developers") // // Result: // URL - /v1/users/{path}/details // Composed URL - /v1/users/groups/developers/details // // It replaces the value of the key while composing the request URL. // The value will be used as-is, no path escape applied. // // It can be overridden at the request level, // see [Request.SetPathRawParam] or [Request.SetPathRawParams] func (c *Client) SetPathRawParam(param, value string) *Client { c.lock.Lock() defer c.lock.Unlock() c.pathParams[param] = value return c } // SetPathRawParamAny method sets a single URL path key-value pair in the // Resty client instance without path escape. // // It is similar to [Client.SetPathRawParam] but accepts any type as the value and converts // it to a string using predefined formatting rules (integers, bools, time.Time, etc.). // // client.SetPathRawParamAny("userId", 12345) // // Result: // URL - /v1/users/{userId}/details // Composed URL - /v1/users/12345/details // // It replaces the value of the key while composing the request URL. // The value will be used as-is, no path escape applied. // // It can be overridden at the request level, // see [Request.SetPathRawParamAny] or [Request.SetPathRawParams] func (c *Client) SetPathRawParamAny(param string, value any) *Client { c.lock.Lock() defer c.lock.Unlock() strVal := formatAnyToString(value) c.pathParams[param] = strVal return c } // SetPathRawParams method sets multiple URL path key-value pairs at one go in the // Resty client instance without path escape. // // client.SetPathRawParams(map[string]string{ // "userId": "sample@sample.com", // "subAccountId": "100002", // "path": "groups/developers", // }) // // Result: // URL - /v1/users/{userId}/{subAccountId}/{path}/details // Composed URL - /v1/users/sample@sample.com/100002/groups/developers/details // // It replaces the value of the key while composing the request URL. // The value will be used as-is, no path escape applied. // // It can be overridden at the request level, // see [Request.SetPathRawParam] or [Request.SetPathRawParams] func (c *Client) SetPathRawParams(params map[string]string) *Client { for p, v := range params { c.SetPathRawParam(p, v) } return c } // SetJSONEscapeHTML method enables or disables the HTML escape on JSON marshal. // By default, escape HTML is `true`. // // NOTE: This option only applies to the standard JSON Marshaller used by Resty. // // It can be overridden at the request level, see [Request.SetJSONEscapeHTML] func (c *Client) SetJSONEscapeHTML(b bool) *Client { c.lock.Lock() defer c.lock.Unlock() c.jsonEscapeHTML = b return c } // ResponseBodyLimit method returns the value max body size limit in bytes from // the client instance. func (c *Client) ResponseBodyLimit() int64 { c.lock.RLock() defer c.lock.RUnlock() return c.responseBodyLimit } // SetResponseBodyLimit method sets a maximum body size limit in bytes on response, // avoid reading too much data to memory. // // Client will return [resty.ErrResponseBodyTooLarge] if the body size of the body // in the uncompressed response is larger than the limit. // Body size limit will not be enforced in the following cases: // - ResponseBodyLimit <= 0, which is the default behavior. // - [Request.SetResponseSaveFileName] is called to save response data to the file. // - "DoNotParseResponse" is set for client or request. // // It can be overridden at the request level; see [Request.SetResponseBodyLimit] func (c *Client) SetResponseBodyLimit(v int64) *Client { c.lock.Lock() defer c.lock.Unlock() c.responseBodyLimit = v return c } // IsTrace method returns true if the trace is enabled on the client instance; otherwise, it returns false. func (c *Client) IsTrace() bool { c.lock.RLock() defer c.lock.RUnlock() return c.isTrace } // SetTrace method is used to turn on/off the trace capability in the Resty client instance. // It provides an insight into the request lifecycle using [httptrace.ClientTrace]. // // client := resty.New().SetTrace(true) // // resp, err := client.R().Get("https://httpbin.org/get") // fmt.Println("error:", err) // fmt.Println("Trace Info:", resp.Request.TraceInfo()) // // The method [Request.SetTrace] is also available to get trace info for a single request. func (c *Client) SetTrace(t bool) *Client { c.lock.Lock() defer c.lock.Unlock() c.isTrace = t return c } // SetCurlCmdGenerate method is used to turn on/off the generate curl command at the // client instance level. // // By default, Resty does not log the curl command in the debug log since it has the potential // to leak sensitive data unless explicitly enabled via [Client.SetCurlCmdDebugLog] or // [Request.SetCurlCmdDebugLog]. // // NOTE: Use with care. // - Potential to leak sensitive data from [Request] and [Response] in the debug log // when the debug log option is enabled. // - Additional memory usage since the request body was reread. // - curl body is not generated for [io.Reader] and multipart request flow. // // It can be overridden at the request level; see [Request.SetCurlCmdGenerate] func (c *Client) SetCurlCmdGenerate(b bool) *Client { c.lock.Lock() defer c.lock.Unlock() c.isCurlCmdGenerate = b return c } // SetCurlCmdDebugLog method enables the curl command to be logged in the debug log. // // It can be overridden at the request level; see [Request.SetCurlCmdDebugLog] func (c *Client) SetCurlCmdDebugLog(b bool) *Client { c.lock.Lock() defer c.lock.Unlock() c.isCurlCmdDebugLog = b return c } // SetQueryParamsUnescape method sets the choice of unescape query parameters for the request URL. // To prevent broken URL, Resty replaces space (" ") with "+" in the query parameters. // // See [Request.SetQueryParamsUnescape] // // NOTE: Request failure is possible due to non-standard usage of Unescaped Query Parameters. func (c *Client) SetQueryParamsUnescape(unescape bool) *Client { c.lock.Lock() defer c.lock.Unlock() c.unescapeQueryParams = unescape return c } // ResponseBodyUnlimitedReads method returns true if enabled. Otherwise, it returns false func (c *Client) ResponseBodyUnlimitedReads() bool { c.lock.RLock() defer c.lock.RUnlock() return c.resBodyUnlimitedReads } // SetResponseBodyUnlimitedReads method is to turn on/off the response body in memory // that provides an ability to do unlimited reads. // // It can be overridden at the request level; see [Request.SetResponseBodyUnlimitedReads] // // Unlimited reads are possible in a few scenarios, even without enabling it. // - When debug mode is enabled // // NOTE: Use with care // - Turning on this feature keeps the response body in memory, which might cause additional memory usage. func (c *Client) SetResponseBodyUnlimitedReads(b bool) *Client { c.lock.Lock() defer c.lock.Unlock() c.resBodyUnlimitedReads = b return c } // IsProxySet method returns the true is proxy is set from the Resty client; otherwise // false. By default, the proxy is set from the environment variable; refer to [http.ProxyFromEnvironment]. func (c *Client) IsProxySet() bool { return c.ProxyURL() != nil } // Client method returns the underlying Go [http.Client] used by the Resty. func (c *Client) Client() *http.Client { c.lock.RLock() defer c.lock.RUnlock() return c.httpClient } // Clone method returns a clone of the original client. // // NOTE: Use with care: // - Interface values are not deeply cloned. Thus, both the original and the // clone will use the same value. // - It is not safe for concurrent use. You should only use this method // when you are sure that any other concurrent process is not using the client // or client instance is protected by a mutex. func (c *Client) Clone(ctx context.Context) *Client { cc := new(Client) // dereference the pointer and copy the value *cc = *c cc.ctx = ctx cc.queryParams = cloneURLValues(c.queryParams) cc.formData = cloneURLValues(c.formData) cc.header = c.header.Clone() cc.pathParams = maps.Clone(c.pathParams) if c.credentials != nil { cc.credentials = c.credentials.Clone() } cc.contentTypeEncoders = maps.Clone(c.contentTypeEncoders) cc.contentTypeDecoders = maps.Clone(c.contentTypeDecoders) cc.contentDecompressers = maps.Clone(c.contentDecompressers) copy(cc.contentDecompresserKeys, c.contentDecompresserKeys) if c.proxyURL != nil { cc.proxyURL, _ = url.Parse(c.proxyURL.String()) } // clone cookies if l := len(c.cookies); l > 0 { cc.cookies = make([]*http.Cookie, 0, l) for _, cookie := range c.cookies { cc.cookies = append(cc.cookies, cloneCookie(cookie)) } } // certain values need to be reset cc.lock = &sync.RWMutex{} return cc } // Close method performs cleanup and closure activities on the client instance func (c *Client) Close() error { // Execute close hooks first c.onCloseHooks() if c.LoadBalancer() != nil { silently(c.LoadBalancer().Close()) } close(c.certWatcherStopChan) return nil } func (c *Client) executeRequestMiddlewares(req *Request) (err error) { for _, f := range c.requestMiddlewares() { if err = f(c, req); err != nil { return err } } return nil } // Executes method executes the given `Request` object and returns // response or error. func (c *Client) execute(req *Request) (*Response, error) { if c.circuitBreaker != nil { if err := c.circuitBreaker.allow(); err != nil { c.circuitBreaker.onTriggerHooks(req, err) return nil, err } } if err := c.executeRequestMiddlewares(req); err != nil { return nil, err } if hostHeader := req.Header.Get("Host"); hostHeader != "" { req.RawRequest.Host = hostHeader } prepareRequestDebugInfo(c, req) req.StartTime = time.Now() resp, err := c.Client().Do(req.withTimeout()) // Cancel multipart context for io.Copy to stop reading/writing further if req.isMultiPart && req.multipartCancelFunc != nil { req.multipartCancelFunc() } response := &Response{Request: req, RawResponse: resp} response.setReceivedAt() if err != nil { return response, err } if req.isMultiPart && req.multipartErrChan != nil { // read all multipart errors from channel for err = range req.multipartErrChan { response.CascadeError = wrapErrors(err, response.CascadeError) } } if resp != nil { if c.circuitBreaker != nil { c.circuitBreaker.applyPolicies(resp) } response.Body = resp.Body if err = response.wrapContentDecompresser(); err != nil { return response, response.wrapError(err, false) } response.wrapLimitReadCloser() if !req.IsResponseDoNotParse { if req.IsResponseBodyUnlimitedReads || req.IsDebug { response.wrapCopyReadCloser() if err = response.readAll(); err != nil { return response, response.wrapError(err, false) } } } } debugLogger(c, response) // Apply Response middleware for _, f := range c.responseMiddlewares() { if err = f(c, response); err != nil { response.CascadeError = wrapErrors(err, response.CascadeError) } } return response, response.wrapError(nil, false) } // getting TLS client config if not exists then create one func (c *Client) tlsConfig() (*tls.Config, error) { c.lock.Lock() defer c.lock.Unlock() if tc, ok := c.httpClient.Transport.(TLSClientConfiger); ok { return tc.TLSClientConfig(), nil } transport, ok := c.httpClient.Transport.(*http.Transport) if !ok { return nil, ErrNotHttpTransportType } if transport.TLSClientConfig == nil { transport.TLSClientConfig = &tls.Config{} } return transport.TLSClientConfig, nil } // just an internal helper method func (c *Client) outputLogTo(w io.Writer) *Client { c.Logger().(*logger).l.SetOutput(w) return c } // ResponseError is a wrapper that includes the server response with an error. // Neither the err nor the response should be nil. type ResponseError struct { Response *Response Err error } func (e *ResponseError) Error() string { return e.Err.Error() } func (e *ResponseError) Unwrap() error { return e.Err } // Helper to run errorHooks hooks. // It wraps the error in a [ResponseError] if the resp is not nil // so hooks can access it. func (c *Client) onErrorHooks(req *Request, res *Response, err error) { c.lock.RLock() defer c.lock.RUnlock() if err != nil { if res != nil { // wrap with ResponseError err = &ResponseError{Response: res, Err: err} } for _, h := range c.errorHooks { h(req, err) } } else { for _, h := range c.successHooks { h(c, res) } } } // Helper to run panicHooks hooks. func (c *Client) onPanicHooks(req *Request, err error) { c.lock.RLock() defer c.lock.RUnlock() for _, h := range c.panicHooks { h(req, err) } } // Helper to run invalidHooks hooks. func (c *Client) onInvalidHooks(req *Request, err error) { c.lock.RLock() defer c.lock.RUnlock() for _, h := range c.invalidHooks { h(req, err) } } // Helper to run closeHooks hooks. func (c *Client) onCloseHooks() { c.lock.RLock() defer c.lock.RUnlock() for _, h := range c.closeHooks { h() } } func (c *Client) debugf(format string, v ...any) { if c.IsDebug() { c.Logger().Debugf(format, v...) } } ================================================ FILE: client_test.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty import ( "bytes" "compress/gzip" "compress/lzw" "context" cryprand "crypto/rand" "crypto/tls" "errors" "fmt" "io" "log" "math" "math/rand" "net" "net/http" "net/url" "os" "path/filepath" "strconv" "strings" "sync" "testing" "time" ) func TestClientBasicAuth(t *testing.T) { ts := createAuthServer(t) defer ts.Close() c := dcnl() c.SetBasicAuth("myuser", "basicauth"). SetBaseURL(ts.URL). SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}) resp, err := c.R(). SetResult(&AuthSuccess{}). Post("/login") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) t.Logf("Result Success: %q", resp.Result().(*AuthSuccess)) logResponse(t, resp) } func TestClientAuthToken(t *testing.T) { ts := createAuthServer(t) defer ts.Close() c := dcnl() c.SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}). SetAuthToken("004DDB79-6801-4587-B976-F093E6AC44FF"). SetBaseURL(ts.URL + "/") resp, err := c.R().Get("/profile") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) } func TestClientAuthScheme(t *testing.T) { ts := createAuthServer(t) defer ts.Close() c := dcnl() // Ensure default Bearer c.SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}). SetAuthToken("004DDB79-6801-4587-B976-F093E6AC44FF"). SetBaseURL(ts.URL + "/") resp, err := c.R().Get("/profile") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) // Ensure setting the scheme works as well c.SetAuthScheme("Bearer") assertEqual(t, "Bearer", c.AuthScheme()) resp2, err2 := c.R().Get("/profile") assertError(t, err2) assertEqual(t, http.StatusOK, resp2.StatusCode()) } func TestClientResponseMiddleware(t *testing.T) { ts := createGenericServer(t) defer ts.Close() c := dcnl() c.AddResponseMiddleware(func(c *Client, res *Response) error { t.Logf("Request sent at: %v", res.Request.StartTime) t.Logf("Response Received at: %v", res.ReceivedAt()) return nil }) resp, err := c.R(). SetBody("ResponseMiddleware: This is plain text body to server"). Put(ts.URL + "/plaintext") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "TestPut: plain text response", resp.String()) } func TestClientRedirectPolicy(t *testing.T) { ts := createRedirectServer(t) defer ts.Close() c := dcnl().SetRedirectPolicy(RedirectFlexiblePolicy(20), RedirectDomainCheckPolicy("127.0.0.1")) res, err := c.R(). SetHeader("Name1", "Value1"). SetHeader("Name2", "Value2"). SetHeader("Name3", "Value3"). Get(ts.URL + "/redirect-1") assertTrue(t, err.Error() == "Get \"/redirect-21\": resty: stopped after 20 redirects") redirects := res.RedirectHistory() assertEqual(t, 20, len(redirects)) finalReq := redirects[0] assertEqual(t, 307, finalReq.StatusCode) assertEqual(t, ts.URL+"/redirect-20", finalReq.URL) c.SetRedirectPolicy(RedirectNoPolicy()) res, err = c.R().Get(ts.URL + "/redirect-1") assertNil(t, err) assertEqual(t, http.StatusTemporaryRedirect, res.StatusCode()) assertEqual(t, `Temporary Redirect.`, res.String()) } func TestClientTimeout(t *testing.T) { ts := createGetServer(t) defer ts.Close() c := dcnl().SetTimeout(200 * time.Millisecond) _, err := c.R().Get(ts.URL + "/set-timeout-test") assertErrorIs(t, context.DeadlineExceeded, err) } func TestClientTimeoutWithinThreshold(t *testing.T) { ts := createGetServer(t) defer ts.Close() c := dcnl().SetTimeout(200 * time.Millisecond) resp, err := c.R().Get(ts.URL + "/set-timeout-test-with-sequence") assertError(t, err) seq1, _ := strconv.ParseInt(resp.String(), 10, 32) resp, err = c.R().Get(ts.URL + "/set-timeout-test-with-sequence") assertError(t, err) seq2, _ := strconv.ParseInt(resp.String(), 10, 32) assertEqual(t, seq1+1, seq2) } func TestClientTimeoutInternalError(t *testing.T) { c := dcnl().SetTimeout(time.Second * 1) _, _ = c.R().Get("http://localhost:9000/set-timeout-test") } func TestClientProxy(t *testing.T) { ts := createGetServer(t) defer ts.Close() c := dcnl() c.SetTimeout(1 * time.Second) c.SetProxy("http://sampleproxy:8888") resp, err := c.R().Get(ts.URL) assertNotNil(t, resp) assertNotNil(t, err) // error c.SetProxy("//not.a.user@%66%6f%6f.com:8888") resp, err = c.R(). Get(ts.URL) assertNotNil(t, err) assertNotNil(t, resp) } func TestClientSetCertificates(t *testing.T) { certFile := filepath.Join(getTestDataPath(), "cert.pem") keyFile := filepath.Join(getTestDataPath(), "key.pem") t.Run("client cert from file", func(t *testing.T) { c := dcnl() c.SetCertificateFromFile(certFile, keyFile) assertEqual(t, 1, len(c.TLSClientConfig().Certificates)) }) t.Run("error-client cert from file", func(t *testing.T) { c := dcnl() c.SetCertificateFromFile(certFile+"no", keyFile+"no") assertEqual(t, 0, len(c.TLSClientConfig().Certificates)) }) t.Run("client cert from string", func(t *testing.T) { certPemData, _ := os.ReadFile(certFile) keyPemData, _ := os.ReadFile(keyFile) c := dcnl() c.SetCertificateFromString(string(certPemData), string(keyPemData)) assertEqual(t, 1, len(c.TLSClientConfig().Certificates)) }) t.Run("error-client cert from string", func(t *testing.T) { c := dcnl() c.SetCertificateFromString(string("empty"), string("empty")) assertEqual(t, 0, len(c.TLSClientConfig().Certificates)) }) } func TestClientSetRootCertificate(t *testing.T) { t.Run("root cert", func(t *testing.T) { client := dcnl() client.SetRootCertificates(filepath.Join(getTestDataPath(), "sample-root.pem")) transport, err := client.HTTPTransport() assertNil(t, err) assertNotNil(t, transport.TLSClientConfig.RootCAs) }) t.Run("root cert not exists", func(t *testing.T) { client := dcnl() client.SetRootCertificates(filepath.Join(getTestDataPath(), "not-exists-sample-root.pem")) transport, err := client.HTTPTransport() assertNil(t, err) assertNil(t, transport.TLSClientConfig) }) t.Run("root cert from string", func(t *testing.T) { client := dcnl() rootPemData, err := os.ReadFile(filepath.Join(getTestDataPath(), "sample-root.pem")) assertNil(t, err) client.SetRootCertificateFromString(string(rootPemData)) transport, err := client.HTTPTransport() assertNil(t, err) assertNotNil(t, transport.TLSClientConfig.RootCAs) }) } type CustomRoundTripper1 struct{} // RoundTrip just for test func (rt *CustomRoundTripper1) RoundTrip(_ *http.Request) (*http.Response, error) { return &http.Response{}, nil } func TestClientCACertificateFromStringErrorTls(t *testing.T) { t.Run("root cert string", func(t *testing.T) { client := NewWithClient(&http.Client{}) client.outputLogTo(io.Discard) rootPemData, err := os.ReadFile(filepath.Join(getTestDataPath(), "sample-root.pem")) assertNil(t, err) rt := &CustomRoundTripper1{} client.SetTransport(rt) transport, err := client.HTTPTransport() client.SetRootCertificateFromString(string(rootPemData)) assertNotNil(t, rt) assertNotNil(t, err) assertNil(t, transport) }) t.Run("client cert string", func(t *testing.T) { client := NewWithClient(&http.Client{}) client.outputLogTo(io.Discard) rootPemData, err := os.ReadFile(filepath.Join(getTestDataPath(), "sample-root.pem")) assertNil(t, err) rt := &CustomRoundTripper1{} client.SetTransport(rt) transport, err := client.HTTPTransport() client.SetClientRootCertificateFromString(string(rootPemData)) assertNotNil(t, rt) assertNotNil(t, err) assertNil(t, transport) }) } // CustomRoundTripper2 just for test type CustomRoundTripper2 struct { http.RoundTripper TLSClientConfiger tlsConfig *tls.Config returnErr bool } // RoundTrip just for test func (rt *CustomRoundTripper2) RoundTrip(_ *http.Request) (*http.Response, error) { if rt.returnErr { return nil, errors.New("test req mock error") } return &http.Response{}, nil } func (rt *CustomRoundTripper2) TLSClientConfig() *tls.Config { return rt.tlsConfig } func (rt *CustomRoundTripper2) SetTLSClientConfig(tlsConfig *tls.Config) error { if rt.returnErr { return errors.New("test mock error") } rt.tlsConfig = tlsConfig return nil } func TestClientTLSConfigerInterface(t *testing.T) { t.Run("assert transport and custom roundtripper", func(t *testing.T) { c := dcnl() assertNotNil(t, c.Transport()) assertEqual(t, "http.Transport", inferType(c.Transport()).String()) ct := &CustomRoundTripper2{} c.SetTransport(ct) assertNotNil(t, c.Transport()) assertEqual(t, "resty.CustomRoundTripper2", inferType(c.Transport()).String()) }) t.Run("get and set tls config", func(t *testing.T) { c := dcnl() ct := &CustomRoundTripper2{} c.SetTransport(ct) tlsConfig := &tls.Config{InsecureSkipVerify: true} c.SetTLSClientConfig(tlsConfig) assertEqual(t, tlsConfig, c.TLSClientConfig()) }) t.Run("get tls config error", func(t *testing.T) { c := dcnl() ct := &CustomRoundTripper1{} c.SetTransport(ct) assertNil(t, c.TLSClientConfig()) }) t.Run("set tls config error", func(t *testing.T) { c := dcnl() ct := &CustomRoundTripper2{returnErr: true} c.SetTransport(ct) tlsConfig := &tls.Config{InsecureSkipVerify: true} c.SetTLSClientConfig(tlsConfig) assertNil(t, c.TLSClientConfig()) }) } func TestClientSetClientRootCertificate(t *testing.T) { client := dcnl() client.SetClientRootCertificates(filepath.Join(getTestDataPath(), "sample-root.pem")) transport, err := client.HTTPTransport() assertNil(t, err) assertNotNil(t, transport.TLSClientConfig.ClientCAs) } func TestClientSetClientRootCertificateNotExists(t *testing.T) { client := dcnl() client.SetClientRootCertificates(filepath.Join(getTestDataPath(), "not-exists-sample-root.pem")) transport, err := client.HTTPTransport() assertNil(t, err) assertNil(t, transport.TLSClientConfig) } func TestClientSetClientRootCertificateWatcher(t *testing.T) { t.Run("Cert exists", func(t *testing.T) { client := dcnl() client.SetClientRootCertificatesWatcher( &CertWatcherOptions{PoolInterval: time.Second * 1}, filepath.Join(getTestDataPath(), "sample-root.pem"), ) transport, err := client.HTTPTransport() assertNil(t, err) assertNotNil(t, transport.TLSClientConfig.ClientCAs) }) t.Run("Cert does not exist", func(t *testing.T) { client := dcnl() client.SetClientRootCertificatesWatcher(nil, filepath.Join(getTestDataPath(), "not-exists-sample-root.pem")) transport, err := client.HTTPTransport() assertNil(t, err) assertNil(t, transport.TLSClientConfig) }) } func TestClientSetClientRootCertificateFromString(t *testing.T) { client := dcnl() rootPemData, err := os.ReadFile(filepath.Join(getTestDataPath(), "sample-root.pem")) assertNil(t, err) client.SetClientRootCertificateFromString(string(rootPemData)) transport, err := client.HTTPTransport() assertNil(t, err) assertNotNil(t, transport.TLSClientConfig.ClientCAs) } func TestClientRequestMiddlewareModification(t *testing.T) { tc := dcnl() tc.AddRequestMiddleware(func(c *Client, r *Request) error { r.SetAuthToken("This is test auth token") return nil }) ts := createGetServer(t) defer ts.Close() resp, err := tc.R().Get(ts.URL + "/") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "200 OK", resp.Status()) assertEqual(t, "TestGet: text response", resp.String()) logResponse(t, resp) } func TestClientSetHeaderVerbatim(t *testing.T) { ts := createPostServer(t) defer ts.Close() c := dcnl(). SetHeaderVerbatim("header-lowercase", "value_lowercase"). SetHeader("header-lowercase", "value_standard") //lint:ignore SA1008 valid one, so ignore this! unConventionHdrValue := strings.Join(c.Header()["header-lowercase"], "") assertEqual(t, "value_lowercase", unConventionHdrValue) assertEqual(t, "value_standard", c.Header().Get("Header-Lowercase")) } func TestClientSetHeaderAny(t *testing.T) { c := dcnl(). SetHeaderAny("X-Int-Value", 42). SetHeaderAny("X-String-Value", "hello") assertEqual(t, "42", c.Header().Get("X-Int-Value")) assertEqual(t, "hello", c.Header().Get("X-String-Value")) } func TestClientSetHeaderVerbatimAny(t *testing.T) { c := dcnl(). SetHeaderVerbatimAny("header-lowercase", 123) //lint:ignore SA1008 valid one, so ignore this! unConventionHdrValue := strings.Join(c.Header()["header-lowercase"], "") assertEqual(t, "123", unConventionHdrValue) } func TestClientSetQueryParamAny(t *testing.T) { c := dcnl(). SetQueryParamAny("page", 5). SetQueryParamAny("active", true) assertEqual(t, "5", c.QueryParams().Get("page")) assertEqual(t, "true", c.QueryParams().Get("active")) } func TestClientSetPathParamAny(t *testing.T) { c := dcnl(). SetPathParamAny("userId", 42). SetPathParamAny("name", "john doe") assertEqual(t, "42", c.PathParams()["userId"]) assertEqual(t, "john%20doe", c.PathParams()["name"]) } func TestClientSetRawPathParamAny(t *testing.T) { c := dcnl(). SetPathRawParamAny("userId", 42). SetPathRawParamAny("name", "john doe") assertEqual(t, "42", c.PathParams()["userId"]) assertEqual(t, "john doe", c.PathParams()["name"]) } func TestClientSetTransport(t *testing.T) { ts := createGetServer(t) defer ts.Close() client := dcnl() transport := &http.Transport{ // something like Proxying to httptest.Server, etc... Proxy: func(req *http.Request) (*url.URL, error) { return url.Parse(ts.URL) }, } client.SetTransport(transport) transportInUse, err := client.HTTPTransport() assertNil(t, err) assertTrue(t, transport == transportInUse, "HTTP Transport should be of same type") } func TestClientSetScheme(t *testing.T) { client := dcnl() client.SetScheme("http") assertEqual(t, "http", client.scheme, "Scheme should be 'http'") } func TestClientSetCookieJar(t *testing.T) { client := dcnl() backupJar := client.httpClient.Jar client.SetCookieJar(nil) assertNil(t, client.httpClient.Jar, "CookieJar should be nil") client.SetCookieJar(backupJar) assertTrue(t, client.httpClient.Jar == backupJar, "CookieJar should be set back to original jar") } // This test methods exist for test coverage purpose // to validate the getter and setter func TestClientSettingsCoverage(t *testing.T) { c := dcnl() assertNotNil(t, c.CookieJar()) assertNotNil(t, c.ContentTypeEncoders()) assertNotNil(t, c.ContentTypeDecoders()) assertFalse(t, c.IsDebug()) assertEqual(t, math.MaxInt32, c.DebugBodyLimit()) assertNotNil(t, c.Logger()) assertEqual(t, 0, c.RetryCount()) assertEqual(t, time.Millisecond*100, c.RetryWaitTime()) assertEqual(t, time.Second*2, c.RetryMaxWaitTime()) assertFalse(t, c.IsTrace()) assertEqual(t, 0, len(c.RetryConditions())) authToken := "sample auth token value" c.SetAuthToken(authToken) assertEqual(t, authToken, c.AuthToken()) customAuthHeader := "X-Custom-Authorization" c.SetHeaderAuthorizationKey(customAuthHeader) assertEqual(t, customAuthHeader, c.HeaderAuthorizationKey()) c.SetCloseConnection(true) c.SetDebug(false) assertTrue(t, c.IsRetryDefaultConditions()) c.SetRetryDefaultConditions(false) assertFalse(t, c.IsRetryDefaultConditions()) c.SetRetryDefaultConditions(true) assertTrue(t, c.IsRetryDefaultConditions()) nr := nopReader{} n, err1 := nr.Read(nil) assertEqual(t, 0, n) assertEqual(t, io.EOF, err1) b, err1 := nr.ReadByte() assertEqual(t, byte(0), b) assertEqual(t, io.EOF, err1) // [Start] Custom Transport scenario ct := dcnl() ct.SetTransport(&CustomRoundTripper1{}) _, err := ct.HTTPTransport() assertNotNil(t, err) assertEqual(t, ErrNotHttpTransportType, err) ct.SetProxy("http://localhost:8080") ct.RemoveProxy() ct.SetCertificates(tls.Certificate{}) ct.SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}) ct.SetRootCertificateFromString("root cert") ct.outputLogTo(io.Discard) // [End] Custom Transport scenario // Response - for now stay here resp := &Response{Request: &Request{}} s := resp.fmtBodyString(0) assertEqual(t, "***** NO CONTENT *****", s) } func TestContentLengthWhenBodyIsNil(t *testing.T) { client := dcnl() fnPreRequestMiddleware1 := func(c *Client, r *Request) error { // validate assertEqual(t, int64(0), r.contentLength) assertEqual(t, int64(0), r.RawRequest.ContentLength) return nil } client.SetRequestMiddlewares( MiddlewareRequestCreate, fnPreRequestMiddleware1, ) client.R().SetBody(nil).Get("http://localhost") } func TestClientPreRequestMiddlewares(t *testing.T) { client := dcnl() fnPreRequestMiddleware1 := func(c *Client, r *Request) error { c.Logger().Debugf("I'm in Pre-Request Hook") return nil } fnPreRequestMiddleware2 := func(c *Client, r *Request) error { c.Logger().Debugf("I'm Overwriting existing Pre-Request Hook") // Reading Request `N` no of times for i := 0; i < 5; i++ { b, _ := r.RawRequest.GetBody() rb, _ := io.ReadAll(b) c.Logger().Debugf("%s %v", string(rb), len(rb)) assertTrue(t, len(rb) >= 45) } return nil } client.SetRequestMiddlewares( MiddlewareRequestCreate, fnPreRequestMiddleware1, fnPreRequestMiddleware2, ) ts := createPostServer(t) defer ts.Close() // Regular bodybuf use case resp, _ := client.R(). SetBody(map[string]any{"username": "testuser", "password": "testpass"}). Post(ts.URL + "/login") assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, `{ "id": "success", "message": "login successful" }`, resp.String()) // io.Reader body use case resp, _ = client.R(). SetHeader(hdrContentTypeKey, jsonContentType). SetBody(bytes.NewReader([]byte(`{"username":"testuser", "password":"testpass"}`))). Post(ts.URL + "/login") assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, `{ "id": "success", "message": "login successful" }`, resp.String()) } func TestClientPreRequestMiddlewareError(t *testing.T) { ts := createGetServer(t) defer ts.Close() c := dcnl() fnPreRequestMiddleware1 := func(c *Client, r *Request) error { return errors.New("error from PreRequestMiddleware") } c.SetRequestMiddlewares( MiddlewareRequestCreate, fnPreRequestMiddleware1, ) resp, err := c.R().Get(ts.URL) assertNotNil(t, err) assertEqual(t, "error from PreRequestMiddleware", err.Error()) assertNil(t, resp) } func TestClientAllowMethodGetPayload(t *testing.T) { ts := createGetServer(t) defer ts.Close() t.Run("method GET allow string payload at client level", func(t *testing.T) { c := dcnl() c.SetMethodGetAllowPayload(true) assertTrue(t, c.IsMethodGetAllowPayload()) payload := "test-payload" resp, err := c.R().SetBody(payload).Get(ts.URL + "/get-method-payload-test") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode(), "Status code should be 200 OK") assertEqual(t, payload, resp.String(), "Response payload should be same as request payload") }) t.Run("method GET allow io.Reader payload at client level", func(t *testing.T) { c := dcnl() c.SetMethodGetAllowPayload(true) assertTrue(t, c.IsMethodGetAllowPayload()) payload := "test-payload" body := bytes.NewReader([]byte(payload)) resp, err := c.R().SetBody(body).Get(ts.URL + "/get-method-payload-test") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, payload, resp.String(), "Response payload should be same as request payload") }) t.Run("method GET disallow payload at client level", func(t *testing.T) { c := dcnl() c.SetMethodGetAllowPayload(false) assertFalse(t, c.IsMethodGetAllowPayload()) payload := bytes.NewReader([]byte("test-payload")) resp, err := c.R().SetBody(payload).Get(ts.URL + "/get-method-payload-test") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "", resp.String()) }) } func TestClientAllowMethodDeletePayload(t *testing.T) { ts := createGenericServer(t) defer ts.Close() t.Run("method DELETE allow string payload at client level", func(t *testing.T) { c := dcnl().SetBaseURL(ts.URL) c.SetMethodDeleteAllowPayload(true) assertTrue(t, c.IsMethodDeleteAllowPayload()) payload := "test-payload" resp, err := c.R().SetBody(payload).Delete("/delete") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, payload, resp.String()) }) t.Run("method DELETE allow io.Reader payload at client level", func(t *testing.T) { c := dcnl().SetBaseURL(ts.URL) c.SetMethodDeleteAllowPayload(true) assertTrue(t, c.IsMethodDeleteAllowPayload()) payload := "test-payload" body := bytes.NewReader([]byte(payload)) resp, err := c.R().SetBody(body).Delete("/delete") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, payload, resp.String()) }) t.Run("method DELETE disallow payload at client level", func(t *testing.T) { c := dcnl().SetBaseURL(ts.URL) c.SetMethodDeleteAllowPayload(false) assertFalse(t, c.IsMethodDeleteAllowPayload()) payload := bytes.NewReader([]byte("test-payload")) resp, err := c.R().SetBody(payload).Delete("/delete") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "", resp.String(), "Response payload should be empty") }) } func TestClientRoundTripper(t *testing.T) { c := NewWithClient(&http.Client{}) c.outputLogTo(io.Discard) rt := &CustomRoundTripper2{} c.SetTransport(rt) ct, err := c.HTTPTransport() assertNotNil(t, err) assertNil(t, ct) assertEqual(t, ErrNotHttpTransportType, err) } func TestClientNewRequest(t *testing.T) { c := New() request := c.NewRequest() assertNotNil(t, request) } func TestClientDebugBodySizeLimit(t *testing.T) { ts := createGetServer(t) defer ts.Close() c, lb := dcldb() c.SetDebugBodyLimit(30) testcases := []struct{ url, want string }{ // Text, does not exceed limit. {url: ts.URL, want: "TestGet: text response"}, // Empty response. {url: ts.URL + "/no-content", want: "***** NO CONTENT *****"}, // JSON, does not exceed limit. {url: ts.URL + "/json", want: "{\n \"TestGet\": \"JSON response\"\n}"}, // Invalid JSON, does not exceed limit. {url: ts.URL + "/json-invalid", want: "DebugLog: Response.fmtBodyString: invalid character 'T' looking for beginning of value"}, // Text, exceeds limit. {url: ts.URL + "/long-text", want: "RESPONSE TOO LARGE"}, // JSON, exceeds limit. {url: ts.URL + "/long-json", want: "RESPONSE TOO LARGE"}, } for _, tc := range testcases { _, err := c.R().Get(tc.url) if tc.want != "" { assertError(t, err) debugLog := lb.String() if !strings.Contains(debugLog, tc.want) { t.Errorf("Expected logs to contain [%v], got [\n%v]", tc.want, debugLog) } lb.Reset() } } } func TestGzipCompress(t *testing.T) { ts := createGenericServer(t) defer ts.Close() c := dcnl() testcases := []struct{ url, want string }{ {ts.URL + "/gzip-test", "This is Gzip response testing"}, {ts.URL + "/gzip-test-gziped-empty-body", ""}, {ts.URL + "/gzip-test-no-gziped-body", ""}, } for _, tc := range testcases { resp, err := c.R().Get(tc.url) assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "200 OK", resp.Status()) assertEqual(t, tc.want, resp.String()) logResponse(t, resp) } } func TestDeflateCompress(t *testing.T) { ts := createGenericServer(t) defer ts.Close() c := dcnl() testcases := []struct{ url, want string }{ {ts.URL + "/deflate-test", "This is Deflate response testing"}, {ts.URL + "/deflate-test-empty-body", ""}, {ts.URL + "/deflate-test-no-body", ""}, } for _, tc := range testcases { resp, err := c.R().Get(tc.url) assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "200 OK", resp.Status()) assertEqual(t, tc.want, resp.String()) logResponse(t, resp) } } type lzwReader struct { s io.ReadCloser r io.ReadCloser } func (l *lzwReader) Read(p []byte) (n int, err error) { return l.r.Read(p) } func (l *lzwReader) Close() error { closeq(l.r) closeq(l.s) return nil } func TestLzwCompress(t *testing.T) { ts := createGenericServer(t) defer ts.Close() c := dcnl() // Not found scenario _, err := c.R().Get(ts.URL + "/lzw-test") assertNotNil(t, err) assertEqual(t, ErrContentDecompresserNotFound, err) // Register LZW content decoder c.AddContentDecompresser("ComPreSs", func(r io.ReadCloser) (io.ReadCloser, error) { l := &lzwReader{ s: r, r: lzw.NewReader(r, lzw.LSB, 8), } return l, nil }) c.SetContentDecompresserKeys([]string{"compress"}) testcases := []struct{ url, want string }{ {ts.URL + "/lzw-test", "This is LZW response testing"}, {ts.URL + "/lzw-test-empty-body", ""}, {ts.URL + "/lzw-test-no-body", ""}, } for _, tc := range testcases { resp, err := c.R().Get(tc.url) assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "200 OK", resp.Status()) assertEqual(t, tc.want, resp.String()) logResponse(t, resp) } } func TestClientLogCallbacks(t *testing.T) { ts := createAuthServer(t) defer ts.Close() c, lb := dcldb() c.OnDebugLog(func(dl *DebugLog) { // request // masking authorization header dl.Request.Header.Set("Authorization", "Bearer *******************************") // response dl.Response.Header.Add("X-Debug-Response-Log", "Modified :)") dl.Response.Body += "\nModified the response body content" }) c.SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}). SetAuthToken("004DDB79-6801-4587-B976-F093E6AC44FF") resp, err := c.R(). SetAuthToken("004DDB79-6801-4587-B976-F093E6AC44FF-Request"). Get(ts.URL + "/profile") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) // Validating debug log updates logInfo := lb.String() assertTrue(t, strings.Contains(logInfo, "Bearer *******************************")) assertTrue(t, strings.Contains(logInfo, "X-Debug-Response-Log")) assertTrue(t, strings.Contains(logInfo, "Modified the response body content")) // overwrite scenario c.OnDebugLog(func(dl *DebugLog) { // overwrite debug log }) resp, err = c.R(). SetAuthToken("004DDB79-6801-4587-B976-F093E6AC44FF-Request"). Get(ts.URL + "/profile") assertNil(t, err) assertNotNil(t, resp) assertEqual(t, int64(66), resp.Size()) assertTrue(t, strings.Contains(lb.String(), "Overwriting an existing on-debug-log callback from=resty.dev/v3.TestClientLogCallbacks.func1 to=resty.dev/v3.TestClientLogCallbacks.func2")) } func TestDebugLogSimultaneously(t *testing.T) { ts := createGetServer(t) c := dcnl(). SetDebug(true). SetBaseURL(ts.URL) t.Cleanup(ts.Close) for i := 0; i < 50; i++ { t.Run(fmt.Sprint(i), func(t *testing.T) { t.Parallel() resp, err := c.R(). SetBody([]int{1, 2, 3}). SetHeader(hdrContentTypeKey, "application/json; charset=utf-8"). Post("/") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) }) } } func TestCustomTransportSettings(t *testing.T) { ts := createGetServer(t) defer ts.Close() customTransportSettings := &TransportSettings{ DialerTimeout: 30 * time.Second, DialerKeepAlive: 15 * time.Second, IdleConnTimeout: 120 * time.Second, TLSHandshakeTimeout: 20 * time.Second, ExpectContinueTimeout: 1 * time.Second, MaxIdleConns: 50, MaxIdleConnsPerHost: 3, MaxConnsPerHost: 100, ResponseHeaderTimeout: 10 * time.Second, MaxResponseHeaderBytes: 1 << 10, WriteBufferSize: 2 << 10, ReadBufferSize: 2 << 10, } client := NewWithTransportSettings(customTransportSettings) client.SetBaseURL(ts.URL) resp, err := client.R().Get("/") assertNil(t, err) assertEqual(t, "TestGet: text response", resp.String()) } func TestDefaultDialerTransportSettings(t *testing.T) { ts := createGetServer(t) defer ts.Close() t.Run("transport-default", func(t *testing.T) { client := NewWithTransportSettings(nil) client.SetBaseURL(ts.URL) resp, err := client.R().Get("/") assertNil(t, err) assertEqual(t, "TestGet: text response", resp.String()) }) t.Run("dialer-transport-default", func(t *testing.T) { client := NewWithDialerAndTransportSettings(nil, nil) client.SetBaseURL(ts.URL) resp, err := client.R().Get("/") assertNil(t, err) assertEqual(t, "TestGet: text response", resp.String()) }) } func TestNewWithDialer(t *testing.T) { ts := createGetServer(t) defer ts.Close() dialer := &net.Dialer{ Timeout: 15 * time.Second, KeepAlive: 15 * time.Second, } client := NewWithDialer(dialer) client.SetBaseURL(ts.URL) resp, err := client.R().Get("/") assertNil(t, err) assertEqual(t, "TestGet: text response", resp.String()) } func TestNewWithLocalAddr(t *testing.T) { ts := createGetServer(t) defer ts.Close() localAddress, _ := net.ResolveTCPAddr("tcp", "127.0.0.1") client := NewWithLocalAddr(localAddress) client.SetBaseURL(ts.URL) resp, err := client.R().Get("/") assertNil(t, err) assertEqual(t, "TestGet: text response", resp.String()) } func TestClientOnResponseFailure(t *testing.T) { tests := []struct { name string setup func(*Client) isError bool hasResponse bool panics bool }{ { name: "successful_request", }, { name: "http_status_failure", setup: func(client *Client) { client.SetAuthToken("BAD") }, }, { name: "before_request_failure", setup: func(client *Client) { client.AddRequestMiddleware(func(client *Client, request *Request) error { return fmt.Errorf("before request") }) }, isError: true, }, { name: "before_request_failure_retry", setup: func(client *Client) { client.SetRetryCount(3).AddRequestMiddleware(func(client *Client, request *Request) error { return fmt.Errorf("before request") }) }, isError: true, }, { name: "after_response_failure", setup: func(client *Client) { client.AddResponseMiddleware(func(client *Client, response *Response) error { return fmt.Errorf("after response") }) }, isError: true, hasResponse: true, }, { name: "after_response_failure_retry", setup: func(client *Client) { client.SetRetryCount(3).AddResponseMiddleware(func(client *Client, response *Response) error { return fmt.Errorf("after response") }) }, isError: true, hasResponse: true, }, { name: "panic with error", setup: func(client *Client) { client.AddRequestMiddleware(func(client *Client, request *Request) error { panic(fmt.Errorf("before request")) }) }, isError: false, hasResponse: false, panics: true, }, { name: "panic with string", setup: func(client *Client) { client.AddRequestMiddleware(func(client *Client, request *Request) error { panic("before request") }) }, isError: false, hasResponse: false, panics: true, }, } for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { t.Parallel() ts := createAuthServer(t) defer ts.Close() var assertErrorHook = func(r *Request, err error) { assertNotNil(t, r) v, ok := err.(*ResponseError) assertEqual(t, test.hasResponse, ok) if ok { assertNotNil(t, v.Response) assertNotNil(t, v.Err) } } var errorHook1, errorHook2, successHook1, successHook2, panicHook1, panicHook2 int defer func() { if rec := recover(); rec != nil { assertTrue(t, test.panics, "expected to panic") assertEqual(t, 0, errorHook1) assertEqual(t, 0, successHook1) assertEqual(t, 1, panicHook1) assertEqual(t, 1, panicHook2) } }() c := dcnl(). SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}). SetAuthToken("004DDB79-6801-4587-B976-F093E6AC44FF"). SetRetryCount(0). SetRetryMaxWaitTime(time.Microsecond). AddRetryConditions(func(response *Response, err error) bool { if err != nil { return true } return response.IsStatusFailure() }). OnError(func(r *Request, err error) { assertErrorHook(r, err) errorHook1++ }). OnError(func(r *Request, err error) { assertErrorHook(r, err) errorHook2++ }). OnPanic(func(r *Request, err error) { assertErrorHook(r, err) panicHook1++ }). OnPanic(func(r *Request, err error) { assertErrorHook(r, err) panicHook2++ }). OnSuccess(func(c *Client, resp *Response) { assertNotNil(t, c) assertNotNil(t, resp) successHook1++ }). OnSuccess(func(c *Client, resp *Response) { assertNotNil(t, c) assertNotNil(t, resp) successHook2++ }) if test.setup != nil { test.setup(c) } _, err := c.R().Get(ts.URL + "/profile") if test.isError { assertNotNil(t, err) assertEqual(t, 1, errorHook1) assertEqual(t, 1, errorHook2) assertEqual(t, 0, successHook1) assertEqual(t, 0, panicHook1) } else { assertError(t, err) assertEqual(t, 0, errorHook1) assertEqual(t, 1, successHook1) assertEqual(t, 1, successHook2) assertEqual(t, 0, panicHook1) } }) } } func TestResponseError(t *testing.T) { err := errors.New("error message") re := &ResponseError{ Response: &Response{}, Err: err, } assertNotNil(t, re.Unwrap()) assertEqual(t, err.Error(), re.Error()) } func TestHostURLForGH318AndGH407(t *testing.T) { ts := createPostServer(t) defer ts.Close() targetURL, _ := url.Parse(ts.URL) t.Log("ts.URL:", ts.URL) t.Log("targetURL.Host:", targetURL.Host) // Sample output // ts.URL: http://127.0.0.1:55967 // targetURL.Host: 127.0.0.1:55967 // Unable use the local http test server for this // use case testing // // using `targetURL.Host` value or test case yield to ERROR // "parse "127.0.0.1:55967": first path segment in URL cannot contain colon" // test the functionality with httpbin.org locally // will figure out later c := dcnl() // c.SetScheme("http") // c.SetHostURL(targetURL.Host + "/") // t.Log("with leading `/`") // resp, err := c.R().Post("/login") // assertNil(t, err) // assertNotNil(t, resp) // t.Log("\nwithout leading `/`") // resp, err = c.R().Post("login") // assertNil(t, err) // assertNotNil(t, resp) t.Log("with leading `/` on request & with trailing `/` on host url") c.SetBaseURL(ts.URL + "/") resp, err := c.R(). SetBody(map[string]any{"username": "testuser", "password": "testpass"}). Post("/login") assertNil(t, err) assertNotNil(t, resp) } func TestPostRedirectWithBody(t *testing.T) { ts := createPostServer(t) defer ts.Close() mu := sync.Mutex{} rnd := rand.New(rand.NewSource(time.Now().UnixNano())) c := dcnl().SetBaseURL(ts.URL) totalRequests := 4000 wg := sync.WaitGroup{} wg.Add(totalRequests) for i := 0; i < totalRequests; i++ { if i%50 == 0 { time.Sleep(20 * time.Millisecond) // to prevent test server socket exhaustion } go func() { defer wg.Done() mu.Lock() randNumber := rnd.Int() mu.Unlock() resp, err := c.R(). SetBody([]byte(strconv.Itoa(randNumber))). Post("/redirect-with-body") assertError(t, err) assertNotNil(t, resp) }() } wg.Wait() } func TestUnixSocket(t *testing.T) { unixSocketAddr := createUnixSocketEchoServer(t) defer os.Remove(unixSocketAddr) // Create a Go's http.Transport so we can set it in resty. transport := http.Transport{ Dial: func(_, _ string) (net.Conn, error) { return net.Dial("unix", unixSocketAddr) }, } // Create a Resty Client client := New() // Set the previous transport that we created, set the scheme of the communication to the // socket and set the unixSocket as the HostURL. client.SetTransport(&transport).SetScheme("http").SetBaseURL(unixSocketAddr) // No need to write the host's URL on the request, just the path. res, err := client.R().Get("http://localhost/") assertNil(t, err) assertEqual(t, "Hi resty client from a server running on Unix domain socket!", res.String()) res, err = client.R().Get("http://localhost/hello") assertNil(t, err) assertEqual(t, "Hello resty client from a server running on endpoint /hello!", res.String()) } func TestClientClone(t *testing.T) { parent := New() // set a non-interface field parent.SetBaseURL("http://localhost") parent.SetBasicAuth("parent", "") parent.SetProxy("http://localhost:8080") parent.SetCookie(&http.Cookie{ Name: "go-resty-1", Value: "This is cookie 1 value", }) parent.SetCookies([]*http.Cookie{ { Name: "go-resty-2", Value: "This is cookie 2 value", }, { Name: "go-resty-3", Value: "This is cookie 3 value", }, }) clone := parent.Clone(context.Background()) // update value of non-interface type - change will only happen on clone clone.SetBaseURL("https://local.host") clone.SetBasicAuth("clone", "clone") // assert non-interface type assertEqual(t, "http://localhost", parent.BaseURL()) assertEqual(t, "https://local.host", clone.BaseURL()) assertEqual(t, "parent", parent.credentials.Username) assertEqual(t, "clone", clone.credentials.Username) // assert interface/pointer type assertEqual(t, parent.Client(), clone.Client()) // assert cookies parentCookies := parent.Cookies() cloneCookies := clone.Cookies() assertEqual(t, len(parentCookies), len(cloneCookies)) for i := range parentCookies { assertEqual(t, parentCookies[i].Name, cloneCookies[i].Name) assertEqual(t, parentCookies[i].Value, cloneCookies[i].Value) } } func TestResponseBodyLimit(t *testing.T) { ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { io.CopyN(w, cryprand.Reader, 100*800) }) defer ts.Close() t.Run("client body limit", func(t *testing.T) { resBodyLimit := int64(1024) c := dcnl().SetResponseBodyLimit(resBodyLimit) assertEqual(t, resBodyLimit, c.ResponseBodyLimit()) resp, err := c.R().Get(ts.URL + "/") assertNotNil(t, err) assertErrorIs(t, ErrReadExceedsThresholdLimit, err) assertTrue(t, resp.Size() == resBodyLimit) }) t.Run("request body limit", func(t *testing.T) { resBodyLimit := int64(1024) c := dcnl() resp, err := c.R().SetResponseBodyLimit(resBodyLimit).Get(ts.URL + "/") assertNotNil(t, err) assertErrorIs(t, ErrReadExceedsThresholdLimit, err) assertTrue(t, resp.Size() == resBodyLimit) }) t.Run("body less than limit", func(t *testing.T) { c := dcnl() res, err := c.R().SetResponseBodyLimit(800*100 + 10).Get(ts.URL + "/") assertNil(t, err) assertEqual(t, 800*100, len(res.Bytes())) assertEqual(t, int64(800*100), res.Size()) }) t.Run("no body limit", func(t *testing.T) { c := dcnl() res, err := c.R().Get(ts.URL + "/") assertNil(t, err) assertEqual(t, 800*100, len(res.Bytes())) assertEqual(t, int64(800*100), res.Size()) }) t.Run("read error", func(t *testing.T) { tse := createTestServer(func(w http.ResponseWriter, r *http.Request) { w.Header().Set(hdrContentEncodingKey, "gzip") var buf [1024]byte w.Write(buf[:]) }) defer tse.Close() c := dcnl() _, err := c.R().SetResponseBodyLimit(10240).Get(tse.URL + "/") assertErrorIs(t, gzip.ErrHeader, err) }) } func TestClient_executeReadAllError(t *testing.T) { ts := createGetServer(t) defer ts.Close() ioReadAll = func(_ io.Reader) ([]byte, error) { return nil, errors.New("test case error") } t.Cleanup(func() { ioReadAll = io.ReadAll }) c := dcnld() resp, err := c.R(). SetQueryParam("request_no", strconv.FormatInt(time.Now().Unix(), 10)). Get(ts.URL + "/json") assertNotNil(t, err) assertEqual(t, "test case error", err.Error()) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "", resp.String()) } func TestClientDebugf(t *testing.T) { t.Run("Debug mode enabled", func(t *testing.T) { var b bytes.Buffer c := New().SetLogger(&logger{l: log.New(&b, "", 0)}).SetDebug(true) c.debugf("hello") assertEqual(t, "DEBUG RESTY hello\n", b.String()) }) t.Run("Debug mode disabled", func(t *testing.T) { var b bytes.Buffer c := New().SetLogger(&logger{l: log.New(&b, "", 0)}) c.debugf("hello") assertEqual(t, "", b.String()) }) } func TestClientOnClose(t *testing.T) { var hookExecuted bool c := dcnl() c.OnClose(func() { hookExecuted = true }) err := c.Close() assertNil(t, err) assertTrue(t, hookExecuted, "OnClose hook should be executed") } func TestClientOnCloseMultipleHooks(t *testing.T) { var executionOrder []string c := dcnl() c.OnClose(func() { executionOrder = append(executionOrder, "first") }) c.OnClose(func() { executionOrder = append(executionOrder, "second") }) c.OnClose(func() { executionOrder = append(executionOrder, "third") }) err := c.Close() assertNil(t, err) assertEqual(t, []string{"first", "second", "third"}, executionOrder) } func TestClientHedgingMutualExclusionWithRetry(t *testing.T) { c := dcnl() // Set retry first c.SetRetryCount(2) assertEqual(t, 2, c.RetryCount()) // Enable hedging should disable retry by default h := NewHedging(). SetDelay(50 * time.Millisecond). SetMaxRequest(3). SetMaxRequestPerSecond(0) c.SetHedging(h) assertEqual(t, 0, c.RetryCount()) // But user can re-enable retry as fallback c.SetRetryCount(1) assertEqual(t, 1, c.RetryCount()) assertEqual(t, true, c.isHedgingEnabled()) // Disable hedging c.SetHedging(nil) assertEqual(t, false, c.isHedgingEnabled()) assertEqual(t, 1, c.RetryCount()) // Retry count should remain } ================================================ FILE: context_test.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // 2016 Andrew Grigorev (https://github.com/ei-grad) // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty import ( "context" "errors" "net/http" "sync/atomic" "testing" "time" ) func TestClientSetContext(t *testing.T) { ts := createGetServer(t) defer ts.Close() c := dcnl() assertNil(t, c.Context()) c.SetContext(context.Background()) resp, err := c.R().Get(ts.URL + "/") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "200 OK", resp.Status()) assertEqual(t, "TestGet: text response", resp.String()) logResponse(t, resp) } func TestRequestSetContext(t *testing.T) { ts := createGetServer(t) defer ts.Close() resp, err := dcnl().R(). SetContext(context.Background()). Get(ts.URL + "/") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "200 OK", resp.Status()) assertEqual(t, "TestGet: text response", resp.String()) logResponse(t, resp) } func TestSetContextWithError(t *testing.T) { ts := createGetServer(t) defer ts.Close() resp, err := dcnlr(). SetContext(context.Background()). Get(ts.URL + "/mypage") assertError(t, err) assertEqual(t, http.StatusBadRequest, resp.StatusCode(), "expected bad request status code") assertEqual(t, "", resp.String(), "expected empty response body on bad request") logResponse(t, resp) } func TestSetContextCancel(t *testing.T) { ch := make(chan struct{}) ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { defer func() { ch <- struct{}{} // tell test request is finished }() t.Logf("Server: %v %v", r.Method, r.URL.Path) ch <- struct{}{} <-ch // wait for client to finish request n, err := w.Write([]byte("TestSetContextCancel: response")) // FIXME? test server doesn't handle request cancellation t.Logf("Server: wrote %d bytes", n) t.Logf("Server: err is %v ", err) }) defer ts.Close() ctx, cancel := context.WithCancel(context.Background()) go func() { <-ch // wait for server to start request handling cancel() }() _, err := dcnl().R(). SetContext(ctx). Get(ts.URL + "/") ch <- struct{}{} // tell server to continue request handling <-ch // wait for server to finish request handling t.Logf("Error: %v", err) if !errIsContextCanceled(err) { t.Errorf("Got unexpected error: %v", err) } } func TestSetContextCancelRetry(t *testing.T) { reqCount := 0 ch := make(chan struct{}) ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { reqCount++ defer func() { ch <- struct{}{} // tell test request is finished }() t.Logf("Server: %v %v", r.Method, r.URL.Path) ch <- struct{}{} <-ch // wait for client to finish request n, err := w.Write([]byte("TestSetContextCancel: response")) // FIXME? test server doesn't handle request cancellation t.Logf("Server: wrote %d bytes", n) t.Logf("Server: err is %v ", err) }) defer ts.Close() ctx, cancel := context.WithCancel(context.Background()) go func() { <-ch // wait for server to start request handling cancel() }() c := dcnl(). SetTimeout(time.Second * 3). SetRetryCount(3) _, err := c.R(). SetContext(ctx). Get(ts.URL + "/") ch <- struct{}{} // tell server to continue request handling <-ch // wait for server to finish request handling t.Logf("Error: %v", err) if !errIsContextCanceled(err) { t.Errorf("Got unexpected error: %v", err) } if reqCount != 1 { t.Errorf("Request was retried %d times instead of 1", reqCount) } } func TestSetContextCancelWithError(t *testing.T) { ch := make(chan struct{}) ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { defer func() { ch <- struct{}{} // tell test request is finished }() t.Logf("Server: %v %v", r.Method, r.URL.Path) t.Log("Server: sending StatusBadRequest response") w.WriteHeader(http.StatusBadRequest) ch <- struct{}{} <-ch // wait for client to finish request n, err := w.Write([]byte("TestSetContextCancelWithError: response")) // FIXME? test server doesn't handle request cancellation t.Logf("Server: wrote %d bytes", n) t.Logf("Server: err is %v ", err) }) defer ts.Close() ctx, cancel := context.WithCancel(context.Background()) go func() { <-ch // wait for server to start request handling cancel() }() _, err := dcnl().R(). SetContext(ctx). Get(ts.URL + "/") ch <- struct{}{} // tell server to continue request handling <-ch // wait for server to finish request handling t.Logf("Error: %v", err) if !errIsContextCanceled(err) { t.Errorf("Got unexpected error: %v", err) } } func TestClientRetryWithSetContext(t *testing.T) { var attempt int32 ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { t.Logf("Method: %v", r.Method) t.Logf("Path: %v", r.URL.Path) if atomic.AddInt32(&attempt, 1) <= 4 { time.Sleep(100 * time.Millisecond) } _, _ = w.Write([]byte("TestClientRetry page")) }) defer ts.Close() c := dcnl(). SetTimeout(50 * time.Millisecond). SetRetryCount(3) _, err := c.R(). SetContext(context.Background()). Get(ts.URL + "/") assertNotNil(t, ts) assertNotNil(t, err) assertErrorIs(t, context.DeadlineExceeded, err, "expected context deadline exceeded error") } func TestRequestContext(t *testing.T) { client := dcnl() r := client.NewRequest() assertNotNil(t, r.Context(), "expected default context to be non-nil") r.SetContext(context.Background()) assertNotNil(t, r.Context(), "expected context to be set") } func errIsContextCanceled(err error) bool { return errors.Is(err, context.Canceled) } ================================================ FILE: curl.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty import ( "bytes" "io" "net/http" "regexp" "net/url" "strings" ) func buildCurlCmd(req *Request) string { // generate curl raw headers var curl = "curl -X " + req.Method + " " headers := dumpCurlHeaders(req.RawRequest) for _, kv := range *headers { curl += "-H " + cmdQuote(kv[0]+": "+kv[1]) + " " } // generate curl cookies if cookieJar := req.client.CookieJar(); cookieJar != nil { if cookies := cookieJar.Cookies(req.RawRequest.URL); len(cookies) > 0 { curl += "-H " + cmdQuote(dumpCurlCookies(cookies)) + " " } } // generate curl body except for io.Reader and multipart request flow if req.RawRequest.GetBody != nil { body, err := req.RawRequest.GetBody() if err == nil { buf, _ := io.ReadAll(body) curl += "-d " + cmdQuote(string(bytes.TrimRight(buf, "\n"))) + " " } else { req.log.Errorf("curl: %v", err) curl += "-d ''" } } urlString := cmdQuote(req.RawRequest.URL.String()) if urlString == "''" { urlString = "'http://unexecuted-request'" } curl += urlString return curl } // dumpCurlCookies dumps cookies to curl format func dumpCurlCookies(cookies []*http.Cookie) string { sb := strings.Builder{} sb.WriteString("Cookie: ") for _, cookie := range cookies { sb.WriteString(cookie.Name + "=" + url.QueryEscape(cookie.Value) + "&") } return strings.TrimRight(sb.String(), "&") } // dumpCurlHeaders dumps headers to curl format func dumpCurlHeaders(req *http.Request) *[][2]string { headers := [][2]string{} for k, vs := range req.Header { for _, v := range vs { headers = append(headers, [2]string{k, v}) } } n := len(headers) for i := 0; i < n; i++ { for j := n - 1; j > i; j-- { jj := j - 1 h1, h2 := headers[j], headers[jj] if h1[0] < h2[0] { headers[jj], headers[j] = headers[j], headers[jj] } } } return &headers } var regexCmdQuote = regexp.MustCompile(`[^\w@%+=:,./-]`) // cmdQuote method to escape arbitrary strings for a safe use as // command line arguments in the most common POSIX shells. // // The original Python package which this work was inspired by can be found // at https://pypi.python.org/pypi/shellescape. func cmdQuote(s string) string { if len(s) == 0 { return "''" } if regexCmdQuote.MatchString(s) { return "'" + strings.ReplaceAll(s, "'", "'\"'\"'") + "'" } return s } ================================================ FILE: curl_test.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty import ( "bytes" "errors" "io" "net/http" "net/http/cookiejar" "strings" "testing" ) func TestCurlGenerateUnexecutedRequest(t *testing.T) { req := dcnldr(). SetBody(map[string]string{ "name": "Resty", }). SetCookies( []*http.Cookie{ {Name: "count", Value: "1"}, }, ). SetMethod(MethodPost) assertEqual(t, "", req.CurlCmd()) curlCmdUnexecuted := req.SetCurlCmdGenerate(true).CurlCmd() req.SetCurlCmdGenerate(false) if !strings.Contains(curlCmdUnexecuted, "Cookie: count=1") || !strings.Contains(curlCmdUnexecuted, "curl -X POST") || !strings.Contains(curlCmdUnexecuted, `-d '{"name":"Resty"}'`) { t.Fatal("Incomplete curl:", curlCmdUnexecuted) } else { t.Log("curlCmdUnexecuted: \n", curlCmdUnexecuted) } } func TestCurlGenerateExecutedRequest(t *testing.T) { ts := createPostServer(t) defer ts.Close() data := map[string]string{ "name": "Resty", } c := dcnl().SetDebug(true) req := c.R(). SetBody(data). SetCookies( []*http.Cookie{ {Name: "count", Value: "1"}, }, ) url := ts.URL + "/curl-cmd-post" resp, err := req. SetCurlCmdGenerate(true). Post(url) if err != nil { t.Fatal(err) } curlCmdExecuted := resp.Request.CurlCmd() c.SetCurlCmdGenerate(false) req.SetCurlCmdGenerate(false) if !strings.Contains(curlCmdExecuted, "Cookie: count=1") || !strings.Contains(curlCmdExecuted, "curl -X POST") || !strings.Contains(curlCmdExecuted, `-d '{"name":"Resty"}'`) || !strings.Contains(curlCmdExecuted, url) { t.Fatal("Incomplete curl:", curlCmdExecuted) } else { t.Log("curlCmdExecuted: \n", curlCmdExecuted) } } func TestCurlCmdDebugMode(t *testing.T) { ts := createPostServer(t) defer ts.Close() c, logBuf := dcldb() c.SetCurlCmdGenerate(true). SetCurlCmdDebugLog(true) // Build request req := c.R(). SetBody(map[string]string{ "name": "Resty", }). SetCookies( []*http.Cookie{ {Name: "count", Value: "1"}, }, ). SetCurlCmdDebugLog(true) // Execute request: set debug mode url := ts.URL + "/curl-cmd-post" _, err := req.SetDebug(true).Post(url) if err != nil { t.Fatal(err) } c.SetCurlCmdGenerate(false) req.SetCurlCmdGenerate(false) // test logContent curl cmd logContent := logBuf.String() if !strings.Contains(logContent, "Cookie: count=1") || !strings.Contains(logContent, `-d '{"name":"Resty"}'`) { t.Fatal("Incomplete debug curl info:", logContent) } } func TestCurl_buildCurlCmd(t *testing.T) { tests := []struct { name string method string url string headers map[string]string body string cookies []*http.Cookie expected string }{ { name: "With Headers", method: "GET", url: "http://example.com", headers: map[string]string{"Content-Type": "application/json", "Authorization": "Bearer token"}, expected: "curl -X GET -H 'Authorization: Bearer token' -H 'Content-Type: application/json' http://example.com", }, { name: "With Body", method: "POST", url: "http://example.com", headers: map[string]string{"Content-Type": "application/json"}, body: `{"key":"value"}`, expected: "curl -X POST -H 'Content-Type: application/json' -d '{\"key\":\"value\"}' http://example.com", }, { name: "With Empty Body", method: "POST", url: "http://example.com", headers: map[string]string{"Content-Type": "application/json"}, expected: "curl -X POST -H 'Content-Type: application/json' http://example.com", }, { name: "With Query Params", method: "GET", url: "http://example.com?param1=value1¶m2=value2", expected: "curl -X GET 'http://example.com?param1=value1¶m2=value2'", }, { name: "With Special Characters in URL", method: "GET", url: "http://example.com/path with spaces", expected: "curl -X GET http://example.com/path%20with%20spaces", }, { name: "With Cookies", method: "GET", url: "http://example.com", cookies: []*http.Cookie{{Name: "session_id", Value: "abc123"}}, expected: "curl -X GET -H 'Cookie: session_id=abc123' http://example.com", }, { name: "Without Cookies", method: "GET", url: "http://example.com", expected: "curl -X GET http://example.com", }, { name: "With Multiple Cookies", method: "GET", url: "http://example.com", cookies: []*http.Cookie{{Name: "session_id", Value: "abc123"}, {Name: "user_id", Value: "user456"}}, expected: "curl -X GET -H 'Cookie: session_id=abc123&user_id=user456' http://example.com", }, { name: "With Empty Cookie Jar", method: "GET", url: "http://example.com", expected: "curl -X GET http://example.com", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := dcnl() req := c.R().SetMethod(tt.method).SetURL(tt.url) if !isStringEmpty(tt.body) { req.SetBody(bytes.NewBufferString(tt.body)) } for k, v := range tt.headers { req.SetHeader(k, v) } err := createRawRequest(c, req) assertNil(t, err) if len(tt.cookies) > 0 { cookieJar, _ := cookiejar.New(nil) cookieJar.SetCookies(req.RawRequest.URL, tt.cookies) c.SetCookieJar(cookieJar) } curlCmd := buildCurlCmd(req) assertEqual(t, tt.expected, curlCmd) }) } } func TestCurlRequestGetBodyError(t *testing.T) { c := dcnl(). SetDebug(true). SetRequestMiddlewares( MiddlewareRequestCreate, func(_ *Client, r *Request) error { r.RawRequest.GetBody = func() (io.ReadCloser, error) { return nil, errors.New("test case error") } return nil }, ) req := c.R(). SetBody(map[string]string{ "name": "Resty", }). SetCookies( []*http.Cookie{ {Name: "count", Value: "1"}, }, ). SetMethod(MethodPost) assertEqual(t, "", req.CurlCmd()) curlCmdUnexecuted := req.SetCurlCmdGenerate(true).CurlCmd() req.SetCurlCmdGenerate(false) if !strings.Contains(curlCmdUnexecuted, "Cookie: count=1") || !strings.Contains(curlCmdUnexecuted, "curl -X POST") || !strings.Contains(curlCmdUnexecuted, `-d ''`) { t.Fatal("Incomplete curl:", curlCmdUnexecuted) } else { t.Log("curlCmdUnexecuted: \n", curlCmdUnexecuted) } } func TestCurlRequestMiddlewaresError(t *testing.T) { errMsg := "middleware error" c := dcnl().SetDebug(true). SetRequestMiddlewares( func(c *Client, r *Request) error { return errors.New(errMsg) }, MiddlewareRequestCreate, ) curlCmdUnexecuted := c.R().SetCurlCmdGenerate(true).CurlCmd() assertEqual(t, "", curlCmdUnexecuted) } func TestCurlMiscTestCoverage(t *testing.T) { cookieStr := dumpCurlCookies([]*http.Cookie{ {Name: "count", Value: "1"}, }) assertEqual(t, "Cookie: count=1", cookieStr) } ================================================ FILE: debug.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty import ( "fmt" "net/http" "time" ) type ( // DebugLogCallbackFunc function type is for request and response debug log callback purposes. // It gets called before Resty logs it DebugLogCallbackFunc func(*DebugLog) // DebugLogFormatterFunc function type is used to implement debug log formatting. // See out of the box [DebugLogStringFormatter], [DebugLogJSONFormatter] DebugLogFormatterFunc func(*DebugLog) string // DebugLog struct is used to collect details from Resty request and response // for debug logging callback purposes. DebugLog struct { Request *DebugLogRequest `json:"request"` Response *DebugLogResponse `json:"response"` TraceInfo *TraceInfo `json:"trace_info"` } // DebugLogRequest type used to capture debug info about the [Request]. DebugLogRequest struct { CorrelationID string `json:"correlation_id"` Host string `json:"host"` URI string `json:"uri"` Method string `json:"method"` Proto string `json:"proto"` Header http.Header `json:"header"` CurlCmd string `json:"curl_cmd"` Attempt int `json:"attempt"` Body string `json:"body"` } // DebugLogResponse type used to capture debug info about the [Response]. DebugLogResponse struct { StatusCode int `json:"status_code"` Status string `json:"status"` Proto string `json:"proto"` ReceivedAt time.Time `json:"received_at"` Duration time.Duration `json:"duration"` Size int64 `json:"size"` Header http.Header `json:"header"` Body string `json:"body"` } ) // DebugLogFormatter function formats the given debug log info in human readable // format. // // This is the default debug log formatter in the Resty. func DebugLogFormatter(dl *DebugLog) string { debugLog := "\n==============================================================================\n" req := dl.Request if len(req.CurlCmd) > 0 { debugLog += "~~~ REQUEST(CURL) ~~~\n" + fmt.Sprintf(" %v\n", req.CurlCmd) } debugLog += "~~~ REQUEST ~~~\n" + fmt.Sprintf("CORRELATION ID: %s\n", req.CorrelationID) + fmt.Sprintf("%s %s %s\n", req.Method, req.URI, req.Proto) + fmt.Sprintf("HOST : %s\n", req.Host) + fmt.Sprintf("HEADERS:\n%s\n", composeHeaders(req.Header)) + fmt.Sprintf("BODY :\n%v\n", req.Body) + fmt.Sprintf("ATTEMPT : %d\n", req.Attempt) + "------------------------------------------------------------------------------\n" res := dl.Response debugLog += "~~~ RESPONSE ~~~\n" + fmt.Sprintf("STATUS : %s\n", res.Status) + fmt.Sprintf("PROTO : %s\n", res.Proto) + fmt.Sprintf("RECEIVED AT : %v\n", res.ReceivedAt.Format(time.RFC3339Nano)) + fmt.Sprintf("DURATION : %v\n", res.Duration) + "HEADERS :\n" + composeHeaders(res.Header) + "\n" + fmt.Sprintf("BODY :\n%v\n", res.Body) if dl.TraceInfo != nil { debugLog += "------------------------------------------------------------------------------\n" debugLog += fmt.Sprintf("%v\n", dl.TraceInfo) } debugLog += "==============================================================================\n" return debugLog } // DebugLogJSONFormatter function formats the given debug log info in JSON format. func DebugLogJSONFormatter(dl *DebugLog) string { return toJSON(dl) } func debugLogger(c *Client, res *Response) { req := res.Request if !req.IsDebug { return } rdl := &DebugLogResponse{ StatusCode: res.StatusCode(), Status: res.Status(), Proto: res.Proto(), ReceivedAt: res.ReceivedAt(), Duration: res.Duration(), Size: res.Size(), Header: sanitizeHeaders(res.Header().Clone()), Body: res.fmtBodyString(res.Request.DebugBodyLimit), } dl := &DebugLog{ Request: req.values[debugRequestLogKey].(*DebugLogRequest), Response: rdl, } if res.Request.IsTrace { ti := req.TraceInfo() dl.TraceInfo = &ti } dblCallback := c.debugLogCallbackFunc() if dblCallback != nil { dblCallback(dl) } formatterFunc := c.debugLogFormatterFunc() if formatterFunc != nil { debugLog := formatterFunc(dl) req.log.Debugf("%s", debugLog) } } const debugRequestLogKey = "__restyDebugRequestLog" func prepareRequestDebugInfo(c *Client, r *Request) { if !r.IsDebug { return } rr := r.RawRequest rh := rr.Header.Clone() if c.Client().Jar != nil { for _, cookie := range c.Client().Jar.Cookies(r.RawRequest.URL) { s := fmt.Sprintf("%s=%s", cookie.Name, cookie.Value) if c := rh.Get(hdrCookieKey); isStringEmpty(c) { rh.Set(hdrCookieKey, s) } else { rh.Set(hdrCookieKey, c+"; "+s) } } } rdl := &DebugLogRequest{ CorrelationID: r.CorrelationID, Host: rr.URL.Host, URI: rr.URL.RequestURI(), Method: r.Method, Proto: rr.Proto, Header: sanitizeHeaders(rh), Attempt: r.Attempt, Body: r.fmtBodyString(r.DebugBodyLimit), } if r.isCurlCmdGenerate && r.isCurlCmdDebugLog { rdl.CurlCmd = r.curlCmdString } r.initValuesMap() r.values[debugRequestLogKey] = rdl } ================================================ FILE: digest.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // 2023 Segev Dagan (https://github.com/segevda) // 2024 Philipp Wolfer (https://github.com/phw) // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty import ( "bytes" "crypto/md5" "crypto/rand" "crypto/sha256" "crypto/sha512" "encoding/hex" "errors" "fmt" "hash" "io" "net/http" "strconv" "strings" ) var ( ErrDigestBadChallenge = errors.New("resty: digest: challenge is bad") ErrDigestInvalidCharset = errors.New("resty: digest: invalid charset") ErrDigestAlgNotSupported = errors.New("resty: digest: algorithm is not supported") ErrDigestQopNotSupported = errors.New("resty: digest: qop is not supported") ) // Reference: https://datatracker.ietf.org/doc/html/rfc7616#section-6.1 var digestHashFuncs = map[string]func() hash.Hash{ "": md5.New, "MD5": md5.New, "MD5-sess": md5.New, "SHA-256": sha256.New, "SHA-256-sess": sha256.New, "SHA-512": sha512.New, "SHA-512-sess": sha512.New, "SHA-512-256": sha512.New512_256, "SHA-512-256-sess": sha512.New512_256, } const ( qopAuth = "auth" qopAuthInt = "auth-int" ) type digestTransport struct { *credentials transport http.RoundTripper } func (dt *digestTransport) RoundTrip(req *http.Request) (*http.Response, error) { // first request without body for all HTTP verbs req1 := dt.cloneReq(req, true) // make a request to get the 401 that contains the challenge. res, err := dt.transport.RoundTrip(req1) if err != nil { return nil, err } if res.StatusCode != http.StatusUnauthorized { return res, nil } _, _ = ioCopy(io.Discard, res.Body) closeq(res.Body) chaHdrValue := strings.TrimSpace(res.Header.Get(hdrWwwAuthenticateKey)) if chaHdrValue == "" { return nil, ErrDigestBadChallenge } cha, err := dt.parseChallenge(chaHdrValue) if err != nil { return nil, err } // prepare second request req2 := dt.cloneReq(req, false) cred, err := dt.createCredentials(cha, req2) if err != nil { return nil, err } auth, err := cred.digest(cha) if err != nil { return nil, err } req2.Header.Set(hdrAuthorizationKey, auth) return dt.transport.RoundTrip(req2) } func (dt *digestTransport) cloneReq(r *http.Request, first bool) *http.Request { r1 := r.Clone(r.Context()) if first { r1.Body = http.NoBody r1.ContentLength = 0 r1.GetBody = nil } return r1 } func (dt *digestTransport) parseChallenge(input string) (*digestChallenge, error) { const ws = " \n\r\t" s := strings.Trim(input, ws) if !strings.HasPrefix(s, "Digest ") { return nil, ErrDigestBadChallenge } s = strings.Trim(s[7:], ws) c := &digestChallenge{} b := strings.Builder{} key := "" quoted := false for _, r := range s { switch r { case '"': quoted = !quoted case ',': if quoted { b.WriteRune(r) } else { val := strings.Trim(b.String(), ws) b.Reset() if err := c.setValue(key, val); err != nil { return nil, err } key = "" } case '=': if quoted { b.WriteRune(r) } else { key = strings.Trim(b.String(), ws) b.Reset() } default: b.WriteRune(r) } } key = strings.TrimSpace(key) if quoted || (key == "" && b.Len() > 0) { return nil, ErrDigestBadChallenge } if key != "" { val := strings.Trim(b.String(), ws) if err := c.setValue(key, val); err != nil { return nil, err } } return c, nil } func (dt *digestTransport) createCredentials(cha *digestChallenge, req *http.Request) (*digestCredentials, error) { cred := &digestCredentials{ username: dt.Username, password: dt.Password, uri: req.URL.RequestURI(), method: req.Method, realm: cha.realm, nonce: cha.nonce, nc: cha.nc, algorithm: cha.algorithm, sessAlgorithm: strings.HasSuffix(cha.algorithm, "-sess"), opaque: cha.opaque, userHash: cha.userHash, } if cha.isQopSupported(qopAuthInt) { if err := dt.prepareBody(req); err != nil { return nil, fmt.Errorf("resty: digest: failed to prepare body for auth-int: %w", err) } body, err := req.GetBody() if err != nil { return nil, fmt.Errorf("resty: digest: failed to get body for auth-int: %w", err) } if body != http.NoBody { defer closeq(body) h := newHashFunc(cha.algorithm) if _, err := ioCopy(h, body); err != nil { return nil, err } cred.bodyHash = hex.EncodeToString(h.Sum(nil)) } } return cred, nil } func (dt *digestTransport) prepareBody(req *http.Request) error { if req.GetBody != nil { return nil } if req.Body == nil || req.Body == http.NoBody { req.GetBody = func() (io.ReadCloser, error) { return http.NoBody, nil } return nil } b, err := ioReadAll(req.Body) if err != nil { return err } closeq(req.Body) req.Body = io.NopCloser(bytes.NewReader(b)) req.GetBody = func() (io.ReadCloser, error) { return io.NopCloser(bytes.NewReader(b)), nil } return nil } type digestChallenge struct { realm string domain string nonce string opaque string stale string algorithm string qop []string nc int userHash string } func (dc *digestChallenge) isQopSupported(qop string) bool { for _, v := range dc.qop { if v == qop { return true } } return false } func (dc *digestChallenge) setValue(k, v string) error { switch k { case "realm": dc.realm = v case "domain": dc.domain = v case "nonce": dc.nonce = v case "opaque": dc.opaque = v case "stale": dc.stale = v case "algorithm": dc.algorithm = v case "qop": if !isStringEmpty(v) { dc.qop = strings.Split(v, ",") } case "charset": if strings.ToUpper(v) != "UTF-8" { return ErrDigestInvalidCharset } case "nc": nc, err := strconv.ParseInt(v, 16, 32) if err != nil { return fmt.Errorf("resty: digest: invalid nc: %w", err) } dc.nc = int(nc) case "userhash": dc.userHash = v default: return ErrDigestBadChallenge } return nil } type digestCredentials struct { username string password string userHash string method string uri string realm string nonce string algorithm string sessAlgorithm bool cnonce string opaque string qop string nc int response string bodyHash string } func (dc *digestCredentials) parseQop(cha *digestChallenge) error { if len(cha.qop) == 0 { return nil } if cha.isQopSupported(qopAuth) { dc.qop = qopAuth return nil } if cha.isQopSupported(qopAuthInt) { dc.qop = qopAuthInt return nil } return ErrDigestQopNotSupported } func (dc *digestCredentials) h(data string) string { h := newHashFunc(dc.algorithm) _, _ = h.Write([]byte(data)) return hex.EncodeToString(h.Sum(nil)) } func (dc *digestCredentials) digest(cha *digestChallenge) (string, error) { if _, ok := digestHashFuncs[dc.algorithm]; !ok { return "", ErrDigestAlgNotSupported } if err := dc.parseQop(cha); err != nil { return "", err } dc.nc++ b := make([]byte, 16) _, _ = io.ReadFull(rand.Reader, b) dc.cnonce = hex.EncodeToString(b) ha1 := dc.ha1() ha2 := dc.ha2() var resp string switch dc.qop { case "": resp = fmt.Sprintf("%s:%s:%s", ha1, dc.nonce, ha2) case qopAuth, qopAuthInt: resp = fmt.Sprintf("%s:%s:%08x:%s:%s:%s", ha1, dc.nonce, dc.nc, dc.cnonce, dc.qop, ha2) } dc.response = dc.h(resp) return "Digest " + dc.String(), nil } // https://datatracker.ietf.org/doc/html/rfc7616#section-3.4.2 func (dc *digestCredentials) ha1() string { a1 := dc.h(fmt.Sprintf("%s:%s:%s", dc.username, dc.realm, dc.password)) if dc.sessAlgorithm { return dc.h(fmt.Sprintf("%s:%s:%s", a1, dc.nonce, dc.cnonce)) } return a1 } // https://datatracker.ietf.org/doc/html/rfc7616#section-3.4.3 func (dc *digestCredentials) ha2() string { if dc.qop == qopAuthInt { return dc.h(fmt.Sprintf("%s:%s:%s", dc.method, dc.uri, dc.bodyHash)) } return dc.h(fmt.Sprintf("%s:%s", dc.method, dc.uri)) } func (dc *digestCredentials) String() string { sl := make([]string, 0, 10) // https://datatracker.ietf.org/doc/html/rfc7616#section-3.4.4 if dc.userHash == "true" { dc.username = dc.h(fmt.Sprintf("%s:%s", dc.username, dc.realm)) } sl = append(sl, fmt.Sprintf(`username="%s"`, dc.username)) sl = append(sl, fmt.Sprintf(`realm="%s"`, dc.realm)) sl = append(sl, fmt.Sprintf(`nonce="%s"`, dc.nonce)) sl = append(sl, fmt.Sprintf(`uri="%s"`, dc.uri)) if dc.algorithm != "" { sl = append(sl, fmt.Sprintf(`algorithm=%s`, dc.algorithm)) } if dc.opaque != "" { sl = append(sl, fmt.Sprintf(`opaque="%s"`, dc.opaque)) } if dc.qop != "" { sl = append(sl, fmt.Sprintf("qop=%s", dc.qop)) sl = append(sl, fmt.Sprintf("nc=%08x", dc.nc)) sl = append(sl, fmt.Sprintf(`cnonce="%s"`, dc.cnonce)) } sl = append(sl, fmt.Sprintf(`userhash=%s`, dc.userHash)) sl = append(sl, fmt.Sprintf(`response="%s"`, dc.response)) return strings.Join(sl, ", ") } func newHashFunc(algorithm string) hash.Hash { hf := digestHashFuncs[algorithm] h := hf() h.Reset() return h } ================================================ FILE: digest_test.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty import ( "errors" "io" "net/http" "strings" "testing" ) type digestServerConfig struct { realm, qop, nonce, opaque, algo, uri, charset, username, password, nc string } func defaultDigestServerConf() *digestServerConfig { return &digestServerConfig{ realm: "testrealm@host.com", qop: "auth", nonce: "dcd98b7102dd2f0e8b11d0f600bfb0c093", opaque: "5ccc069c403ebaf9f0171e9517f40e41", algo: "MD5", uri: "/dir/index.html", charset: "utf-8", username: "Mufasa", password: "Circle Of Life", nc: "00000001", } } func TestClientDigestAuth(t *testing.T) { conf := *defaultDigestServerConf() ts := createDigestServer(t, &conf) defer ts.Close() c := dcnl(). SetBaseURL(ts.URL+"/"). SetDigestAuth(conf.username, conf.password) resp, err := c.R(). SetResult(&AuthSuccess{}). Get(conf.uri) assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) } func TestClientDigestAuthSession(t *testing.T) { conf := *defaultDigestServerConf() conf.algo = "MD5-sess" conf.qop = "auth, auth-int" ts := createDigestServer(t, &conf) defer ts.Close() c := dcnl(). SetBaseURL(ts.URL+"/"). SetDigestAuth(conf.username, conf.password) resp, err := c.R(). SetResult(&AuthSuccess{}). Get(conf.uri) assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) } func TestClientDigestAuthErrors(t *testing.T) { type test struct { mutateConf func(*digestServerConfig) expect error } tests := []test{ {mutateConf: func(c *digestServerConfig) { c.algo = "BAD_ALGO" }, expect: ErrDigestAlgNotSupported}, {mutateConf: func(c *digestServerConfig) { c.qop = "bad-qop" }, expect: ErrDigestQopNotSupported}, {mutateConf: func(c *digestServerConfig) { c.charset = "utf-16" }, expect: ErrDigestInvalidCharset}, {mutateConf: func(c *digestServerConfig) { c.uri = "/bad" }, expect: ErrDigestBadChallenge}, {mutateConf: func(c *digestServerConfig) { c.uri = "/unknown_param" }, expect: ErrDigestBadChallenge}, {mutateConf: func(c *digestServerConfig) { c.uri = "/missing_value" }, expect: ErrDigestBadChallenge}, {mutateConf: func(c *digestServerConfig) { c.uri = "/unclosed_quote" }, expect: ErrDigestBadChallenge}, {mutateConf: func(c *digestServerConfig) { c.uri = "/no_challenge" }, expect: ErrDigestBadChallenge}, {mutateConf: func(c *digestServerConfig) { c.uri = "/status_500" }, expect: nil}, } for _, tc := range tests { conf := *defaultDigestServerConf() tc.mutateConf(&conf) ts := createDigestServer(t, &conf) c := dcnl(). SetBaseURL(ts.URL+"/"). SetDigestAuth(conf.username, conf.password) _, err := c.R().Get(conf.uri) assertErrorIs(t, tc.expect, err) ts.Close() } } func TestClientDigestAuthWithBody(t *testing.T) { conf := *defaultDigestServerConf() ts := createDigestServer(t, &conf) defer ts.Close() c := dcnl().SetDigestAuth(conf.username, conf.password) resp, err := c.R(). SetResult(&AuthSuccess{}). SetHeader(hdrContentTypeKey, "application/json"). SetBody(map[string]any{"zip_code": "00000", "city": "Los Angeles"}). Post(ts.URL + conf.uri) resObj := resp.Result().(*AuthSuccess) assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, resObj.ID, "success") assertEqual(t, resObj.Message, "login successful") } func TestClientDigestAuthWithBodyQopAuthInt(t *testing.T) { conf := *defaultDigestServerConf() conf.qop = "auth-int" ts := createDigestServer(t, &conf) defer ts.Close() c := dcnl().SetDigestAuth(conf.username, conf.password) resp, err := c.R(). SetResult(&AuthSuccess{}). SetHeader(hdrContentTypeKey, "application/json"). SetBody(map[string]any{"zip_code": "00000", "city": "Los Angeles"}). Post(ts.URL + conf.uri) assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) } func TestClientDigestAuthWithBodyQopAuthIntIoCopyError(t *testing.T) { conf := *defaultDigestServerConf() conf.qop = "auth-int" ts := createDigestServer(t, &conf) defer ts.Close() c := dcnl().SetDigestAuth(conf.username, conf.password) errCopyMsg := "test copy error" ioCopy = func(dst io.Writer, src io.Reader) (written int64, err error) { return 0, errors.New(errCopyMsg) } t.Cleanup(func() { ioCopy = io.Copy }) resp, err := c.R(). SetResult(&AuthSuccess{}). SetHeader(hdrContentTypeKey, "application/json"). SetBody(map[string]any{"zip_code": "00000", "city": "Los Angeles"}). Post(ts.URL + conf.uri) assertNotNil(t, err) assertTrue(t, strings.Contains(err.Error(), errCopyMsg), "expected io copy error") assertEqual(t, 0, resp.StatusCode(), "expected response status code to be zero on error") } func TestClientDigestAuthRoundTripError(t *testing.T) { conf := *defaultDigestServerConf() ts := createDigestServer(t, &conf) defer ts.Close() c := dcnl().SetTransport(&CustomRoundTripper2{returnErr: true}) c.SetDigestAuth(conf.username, conf.password) _, err := c.R(). SetResult(&AuthSuccess{}). SetHeader(hdrContentTypeKey, "application/json"). SetBody(map[string]any{"zip_code": "00000", "city": "Los Angeles"}). Post(ts.URL + conf.uri) assertNotNil(t, err) assertTrue(t, strings.Contains(err.Error(), "test req mock error"), "expected round trip error") } func TestClientDigestAuthWithBodyQopAuthIntGetBodyNil(t *testing.T) { conf := *defaultDigestServerConf() conf.qop = "auth-int" ts := createDigestServer(t, &conf) defer ts.Close() c := dcnl().SetDigestAuth(conf.username, conf.password) c.SetRequestMiddlewares( MiddlewareRequestCreate, func(c *Client, r *Request) error { r.RawRequest.GetBody = nil return nil }, ) resp, err := c.R(). SetResult(&AuthSuccess{}). SetHeader(hdrContentTypeKey, "application/json"). SetBody(map[string]any{"zip_code": "00000", "city": "Los Angeles"}). Post(ts.URL + conf.uri) assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) } func TestClientDigestAuthWithGetBodyError(t *testing.T) { conf := *defaultDigestServerConf() conf.qop = "auth-int" ts := createDigestServer(t, &conf) defer ts.Close() c := dcnl().SetDigestAuth(conf.username, conf.password) c.SetRequestMiddlewares( MiddlewareRequestCreate, func(c *Client, r *Request) error { r.RawRequest.GetBody = func() (_ io.ReadCloser, _ error) { return nil, errors.New("get body test error") } return nil }, ) resp, err := c.R(). SetResult(&AuthSuccess{}). SetHeader(hdrContentTypeKey, "application/json"). SetBody(map[string]any{"zip_code": "00000", "city": "Los Angeles"}). Post(ts.URL + conf.uri) assertNotNil(t, err) assertTrue(t, strings.Contains(err.Error(), "resty: digest: failed to get body for auth-int: get body test error"), "expected get body error") assertEqual(t, 0, resp.StatusCode(), "expected response status code to be zero on error") } func TestClientDigestAuthWithGetBodyNilReadError(t *testing.T) { conf := *defaultDigestServerConf() conf.qop = "auth-int" ts := createDigestServer(t, &conf) defer ts.Close() c := dcnl().SetDigestAuth(conf.username, conf.password) c.SetRequestMiddlewares( MiddlewareRequestCreate, func(c *Client, r *Request) error { r.RawRequest.GetBody = nil return nil }, ) resp, err := c.R(). SetResult(&AuthSuccess{}). SetHeader(hdrContentTypeKey, "application/json"). SetBody(&brokenReadCloser{}). Post(ts.URL + conf.uri) assertNotNil(t, err) assertTrue(t, strings.Contains(err.Error(), "resty: digest: failed to prepare body for auth-int: read error"), "expected read error") assertEqual(t, 0, resp.StatusCode(), "expected response status code to be zero on error") } func TestClientDigestAuthWithNoBodyQopAuthInt(t *testing.T) { conf := *defaultDigestServerConf() conf.qop = "auth-int" ts := createDigestServer(t, &conf) defer ts.Close() c := dcnl().SetDigestAuth(conf.username, conf.password) resp, err := c.R().Get(ts.URL + conf.uri) assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) } func TestClientDigestAuthNoQop(t *testing.T) { conf := *defaultDigestServerConf() conf.qop = "" ts := createDigestServer(t, &conf) defer ts.Close() c := dcnl().SetDigestAuth(conf.username, conf.password) resp, err := c.R(). SetResult(&AuthSuccess{}). SetHeader(hdrContentTypeKey, "application/json"). SetBody(map[string]any{"zip_code": "00000", "city": "Los Angeles"}). Post(ts.URL + conf.uri) assertNil(t, err) assertEqual(t, "200 OK", resp.Status()) } func TestClientDigestAuthWithIncorrectNcValue(t *testing.T) { conf := *defaultDigestServerConf() conf.nc = "1234567890" ts := createDigestServer(t, &conf) defer ts.Close() c := dcnl().SetDigestAuth(conf.username, conf.password) resp, err := c.R(). SetResult(&AuthSuccess{}). SetHeader(hdrContentTypeKey, "application/json"). SetBody(map[string]any{"zip_code": "00000", "city": "Los Angeles"}). Post(ts.URL + conf.uri) assertNotNil(t, err) assertTrue(t, strings.Contains(err.Error(), `parsing "1234567890": value out of range`), "expected nc value out of range error") assertEqual(t, "", resp.Status(), "expected empty response status on error") } ================================================ FILE: go.mod ================================================ module resty.dev/v3 go 1.23.0 require golang.org/x/net v0.43.0 ================================================ FILE: go.sum ================================================ golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= ================================================ FILE: hedging.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // 2025 Ahmet Demir (https://github.com/ahmet2mir) // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty // This hedging implementation draws inspiration from the reference provided here: https://github.com/cristalhq/hedgedhttp. import ( "context" "net/http" "sync" "time" ) // NewHedging creates a new Hedging instance with default configuration. // By default values are: // - 50ms delay between requests // - Maximum 3 requests // - Maximum 3 requests per second // - Only read-only methods are hedged // // You can customize these settings using the corresponding setter methods. // For example: // // hedging := NewHedging(). // SetDelay(100 * time.Millisecond). // SetMaxRequest(5). // SetMaxRequestPerSecond(10) // // // Assign the hedging instance to the Resty client // client := resty.New(). // SetHedging(hedging) // // defer c.Close() func NewHedging() *Hedging { h := &Hedging{ lock: new(sync.RWMutex), delay: 50 * time.Millisecond, // delay between requests maxRequest: 3, // max requests maxRequestPerSecond: 3, // max requests per second isNonReadOnlyAllowed: false, // only hedge read-only methods by default } h.calculateRateDelay() return h } // Hedging struct implements the http.RoundTripper interface to perform hedged HTTP requests. // It sends multiple requests in parallel with a specified delay and returns the first successful // response. Hedging is particularly useful for improving latency and reliability in scenarios // where requests may occasionally fail or experience high latency. // // By default only read-only HTTP methods (GET, HEAD, OPTIONS, TRACE) are hedged to avoid unintended // side effects on the server. Unless SetHedgingAllowNonReadOnly is used to allow non-read-only methods, // in which case all HTTP methods will be hedged. // // NOTE: // - Hedging should be used with caution, especially for non-read-only methods, as it can lead to // unintended consequences if multiple requests are processed by the server. // - Ensure that the server can safely handle multiple concurrent requests when using hedging, // as otherwise, hedging requests can overwhelm the server. // // For more information on hedging and its use cases, refer to the following resources: // - [The Tail at Scale] // // [The Tail at Scale]: https://research.google/pubs/the-tail-at-scale/ type Hedging struct { lock *sync.RWMutex transport http.RoundTripper delay time.Duration maxRequest int maxRequestPerSecond float64 rateDelay time.Duration // delay between requests based on maxPerSecond isNonReadOnlyAllowed bool } // Delay method returns the configured hedging delay. func (h *Hedging) Delay() time.Duration { h.lock.RLock() defer h.lock.RUnlock() return h.delay } // SetDelay method sets the delay between hedged requests. func (h *Hedging) SetDelay(delay time.Duration) *Hedging { h.lock.Lock() defer h.lock.Unlock() h.delay = delay return h } // MaxRequest method returns the maximum concurrent requests. func (h *Hedging) MaxRequest() int { h.lock.RLock() defer h.lock.RUnlock() return h.maxRequest } // SetMaxRequest method sets maximum concurrent hedged requests. func (h *Hedging) SetMaxRequest(count int) *Hedging { h.lock.Lock() defer h.lock.Unlock() h.maxRequest = count return h } // MaxRequestPerSecond method returns the hedging rate limit. func (h *Hedging) MaxRequestPerSecond() float64 { h.lock.RLock() defer h.lock.RUnlock() return h.maxRequestPerSecond } // SetMaxRequestPerSecond method sets rate limit for hedged requests. func (h *Hedging) SetMaxRequestPerSecond(count float64) *Hedging { h.lock.Lock() defer h.lock.Unlock() h.maxRequestPerSecond = count h.calculateRateDelay() return h } // IsNonReadOnlyAllowed method returns true if hedging is enabled for non-read-only // HTTP methods. func (h *Hedging) IsNonReadOnlyAllowed() bool { h.lock.RLock() defer h.lock.RUnlock() return h.isNonReadOnlyAllowed } // SetNonReadOnlyAllowed method allows hedging for non-read-only HTTP methods. // By default, only read-only methods (GET, HEAD, OPTIONS, TRACE) are hedged. // // NOTE: // - Use this with caution as hedging write operations can lead to duplicates. func (h *Hedging) SetNonReadOnlyAllowed(allow bool) *Hedging { h.lock.Lock() defer h.lock.Unlock() h.isNonReadOnlyAllowed = allow return h } // calculateRateDelay method calculates the delay between requests based on the maxPerSecond setting. // If maxPerSecond is greater than 0, it sets rateDelay to 1 second divided by maxPerSecond. // Otherwise, it sets rateDelay to 0 (no delay). // // NOTE: It should be called within lock region. func (h *Hedging) calculateRateDelay() { if h.maxRequestPerSecond > 0 { // Calculate rate delay: if maxPerSecond is 10, delay is 100ms (1s / 10) h.rateDelay = time.Duration(float64(time.Second) / h.maxRequestPerSecond) } else { h.rateDelay = 0 // no delay if maxPerSecond is 0 or negative } } func (ht *Hedging) RoundTrip(req *http.Request) (*http.Response, error) { if !ht.isNonReadOnlyAllowed && !isReadOnlyMethod(req.Method) { return ht.transport.RoundTrip(req) } if ht.MaxRequest() <= 1 { return ht.transport.RoundTrip(req) } ctx := req.Context() deadline, hasDeadline := ctx.Deadline() // Derive hedgeCtx from the original request context to respect cancellations var ( hedgeCtx context.Context cancel context.CancelFunc ) if hasDeadline { // Use original deadline for the race (first to complete wins) remaining := time.Until(deadline) if remaining > 0 { hedgeCtx, cancel = context.WithTimeout(ctx, remaining) } else { // Deadline already expired, use context with cancel hedgeCtx, cancel = context.WithCancel(ctx) } } else { // No deadline in original context, create cancellable context from it hedgeCtx, cancel = context.WithCancel(ctx) } // defer cancel() ensures cleanup on all paths (timeout, cancellation, or normal return) // cancel() may also be called inside once.Do() when a request wins, but calling it // multiple times is safe and ensures the context is canceled as soon as any goroutine completes defer cancel() type result struct { resp *http.Response err error } ht.lock.RLock() maxReq := ht.maxRequest delay := ht.delay rateDelay := ht.rateDelay ht.lock.RUnlock() resultCh := make(chan result, maxReq) var once sync.Once for i := range maxReq { if i > 0 { if delay > 0 { select { case <-time.After(delay): case <-hedgeCtx.Done(): break } } // Rate limiting: add delay between requests based on maxPerSecond // to prevent overwhelming the server. if rateDelay > 0 { select { case <-time.After(rateDelay): case <-hedgeCtx.Done(): break } } } go func() { hedgedReq := req.Clone(hedgeCtx) resp, err := ht.transport.RoundTrip(hedgedReq) won := false once.Do(func() { won = true resultCh <- result{resp: resp, err: err} // Cancel inside once.Do() to stop other goroutines immediately when a request wins // defer cancel() ensures cleanup even if no request completes successfully cancel() }) if !won && resp != nil && resp.Body != nil { drainReadCloser(resp.Body) } }() } res := <-resultCh close(resultCh) return res.resp, res.err } // isReadOnlyMethod verifies if the HTTP method is read-only (safe for hedging) func isReadOnlyMethod(method string) bool { switch method { case MethodGet, MethodHead, MethodOptions, MethodTrace: return true default: return false } } ================================================ FILE: hedging_test.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty import ( "context" "fmt" "net/http" "net/http/httptest" "sync" "sync/atomic" "testing" "time" ) func createHedgingTestServer(t *testing.T, attemptCount *int32) *httptest.Server { timeouts := [5]time.Duration{800 * time.Millisecond, 400 * time.Millisecond, 10 * time.Millisecond, 5 * time.Millisecond, 1 * time.Millisecond} return createTestServer(func(w http.ResponseWriter, r *http.Request) { attempt := atomic.AddInt32(attemptCount, 1) time.Sleep(timeouts[attempt-1]) w.Header().Set("X-Attempt", fmt.Sprintf("%d", attempt)) _, _ = fmt.Fprintf(w, "Attempt %d", attempt) }) } func TestHedgingBasic(t *testing.T) { var attemptCount int32 ts := createHedgingTestServer(t, &attemptCount) defer ts.Close() const maxRequests = 3 h := NewHedging(). SetDelay(10 * time.Millisecond). SetMaxRequest(3). SetMaxRequestPerSecond(0) c := dcnl().SetHedging(h) resp, err := c.R().Get(ts.URL) assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, int32(maxRequests), atomic.LoadInt32(&attemptCount), "total attempts should match max requests") } func TestHedgingSecondWins(t *testing.T) { var attemptCount int32 winnerAttempt := atomic.Int32{} timeouts := [2]time.Duration{400 * time.Millisecond, 20 * time.Millisecond} ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { attempt := atomic.AddInt32(&attemptCount, 1) time.Sleep(timeouts[attempt-1]) winnerAttempt.CompareAndSwap(0, attempt) w.Header().Set("X-Attempt", fmt.Sprintf("%d", attempt)) w.WriteHeader(http.StatusOK) fmt.Fprintf(w, "Attempt %d", attempt) }) defer ts.Close() h := NewHedging(). SetDelay(10 * time.Millisecond). SetMaxRequest(2). SetMaxRequestPerSecond(0) c := dcnl().SetHedging(h) resp, err := c.R().Get(ts.URL) assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) winnerRequest := winnerAttempt.Load() assertEqual(t, fmt.Sprintf("Attempt %d", winnerRequest), resp.String(), "expected second attempt to win") assertEqual(t, int32(2), winnerRequest, "expected second request to win") assertEqual(t, int32(2), atomic.LoadInt32(&attemptCount), "total attempts should be 2") } func TestHedgingTimeout(t *testing.T) { var attemptCount int32 requestTimes := make([]time.Time, 0, 3) var timesLock atomic.Value timesLock.Store(requestTimes) ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { attempt := atomic.AddInt32(&attemptCount, 1) now := time.Now() times := timesLock.Load().([]time.Time) times = append(times, now) timesLock.Store(times) if attempt == 1 { time.Sleep(300 * time.Millisecond) } w.WriteHeader(http.StatusOK) fmt.Fprintf(w, "Attempt %d", attempt) }) defer ts.Close() delay := 50 * time.Millisecond h := NewHedging(). SetDelay(delay). SetMaxRequest(3). SetMaxRequestPerSecond(0) c := dcnl().SetHedging(h) resp, err := c.R().Get(ts.URL) assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) time.Sleep(200 * time.Millisecond) times := timesLock.Load().([]time.Time) if len(times) >= 2 { diff := times[1].Sub(times[0]) if diff < delay || diff > delay+30*time.Millisecond { t.Logf("Expected delay between requests to be ~%v, got %v", delay, diff) } } } func TestHedgingReadOnlyMethodsOnly(t *testing.T) { var attemptCount int32 ts := createHedgingTestServer(t, &attemptCount) defer ts.Close() h := NewHedging(). SetDelay(10 * time.Millisecond). SetMaxRequest(3). SetMaxRequestPerSecond(0) c := dcnl().SetHedging(h) testCases := []struct { method string expectHedging bool requestFunc func(*Client, string) (*Response, error) }{ {MethodGet, true, func(c *Client, url string) (*Response, error) { return c.R().Get(url) }}, {MethodHead, true, func(c *Client, url string) (*Response, error) { return c.R().Head(url) }}, {MethodOptions, true, func(c *Client, url string) (*Response, error) { return c.R().Options(url) }}, {MethodPost, false, func(c *Client, url string) (*Response, error) { return c.R().Post(url) }}, {MethodPut, false, func(c *Client, url string) (*Response, error) { return c.R().Put(url) }}, {MethodPatch, false, func(c *Client, url string) (*Response, error) { return c.R().Patch(url) }}, {MethodDelete, false, func(c *Client, url string) (*Response, error) { return c.R().Delete(url) }}, } for _, tc := range testCases { t.Run(tc.method, func(t *testing.T) { atomic.StoreInt32(&attemptCount, 0) resp, err := tc.requestFunc(c, ts.URL) assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) time.Sleep(20 * time.Millisecond) count := atomic.LoadInt32(&attemptCount) if tc.expectHedging { assertNotEqual(t, 1, count, fmt.Sprintf("%s: expected hedging with multiple requests, got %d request(s)", tc.method, count)) } else { assertEqual(t, int32(1), count, fmt.Sprintf("%s: no hedging 1 request only", tc.method)) } }) } } func TestHedgingRateLimit(t *testing.T) { var attemptCount int32 ts := createHedgingTestServer(t, &attemptCount) defer ts.Close() h := NewHedging(). SetDelay(10 * time.Millisecond). SetMaxRequest(10). SetMaxRequestPerSecond(5.0) c := dcnl().SetHedging(h) start := time.Now() resp, err := c.R().Get(ts.URL) assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) duration := time.Since(start) if duration < 200*time.Millisecond { t.Logf("Rate limiting may have limited hedged requests. Duration: %v, Attempts: %d", duration, atomic.LoadInt32(&attemptCount)) } } func TestHedgingWithRetryFallback(t *testing.T) { c := dcnl() // Set retry first c.SetRetryCount(2) assertEqual(t, 2, c.RetryCount()) h := NewHedging(). SetDelay(50 * time.Millisecond). SetMaxRequest(3). SetMaxRequestPerSecond(0) // Enable hedging should disable retry by default c.SetHedging(h) assertEqual(t, 0, c.RetryCount()) // But user can re-enable retry as fallback c.SetRetryCount(1) assertEqual(t, 1, c.RetryCount()) assertEqual(t, true, c.isHedgingEnabled()) // Disable hedging c.SetHedging(nil) assertEqual(t, false, c.isHedgingEnabled()) assertEqual(t, 1, c.RetryCount()) // Retry count should remain } func TestHedgingDisable(t *testing.T) { var attemptCount int32 ts := createHedgingTestServer(t, &attemptCount) defer ts.Close() h := NewHedging(). SetDelay(10 * time.Millisecond). SetMaxRequest(3). SetMaxRequestPerSecond(0) c := dcnl() c.SetHedging(h) assertEqual(t, true, c.isHedgingEnabled()) c.SetHedging(nil) assertEqual(t, false, c.isHedgingEnabled()) atomic.StoreInt32(&attemptCount, 0) resp, err := c.R().Get(ts.URL) assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) time.Sleep(50 * time.Millisecond) assertEqual(t, int32(1), atomic.LoadInt32(&attemptCount)) } func TestHedgingContextCancellation(t *testing.T) { attemptCount := atomic.Int32{} startedCount := atomic.Int32{} ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { startedCount.Add(1) time.Sleep(200 * time.Millisecond) attemptCount.Add(1) w.WriteHeader(http.StatusOK) }) defer ts.Close() h := NewHedging(). SetDelay(10 * time.Millisecond). SetMaxRequest(3). SetMaxRequestPerSecond(0) c := dcnl().SetHedging(h) ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) defer cancel() _, err := c.R().SetContext(ctx).Get(ts.URL) assertErrorIs(t, context.DeadlineExceeded, err) time.Sleep(50 * time.Millisecond) started := startedCount.Load() completed := attemptCount.Load() assertTrue(t, started > 1, "expected multiple hedged request to start") assertEqual(t, int32(0), completed, "context cancellation should have prevented completion") } func TestHedgingConfiguration(t *testing.T) { h := NewHedging(). SetDelay(50 * time.Millisecond). SetMaxRequest(3). SetMaxRequestPerSecond(10.0) assertEqual(t, 50*time.Millisecond, h.Delay()) assertEqual(t, 3, h.MaxRequest()) assertEqual(t, 10.0, h.MaxRequestPerSecond()) // Now we can update individual settings h.SetDelay(100 * time.Millisecond) assertEqual(t, 100*time.Millisecond, h.Delay()) h.SetMaxRequest(5) assertEqual(t, 5, h.MaxRequest()) h.SetMaxRequestPerSecond(20.0) assertEqual(t, 20.0, h.MaxRequestPerSecond()) } func TestHedgingConfigurationViaClient(t *testing.T) { c := dcnl() // Setters require hedging to be enabled first assertEqual(t, false, c.isHedgingEnabled()) h := NewHedging(). SetDelay(50 * time.Millisecond). SetMaxRequest(3). SetMaxRequestPerSecond(10.0) c.SetHedging(h) assertEqual(t, true, c.isHedgingEnabled()) assertEqual(t, 50*time.Millisecond, c.Hedging().Delay()) assertEqual(t, 3, c.Hedging().MaxRequest()) assertEqual(t, 10.0, c.Hedging().MaxRequestPerSecond()) // Now we can update individual settings c.Hedging().SetDelay(100 * time.Millisecond) assertEqual(t, 100*time.Millisecond, c.Hedging().Delay()) c.Hedging().SetMaxRequest(5) assertEqual(t, 5, c.Hedging().MaxRequest()) c.Hedging().SetMaxRequestPerSecond(20.0) assertEqual(t, 20.0, c.Hedging().MaxRequestPerSecond()) } func TestHedgingWithCustomTransport(t *testing.T) { var attemptCount int32 ts := createHedgingTestServer(t, &attemptCount) defer ts.Close() customTransport := &http.Transport{} c := NewWithClient(&http.Client{Transport: customTransport}) h := NewHedging(). SetDelay(10 * time.Millisecond). SetMaxRequest(3). SetMaxRequestPerSecond(0) c.SetHedging(h) resp, err := c.R().Get(ts.URL) assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, int32(3), atomic.LoadInt32(&attemptCount), "Expected 3 attempts with hedging enabled") // disable hedging and verify transport is unwrapped c.SetHedging(nil) _, ok := c.httpClient.Transport.(*Hedging) assertFalse(t, ok, "transport should be unwrapped after disabling hedging") } func TestHedgingSingleRequest(t *testing.T) { var attemptCount int32 ts := createHedgingTestServer(t, &attemptCount) defer ts.Close() h := NewHedging(). SetDelay(10 * time.Millisecond). SetMaxRequest(1). SetMaxRequestPerSecond(0) c := dcnl().SetHedging(h) resp, err := c.R().Get(ts.URL) assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, int32(1), atomic.LoadInt32(&attemptCount)) } func TestHedgingAllowNonReadOnly(t *testing.T) { var attemptCount int32 ts := createHedgingTestServer(t, &attemptCount) defer ts.Close() h := NewHedging(). SetDelay(10 * time.Millisecond). SetMaxRequest(3). SetMaxRequestPerSecond(0) c := dcnl().SetHedging(h) // By default, non-read-only methods should not be hedged assertEqual(t, false, c.Hedging().IsNonReadOnlyAllowed()) // Test POST without allowing non-read-only atomic.StoreInt32(&attemptCount, 0) resp, err := c.R().Post(ts.URL) assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, int32(1), atomic.LoadInt32(&attemptCount), "no hedging for POST without allow flag") // Enable non-read-only methods c.Hedging().SetNonReadOnlyAllowed(true) assertEqual(t, true, c.Hedging().IsNonReadOnlyAllowed()) // Test POST with allowing non-read-only atomic.StoreInt32(&attemptCount, 0) resp, err = c.R().Post(ts.URL) assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, int32(3), atomic.LoadInt32(&attemptCount), "hedging for POST with allow flag") } func TestHedgingWithNilTransport(t *testing.T) { var attemptCount int32 ts := createHedgingTestServer(t, &attemptCount) defer ts.Close() // Create client with nil transport c := NewWithClient(&http.Client{Transport: nil}) h := NewHedging(). SetDelay(10 * time.Millisecond). SetMaxRequest(3). SetMaxRequestPerSecond(0) c.SetHedging(h) resp, err := c.R().Get(ts.URL) assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, int32(3), atomic.LoadInt32(&attemptCount), "hedging with nil transport should still work") } func TestHedgingEnableMultipleTimes(t *testing.T) { var attemptCount int32 ts := createHedgingTestServer(t, &attemptCount) defer ts.Close() h := NewHedging(). SetDelay(10 * time.Millisecond). SetMaxRequest(3). SetMaxRequestPerSecond(0) c := dcnl() // Enable hedging first time c.SetHedging(h) assertEqual(t, true, c.isHedgingEnabled()) // Enable hedging again without disabling - should handle already wrapped transport nh := NewHedging(). SetDelay(30 * time.Millisecond). SetMaxRequest(5). SetMaxRequestPerSecond(10.0) c.SetHedging(nh) assertEqual(t, true, c.isHedgingEnabled()) assertEqual(t, 30*time.Millisecond, c.Hedging().Delay()) assertEqual(t, 5, c.Hedging().MaxRequest()) assertEqual(t, 10.0, c.Hedging().MaxRequestPerSecond()) // Verify hedging still works resp, err := c.R().Get(ts.URL) assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, int32(3), atomic.LoadInt32(&attemptCount), "expected hedging after re-enabling") } func TestHedgingWrapWithDisabledHedging(t *testing.T) { c := dcnl() h := NewHedging(). SetDelay(20 * time.Millisecond). SetMaxRequest(3). SetMaxRequestPerSecond(0) // Enable and then disable hedging c.SetHedging(h) assertEqual(t, true, c.isHedgingEnabled()) c.SetHedging(nil) assertEqual(t, false, c.isHedgingEnabled()) // Verify transport is not a hedgingTransport _, ok := c.httpClient.Transport.(*Hedging) assertFalse(t, ok, "transport should not be hedging transport") } func TestHedgingRateDelayBetweenRequests(t *testing.T) { requestTimes := make([]time.Time, 0, 3) var mu sync.Mutex ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { mu.Lock() requestTimes = append(requestTimes, time.Now()) mu.Unlock() // Slow response to ensure multiple hedged requests are sent time.Sleep(500 * time.Millisecond) w.WriteHeader(http.StatusOK) }) defer ts.Close() c := dcnl() // delay=10ms, maxRequest=3, maxRequestPerSecond=5.0 (rateDelay = 200ms) // Expected timing: req1 at 0, req2 at ~10ms + 200ms = ~210ms, req3 at ~420ms h := NewHedging(). SetDelay(10 * time.Millisecond). SetMaxRequest(3). SetMaxRequestPerSecond(5.0) c.SetHedging(h) _, err := c.R().Get(ts.URL) assertError(t, err) // Wait for all requests to be recorded time.Sleep(600 * time.Millisecond) mu.Lock() times := make([]time.Time, len(requestTimes)) copy(times, requestTimes) mu.Unlock() assertEqual(t, 3, len(times), "expected 3 hedged requests to be sent") // Verify rate delay was applied between requests // With maxPerSecond=5.0, rateDelay should be 200ms // The gap between requests should be at least rateDelay (200ms) expectedRateDelay := 200 * time.Millisecond tolerance := 50 * time.Millisecond for i := 1; i < len(times); i++ { gap := times[i].Sub(times[i-1]) // Gap should be >= (delay + rateDelay) - tolerance minExpectedGap := expectedRateDelay - tolerance if gap < minExpectedGap { t.Errorf("Gap between request %d and %d was %v, expected at least %v (rate delay should be ~%v)", i-1, i, gap, minExpectedGap, expectedRateDelay) } } } func TestHedgingNoDoubleWrap(t *testing.T) { h1 := NewHedging().SetDelay(50 * time.Millisecond) h2 := NewHedging().SetDelay(100 * time.Millisecond) c := dcnl() // Enable hedging first time c.SetHedging(h1) _, ok := c.httpClient.Transport.(*Hedging) assertTrue(t, ok, "Hedging transport") // Enable different hedging without disabling first c.SetHedging(h2) // Both should be Hedging hedging2, ok := c.httpClient.Transport.(*Hedging) assertTrue(t, ok, "Hedging transport") // The wrapped transport should NOT be another Hedging _, isHedging := hedging2.transport.(*Hedging) assertFalse(t, isHedging, "Double-wrapped hedging detected - transport should be unwrapped") // Verify transport chain depth, should only have one Hedging layer if hedging, ok := c.httpClient.Transport.(*Hedging); ok { _, isHedging := hedging.transport.(*Hedging) assertFalse(t, isHedging, "Double-wrapped hedging detected") } // Verify the configuration is the new one assertEqual(t, hedging2.Delay(), 100*time.Millisecond, "Expected 100ms delay") } func TestHedgingRoundTripDeadlineExpired(t *testing.T) { var attemptCount int32 ts := createHedgingTestServer(t, &attemptCount) defer ts.Close() h := NewHedging(). SetDelay(10 * time.Millisecond). SetMaxRequest(3). SetMaxRequestPerSecond(0) c := dcnl().SetHedging(h) ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-1*time.Millisecond)) defer cancel() _, err := c.R().SetContext(ctx).Get(ts.URL) assertErrorIs(t, context.DeadlineExceeded, err, "Expected context deadline expired error") time.Sleep(50 * time.Millisecond) assertEqual(t, int32(0), atomic.LoadInt32(&attemptCount)) } ================================================ FILE: load_balancer.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty import ( "context" "errors" "fmt" "net" "net/url" "strings" "sync" "time" ) // ErrNoBaseURLs error returned when no base URLs are found var ErrNoBaseURLs = errors.New("resty: no base URLs found") // LoadBalancer is the interface that wraps the HTTP client load-balancing // algorithm that returns the "Next" Base URL for the request to target type LoadBalancer interface { NextWithContext(ctx context.Context) (string, error) Feedback(*RequestFeedback) Close() error } // RequestFeedback struct is used to send the request feedback to load balancing // algorithm type RequestFeedback struct { BaseURL string Success bool Attempt int } // NewRoundRobin method creates the new Round-Robin(RR) request load balancer // instance with given base URLs func NewRoundRobin(baseURLs ...string) (*RoundRobin, error) { if len(baseURLs) == 0 { return nil, ErrNoBaseURLs } rr := &RoundRobin{lock: new(sync.Mutex)} if err := rr.Refresh(baseURLs...); err != nil { return rr, err } return rr, nil } var _ LoadBalancer = (*RoundRobin)(nil) // RoundRobin struct used to implement the Round-Robin(RR) request // load balancer algorithm type RoundRobin struct { lock *sync.Mutex baseURLs []string current int } // NextWithContext method returns the next Base URL based on the Round-Robin(RR) algorithm // with context support for cancellation func (rr *RoundRobin) NextWithContext(ctx context.Context) (string, error) { select { case <-ctx.Done(): return "", ctx.Err() default: } rr.lock.Lock() defer rr.lock.Unlock() if len(rr.baseURLs) == 0 { return "", ErrNoBaseURLs } baseURL := rr.baseURLs[rr.current] rr.current = (rr.current + 1) % len(rr.baseURLs) return baseURL, nil } // Feedback method does nothing in Round-Robin(RR) request load balancer func (rr *RoundRobin) Feedback(_ *RequestFeedback) {} // Close method does nothing in Round-Robin(RR) request load balancer func (rr *RoundRobin) Close() error { return nil } // Refresh method reset the existing Base URLs with the given Base URLs slice to refresh it func (rr *RoundRobin) Refresh(baseURLs ...string) error { rr.lock.Lock() defer rr.lock.Unlock() result := make([]string, 0) for _, u := range baseURLs { baseURL, err := extractBaseURL(u) if err != nil { return err } result = append(result, baseURL) } // after processing, assign the updates rr.baseURLs = result return nil } // Host struct used to represent the host information and its weight // to load balance the requests type Host struct { // BaseURL represents the targeted host base URL // https://resty.dev BaseURL string // Weight represents the host weight to determine // the percentage of requests to send Weight int // MaxFailures represents the value to mark the host as // not usable until it reaches the Recovery duration // Default value is 5 MaxFailures int state HostState currentWeight int failedRequests int } func (h *Host) addWeight() { h.currentWeight += h.Weight } func (h *Host) resetWeight(totalWeight int) { h.currentWeight -= totalWeight } type HostState int // Host transition states const ( HostStateInActive HostState = iota HostStateActive ) // HostStateChangeFunc type provides feedback on host state transitions type HostStateChangeFunc func(baseURL string, from, to HostState) // ErrNoActiveHost error returned when all hosts are inactive on the load balancer var ErrNoActiveHost = errors.New("resty: no active host") // NewWeightedRoundRobin method creates the new Weighted Round-Robin(WRR) // request load balancer instance with given recovery duration and hosts slice func NewWeightedRoundRobin(recovery time.Duration, hosts ...*Host) (*WeightedRoundRobin, error) { if recovery == 0 { recovery = 120 * time.Second // defaults to 120 seconds } wrr := &WeightedRoundRobin{ lock: new(sync.RWMutex), hosts: make([]*Host, 0), tick: time.NewTicker(recovery), recovery: recovery, } err := wrr.Refresh(hosts...) go wrr.ticker() return wrr, err } var _ LoadBalancer = (*WeightedRoundRobin)(nil) // WeightedRoundRobin struct used to represent the host details for // Weighted Round-Robin(WRR) algorithm implementation type WeightedRoundRobin struct { lock *sync.RWMutex hosts []*Host totalWeight int tick *time.Ticker onStateChange HostStateChangeFunc // Recovery duration is used to set the timer to put // the host back in the pool for the next turn and // reset the failed request count for the segment recovery time.Duration } // NextWithContext method returns the next Base URL based on Weighted Round-Robin(WRR) // with context support for cancellation func (wrr *WeightedRoundRobin) NextWithContext(ctx context.Context) (string, error) { select { case <-ctx.Done(): return "", ctx.Err() default: } wrr.lock.Lock() defer wrr.lock.Unlock() var best *Host total := 0 for _, h := range wrr.hosts { if h.state == HostStateInActive { continue } h.addWeight() total += h.Weight if best == nil || h.currentWeight > best.currentWeight { best = h } } if best == nil { return "", ErrNoActiveHost } best.resetWeight(total) return best.BaseURL, nil } // Feedback method process the request feedback for Weighted Round-Robin(WRR) // request load balancer func (wrr *WeightedRoundRobin) Feedback(f *RequestFeedback) { if f == nil { return } wrr.lock.Lock() defer wrr.lock.Unlock() for _, host := range wrr.hosts { if host.BaseURL == f.BaseURL { if !f.Success { host.failedRequests++ } if host.failedRequests >= host.MaxFailures { host.state = HostStateInActive if wrr.onStateChange != nil { wrr.onStateChange(host.BaseURL, HostStateActive, HostStateInActive) } } break } } } // Close method does the cleanup by stopping the [time.Ticker] on // Weighted Round-Robin(WRR) request load balancer func (wrr *WeightedRoundRobin) Close() error { wrr.lock.Lock() defer wrr.lock.Unlock() wrr.tick.Stop() return nil } // Refresh method reset the existing values with the given [Host] slice to refresh it func (wrr *WeightedRoundRobin) Refresh(hosts ...*Host) error { if hosts == nil { return nil } wrr.lock.Lock() defer wrr.lock.Unlock() newTotalWeight := 0 for _, h := range hosts { baseURL, err := extractBaseURL(h.BaseURL) if err != nil { return err } h.BaseURL = baseURL h.state = HostStateActive newTotalWeight += h.Weight // assign defaults if not provided if h.MaxFailures == 0 { h.MaxFailures = 5 // default value is 5 } } // after processing, assign the updates wrr.hosts = hosts wrr.totalWeight = newTotalWeight return nil } // SetOnStateChange method used to set a callback for the host transition state func (wrr *WeightedRoundRobin) SetOnStateChange(fn HostStateChangeFunc) { wrr.lock.Lock() defer wrr.lock.Unlock() wrr.onStateChange = fn } // SetRecoveryDuration method is used to change the existing recovery duration for the host func (wrr *WeightedRoundRobin) SetRecoveryDuration(d time.Duration) { wrr.lock.Lock() defer wrr.lock.Unlock() wrr.recovery = d wrr.tick.Reset(d) } func (wrr *WeightedRoundRobin) ticker() { for range wrr.tick.C { wrr.lock.Lock() hosts := make([]*Host, len(wrr.hosts)) copy(hosts, wrr.hosts) wrr.lock.Unlock() for _, host := range hosts { if host.state == HostStateInActive { host.state = HostStateActive host.failedRequests = 0 if wrr.onStateChange != nil { wrr.onStateChange(host.BaseURL, HostStateInActive, HostStateActive) } } } } } // NewSRVWeightedRoundRobin method creates a new Weighted Round-Robin(WRR) load balancer instance // with given SRV values func NewSRVWeightedRoundRobin(service, proto, domainName, httpScheme string) (*SRVWeightedRoundRobin, error) { if isStringEmpty(proto) { proto = "tcp" } if isStringEmpty(httpScheme) { httpScheme = "https" } wrr, _ := NewWeightedRoundRobin(0) // with this input error will not occur swrr := &SRVWeightedRoundRobin{ Service: service, Proto: proto, DomainName: domainName, HttpScheme: httpScheme, wrr: wrr, tick: time.NewTicker(180 * time.Second), // default is 180 seconds lock: new(sync.Mutex), lookupSRV: func() ([]*net.SRV, error) { _, addrs, err := net.LookupSRV(service, proto, domainName) return addrs, err }, } err := swrr.Refresh() go swrr.ticker() return swrr, err } var _ LoadBalancer = (*SRVWeightedRoundRobin)(nil) // SRVWeightedRoundRobin struct used to implement SRV Weighted Round-Robin(RR) algorithm type SRVWeightedRoundRobin struct { Service string Proto string DomainName string HttpScheme string wrr *WeightedRoundRobin tick *time.Ticker lock *sync.Mutex lookupSRV func() ([]*net.SRV, error) } // NextWithContext method returns the next SRV Base URL based on Weighted Round-Robin(RR) // with context support for cancellation func (swrr *SRVWeightedRoundRobin) NextWithContext(ctx context.Context) (string, error) { return swrr.wrr.NextWithContext(ctx) } // Feedback method does nothing in SRV Base URL based on Weighted Round-Robin(WRR) // request load balancer func (swrr *SRVWeightedRoundRobin) Feedback(f *RequestFeedback) { swrr.wrr.Feedback(f) } // Close method does the cleanup by stopping the [time.Ticker] SRV Base URL based // on Weighted Round-Robin(WRR) request load balancer func (swrr *SRVWeightedRoundRobin) Close() error { swrr.lock.Lock() defer swrr.lock.Unlock() swrr.wrr.Close() swrr.tick.Stop() return nil } // Refresh method reset the values based [net.LookupSRV] values to refresh it func (swrr *SRVWeightedRoundRobin) Refresh() error { swrr.lock.Lock() defer swrr.lock.Unlock() addrs, err := swrr.lookupSRV() if err != nil { return err } hosts := make([]*Host, len(addrs)) for idx, addr := range addrs { domain := strings.TrimRight(addr.Target, ".") baseURL := fmt.Sprintf("%s://%s:%d", swrr.HttpScheme, domain, addr.Port) hosts[idx] = &Host{BaseURL: baseURL, Weight: int(addr.Weight)} } return swrr.wrr.Refresh(hosts...) } // SetRefreshDuration method assists in changing the default (180 seconds) refresh duration func (swrr *SRVWeightedRoundRobin) SetRefreshDuration(d time.Duration) { swrr.lock.Lock() defer swrr.lock.Unlock() swrr.tick.Reset(d) } // SetOnStateChange method used to set a callback for the host transition state func (swrr *SRVWeightedRoundRobin) SetOnStateChange(fn HostStateChangeFunc) { swrr.wrr.SetOnStateChange(fn) } // SetRecoveryDuration method is used to change the existing recovery duration for the host func (swrr *SRVWeightedRoundRobin) SetRecoveryDuration(d time.Duration) { swrr.wrr.SetRecoveryDuration(d) } func (swrr *SRVWeightedRoundRobin) ticker() { for range swrr.tick.C { swrr.Refresh() } } func extractBaseURL(u string) (string, error) { baseURL, err := url.Parse(u) if err != nil { return "", err } // we only require base URL LB baseURL.Path = "" baseURL.RawQuery = "" return strings.TrimRight(baseURL.String(), "/"), nil } ================================================ FILE: load_balancer_test.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty import ( "context" "errors" "net" "net/http" "net/url" "sync/atomic" "testing" "time" ) func TestRoundRobin(t *testing.T) { t.Run("2 base urls", func(t *testing.T) { rr, err := NewRoundRobin("https://example1.com", "https://example2.com") assertNil(t, err) runCount := 5 var result []string ctx := context.Background() for i := 0; i < runCount; i++ { baseURL, _ := rr.NextWithContext(ctx) result = append(result, baseURL) } expected := []string{ "https://example1.com", "https://example2.com", "https://example1.com", "https://example2.com", "https://example1.com", } assertEqual(t, runCount, len(expected)) assertEqual(t, runCount, len(result)) assertEqual(t, expected, result) rr.Feedback(&RequestFeedback{}) rr.Close() }) t.Run("5 base urls", func(t *testing.T) { input := []string{"https://example1.com", "https://example2.com", "https://example3.com", "https://example4.com", "https://example5.com"} rr, err := NewRoundRobin(input...) assertNil(t, err) runCount := 30 var result []string ctx := context.Background() for i := 0; i < runCount; i++ { baseURL, _ := rr.NextWithContext(ctx) result = append(result, baseURL) } var expected []string for i := 0; i < runCount/len(input); i++ { expected = append(expected, input...) } assertEqual(t, runCount, len(expected)) assertEqual(t, runCount, len(result)) assertEqual(t, expected, result) rr.Feedback(&RequestFeedback{}) rr.Close() }) t.Run("2 base urls with refresh", func(t *testing.T) { rr, err := NewRoundRobin("https://example1.com", "https://example2.com") assertNil(t, err) err = rr.Refresh("https://example3.com", "https://example4.com") assertNil(t, err) runCount := 5 var result []string ctx := context.Background() for i := 0; i < runCount; i++ { baseURL, _ := rr.NextWithContext(ctx) result = append(result, baseURL) } expected := []string{ "https://example3.com", "https://example4.com", "https://example3.com", "https://example4.com", "https://example3.com", } assertEqual(t, runCount, len(expected)) assertEqual(t, runCount, len(result)) assertEqual(t, expected, result) rr.Feedback(&RequestFeedback{}) rr.Close() }) t.Run("NextWithContext context cancellation", func(t *testing.T) { rr, _ := NewRoundRobin("https://example.com") ctx, cancel := context.WithCancel(context.Background()) cancel() _, err := rr.NextWithContext(ctx) assertErrorIs(t, context.Canceled, err) }) t.Run("NextWithContext normal operation", func(t *testing.T) { rr, _ := NewRoundRobin("https://example1.com", "https://example2.com") ctx := context.Background() url1, err := rr.NextWithContext(ctx) assertNil(t, err) url2, err := rr.NextWithContext(ctx) assertNil(t, err) assertNotEqual(t, url1, url2) }) } func TestRoundRobinNoBaseURLs(t *testing.T) { t.Run("new round robin no base urls", func(t *testing.T) { rr, err := NewRoundRobin() assertErrorIs(t, ErrNoBaseURLs, err) assertNil(t, rr) }) t.Run("new round robin no base urls on next with context", func(t *testing.T) { rr, err := NewRoundRobin("https://example1.com") assertNil(t, err) assertNotNil(t, rr) rr.Refresh() ctx := context.Background() _, err = rr.NextWithContext(ctx) assertErrorIs(t, ErrNoBaseURLs, err) }) } func TestWeightedRoundRobin(t *testing.T) { t.Run("3 hosts with weight {5,2,1}", func(t *testing.T) { hosts := []*Host{ {BaseURL: "https://example1.com", Weight: 5}, {BaseURL: "https://example2.com", Weight: 2}, {BaseURL: "https://example3.com", Weight: 1}, } wrr, err := NewWeightedRoundRobin(200*time.Millisecond, hosts...) assertNil(t, err) defer wrr.Close() runCount := 5 var result []string ctx := context.Background() for i := 0; i < runCount; i++ { baseURL, err := wrr.NextWithContext(ctx) assertNil(t, err) result = append(result, baseURL) } expected := []string{ "https://example1.com", "https://example2.com", "https://example1.com", "https://example1.com", "https://example3.com", } assertEqual(t, runCount, len(expected)) assertEqual(t, runCount, len(result)) assertEqual(t, expected, result) wrr.Feedback(nil) }) t.Run("3 hosts with weight {2,1,10}", func(t *testing.T) { hosts := []*Host{ {BaseURL: "https://example1.com", Weight: 2}, {BaseURL: "https://example2.com", Weight: 1}, {BaseURL: "https://example3.com", Weight: 10, MaxFailures: 3}, } wrr, err := NewWeightedRoundRobin(200*time.Millisecond, hosts...) assertNil(t, err) defer wrr.Close() var stateChangeCalled int32 wrr.SetOnStateChange(func(baseURL string, from, to HostState) { atomic.AddInt32(&stateChangeCalled, 1) }) runCount := 10 var result []string ctx := context.Background() for i := 0; i < runCount; i++ { baseURL, err := wrr.NextWithContext(ctx) assertNil(t, err) result = append(result, baseURL) if baseURL == "https://example3.com" && i%2 != 0 { wrr.Feedback(&RequestFeedback{BaseURL: baseURL, Success: false, Attempt: 1}) } else { wrr.Feedback(&RequestFeedback{BaseURL: baseURL, Success: true, Attempt: 1}) } } expected := []string{ "https://example3.com", "https://example3.com", "https://example1.com", "https://example3.com", "https://example3.com", "https://example3.com", "https://example2.com", "https://example2.com", "https://example1.com", "https://example1.com", } assertEqual(t, int32(1), stateChangeCalled) assertEqual(t, runCount, len(expected)) assertEqual(t, runCount, len(result)) assertEqual(t, expected, result) }) t.Run("2 hosts with weight {5,5} and refresh", func(t *testing.T) { wrr, err := NewWeightedRoundRobin( 200*time.Millisecond, &Host{BaseURL: "https://example1.com", Weight: 5}, &Host{BaseURL: "https://example2.com", Weight: 5}, ) assertNil(t, err) defer wrr.Close() err = wrr.Refresh( &Host{BaseURL: "https://example3.com", Weight: 5}, &Host{BaseURL: "https://example4.com", Weight: 5}, ) assertNil(t, err) runCount := 5 var result []string ctx := context.Background() for i := 0; i < runCount; i++ { baseURL, err := wrr.NextWithContext(ctx) assertNil(t, err) result = append(result, baseURL) } expected := []string{ "https://example3.com", "https://example4.com", "https://example3.com", "https://example4.com", "https://example3.com", } assertEqual(t, runCount, len(expected)) assertEqual(t, runCount, len(result)) assertEqual(t, expected, result) }) t.Run("no active hosts error", func(t *testing.T) { wrr, err := NewWeightedRoundRobin(200 * time.Millisecond) assertNil(t, err) defer wrr.Close() _, err = wrr.NextWithContext(context.Background()) assertErrorIs(t, ErrNoActiveHost, err) }) t.Run("NextWithContext context cancellation", func(t *testing.T) { wrr, _ := NewWeightedRoundRobin(0, &Host{BaseURL: "https://example.com", Weight: 1}) ctx, cancel := context.WithCancel(context.Background()) cancel() _, err := wrr.NextWithContext(ctx) assertErrorIs(t, context.Canceled, err) }) t.Run("NextWithContext normal operation", func(t *testing.T) { hosts := []*Host{ {BaseURL: "https://example1.com", Weight: 1}, {BaseURL: "https://example2.com", Weight: 1}, } wrr, _ := NewWeightedRoundRobin(0, hosts...) ctx := context.Background() url1, err := wrr.NextWithContext(ctx) assertNil(t, err) url2, err := wrr.NextWithContext(ctx) assertNil(t, err) assertNotEqual(t, url1, url2) }) } func TestSRVWeightedRoundRobin(t *testing.T) { t.Run("3 records with weight {50,30,20}", func(t *testing.T) { srv, err := NewSRVWeightedRoundRobin("_sample-server", "", "example.com", "") assertNotNil(t, err) assertNotNil(t, srv) var dnsErr *net.DNSError assertTrue(t, errors.As(err, &dnsErr), "expected net.DNSError type") // mock net.LookupSRV call srv.lookupSRV = func() ([]*net.SRV, error) { return []*net.SRV{ {Target: "service1.example.com.", Port: 443, Priority: 10, Weight: 50}, {Target: "service2.example.com.", Port: 443, Priority: 20, Weight: 30}, {Target: "service3.example.com.", Port: 443, Priority: 20, Weight: 20}, }, nil } err = srv.Refresh() assertNil(t, err) srv.SetRecoveryDuration(200 * time.Millisecond) runCount := 5 var result []string ctx := context.Background() for i := 0; i < runCount; i++ { baseURL, err := srv.NextWithContext(ctx) assertNil(t, err) result = append(result, baseURL) } expected := []string{ "https://service1.example.com:443", "https://service2.example.com:443", "https://service3.example.com:443", "https://service1.example.com:443", "https://service1.example.com:443", } assertEqual(t, runCount, len(expected)) assertEqual(t, runCount, len(result)) assertEqual(t, expected, result) }) t.Run("2 records with weight {50,50}", func(t *testing.T) { srv, err := NewSRVWeightedRoundRobin("_sample-server", "", "example.com", "") assertNotNil(t, err) assertNotNil(t, srv) var dnsErr *net.DNSError assertTrue(t, errors.As(err, &dnsErr), "expected net.DNSError type") // mock net.LookupSRV call srv.lookupSRV = func() ([]*net.SRV, error) { return []*net.SRV{ {Target: "service1.example.com.", Port: 443, Priority: 10, Weight: 50}, {Target: "service2.example.com.", Port: 443, Priority: 20, Weight: 50}, }, nil } err = srv.Refresh() assertNil(t, err) srv.SetRecoveryDuration(200 * time.Millisecond) runCount := 5 var result []string ctx := context.Background() for i := 0; i < runCount; i++ { baseURL, err := srv.NextWithContext(ctx) assertNil(t, err) result = append(result, baseURL) } expected := []string{ "https://service1.example.com:443", "https://service2.example.com:443", "https://service1.example.com:443", "https://service2.example.com:443", "https://service1.example.com:443", } assertEqual(t, runCount, len(expected)) assertEqual(t, runCount, len(result)) assertEqual(t, expected, result) }) t.Run("3 records with weight {60,20,20}", func(t *testing.T) { srv, err := NewSRVWeightedRoundRobin("_sample-server", "", "example.com", "") assertNotNil(t, err) assertNotNil(t, srv) var dnsErr *net.DNSError assertTrue(t, errors.As(err, &dnsErr), "expected net.DNSError type") // mock net.LookupSRV call srv.lookupSRV = func() ([]*net.SRV, error) { return []*net.SRV{ {Target: "service1.example.com.", Port: 443, Priority: 10, Weight: 60}, {Target: "service2.example.com.", Port: 443, Priority: 20, Weight: 20}, {Target: "service3.example.com.", Port: 443, Priority: 20, Weight: 20}, }, nil } err = srv.Refresh() assertNil(t, err) var stateChangeCalled int32 srv.SetOnStateChange(func(baseURL string, from, to HostState) { atomic.AddInt32(&stateChangeCalled, 1) }) srv.SetRecoveryDuration(200 * time.Millisecond) runCount := 20 var result []string ctx := context.Background() for i := 0; i < runCount; i++ { baseURL, err := srv.NextWithContext(ctx) assertNil(t, err) result = append(result, baseURL) if baseURL == "https://service1.example.com:443" { srv.Feedback(&RequestFeedback{BaseURL: baseURL, Success: false, Attempt: 1}) } else { srv.Feedback(&RequestFeedback{BaseURL: baseURL, Success: true, Attempt: 1}) } } expected := []string{ "https://service1.example.com:443", "https://service2.example.com:443", "https://service1.example.com:443", "https://service3.example.com:443", "https://service1.example.com:443", "https://service1.example.com:443", "https://service2.example.com:443", "https://service1.example.com:443", "https://service3.example.com:443", "https://service3.example.com:443", "https://service3.example.com:443", "https://service2.example.com:443", "https://service3.example.com:443", "https://service2.example.com:443", "https://service3.example.com:443", "https://service2.example.com:443", "https://service3.example.com:443", "https://service2.example.com:443", "https://service3.example.com:443", "https://service2.example.com:443", } assertEqual(t, runCount, len(expected)) assertEqual(t, runCount, len(result)) assertEqual(t, expected, result) }) t.Run("srv record with refresh duration 100ms", func(t *testing.T) { srv, err := NewSRVWeightedRoundRobin("_sample-server", "", "example.com", "") assertNotNil(t, err) assertNotNil(t, srv) var dnsErr *net.DNSError assertTrue(t, errors.As(err, &dnsErr), "expected net.DNSError type") // mock net.LookupSRV call srv.lookupSRV = func() ([]*net.SRV, error) { return []*net.SRV{ {Target: "service1.example.com.", Port: 443, Priority: 10, Weight: 50}, {Target: "service2.example.com.", Port: 443, Priority: 20, Weight: 50}, }, nil } err = srv.Refresh() assertNil(t, err) srv.SetRecoveryDuration(200 * time.Millisecond) go func() { for i := 0; i < 10; i++ { baseURL, _ := srv.NextWithContext(context.Background()) assertNotNil(t, baseURL) time.Sleep(15 * time.Millisecond) } }() srv.SetRefreshDuration(150 * time.Millisecond) time.Sleep(320 * time.Millisecond) srv.Close() }) t.Run("srv record with error on default lookupSRV", func(t *testing.T) { srv, err := NewSRVWeightedRoundRobin("_sample-server", "", "example.com", "") assertNotNil(t, err) assertNotNil(t, srv) var dnsErr *net.DNSError assertTrue(t, errors.As(err, &dnsErr), "expected net.DNSError type") // default error flow err = srv.Refresh() assertNotNil(t, err) assertTrue(t, errors.As(err, &dnsErr), "expected net.DNSError type") // replace with mock error flow errMockTest := errors.New("network error") srv.lookupSRV = func() ([]*net.SRV, error) { return nil, errMockTest } err = srv.Refresh() assertNotNil(t, err) assertErrorIs(t, errMockTest, err, "expected network error type") }) } func TestLoadBalancerRequest(t *testing.T) { ts1 := createGetServer(t) defer ts1.Close() ts2 := createGetServer(t) defer ts2.Close() rr, err := NewRoundRobin(ts1.URL, ts2.URL) assertNil(t, err) c := dcnl() defer c.Close() c.SetLoadBalancer(rr) ts1URL, ts2URL := 0, 0 for i := 0; i < 20; i++ { resp, err := c.R().Get("/") assertNil(t, err) switch resp.Request.baseURL { case ts1.URL: ts1URL++ case ts2.URL: ts2URL++ } } assertEqual(t, ts1URL, ts2URL) } func TestLoadBalancerRequestFlowError(t *testing.T) { t.Run("obtain next url error", func(t *testing.T) { wrr, err := NewWeightedRoundRobin(0) assertNil(t, err) c := dcnl() defer c.Close() c.SetLoadBalancer(wrr) resp, err := c.R().Get("/") assertErrorIs(t, ErrNoActiveHost, err) assertNil(t, resp) }) t.Run("round-robin invalid url input", func(t *testing.T) { rr, err := NewRoundRobin("://example.com") assertType(t, url.Error{}, err) assertNotNil(t, rr) wrr, err := NewWeightedRoundRobin(0, &Host{BaseURL: "://example.com"}) assertType(t, url.Error{}, err) assertNotNil(t, wrr) }) t.Run("weighted round-robin invalid url input", func(t *testing.T) { wrr, err := NewWeightedRoundRobin(0, &Host{BaseURL: "://example.com"}) assertType(t, url.Error{}, err) assertNotNil(t, wrr) }) } func Test_extractBaseURL(t *testing.T) { for _, tt := range []struct { name string inputURL string expectedURL string expectedErr error }{ { name: "simple relative path", inputURL: "https://resty.dev/welcome", expectedURL: "https://resty.dev", }, { name: "longer relative path with file extension", inputURL: "https://resty.dev/welcome/path/to/remove.html", expectedURL: "https://resty.dev", }, { name: "longer relative path with file extension and query params", inputURL: "https://resty.dev/welcome/path/to/remove.html?a=1&b=2", expectedURL: "https://resty.dev", }, { name: "invalid url input", inputURL: "://resty.dev/welcome", expectedURL: "", expectedErr: &url.Error{Op: "parse", URL: "://resty.dev/welcome", Err: errors.New("missing protocol scheme")}, }, } { t.Run(tt.name, func(t *testing.T) { outputURL, err := extractBaseURL(tt.inputURL) if tt.expectedErr != nil { assertEqual(t, tt.expectedErr, err) } assertEqual(t, tt.expectedURL, outputURL) }) } } func TestLoadBalancerRequestFailures(t *testing.T) { ts1 := createGetServer(t) ts1.Close() ts2 := createGetServer(t) defer ts2.Close() rr, err := NewWeightedRoundRobin(200*time.Millisecond, &Host{BaseURL: ts1.URL, Weight: 50, MaxFailures: 3}, &Host{BaseURL: ts2.URL, Weight: 50}) assertNil(t, err) c := dcnl() defer c.Close() c.SetLoadBalancer(rr) ts1URL, ts2URL := 0, 0 for i := 0; i < 10; i++ { resp, _ := c.R().Get("/") switch resp.Request.baseURL { case ts1.URL: ts1URL++ case ts2.URL: assertError(t, err) ts2URL++ } } assertEqual(t, 3, ts1URL) assertEqual(t, 7, ts2URL) } type mockTimeoutErr struct{} func (e *mockTimeoutErr) Error() string { return "i/o timeout" } func (e *mockTimeoutErr) Timeout() bool { return true } func TestLoadBalancerCoverage(t *testing.T) { t.Run("mock net op timeout error", func(t *testing.T) { wrr, err := NewWeightedRoundRobin(0) assertNil(t, err) c := dcnl() defer c.Close() c.SetLoadBalancer(wrr) req := c.R() netOpErr := &net.OpError{Op: "mock", Net: "mock", Err: &mockTimeoutErr{}} req.sendLoadBalancerFeedback(&Response{}, netOpErr) req.sendLoadBalancerFeedback(&Response{RawResponse: &http.Response{ StatusCode: http.StatusInternalServerError, }}, nil) }) } ================================================ FILE: middleware.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty import ( "bytes" "context" "fmt" "io" "mime" "mime/multipart" "net/http" "net/textproto" "net/url" "path" "path/filepath" "reflect" "strings" ) //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ // Request Middleware(s) //_______________________________________________________________________ // MiddlewareRequestCreate method is used to prepare HTTP requests using the // user-provided request values. It performs the following operations - // - Parse the request URL with path params and query params // - Parse the request headers from client and request level // - Parse the request body based on the content type and body type // - Create the underlying [http.Request] object // - Add credentials such as Basic Auth and Token Auth into the request // // Returns an error if request preparation fails. func MiddlewareRequestCreate(c *Client, r *Request) (err error) { if err = parseRequestURL(c, r); err != nil { return err } // no error returned parseRequestHeader(c, r) if err = parseRequestBody(c, r); err != nil { return err } // at this point, possible error from `http.NewRequestWithContext` // is URL-related, and those get caught up in the `parseRequestURL` createRawRequest(c, r) addCredentials(c, r) _ = r.generateCurlCommand() return nil } func parseRequestURL(c *Client, r *Request) error { if len(c.PathParams())+len(r.PathParams) > 0 { // GitHub #103 Path Params, #663 Raw Path Params for p, v := range c.PathParams() { if _, ok := r.PathParams[p]; ok { continue } r.PathParams[p] = v } var prev int buf := acquireBuffer() defer releaseBuffer(buf) // search for the next or first opened curly bracket for curr := strings.Index(r.URL, "{"); curr == 0 || curr > prev; curr = prev + strings.Index(r.URL[prev:], "{") { // write everything from the previous position up to the current if curr > prev { buf.WriteString(r.URL[prev:curr]) } // search for the closed curly bracket from current position next := curr + strings.Index(r.URL[curr:], "}") // if not found, then write the remainder and exit if next < curr { buf.WriteString(r.URL[curr:]) prev = len(r.URL) break } // special case for {}, without parameter's name if next == curr+1 { buf.WriteString("{}") } else { // check for the replacement key := r.URL[curr+1 : next] value, ok := r.PathParams[key] // keep the original string if the replacement not found if !ok { value = r.URL[curr : next+1] } buf.WriteString(value) } // set the previous position after the closed curly bracket prev = next + 1 if prev >= len(r.URL) { break } } if buf.Len() > 0 { // write remainder if prev < len(r.URL) { buf.WriteString(r.URL[prev:]) } r.URL = buf.String() } } // Parsing request URL reqURL, err := url.Parse(r.URL) if err != nil { return &invalidRequestError{Err: err} } // If [Request.URL] is a relative path, then the following // gets evaluated in the order // 1. [Client.LoadBalancer] is used to obtain the base URL if not nil // 2. [Client.BaseURL] is used to obtain the base URL // 3. Otherwise [Request.URL] is used as-is if !reqURL.IsAbs() { r.URL = reqURL.String() if len(r.URL) > 0 && r.URL[0] != '/' { r.URL = "/" + r.URL } if r.client.LoadBalancer() != nil { r.baseURL, err = r.client.LoadBalancer().NextWithContext(r.Context()) if err != nil { return &invalidRequestError{Err: err} } } reqURL, err = url.Parse(r.baseURL + r.URL) if err != nil { return &invalidRequestError{Err: err} } } // GH #407 && #318 if reqURL.Scheme == "" && len(c.Scheme()) > 0 { reqURL.Scheme = c.Scheme() } // Adding Query Param if len(c.QueryParams())+len(r.QueryParams) > 0 { for k, v := range c.QueryParams() { if _, ok := r.QueryParams[k]; ok { continue } r.QueryParams[k] = v[:] } // GitHub #123 Preserve query string order partially. // Since not feasible in `SetQuery*` resty methods, because // standard package `url.Encode(...)` sorts the query params // alphabetically if isStringEmpty(reqURL.RawQuery) { reqURL.RawQuery = r.QueryParams.Encode() } else { reqURL.RawQuery = reqURL.RawQuery + "&" + r.QueryParams.Encode() } } // GH#797 Unescape query parameters (non-standard - not recommended) if r.unescapeQueryParams && len(reqURL.RawQuery) > 0 { // at this point, all errors caught up in the above operations // so ignore the return error on query unescape; I realized // while writing the unit test unescapedQuery, _ := url.QueryUnescape(reqURL.RawQuery) reqURL.RawQuery = strings.ReplaceAll(unescapedQuery, " ", "+") // otherwise request becomes bad request } r.URL = reqURL.String() return nil } func parseRequestHeader(c *Client, r *Request) error { for k, v := range c.Header() { if _, ok := r.Header[k]; ok { continue } r.Header[k] = v[:] } if !r.isHeaderExists(hdrUserAgentKey) { r.Header.Set(hdrUserAgentKey, hdrUserAgentValue) } if !r.isHeaderExists(hdrAcceptEncodingKey) { r.Header.Set(hdrAcceptEncodingKey, r.client.ContentDecompresserKeys()) } return nil } func parseRequestBody(c *Client, r *Request) error { if r.isMultiPart && !(r.Method == MethodPost || r.Method == MethodPut || r.Method == MethodPatch) { err := fmt.Errorf("resty: multipart is not allowed in HTTP verb: %v", r.Method) return &invalidRequestError{Err: err} } if r.isPayloadSupported() { switch { case r.isMultiPart: // Handling Multipart if err := handleMultipart(c, r); err != nil { return &invalidRequestError{Err: err} } case len(c.FormData()) > 0 || len(r.FormData) > 0: // Handling Form Data handleFormData(c, r) case r.Body != nil: // Handling Request body if err := handleRequestBody(c, r); err != nil { return &invalidRequestError{Err: err} } } } else { r.Body = nil // if the payload is not supported by HTTP verb, set explicit nil } return nil } func createRawRequest(c *Client, r *Request) (err error) { // init client trace if enabled r.initTraceIfEnabled() if r.bodyBuf == nil { if reader, ok := r.Body.(io.Reader); ok { r.RawRequest, err = http.NewRequestWithContext(r.Context(), r.Method, r.URL, reader) } else { r.RawRequest, err = http.NewRequestWithContext(r.Context(), r.Method, r.URL, nil) } } else { r.RawRequest, err = http.NewRequestWithContext(r.Context(), r.Method, r.URL, r.bodyBuf) } if err != nil { return &invalidRequestError{Err: err} } // get the context reference back from underlying RawRequest r.SetContext(r.RawRequest.Context()) // Assign close connection option r.RawRequest.Close = r.IsCloseConnection // Add headers into http request r.RawRequest.Header = r.Header.Clone() // Add cookies from client instance into http request for _, cookie := range c.Cookies() { r.RawRequest.AddCookie(cookie) } // Add cookies from request instance into http request for _, cookie := range r.Cookies { r.RawRequest.AddCookie(cookie) } // Set given content length value into the request if r.isContentLengthSet { r.RawRequest.ContentLength = r.contentLength } else { r.contentLength = r.RawRequest.ContentLength } return } func addCredentials(c *Client, r *Request) error { credentialsAdded := false // Basic Auth if r.credentials != nil { credentialsAdded = true r.RawRequest.SetBasicAuth(r.credentials.Username, r.credentials.Password) } // Build the token Auth header if !isStringEmpty(r.AuthToken) { credentialsAdded = true r.RawRequest.Header.Set(r.HeaderAuthorizationKey, strings.TrimSpace(r.AuthScheme+" "+r.AuthToken)) } if !c.IsDisableWarn() && credentialsAdded { if r.RawRequest.URL.Scheme == "http" { r.log.Warnf("Using sensitive credentials in HTTP mode is not secure. Use HTTPS") } } return nil } var multipartWriteField = func(w *multipart.Writer, name, value string) error { return w.WriteField(name, value) } var multipartWriteFormData = func(w *multipart.Writer, r *Request) error { for k, v := range r.FormData { for _, iv := range v { if err := multipartWriteField(w, k, iv); err != nil { return err } } } return nil } var multipartCreatePart = func(w *multipart.Writer, h textproto.MIMEHeader) (io.Writer, error) { return w.CreatePart(h) } var multipartSetBoundary = func(w *multipart.Writer, r *Request) error { if isStringEmpty(r.multipartBoundary) { return nil } return w.SetBoundary(r.multipartBoundary) } func handleMultipartFormData(r *Request) error { r.bodyBuf = acquireBuffer() mw := multipart.NewWriter(r.bodyBuf) defer mw.Close() // set custom multipart boundary if exists if err := multipartSetBoundary(mw, r); err != nil { return err } r.Header.Set(hdrContentTypeKey, mw.FormDataContentType()) return multipartWriteFormData(mw, r) } func handleMultipart(c *Client, r *Request) error { for k, v := range c.FormData() { if _, ok := r.FormData[k]; ok { continue } r.FormData[k] = v[:] } if len(r.multipartFields) == 0 { return handleMultipartFormData(r) } // pre-process multipart fields to catch possible errors for _, mf := range r.multipartFields { if mf.isValues() { continue } if err := mf.openFile(); err != nil { return err } if err := mf.detectContentType(); err != nil { return err } } // multipart streaming br, bw := io.Pipe() mw := multipart.NewWriter(bw) r.Body = br // set custom multipart boundary if exists if err := multipartSetBoundary(mw, r); err != nil { closeq(bw) return err } r.Header.Set(hdrContentTypeKey, mw.FormDataContentType()) r.multipartErrChan = make(chan error, 1) go func() { defer close(r.multipartErrChan) defer func() { if err := mw.Close(); err != nil { r.multipartErrChan <- err } if err := bw.Close(); err != nil { r.multipartErrChan <- err } }() if err := multipartWriteFormData(mw, r); err != nil { r.multipartErrChan <- err return } ctx, cancel := context.WithCancel(r.Context()) r.multipartCancelFunc = cancel for _, mf := range r.multipartFields { if mf.isValues() { for _, v := range mf.Values { if err := multipartWriteField(mw, mf.Name, v); err != nil { r.multipartErrChan <- err return } } continue } partWriter, err := multipartCreatePart(mw, mf.createHeader()) if err != nil { r.multipartErrChan <- err return } partWriter = mf.wrapProgressCallbackIfPresent(partWriter) if len(mf.tempBuf) > 0 { if _, err = partWriter.Write(mf.tempBuf); err != nil { r.multipartErrChan <- err return } } reader := &gracefulStopReader{ctx: ctx, r: mf.Reader} if _, err = ioCopy(partWriter, reader); err != nil { r.multipartErrChan <- err return } } }() return nil } func handleFormData(c *Client, r *Request) { for k, v := range c.FormData() { if _, ok := r.FormData[k]; ok { continue } r.FormData[k] = v[:] } r.bodyBuf = acquireBuffer() r.bodyBuf.WriteString(r.FormData.Encode()) r.Header.Set(hdrContentTypeKey, formContentType) r.isFormData = true } func handleRequestBody(c *Client, r *Request) error { contentType := strings.ToLower(r.Header.Get(hdrContentTypeKey)) if isStringEmpty(contentType) { // it is highly recommended that the user provide a request content-type // so that we can minimize memory allocation and compute. contentType = detectContentType(r.Body) } if !r.isHeaderExists(hdrContentTypeKey) { r.Header.Set(hdrContentTypeKey, contentType) } r.bodyBuf = acquireBuffer() switch body := r.Body.(type) { case io.Reader: // Resty v3 onwards io.Reader used as-is with the request body. releaseBuffer(r.bodyBuf) r.bodyBuf = nil // enable multiple reads if body is *bytes.Buffer if b, ok := r.Body.(*bytes.Buffer); ok { v := b.Bytes() r.Body = bytes.NewReader(v) } // do seek start for retry attempt if io.ReadSeeker // interface supported if r.Attempt > 1 { if rs, ok := r.Body.(io.ReadSeeker); ok { _, _ = rs.Seek(0, io.SeekStart) } } return nil case []byte: r.bodyBuf.Write(body) case string: r.bodyBuf.Write([]byte(body)) default: encKey := inferContentTypeMapKey(contentType) if jsonKey == encKey { if !r.jsonEscapeHTML { return encodeJSONEscapeHTML(r.bodyBuf, r.Body, r.jsonEscapeHTML) } } else if xmlKey == encKey { if inferKind(r.Body) != reflect.Struct { releaseBuffer(r.bodyBuf) r.bodyBuf = nil return ErrUnsupportedRequestBodyKind } } // user registered encoders with resty fallback key encFunc, found := c.inferContentTypeEncoder(contentType, encKey) if !found { releaseBuffer(r.bodyBuf) r.bodyBuf = nil return fmt.Errorf("resty: content-type encoder not found for %s", contentType) } if err := encFunc(r.bodyBuf, r.Body); err != nil { releaseBuffer(r.bodyBuf) r.bodyBuf = nil return err } } return nil } //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ // Response Middleware(s) //_______________________________________________________________________ // MiddlewareResponseAutoParse method is used to parse the response body automatically // based on the registered HTTP response `Content-Type` decoder, see [Client.AddContentTypeDecoder]; // if [Request.SetResult], [Request.SetResultError], or [Client.SetResultError] is used, it performs // the auto unmarshalling into the respective object. func MiddlewareResponseAutoParse(c *Client, res *Response) (err error) { if (res.CascadeError != nil && (res.Request.isMultiPart && res.StatusCode() == 0)) || res.Request.IsResponseDoNotParse { return // move on } if res.StatusCode() == http.StatusNoContent { res.Request.ResultError = nil return } rct := strings.ToLower(firstNonEmpty( res.Request.ResponseForceContentType, res.Header().Get(hdrContentTypeKey), res.Request.ResponseExpectContentType, )) decKey := inferContentTypeMapKey(rct) decFunc, found := c.inferContentTypeDecoder(rct, decKey) if !found { // the Content-Type decoder is not found; just read all the body bytes err = res.readAll() return } // HTTP status code > 199 and < 300, considered as Result if res.IsStatusSuccess() && res.Request.Result != nil { res.Request.ResultError = nil defer closeq(res.Body) err = decFunc(res.Body, res.Request.Result) res.IsRead = true return } // HTTP status code > 399, considered as Error if res.IsStatusFailure() { // global error type registered at client-instance if res.Request.ResultError == nil { res.Request.ResultError = c.newErrorInterface() } if res.Request.ResultError != nil { defer closeq(res.Body) err = decFunc(res.Body, res.Request.ResultError) res.IsRead = true return } } return } var hostnameReplacer = strings.NewReplacer(":", "_", ".", "_") // MiddlewareResponseSaveToFile method used to write HTTP response body into // file. The filename is determined in the following order - // - [Request.SetResponseSaveFileName] // - Content-Disposition header // - Request URL using [path.Base] func MiddlewareResponseSaveToFile(c *Client, res *Response) error { if res.CascadeError != nil || !res.Request.IsResponseSaveToFile { return nil } file := res.Request.ResponseSaveFileName if isStringEmpty(file) { cntDispositionValue := res.Header().Get(hdrContentDisposition) if len(cntDispositionValue) > 0 { if _, params, err := mime.ParseMediaType(cntDispositionValue); err == nil { file = params["filename"] } } if isStringEmpty(file) { rURL, _ := url.Parse(res.Request.URL) if isStringEmpty(rURL.Path) || rURL.Path == "/" { file = hostnameReplacer.Replace(rURL.Host) } else { file = path.Base(rURL.Path) } } } if len(c.ResponseSaveDirectory()) > 0 && !filepath.IsAbs(file) { file = filepath.Join(c.ResponseSaveDirectory(), string(filepath.Separator), file) } file = filepath.Clean(file) if err := createDirectory(filepath.Dir(file)); err != nil { return err } outFile, err := createFile(file) if err != nil { return err } defer func() { closeq(outFile) closeq(res.Body) }() // io.Copy reads maximum 32kb size, it is perfect for large file download too res.size, err = ioCopy(outFile, res.Body) return err } ================================================ FILE: middleware_test.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty import ( "bytes" "encoding/json" "errors" "io" "mime" "mime/multipart" "net/http" "net/textproto" "net/url" "os" "path/filepath" "reflect" "strings" "sync" "testing" ) func Test_parseRequestURL(t *testing.T) { for _, tt := range []struct { name string initClient func(c *Client) initRequest func(r *Request) expectedURL string }{ { name: "apply client path parameters", initClient: func(c *Client) { c.SetPathParams(map[string]string{ "foo": "1", "bar": "2/3", }) }, initRequest: func(r *Request) { r.URL = "https://example.com/{foo}/{bar}" }, expectedURL: "https://example.com/1/2%2F3", }, { name: "apply request path parameters", initRequest: func(r *Request) { r.SetPathParams(map[string]string{ "foo": "4", "bar": "5/6", }) r.URL = "https://example.com/{foo}/{bar}" }, expectedURL: "https://example.com/4/5%2F6", }, { name: "apply request and client path parameters", initClient: func(c *Client) { c.SetPathParams(map[string]string{ "foo": "1", // ignored, because of the request's "foo" "bar": "2/3", }) }, initRequest: func(r *Request) { r.SetPathParams(map[string]string{ "foo": "4/5", }) r.URL = "https://example.com/{foo}/{bar}" }, expectedURL: "https://example.com/4%2F5/2%2F3", }, { name: "apply client raw path parameters", initClient: func(c *Client) { c.SetPathRawParams(map[string]string{ "foo": "1/2", "bar": "3", }) }, initRequest: func(r *Request) { r.URL = "https://example.com/{foo}/{bar}" }, expectedURL: "https://example.com/1/2/3", }, { name: "apply request raw path parameters", initRequest: func(r *Request) { r.SetPathRawParams(map[string]string{ "foo": "4", "bar": "5/6", }) r.URL = "https://example.com/{foo}/{bar}" }, expectedURL: "https://example.com/4/5/6", }, { name: "apply request and client raw path parameters", initClient: func(c *Client) { c.SetPathRawParams(map[string]string{ "foo": "1", // ignored, because of the request's "foo" "bar": "2/3", }) }, initRequest: func(r *Request) { r.SetPathRawParams(map[string]string{ "foo": "4/5", }) r.URL = "https://example.com/{foo}/{bar}" }, expectedURL: "https://example.com/4/5/2/3", }, { name: "apply request path and raw path parameters", initRequest: func(r *Request) { r.SetPathParams(map[string]string{ "foo": "4/5", }).SetPathRawParams(map[string]string{ "foo": "4/5", // it gets overwritten since same key name "bar": "6/7", }) r.URL = "https://example.com/{foo}/{bar}" }, expectedURL: "https://example.com/4/5/6/7", }, { name: "empty path parameter in URL", initRequest: func(r *Request) { r.SetPathParams(map[string]string{ "bar": "4", }) r.URL = "https://example.com/{}/{bar}" }, expectedURL: "https://example.com/%7B%7D/4", }, { name: "not closed path parameter in URL", initRequest: func(r *Request) { r.SetPathParams(map[string]string{ "foo": "4", }) r.URL = "https://example.com/{foo}/{bar/1" }, expectedURL: "https://example.com/4/%7Bbar/1", }, { name: "extra path parameter in URL", initRequest: func(r *Request) { r.SetPathParams(map[string]string{ "foo": "1", }) r.URL = "https://example.com/{foo}/{bar}" }, expectedURL: "https://example.com/1/%7Bbar%7D", }, { name: " path parameter with remainder", initRequest: func(r *Request) { r.SetPathParams(map[string]string{ "foo": "1", }) r.URL = "https://example.com/{foo}/2" }, expectedURL: "https://example.com/1/2", }, { name: "using base url with path param at index 0", initClient: func(c *Client) { c.SetBaseURL("https://example.com/prefix") }, initRequest: func(r *Request) { r.SetPathParam("first", "1"). SetPathParam("second", "2") r.URL = "{first}/{second}" }, expectedURL: "https://example.com/prefix/1/2", }, { name: "using BaseURL with absolute URL in request", initClient: func(c *Client) { c.SetBaseURL("https://foo.bar") // ignored }, initRequest: func(r *Request) { r.URL = "https://example.com/" }, expectedURL: "https://example.com/", }, { name: "using BaseURL with relative path in request URL without leading slash", initClient: func(c *Client) { c.SetBaseURL("https://example.com") }, initRequest: func(r *Request) { r.URL = "foo/bar" }, expectedURL: "https://example.com/foo/bar", }, { name: "using BaseURL with relative path in request URL with leading slash", initClient: func(c *Client) { c.SetBaseURL("https://example.com") }, initRequest: func(r *Request) { r.URL = "/foo/bar" }, expectedURL: "https://example.com/foo/bar", }, { name: "using deprecated HostURL with relative path in request URL", initClient: func(c *Client) { c.SetBaseURL("https://example.com") }, initRequest: func(r *Request) { r.URL = "foo/bar" }, expectedURL: "https://example.com/foo/bar", }, { name: "request URL without scheme", initRequest: func(r *Request) { r.URL = "example.com/foo/bar" }, expectedURL: "/example.com/foo/bar", }, { name: "BaseURL without scheme", initClient: func(c *Client) { c.SetBaseURL("example.com") }, initRequest: func(r *Request) { r.URL = "foo/bar" }, expectedURL: "example.com/foo/bar", }, { name: "using SetScheme and BaseURL without scheme", initClient: func(c *Client) { c.SetBaseURL("example.com"). SetScheme("https") }, initRequest: func(r *Request) { r.URL = "foo/bar" }, expectedURL: "https://example.com/foo/bar", }, { name: "adding query parameters by client", initClient: func(c *Client) { c.SetQueryParams(map[string]string{ "foo": "1", "bar": "2", }) }, initRequest: func(r *Request) { r.URL = "https://example.com/" }, expectedURL: "https://example.com/?foo=1&bar=2", }, { name: "adding query parameters by request", initRequest: func(r *Request) { r.SetQueryParams(map[string]string{ "foo": "1", "bar": "2", }) r.URL = "https://example.com/" }, expectedURL: "https://example.com/?foo=1&bar=2", }, { name: "adding query parameters by client and request", initClient: func(c *Client) { c.SetQueryParams(map[string]string{ "foo": "1", // ignored, because of the "foo" parameter in request "bar": "2", }) }, initRequest: func(r *Request) { r.SetQueryParams(map[string]string{ "foo": "3", }) r.URL = "https://example.com/" }, expectedURL: "https://example.com/?foo=3&bar=2", }, { name: "adding query parameters by request to URL with existent", initRequest: func(r *Request) { r.SetQueryParams(map[string]string{ "bar": "2", }) r.URL = "https://example.com/?foo=1" }, expectedURL: "https://example.com/?foo=1&bar=2", }, { name: "adding query parameters by request with multiple values", initRequest: func(r *Request) { r.QueryParams.Add("foo", "1") r.QueryParams.Add("foo", "2") r.URL = "https://example.com/" }, expectedURL: "https://example.com/?foo=1&foo=2", }, { name: "unescape query params", initClient: func(c *Client) { c.SetBaseURL("https://example.com/"). SetQueryParamsUnescape(true). // this line is just code coverage; I will restructure this test in v3 for the client and request the respective init method SetQueryParam("fromclient", "hey unescape"). SetQueryParam("initone", "cáfe") }, initRequest: func(r *Request) { r.SetQueryParamsUnescape(true) // this line takes effect r.SetQueryParams( map[string]string{ "registry": "nacos://test:6801", // GH #797 }, ) }, expectedURL: "https://example.com?initone=cáfe&fromclient=hey+unescape®istry=nacos://test:6801", }, } { t.Run(tt.name, func(t *testing.T) { c := New() if tt.initClient != nil { tt.initClient(c) } r := c.R() if tt.initRequest != nil { tt.initRequest(r) } if err := parseRequestURL(c, r); err != nil { t.Errorf("parseRequestURL() error = %v", err) } // compare URLs without query parameters first // then compare query parameters, because the order of the items in a map is not guarantied expectedURL, _ := url.Parse(tt.expectedURL) expectedQuery := expectedURL.Query() expectedURL.RawQuery = "" actualURL, _ := url.Parse(r.URL) actualQuery := actualURL.Query() actualURL.RawQuery = "" if expectedURL.String() != actualURL.String() { t.Errorf("r.URL = %q does not match expected %q", r.URL, tt.expectedURL) } if !reflect.DeepEqual(expectedQuery, actualQuery) { t.Errorf("r.URL = %q does not match expected %q", r.URL, tt.expectedURL) } }) } } func Test_parseRequestHeader(t *testing.T) { for _, tt := range []struct { name string init func(c *Client, r *Request) expectedHeader http.Header }{ { name: "headers in request", init: func(c *Client, r *Request) { r.SetHeaders(map[string]string{ "foo": "1", "bar": "2", }) }, expectedHeader: http.Header{ http.CanonicalHeaderKey("foo"): []string{"1"}, http.CanonicalHeaderKey("bar"): []string{"2"}, hdrUserAgentKey: []string{hdrUserAgentValue}, }, }, { name: "headers in client", init: func(c *Client, r *Request) { c.SetHeaders(map[string]string{ "foo": "1", "bar": "2", }) }, expectedHeader: http.Header{ http.CanonicalHeaderKey("foo"): []string{"1"}, http.CanonicalHeaderKey("bar"): []string{"2"}, hdrUserAgentKey: []string{hdrUserAgentValue}, }, }, { name: "headers in client and request", init: func(c *Client, r *Request) { c.SetHeaders(map[string]string{ "foo": "1", // ignored, because of the same header in the request "bar": "2", }) r.SetHeaders(map[string]string{ "foo": "3", "xyz": "4", }) }, expectedHeader: http.Header{ http.CanonicalHeaderKey("foo"): []string{"3"}, http.CanonicalHeaderKey("bar"): []string{"2"}, http.CanonicalHeaderKey("xyz"): []string{"4"}, hdrUserAgentKey: []string{hdrUserAgentValue}, }, }, { name: "no headers", init: func(c *Client, r *Request) {}, expectedHeader: http.Header{ hdrUserAgentKey: []string{hdrUserAgentValue}, }, }, { name: "user agent", init: func(c *Client, r *Request) { c.SetHeader(hdrUserAgentKey, "foo bar") }, expectedHeader: http.Header{ http.CanonicalHeaderKey(hdrUserAgentKey): []string{"foo bar"}, }, }, { name: "json content type", init: func(c *Client, r *Request) { c.SetHeader(hdrContentTypeKey, "application/json") }, expectedHeader: http.Header{ hdrContentTypeKey: []string{"application/json"}, hdrUserAgentKey: []string{hdrUserAgentValue}, }, }, } { t.Run(tt.name, func(t *testing.T) { c := New() r := c.R() tt.init(c, r) // add common expected headers from client into expectedHeader tt.expectedHeader.Set(hdrAcceptEncodingKey, c.ContentDecompresserKeys()) if err := parseRequestHeader(c, r); err != nil { t.Errorf("parseRequestHeader() error = %v", err) } if !reflect.DeepEqual(tt.expectedHeader, r.Header) { t.Errorf("r.Header = %#+v does not match expected %#+v", r.Header, tt.expectedHeader) } }) } } func TestParseRequestBody(t *testing.T) { for _, tt := range []struct { name string initClient func(c *Client) initRequest func(r *Request) expectedBodyBuf []byte expectedContentLength int64 expectedContentType string wantErr bool }{ { name: "empty body", }, { name: "empty body with SetContentLength by request", expectedContentLength: 0, }, { name: "string body", initRequest: func(r *Request) { r.SetMethod(MethodPost). SetBody("foo") }, expectedBodyBuf: []byte("foo"), expectedContentType: plainTextType, expectedContentLength: 3, }, { name: "string body with GET method", initRequest: func(r *Request) { r.SetBody("foo") r.Method = http.MethodGet }, }, { name: "string body with GET method and AllowMethodGetPayload by client", initClient: func(c *Client) { c.SetMethodGetAllowPayload(true) }, initRequest: func(r *Request) { r.SetBody("foo") r.Method = http.MethodGet }, expectedBodyBuf: []byte("foo"), expectedContentType: plainTextType, expectedContentLength: 3, }, { name: "string body with GET method and AllowMethodGetPayload by request", initRequest: func(r *Request) { r.SetMethodGetAllowPayload(true) r.SetBody("foo") r.Method = http.MethodGet }, expectedBodyBuf: []byte("foo"), expectedContentType: plainTextType, expectedContentLength: 3, }, { name: "string body with HEAD method", initRequest: func(r *Request) { r.SetBody("foo") r.Method = http.MethodHead }, }, { name: "string body with OPTIONS method", initRequest: func(r *Request) { r.SetBody("foo") r.Method = http.MethodOptions }, }, { name: "string body with POST method", initRequest: func(r *Request) { r.SetBody("foo") r.Method = http.MethodPost }, expectedBodyBuf: []byte("foo"), expectedContentType: plainTextType, expectedContentLength: 3, }, { name: "string body with PATCH method", initRequest: func(r *Request) { r.SetBody("foo") r.Method = http.MethodPatch }, expectedBodyBuf: []byte("foo"), expectedContentType: plainTextType, expectedContentLength: 3, }, { name: "string body with PUT method", initRequest: func(r *Request) { r.SetBody("foo") r.Method = http.MethodPut }, expectedBodyBuf: []byte("foo"), expectedContentType: plainTextType, expectedContentLength: 3, }, { name: "string body with DELETE method", initRequest: func(r *Request) { r.SetBody("foo") r.Method = http.MethodDelete }, expectedBodyBuf: nil, expectedContentType: "", }, { name: "string body with DELETE method with AllowMethodDeletePayload by request", initRequest: func(r *Request) { r.SetMethodDeleteAllowPayload(true) r.SetBody("foo") r.Method = http.MethodDelete }, expectedBodyBuf: []byte("foo"), expectedContentType: plainTextType, expectedContentLength: 3, }, { name: "string body with CONNECT method", initRequest: func(r *Request) { r.SetBody("foo") r.Method = http.MethodConnect }, expectedBodyBuf: nil, expectedContentType: "", }, { name: "string body with TRACE method", initRequest: func(r *Request) { r.SetBody("foo") r.Method = http.MethodTrace }, expectedBodyBuf: nil, expectedContentType: "", }, { name: "byte body with method post", initRequest: func(r *Request) { r.SetMethod(MethodPost). SetBody([]byte("foo")) }, expectedBodyBuf: []byte("foo"), expectedContentType: plainTextType, expectedContentLength: 3, }, { name: "io.Reader body, no bodyBuf with method put", initRequest: func(r *Request) { r.SetMethod(MethodPut). SetBody(bytes.NewBufferString("foo")) }, expectedContentType: jsonContentType, }, { name: "form data by request with method post", initRequest: func(r *Request) { r.SetMethod(MethodPost). SetFormData(map[string]string{ "foo": "1", "bar": "2", }) }, expectedBodyBuf: []byte("foo=1&bar=2"), expectedContentType: formContentType, expectedContentLength: 11, }, { name: "form data by client with method patch", initClient: func(c *Client) { c.SetFormData(map[string]string{ "foo": "1", "bar": "2", }) }, initRequest: func(r *Request) { r.SetMethod(MethodPatch) }, expectedBodyBuf: []byte("foo=1&bar=2"), expectedContentType: formContentType, expectedContentLength: 11, }, { name: "form data by client and request", initClient: func(c *Client) { c.SetFormData(map[string]string{ "foo": "1", "bar": "2", }) }, initRequest: func(r *Request) { r.SetMethod(MethodPatch). SetFormData(map[string]string{ "foo": "3", "baz": "4", }) }, expectedBodyBuf: []byte("foo=3&bar=2&baz=4"), expectedContentType: formContentType, expectedContentLength: 17, }, { name: "json from struct", initRequest: func(r *Request) { r.SetMethod(MethodPut) r.SetBody(struct { Foo string `json:"foo"` Bar string `json:"bar"` }{ Foo: "1", Bar: "2", }) }, expectedBodyBuf: append([]byte(`{"foo":"1","bar":"2"}`), '\n'), expectedContentType: jsonContentType, expectedContentLength: 22, }, { name: "json from slice", initRequest: func(r *Request) { r.SetMethod(MethodPost). SetBody([]string{"foo", "bar"}) }, expectedBodyBuf: append([]byte(`["foo","bar"]`), '\n'), expectedContentType: jsonContentType, expectedContentLength: 14, }, { name: "json from map", initRequest: func(r *Request) { r.SetMethod(MethodPost). SetBody(map[string]any{ "foo": "1", "bar": []int{1, 2, 3}, "baz": map[string]string{ "qux": "4", }, "xyz": nil, }) }, expectedBodyBuf: append([]byte(`{"bar":[1,2,3],"baz":{"qux":"4"},"foo":"1","xyz":null}`), '\n'), expectedContentType: jsonContentType, expectedContentLength: 55, }, { name: "json from map", initRequest: func(r *Request) { r.SetMethod(MethodPut). SetBody(map[string]any{ "foo": "1", "bar": []int{1, 2, 3}, "baz": map[string]string{ "qux": "4", }, "xyz": nil, }) }, expectedBodyBuf: append([]byte(`{"bar":[1,2,3],"baz":{"qux":"4"},"foo":"1","xyz":null}`), '\n'), expectedContentType: jsonContentType, expectedContentLength: 55, }, { name: "json from map", initRequest: func(r *Request) { r.SetMethod(MethodPost). SetBody(map[string]any{ "foo": "1", "bar": []int{1, 2, 3}, "baz": map[string]string{ "qux": "4", }, "xyz": nil, }) }, expectedBodyBuf: append([]byte(`{"bar":[1,2,3],"baz":{"qux":"4"},"foo":"1","xyz":null}`), '\n'), expectedContentType: jsonContentType, expectedContentLength: 55, }, { name: "xml from struct", initRequest: func(r *Request) { type FooBar struct { Foo string `xml:"foo"` Bar string `xml:"bar"` } r.SetMethod(MethodPatch). SetBody(FooBar{ Foo: "1", Bar: "2", }). SetHeader(hdrContentTypeKey, "text/xml") }, expectedBodyBuf: []byte(`12`), expectedContentType: "text/xml", expectedContentLength: 41, }, { name: "unsupported type", initRequest: func(r *Request) { r.SetMethod(MethodPost). SetBody(1) }, wantErr: true, }, { name: "unsupported xml", initRequest: func(r *Request) { r.SetMethod(MethodPut). SetBody(struct { Foo string `xml:"foo"` Bar string `xml:"bar"` }{ Foo: "1", Bar: "2", }). SetHeader(hdrContentTypeKey, "text/xml") }, wantErr: true, }, } { t.Run(tt.name, func(t *testing.T) { c := New() if tt.initClient != nil { tt.initClient(c) } r := c.R() if tt.initRequest != nil { tt.initRequest(r) } if err := parseRequestBody(c, r); err != nil { if tt.wantErr { return } t.Errorf("parseRequestBody() error = %v", err) } else if tt.wantErr { t.Errorf("wanted error, but got nil") } // obtain value, since this is only parse request body method test if r.bodyBuf != nil { r.contentLength = int64(r.bodyBuf.Len()) } switch { case r.bodyBuf == nil && tt.expectedBodyBuf != nil: t.Errorf("bodyBuf is nil, but expected: %s", string(tt.expectedBodyBuf)) case r.bodyBuf != nil && tt.expectedBodyBuf == nil: t.Errorf("bodyBuf is not nil, but expected nil: %s", r.bodyBuf.String()) case r.bodyBuf != nil && tt.expectedBodyBuf != nil: var actual, expected any = r.bodyBuf.Bytes(), tt.expectedBodyBuf if r.isFormData { var err error actual, err = url.ParseQuery(r.bodyBuf.String()) if err != nil { t.Errorf("ParseQuery(r.bodyBuf) error = %v", err) } expected, err = url.ParseQuery(string(tt.expectedBodyBuf)) if err != nil { t.Errorf("ParseQuery(tt.expectedBodyBuf) error = %v", err) } } else if r.isMultiPart { _, params, err := mime.ParseMediaType(r.Header.Get(hdrContentTypeKey)) if err != nil { t.Errorf("ParseMediaType(hdrContentTypeKey) error = %v", err) } boundary, ok := params["boundary"] if !ok { t.Errorf("boundary not found in Content-Type header") } reader := multipart.NewReader(r.bodyBuf, boundary) body := make(map[string]any) for part, perr := reader.NextPart(); perr != io.EOF; part, perr = reader.NextPart() { if perr != nil { t.Errorf("NextPart() error = %v", perr) } name := part.FormName() if name == "" { name = part.FileName() } data, err := io.ReadAll(part) if err != nil { t.Errorf("ReadAll(part) error = %v", err) } body[name] = string(data) } actual = body expected = nil if err := json.Unmarshal(tt.expectedBodyBuf, &expected); err != nil { t.Errorf("json.Unmarshal(tt.expectedBodyBuf) error = %v", err) } t.Logf(`in case of an error, the expected body should be set as json for object: %#+v`, actual) } if !reflect.DeepEqual(actual, expected) { t.Errorf("bodyBuf = %q does not match expected %q", r.bodyBuf.String(), string(tt.expectedBodyBuf)) } } if tt.expectedContentLength != r.contentLength { t.Errorf("Content length value = %v does not match expected %v", r.contentLength, tt.expectedContentLength) } if ct := r.Header.Get(hdrContentTypeKey); !((tt.expectedContentType == "" && ct != "") || strings.Contains(ct, tt.expectedContentType)) { t.Errorf("Content-Type header = %q does not match expected %q", r.Header.Get(hdrContentTypeKey), tt.expectedContentType) } }) } } func TestMiddlewareSaveToFileErrorCases(t *testing.T) { c := dcnl() tempDir := t.TempDir() errDirMsg := "test dir error" mkdirAll = func(_ string, _ os.FileMode) error { return errors.New(errDirMsg) } errFileMsg := "test file error" createFile = func(_ string) (*os.File, error) { return nil, errors.New(errFileMsg) } t.Cleanup(func() { mkdirAll = os.MkdirAll createFile = os.Create }) // dir create error req1 := c.R() req1.SetResponseSaveFileName(filepath.Join(tempDir, "new-res-dir", "sample.txt")) err1 := MiddlewareResponseSaveToFile(c, &Response{Request: req1}) assertEqual(t, errDirMsg, err1.Error()) // file create error req2 := c.R() req2.SetResponseSaveFileName(filepath.Join(tempDir, "sample.txt")) err2 := MiddlewareResponseSaveToFile(c, &Response{Request: req2}) assertEqual(t, errFileMsg, err2.Error()) } func TestMiddlewareSaveToFileCopyError(t *testing.T) { c := dcnl() tempDir := t.TempDir() errCopyMsg := "test copy error" ioCopy = func(dst io.Writer, src io.Reader) (written int64, err error) { return 0, errors.New(errCopyMsg) } t.Cleanup(func() { ioCopy = io.Copy }) // copy error req1 := c.R() req1.SetResponseSaveFileName(filepath.Join(tempDir, "new-res-dir", "sample.txt")) err1 := MiddlewareResponseSaveToFile(c, &Response{Request: req1, Body: io.NopCloser(bytes.NewBufferString("Test context"))}) assertEqual(t, errCopyMsg, err1.Error()) } func TestRequestURL_GH797(t *testing.T) { ts := createGetServer(t) defer ts.Close() c := dcnl(). SetBaseURL(ts.URL). SetQueryParamsUnescape(true). // this line is just code coverage; I will restructure this test in v3 for the client and request the respective init method SetQueryParam("fromclient", "hey unescape"). SetQueryParam("initone", "cáfe") resp, err := c.R(). SetQueryParamsUnescape(true). // this line takes effect SetQueryParams( map[string]string{ "registry": "nacos://test:6801", // GH #797 }, ). Get("/unescape-query-params") assertError(t, err) assertEqual(t, "query params looks good", resp.String()) } func TestMiddleware_multipartWriteFormData(t *testing.T) { c := dcnl() oldFunc := multipartWriteFormData errMsg := "test write form data error" multipartWriteFormData = func(*multipart.Writer, *Request) error { return errors.New(errMsg) } t.Cleanup(func() { multipartWriteFormData = oldFunc }) req := &Request{ Header: http.Header{}, isMultiPart: true, multipartFields: []*MultipartField{ { Name: "field1", Values: []string{"field1value1", "field1value2"}, }, }, } err := handleMultipart(c, req) assertNil(t, err) err = <-req.multipartErrChan assertNotNil(t, err) assertEqual(t, errMsg, err.Error()) } func TestMiddleware_multipartWriteField(t *testing.T) { c := dcnl() oldFunc := multipartWriteField errMsg := "test write field error" multipartWriteField = func(w *multipart.Writer, name, value string) error { return errors.New(errMsg) } t.Cleanup(func() { multipartWriteField = oldFunc }) req := &Request{ mu: new(sync.Mutex), Header: http.Header{}, isMultiPart: true, multipartFields: []*MultipartField{ { Name: "field1", Values: []string{"field1value1", "field1value2"}, }, }, } err := handleMultipart(c, req) assertNil(t, err) err = <-req.multipartErrChan assertNotNil(t, err) assertEqual(t, errMsg, err.Error()) } func TestMiddleware_multipartCreatePart(t *testing.T) { c := dcnl() oldFunc := multipartCreatePart errMsg := "test create part error" multipartCreatePart = func(w *multipart.Writer, h textproto.MIMEHeader) (io.Writer, error) { return nil, errors.New(errMsg) } t.Cleanup(func() { multipartCreatePart = oldFunc }) jsonStr1 := `{"input": {"name": "Uploaded document 1", "_filename" : ["file1.txt"]}}` req := &Request{ mu: new(sync.Mutex), Header: http.Header{}, isMultiPart: true, multipartFields: []*MultipartField{ { Name: "uploadManifest1", FileName: "upload-file-1.json", ContentType: "application/json", Reader: bytes.NewBufferString(jsonStr1), }, }, } err := handleMultipart(c, req) assertNil(t, err) err = <-req.multipartErrChan assertNotNil(t, err) assertEqual(t, errMsg, err.Error()) } func TestMiddleware_multipartCreatePart_WriteError(t *testing.T) { c := dcnl() oldFunc := multipartCreatePart multipartCreatePart = func(w *multipart.Writer, h textproto.MIMEHeader) (io.Writer, error) { return &mpWriterError{}, nil } t.Cleanup(func() { multipartCreatePart = oldFunc }) jsonStr1 := `{"input": {"name": "Uploaded document 1", "_filename" : ["file1.txt"]}}` req := &Request{ mu: new(sync.Mutex), Header: http.Header{}, isMultiPart: true, multipartFields: []*MultipartField{ { Name: "uploadManifest1", FileName: "upload-file-1.json", ContentType: "application/json", Reader: bytes.NewBufferString(jsonStr1), tempBuf: []byte("test data"), }, }, } err := handleMultipart(c, req) assertNil(t, err) err = <-req.multipartErrChan assertNotNil(t, err) assertEqual(t, "multipart write error", err.Error()) } func TestMiddlewareCoverage(t *testing.T) { c := dcnl() req1 := c.R() req1.URL = "//invalid-url .local" err1 := createRawRequest(c, req1) assertTrue(t, strings.Contains(err1.Error(), "invalid character"), "invalid URL error expected") } ================================================ FILE: multipart.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty import ( "fmt" "io" "net/http" "net/textproto" "os" "path/filepath" "strings" ) var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"") func escapeQuotes(s string) string { return quoteEscaper.Replace(s) } // MultipartField struct represents the multipart field to compose // all [io.Reader] capable input for multipart form request type MultipartField struct { // Name of the multipart field name that the server expects it Name string // FileName is used to set the file name we have to send to the server FileName string // ContentType is a multipart file content-type value. It is highly // recommended setting it if you know the content-type so that Resty // don't have to do additional computing to auto-detect (Optional) ContentType string // Reader is an input of [io.Reader] for multipart upload. It // is optional if you set the FilePath value Reader io.Reader // FilePath is a file path for multipart upload. It // is optional if you set the Reader value FilePath string // FileSize in bytes is used just for the information purpose of // sharing via [MultipartFieldCallbackFunc] (Optional) FileSize int64 // ProgressCallback function is used to provide live progress details // during a multipart upload (Optional) // // NOTE: It is recommended to set the FileSize value when using `MultipartField.Reader` // with `ProgressCallback` feature so that Resty sends the FileSize // value via [MultipartFieldProgress] ProgressCallback MultipartFieldCallbackFunc // Values field is used to provide form field value. (Optional, unless it's a form-data field) // // It is primarily added for ordered multipart form-data field use cases Values []string // tempBuf is used to preserve the byte(s) read from the file to detect the content type. // Or any possible read error early. tempBuf []byte } // Clone method returns the deep copy of m except [io.Reader]. func (mf *MultipartField) Clone() *MultipartField { mf2 := new(MultipartField) *mf2 = *mf return mf2 } func (mf *MultipartField) resetReader() error { if rs, ok := mf.Reader.(io.ReadSeeker); ok { _, err := rs.Seek(0, io.SeekStart) return err } return nil } func (mf *MultipartField) isValues() bool { return len(mf.Values) > 0 } func (mf *MultipartField) close() { closeq(mf.Reader) } func (mf *MultipartField) createHeader() textproto.MIMEHeader { h := make(textproto.MIMEHeader) if isStringEmpty(mf.FileName) { h.Set(hdrContentDisposition, fmt.Sprintf(`form-data; name="%s"`, escapeQuotes(mf.Name))) } else { h.Set(hdrContentDisposition, fmt.Sprintf(`form-data; name="%s"; filename="%s"`, escapeQuotes(mf.Name), escapeQuotes(mf.FileName))) } if !isStringEmpty(mf.ContentType) { h.Set(hdrContentTypeKey, mf.ContentType) } return h } func (mf *MultipartField) openFile() error { if isStringEmpty(mf.FilePath) || mf.Reader != nil { return nil } file, err := os.Open(mf.FilePath) if err != nil { return err } if isStringEmpty(mf.FileName) { mf.FileName = filepath.Base(mf.FilePath) } // if file open is success, stat will succeed fileStat, _ := file.Stat() mf.Reader = file mf.FileSize = fileStat.Size() return nil } func (mf *MultipartField) detectContentType() error { if !isStringEmpty(mf.ContentType) || mf.Reader == nil { return nil } p := make([]byte, 512) size, err := mf.Reader.Read(p) if err != nil && err != io.EOF { return err } mf.tempBuf = p[:size] mf.ContentType = http.DetectContentType(mf.tempBuf) return nil } func (mf *MultipartField) wrapProgressCallbackIfPresent(pw io.Writer) io.Writer { if mf.ProgressCallback == nil { return pw } return &multipartProgressWriter{ w: pw, f: func(pb int64) { mf.ProgressCallback(MultipartFieldProgress{ Name: mf.Name, FileName: mf.FileName, FileSize: mf.FileSize, Written: pb, }) }, } } // MultipartFieldCallbackFunc function used to transmit live multipart upload // progress in bytes count type MultipartFieldCallbackFunc func(MultipartFieldProgress) // MultipartFieldProgress struct used to provide multipart field upload progress // details via callback function type MultipartFieldProgress struct { Name string FileName string FileSize int64 Written int64 } // String method creates the string representation of [MultipartFieldProgress] func (mfp MultipartFieldProgress) String() string { return fmt.Sprintf("FieldName: %s, FileName: %s, FileSize: %v, Written: %v", mfp.Name, mfp.FileName, mfp.FileSize, mfp.Written) } type multipartProgressWriter struct { w io.Writer pb int64 f func(int64) } func (mpw *multipartProgressWriter) Write(p []byte) (n int, err error) { n, err = mpw.w.Write(p) if n <= 0 { return } mpw.pb += int64(n) mpw.f(mpw.pb) return } ================================================ FILE: multipart_test.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty import ( "bytes" "context" "errors" "io/fs" "mime/multipart" "net/http" "net/url" "os" "path/filepath" "strconv" "strings" "sync" "testing" "time" ) func TestMultipartFormDataAndUpload(t *testing.T) { ts := createFormPostServer(t) defer ts.Close() defer cleanupFiles(".testdata/upload") c := dcnl() c.SetFormData(map[string]string{"zip_code": "00001", "city": "Los Angeles"}) t.Run("form data and upload", func(t *testing.T) { resp, err := c.R(). SetFile("profile_img", filepath.Join(getTestDataPath(), "test-img.png")). Post(ts.URL + "/upload") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertTrue(t, strings.Contains(resp.String(), "test-img.png")) }) t.Run("request form data and upload", func(t *testing.T) { resp, err := c.R(). SetFormData(map[string]string{ "welcome1": "welcome value 1", "welcome2": "welcome value 2", "welcome3": "welcome value 3", }). SetFile("profile_img", filepath.Join(getTestDataPath(), "test-img.png")). Post(ts.URL + "/upload") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertTrue(t, strings.Contains(resp.String(), "test-img.png")) }) } func TestMultipartFormDataAndUploadMethodPatch(t *testing.T) { ts := createFormPatchServer(t) defer ts.Close() defer cleanupFiles(".testdata/upload") c := dcnl() c.SetFormData(map[string]string{"zip_code": "00001", "city": "Los Angeles"}) resp, err := c.R(). SetFormData(map[string]string{"zip_code": "00002", "city": "Los Angeles"}). SetFile("profile_img", filepath.Join(getTestDataPath(), "test-img.png")). Patch(ts.URL + "/upload") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertTrue(t, strings.Contains(resp.String(), "test-img.png")) } func TestMultipartUploadError(t *testing.T) { ts := createFormPostServer(t) defer ts.Close() defer cleanupFiles(".testdata/upload") c := dcnl() c.SetFormData(map[string]string{"zip_code": "00001", "city": "Los Angeles"}) resp, err := c.R(). SetFile("profile_img", filepath.Join(getTestDataPath(), "test-img-not-exists.png")). Post(ts.URL + "/upload") assertNotNil(t, err) assertNil(t, resp) assertEqual(t, true, errors.Is(err, fs.ErrNotExist)) } func TestMultipartUploadFiles(t *testing.T) { ts := createFormPostServer(t) defer ts.Close() defer cleanupFiles(".testdata/upload") basePath := getTestDataPath() c := dcnld() r := c.R(). SetFormDataFromValues(url.Values{ "first_name": []string{"Jeevanandam"}, "last_name": []string{"M"}, }). SetFiles(map[string]string{ "profile_img": filepath.Join(basePath, "test-img.png"), "notes": filepath.Join(basePath, "text-file.txt"), }) resp, err := r.Post(ts.URL + "/upload") responseStr := resp.String() _ = r.Clone(context.Background()) assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertTrue(t, strings.Contains(responseStr, "test-img.png")) assertTrue(t, strings.Contains(responseStr, "text-file.txt")) } func TestMultipartFilesAndFormDataEmptyGH1046(t *testing.T) { ts := createFormPostServer(t) defer ts.Close() defer cleanupFiles(".testdata/upload") basePath := getTestDataPath() c := dcnld() resp, err := c.R(). SetFiles(map[string]string{ "profile_img": filepath.Join(basePath, "test-img.png"), "notes": filepath.Join(basePath, "text-file.txt"), }). Post(ts.URL + "/upload") responseStr := resp.String() assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertTrue(t, strings.Contains(responseStr, "test-img.png")) assertTrue(t, strings.Contains(responseStr, "text-file.txt")) } func TestMultipartIoReaderFiles(t *testing.T) { ts := createFormPostServer(t) defer ts.Close() defer cleanupFiles(".testdata/upload") basePath := getTestDataPath() profileImgBytes, _ := os.ReadFile(filepath.Join(basePath, "test-img.png")) notesBytes, _ := os.ReadFile(filepath.Join(basePath, "text-file.txt")) // Just info values // file := File{ // Name: "test_file_name.jpg", // ParamName: "test_param", // Reader: bytes.NewBuffer([]byte("test bytes")), // } // t.Logf("File Info: %v", file.String()) c := dcnld() r := c.R(). SetFormData(map[string]string{"first_name": "Jeevanandam", "last_name": "M"}). SetFileReader("profile_img", "test-img.png", bytes.NewReader(profileImgBytes)). SetFileReader("notes", "text-file.txt", bytes.NewReader(notesBytes)) resp, err := r.Post(ts.URL + "/upload") responseStr := resp.String() _ = r.Clone(context.Background()) assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertTrue(t, strings.Contains(responseStr, "test-img.png")) assertTrue(t, strings.Contains(responseStr, "text-file.txt")) } func TestMultipartUploadFileNotOnGetOrDelete(t *testing.T) { ts := createFormPostServer(t) defer ts.Close() defer cleanupFiles(".testdata/upload") basePath := getTestDataPath() _, err := dcnldr(). SetFile("profile_img", filepath.Join(basePath, "test-img.png")). Get(ts.URL + "/upload") assertEqual(t, "resty: multipart is not allowed in HTTP verb: GET", err.Error()) _, err = dcnldr(). SetFile("profile_img", filepath.Join(basePath, "test-img.png")). Delete(ts.URL + "/upload") assertEqual(t, "resty: multipart is not allowed in HTTP verb: DELETE", err.Error()) var hook1Count int var hook2Count int _, err = dcnl(). OnInvalid(func(r *Request, err error) { assertEqual(t, "resty: multipart is not allowed in HTTP verb: HEAD", err.Error()) assertNotNil(t, r) hook1Count++ }). OnInvalid(func(r *Request, err error) { assertEqual(t, "resty: multipart is not allowed in HTTP verb: HEAD", err.Error()) assertNotNil(t, r) hook2Count++ }). R(). SetFile("profile_img", filepath.Join(basePath, "test-img.png")). Head(ts.URL + "/upload") assertEqual(t, "resty: multipart is not allowed in HTTP verb: HEAD", err.Error()) assertEqual(t, 1, hook1Count) assertEqual(t, 1, hook2Count) } func TestMultipartFormData(t *testing.T) { ts := createFormPostServer(t) defer ts.Close() resp, err := dcnldr(). SetMultipartFormData(map[string]string{"first_name": "Jeevanandam", "last_name": "M", "zip_code": "00001"}). SetBasicAuth("myuser", "mypass"). Post(ts.URL + "/profile") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "Success", resp.String()) } func TestMultipartFormDataFields(t *testing.T) { ts := createFormPostServer(t) defer ts.Close() fields := []*MultipartField{ { Name: "field1", Values: []string{"field1value1", "field1value2"}, }, { Name: "field1", Values: []string{"field1value3", "field1value4"}, }, { Name: "field3", Values: []string{"field3value1", "field3value2"}, }, { Name: "field4", Values: []string{"field4value1", "field4value2"}, }, } resp, err := dcnldr(). SetMultipartFields(fields...). SetBasicAuth("myuser", "mypass"). Post(ts.URL + "/profile") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "Success", resp.String()) } func TestMultipartField(t *testing.T) { ts := createFormPostServer(t) defer ts.Close() defer cleanupFiles(".testdata/upload") jsonBytes := []byte(`{"input": {"name": "Uploaded document", "_filename" : ["file.txt"]}}`) c := dcnld() r := c.R(). SetFormDataFromValues(url.Values{ "first_name": []string{"Jeevanandam"}, "last_name": []string{"M"}, }). SetMultipartField("uploadManifest", "upload-file.json", "application/json", bytes.NewReader(jsonBytes)) resp, err := r.Post(ts.URL + "/upload") responseStr := resp.String() _ = r.Clone(context.Background()) assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertTrue(t, strings.Contains(responseStr, "upload-file.json")) } func TestMultipartFields(t *testing.T) { ts := createFormPostServer(t) defer ts.Close() defer cleanupFiles(".testdata/upload") jsonStr1 := `{"input": {"name": "Uploaded document 1", "_filename" : ["file1.txt"]}}` jsonStr2 := `{"input": {"name": "Uploaded document 2", "_filename" : ["file2.txt"]}}` fields := []*MultipartField{ { Name: "uploadManifest1", FileName: "upload-file-1.json", ContentType: "application/json", Reader: bytes.NewBufferString(jsonStr1), }, { Name: "uploadManifest2", FileName: "upload-file-2.json", ContentType: "application/json", Reader: bytes.NewBufferString(jsonStr2), }, { Name: "uploadManifest3", ContentType: "application/json", Reader: bytes.NewBufferString(jsonStr2), }, } c := dcnld() r := c.R(). SetFormData(map[string]string{"first_name": "Jeevanandam", "last_name": "M"}). SetMultipartFields(fields...) resp, err := r.Post(ts.URL + "/upload") responseStr := resp.String() _ = r.Clone(context.Background()) assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertTrue(t, strings.Contains(responseStr, "upload-file-1.json")) assertTrue(t, strings.Contains(responseStr, "upload-file-2.json")) } func TestMultipartCustomBoundary(t *testing.T) { ts := createFormPostServer(t) defer ts.Close() defer cleanupFiles(".testdata/upload") t.Run("incorrect custom boundary", func(t *testing.T) { _, err := dcnldr(). SetMultipartFormData(map[string]string{"first_name": "Jeevanandam", "last_name": "M", "zip_code": "00001"}). SetMultipartBoundary(`"custom-boundary"`). SetBasicAuth("myuser", "mypass"). Post(ts.URL + "/profile") assertEqual(t, "mime: invalid boundary character", err.Error()) }) t.Run("correct custom boundary", func(t *testing.T) { resp, err := dcnldr(). SetMultipartFormData(map[string]string{"first_name": "Jeevanandam", "last_name": "M", "zip_code": "00001"}). SetMultipartBoundary("custom-boundary-" + strconv.FormatInt(time.Now().Unix(), 10)). Post(ts.URL + "/profile") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "Success", resp.String()) }) } func TestMultipartLargeFile(t *testing.T) { ts := createFileUploadServer(t) defer ts.Close() t.Run("upload a 2+mb image file with content-type and custom boundary", func(t *testing.T) { c := dcnl() resp, err := c.R(). SetFile("file", filepath.Join(getTestDataPath(), "test-img.png")). SetMultipartBoundary("custom-boundary-" + strconv.FormatInt(time.Now().Unix(), 10)). SetContentType("image/png"). Post(ts.URL + "/upload") assertNil(t, err) assertNotNil(t, resp) assertTrue(t, strings.Contains(resp.String(), "File Uploaded successfully, file size: 2579629")) // 2579697 }) t.Run("upload a 2+mb image file with content-type and incorrect custom boundary", func(t *testing.T) { c := dcnl() _, err := c.R(). SetFile("file", filepath.Join(getTestDataPath(), "test-img.png")). SetMultipartBoundary(`"custom-boundary-"` + strconv.FormatInt(time.Now().Unix(), 10)). SetContentType("image/png"). Post(ts.URL + "/upload") assertNotNil(t, err) assertEqual(t, "mime: invalid boundary character", err.Error()) }) t.Run("upload a 2+mb image file without content-type", func(t *testing.T) { c := dcnl() resp, err := c.R(). SetFile("file", filepath.Join(getTestDataPath(), "test-img.png")). Post(ts.URL + "/upload") assertNil(t, err) assertNotNil(t, resp) assertTrue(t, strings.Contains(resp.String(), "File Uploaded successfully, file size: 2579697")) }) t.Run("upload a 50+mb binary file", func(t *testing.T) { fp := createBinFile("50mbfile.bin", 50<<20) defer cleanupFiles(fp) c := dcnl() resp, err := c.R(). SetFile("file", fp). Post(ts.URL + "/upload") assertNil(t, err) assertNotNil(t, resp) assertTrue(t, strings.Contains(resp.String(), "File Uploaded successfully, file size: 52429044")) }) } func TestMultipartFieldProgressCallback(t *testing.T) { ts := createFormPostServer(t) defer ts.Close() defer cleanupFiles(".testdata/upload") file1, _ := os.Open(filepath.Join(getTestDataPath(), "test-img.png")) file1Stat, _ := file1.Stat() fileName2 := "50mbfile.bin" filePath2 := createBinFile(fileName2, 50<<20) defer cleanupFiles(filePath2) file2, _ := os.Open(filePath2) file2Stat, _ := file2.Stat() fileName3 := "100mbfile.bin" filePath3 := createBinFile(fileName3, 100<<20) defer cleanupFiles(filePath3) file3, _ := os.Open(filePath3) file3Stat, _ := file3.Stat() progressCallback := func(mp MultipartFieldProgress) { t.Logf("%s\n", mp) } fields := []*MultipartField{ { Name: "test-image", FilePath: filepath.Join(getTestDataPath(), "test-img.png"), ProgressCallback: progressCallback, }, { Name: "test-image-1", FileName: "test-image-1.png", ContentType: "image/png", Reader: file1, FileSize: file1Stat.Size(), ProgressCallback: progressCallback, }, { Name: "50mbfile", FileName: fileName2, Reader: file2, FileSize: file2Stat.Size(), ProgressCallback: progressCallback, }, { Name: "100mbfile", FileName: fileName3, Reader: file3, FileSize: file3Stat.Size(), ProgressCallback: progressCallback, }, } c := dcnld() r := c.R(). SetFormData(map[string]string{"first_name": "Jeevanandam", "last_name": "M"}). SetMultipartFields(fields...) resp, err := r.Post(ts.URL + "/upload") responseStr := resp.String() assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertTrue(t, strings.Contains(responseStr, "test-image-1.png")) assertTrue(t, strings.Contains(responseStr, "test-img.png")) assertTrue(t, strings.Contains(responseStr, "50mbfile.bin")) assertTrue(t, strings.Contains(responseStr, "100mbfile.bin")) } func TestMultipartOrderedFormData(t *testing.T) { ts := createFormPostServer(t) defer ts.Close() defer cleanupFiles(".testdata/upload") jsonStr1 := `{"input": {"name": "Uploaded document 1", "_filename" : ["file1.txt"]}}` jsonStr2 := `{"input": {"name": "Uploaded document 2", "_filename" : ["file2.txt"]}}` fields := []*MultipartField{ { Name: "field1", Values: []string{"field1value1", "field1value2"}, }, { Name: "field2", Values: []string{"field2value1", "field2value2"}, }, { Name: "uploadManifest1", FileName: "upload-file-1.json", ContentType: "application/json", Reader: bytes.NewBufferString(jsonStr1), }, { Name: "field3", Values: []string{"field3value1", "field3value2"}, }, { Name: "uploadManifest2", FileName: "upload-file-2.json", ContentType: "application/json", Reader: bytes.NewBufferString(jsonStr2), }, { Name: "field4", Values: []string{"field4value1", "field4value2"}, }, { Name: "uploadManifest3", ContentType: "application/json", Reader: bytes.NewBufferString(jsonStr2), }, } c := dcnld().SetBaseURL(ts.URL) resp, err := c.R(). SetMultipartOrderedFormData("first_name", []string{"Jeevanandam"}). SetMultipartOrderedFormData("last_name", []string{"M"}). SetMultipartFields(fields...). Post("/upload") responseStr := resp.String() assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertTrue(t, strings.Contains(responseStr, "upload-file-1.json")) assertTrue(t, strings.Contains(responseStr, "upload-file-2.json")) } var errTestErrorReader = errors.New("fake") type errorReader struct{} func (errorReader) Read(p []byte) (n int, err error) { return 0, errTestErrorReader } func TestMultipartReaderErrors(t *testing.T) { ts := createFileUploadServer(t) defer ts.Close() c := dcnl().SetBaseURL(ts.URL) t.Run("multipart fields with errorReader", func(t *testing.T) { resp, err := c.R(). SetMultipartFields(&MultipartField{ Name: "foo", ContentType: "text/plain", Reader: &errorReader{}, }). Post("/upload") assertNotNil(t, err) assertEqual(t, errTestErrorReader, err) assertNotNil(t, resp) err = resp.wrapError(errors.New("test error"), true) assertNil(t, err) assertEqual(t, "test error", resp.CascadeError.Error()) }) t.Run("multipart files with errorReader", func(t *testing.T) { resp, err := c.R(). SetFileReader("foo", "foo.txt", &errorReader{}). Post("/upload") assertNotNil(t, err) assertEqual(t, errTestErrorReader, err) assertNil(t, resp) }) t.Run("multipart with file not found", func(t *testing.T) { resp, err := c.R(). SetFile("foo", "foo.txt"). Post("/upload") assertNotNil(t, err) assertEqual(t, true, errors.Is(err, fs.ErrNotExist)) assertNil(t, resp) }) } type mpWriterError struct{} func (mwe *mpWriterError) Write(p []byte) (int, error) { return 0, errors.New("multipart write error") } func TestMultipartRequest_Errors(t *testing.T) { mw := multipart.NewWriter(&mpWriterError{}) c := dcnl() req1 := c.R().SetFormData(map[string]string{ "name1": "value1", "name2": "value2", }) t.Run("writeFormData", func(t *testing.T) { err1 := multipartWriteFormData(mw, req1) assertNotNil(t, err1) assertEqual(t, "multipart write error", err1.Error()) }) } func TestMultipartUploadFailAutoErrorParse(t *testing.T) { type ErrorResponse struct { Code int `json:"code"` Message string `json:"message"` } ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { w.Header().Set(hdrContentTypeKey, "application/json") w.WriteHeader(http.StatusForbidden) _, _ = w.Write([]byte(`{ "code": 403, "message": "forbidden error message" }`)) }) defer ts.Close() c := dcnl() t.Run("single request", func(t *testing.T) { res, err := c.R(). SetFile("profile_img", filepath.Join(getTestDataPath(), "test-img.png")). SetResultError(&ErrorResponse{}). Post(ts.URL) assertNil(t, err) assertEqual(t, http.StatusForbidden, res.StatusCode()) er := res.ResultError().(*ErrorResponse) assertEqual(t, 403, er.Code) assertEqual(t, "forbidden error message", er.Message) }) t.Run("concurrent requests", func(t *testing.T) { concurrencyCount := 50 wg := sync.WaitGroup{} for i := 0; i < concurrencyCount; i++ { wg.Add(1) go func() { defer wg.Done() res, _ := c.R(). SetFile("profile_img", filepath.Join(getTestDataPath(), "test-img.png")). SetResultError(&ErrorResponse{}). Post(ts.URL) er := res.ResultError().(*ErrorResponse) assertEqual(t, http.StatusForbidden, res.StatusCode()) assertEqual(t, 403, er.Code) assertEqual(t, "forbidden error message", er.Message) }() } wg.Wait() }) } func TestMultipartConcurrentRequests(t *testing.T) { ts := createFormPostServer(t) defer ts.Close() defer cleanupFiles(".testdata/upload") c := dcnl() c.SetFormData(map[string]string{"zip_code": "00001", "city": "Los Angeles"}) concurrencyCount := 100 wg := sync.WaitGroup{} for i := 0; i < concurrencyCount; i++ { wg.Add(1) go func() { defer wg.Done() res, err := c.R(). SetFormData(map[string]string{ "welcome1": "welcome value 1", "welcome2": "welcome value 2", "welcome3": "welcome value 3", }). SetFile("profile_img", filepath.Join(getTestDataPath(), "test-img.png")). Post(ts.URL + "/upload") assertError(t, err) assertEqual(t, http.StatusOK, res.StatusCode()) assertEqual(t, true, strings.Contains(res.String(), "test-img.png")) }() } wg.Wait() } type returnValueTestWriter struct { } func (z *returnValueTestWriter) Write(p []byte) (n int, err error) { return 0, nil } func TestMultipartCornerCoverage(t *testing.T) { mf := &MultipartField{ Name: "foo", Reader: bytes.NewBufferString("I have no seek capability"), } err := mf.resetReader() assertNil(t, err) // wrap test writer to return 0 written value mpw := multipartProgressWriter{w: &returnValueTestWriter{}} n, err := mpw.Write([]byte("test return value")) assertNil(t, err) assertEqual(t, 0, n) } ================================================ FILE: redirect.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty import ( "errors" "fmt" "net" "net/http" "strings" ) type ( // RedirectPolicy to regulate the redirects in the Resty client. // Objects implementing the [RedirectPolicy] interface can be registered as // // Apply function should return nil to continue the redirect journey; otherwise // return error to stop the redirect. RedirectPolicy interface { Apply(*http.Request, []*http.Request) error } // The [RedirectPolicyFunc] type is an adapter to allow the use of ordinary // functions as [RedirectPolicy]. If `f` is a function with the appropriate // signature, RedirectPolicyFunc(f) is a RedirectPolicy object that calls `f`. RedirectPolicyFunc func(*http.Request, []*http.Request) error // RedirectInfo struct is used to capture the URL and status code for the redirect history RedirectInfo struct { URL string StatusCode int } ) // Apply calls f(req, via). func (f RedirectPolicyFunc) Apply(req *http.Request, via []*http.Request) error { return f(req, via) } // RedirectNoPolicy is used to disable the redirects in the Resty client // // resty.SetRedirectPolicy(resty.RedirectNoPolicy()) func RedirectNoPolicy() RedirectPolicy { return RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse }) } // RedirectFlexiblePolicy method is convenient for creating several redirect policies for Resty clients. // // resty.SetRedirectPolicy(RedirectFlexiblePolicy(20)) func RedirectFlexiblePolicy(noOfRedirect int) RedirectPolicy { return RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error { if len(via) >= noOfRedirect { return fmt.Errorf("resty: stopped after %d redirects", noOfRedirect) } checkHostAndAddHeaders(req, via[0]) return nil }) } // RedirectDomainCheckPolicy method is convenient for defining domain name redirect rules in Resty clients. // Redirect is allowed only for the host mentioned in the policy. // // resty.SetRedirectPolicy(resty.RedirectDomainCheckPolicy("host1.com", "host2.org", "host3.net")) func RedirectDomainCheckPolicy(hostnames ...string) RedirectPolicy { hosts := make(map[string]bool) for _, h := range hostnames { hosts[strings.ToLower(h)] = true } return RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error { if ok := hosts[getHostname(req.URL.Host)]; !ok { return errors.New("redirect is not allowed as per DomainCheckRedirectPolicy") } checkHostAndAddHeaders(req, via[0]) return nil }) } func getHostname(host string) (hostname string) { if strings.Index(host, ":") > 0 { host, _, _ = net.SplitHostPort(host) } hostname = strings.ToLower(host) return } // By default, Golang will not redirect request headers. // After reading through the various discussion comments from the thread - // https://github.com/golang/go/issues/4800 // Resty will add all the headers during a redirect for the same host and // adds library user-agent if the Host is different. func checkHostAndAddHeaders(cur *http.Request, pre *http.Request) { curHostname := getHostname(cur.URL.Host) preHostname := getHostname(pre.URL.Host) if strings.EqualFold(curHostname, preHostname) { for key, val := range pre.Header { cur.Header[key] = val } } } ================================================ FILE: request.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty import ( "bytes" "context" "encoding/json" "encoding/xml" "errors" "fmt" "io" "maps" "net" "net/http" "net/url" "path/filepath" "reflect" "strings" "sync" "syscall" "time" ) //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ // Request struct and methods //_______________________________________________________________________ // Request struct is used to compose and fire individual requests from // Resty client. The [Request] provides an option to override client-level // settings and also an option for the request composition. type Request struct { // CorrelationID used to track/relate requests. // By default, Resty sets a GUID as the correlation ID for requests with retry count > 0. CorrelationID string URL string Method string AuthToken string AuthScheme string QueryParams url.Values FormData url.Values PathParams map[string]string Header http.Header StartTime time.Time Body any Result any ResultError any RawRequest *http.Request Cookies []*http.Cookie IsDebug bool IsCloseConnection bool IsResponseDoNotParse bool ResponseSaveFileName string ResponseExpectContentType string ResponseForceContentType string DebugBodyLimit int ResponseBodyLimit int64 IsResponseBodyUnlimitedReads bool IsTrace bool IsMethodGetAllowPayload bool IsMethodDeleteAllowPayload bool IsDone bool IsResponseSaveToFile bool Timeout time.Duration HeaderAuthorizationKey string RetryCount int RetryWaitTime time.Duration RetryMaxWaitTime time.Duration RetryDelayStrategy RetryDelayStrategyFunc IsRetryDefaultConditions bool IsRetryAllowNonIdempotent bool // Attempt provides insights into no. of attempts // Resty made. // // first attempt + retry count = total attempts Attempt int mu *sync.Mutex credentials *credentials isMultiPart bool isFormData bool isContentLengthSet bool contentLength int64 jsonEscapeHTML bool ctx context.Context ctxCancelFunc context.CancelFunc values map[string]any client *Client bodyBuf *bytes.Buffer trace *clientTrace log Logger baseURL string multipartBoundary string multipartFields []*MultipartField retryConditions []RetryConditionFunc isSetRetryConditions bool retryHooks []RetryHookFunc isSetRetryHooks bool curlCmdString string isCurlCmdGenerate bool isCurlCmdDebugLog bool unescapeQueryParams bool multipartErrChan chan error multipartCancelFunc context.CancelFunc } // SetCorrelationID method is used to set the correlation ID for the request // // By default, Resty sets a GUID as the correlation ID for requests with retry count > 0. func (r *Request) SetCorrelationID(id string) *Request { r.CorrelationID = id return r } // SetMethod method used to set the HTTP verb for the request func (r *Request) SetMethod(m string) *Request { r.Method = m return r } // SetURL method used to set the request URL for the request func (r *Request) SetURL(url string) *Request { r.URL = url return r } // Context method returns the request's [context.Context]. To change the context, use // [Request.Clone] or [Request.WithContext]. // // The returned context is always non-nil; it defaults to the // background context. func (r *Request) Context() context.Context { r.mu.Lock() defer r.mu.Unlock() if r.ctx == nil { return context.Background() } return r.ctx } // SetContext method sets the [context.Context] for current [Request]. // It overwrites the current context in the Request instance; it does not // affect the [Request].RawRequest that was already created. // // If you want this method to take effect, use this method before invoking // [Request.Send] or [Request].HTTPVerb methods. // // See [Request.WithContext], [Request.Clone] func (r *Request) SetContext(ctx context.Context) *Request { r.mu.Lock() defer r.mu.Unlock() r.ctx = ctx return r } // WithContext method returns a shallow copy of r with its context changed // to ctx. The provided ctx must be non-nil. It does not // affect the [Request].RawRequest that was already created. // // If you want this method to take effect, use this method before invoking // [Request.Send] or [Request].HTTPVerb methods. // // See [Request.SetContext], [Request.Clone] func (r *Request) WithContext(ctx context.Context) *Request { if ctx == nil { panic("resty: Request.WithContext nil context") } rr := new(Request) *rr = *r rr.ctx = ctx return rr } // SetContentType method is a convenient way to set the header Content-Type in the request // // client.R().SetContentType("application/json") func (r *Request) SetContentType(ct string) *Request { r.SetHeader(hdrContentTypeKey, ct) return r } // SetHeader method sets a single header field and its value in the current request. // // For Example: To set `Content-Type` and `Accept` as `application/json`. // // client.R(). // SetHeader("Content-Type", "application/json"). // SetHeader("Accept", "application/json") // // It overrides the header value set at the client instance level. func (r *Request) SetHeader(header, value string) *Request { r.Header.Set(header, value) return r } // SetHeaderAny method sets a single header field and its value in the current request. // // It is similar to [Request.SetHeader] but accepts any type as the value and converts // it to a string using predefined formatting rules (integers, bools, time.Time, etc.). // // For Example: To set `X-Request-Id` with an integer value // // client.R().SetHeaderAny("X-Request-Id", 12345) // // It overrides the header value set at the client instance level. // // See [Client.SetHeaderAny]. func (r *Request) SetHeaderAny(header string, value any) *Request { strVal := formatAnyToString(value) r.Header.Set(header, strVal) return r } // SetHeaders method sets multiple header fields and their values at one go in the current request. // // For Example: To set `Content-Type` and `Accept` as `application/json` // // client.R(). // SetHeaders(map[string]string{ // "Content-Type": "application/json", // "Accept": "application/json", // }) // // It overrides the header value set at the client instance level. func (r *Request) SetHeaders(headers map[string]string) *Request { for h, v := range headers { r.SetHeader(h, v) } return r } // SetHeaderMultiValues sets multiple header fields and their values as a list of strings in the current request. // // For Example: To set `Accept` as `text/html, application/xhtml+xml, application/xml;q=0.9, image/webp, */*;q=0.8` // // client.R(). // SetHeaderMultiValues(map[string][]string{ // "Accept": []string{"text/html", "application/xhtml+xml", "application/xml;q=0.9", "image/webp", "*/*;q=0.8"}, // }) // // It overrides the header value set at the client instance level. func (r *Request) SetHeaderMultiValues(headers map[string][]string) *Request { for key, values := range headers { r.SetHeader(key, strings.Join(values, ", ")) } return r } // SetHeaderVerbatim method is used to set the HTTP header key and value verbatim in the current request. // It is typically helpful for legacy applications or servers that require HTTP headers in a certain way // // For Example: To set header key as `all_lowercase`, `UPPERCASE`, and `x-cloud-trace-id` // // client.R(). // SetHeaderVerbatim("all_lowercase", "available"). // SetHeaderVerbatim("UPPERCASE", "available"). // SetHeaderVerbatim("x-cloud-trace-id", "798e94019e5fc4d57fbb8901eb4c6cae") // // It overrides the header value set at the client instance level. func (r *Request) SetHeaderVerbatim(header, value string) *Request { r.Header[header] = []string{value} return r } // SetHeaderVerbatimAny method sets the HTTP header key and value verbatim in the current request. // // It is similar to [Request.SetHeaderVerbatim] but accepts any type as the value and converts // it to a string using predefined formatting rules (integers, bools, time.Time, etc.). // // For Example: To set header key as `x-trace-id` with an integer value // // client.R().SetHeaderVerbatimAny("x-trace-id", 798940) // // It overrides the header value set at the client instance level. // // See [Client.SetHeaderVerbatimAny]. func (r *Request) SetHeaderVerbatimAny(header string, value any) *Request { strVal := formatAnyToString(value) r.Header[header] = []string{strVal} return r } // SetQueryParam method sets a single parameter and its value in the current request. // It will be formed as a query string for the request. // // For Example: `search=kitchen%20papers&size=large` in the URL after the `?` mark. // // client.R(). // SetQueryParam("search", "kitchen papers"). // SetQueryParam("size", "large") // // It overrides the query parameter value set at the client instance level. func (r *Request) SetQueryParam(param, value string) *Request { r.QueryParams.Set(param, value) return r } // SetQueryParamAny method sets a single query parameter and its value in the current request. // It will be formed as a query string for the request. // // It is similar to [Request.SetQueryParam] but accepts any type as the value and converts // it to a string using predefined formatting rules (integers, bools, time.Time, etc.). // // For Example: To set `page` and `active` query parameters // // client.R(). // SetQueryParamAny("page", 5). // SetQueryParamAny("active", true) // // It overrides the query parameter value set at the client instance level. // // See [Client.SetQueryParamAny]. func (r *Request) SetQueryParamAny(param string, value any) *Request { strVal := formatAnyToString(value) r.QueryParams.Set(param, strVal) return r } // SetQueryParams method sets multiple parameters and their values at one go in the current request. // It will be formed as a query string for the request. // // For Example: `search=kitchen%20papers&size=large` in the URL after the `?` mark. // // client.R(). // SetQueryParams(map[string]string{ // "search": "kitchen papers", // "size": "large", // }) // // It overrides the query parameter value set at the client instance level. func (r *Request) SetQueryParams(params map[string]string) *Request { for p, v := range params { r.SetQueryParam(p, v) } return r } // SetQueryParamsFromValues method appends multiple parameters with multi-value // ([url.Values]) at one go in the current request. It will be formed as // query string for the request. // // For Example: `status=pending&status=approved&status=open` in the URL after the `?` mark. // // client.R(). // SetQueryParamsFromValues(url.Values{ // "status": []string{"pending", "approved", "open"}, // }) // // It overrides the query parameter value set at the client instance level. func (r *Request) SetQueryParamsFromValues(params url.Values) *Request { for p, v := range params { for _, pv := range v { r.QueryParams.Add(p, pv) } } return r } // SetQueryString method provides the ability to use string as an input to set URL query string for the request. // // client.R(). // SetQueryString("productId=232&template=fresh-sample&cat=resty&source=google&kw=buy a lot more") // // It overrides the query parameter value set at the client instance level. func (r *Request) SetQueryString(query string) *Request { params, err := url.ParseQuery(strings.TrimSpace(query)) if err == nil { for p, v := range params { for _, pv := range v { r.QueryParams.Add(p, pv) } } } else { r.log.Errorf("%v", err) } return r } // SetFormData method sets form parameters and their values in the current request. // The request content type would be set as `application/x-www-form-urlencoded`. // // client.R(). // SetFormData(map[string]string{ // "access_token": "BC594900-518B-4F7E-AC75-BD37F019E08F", // "user_id": "3455454545", // }) // // It overrides the form data value set at the client instance level. // // See [Request.SetFormDataFromValues] for the same field name with multiple values. func (r *Request) SetFormData(data map[string]string) *Request { for k, v := range data { r.FormData.Set(k, v) } return r } // SetFormDataFromValues method appends multiple form parameters with multi-value // ([url.Values]) at one go in the current request. // // client.R(). // SetFormDataFromValues(url.Values{ // "search_criteria": []string{"book", "glass", "pencil"}, // }) // // It overrides the form data value set at the client instance level. func (r *Request) SetFormDataFromValues(data url.Values) *Request { for k, v := range data { for _, kv := range v { r.FormData.Add(k, kv) } } return r } // SetBody method sets the request body for the request. It supports various practical needs as easy. // It's quite handy and powerful. Supported request body data types are `string`, // `[]byte`, `struct`, `map`, `slice` and [io.Reader]. // // Body value can be pointer or non-pointer. Automatic marshalling for JSON and XML content type, if it is `struct`, `map`, or `slice`. // // NOTE: [io.Reader] is processed in bufferless mode while sending a request. // // For Example: // // `struct` gets marshaled based on the request header `Content-Type`. // // client.R(). // SetBody(User{ // Username: "jeeva@myjeeva.com", // Password: "welcome2resty", // }) // // 'map` gets marshaled based on the request header `Content-Type`. // // client.R(). // SetBody(map[string]any{ // "username": "jeeva@myjeeva.com", // "password": "welcome2resty", // "address": &Address{ // Address1: "1111 This is my street", // Address2: "Apt 201", // City: "My City", // State: "My State", // ZipCode: 00000, // }, // }) // // `string` as a body input. Suitable for any need as a string input. // // client.R(). // SetBody(`{ // "username": "jeeva@getrightcare.com", // "password": "admin" // }`) // // `[]byte` as a body input. Suitable for raw requests such as file upload, serialize & deserialize, etc. // // client.R(). // SetBody([]byte("This is my raw request, sent as-is")) // // and so on. func (r *Request) SetBody(body any) *Request { r.Body = body return r } // SetResult method registers the response `Result` object type for automatic // unmarshalling of the HTTP response if the response status code is // between 200 and 299, and the content type is either JSON or XML. // // Note: [Request.SetResult] input can be a pointer or non-pointer. // // The pointer with handle // // authToken := &AuthToken{} // client.R().SetResult(authToken) // // // Can be accessed via - // fmt.Println(authToken) OR fmt.Println(response.Result().(*AuthToken)) // // OR - // // The pointer without handle or non-pointer // // client.R().SetResult(&AuthToken{}) // // OR // client.R().SetResult(AuthToken{}) // // // Can be accessed via - // fmt.Println(response.Result().(*AuthToken)) func (r *Request) SetResult(v any) *Request { r.Result = getPointer(v) return r } // SetResultError method registers the response `ResultError` object type for automatic // unmarshalling for the request, if the response status code is greater than 399 and // the content type is either JSON or XML. // // NOTE: [Request.SetResultError] input can be a pointer or non-pointer. // // client.R().SetResultError(&AuthError{}) // // OR // client.R().SetResultError(AuthError{}) // // Accessing an unmarshalled error object from response instance. // // response.ResultError().(*AuthError) // // If this request ResultError object is nil, it will use the client-level error object // type if it is set. func (r *Request) SetResultError(err any) *Request { r.ResultError = getPointer(err) return r } // SetFile method sets a single file field name and its path for multipart upload. // // Resty provides an optional multipart live upload progress callback; // see method [Request.SetMultipartFields] // // client.R(). // SetFile("my_file", "/Users/jeeva/Gas Bill - Sep.pdf") func (r *Request) SetFile(fieldName, filePath string) *Request { r.isMultiPart = true r.multipartFields = append(r.multipartFields, &MultipartField{ Name: fieldName, FileName: filepath.Base(filePath), FilePath: filePath, }) return r } // SetFiles method sets multiple file field names and their paths for multipart uploads. // // Resty provides an optional multipart live upload progress callback; // see method [Request.SetMultipartFields] // // client.R(). // SetFiles(map[string]string{ // "my_file1": "/Users/jeeva/Gas Bill - Sep.pdf", // "my_file2": "/Users/jeeva/Electricity Bill - Sep.pdf", // "my_file3": "/Users/jeeva/Water Bill - Sep.pdf", // }) func (r *Request) SetFiles(files map[string]string) *Request { r.isMultiPart = true for f, fp := range files { r.multipartFields = append(r.multipartFields, &MultipartField{ Name: f, FileName: filepath.Base(fp), FilePath: fp, }) } return r } // SetFileReader method is to set a file using [io.Reader] for multipart upload. // // Resty provides an optional multipart live upload progress callback; // see method [Request.SetMultipartFields] // // client.R(). // SetFileReader("profile_img", "my-profile-img.png", bytes.NewReader(profileImgBytes)). // SetFileReader("notes", "user-notes.txt", bytes.NewReader(notesBytes)) func (r *Request) SetFileReader(fieldName, fileName string, reader io.Reader) *Request { r.SetMultipartField(fieldName, fileName, "", reader) return r } // SetMultipartFormData method allows simple form data to be attached to the request // as `multipart:form-data` func (r *Request) SetMultipartFormData(data map[string]string) *Request { r.isMultiPart = true for k, v := range data { r.FormData.Set(k, v) } return r } // SetMultipartOrderedFormData method allows add ordered form data to be attached to the request // as `multipart:form-data` func (r *Request) SetMultipartOrderedFormData(name string, values []string) *Request { r.isMultiPart = true r.multipartFields = append(r.multipartFields, &MultipartField{ Name: name, Values: values, }) return r } // SetMultipartField method sets custom data with Content-Type using [io.Reader] for multipart upload. // // Resty provides an optional multipart live upload progress callback; // see method [Request.SetMultipartFields] func (r *Request) SetMultipartField(fieldName, fileName, contentType string, reader io.Reader) *Request { r.isMultiPart = true r.multipartFields = append(r.multipartFields, &MultipartField{ Name: fieldName, FileName: fileName, ContentType: contentType, Reader: reader, }) return r } // SetMultipartFields method sets multiple data fields using [io.Reader] for multipart upload. // // Resty provides an optional multipart live upload progress count in bytes; see // [MultipartField].ProgressCallback and [MultipartFieldProgress] // // For Example: // // client.R().SetMultipartFields( // &resty.MultipartField{ // Name: "uploadManifest1", // FileName: "upload-file-1.json", // ContentType: "application/json", // Reader: strings.NewReader(`{"input": {"name": "Uploaded document 1", "_filename" : ["file1.txt"]}}`), // }, // &resty.MultipartField{ // Name: "uploadManifest2", // FileName: "upload-file-2.json", // ContentType: "application/json", // FilePath: "/path/to/upload-file-2.json", // }, // &resty.MultipartField{ // Name: "image-file1", // FileName: "image-file1.png", // ContentType: "image/png", // Reader: bytes.NewReader(fileBytes), // ProgressCallback: func(mp MultipartFieldProgress) { // // use the progress details // }, // }, // &resty.MultipartField{ // Name: "image-file2", // FileName: "image-file2.png", // ContentType: "image/png", // Reader: imageFile2, // instance of *os.File // ProgressCallback: func(mp MultipartFieldProgress) { // // use the progress details // }, // }) // // If you have a `slice` of fields already, then call- // // client.R().SetMultipartFields(fields...) func (r *Request) SetMultipartFields(fields ...*MultipartField) *Request { r.isMultiPart = true r.multipartFields = append(r.multipartFields, fields...) return r } // SetMultipartBoundary method sets the custom multipart boundary for the multipart request. // Typically, the `mime/multipart` package generates a random multipart boundary if not provided. func (r *Request) SetMultipartBoundary(boundary string) *Request { r.multipartBoundary = boundary return r } // SetContentLength method sets the given content length value in the HTTP request. // By default, Resty won't set `Content-Length`. // // client.R().SetContentLength(3486547657) func (r *Request) SetContentLength(v int64) *Request { r.contentLength = v r.isContentLengthSet = true return r } // SetBasicAuth method sets the basic authentication header in the current HTTP request. // // For Example: // // Authorization: Basic // // To set the header for username "go-resty" and password "welcome" // // client.R().SetBasicAuth("go-resty", "welcome") // // It overrides the credentials set by method [Client.SetBasicAuth]. func (r *Request) SetBasicAuth(username, password string) *Request { r.credentials = &credentials{Username: username, Password: password} return r } // SetAuthToken method sets the auth token header(Default Scheme: Bearer) in the current HTTP request. Header example: // // Authorization: Bearer // // For Example: To set auth token BC594900518B4F7EAC75BD37F019E08FBC594900518B4F7EAC75BD37F019E08F // // client.R().SetAuthToken("BC594900518B4F7EAC75BD37F019E08FBC594900518B4F7EAC75BD37F019E08F") // // It overrides the Auth token set by method [Client.SetAuthToken]. func (r *Request) SetAuthToken(authToken string) *Request { r.AuthToken = authToken return r } // SetAuthScheme method sets the auth token scheme type in the HTTP request. // // Example Header value structure: // // Authorization: // // For Example: To set the scheme to use OAuth // // client.R().SetAuthScheme("OAuth") // // // The outcome will be - // Authorization: OAuth // // Information about Auth schemes can be found in [RFC 7235], IANA [HTTP Auth schemes] // // It overrides the `Authorization` scheme set by method [Client.SetAuthScheme]. // // [RFC 7235]: https://tools.ietf.org/html/rfc7235 // [HTTP Auth schemes]: https://www.iana.org/assignments/http-authschemes/http-authschemes.xhtml#authschemes func (r *Request) SetAuthScheme(scheme string) *Request { r.AuthScheme = scheme return r } // SetHeaderAuthorizationKey method sets the given HTTP header name for Authorization in the request. // // It overrides the `Authorization` header name set by method [Client.SetHeaderAuthorizationKey]. // // client.R().SetHeaderAuthorizationKey("X-Custom-Authorization") func (r *Request) SetHeaderAuthorizationKey(k string) *Request { r.HeaderAuthorizationKey = k return r } // SetResponseSaveFileName method sets the output file for the current HTTP request. The current // HTTP response will be saved in the given file. It is similar to the `curl -o` flag. // // Absolute path or relative path can be used. // // If it is a relative path, then the output file goes under the output directory, as mentioned // in the [Client.SetResponseSaveDirectory]. // // client.R(). // SetResponseSaveFileName("/Users/jeeva/Downloads/ReplyWithHeader-v5.1-beta.zip"). // Get("http://bit.ly/1LouEKr") // // NOTE: In this scenario // - [Response.BodyBytes] might be nil. // - [Response].Body might have been already read. func (r *Request) SetResponseSaveFileName(file string) *Request { r.ResponseSaveFileName = file r.SetResponseSaveToFile(true) return r } // SetResponseSaveToFile method used to enable the save response option for the current requests // // client.R().SetResponseSaveToFile(true) // // Resty determines the save filename in the following order - // - [Request.SetResponseSaveFileName] // - Content-Disposition header // - Request URL using [path.Base] // - Request URL hostname if path is empty or "/" // // It overrides the value set at the client instance level, see [Client.SetResponseSaveToFile] func (r *Request) SetResponseSaveToFile(save bool) *Request { r.IsResponseSaveToFile = save return r } // SetCloseConnection method sets variable `Close` in HTTP request struct with the given // value. More info: https://golang.org/src/net/http/request.go // // It overrides the value set at the client instance level, see [Client.SetCloseConnection] func (r *Request) SetCloseConnection(close bool) *Request { r.IsCloseConnection = close return r } // SetResponseDoNotParse method instructs Resty not to parse the response body automatically. // // Resty exposes the raw response body as [io.ReadCloser]. If you use it, do not // forget to close the body, otherwise, you might get into connection leaks, and connection // reuse may not happen. // // NOTE: The default [Response] middlewares are not executed when using this option. User // takes over the control of handling response body from Resty. func (r *Request) SetResponseDoNotParse(notParse bool) *Request { r.IsResponseDoNotParse = notParse return r } // SetResponseBodyLimit method sets a maximum body size limit in bytes on response, // avoid reading too much data to memory. // // Client will return [resty.ErrResponseBodyTooLarge] if the body size of the body // in the uncompressed response is larger than the limit. // Body size limit will not be enforced in the following cases: // - ResponseBodyLimit <= 0, which is the default behavior. // - [Request.SetResponseSaveFileName] is called to save response data to the file. // - "DoNotParseResponse" is set for client or request. // // It overrides the value set at the client instance level, see [Client.SetResponseBodyLimit] func (r *Request) SetResponseBodyLimit(v int64) *Request { r.ResponseBodyLimit = v return r } // SetResponseBodyUnlimitedReads method is to turn on/off the response body in memory // that provides an ability to do unlimited reads. // // It overrides the value set at the client level; see [Client.SetResponseBodyUnlimitedReads] // // Unlimited reads are possible in a few scenarios, even without enabling it. // - When debug mode is enabled // // NOTE: Use with care // - Turning on this feature keeps the response body in memory, which might cause additional memory usage. func (r *Request) SetResponseBodyUnlimitedReads(b bool) *Request { r.IsResponseBodyUnlimitedReads = b return r } // SetPathParam method sets a single URL path key-value pair in the // Resty current request instance. // // client.R().SetPathParam("userId", "sample@sample.com") // // Result: // URL - /v1/users/{userId}/details // Composed URL - /v1/users/sample@sample.com/details // // client.R().SetPathParam("path", "groups/developers") // // Result: // URL - /v1/users/{path}/details // Composed URL - /v1/users/groups%2Fdevelopers/details // // It replaces the value of the key while composing the request URL. // The values will be escaped using function [url.PathEscape]. // // It overrides the path parameter set at the client instance level. func (r *Request) SetPathParam(param, value string) *Request { r.PathParams[param] = url.PathEscape(value) return r } // SetPathParamAny method sets a single URL path key-value pair in the // current request instance. // // It is similar to [Request.SetPathParam] but accepts any type as the value and converts // it to a string using predefined formatting rules (integers, bools, time.Time, etc.). // // client.R().SetPathParamAny("userId", 12345) // // Result: // URL - /v1/users/{userId}/details // Composed URL - /v1/users/12345/details // // It replaces the value of the key while composing the request URL. // The value will be escaped using [url.PathEscape] function. // // It overrides the path parameter set at the client instance level. // // See [Client.SetPathParamAny]. func (r *Request) SetPathParamAny(param string, value any) *Request { strVal := formatAnyToString(value) r.PathParams[param] = url.PathEscape(strVal) return r } // SetPathParams method sets multiple URL path key-value pairs at one go in the // Resty current request instance. // // client.R().SetPathParams(map[string]string{ // "userId": "sample@sample.com", // "subAccountId": "100002", // "path": "groups/developers", // }) // // Result: // URL - /v1/users/{userId}/{subAccountId}/{path}/details // Composed URL - /v1/users/sample@sample.com/100002/groups%2Fdevelopers/details // // It replaces the value of the key while composing the request URL. // The values will be escaped using function [url.PathEscape]. // // It overrides the path parameter set at the client instance level. func (r *Request) SetPathParams(params map[string]string) *Request { for p, v := range params { r.SetPathParam(p, v) } return r } // SetPathRawParam method sets a single URL path key-value pair in the // Resty current request instance without path escape. // // client.R().SetPathRawParam("userId", "sample@sample.com") // // Result: // URL - /v1/users/{userId}/details // Composed URL - /v1/users/sample@sample.com/details // // client.R().SetPathRawParam("path", "groups/developers") // // Result: // URL - /v1/users/{path}/details // Composed URL - /v1/users/groups/developers/details // // It replaces the value of the key while composing the request URL. // The value will be used as-is, no path escape applied. // // It overrides the raw path parameter set at the client instance level. func (r *Request) SetPathRawParam(param, value string) *Request { r.PathParams[param] = value return r } // SetPathRawParamAny method sets a single URL path key-value pair in the // current request instance without path escape. // // It is similar to [Request.SetPathRawParam] but accepts any type as the value and converts // it to a string using predefined formatting rules (integers, bools, time.Time, etc.). // // client.R().SetPathRawParamAny("userId", 12345) // // Result: // URL - /v1/users/{userId}/details // Composed URL - /v1/users/12345/details // // It replaces the value of the key while composing the request URL. // The value will be used as-is, no path escape applied. // // It overrides the raw path parameter set at the client instance level. // // See [Client.SetPathRawParamAny]. func (r *Request) SetPathRawParamAny(param string, value any) *Request { strVal := formatAnyToString(value) r.PathParams[param] = strVal return r } // SetPathRawParams method sets multiple URL path key-value pairs at one go in the // Resty current request instance without path escape. // // client.R().SetPathParams(map[string]string{ // "userId": "sample@sample.com", // "subAccountId": "100002", // "path": "groups/developers", // }) // // Result: // URL - /v1/users/{userId}/{subAccountId}/{path}/details // Composed URL - /v1/users/sample@sample.com/100002/groups/developers/details // // It replaces the value of the key while composing the request URL. // The value will be used as-is, no path escape applied. // // It overrides the raw path parameter set at the client instance level. func (r *Request) SetPathRawParams(params map[string]string) *Request { for p, v := range params { r.SetPathRawParam(p, v) } return r } // SetResponseExpectContentType method allows to provide fallback `Content-Type` // for automatic unmarshalling when the `Content-Type` response header is unavailable. func (r *Request) SetResponseExpectContentType(contentType string) *Request { r.ResponseExpectContentType = contentType return r } // SetResponseForceContentType method provides a strong sense of response `Content-Type` for // automatic unmarshalling. Resty gives this a higher priority than the `Content-Type` // response header. // // This means that if both [Request.SetResponseForceContentType] is set and // the response `Content-Type` is available, `SetResponseForceContentType` value will win. func (r *Request) SetResponseForceContentType(contentType string) *Request { r.ResponseForceContentType = contentType return r } // SetJSONEscapeHTML method enables or disables the HTML escape on JSON marshal. // By default, escape HTML is `true`. // // NOTE: This option only applies to the standard JSON Marshaller used by Resty. // // It overrides the value set at the client instance level, see [Client.SetJSONEscapeHTML] func (r *Request) SetJSONEscapeHTML(b bool) *Request { r.jsonEscapeHTML = b return r } // SetCookie method appends a single cookie in the current request instance. // // client.R().SetCookie(&http.Cookie{ // Name:"go-resty", // Value:"This is cookie value", // }) // // NOTE: Method appends the Cookie value into existing Cookie even if its already existing. func (r *Request) SetCookie(hc *http.Cookie) *Request { r.Cookies = append(r.Cookies, hc) return r } // SetCookies method sets an array of cookies in the current request instance. // // cookies := []*http.Cookie{ // &http.Cookie{ // Name:"go-resty-1", // Value:"This is cookie 1 value", // }, // &http.Cookie{ // Name:"go-resty-2", // Value:"This is cookie 2 value", // }, // } // // // Setting a cookies into resty's current request // client.R().SetCookies(cookies) // // NOTE: Method appends the Cookie value into existing Cookie even if its already existing. func (r *Request) SetCookies(rs []*http.Cookie) *Request { r.Cookies = append(r.Cookies, rs...) return r } // SetTimeout method is used to set a timeout for the current request // // client.R().SetTimeout(1 * time.Minute) // // It overrides the timeout set at the client instance level, See [Client.SetTimeout] // // NOTE: Resty uses [context.WithTimeout] on the request, it does not use [http.Client.Timeout] func (r *Request) SetTimeout(timeout time.Duration) *Request { r.Timeout = timeout return r } // SetLogger method sets given writer for logging Resty request and response details. // By default, requests and responses inherit their logger from the client. // // Compliant to interface [resty.Logger]. // // It overrides the logger value set at the client instance level. func (r *Request) SetLogger(l Logger) *Request { r.log = l return r } // SetDebug method enables the debug mode on the current request. It logs // the details current request and response. // // client.R().SetDebug(true) // // It overrides the debug value set at the client instance level. // - For [Request], it logs information such as HTTP verb, Relative URL path, // Host, Headers, and Body if it has one. // - For [Response], it logs information such as Status, Response Time, Headers, // and Body if it has one. func (r *Request) SetDebug(d bool) *Request { r.IsDebug = d return r } // AddRetryConditions method adds one or more retry condition functions into the request. // These retry conditions are executed to determine if the request can be retried. // The request will retry if any functions return `true`, otherwise return `false`. // // NOTE: // - Retry conditions are executed on each retry attempt. // - Default retry conditions are executed first. // - Client-level retry conditions are applied to all requests. // - Request-level retry conditions are executed before client-level retry conditions. // See [Client.AddRetryConditions], [Request.SetRetryConditions] // - Once a retry condition returns true, the remaining retry conditions are not executed. // - Retry conditions are executed in the order in which they are added. func (r *Request) AddRetryConditions(conditions ...RetryConditionFunc) *Request { r.retryConditions = append(r.retryConditions, conditions...) return r } // SetRetryConditions method overwrites the retry conditions in the request. // These retry conditions are executed to determine if the request can be retried. // The request will retry if any function returns `true`, otherwise return `false`. // // NOTE: // - It overwrites the existing retry conditions. // - See [Request.AddRetryConditions] method for more details. func (r *Request) SetRetryConditions(conditions ...RetryConditionFunc) *Request { r.retryConditions = conditions r.isSetRetryConditions = true return r } // AddRetryHooks method adds one or more side-effecting retry hooks in the request. // // NOTE: // - Retry hooks are executed on each retry attempt. // - The request-level retry hooks are executed first before client-level hooks. // See [Client.AddRetryHooks] // - Retry hooks are executed in the order in which they are added. func (r *Request) AddRetryHooks(hooks ...RetryHookFunc) *Request { r.retryHooks = append(r.retryHooks, hooks...) return r } // SetRetryHooks method overwrites side-effecting retry hooks in the request. // // NOTE: // - It overwrites the existing retry hooks. // - See [Request.AddRetryHooks] method for more details. func (r *Request) SetRetryHooks(hooks ...RetryHookFunc) *Request { r.retryHooks = hooks r.isSetRetryHooks = true return r } // SetRetryCount method enables retry on Resty client and allows you // to set no. of retry count. // // first attempt + retry count = total attempts // // See [Request.SetRetryDelayStrategy] // // NOTE: // - By default, Resty only does retry on idempotent HTTP verb, [RFC 9110 Section 9.2.2], [RFC 9110 Section 18.2] // // [RFC 9110 Section 9.2.2]: https://datatracker.ietf.org/doc/html/rfc9110.html#name-idempotent-methods // [RFC 9110 Section 18.2]: https://datatracker.ietf.org/doc/html/rfc9110.html#name-method-registration func (r *Request) SetRetryCount(count int) *Request { r.RetryCount = count return r } // SetRetryWaitTime method sets the default wait time for sleep before retrying // // Default is 100 milliseconds. func (r *Request) SetRetryWaitTime(waitTime time.Duration) *Request { r.RetryWaitTime = waitTime return r } // SetRetryMaxWaitTime method sets the max wait time for sleep before retrying // // Default is 2 seconds. func (r *Request) SetRetryMaxWaitTime(maxWaitTime time.Duration) *Request { r.RetryMaxWaitTime = maxWaitTime return r } // SetRetryDelayStrategy method used to set the custom Retry delay strategy on request, // it is used to get wait time before each retry. It overrides the retry delay // strategy set at the client instance level, see [Client.SetRetryDelayStrategy] // // By default, Resty employs the capped exponential backoff with a jitter delay strategy. func (r *Request) SetRetryDelayStrategy(rs RetryDelayStrategyFunc) *Request { r.RetryDelayStrategy = rs return r } // SetRetryDefaultConditions method is used to enable/disable the Resty's default // retry conditions on request level // // It overrides value set at the client instance level, see [Client.SetRetryDefaultConditions] func (r *Request) SetRetryDefaultConditions(b bool) *Request { r.IsRetryDefaultConditions = b return r } // SetRetryAllowNonIdempotent method is used to enable/disable non-idempotent HTTP // methods retry. By default, Resty only allows idempotent HTTP methods, see // [RFC 9110 Section 9.2.2], [RFC 9110 Section 18.2] // // It overrides value set at the client instance level, see [Client.SetRetryAllowNonIdempotent] // // [RFC 9110 Section 9.2.2]: https://datatracker.ietf.org/doc/html/rfc9110.html#name-idempotent-methods // [RFC 9110 Section 18.2]: https://datatracker.ietf.org/doc/html/rfc9110.html#name-method-registration func (r *Request) SetRetryAllowNonIdempotent(b bool) *Request { r.IsRetryAllowNonIdempotent = b return r } // SetTrace method is used to turn on/off the trace capability at the request level. // It provides an insight into the request lifecycle using [httptrace.ClientTrace]. // // client := resty.New() // defer client.Close() // // resp, err := client.R(). // SetTrace(true). // Get("https://httpbin.org/get") // fmt.Println("Error:", err) // fmt.Println("Trace Info:", resp.Request.TraceInfo()) // // See [Client.SetTrace] func (r *Request) SetTrace(t bool) *Request { r.IsTrace = t return r } // SetCurlCmdGenerate method is used to turn on/off the generate curl command for the current request. // // By default, Resty does not log the curl command in the debug log since it has the potential // to leak sensitive data unless explicitly enabled via [Request.SetCurlCmdDebugLog] or // [Client.SetCurlCmdDebugLog]. // // It overrides the options set by the [Client.SetCurlCmdGenerate] // // NOTE: Use with care. // - Potential to leak sensitive data from [Request] and [Response] in the debug log // when the debug log option is enabled. // - Additional memory usage since the request body was reread. // - curl body is not generated for [io.Reader] and multipart request flow. func (r *Request) SetCurlCmdGenerate(b bool) *Request { r.isCurlCmdGenerate = b return r } // SetCurlCmdDebugLog method enables the curl command to be logged in the debug log // for the current request. // // It can be overridden at the request level; see [Client.SetCurlCmdDebugLog] func (r *Request) SetCurlCmdDebugLog(b bool) *Request { r.isCurlCmdDebugLog = b return r } // CurlCmd method generates the curl command for the request. func (r *Request) CurlCmd() string { return r.generateCurlCommand() } func (r *Request) generateCurlCommand() string { if !r.isCurlCmdGenerate { return "" } if len(r.curlCmdString) > 0 { return r.curlCmdString } if r.RawRequest == nil { if err := r.client.executeRequestMiddlewares(r); err != nil { r.log.Errorf("%v", err) return "" } } r.curlCmdString = buildCurlCmd(r) return r.curlCmdString } // SetQueryParamsUnescape method sets the choice of unescape query parameters for the request URL. // To prevent broken URL, Resty replaces space (" ") with "+" in the query parameters. // // This method overrides the value set by [Client.SetQueryParamsUnescape] // // NOTE: Request failure is possible due to non-standard usage of Unescaped Query Parameters. func (r *Request) SetQueryParamsUnescape(unescape bool) *Request { r.unescapeQueryParams = unescape return r } // SetMethodGetAllowPayload method allows the GET method with payload on the request level. // By default, Resty does not allow. // // client.R().SetMethodGetAllowPayload(true) // // It overrides the option set by the [Client.SetMethodGetAllowPayload] func (r *Request) SetMethodGetAllowPayload(allow bool) *Request { r.IsMethodGetAllowPayload = allow return r } // SetMethodDeleteAllowPayload method allows the DELETE method with payload on the request level. // By default, Resty does not allow. // // client.R().SetMethodDeleteAllowPayload(true) // // More info, refer to GH#881 // // It overrides the option set by the [Client.SetMethodDeleteAllowPayload] func (r *Request) SetMethodDeleteAllowPayload(allow bool) *Request { r.IsMethodDeleteAllowPayload = allow return r } // TraceInfo method returns the trace info for the request. // If either the [Client.EnableTrace] or [Request.EnableTrace] function has not been called // before the request is made, an empty [resty.TraceInfo] object is returned. func (r *Request) TraceInfo() TraceInfo { ct := r.trace if ct == nil { return TraceInfo{} } ct.lock.RLock() defer ct.lock.RUnlock() ti := TraceInfo{ DNSLookup: 0, TCPConnTime: 0, ServerTime: 0, IsConnReused: ct.gotConnInfo.Reused, IsConnWasIdle: ct.gotConnInfo.WasIdle, ConnIdleTime: ct.gotConnInfo.IdleTime, RequestAttempt: r.Attempt, } if !ct.dnsStart.IsZero() && !ct.dnsDone.IsZero() { ti.DNSLookup = ct.dnsDone.Sub(ct.dnsStart) } if !ct.tlsHandshakeDone.IsZero() && !ct.tlsHandshakeStart.IsZero() { ti.TLSHandshake = ct.tlsHandshakeDone.Sub(ct.tlsHandshakeStart) } if !ct.gotFirstResponseByte.IsZero() && !ct.gotConn.IsZero() { ti.ServerTime = ct.gotFirstResponseByte.Sub(ct.gotConn) } // Calculate the total time accordingly when connection is reused, // and DNS start and get conn time may be zero if the request is invalid. // See issue #1016. requestStartTime := r.StartTime if ct.gotConnInfo.Reused && !ct.getConn.IsZero() { requestStartTime = ct.getConn } else if !ct.dnsStart.IsZero() { requestStartTime = ct.dnsStart } ti.TotalTime = ct.endTime.Sub(requestStartTime) // Only calculate on successful connections if !ct.connectDone.IsZero() { ti.TCPConnTime = ct.connectDone.Sub(ct.dnsDone) } // Only calculate on successful connections if !ct.gotConn.IsZero() { ti.ConnTime = ct.gotConn.Sub(ct.getConn) } // Only calculate on successful connections if !ct.gotFirstResponseByte.IsZero() { ti.ResponseTime = ct.endTime.Sub(ct.gotFirstResponseByte) } // Capture remote address info when connection is non-nil if ct.gotConnInfo.Conn != nil { ti.RemoteAddr = ct.gotConnInfo.Conn.RemoteAddr().String() } return ti } //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ // HTTP verb method starts here //_______________________________________________________________________ // Get method does GET HTTP request. It's defined in section 9.3.1 of [RFC 9110]. // // [RFC 9110]: https://datatracker.ietf.org/doc/html/rfc9110.html#section-9.3.1 func (r *Request) Get(url string) (*Response, error) { return r.Execute(MethodGet, url) } // Head method does HEAD HTTP request. It's defined in section 9.3.2 of [RFC 9110]. // // [RFC 9110]: https://datatracker.ietf.org/doc/html/rfc9110.html#section-9.3.2 func (r *Request) Head(url string) (*Response, error) { return r.Execute(MethodHead, url) } // Post method does POST HTTP request. It's defined in section 9.3.3 of [RFC 9110]. // // [RFC 9110]: https://datatracker.ietf.org/doc/html/rfc9110.html#section-9.3.3 func (r *Request) Post(url string) (*Response, error) { return r.Execute(MethodPost, url) } // Put method does PUT HTTP request. It's defined in section 9.3.4 of [RFC 9110]. // // [RFC 9110]: https://datatracker.ietf.org/doc/html/rfc9110.html#section-9.3.4 func (r *Request) Put(url string) (*Response, error) { return r.Execute(MethodPut, url) } // Patch method does PATCH HTTP request. It's defined in section 2 of [RFC 5789]. // // [RFC 5789]: https://datatracker.ietf.org/doc/html/rfc5789.html#section-2 func (r *Request) Patch(url string) (*Response, error) { return r.Execute(MethodPatch, url) } // Delete method does DELETE HTTP request. It's defined in section 9.3.5 of [RFC 9110]. // // [RFC 9110]: https://datatracker.ietf.org/doc/html/rfc9110.html#section-9.3.5 func (r *Request) Delete(url string) (*Response, error) { return r.Execute(MethodDelete, url) } // Options method does OPTIONS HTTP request. It's defined in section 9.3.7 of [RFC 9110]. // // [RFC 9110]: https://datatracker.ietf.org/doc/html/rfc9110.html#section-9.3.7 func (r *Request) Options(url string) (*Response, error) { return r.Execute(MethodOptions, url) } // Trace method does TRACE HTTP request. It's defined in section 9.3.8 of [RFC 9110]. // // [RFC 9110]: https://datatracker.ietf.org/doc/html/rfc9110.html#section-9.3.8 func (r *Request) Trace(url string) (*Response, error) { return r.Execute(MethodTrace, url) } // Send method performs the HTTP request using the method and URL already defined // for current [Request]. // // res, err := client.R(). // SetMethod(resty.MethodGet). // SetURL("http://httpbin.org/get"). // Send() func (r *Request) Send() (*Response, error) { return r.Execute(r.Method, r.URL) } // Execute method performs the HTTP request with the given HTTP method and URL // for current [Request]. // // resp, err := client.R().Execute(resty.MethodGet, "http://httpbin.org/get") func (r *Request) Execute(method, url string) (res *Response, err error) { defer func() { if rec := recover(); rec != nil { if err, ok := rec.(error); ok { r.client.onPanicHooks(r, err) } else { r.client.onPanicHooks(r, fmt.Errorf("panic %v", rec)) } panic(rec) } }() r.Method = method if r.RetryCount < 0 { r.RetryCount = 0 // default behavior is no retry } isIdempotent := r.isIdempotent() var backoff *backoffWithJitter if r.RetryCount > 0 && isIdempotent { backoff = newBackoffWithJitter(r.RetryWaitTime, r.RetryMaxWaitTime) r.SetCorrelationID(newGUID()) } retryConditions := append(r.retryConditions, r.client.retryConditions...) if r.isSetRetryConditions { retryConditions = r.retryConditions } retryHooks := append(r.retryHooks, r.client.retryHooks...) if r.isSetRetryHooks { retryHooks = r.retryHooks } isInvalidRequestErr := false // first attempt + retry count = total attempts for i := 0; i <= r.RetryCount; i++ { r.Attempt++ err = nil r.URL = url res, err = r.client.execute(r) if err != nil { if irErr, ok := err.(*invalidRequestError); ok { err = irErr.Err isInvalidRequestErr = true break } if r.Context().Err() != nil { if r.ctxCancelFunc != nil { r.ctxCancelFunc() r.ctxCancelFunc = nil } if !errors.Is(err, context.DeadlineExceeded) { err = wrapErrors(r.Context().Err(), err) break } } } // we have reached the maximum no. of requests // or request method is not an idempotent if r.Attempt-1 == r.RetryCount || !isIdempotent { break } if backoff != nil { needsRetry, isCtxDone := false, false // apply default retry conditions if r.IsRetryDefaultConditions { needsRetry = applyRetryDefaultConditions(res, err) } // apply user-defined retry conditions if default one // is still false if !needsRetry && res != nil { // run user-defined retry conditions for _, retryCondition := range retryConditions { if needsRetry = retryCondition(res, err); needsRetry { break } } } // retry not required stop here if !needsRetry { break } // by default reset file readers if err = r.resetFileReaders(); err != nil { // if any error in reset readers, stop here break } // run user-defined retry hooks for _, retryHookFunc := range retryHooks { retryHookFunc(res, err) } // let's drain the response body, before retry wait drainBody(res) waitDuration, waitErr := backoff.NextWaitDuration(r.client, res, err, r.Attempt) if waitErr != nil { // if any error in retry strategy, stop here err = wrapErrors(waitErr, err) break } timer := time.NewTimer(waitDuration) select { case <-r.Context().Done(): isCtxDone = true err = wrapErrors(r.Context().Err(), err) break case <-timer.C: } timer.Stop() if isCtxDone { break } } } if r.isMultiPart { for _, mf := range r.multipartFields { mf.close() } } r.IsDone = true if isInvalidRequestErr { r.client.onInvalidHooks(r, err) } else { r.client.onErrorHooks(r, res, err) } r.sendLoadBalancerFeedback(res, err) backToBufPool(r.bodyBuf) return } // Clone returns a deep copy of r with its context changed to ctx. // It does clone appropriate fields, reset, and reinitialize, so // [Request] can be used again. // // The body is not copied, but it's a reference to the original body. // // req := client.R(). // SetBody("body"). // SetHeader("header", "value") // clonedRequest := req.Clone(context.Background()) func (r *Request) Clone(ctx context.Context) *Request { if ctx == nil { panic("resty: Request.Clone nil context") } rr := new(Request) *rr = *r // set new context rr.ctx = ctx // RawRequest should not copied, since its created on request execution flow. rr.RawRequest = nil // clone values rr.Header = r.Header.Clone() rr.FormData = cloneURLValues(r.FormData) rr.QueryParams = cloneURLValues(r.QueryParams) rr.PathParams = maps.Clone(r.PathParams) // reset content length if not set by user if !r.isContentLengthSet { rr.contentLength = 0 } // clone basic auth if r.credentials != nil { rr.credentials = r.credentials.Clone() } // clone cookies if l := len(r.Cookies); l > 0 { rr.Cookies = make([]*http.Cookie, l) for _, cookie := range r.Cookies { rr.Cookies = append(rr.Cookies, cloneCookie(cookie)) } } // create new interface for result and error rr.Result = newInterface(r.Result) rr.ResultError = newInterface(r.ResultError) // clone multipart fields if l := len(r.multipartFields); l > 0 { rr.multipartFields = make([]*MultipartField, l) for i, mf := range r.multipartFields { rr.multipartFields[i] = mf.Clone() } } // reset values rr.StartTime = time.Time{} rr.Attempt = 0 rr.initTraceIfEnabled() r.values = make(map[string]any) r.multipartErrChan = nil r.ctxCancelFunc = nil // copy bodyBuf if r.bodyBuf != nil { rr.bodyBuf = acquireBuffer() rr.bodyBuf.Write(r.bodyBuf.Bytes()) } return rr } // Funcs method gets executed on request composition that passes the // current request instance to provided [RequestFunc], which could be // used to apply common/reusable logic to the given request instance. // // func addRequestContentType(r *Request) *Request { // return r.SetHeader("Content-Type", "application/json"). // SetHeader("Accept", "application/json") // } // // func addRequestQueryParams(page, size int) func(r *Request) *Request { // return func(r *Request) *Request { // return r.SetQueryParam("page", strconv.Itoa(page)). // SetQueryParam("size", strconv.Itoa(size)). // SetQueryParam("request_no", strconv.Itoa(int(time.Now().Unix()))) // } // } // // client.R(). // Funcs(addRequestContentType, addRequestQueryParams(1, 100)). // Get("https://localhost:8080/foobar") func (r *Request) Funcs(funcs ...RequestFunc) *Request { for _, f := range funcs { r = f(r) } return r } func (r *Request) fmtBodyString(sl int) (body string) { body = "***** NO CONTENT *****" if !r.isPayloadSupported() { return } if _, ok := r.Body.(io.Reader); ok { body = "***** BODY IS io.Reader *****" return } // multipart or form-data if r.isMultiPart || r.isFormData { bodySize := r.bodyBuf.Len() if bodySize > sl { body = fmt.Sprintf("***** REQUEST TOO LARGE (size - %d) *****", bodySize) return } body = r.bodyBuf.String() return } // request body data if r.Body == nil { return } var prtBodyBytes []byte var err error contentType := r.Header.Get(hdrContentTypeKey) ctKey := inferContentTypeMapKey(contentType) kind := inferKind(r.Body) if jsonKey == ctKey && (kind == reflect.Struct || kind == reflect.Map || kind == reflect.Slice) { buf := acquireBuffer() defer releaseBuffer(buf) if err = encodeJSONEscapeHTMLIndent(buf, &r.Body, false, " "); err == nil { prtBodyBytes = buf.Bytes() } } else if xmlKey == ctKey && kind == reflect.Struct { prtBodyBytes, err = xml.MarshalIndent(&r.Body, "", " ") } else { switch b := r.Body.(type) { case string: prtBodyBytes = []byte(b) if jsonKey == ctKey { prtBodyBytes = jsonIndent(prtBodyBytes) } case []byte: body = fmt.Sprintf("***** BODY IS byte(s) (size - %d) *****", len(b)) return } } bodySize := len(prtBodyBytes) if bodySize > sl { body = fmt.Sprintf("***** REQUEST TOO LARGE (size - %d) *****", bodySize) return } if prtBodyBytes != nil && err == nil { body = string(prtBodyBytes) } return } func (r *Request) initValuesMap() { if r.values == nil { r.values = make(map[string]any) } } func (r *Request) initTraceIfEnabled() { if r.IsTrace { r.trace = new(clientTrace) r.ctx = r.trace.createContext(r.Context()) } } func (r *Request) isHeaderExists(k string) bool { _, f := r.Header[k] return f } func (r *Request) isPayloadSupported() bool { if r.Method == "" { r.Method = MethodGet } if r.Method == MethodGet && r.IsMethodGetAllowPayload { return true } // More info, refer to GH#881 if r.Method == MethodDelete && r.IsMethodDeleteAllowPayload { return true } if r.Method == MethodPost || r.Method == MethodPut || r.Method == MethodPatch { return true } return false } func (r *Request) sendLoadBalancerFeedback(res *Response, err error) { if r.client.LoadBalancer() == nil { return } success := true // load balancer feedback mainly focuses on connection // failures and status code >= 500 // so that we can prevent sending the request to // that server which may fail if err != nil { var noe *net.OpError if errors.As(err, &noe) { success = !errors.Is(noe.Err, syscall.ECONNREFUSED) || noe.Timeout() } } if success && res != nil && (res.StatusCode() >= 500 && res.StatusCode() != http.StatusNotImplemented) { success = false } r.client.LoadBalancer().Feedback(&RequestFeedback{ BaseURL: r.baseURL, Success: success, Attempt: r.Attempt, }) } func (r *Request) resetFileReaders() error { for _, f := range r.multipartFields { if err := f.resetReader(); err != nil { return err } } return nil } // https://datatracker.ietf.org/doc/html/rfc9110.html#name-idempotent-methods // https://datatracker.ietf.org/doc/html/rfc9110.html#name-method-registration var idempotentMethods = map[string]struct{}{ MethodDelete: {}, MethodGet: {}, MethodHead: {}, MethodOptions: {}, MethodPut: {}, MethodTrace: {}, } func (r *Request) isIdempotent() bool { _, found := idempotentMethods[r.Method] return found || r.IsRetryAllowNonIdempotent } func (r *Request) withTimeout() *http.Request { if _, found := r.Context().Deadline(); found { return r.RawRequest } if r.Timeout > 0 { ctx, ctxCancelFunc := context.WithTimeout(r.Context(), r.Timeout) r.ctxCancelFunc = ctxCancelFunc return r.RawRequest.WithContext(ctx) } return r.RawRequest } func jsonIndent(v []byte) []byte { buf := acquireBuffer() defer releaseBuffer(buf) if err := json.Indent(buf, v, "", " "); err != nil { return v } return buf.Bytes() } ================================================ FILE: request_test.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty import ( "bytes" "context" "crypto/tls" "encoding/xml" "errors" "io" "net/http" "net/http/httptest" "net/url" "os" "path/filepath" "regexp" "strconv" "strings" "sync" "testing" "time" ) type AuthSuccess struct { ID string `xml:"Id"` Message string `xml:"Message"` } type AuthError struct { ID, Message string } func TestGet(t *testing.T) { ts := createGetServer(t) defer ts.Close() resp, err := dcnl().R(). SetQueryParam("request_no", strconv.FormatInt(time.Now().Unix(), 10)). Get(ts.URL + "/") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "HTTP/1.1", resp.Proto()) assertEqual(t, "200 OK", resp.Status()) assertEqual(t, "TestGet: text response", resp.String()) logResponse(t, resp) } func TestGetGH524(t *testing.T) { ts := createGetServer(t) defer ts.Close() resp, err := dcnl().R(). SetPathParams((map[string]string{ "userId": "sample@sample.com", "subAccountId": "100002", "path": "groups/developers", })). SetQueryParam("request_no", strconv.FormatInt(time.Now().Unix(), 10)). SetDebug(true). Get(ts.URL + "/v1/users/{userId}/{subAccountId}/{path}/details") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, resp.Request.Header.Get("Content-Type"), "") // unable to reproduce reported issue } func TestRequestNegativeRetryCount(t *testing.T) { ts := createGetServer(t) defer ts.Close() resp, err := dcnl().SetRetryCount(-1).R().Get(ts.URL + "/") assertNil(t, err) assertNotNil(t, resp) assertEqual(t, "TestGet: text response", resp.String()) } func TestGetCustomUserAgent(t *testing.T) { ts := createGetServer(t) defer ts.Close() resp, err := dcnlr(). SetHeader(hdrUserAgentKey, "Test Custom User agent"). SetQueryParam("request_no", strconv.FormatInt(time.Now().Unix(), 10)). Get(ts.URL + "/") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "HTTP/1.1", resp.Proto()) assertEqual(t, "200 OK", resp.Status()) assertEqual(t, "TestGet: text response", resp.String()) logResponse(t, resp) } func TestGetClientParamRequestParam(t *testing.T) { ts := createGetServer(t) defer ts.Close() c := dcnl() c.SetQueryParam("client_param", "true"). SetQueryParams(map[string]string{"req_1": "jeeva", "req_3": "jeeva3"}). SetDebug(true) c.outputLogTo(io.Discard) resp, err := c.R(). SetQueryParams(map[string]string{"req_1": "req 1 value", "req_2": "req 2 value"}). SetQueryParam("request_no", strconv.FormatInt(time.Now().Unix(), 10)). SetHeader(hdrUserAgentKey, "Test Custom User agent"). Get(ts.URL + "/") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "HTTP/1.1", resp.Proto()) assertEqual(t, "200 OK", resp.Status()) assertEqual(t, "TestGet: text response", resp.String()) logResponse(t, resp) } func TestGetRelativePath(t *testing.T) { ts := createGetServer(t) defer ts.Close() c := dcnl() c.SetBaseURL(ts.URL) resp, err := c.R().Get("mypage2") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "TestGet: text response from mypage2", resp.String()) logResponse(t, resp) } func TestGet400Error(t *testing.T) { ts := createGetServer(t) defer ts.Close() resp, err := dcnlr().Get(ts.URL + "/mypage") assertError(t, err) assertEqual(t, http.StatusBadRequest, resp.StatusCode()) assertEqual(t, "", resp.String()) logResponse(t, resp) } func TestPostJSONStringSuccess(t *testing.T) { ts := createPostServer(t) defer ts.Close() c := dcnl() c.SetHeader(hdrContentTypeKey, "application/json; charset=utf-8"). SetHeaders(map[string]string{hdrUserAgentKey: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_5) go-resty v0.1", hdrAcceptKey: "application/json; charset=utf-8"}) resp, err := c.R(). SetBody(`{"username":"testuser", "password":"testpass"}`). Post(ts.URL + "/login") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) logResponse(t, resp) // PostJSONStringError resp, err = c.R(). SetBody(`{"username":"testuser" "password":"testpass"}`). Post(ts.URL + "/login") assertError(t, err) assertEqual(t, http.StatusBadRequest, resp.StatusCode()) logResponse(t, resp) } func TestPostJSONBytesSuccess(t *testing.T) { ts := createPostServer(t) defer ts.Close() c := dcnl() c.SetHeader(hdrContentTypeKey, "application/json; charset=utf-8"). SetHeaders(map[string]string{hdrUserAgentKey: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_10_5) go-resty v0.7", hdrAcceptKey: "application/json; charset=utf-8"}) resp, err := c.R(). SetBody([]byte(`{"username":"testuser", "password":"testpass"}`)). Post(ts.URL + "/login") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) logResponse(t, resp) } func TestPostJSONBytesIoReader(t *testing.T) { ts := createPostServer(t) defer ts.Close() c := dcnl() c.SetHeader(hdrContentTypeKey, "application/json; charset=utf-8") bodyBytes := []byte(`{"username":"testuser", "password":"testpass"}`) resp, err := c.R(). SetBody(bytes.NewReader(bodyBytes)). Post(ts.URL + "/login") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) logResponse(t, resp) } func TestPostJSONStructSuccess(t *testing.T) { ts := createPostServer(t) defer ts.Close() user := &credentials{Username: "testuser", Password: "testpass"} assertEqual(t, "Username: **********, Password: **********", user.String()) c := dcnl().SetJSONEscapeHTML(false) r := c.R(). SetHeader(hdrContentTypeKey, "application/json; charset=utf-8"). SetBody(user). SetResult(&AuthSuccess{}) rr := r.WithContext(context.Background()) resp, err := rr.Post(ts.URL + "/login") _ = rr.Clone(context.Background()) assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, int64(50), resp.Size()) t.Logf("Result Success: %q", resp.Result().(*AuthSuccess)) logResponse(t, resp) } func TestPostJSONRPCStructSuccess(t *testing.T) { ts := createPostServer(t) defer ts.Close() user := &credentials{Username: "testuser", Password: "testpass"} assertEqual(t, "Username: **********, Password: **********", user.String()) c := dcnl().SetJSONEscapeHTML(false) r := c.R(). SetHeader(hdrContentTypeKey, "application/json-rpc"). SetBody(user). SetResult(&AuthSuccess{}). SetQueryParam("ct", "rpc") rr := r.WithContext(context.Background()) resp, err := rr.Post(ts.URL + "/login") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, int64(50), resp.Size()) t.Logf("Result Success: %q", resp.Result().(*AuthSuccess)) logResponse(t, resp) } func TestPostJSONStructInvalidLogin(t *testing.T) { ts := createPostServer(t) defer ts.Close() c := dcnl() c.SetDebug(false) resp, err := c.R(). SetHeader(hdrContentTypeKey, "application/json; charset=utf-8"). SetBody(credentials{Username: "testuser", Password: "testpass1"}). SetResultError(AuthError{}). SetJSONEscapeHTML(false). Post(ts.URL + "/login") assertError(t, err) assertEqual(t, http.StatusUnauthorized, resp.StatusCode()) authError := resp.ResultError().(*AuthError) assertEqual(t, "unauthorized", authError.ID) assertEqual(t, "Invalid credentials", authError.Message) t.Logf("Result Error: %q", resp.ResultError().(*AuthError)) logResponse(t, resp) } func TestPostJSONErrorRFC7807(t *testing.T) { ts := createPostServer(t) defer ts.Close() c := dcnl() resp, err := c.R(). SetHeader(hdrContentTypeKey, "application/json; charset=utf-8"). SetBody(credentials{Username: "testuser", Password: "testpass1"}). SetResultError(AuthError{}). Post(ts.URL + "/login?ct=problem") assertError(t, err) assertEqual(t, http.StatusUnauthorized, resp.StatusCode()) authError := resp.ResultError().(*AuthError) assertEqual(t, "unauthorized", authError.ID) assertEqual(t, "Invalid credentials", authError.Message) t.Logf("Result Error: %q", resp.ResultError().(*AuthError)) logResponse(t, resp) } func TestPostJSONMapSuccess(t *testing.T) { ts := createPostServer(t) defer ts.Close() c := dcnl() c.SetDebug(false) resp, err := c.R(). SetBody(map[string]any{"username": "testuser", "password": "testpass"}). SetResult(AuthSuccess{}). Post(ts.URL + "/login") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) t.Logf("Result Success: %q", resp.Result().(*AuthSuccess)) logResponse(t, resp) } func TestPostJSONMapInvalidResponseJson(t *testing.T) { ts := createPostServer(t) defer ts.Close() resp, err := dcnldr(). SetBody(map[string]any{"username": "testuser", "password": "invalidjson"}). SetResult(&AuthSuccess{}). Post(ts.URL + "/login") assertEqual(t, "invalid character '}' looking for beginning of object key string", err.Error()) assertEqual(t, http.StatusOK, resp.StatusCode()) authSuccess := resp.Result().(*AuthSuccess) assertEqual(t, "", authSuccess.ID) assertEqual(t, "", authSuccess.Message) t.Logf("Result Success: %q", resp.Result().(*AuthSuccess)) logResponse(t, resp) } type brokenMarshalJSON struct{} func (b brokenMarshalJSON) MarshalJSON() ([]byte, error) { return nil, errors.New("b0rk3d") } func TestPostJSONMarshalError(t *testing.T) { ts := createPostServer(t) defer ts.Close() b := brokenMarshalJSON{} exp := "b0rk3d" _, err := dcnldr(). SetHeader(hdrContentTypeKey, "application/json"). SetBody(b). Post(ts.URL + "/login") if err == nil { t.Fatalf("expected error but got %v", err) } if !strings.Contains(err.Error(), exp) { t.Errorf("expected error string %q to contain %q", err, exp) } } func TestForceContentTypeForGH276andGH240(t *testing.T) { ts := createPostServer(t) defer ts.Close() retried := 0 c := dcnl() c.SetDebug(false) resp, err := c.R(). SetBody(map[string]any{"username": "testuser", "password": "testpass"}). SetResult(AuthSuccess{}). SetResponseForceContentType("application/json"). Post(ts.URL + "/login-json-html") assertNil(t, err) // JSON response comes with incorrect content-type, we correct it with ForceContentType assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, 0, retried) assertEqual(t, int64(50), resp.Size()) t.Logf("Result Success: %q", resp.Result().(*AuthSuccess)) logResponse(t, resp) } func TestPostXMLStringSuccess(t *testing.T) { ts := createPostServer(t) defer ts.Close() c := dcnl() c.SetDebug(false) resp, err := c.R(). SetHeader(hdrContentTypeKey, "application/xml"). SetBody(`testusertestpass`). SetQueryParam("request_no", strconv.FormatInt(time.Now().Unix(), 10)). Post(ts.URL + "/login") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, int64(116), resp.Size()) logResponse(t, resp) } type brokenMarshalXML struct{} func (b brokenMarshalXML) MarshalXML(e *xml.Encoder, start xml.StartElement) error { return errors.New("b0rk3d") } func TestPostXMLMarshalError(t *testing.T) { ts := createPostServer(t) defer ts.Close() b := brokenMarshalXML{} exp := "b0rk3d" _, err := dcnldr(). SetHeader(hdrContentTypeKey, "application/xml"). SetBody(b). Post(ts.URL + "/login") if err == nil { t.Fatalf("expected error but got %v", err) } if !strings.Contains(err.Error(), exp) { t.Errorf("expected error string %q to contain %q", err, exp) } } func TestPostXMLStringError(t *testing.T) { ts := createPostServer(t) defer ts.Close() resp, err := dcnldr(). SetHeader(hdrContentTypeKey, "application/xml"). SetBody(`testusertestpass`). Post(ts.URL + "/login") assertError(t, err) assertEqual(t, http.StatusBadRequest, resp.StatusCode()) assertEqual(t, `bad_requestUnable to read user info`, resp.String()) logResponse(t, resp) } func TestPostXMLBytesSuccess(t *testing.T) { ts := createPostServer(t) defer ts.Close() c := dcnl() c.SetDebug(false) resp, err := c.R(). SetHeader(hdrContentTypeKey, "application/xml"). SetBody([]byte(`testusertestpass`)). SetQueryParam("request_no", strconv.FormatInt(time.Now().Unix(), 10)). Post(ts.URL + "/login") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) logResponse(t, resp) } func TestPostXMLStructSuccess(t *testing.T) { ts := createPostServer(t) defer ts.Close() resp, err := dcnldr(). SetHeader(hdrContentTypeKey, "application/xml"). SetBody(credentials{Username: "testuser", Password: "testpass"}). SetResult(&AuthSuccess{}). Post(ts.URL + "/login") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) t.Logf("Result Success: %q", resp.Result().(*AuthSuccess)) logResponse(t, resp) } func TestPostXMLStructInvalidLogin(t *testing.T) { ts := createPostServer(t) defer ts.Close() c := dcnl() c.SetResultError(&AuthError{}) resp, err := c.R(). SetHeader(hdrContentTypeKey, "application/xml"). SetBody(credentials{Username: "testuser", Password: "testpass1"}). Post(ts.URL + "/login") assertError(t, err) assertEqual(t, http.StatusUnauthorized, resp.StatusCode()) assertEqual(t, resp.Header().Get("Www-Authenticate"), "Protected Realm") t.Logf("Result Error: %q", resp.ResultError().(*AuthError)) logResponse(t, resp) } func TestPostXMLStructInvalidResponseXml(t *testing.T) { ts := createPostServer(t) defer ts.Close() resp, err := dcnldr(). SetHeader(hdrContentTypeKey, "application/xml"). SetBody(credentials{Username: "testuser", Password: "invalidxml"}). SetResult(&AuthSuccess{}). Post(ts.URL + "/login") assertEqual(t, "XML syntax error on line 1: element closed by ", err.Error()) assertEqual(t, http.StatusOK, resp.StatusCode()) t.Logf("Result Success: %q", resp.Result().(*AuthSuccess)) logResponse(t, resp) } func TestPostXMLMapNotSupported(t *testing.T) { ts := createPostServer(t) defer ts.Close() _, err := dcnldr(). SetHeader(hdrContentTypeKey, "application/xml"). SetBody(map[string]any{"Username": "testuser", "Password": "testpass"}). Post(ts.URL + "/login") assertErrorIs(t, ErrUnsupportedRequestBodyKind, err) } func TestRequestBasicAuth(t *testing.T) { ts := createAuthServer(t) defer ts.Close() c := dcnl() c.SetBaseURL(ts.URL). SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}) resp, err := c.R(). SetBasicAuth("myuser", "basicauth"). SetResult(&AuthSuccess{}). Post("/login") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) t.Logf("Result Success: %q", resp.Result().(*AuthSuccess)) logResponse(t, resp) } func TestRequestBasicAuthWithBody(t *testing.T) { ts := createAuthServer(t) defer ts.Close() c := dcnl() c.SetBaseURL(ts.URL). SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}) resp, err := c.R(). SetBasicAuth("myuser", "basicauth"). SetBody([]string{strings.Repeat("hello", 25)}). SetResult(&AuthSuccess{}). Post("/login") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) t.Logf("Result Success: %q", resp.Result().(*AuthSuccess)) logResponse(t, resp) } func TestRequestInsecureBasicAuth(t *testing.T) { ts := createAuthServerTLSOptional(t, false) defer ts.Close() var logBuf bytes.Buffer logger := createLogger() logger.l.SetOutput(&logBuf) c := dcnl() c.SetBaseURL(ts.URL) resp, err := c.R(). SetBasicAuth("myuser", "basicauth"). SetResult(&AuthSuccess{}). SetLogger(logger). Post("/login") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertTrue(t, strings.Contains(logBuf.String(), "WARN RESTY Using sensitive credentials in HTTP mode is not secure. Use HTTPS")) t.Logf("Result Success: %q", resp.Result().(*AuthSuccess)) logResponse(t, resp) t.Logf("captured request-level logs: %s", logBuf.String()) } func TestRequestBasicAuthFail(t *testing.T) { ts := createAuthServer(t) defer ts.Close() c := dcnl() c.SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}). SetResultError(AuthError{}) resp, err := c.R(). SetBasicAuth("myuser", "basicauth1"). Post(ts.URL + "/login") assertError(t, err) assertEqual(t, http.StatusUnauthorized, resp.StatusCode()) t.Logf("Result Error: %q", resp.ResultError().(*AuthError)) logResponse(t, resp) } func TestRequestAuthToken(t *testing.T) { ts := createAuthServer(t) defer ts.Close() c := dcnl() c.SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}). SetAuthToken("004DDB79-6801-4587-B976-F093E6AC44FF") resp, err := c.R(). SetAuthToken("004DDB79-6801-4587-B976-F093E6AC44FF-Request"). Get(ts.URL + "/profile") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) } func TestRequestAuthScheme(t *testing.T) { ts := createAuthServer(t) defer ts.Close() c := dcnl() c.SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}). SetAuthScheme("OAuth"). SetAuthToken("004DDB79-6801-4587-B976-F093E6AC44FF") t.Run("override auth scheme", func(t *testing.T) { resp, err := c.R(). SetAuthScheme("Bearer"). SetAuthToken("004DDB79-6801-4587-B976-F093E6AC44FF-Request"). Get(ts.URL + "/profile") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) }) t.Run("empty auth scheme at client level GH954", func(t *testing.T) { tokenValue := "004DDB79-6801-4587-B976-F093E6AC44FF" // set client level c.SetAuthScheme(""). SetAuthToken(tokenValue) resp, err := c.R(). Get(ts.URL + "/profile") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "", resp.Request.Header.Get(hdrAuthorizationKey)) assertEqual(t, tokenValue, resp.Request.RawRequest.Header.Get(hdrAuthorizationKey)) }) t.Run("empty auth scheme at request level GH954", func(t *testing.T) { tokenValue := "004DDB79-6801-4587-B976-F093E6AC44FF" // set client level c := dcnl(). SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}). SetAuthToken(tokenValue) resp, err := c.R(). SetAuthScheme(""). Get(ts.URL + "/profile") assertError(t, err) assertEqual(t, "", resp.Request.Header.Get(hdrAuthorizationKey)) assertEqual(t, tokenValue, resp.Request.RawRequest.Header.Get(hdrAuthorizationKey)) }) t.Run("only client level auth token GH959", func(t *testing.T) { tokenValue := "004DDB79-6801-4587-B976-F093E6AC44FF" c := dcnl(). SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}). SetAuthToken(tokenValue) resp, err := c.R(). Get(ts.URL + "/profile") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "", resp.Request.Header.Get(hdrAuthorizationKey)) assertEqual(t, "Bearer "+tokenValue, resp.Request.RawRequest.Header.Get(hdrAuthorizationKey)) }) } func TestFormData(t *testing.T) { ts := createFormPostServer(t) defer ts.Close() c := dcnl() c.SetFormData(map[string]string{"zip_code": "00000", "city": "Los Angeles"}). SetDebug(true) c.outputLogTo(io.Discard) resp, err := c.R(). SetFormData(map[string]string{"first_name": "Jeevanandam", "last_name": "M", "zip_code": "00001"}). SetBasicAuth("myuser", "mypass"). Post(ts.URL + "/profile") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "Success", resp.String()) } func TestMultiValueFormData(t *testing.T) { ts := createFormPostServer(t) defer ts.Close() v := url.Values{ "search_criteria": []string{"book", "glass", "pencil"}, } c := dcnl() c.SetDebug(true) c.outputLogTo(io.Discard) resp, err := c.R(). SetQueryParamsFromValues(v). Post(ts.URL + "/search") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "Success", resp.String()) } func TestFormDataDisableWarn(t *testing.T) { ts := createFormPostServer(t) defer ts.Close() c := dcnl() c.SetFormData(map[string]string{"zip_code": "00000", "city": "Los Angeles"}). SetLoggerWarnLevel(true) c.outputLogTo(io.Discard) resp, err := c.R(). SetDebug(true). SetFormData(map[string]string{"first_name": "Jeevanandam", "last_name": "M", "zip_code": "00001"}). SetBasicAuth("myuser", "mypass"). Post(ts.URL + "/profile") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "Success", resp.String()) } func TestGetWithCookie(t *testing.T) { ts := createGetServer(t) defer ts.Close() c := dcnl() c.SetBaseURL(ts.URL) c.SetCookie(&http.Cookie{ Name: "go-resty-1", Value: "This is cookie 1 value", }) r := c.R(). SetCookie(&http.Cookie{ Name: "go-resty-2", Value: "This is cookie 2 value", }). SetCookies([]*http.Cookie{ { Name: "go-resty-1", Value: "This is cookie 1 value additional append", }, }) resp, err := r.Get("mypage2") _ = r.Clone(context.Background()) assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "TestGet: text response from mypage2", resp.String()) logResponse(t, resp) } func TestGetWithCookies(t *testing.T) { ts := createGetServer(t) defer ts.Close() c := dcnl() c.SetBaseURL(ts.URL).SetDebug(true) tu, _ := url.Parse(ts.URL) c.Client().Jar.SetCookies(tu, []*http.Cookie{ { Name: "jar-go-resty-1", Value: "From Jar - This is cookie 1 value", }, { Name: "jar-go-resty-2", Value: "From Jar - This is cookie 2 value", }, }) resp, err := c.R().SetHeader(hdrCookieKey, "").Get("mypage2") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) // Client cookies c.SetCookies([]*http.Cookie{ { Name: "go-resty-1", Value: "This is cookie 1 value", }, { Name: "go-resty-2", Value: "This is cookie 2 value", }, }) r := c.R(). SetCookie(&http.Cookie{ Name: "req-go-resty-1", Value: "This is request cookie 1 value additional append", }) resp, err = r.Get("mypage2") _ = r.Clone(context.Background()) assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "TestGet: text response from mypage2", resp.String()) logResponse(t, resp) } func TestPutPlainString(t *testing.T) { ts := createGenericServer(t) defer ts.Close() resp, err := dcnl().R(). SetBody("This is plain text body to server"). Put(ts.URL + "/plaintext") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "TestPut: plain text response", resp.String()) } func TestPutJSONString(t *testing.T) { ts := createGenericServer(t) defer ts.Close() client := dcnl() client.AddRequestMiddleware(func(c *Client, r *Request) error { r.SetHeader("X-Custom-Request-Middleware", "Request middleware") return nil }) client.AddRequestMiddleware(func(c *Client, r *Request) error { r.SetHeader("X-ContentLength", "Request middleware ContentLength set") return nil }) client.SetDebug(true) client.outputLogTo(io.Discard) resp, err := client.R(). SetHeaders(map[string]string{hdrContentTypeKey: "application/json; charset=utf-8", hdrAcceptKey: "application/json; charset=utf-8"}). SetBody(`{"content":"json content sending to server"}`). Put(ts.URL + "/json") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, `{"response":"json response"}`, resp.String()) } func TestPutXMLString(t *testing.T) { ts := createGenericServer(t) defer ts.Close() resp, err := dcnl().R(). SetHeaders(map[string]string{hdrContentTypeKey: "application/xml", hdrAcceptKey: "application/xml"}). SetBody(`XML Content sending to server`). Put(ts.URL + "/xml") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, `XML response`, resp.String()) } func TestRequestMiddleware(t *testing.T) { ts := createGenericServer(t) defer ts.Close() c := dcnl() c.AddRequestMiddleware(func(c *Client, r *Request) error { r.SetHeader("X-Custom-Request-Middleware", "Request middleware") return nil }) c.AddRequestMiddleware(func(c *Client, r *Request) error { r.SetHeader("X-ContentLength", "Request middleware ContentLength set") return nil }) resp, err := c.R(). SetBody("RequestMiddleware: This is plain text body to server"). Put(ts.URL + "/plaintext") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "TestPut: plain text response", resp.String()) } func TestHTTPAutoRedirectUpTo10(t *testing.T) { ts := createRedirectServer(t) defer ts.Close() res, err := dcnl().R().Get(ts.URL + "/redirect-1") redirects := res.RedirectHistory() assertEqual(t, 10, len(redirects)) finalReq := redirects[0] assertEqual(t, 307, finalReq.StatusCode) assertEqual(t, ts.URL+"/redirect-10", finalReq.URL) assertTrue(t, (err.Error() == "Get /redirect-11: stopped after 10 redirects" || err.Error() == "Get \"/redirect-11\": stopped after 10 redirects")) } func TestHostCheckRedirectPolicy(t *testing.T) { ts := createRedirectServer(t) defer ts.Close() c := dcnl(). SetRedirectPolicy(RedirectDomainCheckPolicy("127.0.0.1")) _, err := c.R().Get(ts.URL + "/redirect-host-check-1") assertNotNil(t, err) assertTrue(t, strings.Contains(err.Error(), "redirect is not allowed as per DomainCheckRedirectPolicy")) } func TestHttpMethods(t *testing.T) { ts := createGenericServer(t) defer ts.Close() t.Run("head method", func(t *testing.T) { resp, err := dcnldr().Head(ts.URL + "/") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) }) t.Run("options method", func(t *testing.T) { resp, err := dcnldr().Options(ts.URL + "/options") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, resp.Header().Get("Access-Control-Expose-Headers"), "x-go-resty-id") }) t.Run("patch method", func(t *testing.T) { resp, err := dcnldr().Patch(ts.URL + "/patch") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "", resp.String()) }) t.Run("trace method", func(t *testing.T) { resp, err := dcnldr().Trace(ts.URL + "/trace") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "", resp.String()) }) } func TestSendMethod(t *testing.T) { ts := createGenericServer(t) defer ts.Close() t.Run("send-get-implicit", func(t *testing.T) { req := dcnldr() req.URL = ts.URL + "/gzip-test" resp, err := req.Send() assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "This is Gzip response testing", resp.String()) }) t.Run("send-get", func(t *testing.T) { req := dcnldr() req.SetMethod(MethodGet) req.URL = ts.URL + "/gzip-test" resp, err := req.Send() assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "This is Gzip response testing", resp.String()) }) t.Run("send-options", func(t *testing.T) { req := dcnldr() req.SetMethod(MethodOptions) req.URL = ts.URL + "/options" resp, err := req.Send() assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "", resp.String()) assertEqual(t, "x-go-resty-id", resp.Header().Get("Access-Control-Expose-Headers")) }) t.Run("send-patch", func(t *testing.T) { req := dcnldr() req.SetMethod(MethodPatch) req.URL = ts.URL + "/patch" resp, err := req.Send() assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "", resp.String()) }) t.Run("send-put", func(t *testing.T) { req := dcnldr() req.SetMethod(MethodPut) req.URL = ts.URL + "/plaintext" resp, err := req.Send() assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "TestPut: plain text response", resp.String()) }) } func TestRawFileUploadByBody(t *testing.T) { ts := createFormPostServer(t) defer ts.Close() fileBytes, err := os.ReadFile(filepath.Join(getTestDataPath(), "test-img.png")) assertNil(t, err) resp, err := dcnldr(). SetBody(fileBytes). SetAuthToken("004DDB79-6801-4587-B976-F093E6AC44FF"). Put(ts.URL + "/raw-upload") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "image/png", resp.Request.Header.Get(hdrContentTypeKey)) } func TestProxySetting(t *testing.T) { c := dcnl() transport, err := c.HTTPTransport() assertNil(t, err) assertFalse(t, c.IsProxySet()) assertNotNil(t, transport.Proxy) c.SetProxy("http://sampleproxy:8888") assertTrue(t, c.IsProxySet()) assertNotNil(t, transport.Proxy) c.SetProxy("//not.a.user@%66%6f%6f.com:8888") assertTrue(t, c.IsProxySet()) assertNotNil(t, transport.Proxy) c.SetProxy("http://sampleproxy:8888") assertTrue(t, c.IsProxySet()) c.RemoveProxy() assertNil(t, c.ProxyURL()) assertNil(t, transport.Proxy) } func TestGetClient(t *testing.T) { client := New() custom := New() customClient := custom.Client() assertNotNil(t, customClient) assertNotEqual(t, client, http.DefaultClient) assertNotEqual(t, customClient, http.DefaultClient) assertNotEqual(t, client, customClient) } func TestIncorrectURL(t *testing.T) { c := dcnl() _, err := c.R().Get("//not.a.user@%66%6f%6f.com/just/a/path/also") assertTrue(t, (strings.Contains(err.Error(), "parse //not.a.user@%66%6f%6f.com/just/a/path/also") || strings.Contains(err.Error(), "parse \"//not.a.user@%66%6f%6f.com/just/a/path/also\""))) c.SetBaseURL("//not.a.user@%66%6f%6f.com") _, err1 := c.R().Get("/just/a/path/also") assertTrue(t, (strings.Contains(err1.Error(), "parse //not.a.user@%66%6f%6f.com/just/a/path/also") || strings.Contains(err1.Error(), "parse \"//not.a.user@%66%6f%6f.com/just/a/path/also\""))) } func TestDetectContentTypeForPointer(t *testing.T) { ts := createPostServer(t) defer ts.Close() user := &credentials{Username: "testuser", Password: "testpass"} resp, err := dcnldr(). SetBody(user). SetResult(AuthSuccess{}). Post(ts.URL + "/login") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) t.Logf("Result Success: %q", resp.Result().(*AuthSuccess)) logResponse(t, resp) } type ExampleUser struct { FirstName string `json:"first_name"` LastName string `json:"last_name"` ZipCode string `json:"zip_code"` } func TestDetectContentTypeForPointerWithSlice(t *testing.T) { ts := createPostServer(t) defer ts.Close() users := &[]ExampleUser{ {FirstName: "firstname1", LastName: "lastname1", ZipCode: "10001"}, {FirstName: "firstname2", LastName: "lastname3", ZipCode: "10002"}, {FirstName: "firstname3", LastName: "lastname3", ZipCode: "10003"}, } resp, err := dcnldr(). SetBody(users). Post(ts.URL + "/users") assertError(t, err) assertEqual(t, http.StatusAccepted, resp.StatusCode()) t.Logf("Result Success: %q", resp) logResponse(t, resp) } func TestDetectContentTypeForPointerWithSliceMap(t *testing.T) { ts := createPostServer(t) defer ts.Close() usersmap := map[string]any{ "user1": ExampleUser{FirstName: "firstname1", LastName: "lastname1", ZipCode: "10001"}, "user2": &ExampleUser{FirstName: "firstname2", LastName: "lastname3", ZipCode: "10002"}, "user3": ExampleUser{FirstName: "firstname3", LastName: "lastname3", ZipCode: "10003"}, } var users []map[string]any users = append(users, usersmap) resp, err := dcnldr(). SetBody(&users). Post(ts.URL + "/usersmap") assertError(t, err) assertEqual(t, http.StatusAccepted, resp.StatusCode()) t.Logf("Result Success: %q", resp) logResponse(t, resp) } func TestDetectContentTypeForSlice(t *testing.T) { ts := createPostServer(t) defer ts.Close() users := []ExampleUser{ {FirstName: "firstname1", LastName: "lastname1", ZipCode: "10001"}, {FirstName: "firstname2", LastName: "lastname3", ZipCode: "10002"}, {FirstName: "firstname3", LastName: "lastname3", ZipCode: "10003"}, } resp, err := dcnldr(). SetBody(users). Post(ts.URL + "/users") assertError(t, err) assertEqual(t, http.StatusAccepted, resp.StatusCode()) t.Logf("Result Success: %q", resp) logResponse(t, resp) } func TestMultiParamsQueryString(t *testing.T) { ts1 := createGetServer(t) defer ts1.Close() client := dcnl() req1 := client.R() client.SetQueryParam("status", "open") _, _ = req1.SetQueryParam("status", "pending"). Get(ts1.URL) assertTrue(t, strings.Contains(req1.URL, "status=pending")) // pending overrides open assertFalse(t, strings.Contains(req1.URL, "status=open")) _, _ = req1.SetQueryParam("status", "approved"). Get(ts1.URL) assertTrue(t, strings.Contains(req1.URL, "status=approved")) // approved overrides pending assertFalse(t, strings.Contains(req1.URL, "status=pending")) ts2 := createGetServer(t) defer ts2.Close() req2 := client.R() v := url.Values{ "status": []string{"pending", "approved", "reject"}, } _, _ = req2.SetQueryParamsFromValues(v).Get(ts2.URL) assertTrue(t, strings.Contains(req2.URL, "status=pending")) assertTrue(t, strings.Contains(req2.URL, "status=approved")) assertTrue(t, strings.Contains(req2.URL, "status=reject")) // because it's removed by key assertFalse(t, strings.Contains(req2.URL, "status=open")) } func TestSetQueryStringTypical(t *testing.T) { ts := createGetServer(t) defer ts.Close() resp, err := dcnldr(). SetQueryString("productId=232&template=fresh-sample&cat=resty&source=google&kw=buy a lot more"). Get(ts.URL) assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "200 OK", resp.Status()) assertEqual(t, "TestGet: text response", resp.String()) resp, err = dcnldr(). SetQueryString("&%%amp;"). Get(ts.URL) assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "200 OK", resp.Status()) assertEqual(t, "TestGet: text response", resp.String()) } func TestSetHeaderVerbatim(t *testing.T) { ts := createPostServer(t) defer ts.Close() r := dcnldr(). SetHeaderVerbatim("header-lowercase", "value_lowercase"). SetHeader("header-lowercase", "value_standard") //lint:ignore SA1008 valid one ignore this! assertEqual(t, "value_lowercase", strings.Join(r.Header["header-lowercase"], "")) assertEqual(t, "value_standard", r.Header.Get("Header-Lowercase")) } func TestSetHeaderMultipleValue(t *testing.T) { ts := createPostServer(t) defer ts.Close() r := dcnldr(). SetHeaderMultiValues(map[string][]string{ "Content": {"text/*", "text/html", "*"}, "Authorization": {"Bearer xyz"}, }) assertEqual(t, "text/*, text/html, *", r.Header.Get("content")) assertEqual(t, "Bearer xyz", r.Header.Get("authorization")) } func TestRequestSetHeaderAny(t *testing.T) { r := dcnldr(). SetHeaderAny("X-Int-Value", 42). SetHeaderAny("X-String-Value", "hello") assertEqual(t, "42", r.Header.Get("X-Int-Value")) assertEqual(t, "hello", r.Header.Get("X-String-Value")) } func TestRequestSetHeaderVerbatimAny(t *testing.T) { r := dcnldr(). SetHeaderVerbatimAny("header-lowercase", 123) //lint:ignore SA1008 valid one ignore this! assertEqual(t, "123", strings.Join(r.Header["header-lowercase"], "")) } func TestRequestSetQueryParamAny(t *testing.T) { r := dcnldr(). SetQueryParamAny("page", 5). SetQueryParamAny("active", true) assertEqual(t, "5", r.QueryParams.Get("page")) assertEqual(t, "true", r.QueryParams.Get("active")) } func TestRequestSetPathParamAny(t *testing.T) { r := dcnldr(). SetPathParamAny("userId", 42). SetPathParamAny("name", "john doe") assertEqual(t, "42", r.PathParams["userId"]) assertEqual(t, "john%20doe", r.PathParams["name"]) } func TestRequestSetRawPathParamAny(t *testing.T) { r := dcnldr(). SetPathRawParamAny("userId", 42). SetPathRawParamAny("name", "john doe") assertEqual(t, "42", r.PathParams["userId"]) assertEqual(t, "john doe", r.PathParams["name"]) } func TestOutputFileWithBaseDirAndRelativePath(t *testing.T) { ts := createGetServer(t) defer ts.Close() defer cleanupFiles(".testdata/dir-sample") baseOutputDir := filepath.Join(getTestDataPath(), "dir-sample") client := dcnl(). SetRedirectPolicy(RedirectFlexiblePolicy(10)). SetResponseSaveDirectory(baseOutputDir). SetDebug(true) outputFilePath := "go-resty/test-img-success.png" resp, err := client.R(). SetResponseSaveFileName(outputFilePath). Get(ts.URL + "/my-image.png") assertError(t, err) assertTrue(t, resp.Size() != 0) assertTrue(t, resp.Duration() > 0) f, err1 := os.Open(filepath.Join(baseOutputDir, outputFilePath)) defer closeq(f) assertError(t, err1) } func TestOutputFileWithBaseDirError(t *testing.T) { c := dcnl().SetRedirectPolicy(RedirectFlexiblePolicy(10)). SetResponseSaveDirectory(filepath.Join(getTestDataPath(), `go-resty\0`)) _ = c } func TestOutputPathDirNotExists(t *testing.T) { ts := createGetServer(t) defer ts.Close() defer cleanupFiles(filepath.Join(".testdata", "not-exists-dir")) client := dcnl(). SetRedirectPolicy(RedirectFlexiblePolicy(10)). SetResponseSaveDirectory(filepath.Join(getTestDataPath(), "not-exists-dir")) resp, err := client.R(). SetResponseSaveFileName("test-img-success.png"). Get(ts.URL + "/my-image.png") assertError(t, err) assertTrue(t, resp.Size() != 0) assertTrue(t, resp.Duration() > 0) } func TestOutputFileAbsPath(t *testing.T) { ts := createGetServer(t) defer ts.Close() defer cleanupFiles(filepath.Join(".testdata", "go-resty")) outputFile := filepath.Join(getTestDataPath(), "go-resty", "test-img-success-2.png") res, err := dcnlr(). SetResponseSaveFileName(outputFile). Get(ts.URL + "/my-image.png") assertError(t, err) assertEqual(t, int64(2579468), res.Size()) _, err = os.Stat(outputFile) assertNil(t, err) } func TestRequestSaveResponse(t *testing.T) { ts := createGetServer(t) defer ts.Close() defer cleanupFiles(filepath.Join(".testdata", "go-resty")) c := dcnl(). SetResponseSaveToFile(true). SetResponseSaveDirectory(filepath.Join(getTestDataPath(), "go-resty")) assertTrue(t, c.IsResponseSaveToFile()) t.Run("content-disposition save response request", func(t *testing.T) { outputFile := filepath.Join(getTestDataPath(), "go-resty", "test-img-success-2.png") c.SetResponseSaveToFile(false) assertFalse(t, c.IsResponseSaveToFile()) res, err := c.R(). SetResponseSaveToFile(true). Get(ts.URL + "/my-image.png?content-disposition=true&filename=test-img-success-2.png") assertError(t, err) assertEqual(t, int64(2579468), res.Size()) _, err = os.Stat(outputFile) assertNil(t, err) }) t.Run("use filename from path", func(t *testing.T) { outputFile := filepath.Join(getTestDataPath(), "go-resty", "my-image.png") c.SetResponseSaveToFile(false) assertFalse(t, c.IsResponseSaveToFile()) res, err := c.R(). SetResponseSaveToFile(true). Get(ts.URL + "/my-image.png") assertError(t, err) assertEqual(t, int64(2579468), res.Size()) _, err = os.Stat(outputFile) assertNil(t, err) }) t.Run("empty path", func(t *testing.T) { _, err := c.R(). SetResponseSaveToFile(true). Get(ts.URL) assertError(t, err) }) } func TestContextInternal(t *testing.T) { ts := createGetServer(t) defer ts.Close() r := dcnl().R(). SetQueryParam("request_no", strconv.FormatInt(time.Now().Unix(), 10)) resp, err := r.Get(ts.URL + "/") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) } func TestRequestDoNotParseResponse(t *testing.T) { ts := createGetServer(t) defer ts.Close() t.Run("do not parse response 1", func(t *testing.T) { client := dcnl().SetResponseDoNotParse(true) resp, err := client.R(). SetQueryParam("request_no", strconv.FormatInt(time.Now().Unix(), 10)). Get(ts.URL + "/") assertError(t, err) b, err := io.ReadAll(resp.Body) _ = resp.Body.Close() assertError(t, err) assertEqual(t, "TestGet: text response", string(b)) }) t.Run("manual reset raw response - do not parse response 2", func(t *testing.T) { resp, err := dcnl().R(). SetResponseDoNotParse(true). Get(ts.URL + "/") assertError(t, err) resp.RawResponse = nil assertEqual(t, 0, resp.StatusCode()) assertEqual(t, "", resp.String()) }) } func TestRequestDoNotParseResponseDebugLog(t *testing.T) { ts := createGetServer(t) defer ts.Close() t.Run("do not parse response debug log client level", func(t *testing.T) { c := dcnl(). SetResponseDoNotParse(true). SetDebug(true) var lgr bytes.Buffer c.outputLogTo(&lgr) _, err := c.R(). SetQueryParam("request_no", strconv.FormatInt(time.Now().Unix(), 10)). Get(ts.URL + "/") assertError(t, err) assertTrue(t, strings.Contains(lgr.String(), "***** DO NOT PARSE RESPONSE - Enabled *****")) }) t.Run("do not parse response debug log request level", func(t *testing.T) { c := dcnl() var lgr bytes.Buffer c.outputLogTo(&lgr) _, err := c.R(). SetDebug(true). SetResponseDoNotParse(true). SetQueryParam("request_no", strconv.FormatInt(time.Now().Unix(), 10)). Get(ts.URL + "/") assertError(t, err) assertTrue(t, strings.Contains(lgr.String(), "***** DO NOT PARSE RESPONSE - Enabled *****")) }) } type noCtTest struct { Response string `json:"response"` } func TestRequestExpectContentTypeTest(t *testing.T) { ts := createGenericServer(t) defer ts.Close() c := dcnl() resp, err := c.R(). SetResult(noCtTest{}). SetResponseExpectContentType("application/json"). Get(ts.URL + "/json-no-set") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertNotNil(t, resp.Result()) assertEqual(t, "json response no content type set", resp.Result().(*noCtTest).Response) assertEqual(t, "", firstNonEmpty("", "")) } func TestGetPathParamAndPathParams(t *testing.T) { ts := createGetServer(t) defer ts.Close() c := dcnl(). SetBaseURL(ts.URL). SetPathParam("userId", "sample@sample.com") assertEqual(t, "sample@sample.com", c.PathParams()["userId"]) resp, err := c.R().SetPathParam("subAccountId", "100002"). Get("/v1/users/{userId}/{subAccountId}/details") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertTrue(t, strings.Contains(resp.String(), "TestGetPathParams: text response")) assertTrue(t, strings.Contains(resp.String(), "/v1/users/sample@sample.com/100002/details")) logResponse(t, resp) } func TestReportMethodSupportsPayload(t *testing.T) { ts := createGenericServer(t) defer ts.Close() c := dcnl() resp, err := c.R(). SetBody("body"). Execute("REPORT", ts.URL+"/report") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) } func TestRequestQueryStringOrder(t *testing.T) { ts := createGetServer(t) defer ts.Close() resp, err := New().R(). SetQueryString("productId=232&template=fresh-sample&cat=resty&source=google&kw=buy a lot more"). Get(ts.URL + "/?UniqueId=ead1d0ed-XXX-XXX-XXX-abb7612b3146&Translate=false&tempauth=eyJ0eXAiOiJKV1QiLC...HZEhwVnJ1d0NSUGVLaUpSaVNLRG5scz0&ApiVersion=2.0") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "200 OK", resp.Status()) assertEqual(t, "TestGet: text response", resp.String()) logResponse(t, resp) } func TestRequestOverridesClientAuthorizationHeader(t *testing.T) { ts := createAuthServer(t) defer ts.Close() c := dcnl() c.SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}). SetHeader("Authorization", "some token"). SetBaseURL(ts.URL + "/") resp, err := c.R(). SetHeader("Authorization", "Bearer 004DDB79-6801-4587-B976-F093E6AC44FF"). Get("/profile") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) } func TestRequestFileUploadAsReader(t *testing.T) { ts := createFileUploadServer(t) defer ts.Close() file, _ := os.Open(filepath.Join(getTestDataPath(), "test-img.png")) defer file.Close() fi, _ := file.Stat() c := dcnl() c.SetRequestMiddlewares( MiddlewareRequestCreate, func(c *Client, r *Request) error { // validate content length values assertTrue(t, r.isContentLengthSet) assertTrue(t, r.contentLength == fi.Size()) assertTrue(t, r.RawRequest.ContentLength == fi.Size()) assertEqual(t, r.contentLength, r.RawRequest.ContentLength) return nil }, ) resp, err := c.R(). SetBody(file). SetContentType("image/png"). SetContentLength(fi.Size()). Post(ts.URL + "/upload") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertTrue(t, strings.Contains(resp.String(), "File Uploaded successfully")) file, _ = os.Open(filepath.Join(getTestDataPath(), "test-img.png")) defer file.Close() resp, err = dcnldr(). SetBody(file). SetContentType("image/png"). Post(ts.URL + "/upload") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertTrue(t, strings.Contains(resp.String(), "File Uploaded successfully")) } func TestHostHeaderOverride(t *testing.T) { ts := createGetServer(t) defer ts.Close() resp, err := dcnl().R(). SetHeader("Host", "myhostname"). Get(ts.URL + "/host-header") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "200 OK", resp.Status()) assertEqual(t, "myhostname", resp.String()) logResponse(t, resp) } type HTTPErrorResponse struct { Error string `json:"error,omitempty"` } func TestNotFoundWithError(t *testing.T) { var httpError HTTPErrorResponse ts := createGetServer(t) defer ts.Close() resp, err := dcnl().R(). SetHeader(hdrContentTypeKey, "application/json"). SetResultError(&httpError). Get(ts.URL + "/not-found-with-error") assertError(t, err) assertEqual(t, http.StatusNotFound, resp.StatusCode()) assertEqual(t, "404 Not Found", resp.Status()) assertNotNil(t, httpError) assertEqual(t, "Not found", httpError.Error) logResponse(t, resp) } func TestNotFoundWithoutError(t *testing.T) { var httpError HTTPErrorResponse ts := createGetServer(t) defer ts.Close() c := dcnl().outputLogTo(os.Stdout) resp, err := c.R(). SetResultError(&httpError). SetHeader(hdrContentTypeKey, "application/json"). Get(ts.URL + "/not-found-no-error") assertError(t, err) assertEqual(t, http.StatusNotFound, resp.StatusCode()) assertEqual(t, "404 Not Found", resp.Status()) assertNotNil(t, httpError) assertEqual(t, "", httpError.Error) logResponse(t, resp) } func TestPathParamURLInput(t *testing.T) { ts := createGetServer(t) defer ts.Close() c := dcnl(). SetBaseURL(ts.URL). SetPathParams(map[string]string{ "userId": "sample@sample.com", "path": "users/developers", }) resp, err := c.R(). SetDebug(true). SetPathParams(map[string]string{ "subAccountId": "100002", "website": "https://example.com", }).Get("/v1/users/{userId}/{subAccountId}/{path}/{website}") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertTrue(t, strings.Contains(resp.String(), "TestPathParamURLInput: text response")) assertTrue(t, strings.Contains(resp.String(), "/v1/users/sample@sample.com/100002/users%2Fdevelopers/https:%2F%2Fexample.com")) logResponse(t, resp) } func TestRawPathParamURLInput(t *testing.T) { ts := createGetServer(t) defer ts.Close() c := dcnl(). SetBaseURL(ts.URL). SetPathRawParams(map[string]string{ "userId": "sample@sample.com", "path": "users/developers", }) assertEqual(t, "sample@sample.com", c.PathParams()["userId"]) assertEqual(t, "users/developers", c.PathParams()["path"]) resp, err := c.R().SetDebug(true). SetPathRawParams(map[string]string{ "subAccountId": "100002", "website": "https://example.com", }).Get("/v1/users/{userId}/{subAccountId}/{path}/{website}") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertTrue(t, strings.Contains(resp.String(), "TestPathParamURLInput: text response")) assertTrue(t, strings.Contains(resp.String(), "/v1/users/sample@sample.com/100002/users/developers/https://example.com")) logResponse(t, resp) } // This test case is kind of pass always func TestTraceInfo(t *testing.T) { ts := createGetServer(t) defer ts.Close() serverAddr := ts.URL[strings.LastIndex(ts.URL, "/")+1:] client := dcnl() t.Run("enable trace on client", func(t *testing.T) { client.SetBaseURL(ts.URL).SetTrace(true) for _, u := range []string{"/", "/json", "/long-text", "/long-json"} { resp, err := client.R().Get(u) assertNil(t, err) assertNotNil(t, resp) tr := resp.Request.TraceInfo() assertTrue(t, tr.DNSLookup >= 0) assertTrue(t, tr.ConnTime >= 0) assertTrue(t, tr.TLSHandshake >= 0) assertTrue(t, tr.ServerTime >= 0) assertTrue(t, tr.ResponseTime >= 0) assertTrue(t, tr.TotalTime >= 0) assertTrue(t, tr.TotalTime < time.Hour) assertTrue(t, tr.TotalTime == resp.Duration()) assertEqual(t, tr.RemoteAddr, serverAddr) assertNotNil(t, tr.Clone()) } client.SetTrace(false) }) t.Run("enable trace on request", func(t *testing.T) { for _, u := range []string{"/", "/json", "/long-text", "/long-json"} { resp, err := client.R().SetTrace(true).Get(u) assertNil(t, err) assertNotNil(t, resp) tr := resp.Request.TraceInfo() assertTrue(t, tr.DNSLookup >= 0) assertTrue(t, tr.ConnTime >= 0) assertTrue(t, tr.TLSHandshake >= 0) assertTrue(t, tr.ServerTime >= 0) assertTrue(t, tr.ResponseTime >= 0) assertTrue(t, tr.TotalTime >= 0) assertTrue(t, tr.TotalTime == resp.Duration()) assertEqual(t, tr.RemoteAddr, serverAddr) } }) t.Run("enable trace on invalid request, issue #1016", func(t *testing.T) { resp, err := client.R().SetTrace(true).Get("unknown://url.com") assertNotNil(t, err) tr := resp.Request.TraceInfo() assertTrue(t, tr.DNSLookup == 0) assertTrue(t, tr.ConnTime == 0) assertTrue(t, tr.TLSHandshake == 0) assertTrue(t, tr.ServerTime == 0) assertTrue(t, tr.ResponseTime == 0) assertTrue(t, tr.TotalTime > 0 && tr.TotalTime < time.Second) }) t.Run("enable trace and debug on request", func(t *testing.T) { c, logBuf := dcldb() c.SetBaseURL(ts.URL) requestURLs := []string{"/", "/json", "/long-text", "/long-json"} for _, u := range requestURLs { resp, err := c.R().SetTrace(true).SetDebug(true).Get(u) assertNil(t, err) assertNotNil(t, resp) jsonStr := resp.Request.TraceInfo().JSON() assertTrue(t, strings.Contains(jsonStr, serverAddr)) } logContent := logBuf.String() regexTraceInfoHeader := regexp.MustCompile("TRACE INFO:") matches := regexTraceInfoHeader.FindAllStringIndex(logContent, -1) assertEqual(t, len(requestURLs), len(matches)) }) t.Run("enable trace and debug on request json formatter", func(t *testing.T) { c, logBuf := dcldb() c.SetBaseURL(ts.URL) c.SetDebugLogFormatter(DebugLogJSONFormatter) requestURLs := []string{"/", "/json", "/long-text", "/long-json"} for _, u := range requestURLs { resp, err := c.R().SetTrace(true).SetDebug(true).Get(u) assertNil(t, err) assertNotNil(t, resp) } logContent := logBuf.String() regexTraceInfoHeader := regexp.MustCompile(`"trace_info":{"`) matches := regexTraceInfoHeader.FindAllStringIndex(logContent, -1) assertEqual(t, len(requestURLs), len(matches)) }) // for sake of hook funcs _, _ = client.R().SetTrace(true).Get("https://httpbin.org/get") } func TestTraceInfoWithoutEnableTrace(t *testing.T) { ts := createGetServer(t) defer ts.Close() client := dcnl() client.SetBaseURL(ts.URL) for _, u := range []string{"/", "/json", "/long-text", "/long-json"} { resp, err := client.R().Get(u) assertNil(t, err) assertNotNil(t, resp) tr := resp.Request.TraceInfo() assertTrue(t, tr.DNSLookup == 0) assertTrue(t, tr.ConnTime == 0) assertTrue(t, tr.TLSHandshake == 0) assertTrue(t, tr.ServerTime == 0) assertTrue(t, tr.ResponseTime == 0) assertTrue(t, tr.TotalTime == 0) } } func TestTraceInfoOnTimeout(t *testing.T) { client := NewWithTransportSettings(&TransportSettings{ DialerTimeout: 100 * time.Millisecond, }). SetBaseURL("http://resty-nowhere.local"). SetTrace(true) resp, err := client.R().Get("/") assertNotNil(t, err) assertNotNil(t, resp) tr := resp.Request.TraceInfo() assertTrue(t, tr.DNSLookup >= 0) assertTrue(t, tr.ConnTime == 0) assertTrue(t, tr.TLSHandshake == 0) assertTrue(t, tr.TCPConnTime == 0) assertTrue(t, tr.ServerTime == 0) assertTrue(t, tr.ResponseTime == 0) assertTrue(t, tr.TotalTime > 0) assertTrue(t, tr.TotalTime == resp.Duration()) } func TestTraceInfoOnTimeoutWithSetTimeout(t *testing.T) { t.Run("timeout with very short timeout", func(t *testing.T) { client := New(). SetTimeout(1 * time.Millisecond). SetBaseURL("http://resty-nowhere.local"). SetTrace(true) resp, err := client.R().Get("/") assertNotNil(t, err) assertNotNil(t, resp) tr := resp.Request.TraceInfo() assertTrue(t, tr.DNSLookup == 0) assertTrue(t, tr.ConnTime == 0) assertTrue(t, tr.TLSHandshake == 0) assertTrue(t, tr.TCPConnTime == 0) assertTrue(t, tr.ServerTime == 0) assertTrue(t, tr.ResponseTime == 0) assertTrue(t, tr.TotalTime > 0) assertTrue(t, tr.TotalTime == resp.Duration()) }) t.Run("successful request with SetTimeout", func(t *testing.T) { ts := createGetServer(t) defer ts.Close() client := New(). SetTimeout(5 * time.Second). SetBaseURL(ts.URL). SetTrace(true) resp, err := client.R().Get("/") assertNil(t, err) assertNotNil(t, resp) tr := resp.Request.TraceInfo() assertTrue(t, tr.DNSLookup >= 0) assertTrue(t, tr.ConnTime >= 0) assertTrue(t, tr.TLSHandshake >= 0) assertTrue(t, tr.TCPConnTime >= 0) assertTrue(t, tr.ServerTime >= 0) assertTrue(t, tr.ResponseTime >= 0) assertTrue(t, tr.TotalTime > 0) assertTrue(t, tr.TotalTime == resp.Duration()) }) t.Run("HTTPS request with TLS handshake", func(t *testing.T) { ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("OK")) })) defer ts.Close() client := New(). SetTimeout(5 * time.Second). SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}). SetTrace(true) resp, err := client.R().Get(ts.URL) assertNil(t, err) assertNotNil(t, resp) tr := resp.Request.TraceInfo() assertTrue(t, tr.TLSHandshake > 0) assertTrue(t, tr.DNSLookup >= 0) assertTrue(t, tr.ConnTime >= 0) assertTrue(t, tr.TCPConnTime >= 0) assertTrue(t, tr.ServerTime >= 0) assertTrue(t, tr.ResponseTime >= 0) assertTrue(t, tr.TotalTime > 0) assertTrue(t, tr.TotalTime == resp.Duration()) }) } func TestDebugLoggerRequestBodyTooLarge(t *testing.T) { formTs := createFormPostServer(t) defer formTs.Close() debugBodySizeLimit := 512 t.Run("post form with more than 512 bytes data", func(t *testing.T) { output := bytes.NewBufferString("") resp, err := New().SetDebug(true).outputLogTo(output).SetDebugBodyLimit(debugBodySizeLimit).R(). SetFormData(map[string]string{ "first_name": "Alex", "last_name": strings.Repeat("C", int(debugBodySizeLimit)), "zip_code": "00001", }). SetBasicAuth("myuser", "mypass"). Post(formTs.URL + "/profile") assertNil(t, err) assertNotNil(t, resp) assertTrue(t, strings.Contains(output.String(), "REQUEST TOO LARGE")) }) t.Run("post form with no more than 512 bytes data", func(t *testing.T) { output := bytes.NewBufferString("") resp, err := New().outputLogTo(output).SetDebugBodyLimit(debugBodySizeLimit).R(). SetDebug(true). SetFormData(map[string]string{ "first_name": "Alex", "last_name": "C", "zip_code": "00001", }). SetBasicAuth("myuser", "mypass"). Post(formTs.URL + "/profile") assertNil(t, err) assertNotNil(t, resp) assertTrue(t, strings.Contains(output.String(), "Alex")) }) t.Run("post string with more than 512 bytes data", func(t *testing.T) { output := bytes.NewBufferString("") resp, err := New().SetDebug(true).outputLogTo(output).SetDebugBodyLimit(debugBodySizeLimit).R(). SetBody(`{ "first_name": "Alex", "last_name": "`+strings.Repeat("C", int(debugBodySizeLimit))+`C", "zip_code": "00001"}`). SetBasicAuth("myuser", "mypass"). Post(formTs.URL + "/profile") assertNil(t, err) assertNotNil(t, resp) assertTrue(t, strings.Contains(output.String(), "REQUEST TOO LARGE")) }) t.Run("post string slice with more than 512 bytes data", func(t *testing.T) { output := bytes.NewBufferString("") resp, err := New().outputLogTo(output).SetDebugBodyLimit(debugBodySizeLimit).R(). SetDebug(true). SetBody([]string{strings.Repeat("hello", debugBodySizeLimit)}). SetBasicAuth("myuser", "mypass"). Post(formTs.URL + "/profile") assertNil(t, err) assertNotNil(t, resp) assertTrue(t, strings.Contains(output.String(), "REQUEST TOO LARGE")) }) } func TestPostMapTemporaryRedirect(t *testing.T) { ts := createPostServer(t) defer ts.Close() c := dcnl() resp, err := c.R().SetBody(map[string]string{"username": "testuser", "password": "testpass"}). Post(ts.URL + "/redirect") assertNil(t, err) assertNotNil(t, resp) assertEqual(t, http.StatusOK, resp.StatusCode()) } func TestPostWith204Response(t *testing.T) { ts := createPostServer(t) defer ts.Close() c := dcnl() resp, err := c.R().SetBody(map[string]string{"username": "testuser", "password": "testpass"}). Post(ts.URL + "/204-response") assertNil(t, err) assertNotNil(t, resp) assertEqual(t, http.StatusNoContent, resp.StatusCode()) } type brokenReadCloser struct{} func (b brokenReadCloser) Read(p []byte) (n int, err error) { return 0, errors.New("read error") } func (b brokenReadCloser) Close() error { return nil } func TestPostBodyError(t *testing.T) { ts := createPostServer(t) defer ts.Close() c := dcnl() resp, err := c.R().SetBody(brokenReadCloser{}).Post(ts.URL + "/redirect") assertNotNil(t, err) assertEqual(t, "read error", errors.Unwrap(err).Error()) assertNotNil(t, resp) } func TestSetResultMustNotPanicOnNil(t *testing.T) { defer func() { if r := recover(); r != nil { t.Errorf("must not panic") } }() dcnl().R().SetResult(nil) } func TestRequestClone(t *testing.T) { ts := createGetServer(t) defer ts.Close() c := dcnl() parent := c.R() // set an non-interface value parent.URL = ts.URL parent.SetPathParam("name", "parent") parent.SetPathRawParam("name", "parent") // set http header parent.SetHeader("X-Header", "parent") // set an interface value parent.SetBasicAuth("parent", "") parent.bodyBuf = acquireBuffer() parent.bodyBuf.WriteString("parent") parent.RawRequest = &http.Request{} clone := parent.Clone(context.Background()) // assume parent request is used _, _ = parent.Get(ts.URL) // update value of non-interface type - change will only happen on clone clone.URL = "http://localhost.clone" clone.PathParams["name"] = "clone" // update value of http header - change will only happen on clone clone.SetHeader("X-Header", "clone") // update value of interface type - change will only happen on clone clone.credentials.Username = "clone" clone.bodyBuf.Reset() clone.bodyBuf.WriteString("clone") // assert non-interface type assertEqual(t, "http://localhost.clone", clone.URL) assertEqual(t, ts.URL, parent.URL) assertEqual(t, "clone", clone.PathParams["name"]) assertEqual(t, "parent", parent.PathParams["name"]) // assert http header assertEqual(t, "parent", parent.Header.Get("X-Header")) assertEqual(t, "clone", clone.Header.Get("X-Header")) // assert interface type assertEqual(t, "parent", parent.credentials.Username) assertEqual(t, "clone", clone.credentials.Username) assertEqual(t, "", parent.bodyBuf.String()) assertEqual(t, "clone", clone.bodyBuf.String()) // parent request should have raw request while clone should not assertNil(t, clone.RawRequest) assertNotNil(t, parent.RawRequest) assertNotEqual(t, parent.RawRequest, clone.RawRequest) } func TestResponseBodyUnlimitedReads(t *testing.T) { ts := createPostServer(t) defer ts.Close() user := &credentials{Username: "testuser", Password: "testpass"} c := dcnl(). SetJSONEscapeHTML(false). SetResponseBodyUnlimitedReads(true) assertTrue(t, c.ResponseBodyUnlimitedReads()) resp, err := c.R(). SetHeader(hdrContentTypeKey, "application/json; charset=utf-8"). SetBody(user). SetResult(&AuthSuccess{}). Post(ts.URL + "/login") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, int64(50), resp.Size()) t.Logf("Result Success: %q", resp.Result().(*AuthSuccess)) for i := 1; i <= 5; i++ { b, err := io.ReadAll(resp.Body) assertNil(t, err) assertEqual(t, `{ "id": "success", "message": "login successful" }`, string(b)) } logResponse(t, resp) } func TestRequestAllowPayload(t *testing.T) { c := dcnl() t.Run("default method is GET", func(t *testing.T) { r := c.R() result1 := r.isPayloadSupported() assertFalse(t, result1) r.SetMethodGetAllowPayload(true) result2 := r.isPayloadSupported() assertTrue(t, result2) }) t.Run("method GET", func(t *testing.T) { r := c.R(). SetMethod(MethodGet) result1 := r.isPayloadSupported() assertFalse(t, result1) r.SetMethodGetAllowPayload(true) result2 := r.isPayloadSupported() assertTrue(t, result2) }) t.Run("method POST", func(t *testing.T) { r := c.R(). SetMethod(MethodPost) result1 := r.isPayloadSupported() assertTrue(t, result1) }) t.Run("method PUT", func(t *testing.T) { r := c.R(). SetMethod(MethodPut) result1 := r.isPayloadSupported() assertTrue(t, result1) }) t.Run("method PATCH", func(t *testing.T) { r := c.R(). SetMethod(MethodPatch) result1 := r.isPayloadSupported() assertTrue(t, result1) }) t.Run("method DELETE", func(t *testing.T) { r := c.R(). SetMethod(MethodDelete) result1 := r.isPayloadSupported() assertFalse(t, result1) r.SetMethodDeleteAllowPayload(true) result2 := r.isPayloadSupported() assertTrue(t, result2) }) t.Run("method HEAD", func(t *testing.T) { r := c.R(). SetMethod(MethodHead) result1 := r.isPayloadSupported() assertFalse(t, result1) }) t.Run("method OPTIONS", func(t *testing.T) { r := c.R(). SetMethod(MethodOptions) result1 := r.isPayloadSupported() assertFalse(t, result1) }) t.Run("method TRACE", func(t *testing.T) { r := c.R(). SetMethod(MethodTrace) result1 := r.isPayloadSupported() assertFalse(t, result1) }) } func TestRequestNoRetryOnNonIdempotentMethod(t *testing.T) { ts := createFileUploadServer(t) defer ts.Close() str := "test" buf := []byte(str) bufReader := bytes.NewReader(buf) bufCpy := make([]byte, len(buf)) c := dcnl(). SetTimeout(time.Second * 3). AddRetryHooks( func(response *Response, _ error) { read, err := bufReader.Read(bufCpy) assertNil(t, err) assertEqual(t, len(buf), read) assertEqual(t, str, string(bufCpy)) }, ) req := c.R(). SetRetryCount(3). SetFileReader("name", "filename", bufReader) resp, err := req.Post(ts.URL + "/set-reset-multipart-readers-test") assertNil(t, err) assertEqual(t, 1, resp.Request.Attempt) assertEqual(t, 500, resp.StatusCode()) } func TestRequestContextTimeout(t *testing.T) { ts := createGetServer(t) defer ts.Close() t.Run("use client set timeout", func(t *testing.T) { c := dcnl().SetTimeout(200 * time.Millisecond) assertTrue(t, c.Timeout() > 0) req := c.R() assertTrue(t, req.Timeout > 0) _, err := req.Get(ts.URL + "/set-timeout-test") assertTrue(t, errors.Is(err, context.DeadlineExceeded)) }) t.Run("use request set timeout", func(t *testing.T) { c := dcnl() assertTrue(t, c.Timeout() == 0) _, err := c.R(). SetTimeout(200 * time.Millisecond). Get(ts.URL + "/set-timeout-test") assertTrue(t, errors.Is(err, context.DeadlineExceeded)) }) t.Run("use external context for timeout", func(t *testing.T) { ctx, ctxCancelFunc := context.WithTimeout(context.Background(), 200*time.Millisecond) defer ctxCancelFunc() c := dcnl() _, err := c.R(). SetContext(ctx). Get(ts.URL + "/set-timeout-test") assertTrue(t, errors.Is(err, context.DeadlineExceeded)) }) } func TestRequestPanicContext(t *testing.T) { defer func() { if r := recover(); r == nil { t.Errorf("The code did not panic") } }() c := dcnl() //lint:ignore SA1012 test case nil check _ = c.R().WithContext(nil) } func TestRequestSetResultAndSetOutputFile(t *testing.T) { ts := createPostServer(t) defer ts.Close() outputFile := filepath.Join(getTestDataPath(), "login-success.txt") defer cleanupFiles(outputFile) c := dcnl().SetBaseURL(ts.URL) res, err := c.R(). SetHeader(hdrContentTypeKey, "application/json; charset=utf-8"). SetBody(&credentials{Username: "testuser", Password: "testpass"}). SetResponseBodyUnlimitedReads(true). SetResult(&AuthSuccess{}). SetResponseSaveFileName(outputFile). Post("/login") assertError(t, err) assertEqual(t, http.StatusOK, res.StatusCode()) assertEqual(t, int64(50), res.Size()) loginResult := res.Result().(*AuthSuccess) assertEqual(t, "success", loginResult.ID) assertEqual(t, "login successful", loginResult.Message) fileContent, _ := os.ReadFile(outputFile) assertEqual(t, `{ "id": "success", "message": "login successful" }`, string(fileContent)) } func TestRequestBodyContentLengthValidation(t *testing.T) { ts := createGenericServer(t) defer ts.Close() c := dcnl().SetBaseURL(ts.URL) c.SetRequestMiddlewares( MiddlewareRequestCreate, func(c *Client, r *Request) error { // validate content length assertTrue(t, r.contentLength > 0) assertTrue(t, r.RawRequest.ContentLength > 0) assertEqual(t, r.contentLength, r.RawRequest.ContentLength) return nil }, ) buf := bytes.NewBuffer([]byte(`{"content":"json content sending to server"}`)) res, err := c.R(). SetHeader(hdrContentTypeKey, "application/json"). SetBody(buf). Put("/json") assertError(t, err) assertEqual(t, http.StatusOK, res.StatusCode()) assertEqual(t, `{"response":"json response"}`, res.String()) assertEqual(t, int64(44), res.Request.RawRequest.ContentLength) } func TestRequestFuncs(t *testing.T) { ts := createGetServer(t) defer ts.Close() c := dcnl(). SetQueryParam("client_param", "true"). SetQueryParams(map[string]string{"req_1": "value1", "req_3": "value3"}). SetDebug(true) addRequestQueryParams := func(page, size int) func(r *Request) *Request { return func(r *Request) *Request { return r.SetQueryParam("page", strconv.Itoa(page)). SetQueryParam("size", strconv.Itoa(size)). SetQueryParam("request_no", strconv.Itoa(int(time.Now().Unix()))) } } addRequestHeaders := func(r *Request) *Request { return r.SetHeader(hdrAcceptKey, "application/json"). SetHeader(hdrUserAgentKey, "my-client/v1.0") } resp, err := c.R(). Funcs(addRequestQueryParams(1, 100), addRequestHeaders). SetHeader(hdrUserAgentKey, "Test Custom User agent"). Get(ts.URL + "/") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "HTTP/1.1", resp.Proto()) assertEqual(t, "200 OK", resp.Status()) assertEqual(t, "TestGet: text response", resp.String()) } func TestHTTPWarnGH970(t *testing.T) { lookupText := "Using sensitive credentials in HTTP mode is not secure. Use HTTPS" t.Run("SSL used", func(t *testing.T) { ts := createAuthServerTLSOptional(t, true) defer ts.Close() c, lb := dcldb() c.SetBaseURL(ts.URL). SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}) res, err := c.R(). SetAuthToken("004DDB79-6801-4587-B976-F093E6AC44FF"). Get("/profile") assertNil(t, err) assertTrue(t, strings.Contains(res.String(), "profile fetch successful")) assertFalse(t, strings.Contains(lb.String(), lookupText)) }) t.Run("non-SSL used", func(t *testing.T) { ts := createAuthServerTLSOptional(t, false) defer ts.Close() c, lb := dcldb() c.SetBaseURL(ts.URL) res, err := c.R(). SetAuthToken("004DDB79-6801-4587-B976-F093E6AC44FF"). Get("/profile") assertNil(t, err) assertTrue(t, strings.Contains(res.String(), "profile fetch successful")) assertTrue(t, strings.Contains(lb.String(), lookupText)) }) } // This test methods exist for test coverage purpose // to validate the getter and setter func TestRequestSettingsCoverage(t *testing.T) { c := dcnl() r1 := c.R() assertFalse(t, r1.IsCloseConnection) r1.SetCloseConnection(true) assertTrue(t, r1.IsCloseConnection) r2 := c.R() assertFalse(t, r2.IsTrace) r2.SetTrace(true) assertTrue(t, r2.IsTrace) r2.SetTrace(false) assertFalse(t, r2.IsTrace) r3 := c.R() assertFalse(t, r3.IsResponseBodyUnlimitedReads) r3.SetResponseBodyUnlimitedReads(true) assertTrue(t, r3.IsResponseBodyUnlimitedReads) r3.SetResponseBodyUnlimitedReads(false) assertFalse(t, r3.IsResponseBodyUnlimitedReads) r4 := c.R() assertFalse(t, r4.IsDebug) r4.SetDebug(true) assertTrue(t, r4.IsDebug) r4.SetDebug(false) assertFalse(t, r4.IsDebug) r5 := c.R() assertTrue(t, r5.IsRetryDefaultConditions) r5.SetRetryDefaultConditions(false) assertFalse(t, r5.IsRetryDefaultConditions) r5.SetRetryDefaultConditions(true) assertTrue(t, r5.IsRetryDefaultConditions) r6 := c.R() customAuthHeader := "X-Custom-Authorization" r6.SetHeaderAuthorizationKey(customAuthHeader) assertEqual(t, customAuthHeader, r6.HeaderAuthorizationKey) invalidJsonBytes := []byte(`{\" \": "value here"}`) result := jsonIndent(invalidJsonBytes) assertEqual(t, string(invalidJsonBytes), string(result)) res := &Response{} assertNil(t, res.RedirectHistory()) defer func() { if rec := recover(); rec != nil { if err, ok := rec.(error); ok { assertTrue(t, strings.Contains(err.Error(), "resty: Request.Clone nil context")) } } }() rc := c.R() //lint:ignore SA1012 test case nil check rc2 := rc.Clone(nil) assertEqual(t, nil, rc2.ctx) } func TestRequestDataRace(t *testing.T) { ts := createPostServer(t) defer ts.Close() usersmap := map[string]any{ "user1": ExampleUser{FirstName: "firstname1", LastName: "lastname1", ZipCode: "10001"}, "user2": &ExampleUser{FirstName: "firstname2", LastName: "lastname3", ZipCode: "10002"}, "user3": ExampleUser{FirstName: "firstname3", LastName: "lastname3", ZipCode: "10003"}, } var users []map[string]any users = append(users, usersmap) c := dcnl().SetBaseURL(ts.URL) totalRequests := 4000 wg := sync.WaitGroup{} wg.Add(totalRequests) for i := 0; i < totalRequests; i++ { if i%100 == 0 { time.Sleep(20 * time.Millisecond) // to prevent test server socket exhaustion } go func() { defer wg.Done() res, err := c.R().SetContext(context.Background()).SetBody(users).Post("/usersmap") assertError(t, err) assertEqual(t, http.StatusAccepted, res.StatusCode()) }() } wg.Wait() } ================================================ FILE: response.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty import ( "bytes" "encoding/json" "fmt" "io" "net/http" "strings" "time" ) //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ // Response struct and methods //_______________________________________________________________________ // Response struct holds response values of executed requests. type Response struct { Request *Request Body io.ReadCloser RawResponse *http.Response IsRead bool // CascadeError field used to cascade the response processing and // middleware execution errors CascadeError error bodyBytes []byte size int64 receivedAt time.Time } // Status method returns the HTTP status string for the executed request. // // Example: 200 OK func (r *Response) Status() string { if r.RawResponse == nil { return "" } return r.RawResponse.Status } // StatusCode method returns the HTTP status code for the executed request. // // Example: 200 func (r *Response) StatusCode() int { if r.RawResponse == nil { return 0 } return r.RawResponse.StatusCode } // Proto method returns the HTTP response protocol used for the request. func (r *Response) Proto() string { if r.RawResponse == nil { return "" } return r.RawResponse.Proto } // Result method returns the unmarshalled result response object if it exists, // otherwise nil. // // client := resty.New() // defer client.Close() // // res, err := client.R(). // SetBody(User{ // Username: "testuser", // Password: "testpass", // }). // SetResult(&LoginResponse{}). // or SetResult(LoginResponse{}). // SetResultError(&LoginErrorResponse{}). // or SetResultError(LoginErrorResponse{}). // Post("https://myapp.com/login") // // fmt.Println(err, res) // fmt.Println(res.Result().(*LoginResponse)) // fmt.Println(res.ResultError().(*LoginErrorResponse)) // // See [Request.SetResult] func (r *Response) Result() any { return r.Request.Result } // ResultError method returns the unmarshalled result error object if it exists, // otherwise nil. // // client := resty.New() // defer client.Close() // // res, err := client.R(). // SetBody(User{ // Username: "testuser", // Password: "testpass", // }). // SetResult(&LoginResponse{}). // or SetResult(LoginResponse{}). // SetResultError(&LoginErrorResponse{}). // or SetResultError(LoginErrorResponse{}). // Post("https://myapp.com/login") // // fmt.Println(err, res) // fmt.Println(res.Result().(*LoginResponse)) // fmt.Println(res.ResultError().(*LoginErrorResponse)) // // See [Request.SetResultError], [Client.SetResultError] func (r *Response) ResultError() any { return r.Request.ResultError } // Header method returns the response headers func (r *Response) Header() http.Header { if r.RawResponse == nil { return http.Header{} } return r.RawResponse.Header } // Cookies method to returns all the response cookies func (r *Response) Cookies() []*http.Cookie { if r.RawResponse == nil { return make([]*http.Cookie, 0) } return r.RawResponse.Cookies() } // String method returns the body of the HTTP response as a `string`. // It returns an empty string if it is nil or the body is zero length. // // NOTE: // - Returns an empty string on auto-unmarshal scenarios, unless // [Client.SetResponseBodyUnlimitedReads] or [Request.SetResponseBodyUnlimitedReads] set. // - Returns an empty string when [Client.SetResponseDoNotParse] or [Request.SetResponseDoNotParse] set. func (r *Response) String() string { r.readIfRequired() return strings.TrimSpace(string(r.bodyBytes)) } // Bytes method returns the body of the HTTP response as a byte slice. // It returns an empty byte slice if it is nil or the body is zero length. // // NOTE: // - Returns an empty byte slice on auto-unmarshal scenarios, unless // [Client.SetResponseBodyUnlimitedReads] or [Request.SetResponseBodyUnlimitedReads] set. // - Returns an empty byte slice when [Client.SetResponseDoNotParse] or [Request.SetResponseDoNotParse] set. func (r *Response) Bytes() []byte { r.readIfRequired() return r.bodyBytes } // Duration method returns the duration of HTTP response time from the request we sent // and received a request. // // See [Response.ReceivedAt] to know when the client received a response and see // `Response.Request.Time` to know when the client sent a request. func (r *Response) Duration() time.Duration { if r.Request.trace != nil { return r.Request.TraceInfo().TotalTime } return r.receivedAt.Sub(r.Request.StartTime) } // ReceivedAt method returns the time we received a response from the server for the request. func (r *Response) ReceivedAt() time.Time { return r.receivedAt } // Size method returns the HTTP response size in bytes. Yeah, you can rely on HTTP `Content-Length` // header, however it won't be available for chucked transfer/compressed response. // Since Resty captures response size details when processing the response body // when possible. So that users get the actual size of response bytes. func (r *Response) Size() int64 { r.readIfRequired() return r.size } // IsStatusSuccess method returns true if HTTP status `code >= 200 and <= 299` otherwise false. // // Example: 200, 201, 204, etc. func (r *Response) IsStatusSuccess() bool { return r.StatusCode() > 199 && r.StatusCode() < 300 } // IsStatusFailure method returns true if HTTP status `code >= 400` otherwise false. // // Example: 400, 500, etc. func (r *Response) IsStatusFailure() bool { return r.StatusCode() > 399 } // RedirectHistory method returns a redirect history slice with the URL and status code func (r *Response) RedirectHistory() []*RedirectInfo { if r.RawResponse == nil { return nil } redirects := make([]*RedirectInfo, 0) res := r.RawResponse for res != nil { req := res.Request redirects = append(redirects, &RedirectInfo{ StatusCode: res.StatusCode, URL: req.URL.String(), }) res = req.Response } return redirects } func (r *Response) setReceivedAt() { r.receivedAt = time.Now() if r.Request.trace != nil { r.Request.trace.endTime = r.receivedAt } } func (r *Response) fmtBodyString(sl int) string { if r.Request.IsResponseDoNotParse { return "***** DO NOT PARSE RESPONSE - Enabled *****" } if r.Request.IsResponseSaveToFile { return "***** RESPONSE WRITTEN INTO FILE *****" } bl := len(r.bodyBytes) if r.IsRead && bl == 0 { return "***** RESPONSE BODY IS ALREADY READ - see Response.{Result()/Error()} *****" } if bl > 0 { if bl > sl { return fmt.Sprintf("***** RESPONSE TOO LARGE (size - %d) *****", bl) } ct := r.Header().Get(hdrContentTypeKey) ctKey := inferContentTypeMapKey(ct) if jsonKey == ctKey { out := acquireBuffer() defer releaseBuffer(out) err := json.Indent(out, r.bodyBytes, "", " ") if err != nil { r.Request.log.Errorf("DebugLog: Response.fmtBodyString: %v", err) return "" } return out.String() } return r.String() } return "***** NO CONTENT *****" } func (r *Response) readIfRequired() { if len(r.bodyBytes) == 0 && !r.Request.IsResponseDoNotParse { _ = r.readAll() } } var ioReadAll = io.ReadAll // auto-unmarshal didn't happen, so fallback to // old behavior of reading response as body bytes func (r *Response) readAll() (err error) { if r.Body == nil || r.IsRead { return nil } if _, ok := r.Body.(*copyReadCloser); ok { _, err = ioReadAll(r.Body) } else { r.bodyBytes, err = ioReadAll(r.Body) closeq(r.Body) r.Body = &nopReadCloser{r: bytes.NewReader(r.bodyBytes), resetOnEOF: true} } if err == io.ErrUnexpectedEOF { // content-encoding scenario's - empty/no response body from server err = nil } r.IsRead = true return } func (r *Response) wrapLimitReadCloser() { r.Body = &limitReadCloser{ r: r.Body, l: r.Request.ResponseBodyLimit, f: func(s int64) { r.size = s }, } } func (r *Response) wrapCopyReadCloser() { r.Body = ©ReadCloser{ s: r.Body, t: acquireBuffer(), f: func(b *bytes.Buffer) { r.bodyBytes = append([]byte{}, b.Bytes()...) closeq(r.Body) r.Body = &nopReadCloser{r: bytes.NewReader(r.bodyBytes), resetOnEOF: true} releaseBuffer(b) }, } } func (r *Response) wrapContentDecompresser() error { ce := r.Header().Get(hdrContentEncodingKey) if isStringEmpty(ce) { return nil } if decFunc, f := r.Request.client.ContentDecompressers()[strings.ToLower(ce)]; f { dec, err := decFunc(r.Body) if err != nil { if err == io.EOF { // empty/no response body from server err = nil } return err } r.Body = dec r.Header().Del(hdrContentEncodingKey) r.Header().Del(hdrContentLengthKey) r.RawResponse.ContentLength = -1 } else { return ErrContentDecompresserNotFound } return nil } func (r *Response) wrapError(err error, preserve bool) error { r.CascadeError = wrapErrors(err, r.CascadeError) if preserve { return nil } e := r.CascadeError r.CascadeError = nil return e } ================================================ FILE: resty.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT // Package resty provides Simple HTTP, REST, and SSE client library for Go. package resty // import "resty.dev/v3" import ( "math" "net" "net/http" "net/http/cookiejar" "net/url" "runtime" "sync" "time" "golang.org/x/net/publicsuffix" ) // Version # of resty const Version = "3.0.0-beta.6" // New method creates a new Resty client. func New() *Client { return NewWithTransportSettings(nil) } // NewWithTransportSettings method creates a new Resty client with provided // timeout values. func NewWithTransportSettings(transportSettings *TransportSettings) *Client { return NewWithDialerAndTransportSettings(nil, transportSettings) } // NewWithClient method creates a new Resty client with given [http.Client]. func NewWithClient(hc *http.Client) *Client { return createClient(hc) } // NewWithDialer method creates a new Resty client with given Local Address // to dial from. func NewWithDialer(dialer *net.Dialer) *Client { return NewWithDialerAndTransportSettings(dialer, nil) } // NewWithLocalAddr method creates a new Resty client with the given Local Address. func NewWithLocalAddr(localAddr net.Addr) *Client { return NewWithDialerAndTransportSettings( &net.Dialer{LocalAddr: localAddr}, nil, ) } // NewWithDialerAndTransportSettings method creates a new Resty client with given Local Address // to dial from. func NewWithDialerAndTransportSettings(dialer *net.Dialer, transportSettings *TransportSettings) *Client { return createClient(&http.Client{ Jar: createCookieJar(), Transport: createTransport(dialer, transportSettings), }) } //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ // Unexported methods //_______________________________________________________________________ func createTransport(dialer *net.Dialer, transportSettings *TransportSettings) *http.Transport { if transportSettings == nil { transportSettings = &TransportSettings{} } // Dialer if dialer == nil { dialer = &net.Dialer{} } if transportSettings.DialerTimeout > 0 { dialer.Timeout = transportSettings.DialerTimeout } else { dialer.Timeout = 30 * time.Second } if transportSettings.DialerKeepAlive > 0 { dialer.KeepAlive = transportSettings.DialerKeepAlive } else { dialer.KeepAlive = 30 * time.Second } // Transport t := &http.Transport{ Proxy: http.ProxyFromEnvironment, DialContext: transportDialContext(dialer), DisableKeepAlives: transportSettings.DisableKeepAlives, DisableCompression: true, // Resty handles it, see [Client.AddContentDecoder] ForceAttemptHTTP2: true, } if transportSettings.IdleConnTimeout > 0 { t.IdleConnTimeout = transportSettings.IdleConnTimeout } else { t.IdleConnTimeout = 90 * time.Second } if transportSettings.TLSHandshakeTimeout > 0 { t.TLSHandshakeTimeout = transportSettings.TLSHandshakeTimeout } else { t.TLSHandshakeTimeout = 10 * time.Second } if transportSettings.ExpectContinueTimeout > 0 { t.ExpectContinueTimeout = transportSettings.ExpectContinueTimeout } else { t.ExpectContinueTimeout = 1 * time.Second } if transportSettings.MaxIdleConns > 0 { t.MaxIdleConns = transportSettings.MaxIdleConns } else { t.MaxIdleConns = 100 } if transportSettings.MaxIdleConnsPerHost > 0 { t.MaxIdleConnsPerHost = transportSettings.MaxIdleConnsPerHost } else { t.MaxIdleConnsPerHost = runtime.GOMAXPROCS(0) + 1 } if transportSettings.MaxConnsPerHost > 0 { t.MaxConnsPerHost = transportSettings.MaxConnsPerHost } // // No default value in Resty for following settings, added to // provide ability to set value otherwise the Go HTTP client // default value applies. // if transportSettings.ResponseHeaderTimeout > 0 { t.ResponseHeaderTimeout = transportSettings.ResponseHeaderTimeout } if transportSettings.MaxResponseHeaderBytes > 0 { t.MaxResponseHeaderBytes = transportSettings.MaxResponseHeaderBytes } if transportSettings.WriteBufferSize > 0 { t.WriteBufferSize = transportSettings.WriteBufferSize } if transportSettings.ReadBufferSize > 0 { t.ReadBufferSize = transportSettings.ReadBufferSize } return t } func createCookieJar() *cookiejar.Jar { cookieJar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) return cookieJar } func createClient(hc *http.Client) *Client { c := &Client{ // not setting language default values lock: &sync.RWMutex{}, queryParams: url.Values{}, formData: url.Values{}, header: http.Header{}, authScheme: defaultAuthScheme, cookies: make([]*http.Cookie, 0), retryWaitTime: defaultWaitTime, retryMaxWaitTime: defaultMaxWaitTime, isRetryDefaultConditions: true, pathParams: make(map[string]string), headerAuthorizationKey: hdrAuthorizationKey, jsonEscapeHTML: true, httpClient: hc, debugBodyLimit: math.MaxInt32, contentTypeEncoders: make(map[string]ContentTypeEncoder), contentTypeDecoders: make(map[string]ContentTypeDecoder), contentDecompresserKeys: make([]string, 0), contentDecompressers: make(map[string]ContentDecompresser), certWatcherStopChan: make(chan bool), } // Logger c.SetLogger(createLogger()) c.SetDebugLogFormatter(DebugLogFormatter) c.AddContentTypeEncoder(jsonKey, encodeJSON) c.AddContentTypeEncoder(xmlKey, encodeXML) c.AddContentTypeDecoder(jsonKey, decodeJSON) c.AddContentTypeDecoder(xmlKey, decodeXML) // Order matter, giving priority to gzip c.AddContentDecompresser("deflate", decompressDeflate) c.AddContentDecompresser("gzip", decompressGzip) // request middlewares c.SetRequestMiddlewares( MiddlewareRequestCreate, ) // response middlewares c.SetResponseMiddlewares( MiddlewareResponseAutoParse, MiddlewareResponseSaveToFile, ) return c } ================================================ FILE: resty_test.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty import ( "bytes" "compress/flate" "compress/gzip" "compress/lzw" "crypto/tls" "encoding/base64" "encoding/hex" "encoding/json" "encoding/xml" "errors" "fmt" "io" "net" "net/http" "net/http/httptest" "net/url" "os" "path/filepath" "reflect" "strconv" "strings" "sync/atomic" "testing" "time" ) var ( hdrLocationKey = http.CanonicalHeaderKey("Location") ) //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ // Testing Unexported methods //___________________________________ func getTestDataPath() string { pwd, _ := os.Getwd() return filepath.Join(pwd, ".testdata") } func createGetServer(t *testing.T) *httptest.Server { var attempt int32 var sequence int32 var lastRequest time.Time ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { t.Logf("Method: %v", r.Method) t.Logf("Path: %v", r.URL.Path) if r.Method == MethodGet { switch r.URL.Path { case "/": _, _ = w.Write([]byte("TestGet: text response")) case "/no-content": _, _ = w.Write([]byte("")) case "/json": w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"TestGet": "JSON response"}`)) case "/json-invalid": w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte("TestGet: Invalid JSON")) case "/long-text": _, _ = w.Write([]byte("TestGet: text response with size > 30")) case "/long-json": w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"TestGet": "JSON response with size > 30"}`)) case "/mypage": w.WriteHeader(http.StatusBadRequest) case "/mypage2": _, _ = w.Write([]byte("TestGet: text response from mypage2")) case "/set-retrycount-test": attp := atomic.AddInt32(&attempt, 1) if attp <= 4 { time.Sleep(time.Millisecond * 150) } _, _ = w.Write([]byte("TestClientRetry page")) case "/set-retrywaittime-test": // Returns time.Duration since last request here // or 0 for the very first request if atomic.LoadInt32(&attempt) == 0 { lastRequest = time.Now() _, _ = fmt.Fprint(w, "0") } else { now := time.Now() sinceLastRequest := now.Sub(lastRequest) lastRequest = now _, _ = fmt.Fprintf(w, "%d", uint64(sinceLastRequest)) } atomic.AddInt32(&attempt, 1) case "/set-retry-error-recover": w.Header().Set(hdrContentTypeKey, "application/json; charset=utf-8") if atomic.LoadInt32(&attempt) == 0 { w.WriteHeader(http.StatusTooManyRequests) _, _ = w.Write([]byte(`{ "message": "too many" }`)) } else { _, _ = w.Write([]byte(`{ "message": "hello" }`)) } atomic.AddInt32(&attempt, 1) case "/set-timeout-test-with-sequence": seq := atomic.AddInt32(&sequence, 1) time.Sleep(100 * time.Millisecond) _, _ = fmt.Fprintf(w, "%d", seq) case "/set-timeout-test": time.Sleep(400 * time.Millisecond) _, _ = w.Write([]byte("TestClientTimeout page")) case "/my-image.png": fileBytes, _ := os.ReadFile(filepath.Join(getTestDataPath(), "test-img.png")) w.Header().Set("Content-Type", "image/png") w.Header().Set("Content-Length", strconv.Itoa(len(fileBytes))) if r.URL.Query().Get("content-disposition") == "true" { filename := r.URL.Query().Get("filename") w.Header().Set(hdrContentDisposition, "inline; filename=\""+filename+"\"") } _, _ = w.Write(fileBytes) case "/get-method-payload-test": body, err := io.ReadAll(r.Body) if err != nil { t.Errorf("Error: could not read get body: %s", err.Error()) } _, _ = w.Write(body) case "/host-header": _, _ = w.Write([]byte(r.Host)) case "/not-found-with-error": w.Header().Set(hdrContentTypeKey, "application/json") w.WriteHeader(http.StatusNotFound) _, _ = w.Write([]byte(`{"error": "Not found"}`)) case "/not-found-no-error": w.Header().Set(hdrContentTypeKey, "application/json") w.WriteHeader(http.StatusNotFound) case "/retry-after-delay": w.Header().Set(hdrContentTypeKey, "application/json; charset=utf-8") if atomic.LoadInt32(&attempt) == 0 { w.Header().Set(hdrRetryAfterKey, "1") w.WriteHeader(http.StatusTooManyRequests) _, _ = w.Write([]byte(`{ "message": "too many" }`)) } else { _, _ = w.Write([]byte(`{ "message": "hello" }`)) } atomic.AddInt32(&attempt, 1) case "/unescape-query-params": initOne := r.URL.Query().Get("initone") fromClient := r.URL.Query().Get("fromclient") registry := r.URL.Query().Get("registry") assertEqual(t, "cáfe", initOne) assertEqual(t, "hey unescape", fromClient) assertEqual(t, "nacos://test:6801", registry) _, _ = w.Write([]byte(`query params looks good`)) } switch { case strings.HasPrefix(r.URL.Path, "/v1/users/sample@sample.com/100002"): if strings.HasSuffix(r.URL.Path, "details") { _, _ = w.Write([]byte("TestGetPathParams: text response: " + r.URL.String())) } else { _, _ = w.Write([]byte("TestPathParamURLInput: text response: " + r.URL.String())) } } } }) return ts } func handleLoginEndpoint(t *testing.T, w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/login" { user := &credentials{} // JSON if isJSONContentType(r.Header.Get(hdrContentTypeKey)) { jd := json.NewDecoder(r.Body) err := jd.Decode(user) if r.URL.Query().Get("ct") == "problem" { w.Header().Set(hdrContentTypeKey, "application/problem+json; charset=utf-8") } else if r.URL.Query().Get("ct") == "rpc" { w.Header().Set(hdrContentTypeKey, "application/json-rpc") } else { w.Header().Set(hdrContentTypeKey, "AppLicAtioN/jsON") } if err != nil { t.Logf("Error: %#v", err) w.WriteHeader(http.StatusBadRequest) _, _ = w.Write([]byte(`{ "id": "bad_request", "message": "Unable to read user info" }`)) return } if user.Username == "testuser" && user.Password == "testpass" { _, _ = w.Write([]byte(`{ "id": "success", "message": "login successful" }`)) } else if user.Username == "testuser" && user.Password == "invalidjson" { _, _ = w.Write([]byte(`{ "id": "success", "message": "login successful", }`)) } else { w.WriteHeader(http.StatusUnauthorized) _, _ = w.Write([]byte(`{ "id": "unauthorized", "message": "Invalid credentials" }`)) } return } // XML if isXMLContentType(r.Header.Get(hdrContentTypeKey)) { xd := xml.NewDecoder(r.Body) err := xd.Decode(user) w.Header().Set(hdrContentTypeKey, "application/xml") if err != nil { t.Logf("Error: %v", err) w.WriteHeader(http.StatusBadRequest) _, _ = w.Write([]byte(``)) _, _ = w.Write([]byte(`bad_requestUnable to read user info`)) return } if user.Username == "testuser" && user.Password == "testpass" { _, _ = w.Write([]byte(``)) _, _ = w.Write([]byte(`successlogin successful`)) } else if user.Username == "testuser" && user.Password == "invalidxml" { _, _ = w.Write([]byte(``)) _, _ = w.Write([]byte(`successlogin successful`)) } else { w.Header().Set("Www-Authenticate", "Protected Realm") w.WriteHeader(http.StatusUnauthorized) _, _ = w.Write([]byte(``)) _, _ = w.Write([]byte(`unauthorizedInvalid credentials`)) } return } } } func handleUsersEndpoint(t *testing.T, w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/users" { // JSON if isJSONContentType(r.Header.Get(hdrContentTypeKey)) { var users []ExampleUser jd := json.NewDecoder(r.Body) err := jd.Decode(&users) w.Header().Set(hdrContentTypeKey, "application/json") if err != nil { t.Logf("Error: %v", err) w.WriteHeader(http.StatusBadRequest) _, _ = w.Write([]byte(`{ "id": "bad_request", "message": "Unable to read user info" }`)) return } // logic check, since we are excepting to reach 3 records if len(users) != 3 { t.Log("Error: Excepted count of 3 records") w.WriteHeader(http.StatusBadRequest) _, _ = w.Write([]byte(`{ "id": "bad_request", "message": "Expected record count doesn't match" }`)) return } eu := users[2] if eu.FirstName == "firstname3" && eu.ZipCode == "10003" { w.WriteHeader(http.StatusAccepted) _, _ = w.Write([]byte(`{ "message": "Accepted" }`)) } return } } } func createPostServer(t *testing.T) *httptest.Server { ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { if r.Method == MethodPost { handleLoginEndpoint(t, w, r) handleUsersEndpoint(t, w, r) switch r.URL.Path { case "/login-json-html": w.Header().Set(hdrContentTypeKey, "text/html") w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(`{ "id": "success", "message": "login successful" }`)) return case "/usersmap": // JSON if isJSONContentType(r.Header.Get(hdrContentTypeKey)) { if r.URL.Query().Get("status") == "500" { body, err := io.ReadAll(r.Body) if err != nil { t.Errorf("Error: could not read post body: %s", err.Error()) } t.Logf("Got query param: status=500 so we're returning the post body as response and a 500 status code. body: %s", string(body)) w.Header().Set(hdrContentTypeKey, "application/json; charset=utf-8") w.WriteHeader(http.StatusInternalServerError) _, _ = w.Write(body) return } var users []map[string]any jd := json.NewDecoder(r.Body) err := jd.Decode(&users) w.Header().Set(hdrContentTypeKey, "application/json; charset=utf-8") if err != nil { t.Logf("Error: %v", err) w.WriteHeader(http.StatusBadRequest) _, _ = w.Write([]byte(`{ "id": "bad_request", "message": "Unable to read user info" }`)) return } // logic check, since we are excepting to reach 1 map records if len(users) != 1 { t.Log("Error: Excepted count of 1 map records") w.WriteHeader(http.StatusBadRequest) _, _ = w.Write([]byte(`{ "id": "bad_request", "message": "Expected record count doesn't match" }`)) return } w.WriteHeader(http.StatusAccepted) _, _ = w.Write([]byte(`{ "message": "Accepted" }`)) return } case "/redirect": w.Header().Set(hdrLocationKey, "/login") w.WriteHeader(http.StatusTemporaryRedirect) case "/redirect-with-body": body, _ := io.ReadAll(r.Body) query := url.Values{} query.Add("body", string(body)) w.Header().Set(hdrLocationKey, "/redirected-with-body?"+query.Encode()) w.WriteHeader(http.StatusTemporaryRedirect) case "/redirected-with-body": body, _ := io.ReadAll(r.Body) assertEqual(t, r.URL.Query().Get("body"), string(body)) w.WriteHeader(http.StatusOK) case "/curl-cmd-post": cookie := http.Cookie{ Name: "testserver", Domain: "localhost", Path: "/", Expires: time.Now().AddDate(0, 0, 1), Value: "yes", } http.SetCookie(w, &cookie) w.WriteHeader(http.StatusOK) case "/204-response": w.WriteHeader(http.StatusNoContent) } } }) return ts } func createFormPostServer(t *testing.T) *httptest.Server { ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { t.Logf("Content-Type: %v", r.Header.Get(hdrConnectionKey)) if r.Method == MethodPost { _ = r.ParseMultipartForm(10e6) if r.URL.Path == "/profile" { if r.MultipartForm == nil { values := r.Form t.Log(values) } else { values := r.MultipartForm.Value t.Log(values) } _, _ = w.Write([]byte("Success")) return } else if r.URL.Path == "/search" { formEncodedData := r.Form.Encode() t.Logf("Received Form Encoded values: %v", formEncodedData) assertTrue(t, strings.Contains(formEncodedData, "search_criteria=pencil"), "expected search_criteria=pencil") assertTrue(t, strings.Contains(formEncodedData, "search_criteria=glass"), "expected search_criteria=glass") _, _ = w.Write([]byte("Success")) return } else if r.URL.Path == "/upload" { t.Logf("FirstName: %v", r.FormValue("first_name")) t.Logf("LastName: %v", r.FormValue("last_name")) targetPath := filepath.Join(getTestDataPath(), "upload") _ = os.MkdirAll(targetPath, 0700) values := r.MultipartForm.Value t.Logf("%v", values) for _, fhdrs := range r.MultipartForm.File { for _, hdr := range fhdrs { t.Logf("Name: %v", hdr.Filename) t.Logf("Header: %v", hdr.Header) dotPos := strings.LastIndex(hdr.Filename, ".") fname := fmt.Sprintf("%s-%v%s", hdr.Filename[:dotPos], time.Now().Unix(), hdr.Filename[dotPos:]) t.Logf("Write name: %v", fname) infile, _ := hdr.Open() f, err := os.OpenFile(filepath.Join(targetPath, fname), os.O_WRONLY|os.O_CREATE, 0666) if err != nil { t.Logf("Error: %v", err) return } defer func() { _ = f.Close() }() size, _ := io.Copy(f, infile) _, _ = w.Write([]byte(fmt.Sprintf("File: %v, uploaded as: %v, size: %v\n", hdr.Filename, fname, size))) } } return } } if r.Method == MethodPut { if r.URL.Path == "/raw-upload" { body, _ := io.ReadAll(r.Body) bl, _ := strconv.Atoi(r.Header.Get("Content-Length")) assertEqual(t, len(body), bl) w.WriteHeader(http.StatusOK) } } }) return ts } func createFormPatchServer(t *testing.T) *httptest.Server { ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { t.Logf("Method: %v", r.Method) t.Logf("Path: %v", r.URL.Path) t.Logf("Content-Type: %v", r.Header.Get(hdrContentTypeKey)) if r.Method == MethodPatch { _ = r.ParseMultipartForm(10e6) if r.URL.Path == "/upload" { t.Logf("FirstName: %v", r.FormValue("first_name")) t.Logf("LastName: %v", r.FormValue("last_name")) targetPath := filepath.Join(getTestDataPath(), "upload") _ = os.MkdirAll(targetPath, 0700) for _, fhdrs := range r.MultipartForm.File { for _, hdr := range fhdrs { t.Logf("Name: %v", hdr.Filename) t.Logf("Header: %v", hdr.Header) dotPos := strings.LastIndex(hdr.Filename, ".") fname := fmt.Sprintf("%s-%v%s", hdr.Filename[:dotPos], time.Now().Unix(), hdr.Filename[dotPos:]) t.Logf("Write name: %v", fname) infile, _ := hdr.Open() f, err := os.OpenFile(filepath.Join(targetPath, fname), os.O_WRONLY|os.O_CREATE, 0666) if err != nil { t.Logf("Error: %v", err) return } defer func() { _ = f.Close() }() _, _ = io.Copy(f, infile) _, _ = w.Write([]byte(fmt.Sprintf("File: %v, uploaded as: %v\n", hdr.Filename, fname))) } } return } } }) return ts } func createFileUploadServer(t *testing.T) *httptest.Server { ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { t.Logf("Method: %v", r.Method) t.Logf("Path: %v", r.URL.Path) t.Logf("Content-Type: %v", r.Header.Get(hdrContentTypeKey)) if r.Method != MethodPost && r.Method != MethodPut { t.Log("createFileUploadServer:: Not a POST or PUT request") w.WriteHeader(http.StatusBadRequest) fmt.Fprint(w, http.StatusText(http.StatusBadRequest)) return } targetPath := filepath.Join(getTestDataPath(), "upload-large") _ = os.MkdirAll(targetPath, 0700) defer cleanupFiles(targetPath) switch r.URL.Path { case "/upload": f, err := os.OpenFile(filepath.Join(targetPath, "large-file.png"), os.O_WRONLY|os.O_CREATE, 0666) if err != nil { t.Logf("Error: %v", err) return } defer func() { _ = f.Close() }() size, _ := io.Copy(f, r.Body) fmt.Fprintf(w, "File Uploaded successfully, file size: %v", size) case "/set-reset-multipart-readers-test": w.Header().Set(hdrContentTypeKey, "application/json; charset=utf-8") w.WriteHeader(http.StatusInternalServerError) _, _ = fmt.Fprintf(w, `{ "message": "error" }`) } }) return ts } func createAuthServer(t *testing.T) *httptest.Server { return createAuthServerTLSOptional(t, true) } func createAuthServerTLSOptional(t *testing.T, useTLS bool) *httptest.Server { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Logf(`createAuthServerTLSOptional: Method: %v, Path: %v, Content-Type: %v`, r.Method, r.URL.Path, r.Header.Get(hdrContentTypeKey)) if r.Method == MethodGet { if r.URL.Path == "/profile" { // 004DDB79-6801-4587-B976-F093E6AC44FF auth := r.Header.Get("Authorization") t.Logf("Bearer Auth: %v", auth) w.Header().Set(hdrContentTypeKey, "application/json; charset=utf-8") if strings.HasPrefix(auth, "Basic ") { w.Header().Set("Www-Authenticate", "Protected Realm") w.WriteHeader(http.StatusUnauthorized) _, _ = w.Write([]byte(`{ "id": "unauthorized", "message": "Invalid credentials" }`)) return } if strings.Contains(auth, "004DDB79-6801-4587-B976-F093E6AC44FF") { _, _ = w.Write([]byte(`{ "username": "auth_test", "message": "profile fetch successful" }`)) } } return } if r.Method == MethodPost { if r.URL.Path == "/login" { auth := r.Header.Get("Authorization") t.Logf("Basic Auth: %v", auth) _, _ = io.ReadAll(r.Body) w.Header().Set(hdrContentTypeKey, "application/json; charset=utf-8") password, err := base64.StdEncoding.DecodeString(auth[6:]) if err != nil || string(password) != "myuser:basicauth" { w.Header().Set("Www-Authenticate", "Protected Realm") w.WriteHeader(http.StatusUnauthorized) _, _ = w.Write([]byte(`{ "id": "unauthorized", "message": "Invalid credentials" }`)) return } _, _ = w.Write([]byte(`{ "id": "success", "message": "login successful" }`)) } return } }) if useTLS { return httptest.NewTLSServer(handler) } return httptest.NewServer(handler) } func createGenericServer(t *testing.T) *httptest.Server { ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { t.Logf("Method: %v", r.Method) t.Logf("Path: %v", r.URL.Path) if r.Method == MethodGet { switch r.URL.Path { case "/json-no-set": // Set empty header value for testing, since Go server sets to // text/plain; charset=utf-8 w.Header().Set(hdrContentTypeKey, "") _, _ = w.Write([]byte(`{"response":"json response no content type set"}`)) // Gzip case "/gzip-test": w.Header().Set(hdrContentTypeKey, plainTextType) w.Header().Set(hdrContentEncodingKey, "gzip") zw := gzip.NewWriter(w) _, _ = zw.Write([]byte("This is Gzip response testing")) zw.Close() case "/gzip-test-gziped-empty-body": w.Header().Set(hdrContentTypeKey, plainTextType) w.Header().Set(hdrContentEncodingKey, "gzip") zw := gzip.NewWriter(w) // write gziped empty body _, _ = zw.Write([]byte("")) zw.Close() case "/gzip-test-no-gziped-body": w.Header().Set(hdrContentTypeKey, plainTextType) w.Header().Set(hdrContentEncodingKey, "gzip") // don't write body // Deflate case "/deflate-test": w.Header().Set(hdrContentTypeKey, plainTextType) w.Header().Set(hdrContentEncodingKey, "deflate") zw, _ := flate.NewWriter(w, flate.BestSpeed) _, _ = zw.Write([]byte("This is Deflate response testing")) zw.Close() case "/deflate-test-empty-body": w.Header().Set(hdrContentTypeKey, plainTextType) w.Header().Set(hdrContentEncodingKey, "deflate") zw, _ := flate.NewWriter(w, flate.BestSpeed) // write deflate empty body _, _ = zw.Write([]byte("")) zw.Close() case "/deflate-test-no-body": w.Header().Set(hdrContentTypeKey, plainTextType) w.Header().Set(hdrContentEncodingKey, "deflate") // don't write body // LZW case "/lzw-test": w.Header().Set(hdrContentTypeKey, plainTextType) w.Header().Set(hdrContentEncodingKey, "coMpReSs") zw := lzw.NewWriter(w, lzw.LSB, 8) _, _ = zw.Write([]byte("This is LZW response testing")) zw.Close() case "/lzw-test-empty-body": w.Header().Set(hdrContentTypeKey, plainTextType) w.Header().Set(hdrContentEncodingKey, "compress") zw := lzw.NewWriter(w, lzw.LSB, 8) // write lzw empty body _, _ = zw.Write([]byte("")) zw.Close() case "/lzw-test-no-body": w.Header().Set(hdrContentTypeKey, plainTextType) w.Header().Set(hdrContentEncodingKey, "compress") // don't write body } return } if r.Method == MethodPut { if r.URL.Path == "/plaintext" { _, _ = w.Write([]byte("TestPut: plain text response")) } else if r.URL.Path == "/json" { w.Header().Set(hdrContentTypeKey, "application/json; charset=utf-8") _, _ = w.Write([]byte(`{"response":"json response"}`)) } else if r.URL.Path == "/xml" { w.Header().Set(hdrContentTypeKey, "application/xml") _, _ = w.Write([]byte(`XML response`)) } return } if r.Method == MethodOptions && r.URL.Path == "/options" { w.Header().Set("Access-Control-Allow-Origin", "localhost") w.Header().Set("Access-Control-Allow-Methods", "PUT, PATCH") w.Header().Set("Access-Control-Expose-Headers", "x-go-resty-id") w.WriteHeader(http.StatusOK) return } if r.Method == MethodPatch && r.URL.Path == "/patch" { w.WriteHeader(http.StatusOK) return } if r.Method == "REPORT" && r.URL.Path == "/report" { body, _ := io.ReadAll(r.Body) if len(body) == 0 { w.WriteHeader(http.StatusOK) } return } if r.Method == MethodTrace && r.URL.Path == "/trace" { w.WriteHeader(http.StatusOK) return } if r.Method == MethodDelete && r.URL.Path == "/delete" { body, err := io.ReadAll(r.Body) if err != nil { t.Errorf("Error: could not read get body: %s", err.Error()) } _, _ = w.Write(body) return } }) return ts } func createRedirectServer(t *testing.T) *httptest.Server { ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { t.Logf("Method: %v", r.Method) t.Logf("Path: %v", r.URL.Path) if r.Method == MethodGet { if strings.HasPrefix(r.URL.Path, "/redirect-host-check-") { cntStr := strings.SplitAfter(r.URL.Path, "-")[3] cnt, _ := strconv.Atoi(cntStr) if cnt != 7 { // Testing hard stop via logical if cnt >= 5 { http.Redirect(w, r, "http://httpbin.org/get", http.StatusTemporaryRedirect) } else { http.Redirect(w, r, fmt.Sprintf("/redirect-host-check-%d", cnt+1), http.StatusTemporaryRedirect) } } } else if strings.HasPrefix(r.URL.Path, "/redirect-") { cntStr := strings.SplitAfter(r.URL.Path, "-")[1] cnt, _ := strconv.Atoi(cntStr) http.Redirect(w, r, fmt.Sprintf("/redirect-%d", cnt+1), http.StatusTemporaryRedirect) } } }) return ts } func createUnixSocketEchoServer(t *testing.T) string { socketPath := filepath.Join(os.TempDir(), strconv.FormatInt(time.Now().Unix(), 10)) + ".sock" // Create a Unix domain socket and listen for incoming connections. socket, err := net.Listen("unix", socketPath) if err != nil { t.Fatal(err) } m := http.NewServeMux() m.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("Hi resty client from a server running on Unix domain socket!\n")) }) m.HandleFunc("/hello", func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("Hello resty client from a server running on endpoint /hello!\n")) }) go func(t *testing.T) { server := http.Server{Handler: m} if err := server.Serve(socket); err != nil { t.Error(err) } }(t) return socketPath } func createDigestServer(t *testing.T, conf *digestServerConfig) *httptest.Server { if conf == nil { conf = defaultDigestServerConf() } setWWWAuthHeader := func(w http.ResponseWriter, v string) { w.Header().Set("WWW-Authenticate", v) w.WriteHeader(http.StatusUnauthorized) } ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { t.Logf("Method: %v", r.Method) t.Logf("Path: %v", r.URL.Path) switch r.URL.Path { case "/bad": setWWWAuthHeader(w, "Bad Challenge") return case "/unknown_param": setWWWAuthHeader(w, "Digest unknown_param=true") return case "/missing_value": setWWWAuthHeader(w, `Digest realm="hello", domain`) return case "/unclosed_quote": setWWWAuthHeader(w, `Digest realm="hello, qop=auth`) return case "/no_challenge": setWWWAuthHeader(w, "") return case "/status_500": w.WriteHeader(http.StatusInternalServerError) return } w.Header().Set(hdrContentTypeKey, "application/json; charset=utf-8") if authorizationHeaderValid(t, r, conf) { if r.URL.Path == "/dir/index.html" && r.Method == MethodPost { body, err := io.ReadAll(r.Body) assertNil(t, err) assertEqual(t, `{"city":"Los Angeles","zip_code":"00000"}`, strings.TrimSpace(string(body))) } w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(`{ "id": "success", "message": "login successful" }`)) } else { setWWWAuthHeader(w, fmt.Sprintf(`Digest realm="%s", domain="%s", qop="%s", algorithm=%s, nonce="%s", opaque="%s", userhash=true, charset=%s, stale=FALSE, nc=%s`, conf.realm, conf.uri, conf.qop, conf.algo, conf.nonce, conf.opaque, conf.charset, conf.nc)) _, _ = w.Write([]byte(`{ "id": "unauthorized", "message": "Invalid credentials" }`)) } }) return ts } func authorizationHeaderValid(t *testing.T, r *http.Request, conf *digestServerConfig) bool { input := r.Header.Get(hdrAuthorizationKey) if input == "" { return false } const ws = " \n\r\t" const qs = `"` s := strings.Trim(input, ws) assertTrue(t, strings.HasPrefix(s, "Digest "), "Digest auth header prefix expected") s = strings.Trim(s[7:], ws) sl := strings.Split(s, ", ") pairs := make(map[string]string, len(sl)) for i := range sl { pair := strings.SplitN(sl[i], "=", 2) pairs[pair[0]] = strings.Trim(pair[1], qs) } assertEqual(t, conf.algo, pairs["algorithm"]) h := func(data string) string { h := newHashFunc(pairs["algorithm"]) _, _ = h.Write([]byte(data)) return hex.EncodeToString(h.Sum(nil)) } assertEqual(t, conf.opaque, pairs["opaque"]) assertEqual(t, "true", pairs["userhash"]) userHash := h(fmt.Sprintf("%s:%s", conf.username, conf.realm)) assertEqual(t, userHash, pairs["username"]) ha1 := h(fmt.Sprintf("%s:%s:%s", conf.username, conf.realm, conf.password)) if strings.HasSuffix(conf.algo, "-sess") { ha1 = h(fmt.Sprintf("%s:%s:%s", ha1, pairs["nonce"], pairs["cnonce"])) } ha2 := h(fmt.Sprintf("%s:%s", r.Method, conf.uri)) qop := pairs["qop"] if qop == "" { kd := h(fmt.Sprintf("%s:%s:%s", ha1, pairs["nonce"], ha2)) return kd == pairs["response"] } nonceCount, err := strconv.Atoi(pairs["nc"]) assertError(t, err) // auth scenario if qop == qopAuth { kd := h(fmt.Sprintf("%s:%s", ha1, fmt.Sprintf("%s:%08x:%s:%s:%s", pairs["nonce"], nonceCount, pairs["cnonce"], pairs["qop"], ha2))) return kd == pairs["response"] } // auth-int scenario body, err := io.ReadAll(r.Body) r.Body.Close() r.Body = io.NopCloser(bytes.NewReader(body)) assertError(t, err) bodyHash := "" if len(body) > 0 { bodyHash = h(string(body)) } ha2 = h(fmt.Sprintf("%s:%s:%s", r.Method, conf.uri, bodyHash)) kd := h(fmt.Sprintf("%s:%s", ha1, fmt.Sprintf("%s:%08x:%s:%s:%s", pairs["nonce"], nonceCount, pairs["cnonce"], pairs["qop"], ha2))) return kd == pairs["response"] } func createTestServer(fn func(w http.ResponseWriter, r *http.Request)) *httptest.Server { return httptest.NewServer(http.HandlerFunc(fn)) } func createTestTLSServer(fn func(w http.ResponseWriter, r *http.Request), certPath, certKeyPath string) *httptest.Server { ts := httptest.NewUnstartedServer(http.HandlerFunc(fn)) ts.TLS = &tls.Config{ GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { cert, err := tls.LoadX509KeyPair(certPath, certKeyPath) if err != nil { return nil, err } return &cert, nil }, } ts.StartTLS() return ts } func dcnl() *Client { c := New(). outputLogTo(io.Discard) return c } func dcnld() *Client { return dcnl().SetDebug(true) } func dcldb() (*Client, *bytes.Buffer) { logBuf := acquireBuffer() c := New(). SetDebug(true). outputLogTo(logBuf) return c, logBuf } func dcnlr() *Request { return dcnl().R() } func dcnldr() *Request { c := dcnl(). SetDebug(true) return c.R() } func assertNil(t *testing.T, v any, failureMsgs ...string) { t.Helper() if !isNil(v) { t.Errorf("[%v] was expected to be nil. Message: %v", v, strings.Join(failureMsgs, " ")) } } func assertNotNil(t *testing.T, v any, failureMsgs ...string) { t.Helper() if isNil(v) { t.Errorf("[%v] was expected to be non-nil. Message: %v", v, strings.Join(failureMsgs, " ")) } } func assertType(t *testing.T, typ, v any, failureMsgs ...string) { t.Helper() if reflect.DeepEqual(reflect.TypeOf(typ), reflect.TypeOf(v)) { t.Errorf("Expected type %t, got %t. Message: %v", typ, v, strings.Join(failureMsgs, " ")) } } func assertError(t *testing.T, err error, failureMsgs ...string) { t.Helper() if err != nil { t.Errorf("Error occurred [%v]. Message: %v", err, strings.Join(failureMsgs, " ")) } } func assertErrorIs(t *testing.T, e, g error, failureMsgs ...string) (r bool) { t.Helper() if !errors.Is(g, e) { t.Errorf("Expected [%v], got [%v]. Message: %v", e, g, strings.Join(failureMsgs, " ")) } return true } func assertTrue(t *testing.T, g any, failureMsgs ...string) (r bool) { t.Helper() if !equal(true, g) { t.Errorf("Expected `true`, got [%v]. Message: %v", g, strings.Join(failureMsgs, " ")) } return } func assertFalse(t *testing.T, g any, failureMsgs ...string) (r bool) { t.Helper() if !equal(false, g) { t.Errorf("Expected `false`, got [%v]. Message: %v", g, strings.Join(failureMsgs, " ")) } return } func assertEqual(t *testing.T, e, g any, failureMsgs ...string) (r bool) { t.Helper() if !equal(e, g) { t.Errorf("Expected [%v], got [%v]. Message: %v", e, g, strings.Join(failureMsgs, " ")) } return } func assertNotEqual(t *testing.T, e, g any, failureMsgs ...string) (r bool) { t.Helper() if equal(e, g) { t.Errorf("Expected [%v], got [%v]. Message: %v", e, g, strings.Join(failureMsgs, " ")) } else { r = true } return } func equal(expected, got any) bool { return reflect.DeepEqual(expected, got) } func isNil(v any) bool { if v == nil { return true } rv := reflect.ValueOf(v) kind := rv.Kind() if kind >= reflect.Chan && kind <= reflect.Slice && rv.IsNil() { return true } return false } func logResponse(t *testing.T, resp *Response) { t.Helper() t.Logf("Response Status: %v", resp.Status()) t.Logf("Response Duration: %v", resp.Duration()) t.Logf("Response Headers: %v", resp.Header()) t.Logf("Response Cookies: %v", resp.Cookies()) t.Logf("Response Body: %v", resp) } func cleanupFiles(files ...string) { pwd, _ := os.Getwd() for _, f := range files { if filepath.IsAbs(f) { _ = os.RemoveAll(f) } else { _ = os.RemoveAll(filepath.Join(pwd, f)) } } } func createBinFile(fileName string, size int64) string { fp := filepath.Join(getTestDataPath(), fileName) f, _ := os.Create(fp) _ = f.Truncate(size) _ = f.Close() return fp } ================================================ FILE: retry.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty import ( "crypto/tls" "math" "math/rand" "net/http" "net/url" "regexp" "strconv" "sync" "time" ) const ( defaultWaitTime = time.Duration(100) * time.Millisecond defaultMaxWaitTime = time.Duration(2000) * time.Millisecond ) type ( // RetryConditionFunc type is for the retry condition function // input: non-nil Response OR request execution error RetryConditionFunc func(*Response, error) bool // RetryHookFunc is for side-effecting functions triggered on retry RetryHookFunc func(*Response, error) // RetryDelayStrategyFunc is a type for implementing custom retry delay strategies. // By default, Resty employs the capped exponential backoff with a jitter delay strategy. RetryDelayStrategyFunc func(*Response, error) (time.Duration, error) ) // RetryConstantDelayStrategy returns a RetryDelayStrategyFunc that always returns the specified delay duration. func RetryConstantDelayStrategy(delay time.Duration) RetryDelayStrategyFunc { return func(*Response, error) (time.Duration, error) { return delay, nil } } var ( regexErrTooManyRedirects = regexp.MustCompile(`stopped after \d+ redirects\z`) regexErrScheme = regexp.MustCompile("unsupported protocol scheme") regexErrInvalidHeader = regexp.MustCompile("invalid header") ) func applyRetryDefaultConditions(res *Response, err error) bool { // no retry on TLS error if _, ok := err.(*tls.CertificateVerificationError); ok { return false } // validate url error, so we can decide to retry or not if u, ok := err.(*url.Error); ok { if regexErrTooManyRedirects.MatchString(u.Error()) { return false } if regexErrScheme.MatchString(u.Error()) { return false } if regexErrInvalidHeader.MatchString(u.Error()) { return false } return u.Temporary() // possible retry if it's true } if res == nil { return false } // certain HTTP status codes are temporary so that we can retry // - 429 Too Many Requests // - 500 or above (it's better to ignore 501 Not Implemented) // - 0 No status code received if res.StatusCode() == http.StatusTooManyRequests || (res.StatusCode() >= 500 && res.StatusCode() != http.StatusNotImplemented) || res.StatusCode() == 0 { return true } return false } func newBackoffWithJitter(min, max time.Duration) *backoffWithJitter { if min <= 0 { min = defaultWaitTime } if max == 0 { max = defaultMaxWaitTime } return &backoffWithJitter{ lock: new(sync.Mutex), rnd: rand.New(rand.NewSource(time.Now().UnixNano())), min: min, max: max, } } type backoffWithJitter struct { lock *sync.Mutex rnd *rand.Rand min time.Duration max time.Duration } func (b *backoffWithJitter) NextWaitDuration(c *Client, res *Response, err error, attempt int) (time.Duration, error) { if res != nil { if res.StatusCode() == http.StatusTooManyRequests || res.StatusCode() == http.StatusServiceUnavailable { if delay, ok := parseRetryAfterHeader(res.Header().Get(hdrRetryAfterKey)); ok { return delay, nil } } } const maxInt = 1<<31 - 1 // max int for arch 386 if b.max < 0 { b.max = maxInt } if res == nil || res.Request.RetryDelayStrategy == nil { return b.balanceMinMax(b.defaultDelayStrategy(attempt)), nil } // invoke custom retry delay strategy return res.Request.RetryDelayStrategy(res, err) } // Return capped exponential backoff with jitter // https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ func (b *backoffWithJitter) defaultDelayStrategy(attempt int) time.Duration { temp := math.Min(float64(b.max), float64(b.min)*math.Exp2(float64(attempt))) ri := time.Duration(temp / 2) if ri <= 0 { ri = time.Nanosecond } return b.randDuration(ri) } func (b *backoffWithJitter) randDuration(center time.Duration) time.Duration { b.lock.Lock() defer b.lock.Unlock() var ri = int64(center) var jitter = b.rnd.Int63n(ri) return time.Duration(math.Abs(float64(ri + jitter))) } func (b *backoffWithJitter) balanceMinMax(delay time.Duration) time.Duration { if delay <= 0 || b.max < delay { return b.max } if delay < b.min { return b.min } return delay } var timeNow = time.Now // parseRetryAfterHeader parses the Retry-After header and returns the // delay duration according to the spec: https://httpwg.org/specs/rfc7231.html#header.retry-after // The bool returned will be true if the header was successfully parsed. // Otherwise, the header was either not present, or was not parseable according to the spec. // // Retry-After headers come in two flavors: Seconds or HTTP-Date // // Examples: // - Retry-After: Fri, 31 Dec 1999 23:59:59 GMT // - Retry-After: 120 func parseRetryAfterHeader(v string) (time.Duration, bool) { if isStringEmpty(v) { return 0, false } // Retry-After: 120 if delay, err := strconv.ParseInt(v, 10, 64); err == nil { if delay < 0 { // a negative delay doesn't make sense return 0, false } return time.Second * time.Duration(delay), true } // Retry-After: Fri, 31 Dec 1999 23:59:59 GMT retryTime, err := time.Parse(time.RFC1123, v) if err != nil { return 0, false } if until := retryTime.Sub(timeNow()); until > 0 { return until, true } // date is in the past return 0, true } ================================================ FILE: retry_test.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty import ( "bytes" "context" "crypto/tls" "encoding/json" "errors" "fmt" "io" "net/http" "net/http/httptest" "reflect" "strconv" "strings" "testing" "time" ) // Check to make sure the functions added to add conditionals work func TestRetryConditionalGet(t *testing.T) { ts := createGetServer(t) defer ts.Close() attemptCount := 1 externalCounter := 0 // This check should pass on first run, and let the response through check := RetryConditionFunc(func(*Response, error) bool { externalCounter++ return attemptCount != externalCounter }) client := dcnl() resp, err := client.R(). AddRetryConditions(check). SetRetryCount(2). SetQueryParam("request_no", strconv.FormatInt(time.Now().Unix(), 10)). Get(ts.URL + "/") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "200 OK", resp.Status()) assertEqual(t, "TestGet: text response", resp.String()) assertEqual(t, externalCounter, attemptCount) logResponse(t, resp) } func TestRequestConditionalGet(t *testing.T) { ts := createGetServer(t) defer ts.Close() externalCounter := 0 // This check should pass on first run, and let the response through check := RetryConditionFunc(func(r *Response, _ error) bool { externalCounter++ return false }) // Clear the default client. c, lb := dcldb() resp, err := c.R(). SetDebug(true). AddRetryConditions(check). SetRetryCount(1). SetRetryWaitTime(50*time.Millisecond). SetRetryMaxWaitTime(1*time.Second). SetQueryParam("request_no", strconv.FormatInt(time.Now().Unix(), 10)). Get(ts.URL + "/") assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "200 OK", resp.Status()) assertEqual(t, "TestGet: text response", resp.String()) assertEqual(t, 1, resp.Request.Attempt) assertEqual(t, 1, externalCounter) assertTrue(t, strings.Contains(lb.String(), "CORRELATION ID:"), "expected debug log with correlation ID") logResponse(t, resp) } func TestClientRetryGetWithTimeout(t *testing.T) { ts := createGetServer(t) defer ts.Close() c := dcnl(). SetTimeout(50 * time.Millisecond). SetRetryCount(3) resp, err := c.R().Get(ts.URL + "/set-retrycount-test") assertEqual(t, "", resp.Status()) assertEqual(t, "", resp.Proto()) assertEqual(t, 0, resp.StatusCode()) assertEqual(t, 0, len(resp.Cookies())) assertEqual(t, 0, len(resp.Header())) assertErrorIs(t, context.DeadlineExceeded, err, "expected context deadline exceeded error") } func TestClientRetryWithMinAndMaxWaitTime(t *testing.T) { ts := createGetServer(t) defer ts.Close() retryCount := 5 retryIntervals := make([]uint64, retryCount+1) // Set retry wait times that do not intersect with default ones retryWaitTime := 10 * time.Millisecond retryMaxWaitTime := 100 * time.Millisecond c, lb := dcldb() c.SetRetryCount(retryCount). SetRetryWaitTime(retryWaitTime). SetRetryMaxWaitTime(retryMaxWaitTime). AddRetryConditions( func(r *Response, _ error) bool { retryIntervals[r.Request.Attempt-1] = parseTimeSleptFromResponse(r.String()) return true }, ) res, _ := c.R().SetDebug(true).Get(ts.URL + "/set-retrywaittime-test") retryIntervals[res.Request.Attempt-1] = parseTimeSleptFromResponse(res.String()) // retryCount+1 == attempts were made assertEqual(t, retryCount+1, res.Request.Attempt) assertTrue(t, strings.Contains(lb.String(), "CORRELATION ID:"), "expected debug log with correlation ID") // Initial attempt has 0 time slept since last request assertEqual(t, retryIntervals[0], uint64(0)) for i := 1; i < len(retryIntervals); i++ { slept := time.Duration(retryIntervals[i]) // Ensure that client has slept some duration between // waitTime and maxWaitTime for consequent requests if slept < retryWaitTime-5*time.Millisecond { t.Logf("Client has slept %f seconds which is s < min (%f) before retry %d", slept.Seconds(), retryWaitTime.Seconds(), i) } if slept > retryMaxWaitTime+5*time.Millisecond { t.Logf("Client has slept %f seconds which is s > max (%f) before retry %d", slept.Seconds(), retryMaxWaitTime.Seconds(), i) } } } func TestClientRetryWaitMaxInfinite(t *testing.T) { ts := createGetServer(t) defer ts.Close() retryCount := 5 retryIntervals := make([]uint64, retryCount+1) // Set retry wait times that do not intersect with default ones retryWaitTime := time.Duration(10) * time.Millisecond retryMaxWaitTime := time.Duration(-1.0) // negative value c := dcnl(). SetRetryCount(retryCount). SetRetryWaitTime(retryWaitTime). SetRetryMaxWaitTime(retryMaxWaitTime). AddRetryConditions( func(r *Response, _ error) bool { retryIntervals[r.Request.Attempt-1] = parseTimeSleptFromResponse(r.String()) return true }, ) res, _ := c.R().Get(ts.URL + "/set-retrywaittime-test") retryIntervals[res.Request.Attempt-1] = parseTimeSleptFromResponse(res.String()) // retryCount+1 == attempts were made assertEqual(t, retryCount+1, res.Request.Attempt) // Initial attempt has 0 time slept since last request assertEqual(t, retryIntervals[0], uint64(0)) for i := 1; i < len(retryIntervals); i++ { slept := time.Duration(retryIntervals[i]) // Ensure that client has slept some duration between // waitTime and maxWaitTime for consequent requests if slept < retryWaitTime-5*time.Millisecond { t.Logf("Client has slept %f seconds which is s < min (%f) before retry %d", slept.Seconds(), retryWaitTime.Seconds(), i) } } } func TestClientRetryWaitMaxMinimum(t *testing.T) { ts := createGetServer(t) defer ts.Close() const retryMaxWaitTime = time.Nanosecond // minimal duration value c := dcnl(). SetRetryCount(1). SetRetryMaxWaitTime(retryMaxWaitTime). AddRetryConditions(func(*Response, error) bool { return true }) _, err := c.R().Get(ts.URL + "/set-retrywaittime-test") assertError(t, err) } func TestClientRetryDelayStrategyFuncError(t *testing.T) { ts := createGetServer(t) defer ts.Close() attempt := 0 retryCount := 5 retryIntervals := make([]uint64, retryCount+1) // Set retry wait times that do not intersect with default ones retryWaitTime := 50 * time.Millisecond retryMaxWaitTime := 150 * time.Millisecond retryDelayStrategyFunc := func(res *Response, err error) (time.Duration, error) { return 0, errors.New("quota exceeded") } c := dcnl(). SetRetryCount(retryCount). SetRetryWaitTime(retryWaitTime). SetRetryMaxWaitTime(retryMaxWaitTime). SetRetryDelayStrategy(retryDelayStrategyFunc). AddRetryConditions( func(r *Response, _ error) bool { retryIntervals[attempt] = parseTimeSleptFromResponse(r.String()) attempt++ return true }, ) _, err := c.R().Get(ts.URL + "/set-retrywaittime-test") // 1 attempts were made assertEqual(t, 1, attempt) // non-nil error was returned assertNotNil(t, err) } func TestClientRetryDelayStrategyFunc(t *testing.T) { ts := createGetServer(t) defer ts.Close() retryCount := 10 retryIntervals := make([]uint64, retryCount+1) // Set retry wait times to constant delay retryWaitTime := 50 * time.Millisecond retryMaxWaitTime := 50 * time.Millisecond c := dcnl(). SetRetryCount(retryCount). SetRetryWaitTime(retryWaitTime). SetRetryMaxWaitTime(retryMaxWaitTime). SetRetryDelayStrategy(RetryConstantDelayStrategy(50 * time.Microsecond)). AddRetryConditions( func(r *Response, _ error) bool { retryIntervals[r.Request.Attempt-1] = parseTimeSleptFromResponse(r.String()) return true }, ) res, _ := c.R().Get(ts.URL + "/set-retrywaittime-test") retryIntervals[res.Request.Attempt-1] = parseTimeSleptFromResponse(res.String()) // retryCount+1 == attempts were made assertEqual(t, retryCount+1, res.Request.Attempt) // Initial attempt has 0 time slept since last request assertEqual(t, retryIntervals[0], uint64(0)) for i := 1; i < len(retryIntervals); i++ { slept := time.Duration(retryIntervals[i]) // Ensure that client has slept some duration between // waitTime and maxWaitTime for consequent requests if slept < retryWaitTime-5*time.Millisecond { t.Logf("Client has slept %f seconds which is s < min (%f) before retry %d", slept.Seconds(), retryWaitTime.Seconds(), i) } if retryMaxWaitTime+5*time.Millisecond < slept { t.Logf("Client has slept %f seconds which is max < s (%f) before retry %d", slept.Seconds(), retryMaxWaitTime.Seconds(), i) } } } func TestRequestRetryDelayStrategyFunc(t *testing.T) { ts := createGetServer(t) defer ts.Close() retryCount := 10 retryIntervals := make([]uint64, retryCount+1) // Set retry wait times to constant delay retryWaitTime := 50 * time.Millisecond retryMaxWaitTime := 50 * time.Millisecond c := dcnl() res, _ := c.R(). SetRetryCount(retryCount). SetRetryWaitTime(retryWaitTime). SetRetryMaxWaitTime(retryMaxWaitTime). SetRetryDelayStrategy(RetryConstantDelayStrategy(50 * time.Microsecond)). AddRetryConditions( func(r *Response, _ error) bool { retryIntervals[r.Request.Attempt-1] = parseTimeSleptFromResponse(r.String()) return true }, ). Get(ts.URL + "/set-retrywaittime-test") retryIntervals[res.Request.Attempt-1] = parseTimeSleptFromResponse(res.String()) // retryCount+1 == attempts were made assertEqual(t, retryCount+1, res.Request.Attempt) // Initial attempt has 0 time slept since last request assertEqual(t, retryIntervals[0], uint64(0)) for i := 1; i < len(retryIntervals); i++ { slept := time.Duration(retryIntervals[i]) // Ensure that client has slept some duration between // waitTime and maxWaitTime for consequent requests if slept < retryWaitTime-5*time.Millisecond { t.Logf("Client has slept %f seconds which is s < min (%f) before retry %d", slept.Seconds(), retryWaitTime.Seconds(), i) } if retryMaxWaitTime+5*time.Millisecond < slept { t.Logf("Client has slept %f seconds which is max < s (%f) before retry %d", slept.Seconds(), retryMaxWaitTime.Seconds(), i) } } } func TestClientRetryDelayStrategyWaitTooShort(t *testing.T) { ts := createGetServer(t) defer ts.Close() retryCount := 5 retryIntervals := make([]uint64, retryCount+1) // Set retry wait times that do not intersect with default ones retryWaitTime := 50 * time.Millisecond retryMaxWaitTime := 150 * time.Millisecond c := dcnl(). SetRetryCount(retryCount). SetRetryWaitTime(retryWaitTime). SetRetryMaxWaitTime(retryMaxWaitTime). SetRetryDelayStrategy(RetryConstantDelayStrategy(10 * time.Microsecond)). AddRetryConditions( func(r *Response, _ error) bool { retryIntervals[r.Request.Attempt-1] = parseTimeSleptFromResponse(r.String()) return true }, ) res, _ := c.R().Get(ts.URL + "/set-retrywaittime-test") retryIntervals[res.Request.Attempt-1] = parseTimeSleptFromResponse(res.String()) // retryCount+1 == attempts were made assertEqual(t, retryCount+1, res.Request.Attempt) // Initial attempt has 0 time slept since last request assertEqual(t, retryIntervals[0], uint64(0)) for i := 1; i < len(retryIntervals); i++ { slept := time.Duration(retryIntervals[i]) // Ensure that client has slept some duration between // waitTime and maxWaitTime for consequent requests if slept < retryWaitTime-5*time.Millisecond { t.Logf("Client has slept %f seconds which is s < min (%f) before retry %d", slept.Seconds(), retryWaitTime.Seconds(), i) } if retryWaitTime+5*time.Millisecond < slept { t.Logf("Client has slept %f seconds which is min < s (%f) before retry %d", slept.Seconds(), retryWaitTime.Seconds(), i) } } } func TestClientRetryDelayStrategyWaitTooLong(t *testing.T) { ts := createGetServer(t) defer ts.Close() retryCount := 5 retryIntervals := make([]uint64, retryCount+1) // Set retry wait times that do not intersect with default ones retryWaitTime := 10 * time.Millisecond retryMaxWaitTime := 50 * time.Millisecond c := dcnl(). SetRetryCount(retryCount). SetRetryWaitTime(retryWaitTime). SetRetryMaxWaitTime(retryMaxWaitTime). SetRetryDelayStrategy(RetryConstantDelayStrategy(1 * time.Second)). AddRetryConditions( func(r *Response, _ error) bool { retryIntervals[r.Request.Attempt-1] = parseTimeSleptFromResponse(r.String()) return true }, ) res, _ := c.R().Get(ts.URL + "/set-retrywaittime-test") retryIntervals[res.Request.Attempt-1] = parseTimeSleptFromResponse(res.String()) // retryCount+1 == attempt attempts were made assertEqual(t, retryCount+1, res.Request.Attempt) // Initial attempt has 0 time slept since last request assertEqual(t, retryIntervals[0], uint64(0)) for i := 1; i < len(retryIntervals); i++ { slept := time.Duration(retryIntervals[i]) // Ensure that client has slept some duration between // waitTime and maxWaitTime for consequent requests if slept < retryMaxWaitTime-5*time.Millisecond { t.Logf("Client has slept %f seconds which is s < max (%f) before retry %d", slept.Seconds(), retryMaxWaitTime.Seconds(), i) } if retryMaxWaitTime+5*time.Millisecond < slept { t.Logf("Client has slept %f seconds which is max < s (%f) before retry %d", slept.Seconds(), retryMaxWaitTime.Seconds(), i) } } } func TestClientRetryCancel(t *testing.T) { ts := createGetServer(t) defer ts.Close() retryCount := 5 retryIntervals := make([]uint64, retryCount+1) // Set retry wait times that do not intersect with default ones retryWaitTime := 100 * time.Millisecond retryMaxWaitTime := 200 * time.Millisecond c := dcnl(). SetRetryCount(retryCount). SetRetryWaitTime(retryWaitTime). SetRetryMaxWaitTime(retryMaxWaitTime). AddRetryConditions( func(r *Response, _ error) bool { retryIntervals[r.Request.Attempt-1] = parseTimeSleptFromResponse(r.String()) return true }, ) timeout := 100 * time.Millisecond ctx, cancelFunc := context.WithTimeout(context.Background(), timeout) req := c.R().SetContext(ctx) _, _ = req.Get(ts.URL + "/set-retrywaittime-test") // 1 attempts were made assertEqual(t, 1, req.Attempt) // Initial attempt has 0 time slept since last request assertEqual(t, retryIntervals[0], uint64(0)) // Second attempt should be interrupted on context timeout if time.Duration(retryIntervals[1]) > timeout { t.Errorf("Client didn't awake on context cancel") } cancelFunc() } func TestClientRetryPost(t *testing.T) { ts := createPostServer(t) defer ts.Close() usersmap := map[string]any{ "user1": map[string]any{"FirstName": "firstname1", "LastName": "lastname1", "ZipCode": "10001"}, } var users []map[string]any users = append(users, usersmap) c := dcnl() c.SetRetryCount(3) c.AddRetryConditions(RetryConditionFunc(func(r *Response, _ error) bool { return r.StatusCode() >= http.StatusInternalServerError })) resp, _ := c.R(). SetBody(&users). Post(ts.URL + "/usersmap?status=500") if resp != nil { if resp.StatusCode() == http.StatusInternalServerError { t.Logf("Got response body: %s", resp.String()) var usersResponse []map[string]any err := json.Unmarshal(resp.Bytes(), &usersResponse) assertError(t, err) if !reflect.DeepEqual(users, usersResponse) { t.Errorf("Expected request body to be echoed back as response body. Instead got: %s", resp.String()) } return } t.Errorf("Got unexpected response code: %d with body: %s", resp.StatusCode(), resp.String()) } } func TestClientRetryErrorRecover(t *testing.T) { ts := createGetServer(t) defer ts.Close() c := dcnl(). SetRetryCount(2). SetResultError(AuthError{}). AddRetryConditions( func(r *Response, _ error) bool { err, ok := r.ResultError().(*AuthError) retry := ok && r.StatusCode() == 429 && err.Message == "too many" return retry }, ) resp, err := c.R(). SetHeader(hdrContentTypeKey, "application/json; charset=utf-8"). SetJSONEscapeHTML(false). SetResult(AuthSuccess{}). Get(ts.URL + "/set-retry-error-recover") assertError(t, err) authSuccess := resp.Result().(*AuthSuccess) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "hello", authSuccess.Message) assertNil(t, resp.ResultError()) } func TestClientRetryCountWithTimeout(t *testing.T) { ts := createGetServer(t) defer ts.Close() attempt := 0 c := dcnl(). SetTimeout(50 * time.Millisecond). SetRetryCount(1). AddRetryConditions( func(r *Response, _ error) bool { attempt++ return true }, ) resp, err := c.R().Get(ts.URL + "/set-retrycount-test") assertEqual(t, "", resp.Status()) assertEqual(t, "", resp.Proto()) assertEqual(t, 0, resp.StatusCode()) assertEqual(t, 0, len(resp.Cookies())) assertEqual(t, 0, len(resp.Header())) assertEqual(t, 2, resp.Request.Attempt) assertErrorIs(t, context.DeadlineExceeded, err, "expected context deadline exceeded error") } func TestClientRetryTooManyRequestsAndRecover(t *testing.T) { ts := createGetServer(t) defer ts.Close() c := dcnl(). SetTimeout(time.Second * 1). SetRetryCount(2) resp, err := c.R(). SetHeader(hdrContentTypeKey, "application/json; charset=utf-8"). SetJSONEscapeHTML(false). SetResult(AuthSuccess{}). SetTimeout(10 * time.Millisecond). Get(ts.URL + "/set-retry-error-recover") assertError(t, err) authSuccess := resp.Result().(*AuthSuccess) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "hello", authSuccess.Message) assertNil(t, resp.ResultError()) } func TestClientRetryHookWithTimeout(t *testing.T) { ts := createGetServer(t) defer ts.Close() hookCalledCount := 0 retryHook := func(r *Response, _ error) { hookCalledCount++ } retryCount := 3 c := dcnl(). SetRetryCount(retryCount). SetTimeout(50 * time.Millisecond). AddRetryHooks(retryHook) // Since reflect.DeepEqual can not compare two functions // just compare pointers of the two hooks originHookPointer := reflect.ValueOf(retryHook).Pointer() getterHookPointer := reflect.ValueOf(c.RetryHooks()[0]).Pointer() assertEqual(t, originHookPointer, getterHookPointer) resp, err := c.R().Get(ts.URL + "/set-retrycount-test") assertEqual(t, "", resp.Status()) assertEqual(t, "", resp.Proto()) assertEqual(t, 0, resp.StatusCode()) assertEqual(t, 0, len(resp.Cookies())) assertEqual(t, 0, len(resp.Header())) assertEqual(t, retryCount+1, resp.Request.Attempt) assertEqual(t, 3, hookCalledCount) assertErrorIs(t, context.DeadlineExceeded, err, "expected context deadline exceeded error") } var errSeekFailure = fmt.Errorf("failing seek test") type failingSeeker struct { reader *bytes.Reader } func (f failingSeeker) Read(b []byte) (n int, err error) { return f.reader.Read(b) } func (f failingSeeker) Seek(offset int64, whence int) (int64, error) { if offset == 0 && whence == io.SeekStart { return 0, errSeekFailure } return f.reader.Seek(offset, whence) } func TestResetMultipartReaderSeekStartError(t *testing.T) { ts := createFileUploadServer(t) defer ts.Close() testSeeker := &failingSeeker{ bytes.NewReader([]byte("test")), } c := dcnl(). SetRetryCount(2). SetTimeout(200 * time.Millisecond) resp, err := c.R(). SetFileReader("name", "filename", testSeeker). Put(ts.URL + "/set-reset-multipart-readers-test") assertEqual(t, 500, resp.StatusCode()) assertEqual(t, err.Error(), errSeekFailure.Error()) } func TestClientResetMultipartReaders(t *testing.T) { ts := createFileUploadServer(t) defer ts.Close() str := "test" buf := []byte(str) bufReader := bytes.NewReader(buf) bufCpy := make([]byte, len(buf)) c := dcnl(). SetRetryCount(2). SetTimeout(time.Second * 3). AddRetryHooks( func(response *Response, _ error) { read, err := bufReader.Read(bufCpy) assertNil(t, err) assertEqual(t, len(buf), read) assertEqual(t, str, string(bufCpy)) }, ) resp, err := c.R(). SetFileReader("name", "filename", bufReader). Put(ts.URL + "/set-reset-multipart-readers-test") assertEqual(t, 500, resp.StatusCode()) assertNil(t, err) } func TestRequestResetMultipartReaders(t *testing.T) { ts := createFileUploadServer(t) defer ts.Close() str := "test" buf := []byte(str) bufReader := bytes.NewReader(buf) bufCpy := make([]byte, len(buf)) c := dcnl(). SetTimeout(time.Second * 3). AddRetryHooks( func(response *Response, _ error) { read, err := bufReader.Read(bufCpy) assertNil(t, err) assertEqual(t, len(buf), read) assertEqual(t, str, string(bufCpy)) }, ) req := c.R(). SetRetryCount(2). SetFileReader("name", "filename", bufReader) resp, err := req.Put(ts.URL + "/set-reset-multipart-readers-test") assertEqual(t, 500, resp.StatusCode()) assertNil(t, err) } func TestParseRetryAfterHeader(t *testing.T) { testStaticTime(t) tests := []struct { name string header string sleep time.Duration ok bool }{ {"seconds", "2", time.Second * 2, true}, {"date", "Fri, 31 Dec 1999 23:59:59 GMT", time.Second * 2, true}, {"past-date", "Fri, 31 Dec 1999 23:59:00 GMT", 0, true}, {"two-headers", "3", time.Second * 3, true}, {"empty", "", 0, false}, {"negative", "-2", 0, false}, {"bad-date", "Fri, 32 Dec 1999 23:59:59 GMT", 0, false}, {"bad-date-format", "badbadbad", 0, false}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { sleep, ok := parseRetryAfterHeader(test.header) if ok != test.ok { t.Errorf("expected ok=%t, got ok=%t", test.ok, ok) } if sleep != test.sleep { t.Errorf("expected sleep=%v, got sleep=%v", test.sleep, sleep) } }) } } func TestRequestRetryTooManyRequestsHeaderRetryAfter(t *testing.T) { ts := createGetServer(t) defer ts.Close() c := dcnl() resp, err := c.R(). SetRetryCount(2). SetHeader(hdrContentTypeKey, "application/json; charset=utf-8"). SetResult(AuthSuccess{}). Get(ts.URL + "/retry-after-delay") assertError(t, err) authSuccess := resp.Result().(*AuthSuccess) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, "hello", authSuccess.Message) assertNil(t, resp.ResultError()) } func TestRetryDefaultConditions(t *testing.T) { t.Run("redirect error", func(t *testing.T) { ts := createRedirectServer(t) defer ts.Close() _, err := dcnl().R(). SetRetryCount(2). Get(ts.URL + "/redirect-1") assertNotNil(t, err) assertTrue(t, (err.Error() == `Get "/redirect-11": stopped after 10 redirects`)) }) t.Run("invalid scheme error", func(t *testing.T) { ts := createGetServer(t) defer ts.Close() c := dcnl().SetBaseURL(strings.Replace(ts.URL, "http", "ftp", 1)) _, err := c.R(). SetRetryCount(2). Get("/") assertNotNil(t, err) assertTrue(t, strings.Contains(err.Error(), `unsupported protocol scheme "ftp"`), "expected unsupported protocol scheme error") }) t.Run("invalid header error", func(t *testing.T) { ts := createGetServer(t) defer ts.Close() _, err := dcnl().R(). SetRetryCount(2). SetHeader("Header-Name", "bad header value \033"). Get(ts.URL + "/") assertNotNil(t, err) assertTrue(t, strings.Contains(err.Error(), "net/http: invalid header field value"), "expected invalid header field value error") _, err = dcnl().R(). SetRetryCount(2). SetHeader("Header-Name\033", "bad header value"). Get(ts.URL + "/") assertNotNil(t, err) assertTrue(t, strings.Contains(err.Error(), "net/http: invalid header field name"), "expected invalid header field name error") }) t.Run("nil values", func(t *testing.T) { result := applyRetryDefaultConditions(nil, nil) assertFalse(t, result) }) } func TestRequestRetryPutIoReadSeekerForBuffer(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { b, err := io.ReadAll(r.Body) assertError(t, err) assertEqual(t, 12, len(b)) assertEqual(t, "body content", string(b)) w.WriteHeader(http.StatusInternalServerError) })) c := dcnl(). AddRetryConditions( func(r *Response, err error) bool { return err != nil || r.StatusCode() > 499 }, ). SetRetryCount(3). SetRetryAllowNonIdempotent(true) assertTrue(t, c.IsRetryAllowNonIdempotent(), "expected AllowNonIdempotentRetry to be true") buf := bytes.NewBuffer([]byte("body content")) resp, err := c.R(). SetBody(buf). SetMethodGetAllowPayload(false). Put(srv.URL) assertNil(t, err) assertEqual(t, 4, resp.Request.Attempt) assertEqual(t, http.StatusInternalServerError, resp.StatusCode()) assertEqual(t, "", resp.String()) } func TestRequestRetryPostIoReadSeeker(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { b, err := io.ReadAll(r.Body) assertError(t, err) assertEqual(t, 12, len(b)) assertEqual(t, "body content", string(b)) w.WriteHeader(http.StatusInternalServerError) })) c := dcnl(). AddRetryConditions( func(r *Response, err error) bool { return err != nil || r.StatusCode() > 499 }, ). SetRetryCount(3). SetRetryAllowNonIdempotent(false) assertFalse(t, c.IsRetryAllowNonIdempotent()) resp, err := c.R(). SetBody([]byte("body content")). SetRetryAllowNonIdempotent(true). Post(srv.URL) assertNil(t, err) assertEqual(t, 4, resp.Request.Attempt) assertEqual(t, http.StatusInternalServerError, resp.StatusCode()) assertEqual(t, "", resp.String()) } func TestRequestRetryHooks(t *testing.T) { ts := createGetServer(t) defer ts.Close() hookFunc := func(msg string) RetryHookFunc { return func(res *Response, err error) { res.Request.log.Debugf(msg) } } c, lb := dcldb() c.AddRetryConditions(func(r *Response, err error) bool { return true }). AddRetryHooks( hookFunc("This is client hook1"), hookFunc("This is client hook2"), ) _, _ = c.R(). SetRetryCount(1). AddRetryHooks(hookFunc("This is request hook1")). SetRetryHooks(hookFunc("This is request overwrite hook1")). Get("/set-retrycount-test") debugLog := lb.String() assertFalse(t, strings.Contains(debugLog, "This is client hook1")) assertFalse(t, strings.Contains(debugLog, "This is client hook2")) assertFalse(t, strings.Contains(debugLog, "This is request hook1")) assertTrue(t, strings.Contains(debugLog, "This is request overwrite hook1"), "expected to find request overwrite hook log") } func TestRequestSetRetryConditions(t *testing.T) { ts := createGetServer(t) defer ts.Close() condFunc := func(fn func() bool) RetryConditionFunc { return func(r *Response, err error) bool { return fn() } } c := dcnl(). AddRetryConditions( condFunc(func() bool { return true }), condFunc(func() bool { return true }), ) res, _ := c.R(). SetRetryCount(2). SetRetryConditions(condFunc(func() bool { return false })). // disable retry with overwrite condition Get("/set-retrycount-test") assertEqual(t, 1, res.Request.Attempt) } func TestRequestRetryQueryParamsGH938(t *testing.T) { ts := createGetServer(t) defer ts.Close() expectedQueryParams := "foo=baz&foo=bar&foo=bar" c := dcnl(). SetBaseURL(ts.URL). SetRetryCount(5). SetRetryWaitTime(10 * time.Millisecond). SetRetryMaxWaitTime(20 * time.Millisecond). AddRetryConditions( func(r *Response, _ error) bool { assertEqual(t, expectedQueryParams, r.Request.RawRequest.URL.RawQuery) return true // always retry }, ) _, _ = c.R(). SetQueryParamsFromValues(map[string][]string{ "foo": { "baz", "bar", "bar", }, }). Get("/set-retrycount-test") } func TestRetryConstantDelayStrategyReturnsGivenDelay(t *testing.T) { d := 250 * time.Millisecond strat := RetryConstantDelayStrategy(d) got, err := strat(nil, nil) assertNil(t, err) assertEqual(t, d, got) } func TestRetryConstantDelayStrategyZeroAndNegative(t *testing.T) { // zero duration strategyZero := RetryConstantDelayStrategy(0) d, err := strategyZero(nil, nil) assertNil(t, err) assertEqual(t, time.Duration(0), d) // negative duration (function should faithfully return what was provided) neg := -5 * time.Second strategyNeg := RetryConstantDelayStrategy(neg) d, err = strategyNeg(nil, nil) assertNil(t, err) assertEqual(t, neg, d) } func TestRetryConstantDelayUsingMinAndMaxWaitTime(t *testing.T) { ts := createGetServer(t) defer ts.Close() retryCount := 10 retryIntervals := make([]uint64, retryCount+1) // Set retry wait times to constant delay constantDelay := 20 * time.Millisecond c := dcnl(). SetRetryCount(retryCount). SetRetryWaitTime(constantDelay). SetRetryMaxWaitTime(constantDelay). AddRetryConditions( func(r *Response, _ error) bool { retryIntervals[r.Request.Attempt-1] = parseTimeSleptFromResponse(r.String()) return true }, ) res, _ := c.R(). Get(ts.URL + "/set-retrywaittime-test") retryIntervals[res.Request.Attempt-1] = parseTimeSleptFromResponse(res.String()) assertNil(t, c.RetryDelayStrategy()) // retryCount+1 == attempts were made assertEqual(t, retryCount+1, res.Request.Attempt) // Initial attempt has 0 time slept since last request assertEqual(t, retryIntervals[0], uint64(0)) for i := 1; i < len(retryIntervals); i++ { slept := time.Duration(retryIntervals[i]) // Ensure that client has slept some duration between // waitTime and maxWaitTime for consequent requests if slept < constantDelay-5*time.Millisecond { t.Logf("Client has slept %f seconds which is s < min (%f) before retry %d", slept.Seconds(), constantDelay.Seconds(), i) } if constantDelay+5*time.Millisecond < slept { t.Logf("Client has slept %f seconds which is max < s (%f) before retry %d", slept.Seconds(), constantDelay.Seconds(), i) } } } func TestRetryConstantDelayUsingStrategy(t *testing.T) { ts := createGetServer(t) defer ts.Close() retryCount := 10 retryIntervals := make([]uint64, retryCount+1) // Set retry wait times to constant delay constantDelay := 20 * time.Millisecond c := dcnl(). SetRetryCount(retryCount). AddRetryConditions( func(r *Response, _ error) bool { retryIntervals[r.Request.Attempt-1] = parseTimeSleptFromResponse(r.String()) return true }, ) res, _ := c.R(). SetRetryDelayStrategy(RetryConstantDelayStrategy(constantDelay)). Get(ts.URL + "/set-retrywaittime-test") retryIntervals[res.Request.Attempt-1] = parseTimeSleptFromResponse(res.String()) assertNil(t, c.RetryDelayStrategy()) // retryCount+1 == attempts were made assertEqual(t, retryCount+1, res.Request.Attempt) // Initial attempt has 0 time slept since last request assertEqual(t, retryIntervals[0], uint64(0)) for i := 1; i < len(retryIntervals); i++ { slept := time.Duration(retryIntervals[i]) // Ensure that client has slept some duration between // waitTime and maxWaitTime for consequent requests if slept < constantDelay-5*time.Millisecond { t.Logf("Client has slept %f seconds which is s < min (%f) before retry %d", slept.Seconds(), constantDelay.Seconds(), i) } if constantDelay+5*time.Millisecond < slept { t.Logf("Client has slept %f seconds which is max < s (%f) before retry %d", slept.Seconds(), constantDelay.Seconds(), i) } } } func TestRetryCoverage(t *testing.T) { t.Run("apply retry default min and max value", func(t *testing.T) { backoff := newBackoffWithJitter(0, 0) assertEqual(t, defaultWaitTime, backoff.min) assertEqual(t, defaultMaxWaitTime, backoff.max) dur1 := backoff.balanceMinMax(0) assertEqual(t, 2*time.Second, dur1) dur2 := backoff.balanceMinMax(4 * time.Second) assertEqual(t, 2*time.Second, dur2) }) t.Run("mock tls cert error", func(t *testing.T) { certError := tls.CertificateVerificationError{} result1 := applyRetryDefaultConditions(nil, &certError) assertFalse(t, result1, "expected no retry for tls.CertificateVerificationError") }) } func parseTimeSleptFromResponse(v string) uint64 { timeSlept, _ := strconv.ParseUint(v, 10, 64) return timeSlept } func testStaticTime(t *testing.T) { timeNow = func() time.Time { now, err := time.Parse(time.RFC1123, "Fri, 31 Dec 1999 23:59:57 GMT") if err != nil { panic(err) } return now } t.Cleanup(func() { timeNow = time.Now }) } ================================================ FILE: sse.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty import ( "bufio" "bytes" "crypto/tls" "errors" "fmt" "io" "net/http" "slices" "strconv" "strings" "sync" "time" ) // Spec: https://html.spec.whatwg.org/multipage/server-sent-events.html var ( defaultSseMaxBufSize = 1 << 15 // 32kb defaultEventName = "message" defaultHTTPMethod = MethodGet headerID = []byte("id:") headerData = []byte("data:") headerEvent = []byte("event:") headerRetry = []byte("retry:") hdrCacheControlKey = http.CanonicalHeaderKey("Cache-Control") hdrConnectionKey = http.CanonicalHeaderKey("Connection") hdrLastEvevntID = http.CanonicalHeaderKey("Last-Event-ID") ) type ( // SSEOpenFunc is a callback function type used to receive notification // when Resty establishes a connection with the server for the // Server-Sent Events(SSE) SSEOpenFunc func(url string, respHdr http.Header) // SSEMessageFunc is a callback function type used to receive event details // from the Server-Sent Events(SSE) stream SSEMessageFunc func(any) // SSEErrorFunc is a callback function type used to receive notification // when an error occurs with [SSESource] processing SSEErrorFunc func(error) // SSERequestFailureFunc is a callback function type used to receive event // details from the Server-Sent Events(SSE) request failure SSERequestFailureFunc func(err error, res *http.Response) // SSE struct represents the event details from the Server-Sent Events(SSE) stream SSE struct { ID string Name string Data string } // SSESource struct implements the Server-Sent Events(SSE) [specification] to receive // stream from the server // // [specification]: https://html.spec.whatwg.org/multipage/server-sent-events.html SSESource struct { lock *sync.RWMutex url string method string header http.Header bodyBytes []byte lastEventID string retryCount int retryWaitTime time.Duration retryMaxWaitTime time.Duration serverSentRetry time.Duration maxBufSize int onOpen SSEOpenFunc onError SSEErrorFunc onRequestFailure SSERequestFailureFunc onEvent map[string]*callback log Logger closed bool httpClient *http.Client } callback struct { Func SSEMessageFunc Result any } ) // NewSSESource method creates a new instance of [SSESource] // with default values for Server-Sent Events(SSE) // // sse := NewSSESource(). // SetURL("https://sse.dev/test"). // OnMessage( // func(e any) { // event := e.(*resty.SSE) // fmt.Println(event) // }, // nil, // see method godoc // ) // // err := sse.Connect() // fmt.Println(err) // // See [SSESource.OnMessage], [SSESource.AddEventListener] func NewSSESource() *SSESource { sse := &SSESource{ lock: new(sync.RWMutex), header: make(http.Header), retryCount: 3, retryWaitTime: defaultWaitTime, retryMaxWaitTime: defaultMaxWaitTime, maxBufSize: defaultSseMaxBufSize, onEvent: make(map[string]*callback), httpClient: &http.Client{ Jar: createCookieJar(), Transport: createTransport(nil, nil), }, } return sse } // SetURL method sets a [SSESource] connection URL in the instance // // sse.SetURL("https://sse.dev/test") func (sse *SSESource) SetURL(url string) *SSESource { sse.url = url return sse } // SetMethod method sets a [SSESource] connection HTTP method in the instance // // sse.SetMethod("POST"), or sse.SetMethod(resty.MethodPost) func (sse *SSESource) SetMethod(method string) *SSESource { sse.method = method return sse } // SetHeader method sets a header and its value to the [SSESource] instance. // It overwrites the header value if the key already exists. These headers will be sent in // the request while establishing a connection to the event source // // sse.SetHeader("Authorization", "token here"). // SetHeader("X-Header", "value") func (sse *SSESource) SetHeader(header, value string) *SSESource { sse.lock.Lock() defer sse.lock.Unlock() sse.header.Set(header, value) return sse } // SetBody method sets body value to the [SSESource] instance // // Example: // sse.SetBody(bytes.NewReader([]byte(`{"test":"put_data"}`))) func (sse *SSESource) SetBody(body io.Reader) *SSESource { sse.lock.Lock() defer sse.lock.Unlock() if body == nil { sse.bodyBytes = nil return sse } sse.bodyBytes = nil bodyBytes, err := ioReadAll(body) if err != nil { sse.log.Errorf("resty:sse: unable to read body, error: %v", err) return sse } sse.bodyBytes = bodyBytes return sse } // TLSClientConfig method returns the [tls.Config] from underlying client transport // otherwise returns nil func (sse *SSESource) TLSClientConfig() *tls.Config { cfg, err := sse.tlsConfig() if err != nil { sse.Logger().Errorf("%v", err) } return cfg } // SetTLSClientConfig method sets TLSClientConfig for underlying client Transport. // // Values supported by https://pkg.go.dev/crypto/tls#Config can be configured. // // // Disable SSL cert verification for local development // sse.SetTLSClientConfig(&tls.Config{ // InsecureSkipVerify: true // }) // // NOTE: This method overwrites existing [http.Transport.TLSClientConfig] func (sse *SSESource) SetTLSClientConfig(tlsConfig *tls.Config) *SSESource { sse.lock.Lock() defer sse.lock.Unlock() // TLSClientConfiger interface handling if tc, ok := sse.httpClient.Transport.(TLSClientConfiger); ok { if err := tc.SetTLSClientConfig(tlsConfig); err != nil { sse.log.Errorf("%v", err) } return sse } // default standard transport handling if transport, ok := sse.httpClient.Transport.(*http.Transport); ok { transport.TLSClientConfig = tlsConfig } return sse } // getting TLS client config if not exists then create one func (sse *SSESource) tlsConfig() (*tls.Config, error) { sse.lock.Lock() defer sse.lock.Unlock() if tc, ok := sse.httpClient.Transport.(TLSClientConfiger); ok { return tc.TLSClientConfig(), nil } transport, ok := sse.httpClient.Transport.(*http.Transport) if !ok { return nil, ErrNotHttpTransportType } if transport.TLSClientConfig == nil { transport.TLSClientConfig = &tls.Config{} } return transport.TLSClientConfig, nil } // AddHeader method adds a header and its value to the [SSESource] instance. // If the header key already exists, it appends. These headers will be sent in // the request while establishing a connection to the event source // // sse.AddHeader("Authorization", "token here"). // AddHeader("X-Header", "value") func (sse *SSESource) AddHeader(header, value string) *SSESource { sse.lock.Lock() defer sse.lock.Unlock() sse.header.Add(header, value) return sse } // SetRetryCount method enables retry attempts on the SSE client while establishing // connection with the server // // first attempt + retry count = total attempts // // Default is 3 // // sse.SetRetryCount(10) func (sse *SSESource) SetRetryCount(count int) *SSESource { sse.lock.Lock() defer sse.lock.Unlock() sse.retryCount = count return sse } // SetRetryWaitTime method sets the default wait time for sleep before retrying // the request // // Default is 100 milliseconds. // // NOTE: The server-sent retry value takes precedence if present. // // sse.SetRetryWaitTime(1 * time.Second) func (sse *SSESource) SetRetryWaitTime(waitTime time.Duration) *SSESource { sse.lock.Lock() defer sse.lock.Unlock() sse.retryWaitTime = waitTime return sse } // SetRetryMaxWaitTime method sets the max wait time for sleep before retrying // the request // // Default is 2 seconds. // // NOTE: The server-sent retry value takes precedence if present. // // sse.SetRetryMaxWaitTime(3 * time.Second) func (sse *SSESource) SetRetryMaxWaitTime(maxWaitTime time.Duration) *SSESource { sse.lock.Lock() defer sse.lock.Unlock() sse.retryMaxWaitTime = maxWaitTime return sse } // SetSizeMaxBuffer method sets the given buffer size into the SSE client // // Default is 32kb // // sse.SetSizeMaxBuffer(64 * 1024) // 64kb func (sse *SSESource) SetSizeMaxBuffer(bufSize int) *SSESource { sse.lock.Lock() defer sse.lock.Unlock() sse.maxBufSize = bufSize return sse } // Logger method returns the logger instance used by the event source instance. func (sse *SSESource) Logger() Logger { sse.lock.RLock() defer sse.lock.RUnlock() return sse.log } // SetLogger method sets given writer for logging // // Compliant to interface [resty.Logger] func (sse *SSESource) SetLogger(l Logger) *SSESource { sse.lock.Lock() defer sse.lock.Unlock() sse.log = l return sse } // just an internal helper method for test case func (sse *SSESource) outputLogTo(w io.Writer) *SSESource { sse.lock.Lock() defer sse.lock.Unlock() sse.log.(*logger).l.SetOutput(w) return sse } // OnOpen registered callback gets triggered when the connection is // established with the server // // sse.OnOpen(func(url string, resHdr http.Header) { // fmt.Println("I'm connected:", url, resHdr) // }) func (sse *SSESource) OnOpen(ef SSEOpenFunc) *SSESource { sse.lock.Lock() defer sse.lock.Unlock() if sse.onOpen != nil { sse.log.Warnf("Overwriting an existing OnOpen callback from=%s to=%s", functionName(sse.onOpen), functionName(ef)) } sse.onOpen = ef return sse } // OnError registered callback gets triggered when the error occurred // in the process // // sse.OnError(func(err error) { // fmt.Println("Error occurred:", err) // }) func (sse *SSESource) OnError(ef SSEErrorFunc) *SSESource { sse.lock.Lock() defer sse.lock.Unlock() if sse.onError != nil { sse.log.Warnf("Overwriting an existing OnError callback from=%s to=%s", functionName(sse.onError), functionName(ef)) } sse.onError = ef return sse } // OnRequestFailure registered callback gets triggered when the HTTP request // failure while establishing a SSE connection. // // sse.OnRequestFailure(func(err error, res *http.Response) { // fmt.Println("Error and response:", err, res) // }) // // NOTE: // - Do not forget to close the HTTP response body. // - HTTP response may be nil. func (sse *SSESource) OnRequestFailure(ef SSERequestFailureFunc) *SSESource { sse.lock.Lock() defer sse.lock.Unlock() if sse.onRequestFailure != nil { sse.log.Warnf("Overwriting an existing OnRequestFailure callback from=%s to=%s", functionName(sse.onRequestFailure), functionName(ef)) } sse.onRequestFailure = ef return sse } // OnMessage method registers a callback to emit every SSE event message // from the server. The second result argument is optional; it can be used // to register the data type for JSON data. // // sse.OnMessage( // func(e any) { // event := e.(*resty.SSE) // fmt.Println("Event message", event) // }, // nil, // ) // // // Receiving JSON data from the server, you can set result type // // to do auto-unmarshal // sse.OnMessage( // func(e any) { // event := e.(*MyData) // fmt.Println(event) // }, // MyData{}, // ) func (sse *SSESource) OnMessage(ef SSEMessageFunc, result any) *SSESource { return sse.AddEventListener(defaultEventName, ef, result) } // AddEventListener method registers a callback to consume a specific event type // messages from the server. The second result argument is optional; it can be used // to register the data type for JSON data. // // sse.AddEventListener( // "friend_logged_in", // func(e any) { // event := e.(*resty.SSE) // fmt.Println(event) // }, // nil, // ) // // // Receiving JSON data from the server, you can set result type // // to do auto-unmarshal // sse.AddEventListener( // "friend_logged_in", // func(e any) { // event := e.(*UserLoggedIn) // fmt.Println(event) // }, // UserLoggedIn{}, // ) func (sse *SSESource) AddEventListener(eventName string, ef SSEMessageFunc, result any) *SSESource { sse.lock.Lock() defer sse.lock.Unlock() if e, found := sse.onEvent[eventName]; found { sse.log.Warnf("Overwriting an existing OnEvent callback from=%s to=%s", functionName(e), functionName(ef)) } cb := &callback{Func: ef, Result: nil} if result != nil { cb.Result = getPointer(result) } sse.onEvent[eventName] = cb return sse } // Get method establishes the connection with the server. // // sse := NewSSE(). // SetURL("https://sse.dev/test"). // OnMessage( // func(e any) { // event := e.(*resty.SSE) // fmt.Println(event) // }, // nil, // see method godoc // ) // // err := sse.Get() // fmt.Println(err) func (sse *SSESource) Get() error { // Validate required values if isStringEmpty(sse.url) { return fmt.Errorf("resty:sse: event source URL is required") } if isStringEmpty(sse.method) { // It is up to the user to choose which http method to use, depending on the specific code implementation. No restrictions are imposed here. // Ensure compatibility, use GET as default http method sse.method = defaultHTTPMethod } if len(sse.onEvent) == 0 { return fmt.Errorf("resty:sse: At least one OnMessage/AddEventListener func is required") } // reset to begin sse.enableConnect() for { if sse.isClosed() { return nil } res, err := sse.connect() if err != nil { return err } sse.triggerOnOpen(res.Header.Clone()) if err := sse.listenStream(res); err != nil { return err } } } // Close method used to close SSE connection explicitly func (sse *SSESource) Close() { sse.lock.Lock() defer sse.lock.Unlock() sse.closed = true } func (sse *SSESource) enableConnect() { sse.lock.Lock() defer sse.lock.Unlock() sse.closed = false } func (sse *SSESource) isClosed() bool { sse.lock.RLock() defer sse.lock.RUnlock() return sse.closed } func (sse *SSESource) triggerOnOpen(hdr http.Header) { sse.lock.RLock() defer sse.lock.RUnlock() if sse.onOpen != nil { sse.onOpen(strings.Clone(sse.url), hdr) } } func (sse *SSESource) triggerOnError(err error) { sse.lock.RLock() defer sse.lock.RUnlock() if sse.onError != nil { sse.onError(err) } } func (sse *SSESource) triggerOnRequestFailure(err error, res *http.Response) { sse.lock.RLock() defer sse.lock.RUnlock() if sse.onRequestFailure != nil { sse.onRequestFailure(err, res) } } func (sse *SSESource) createRequest() (*http.Request, error) { var reqBody io.Reader if sse.bodyBytes != nil { // create reader from bytes on each request reqBody = bytes.NewReader(sse.bodyBytes) } req, err := http.NewRequest(sse.method, sse.url, reqBody) if err != nil { return nil, err } req.Header = sse.header.Clone() req.Header.Set(hdrAcceptKey, "text/event-stream") req.Header.Set(hdrCacheControlKey, "no-cache") req.Header.Set(hdrConnectionKey, "keep-alive") if len(sse.lastEventID) > 0 { req.Header.Set(hdrLastEvevntID, sse.lastEventID) } return req, nil } func (sse *SSESource) connect() (*http.Response, error) { sse.lock.RLock() defer sse.lock.RUnlock() var backoff *backoffWithJitter if sse.serverSentRetry > 0 { backoff = newBackoffWithJitter(sse.serverSentRetry, sse.serverSentRetry) } else { backoff = newBackoffWithJitter(sse.retryWaitTime, sse.retryMaxWaitTime) } var ( err error attempt int ) for i := 0; i <= sse.retryCount; i++ { attempt++ req, reqErr := sse.createRequest() if reqErr != nil { err = reqErr break } resp, doErr := sse.httpClient.Do(req) if resp != nil && resp.StatusCode == http.StatusOK { // successful connection, return response to listenStream return resp, nil } // we have reached the maximum no. of requests // first attempt + retry count = total attempts if attempt-1 == sse.retryCount { err = doErr break } rRes := wrapResponse(resp, req) needsRetry := applyRetryDefaultConditions(rRes, doErr) // retry not required stop here if !needsRetry { if rRes != nil { err = wrapErrors(fmt.Errorf("resty:sse: %v", rRes.Status()), doErr) } else { err = doErr } if err != nil { sse.triggerOnRequestFailure(err, resp) } break } // let's drain the response body, before retry wait drainBody(rRes) waitDuration, _ := backoff.NextWaitDuration(nil, rRes, doErr, attempt) timer := time.NewTimer(waitDuration) <-timer.C timer.Stop() } if err != nil { return nil, err } return nil, fmt.Errorf("resty:sse: unable to connect stream") } func (sse *SSESource) listenStream(res *http.Response) error { defer closeq(res.Body) scanner := bufio.NewScanner(res.Body) scanner.Buffer(make([]byte, slices.Min([]int{4096, sse.maxBufSize})), sse.maxBufSize) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { return 0, nil, nil } if i := bytes.Index(data, []byte{'\n', '\n'}); i >= 0 { // We have a full double newline-terminated line. return i + 1, data[0:i], nil } // If we're at EOF, we have a final, non-terminated line. Return it. if atEOF { return len(data), data, nil } // Request more data. return 0, nil, nil }) for { if sse.isClosed() { return nil } if err := sse.processEvent(scanner); err != nil { return err } } } func (sse *SSESource) processEvent(scanner *bufio.Scanner) error { e, err := readEvent(scanner) if err != nil { if err == io.EOF { return err } sse.triggerOnError(err) return err } ed, err := parseEvent(e) if err != nil { sse.triggerOnError(err) return nil // parsing errors, will not return error. } defer putRawEvent(ed) if len(ed.ID) > 0 { sse.lock.Lock() sse.lastEventID = string(ed.ID) sse.lock.Unlock() } if len(ed.Retry) > 0 { if retry, err := strconv.Atoi(string(ed.Retry)); err == nil { sse.lock.Lock() sse.serverSentRetry = time.Millisecond * time.Duration(retry) sse.lock.Unlock() } else { sse.triggerOnError(err) } } if len(ed.Data) > 0 { sse.handleCallback(&SSE{ ID: string(ed.ID), Name: string(ed.Event), Data: string(ed.Data), }) } return nil } func (sse *SSESource) handleCallback(e *SSE) { eventName := e.Name if len(eventName) == 0 { eventName = defaultEventName } sse.lock.RLock() cb, found := sse.onEvent[eventName] sse.lock.RUnlock() if found { if cb.Result == nil { cb.Func(e) return } r := newInterface(cb.Result) if err := decodeJSON(strings.NewReader(e.Data), r); err != nil { sse.triggerOnError(err) return } cb.Func(r) } } var readEvent = readEventFunc func readEventFunc(scanner *bufio.Scanner) ([]byte, error) { if scanner.Scan() { event := scanner.Bytes() return event, nil } if err := scanner.Err(); err != nil { return nil, err } return nil, io.EOF } func wrapResponse(res *http.Response, req *http.Request) *Response { if res == nil { return nil } return &Response{RawResponse: res, Request: &Request{RawRequest: req}} } type rawSSE struct { ID []byte Data []byte Event []byte Retry []byte } var parseEvent = parseEventFunc // event value parsing logic obtained and modified for Resty processing flow. // https://github.com/r3labs/sse/blob/c6d5381ee3ca63828b321c16baa008fd6c0b4564/client.go#L322 func parseEventFunc(msg []byte) (*rawSSE, error) { if len(msg) < 1 { return nil, errors.New("resty:sse: event message was empty") } e := newRawEvent() // Split the line by "\n" for _, line := range bytes.FieldsFunc(msg, func(r rune) bool { return r == '\n' }) { switch { case bytes.HasPrefix(line, headerID): e.ID = append([]byte(nil), trimHeader(len(headerID), line)...) case bytes.HasPrefix(line, headerData): // The spec allows for multiple data fields per event, concatenated them with "\n" e.Data = append(e.Data[:], append(trimHeader(len(headerData), line), byte('\n'))...) // The spec says that a line that simply contains the string "data" should be treated as a data field with an empty body. case bytes.Equal(line, bytes.TrimSuffix(headerData, []byte(":"))): e.Data = append(e.Data, byte('\n')) case bytes.HasPrefix(line, headerEvent): e.Event = append([]byte(nil), trimHeader(len(headerEvent), line)...) case bytes.HasPrefix(line, headerRetry): e.Retry = append([]byte(nil), trimHeader(len(headerRetry), line)...) default: // Ignore anything that doesn't match the header } } // Trim the last "\n" per the spec e.Data = bytes.TrimSuffix(e.Data, []byte("\n")) return e, nil } func trimHeader(size int, data []byte) []byte { if data == nil || len(data) < size { return data } data = data[size:] if len(data) > 0 && data[0] == ' ' { data = data[1:] } if len(data) > 0 && data[len(data)-1] == '\n' { data = data[:len(data)-1] } return data } var rawEventPool = &sync.Pool{New: func() any { return new(rawSSE) }} func newRawEvent() *rawSSE { e := rawEventPool.Get().(*rawSSE) e.ID = e.ID[:0] e.Data = e.Data[:0] e.Event = e.Event[:0] e.Retry = e.Retry[:0] return e } func putRawEvent(e *rawSSE) { rawEventPool.Put(e) } ================================================ FILE: sse_test.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty import ( "bufio" "bytes" "crypto/tls" "errors" "fmt" "io" "net/http" "net/http/httptest" "strconv" "strings" "testing" "time" ) func TestSSESourceSimpleFlow(t *testing.T) { es := createSSESource(t, "", nil, nil) messageCounter := 0 messageFunc := func(e any) { event := e.(*SSE) assertEqual(t, strconv.Itoa(messageCounter), event.ID) assertTrue(t, strings.HasPrefix(event.Data, "The time is")) messageCounter++ if messageCounter == 100 { es.Close() } } es.OnMessage(messageFunc, nil) counter := 0 ts := createSSETestServer( t, 10*time.Millisecond, func(w io.Writer) error { if counter == 100 { return fmt.Errorf("stop sending events") } _, err := fmt.Fprintf(w, "id: %v\ndata: The time is %s\n\n", counter, time.Now().Format(time.UnixDate)) counter++ return err }, ) defer ts.Close() es.SetURL(ts.URL) es.SetMethod(MethodPost) err := es.Get() assertNil(t, err) assertEqual(t, counter, messageCounter) } func TestSSESourceMultipleEventTypes(t *testing.T) { type userEvent struct { UserName string `json:"username"` Message string `json:"msg"` Time time.Time `json:"time"` } tm := time.Now().Add(-1 * time.Minute) userConnectCounter := 0 userConnectFunc := func(e any) { data := e.(*userEvent) assertEqual(t, "username"+strconv.Itoa(userConnectCounter), data.UserName) assertTrue(t, data.Time.After(tm)) userConnectCounter++ } userMessageCounter := 0 userMessageFunc := func(e any) { data := e.(*userEvent) assertEqual(t, "username"+strconv.Itoa(userConnectCounter), data.UserName) assertEqual(t, "Hello, how are you?", data.Message) assertTrue(t, data.Time.After(tm)) userMessageCounter++ } counter := 0 es := createSSESource(t, "", func(any) {}, nil) ts := createSSETestServer( t, 10*time.Millisecond, func(w io.Writer) error { if counter == 100 { es.Close() return fmt.Errorf("stop sending events") } id := counter / 2 if counter%2 == 0 { event := fmt.Sprintf("id: %v\n"+ "event: user_message\n"+ `data: {"username": "%v", "time": "%v", "msg": "Hello, how are you?"}`+"\n\n", id, "username"+strconv.Itoa(id), time.Now().Format(time.RFC3339), ) fmt.Fprint(w, event) } else { event := fmt.Sprintf("id: %v\n"+ "event: user_connect\n"+ `data: {"username": "%v", "time": "%v"}`+"\n\n", int(id), "username"+strconv.Itoa(int(id)), time.Now().Format(time.RFC3339), ) fmt.Fprint(w, event) } counter++ return nil }, ) defer ts.Close() es.SetURL(ts.URL). SetMethod(MethodPost). AddEventListener("user_connect", userConnectFunc, userEvent{}). AddEventListener("user_message", userMessageFunc, userEvent{}) err := es.Get() assertNil(t, err) assertEqual(t, userConnectCounter, userMessageCounter) } func TestSSESourceOverwriteFuncs(t *testing.T) { messageFunc1 := func(e any) { assertNotNil(t, e) } es := createSSESource(t, "", messageFunc1, nil) message2Counter := 0 messageFunc2 := func(e any) { event := e.(*SSE) assertEqual(t, strconv.Itoa(message2Counter), event.ID) assertTrue(t, strings.HasPrefix(event.Data, "The time is")) message2Counter++ if message2Counter == 50 { es.Close() } } counter := 0 ts := createSSETestServer( t, 10*time.Millisecond, func(w io.Writer) error { if counter == 50 { return fmt.Errorf("stop sending events") } _, err := fmt.Fprintf(w, "id: %v\ndata: The time is %s\n\n", counter, time.Now().Format(time.UnixDate)) counter++ return err }, ) defer ts.Close() lb := new(bytes.Buffer) es.outputLogTo(lb) es.SetURL(ts.URL). OnMessage(messageFunc2, nil). OnOpen(func(url string, respHdr http.Header) { t.Log("from overwrite func", url, respHdr) }). OnError(func(err error) { t.Log("from overwrite func", err) }) err := es.Get() assertNil(t, err) assertEqual(t, counter, message2Counter) logLines := lb.String() assertTrue(t, strings.Contains(logLines, "Overwriting an existing OnEvent callback")) assertTrue(t, strings.Contains(logLines, "Overwriting an existing OnOpen callback")) assertTrue(t, strings.Contains(logLines, "Overwriting an existing OnError callback")) } func TestSSESourceRetry(t *testing.T) { es := createSSESource(t, "", nil, nil) messageCounter := 2 // 0 & 1 connection failure messageFunc := func(e any) { event := e.(*SSE) assertEqual(t, strconv.Itoa(messageCounter), event.ID) assertTrue(t, strings.HasPrefix(event.Data, "The time is")) messageCounter++ if messageCounter == 15 { es.Close() } } es.OnMessage(messageFunc, nil) counter := 0 ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { if counter == 1 && r.URL.Query().Get("reconnect") == "1" { w.WriteHeader(http.StatusTooManyRequests) counter++ return } if counter < 2 || counter == 7 { w.WriteHeader(http.StatusTooManyRequests) counter++ return } w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") // for local testing allow it w.Header().Set("Access-Control-Allow-Origin", "*") // Create a channel for client disconnection clientGone := r.Context().Done() rc := http.NewResponseController(w) tick := time.NewTicker(10 * time.Millisecond) defer tick.Stop() for { select { case <-clientGone: t.Log("Client disconnected") return case <-tick.C: if counter == 5 { fmt.Fprintf(w, "id: %v\nretry: abc\ndata: The time is %s\n\n", counter, time.Now().Format(time.UnixDate)) counter++ return } if counter == 15 { es.Close() return // stop sending events } fmt.Fprintf(w, "id: %v\nretry: 1\ndata: The time is %s\ndata\n\n", counter, time.Now().Format(time.UnixDate)) counter++ if err := rc.Flush(); err != nil { t.Log(err) return } } } }) defer ts.Close() // first round es.SetURL(ts.URL) err1 := es.Get() assertNotNil(t, err1) // second round counter = 0 messageCounter = 2 es.SetRetryCount(1). SetURL(ts.URL + "?reconnect=1") err2 := es.Get() assertNotNil(t, err2) } func TestSSESourceRetryReusesRequestBody(t *testing.T) { const payload = `{"test":"retry-body"}` attempt := 0 bodies := make([]string, 0, 2) ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(r.Body) assertNil(t, err) bodies = append(bodies, string(body)) attempt++ if attempt == 1 { w.WriteHeader(http.StatusTooManyRequests) return } w.WriteHeader(http.StatusOK) }) defer ts.Close() es := NewSSESource(). SetURL(ts.URL). SetMethod(MethodPost). SetRetryCount(1). SetRetryWaitTime(1 * time.Millisecond). SetRetryMaxWaitTime(1 * time.Millisecond) es.SetBody(bytes.NewBufferString(payload)) resp, err := es.connect() assertNil(t, err) assertNotNil(t, resp) if resp != nil { closeq(resp.Body) } assertEqual(t, 2, attempt, "expected one retry attempt") assertEqual(t, 2, len(bodies), "expected request body on both attempts") assertEqual(t, payload, bodies[0], "expected first attempt body to match") assertEqual(t, payload, bodies[1], "expected retry attempt body to match") } func TestSSESourceTLSConfigerInterface(t *testing.T) { t.Run("set and get tls config", func(t *testing.T) { es := createSSESource(t, "", func(any) {}, nil) tc, err := es.tlsConfig() assertNil(t, err) assertNotNil(t, tc) tlsConfig := &tls.Config{InsecureSkipVerify: true} es.SetTLSClientConfig(tlsConfig) assertEqual(t, tlsConfig, es.TLSClientConfig()) }) t.Run("get tls config error", func(t *testing.T) { es := createSSESource(t, "", func(any) {}, nil) ct := &CustomRoundTripper1{} es.httpClient.Transport = ct assertNil(t, es.TLSClientConfig()) }) t.Run("set tls config", func(t *testing.T) { es := createSSESource(t, "", func(any) {}, nil) ct := &CustomRoundTripper2{} es.httpClient.Transport = ct tlsConfig := &tls.Config{InsecureSkipVerify: true} es.SetTLSClientConfig(tlsConfig) assertNotNil(t, es.TLSClientConfig()) }) t.Run("set tls config error", func(t *testing.T) { es := createSSESource(t, "", func(any) {}, nil) ct := &CustomRoundTripper2{returnErr: true} es.httpClient.Transport = ct tlsConfig := &tls.Config{InsecureSkipVerify: true} es.SetTLSClientConfig(tlsConfig) assertNil(t, es.TLSClientConfig()) }) } func TestSSESourceNoRetryRequired(t *testing.T) { es := createSSESource(t, "", func(any) {}, nil) ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) }) defer ts.Close() es.SetURL(ts.URL) err := es.Get() fmt.Println(err) assertTrue(t, strings.Contains(err.Error(), "400 Bad Request")) } func TestGH1044TrimHeader(t *testing.T) { t.Run("data is nil", func(t *testing.T) { result := trimHeader(0, nil) assertNil(t, result) }) t.Run("data has double whitespace", func(t *testing.T) { data := []byte("data: double whitespace message") result := trimHeader(5, data) assertTrue(t, result[0] == ' ') }) t.Run("data has newline", func(t *testing.T) { data := []byte("data: newline message\n") result := trimHeader(5, data) assertTrue(t, result[len(result)-1] != '\n') }) } func TestGH1041RequestFailureWithResponseBody(t *testing.T) { es := createSSESource(t, "", func(any) {}, nil) ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { w.Header().Set(hdrContentTypeKey, jsonContentType) w.WriteHeader(http.StatusBadRequest) _, _ = w.Write([]byte(`{ "id": "bad_request", "message": "Unable to establish connection" }`)) }) defer ts.Close() rfFunc := func(err error, res *http.Response) { defer res.Body.Close() resBytes, _ := io.ReadAll(res.Body) assertNotNil(t, err) assertEqual(t, "resty:sse: 400 Bad Request", err.Error()) assertEqual(t, `{ "id": "bad_request", "message": "Unable to establish connection" }`, string(resBytes)) } es.SetURL(ts.URL).OnRequestFailure(rfFunc) es.OnRequestFailure(rfFunc) err := es.Get() assertNotNil(t, err) assertEqual(t, "resty:sse: 400 Bad Request", err.Error()) } func TestSSESourceHTTPError(t *testing.T) { es := createSSESource(t, "", func(any) {}, nil) ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "http://local host", http.StatusTemporaryRedirect) }) defer ts.Close() es.SetURL(ts.URL) err := es.Get() assertTrue(t, strings.Contains(err.Error(), `invalid character " " in host name`)) } func TestSSESourceParseAndReadError(t *testing.T) { type data struct{} counter := 0 es := createSSESource(t, "", func(any) {}, data{}) ts := createSSETestServer( t, 5*time.Millisecond, func(w io.Writer) error { if counter == 5 { es.Close() return fmt.Errorf("stop sending events") } _, err := fmt.Fprintf(w, "id: %v\n"+ `data: The time is %s\n\n`+"\n\n", counter, time.Now().Format(time.UnixDate)) counter++ return err }, ) defer ts.Close() es.SetURL(ts.URL) err := es.Get() assertNil(t, err) // parse error parseEvent = func(_ []byte) (*rawSSE, error) { return nil, errors.New("test error") } counter = 0 err = es.Get() assertNil(t, err) t.Cleanup(func() { parseEvent = parseEventFunc }) } func TestSSESourceReadError(t *testing.T) { es := createSSESource(t, "", func(any) {}, nil) ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) defer ts.Close() // read error readEvent = func(_ *bufio.Scanner) ([]byte, error) { return nil, errors.New("read event test error") } t.Cleanup(func() { readEvent = readEventFunc }) es.SetURL(ts.URL) err := es.Get() assertNotNil(t, err) assertTrue(t, strings.Contains(err.Error(), "read event test error")) } func TestSSESourceWithDifferentMethods(t *testing.T) { testCases := []struct { name string method string body []byte }{ { name: "GET Method", method: MethodGet, body: nil, }, { name: "POST Method", method: MethodPost, body: []byte(`{"test":"post_data"}`), }, { name: "PUT Method", method: MethodPut, body: []byte(`{"test":"put_data"}`), }, { name: "DELETE Method", method: MethodDelete, body: nil, }, { name: "PATCH Method", method: MethodPatch, body: []byte(`{"test":"patch_data"}`), }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { es := createSSESource(t, "", nil, nil) messageCounter := 0 messageFunc := func(e any) { event := e.(*SSE) assertEqual(t, strconv.Itoa(messageCounter), event.ID) assertTrue(t, strings.HasPrefix(event.Data, fmt.Sprintf("%s method test:", tc.method))) messageCounter++ if messageCounter == 20 { es.Close() } } es.OnMessage(messageFunc, nil) counter := 0 methodVerified := false bodyVerified := false ts := createMethodVerifyingSSETestServer( t, 10*time.Millisecond, tc.method, tc.body, &methodVerified, &bodyVerified, func(w io.Writer) error { if counter == 20 { return fmt.Errorf("stop sending events") } _, err := fmt.Fprintf(w, "id: %v\ndata: %s method test: %s\n\n", counter, tc.method, time.Now().Format(time.RFC3339)) counter++ return err }, ) defer ts.Close() es.SetURL(ts.URL) es.SetMethod(tc.method) // set body if tc.body != nil { es.SetBody(bytes.NewBuffer(tc.body)) } err := es.Get() assertNil(t, err) // check the message count assertEqual(t, counter, messageCounter) // check if server receive correct method and body assertTrue(t, methodVerified) if tc.body != nil { assertTrue(t, bodyVerified) } }) } } func TestSSESource_readEventFunc(t *testing.T) { t.Run("successful scan", func(t *testing.T) { input := "event: test\ndata: test data\n\n" scanner := bufio.NewScanner(strings.NewReader(input)) event, err := readEventFunc(scanner) assertNil(t, err) assertNotNil(t, event) assertEqual(t, "event: test", string(event)) }) t.Run("scanner error", func(t *testing.T) { // Create a custom reader that returns an error scanner := bufio.NewScanner(&errorReader{}) event, err := readEventFunc(scanner) assertNotNil(t, err) assertNil(t, event) assertEqual(t, "fake", err.Error()) }) t.Run("EOF error", func(t *testing.T) { // Empty reader will immediately return EOF scanner := bufio.NewScanner(strings.NewReader("")) event, err := readEventFunc(scanner) assertEqual(t, io.EOF, err) assertNil(t, event) }) t.Run("multiple lines", func(t *testing.T) { input := "line1\nline2\nline3\n" scanner := bufio.NewScanner(strings.NewReader(input)) // First call should return the first line event1, err1 := readEventFunc(scanner) assertNil(t, err1) assertEqual(t, "line1", string(event1)) // Second call should return the second line event2, err2 := readEventFunc(scanner) assertNil(t, err2) assertEqual(t, "line2", string(event2)) // Third call should return the third line event3, err3 := readEventFunc(scanner) assertNil(t, err3) assertEqual(t, "line3", string(event3)) // Fourth call should return EOF event4, err4 := readEventFunc(scanner) assertEqual(t, io.EOF, err4) assertNil(t, event4) }) } func TestSSESourceCoverage(t *testing.T) { es := NewSSESource() err1 := es.Get() assertEqual(t, "resty:sse: event source URL is required", err1.Error()) es.SetURL("https://sse.dev/test") err2 := es.Get() assertEqual(t, "resty:sse: At least one OnMessage/AddEventListener func is required", err2.Error()) es.OnMessage(func(a any) {}, nil) es.SetURL("//res%20ty.dev") err3 := es.Get() assertTrue(t, strings.Contains(err3.Error(), `invalid URL escape "%20"`)) wrapResponse(nil, nil) trimHeader(2, nil) parseEvent([]byte{}) } func TestSSESetBody(t *testing.T) { t.Run("nil input", func(t *testing.T) { es := createSSESource(t, "", nil, nil) es.SetBody(nil) assertNil(t, es.bodyBytes) }) t.Run("read error", func(t *testing.T) { es := createSSESource(t, "", nil, nil) es.SetBody(&errorReader{}) assertNil(t, es.bodyBytes) }) } func createSSESource(t *testing.T, url string, fn SSEMessageFunc, rt any) *SSESource { es := NewSSESource(). SetURL(url). SetMethod(MethodGet). AddHeader("X-Test-Header-1", "test header 1"). SetHeader("X-Test-Header-2", "test header 2"). SetRetryCount(2). SetRetryWaitTime(200 * time.Millisecond). SetRetryMaxWaitTime(1000 * time.Millisecond). SetSizeMaxBuffer(1 << 14). // 16kb SetLogger(createLogger()). OnOpen(func(url string, respHdr http.Header) { t.Log("I'm connected:", url, respHdr) }). OnError(func(err error) { t.Log("Error occurred:", err) }) if fn != nil { es.OnMessage(fn, rt) } return es } func createSSETestServer(t *testing.T, ticker time.Duration, fn func(io.Writer) error) *httptest.Server { return createTestServer(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") // for local testing allow it w.Header().Set("Access-Control-Allow-Origin", "*") // Create a channel for client disconnection clientGone := r.Context().Done() rc := http.NewResponseController(w) tick := time.NewTicker(ticker) defer tick.Stop() for { select { case <-clientGone: t.Log("Client disconnected") return case <-tick.C: if err := fn(w); err != nil { t.Log(err) return } if err := rc.Flush(); err != nil { t.Log(err) return } } } }) } // almost like create server before but add verifying method and body func createMethodVerifyingSSETestServer( t *testing.T, ticker time.Duration, expectedMethod string, expectedBody []byte, methodVerified *bool, bodyVerified *bool, fn func(io.Writer) error, ) *httptest.Server { return createTestServer(func(w http.ResponseWriter, r *http.Request) { // validate method if r.Method == expectedMethod { *methodVerified = true } else { t.Errorf("Expected method %s, got %s", expectedMethod, r.Method) } // validate body if expectedBody != nil { body, err := io.ReadAll(r.Body) if err != nil { t.Errorf("Failed to read request body: %v", err) } else if string(body) == string(expectedBody) { *bodyVerified = true } else { t.Errorf("Expected body %s, got %s", string(expectedBody), string(body)) } } // same as createSSETestServer w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") w.Header().Set("Access-Control-Allow-Origin", "*") clientGone := r.Context().Done() rc := http.NewResponseController(w) tick := time.NewTicker(ticker) defer tick.Stop() for { select { case <-clientGone: t.Log("Client disconnected") return case <-tick.C: if err := fn(w); err != nil { t.Log(err) return } if err := rc.Flush(); err != nil { t.Log(err) return } } } }) } ================================================ FILE: stream.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty import ( "bytes" "compress/flate" "compress/gzip" "context" "encoding/json" "encoding/xml" "errors" "fmt" "io" "sync" ) var ( ErrContentDecompresserNotFound = errors.New("resty: content decoder not found") // It's good to have decode limit; let's start with object size // limit as 1M objects, which should be more than enough for most // use cases if users need more, they can always implement their // own decoder and set it in Client.SetContentDecoder // // Max 1 million objects, +1 to detect if we exceed the limit without EOF maxDecodeObjects = 1000001 ) type ( // ContentTypeEncoder type is for encoding the request body based on header Content-Type ContentTypeEncoder func(io.Writer, any) error // ContentTypeDecoder type is for decoding the response body based on header Content-Type ContentTypeDecoder func(io.Reader, any) error // ContentDecompresser type is for decompressing response body based on header Content-Encoding // ([RFC 9110]) // // For example, gzip, deflate, etc. // // [RFC 9110]: https://datatracker.ietf.org/doc/html/rfc9110 ContentDecompresser func(io.ReadCloser) (io.ReadCloser, error) ) func encodeJSON(w io.Writer, v any) error { return encodeJSONEscapeHTML(w, v, true) } func encodeJSONEscapeHTML(w io.Writer, v any, esc bool) error { enc := json.NewEncoder(w) enc.SetEscapeHTML(esc) return enc.Encode(v) } func encodeJSONEscapeHTMLIndent(w io.Writer, v any, esc bool, indent string) error { enc := json.NewEncoder(w) enc.SetEscapeHTML(esc) enc.SetIndent("", indent) return enc.Encode(v) } func decodeJSON(r io.Reader, v any) error { dec := json.NewDecoder(r) // Handle nopReadCloser specially to support multiple JSON objects // while preventing infinite loops if nrc, ok := r.(*nopReadCloser); ok { // Temporarily disable auto-reset to prevent infinite loops originalReset := nrc.resetOnEOF nrc.resetOnEOF = false defer func() { nrc.resetOnEOF = originalReset }() if err := doDecodeJSON(dec, v); err != nil { return err } // After decoding, reset for future reads nrc.Reset() return nil } // For other readers, decode multiple JSON objects as intended return doDecodeJSON(dec, v) } func doDecodeJSON(dec *json.Decoder, v any) error { // Decode all JSON objects in the data for range maxDecodeObjects { if err := dec.Decode(v); err != nil { if err == io.EOF { return nil } return err } } return fmt.Errorf("resty: JSON decode exceeded %d objects without EOF", maxDecodeObjects) } func encodeXML(w io.Writer, v any) error { return xml.NewEncoder(w).Encode(v) } func decodeXML(r io.Reader, v any) error { dec := xml.NewDecoder(r) for range maxDecodeObjects { if err := dec.Decode(v); err != nil { if err == io.EOF { return nil } return err } } return fmt.Errorf("resty: XML decode exceeded %d objects without EOF", maxDecodeObjects) } // gzipReaderPool pools actual *gzip.Reader objects for reuse via Reset(). // This avoids the allocation cost of gzip.NewReader for each decompression. // Thread-safety is ensured by the gzipReaderWrapper's mutex which guards access. var gzipReaderPool = sync.Pool{ New: func() any { // Return nil; let's create reader on first use or get them from pool return nil }, } // gzipReaderWrapper wraps a pooled gzip.Reader with a mutex for safe concurrent access. // The mutex ensures exclusive access to the reader during Read() and state transitions. type gzipReaderWrapper struct { mu *sync.Mutex r io.ReadCloser gr *gzip.Reader } // acquireGzipReader gets a gzip.Reader from the pool or creates one. // It resets the reader for the new stream using the provided io.ReadCloser. func acquireGzipReader(r io.ReadCloser) (*gzipReaderWrapper, error) { w := &gzipReaderWrapper{ mu: new(sync.Mutex), r: r, } w.mu.Lock() defer w.mu.Unlock() // Try to get a cached reader from the pool if cached := gzipReaderPool.Get(); cached != nil { w.gr = cached.(*gzip.Reader) // Reset the pooled reader for the new stream if err := w.gr.Reset(r); err != nil { gzipReaderPool.Put(w.gr) // Return to pool on reset error return nil, err } } else { // Pool is empty, create a new reader gr, err := gzip.NewReader(r) if err != nil { return nil, err } w.gr = gr } return w, nil } // releaseGzipReader returns the gzip reader to the pool for reuse, // and closes the underlying source. func releaseGzipReader(w *gzipReaderWrapper) { w.mu.Lock() defer w.mu.Unlock() if w.gr != nil { w.gr.Reset(nopReader{}) // clear reference to the closed source before pooling gzipReaderPool.Put(w.gr) w.gr = nil } if w.r != nil { closeq(w.r) w.r = nil } } func decompressGzip(r io.ReadCloser) (io.ReadCloser, error) { return acquireGzipReader(r) } // Implement io.ReadCloser for gzipReaderWrapper func (w *gzipReaderWrapper) Read(p []byte) (n int, err error) { // Hold the lock during Read to ensure exclusive access to the gzip reader w.mu.Lock() defer w.mu.Unlock() if w.gr == nil { return 0, io.EOF } return w.gr.Read(p) } func (w *gzipReaderWrapper) Close() error { releaseGzipReader(w) return nil } // flateReaderPool pools io.ReadCloser (flate.Reader) objects for reuse via Reset(). // This avoids the allocation cost of flate.NewReader for each decompression. // Thread-safety is ensured by the deflateReaderWrapper's mutex which guards access. var flateReaderPool = sync.Pool{ New: func() any { // Return nil; let's create reader on first use or get them from pool return nil }, } // deflateReaderWrapper wraps a pooled flate.Reader with a mutex for safe concurrent access. // The mutex ensures exclusive access to the reader during Read() and state transitions. type deflateReaderWrapper struct { mu *sync.Mutex r io.ReadCloser fr io.ReadCloser } // acquireDeflateReader gets a flate.Reader from the pool or creates one. // It resets the reader for the new stream using the provided io.ReadCloser. func acquireDeflateReader(r io.ReadCloser) (*deflateReaderWrapper, error) { w := &deflateReaderWrapper{ mu: new(sync.Mutex), r: r, } w.mu.Lock() defer w.mu.Unlock() // Try to get a cached reader from the pool if cached := flateReaderPool.Get(); cached != nil { w.fr = cached.(io.ReadCloser) // Reset the pooled reader for the new stream; flate.Resetter.Reset never errors w.fr.(flate.Resetter).Reset(r, nil) } else { // Pool is empty, create a new reader w.fr = flate.NewReader(r) } return w, nil } // releaseDeflateReader returns the flate reader to the pool for reuse, // and closes the underlying source. func releaseDeflateReader(w *deflateReaderWrapper) { w.mu.Lock() defer w.mu.Unlock() if w.fr != nil { w.fr.(flate.Resetter).Reset(nopReader{}, nil) flateReaderPool.Put(w.fr) w.fr = nil } if w.r != nil { closeq(w.r) w.r = nil } } func decompressDeflate(r io.ReadCloser) (io.ReadCloser, error) { return acquireDeflateReader(r) } // Implement io.ReadCloser for deflateReaderWrapper func (w *deflateReaderWrapper) Read(p []byte) (n int, err error) { // Hold the lock during Read to ensure exclusive access to the flate reader w.mu.Lock() defer w.mu.Unlock() if w.fr == nil { return 0, io.EOF } return w.fr.Read(p) } func (w *deflateReaderWrapper) Close() error { releaseDeflateReader(w) return nil } // ErrReadExceedsThresholdLimit is returned when the read operation exceeds the defined threshold limit. var ErrReadExceedsThresholdLimit = errors.New("resty: read exceeds the threshold limit") var _ io.ReadCloser = (*limitReadCloser)(nil) var _ resetter = (*limitReadCloser)(nil) // resetter is an interface that defines a Reset method for resetting the reader state. type resetter interface { Reset() error } const unlimitedRead = 0 type limitReadCloser struct { r io.Reader l int64 // Limit (0 or <0 - unlimited, >0 limit) t int64 // Total bytes read f func(s int64) } func (l *limitReadCloser) Read(p []byte) (n int, err error) { switch { case l.l <= unlimitedRead: n, err = l.r.Read(p) l.t += int64(n) l.f(l.t) return n, err default: remaining := l.l - l.t if remaining <= 0 { return 0, ErrReadExceedsThresholdLimit } if remaining < int64(len(p)) { p = p[:remaining] } n, err = l.r.Read(p) l.t += int64(n) l.f(l.t) return n, err } } func (l *limitReadCloser) Close() error { if c, ok := l.r.(io.Closer); ok { return c.Close() } return nil } func (l *limitReadCloser) Reset() error { l.t = 0 // Reset total bytes read to zero return nil } var _ io.ReadCloser = (*copyReadCloser)(nil) type copyReadCloser struct { s io.Reader t *bytes.Buffer c bool f func(*bytes.Buffer) } func (r *copyReadCloser) Read(p []byte) (int, error) { n, err := r.s.Read(p) if n > 0 { _, _ = r.t.Write(p[:n]) } if err == io.EOF || err == ErrReadExceedsThresholdLimit { if !r.c { r.f(r.t) r.c = true } } return n, err } func (r *copyReadCloser) Close() error { if c, ok := r.s.(io.Closer); ok { return c.Close() } return nil } var _ io.ReadCloser = (*nopReadCloser)(nil) type nopReadCloser struct { r io.Reader resetOnEOF bool // Whether to reset on EOF } func (r *nopReadCloser) Read(p []byte) (int, error) { n, err := r.r.Read(p) if err == io.EOF && r.resetOnEOF { r.Reset() } return n, err } func (r *nopReadCloser) Close() error { return nil } // Reset allows manual reset of the reader position func (r *nopReadCloser) Reset() { // If the underlying reader supports seeking, reset to the beginning if seeker, ok := r.r.(io.Seeker); ok { seeker.Seek(0, io.SeekStart) } // Also try to reset underlying layer if ur, ok := r.r.(resetter); ok { _ = ur.Reset() } } var _ flate.Reader = (*nopReader)(nil) type nopReader struct{} func (nopReader) Read([]byte) (int, error) { return 0, io.EOF } func (nopReader) ReadByte() (byte, error) { return 0, io.EOF } type gracefulStopReader struct { ctx context.Context r io.Reader } func (gsr *gracefulStopReader) Read(p []byte) (n int, err error) { if err := gsr.ctx.Err(); err != nil { // Return io.EOF to stop io.Copy gracefully without an error. return 0, io.EOF } return gsr.r.Read(p) } ================================================ FILE: stream_test.go ================================================ package resty import ( "bytes" "compress/flate" "compress/gzip" "io" "net/http" "net/http/httptest" "strings" "sync" "testing" ) func TestDecodeJSONWhenResponseBodyIsNull(t *testing.T) { r := &Response{ Body: io.NopCloser(bytes.NewReader([]byte("null"))), } r.wrapCopyReadCloser() err := r.readAll() assertNil(t, err) var result map[int]int err = decodeJSON(r.Body, &result) assertNil(t, err) assertNil(t, result, "expected result to be nil map when JSON is null") } func TestGetMethodWhenResponseIsNull(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("null")) })) client := New().SetRetryCount(3).SetCurlCmdGenerate(true) var x any resp, err := client.R().SetBody("{}"). SetHeader("Content-Type", "application/json; charset=utf-8"). SetResponseForceContentType("application/json"). SetMethodGetAllowPayload(true). SetResponseBodyUnlimitedReads(true). SetResult(&x). Get(server.URL + "/test") assertNil(t, err) assertEqual(t, "null", resp.String()) assertNil(t, x, "expected result to be nil when response body is null") } func TestDecodeJSON(t *testing.T) { t.Run("single object", func(t *testing.T) { jsonData := `{"name": "John", "age": 30}` reader := bytes.NewReader([]byte(jsonData)) var result map[string]any err := decodeJSON(reader, &result) assertNil(t, err) assertEqual(t, "John", result["name"]) assertEqual(t, float64(30), result["age"]) }) t.Run("multiple objects", func(t *testing.T) { multipleJSON := `{"id": 1} {"id": 2} {"id": 3}` reader2 := bytes.NewReader([]byte(multipleJSON)) var result2 map[string]any err := decodeJSON(reader2, &result2) assertNil(t, err) assertEqual(t, float64(3), result2["id"]) }) t.Run("list of objects", func(t *testing.T) { multipleJSON := `[{"id": 1}, {"id": 2}, {"id": 3}]` reader2 := bytes.NewReader([]byte(multipleJSON)) var result2 []map[string]any err := decodeJSON(reader2, &result2) assertNil(t, err) assertEqual(t, float64(3), result2[2]["id"]) }) t.Run("malformed JSON", func(t *testing.T) { malformedJSON := `{"name": "John", "age":}` reader3 := bytes.NewReader([]byte(malformedJSON)) var result3 map[string]any err := decodeJSON(reader3, &result3) assertNotNil(t, err) }) t.Run("empty body", func(t *testing.T) { emptyJSON := `` reader4 := bytes.NewReader([]byte(emptyJSON)) var result4 map[string]any err := decodeJSON(reader4, &result4) assertNil(t, err) }) t.Run("exceeds maxDecodeObjects limit", func(t *testing.T) { preMaxDecodeObjects := maxDecodeObjects maxDecodeObjects = 51 // Set a lower limit for testing t.Cleanup(func() { maxDecodeObjects = preMaxDecodeObjects // Reset to original value after test }) // Build a reader that returns maxDecodeObjects+1 objects without EOF // by using a custom reader that signals no EOF until asked enough times. // Simplest approach: patch the limit via the loop by creating a reader // backed by a sufficient number of elements. We instead test the boundary // by constructing exactly that many elements with a streaming reader // built from io.MultiReader. elem := []byte(`{"key": "value"}`) readers := make([]io.Reader, maxDecodeObjects+1) for i := range readers { readers[i] = bytes.NewReader(elem) } r := io.MultiReader(readers...) var v map[string]any err := decodeJSON(r, &v) assertNotNil(t, err) assertEqual(t, "resty: JSON decode exceeded 51 objects without EOF", err.Error()) }) } func TestWrapCopyReadCloser(t *testing.T) { testData := "Hello, World!" r := &Response{ Body: io.NopCloser(bytes.NewReader([]byte(testData))), } // Before wrapping, bodyBytes should be empty assertEqual(t, 0, len(r.bodyBytes)) r.wrapCopyReadCloser() // Read data - should trigger copy mechanism and transform to nopReadCloser data, err := io.ReadAll(r.Body) assertNil(t, err) assertEqual(t, testData, string(data)) assertEqual(t, testData, string(r.bodyBytes)) // Should now be nopReadCloser for unlimited reads _, ok := r.Body.(*nopReadCloser) assertTrue(t, ok, "expected Body to be of type *nopReadCloser") // Test unlimited reads data2, err := io.ReadAll(r.Body) assertNil(t, err) assertEqual(t, testData, string(data2)) } func TestMultipleJSONObjectsSupport(t *testing.T) { // Test multiple JSON objects with wrapCopyReadCloser jsonData := `{"first": 1} {"second": 2} {"third": 3}` r := &Response{ Body: io.NopCloser(bytes.NewReader([]byte(jsonData))), } r.wrapCopyReadCloser() // Should process all objects and get the last one var result map[string]any err := decodeJSON(r.Body, &result) assertNil(t, err) assertEqual(t, float64(3), result["third"]) // Should support unlimited reads and decoding var result2 map[string]any err = decodeJSON(r.Body, &result2) assertNil(t, err) assertEqual(t, float64(3), result2["third"]) // Test direct nopReadCloser usage nopReader := &nopReadCloser{ r: bytes.NewReader([]byte(jsonData)), resetOnEOF: true, } var result3 map[string]any err = decodeJSON(nopReader, &result3) assertNil(t, err) assertEqual(t, float64(3), result3["third"]) } // Test case from GH-#1087 to ensure no panic occurs // with gzip.Reader on corrupted gzip data when multiple // concurrent requests are made. func TestGzipReaderPanicOnConcurrentCorruptedBody(t *testing.T) { writeHeaders := func(w http.ResponseWriter) { w.Header().Set(hdrContentEncodingKey, "gzip") w.Header().Set(hdrContentTypeKey, "application/json") w.WriteHeader(http.StatusOK) } ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { writeHeaders(w) // We want the Client to think it's reading Gzip, but fail immediately // upon processing these bytes. w.Write([]byte{0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x01}) }) defer ts.Close() client := NewWithTransportSettings(&TransportSettings{MaxIdleConns: 1000, MaxIdleConnsPerHost: 1000}). SetRetryCount(2). AddRetryConditions(func(r *Response, err error) bool { return err != nil }) totalRequests := 100 concurrencyLimit := 100 sem := make(chan struct{}, concurrencyLimit) panicChan := make(chan any, 1) doneChan := make(chan struct{}) go func() { var wg sync.WaitGroup defer close(doneChan) for range totalRequests { select { case <-panicChan: return default: } wg.Add(1) sem <- struct{}{} go func() { defer wg.Done() defer func() { <-sem }() defer func() { if r := recover(); r != nil { select { case panicChan <- r: default: } } }() var out map[string]any client.R(). SetRetryAllowNonIdempotent(true). SetResult(&out). Post(ts.URL) }() } wg.Wait() }() select { case r := <-panicChan: t.Errorf("Test Failed Immediately: Panic detected: %v", r) case <-doneChan: select { case r := <-panicChan: t.Errorf("Test Failed: Panic detected at end of run: %v", r) default: // If we get here, no panic occurred. } } // at the end the client should still be functional // and can make valid requests goodServer := createTestServer(func(w http.ResponseWriter, r *http.Request) { writeHeaders(w) gz := gzip.NewWriter(w) defer gz.Close() gz.Write([]byte(`{"status": "ok"}`)) }) defer goodServer.Close() var result map[string]string res, err := client.R(). SetResult(&result). Post(goodServer.URL) assertError(t, err) assertEqual(t, http.StatusOK, res.StatusCode()) assertEqual(t, "ok", result["status"], "expected to successfully decode valid gzip response") } func TestGzipReaderAcquireAndResetError(t *testing.T) { t.Run("invalid data", func(t *testing.T) { // Test the scenario where gzip.NewReader fails (pool empty path) invalidData := io.NopCloser(bytes.NewReader([]byte("not gzip data"))) // This should trigger the gzip.NewReader error path wrapper, err := acquireGzipReader(invalidData) assertNotNil(t, err) assertNil(t, wrapper) assertTrue(t, strings.Contains(err.Error(), "gzip") || strings.Contains(err.Error(), "header") || strings.Contains(err.Error(), "invalid"), "expected gzip-related error, got: "+err.Error()) }) t.Run("reset error", func(t *testing.T) { // Test the scenario where Reset fails (pool hit path) validData := io.NopCloser(bytes.NewReader(createGzipValidData())) // First acquire to populate the pool wrapper, err := acquireGzipReader(validData) assertNil(t, err) assertNotNil(t, wrapper) releaseGzipReader(wrapper) errorReader := &brokenReadCloser{} // Now acquire again with a broken reader to trigger Reset error on pool-hit path wrapper2, err := acquireGzipReader(errorReader) assertNotNil(t, err) assertNil(t, wrapper2) assertTrue(t, strings.Contains(err.Error(), "read error")) }) } func TestGzipReaderPoolConcurrentAccess(t *testing.T) { // Test concurrent pool access to ensure thread safety const numGoroutines = 10 const numOperations = 5 var wg sync.WaitGroup wg.Add(numGoroutines) for range numGoroutines { go func() { defer wg.Done() for range numOperations { // Create fresh data for each operation validData := io.NopCloser(bytes.NewReader(createGzipValidData())) wrapper, err := acquireGzipReader(validData) assertNil(t, err) assertNotNil(t, wrapper) // Use the reader briefly _, err = wrapper.gr.Read(make([]byte, 5)) assertNil(t, err) // Release back to pool releaseGzipReader(wrapper) } }() } wg.Wait() } // Helper functions for testing func createGzipValidData() []byte { var buf bytes.Buffer zw := gzip.NewWriter(&buf) zw.Write([]byte("test data")) zw.Close() return buf.Bytes() } func createDeflateValidData() []byte { var buf bytes.Buffer zw, _ := flate.NewWriter(&buf, flate.BestSpeed) zw.Write([]byte("test data")) zw.Close() return buf.Bytes() } // Test case to ensure no panic occurs with flate.Reader on corrupted deflate data // when multiple concurrent requests are made. func TestDeflateReaderPanicOnConcurrentCorruptedBody(t *testing.T) { writeHeaders := func(w http.ResponseWriter) { w.Header().Set(hdrContentEncodingKey, "deflate") w.Header().Set(hdrContentTypeKey, "application/json") w.WriteHeader(http.StatusOK) } ts := createTestServer(func(w http.ResponseWriter, r *http.Request) { writeHeaders(w) // Send bytes that are not valid deflate data to force a read error. w.Write([]byte{0xde, 0xad, 0xbe, 0xef, 0x00, 0x01, 0x02, 0x03}) }) defer ts.Close() client := NewWithTransportSettings(&TransportSettings{MaxIdleConns: 1000, MaxIdleConnsPerHost: 1000}). SetRetryCount(2). AddRetryConditions(func(r *Response, err error) bool { return err != nil }) totalRequests := 100 concurrencyLimit := 100 sem := make(chan struct{}, concurrencyLimit) panicChan := make(chan any, 1) doneChan := make(chan struct{}) go func() { var wg sync.WaitGroup defer close(doneChan) for range totalRequests { select { case <-panicChan: return default: } wg.Add(1) sem <- struct{}{} go func() { defer wg.Done() defer func() { <-sem }() defer func() { if r := recover(); r != nil { select { case panicChan <- r: default: } } }() var out map[string]any client.R(). SetRetryAllowNonIdempotent(true). SetResult(&out). Post(ts.URL) }() } wg.Wait() }() select { case r := <-panicChan: t.Errorf("Test Failed Immediately: Panic detected: %v", r) case <-doneChan: select { case r := <-panicChan: t.Errorf("Test Failed: Panic detected at end of run: %v", r) default: // If we get here, no panic occurred. } } // at the end the client should still be functional // and can make valid requests goodServer := createTestServer(func(w http.ResponseWriter, r *http.Request) { writeHeaders(w) zw, _ := flate.NewWriter(w, flate.BestSpeed) defer zw.Close() zw.Write([]byte(`{"status": "ok"}`)) }) defer goodServer.Close() var result map[string]string res, err := client.R(). SetResult(&result). Post(goodServer.URL) assertError(t, err) assertEqual(t, http.StatusOK, res.StatusCode()) assertEqual(t, "ok", result["status"], "expected to successfully decode valid deflate response") } func TestDeflateReaderPoolAcquireAndRead(t *testing.T) { // Test successful creation and read with valid deflate data validData := io.NopCloser(bytes.NewReader(createDeflateValidData())) wrapper, err := acquireDeflateReader(validData) assertNil(t, err) assertNotNil(t, wrapper) buf := make([]byte, 128) // flate.Reader may return (n, io.EOF) in the same call on the final read; ignore it. n, _ := wrapper.Read(buf) assertTrue(t, n > 0, "expected to read some bytes from valid deflate data") assertEqual(t, "test data", strings.TrimRight(string(buf[:n]), "\x00")) wrapper.Close() // Test that Read on a closed wrapper returns io.EOF _, err = wrapper.Read(buf) assertEqual(t, io.EOF, err) } func TestDeflateReaderPoolConcurrentAccess(t *testing.T) { // Test concurrent pool access to ensure thread safety const numGoroutines = 10 const numOperations = 5 var wg sync.WaitGroup wg.Add(numGoroutines) for range numGoroutines { go func() { defer wg.Done() for range numOperations { // Create fresh data for each operation validData := io.NopCloser(bytes.NewReader(createDeflateValidData())) wrapper, err := acquireDeflateReader(validData) assertNil(t, err) assertNotNil(t, wrapper) // Use the reader briefly _, err = wrapper.fr.Read(make([]byte, 5)) assertNil(t, err) // Release back to pool releaseDeflateReader(wrapper) } }() } wg.Wait() } func TestLimitCloserResetterInterface(t *testing.T) { testStr := "This is limit reset test" testStrLen := int64(len(testStr)) r := bytes.NewReader([]byte(testStr)) lc := &limitReadCloser{ r: r, l: testStrLen, f: func(total int64) {}, } assertEqual(t, testStrLen, lc.l) rc := nopReadCloser{r: lc, resetOnEOF: true} rc.Read(make([]byte, 25)) // read to reach total size assertEqual(t, testStrLen, lc.l) assertEqual(t, testStrLen, lc.t) rc.Reset() // reset should change the total to 0 assertEqual(t, int64(0), lc.t) } func TestDecodeXML(t *testing.T) { type Item struct { Name string `xml:"name"` } t.Run("single object", func(t *testing.T) { data := `foo` var v Item err := decodeXML(bytes.NewReader([]byte(data)), &v) assertNil(t, err) assertEqual(t, "foo", v.Name) }) t.Run("multiple objects - last one wins", func(t *testing.T) { data := `firstlast` var v Item err := decodeXML(bytes.NewReader([]byte(data)), &v) assertNil(t, err) assertEqual(t, "last", v.Name) }) t.Run("malformed XML returns error", func(t *testing.T) { data := `broken` var v Item err := decodeXML(bytes.NewReader([]byte(data)), &v) assertNotNil(t, err) }) t.Run("exceeds maxDecodeObjects limit", func(t *testing.T) { preMaxDecodeObjects := maxDecodeObjects maxDecodeObjects = 51 // Set a lower limit for testing t.Cleanup(func() { maxDecodeObjects = preMaxDecodeObjects // Reset to original value after test }) // Build a reader that returns maxDecodeObjects+1 objects without EOF // by using a custom reader that signals no EOF until asked enough times. // Simplest approach: patch the limit via the loop by creating a reader // backed by a sufficient number of elements. We instead test the boundary // by constructing exactly that many elements with a streaming reader // built from io.MultiReader. elem := []byte(`x`) readers := make([]io.Reader, maxDecodeObjects+1) for i := range readers { readers[i] = bytes.NewReader(elem) } r := io.MultiReader(readers...) var v Item err := decodeXML(r, &v) assertNotNil(t, err) assertEqual(t, "resty: XML decode exceeded 51 objects without EOF", err.Error()) }) } func TestStreamMisc(t *testing.T) { t.Run("wrapper gzip reader is nil", func(t *testing.T) { // Simulate a scenario where gzip.NewReader returns a wrapper with nil gr // due to an error, and ensure that Read on the wrapper does not panic // and returns an appropriate error instead. gzipReader := &gzipReaderWrapper{mu: new(sync.Mutex)} n, err := gzipReader.Read(make([]byte, 5)) assertNotNil(t, err) assertErrorIs(t, io.EOF, err) assertEqual(t, 0, n) }) } ================================================ FILE: trace.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty import ( "context" "crypto/tls" "fmt" "net/http/httptrace" "sync" "time" ) // TraceInfo struct is used to provide request trace info such as DNS lookup // duration, Connection obtain duration, Server processing duration, etc. type TraceInfo struct { // DNSLookup is the duration that transport took to perform // DNS lookup. DNSLookup time.Duration `json:"dns_lookup_time"` // ConnTime is the duration it took to obtain a successful connection. ConnTime time.Duration `json:"connection_time"` // TCPConnTime is the duration it took to obtain the TCP connection. TCPConnTime time.Duration `json:"tcp_connection_time"` // TLSHandshake is the duration of the TLS handshake. TLSHandshake time.Duration `json:"tls_handshake_time"` // ServerTime is the server's duration for responding to the first byte. ServerTime time.Duration `json:"server_time"` // ResponseTime is the duration since the first response byte from the server to // request completion. ResponseTime time.Duration `json:"response_time"` // TotalTime is the duration of the total time request taken end-to-end. TotalTime time.Duration `json:"total_time"` // IsConnReused is whether this connection has been previously // used for another HTTP request. IsConnReused bool `json:"is_connection_reused"` // IsConnWasIdle is whether this connection was obtained from an // idle pool. IsConnWasIdle bool `json:"is_connection_was_idle"` // ConnIdleTime is the duration how long the connection that was previously // idle, if IsConnWasIdle is true. ConnIdleTime time.Duration `json:"connection_idle_time"` // RequestAttempt is to represent the request attempt made during a Resty // request execution flow, including retry count. RequestAttempt int `json:"request_attempt"` // RemoteAddr returns the remote network address. RemoteAddr string `json:"remote_address"` } // String method returns string representation of request trace information. func (ti TraceInfo) String() string { return fmt.Sprintf(`TRACE INFO: DNSLookupTime : %v ConnTime : %v TCPConnTime : %v TLSHandshake : %v ServerTime : %v ResponseTime : %v TotalTime : %v IsConnReused : %v IsConnWasIdle : %v ConnIdleTime : %v RequestAttempt: %v RemoteAddr : %v`, ti.DNSLookup, ti.ConnTime, ti.TCPConnTime, ti.TLSHandshake, ti.ServerTime, ti.ResponseTime, ti.TotalTime, ti.IsConnReused, ti.IsConnWasIdle, ti.ConnIdleTime, ti.RequestAttempt, ti.RemoteAddr) } // JSON method returns the JSON string of request trace information func (ti TraceInfo) JSON() string { return toJSON(ti) } // Clone method returns the clone copy of [TraceInfo] func (ti TraceInfo) Clone() *TraceInfo { ti2 := new(TraceInfo) *ti2 = ti return ti2 } // clientTrace struct maps the [httptrace.ClientTrace] hooks into Fields // with the same naming for easy understanding. Plus additional insights // [Request]. type clientTrace struct { lock sync.RWMutex getConn time.Time dnsStart time.Time dnsDone time.Time connectDone time.Time tlsHandshakeStart time.Time tlsHandshakeDone time.Time gotConn time.Time gotFirstResponseByte time.Time endTime time.Time gotConnInfo httptrace.GotConnInfo } func (t *clientTrace) createContext(ctx context.Context) context.Context { return httptrace.WithClientTrace( ctx, &httptrace.ClientTrace{ DNSStart: func(_ httptrace.DNSStartInfo) { t.lock.Lock() t.dnsStart = time.Now() t.lock.Unlock() }, DNSDone: func(_ httptrace.DNSDoneInfo) { t.lock.Lock() t.dnsDone = time.Now() t.lock.Unlock() }, ConnectStart: func(_, _ string) { t.lock.Lock() if t.dnsDone.IsZero() { t.dnsDone = time.Now() } if t.dnsStart.IsZero() { t.dnsStart = t.dnsDone } t.lock.Unlock() }, ConnectDone: func(net, addr string, err error) { t.lock.Lock() t.connectDone = time.Now() t.lock.Unlock() }, GetConn: func(_ string) { t.lock.Lock() t.getConn = time.Now() t.lock.Unlock() }, GotConn: func(ci httptrace.GotConnInfo) { t.lock.Lock() t.gotConn = time.Now() t.gotConnInfo = ci t.lock.Unlock() }, GotFirstResponseByte: func() { t.lock.Lock() t.gotFirstResponseByte = time.Now() t.lock.Unlock() }, TLSHandshakeStart: func() { t.lock.Lock() t.tlsHandshakeStart = time.Now() t.lock.Unlock() }, TLSHandshakeDone: func(_ tls.ConnectionState, _ error) { t.lock.Lock() t.tlsHandshakeDone = time.Now() t.lock.Unlock() }, }, ) } ================================================ FILE: transport_dial.go ================================================ // Copyright 2021 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. //go:build !(js && wasm) // +build !js !wasm package resty import ( "context" "net" ) func transportDialContext(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error) { return dialer.DialContext } ================================================ FILE: transport_dial_wasm.go ================================================ // Copyright 2021 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. //go:build (js && wasm) || wasip1 // +build js,wasm wasip1 package resty import ( "context" "net" ) func transportDialContext(_ *net.Dialer) func(context.Context, string, string) (net.Conn, error) { return nil } ================================================ FILE: util.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty import ( "bytes" "crypto/rand" "crypto/sha256" "encoding/binary" "encoding/hex" "encoding/json" "encoding/xml" "errors" "fmt" "io" "log" "net/http" "net/url" "os" "reflect" "runtime" "sort" "strconv" "strings" "sync/atomic" "time" ) //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ // Logger interface //_______________________________________________________________________ // Logger interface is to abstract the logging from Resty. Gives control to // the Resty users, choice of the logger. type Logger interface { Errorf(format string, v ...any) Warnf(format string, v ...any) Debugf(format string, v ...any) } func createLogger() *logger { l := &logger{l: log.New(os.Stderr, "", log.Ldate|log.Lmicroseconds)} return l } var _ Logger = (*logger)(nil) type logger struct { l *log.Logger } func (l *logger) Errorf(format string, v ...any) { l.output("ERROR RESTY "+format, v...) } func (l *logger) Warnf(format string, v ...any) { l.output("WARN RESTY "+format, v...) } func (l *logger) Debugf(format string, v ...any) { l.output("DEBUG RESTY "+format, v...) } func (l *logger) output(format string, v ...any) { if len(v) == 0 { l.l.Print(format) return } l.l.Printf(format, v...) } //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ // In Memory JSON & XML Marshal and Unmarshal using Go package //_____________________________________________________________ var ( // InMemoryJSONMarshal function performs the JSON marshalling completely in memory. // // c := resty.New() // defer c.Close() // // c.AddContentTypeEncoder("application/json", resty.InMemoryJSONMarshal) InMemoryJSONMarshal = func(w io.Writer, v any) error { jsonData, err := json.Marshal(v) if err != nil { return err } _, err = w.Write(jsonData) return err } // InMemoryJSONUnmarshal function performs the JSON unmarshalling completely in memory. // // c := resty.New() // defer c.Close() // // c.AddContentTypeDecoder("application/json", resty.InMemoryJSONUnmarshal) InMemoryJSONUnmarshal = func(r io.Reader, v any) error { byteData, err := io.ReadAll(r) if err != nil { return err } return json.Unmarshal(byteData, v) } // InMemoryXMLMarshal function performs the XML marshalling completely in memory. // // c := resty.New() // defer c.Close() // // c.AddContentTypeEncoder("application/xml", resty.InMemoryXMLMarshal) InMemoryXMLMarshal = func(w io.Writer, v any) error { xmlData, err := xml.Marshal(v) if err != nil { return err } _, err = w.Write(xmlData) return err } // InMemoryJSONUnmarshal function performs the XML unmarshalling completely in memory. // // c := resty.New() // defer c.Close() // // c.AddContentTypeDecoder("application/xml", resty.InMemoryXMLUnmarshal) InMemoryXMLUnmarshal = func(r io.Reader, v any) error { byteData, err := io.ReadAll(r) if err != nil { return err } return xml.Unmarshal(byteData, v) } ) // credentials type is to hold an username and password information type credentials struct { Username string `json:"username"` Password string `json:"password"` } // Clone method returns clone of c. func (c *credentials) Clone() *credentials { cc := new(credentials) *cc = *c return cc } // String method returns masked value of username and password func (c credentials) String() string { return "Username: **********, Password: **********" } //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ // Package Helper methods //_______________________________________________________________________ // isStringEmpty method tells whether given string is empty or not func isStringEmpty(str string) bool { return len(strings.TrimSpace(str)) == 0 } // detectContentType method is used to figure out `Request.Body` content type for request header func detectContentType(body any) string { contentType := plainTextType kind := inferKind(body) switch kind { case reflect.Struct, reflect.Map: contentType = jsonContentType case reflect.String: contentType = plainTextType default: if b, ok := body.([]byte); ok { contentType = http.DetectContentType(b) } else if kind == reflect.Slice { // check slice here to differentiate between any slice vs byte slice contentType = jsonContentType } } return contentType } func isJSONContentType(ct string) bool { return strings.Contains(ct, jsonKey) } func isXMLContentType(ct string) bool { return strings.Contains(ct, xmlKey) } func inferContentTypeMapKey(v string) string { if isJSONContentType(v) { return jsonKey } else if isXMLContentType(v) { return xmlKey } return "" } func firstNonEmpty(v ...string) string { for _, s := range v { if !isStringEmpty(s) { return s } } return "" } var ( mkdirAll = os.MkdirAll createFile = os.Create ioCopy = io.Copy ) func createDirectory(dir string) (err error) { if _, err = os.Stat(dir); err != nil { if os.IsNotExist(err) { if err = mkdirAll(dir, 0755); err != nil { return } } } return } func getPointer(v any) any { if v == nil { return nil } vv := reflect.ValueOf(v) if vv.Kind() == reflect.Ptr { return v } return reflect.New(vv.Type()).Interface() } func inferType(v any) reflect.Type { return reflect.Indirect(reflect.ValueOf(v)).Type() } func inferKind(v any) reflect.Kind { return inferType(v).Kind() } func newInterface(v any) any { if v == nil { return nil } return reflect.New(inferType(v)).Interface() } func functionName(i any) string { return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name() } func acquireBuffer() *bytes.Buffer { buf := bufPool.Get().(*bytes.Buffer) if buf.Len() == 0 { buf.Reset() return buf } bufPool.Put(buf) return new(bytes.Buffer) } func releaseBuffer(buf *bytes.Buffer) { if buf != nil { buf.Reset() bufPool.Put(buf) } } func backToBufPool(buf *bytes.Buffer) { if buf != nil { bufPool.Put(buf) } } func closeq(v any) { if c, ok := v.(io.Closer); ok { silently(c.Close()) } } func silently(_ ...any) {} var sanitizeHeaderToken = []string{ "authorization", "auth", "token", } func isSanitizeHeader(k string) bool { kk := strings.ToLower(k) for _, v := range sanitizeHeaderToken { if strings.Contains(kk, v) { return true } } return false } func sanitizeHeaders(hdr http.Header) http.Header { for k := range hdr { if isSanitizeHeader(k) { hdr[k] = []string{"********************"} } } return hdr } func composeHeaders(hdr http.Header) string { str := make([]string, 0, len(hdr)) for _, k := range sortHeaderKeys(hdr) { str = append(str, "\t"+strings.TrimSpace(fmt.Sprintf("%25s: %s", k, strings.Join(hdr[k], ", ")))) } return strings.Join(str, "\n") } func sortHeaderKeys(hdr http.Header) []string { keys := make([]string, 0, len(hdr)) for key := range hdr { keys = append(keys, key) } sort.Strings(keys) return keys } func wrapErrors(n error, inner error) error { if n == nil && inner == nil { return nil } if inner == nil { return n } if n == nil { return inner } return &restyError{ err: n, inner: inner, } } type restyError struct { err error inner error } func (e *restyError) Error() string { return e.err.Error() } func (e *restyError) Unwrap() error { return e.inner } // cloneURLValues is a helper function to deep copy url.Values. func cloneURLValues(v url.Values) url.Values { if v == nil { return nil } return url.Values(http.Header(v).Clone()) } func cloneCookie(c *http.Cookie) *http.Cookie { return &http.Cookie{ Name: c.Name, Value: c.Value, Path: c.Path, Domain: c.Domain, Expires: c.Expires, RawExpires: c.RawExpires, MaxAge: c.MaxAge, Secure: c.Secure, HttpOnly: c.HttpOnly, SameSite: c.SameSite, Raw: c.Raw, Unparsed: c.Unparsed, } } type invalidRequestError struct { Err error } func (ire *invalidRequestError) Error() string { return ire.Err.Error() } func drainBody(res *Response) { if res != nil && res.Body != nil { drainReadCloser(res.Body) } } func drainReadCloser(body io.ReadCloser) { if body != nil { defer closeq(body) _, _ = io.Copy(io.Discard, body) } } func toJSON(v any) string { buf := acquireBuffer() defer releaseBuffer(buf) _ = encodeJSON(buf, v) return buf.String() } // formatAnyToString converts various types of values to their string representation // based on predefined formatting rules. func formatAnyToString(value any) string { switch v := value.(type) { // Tier 1: most common URL types case string: return v case int: return strconv.Itoa(v) case bool: return strconv.FormatBool(v) case int64: return strconv.FormatInt(v, 10) case []string: return strings.Join(v, ",") // Tier 2: common stdlib types case time.Time: return v.Format(time.RFC3339) case []byte: return string(v) case float64: return strconv.FormatFloat(v, 'f', -1, 64) // Tier 3: less common integers (signed) case int32: return strconv.FormatInt(int64(v), 10) case int16: return strconv.FormatInt(int64(v), 10) case int8: return strconv.FormatInt(int64(v), 10) // Tier 4: less common integers (unsigned) case uint64: return strconv.FormatUint(v, 10) case uint32: return strconv.FormatUint(uint64(v), 10) case uint16: return strconv.FormatUint(uint64(v), 10) case uint8: return strconv.FormatUint(uint64(v), 10) case uint: return strconv.FormatUint(uint64(v), 10) // Tier 5: rare types and fallbacks case float32: return strconv.FormatFloat(float64(v), 'f', -1, 32) case fmt.Stringer: return v.String() default: return fmt.Sprint(v) } } //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ // GUID generation // Code inspired from mgo/bson ObjectId // Code obtained from https://github.com/go-aah/aah/blob/edge/essentials/guid.go //___________________________________ var ( // guidCounter is atomically incremented when generating a new GUID // using UniqueID() function. It's used as a counter part of an id. guidCounter = readRandomUint32() // machineID stores machine id generated once and used in subsequent calls // to UniqueId function. machineID = readMachineID() // processID is current Process Id processID = os.Getpid() ) // newGUID method returns a new Globally Unique Identifier (GUID). // // The 12-byte `UniqueId` consists of- // - 4-byte value representing the seconds since the Unix epoch, // - 3-byte machine identifier, // - 2-byte process id, and // - 3-byte counter, starting with a random value. // // Uses Mongo Object ID algorithm to generate globally unique ids - // https://docs.mongodb.com/manual/reference/method/ObjectId/ func newGUID() string { var b [12]byte // Timestamp, 4 bytes, big endian binary.BigEndian.PutUint32(b[:], uint32(time.Now().Unix())) // Machine, first 3 bytes of sha256.Sum256([]byte(hostname)) b[4], b[5], b[6] = machineID[0], machineID[1], machineID[2] // Pid, 2 bytes, specs don't specify endianness, but we use big endian. b[7], b[8] = byte(processID>>8), byte(processID) // Increment, 3 bytes, big endian i := atomic.AddUint32(&guidCounter, 1) b[9], b[10], b[11] = byte(i>>16), byte(i>>8), byte(i) return hex.EncodeToString(b[:]) } var ioReadFull = io.ReadFull // readRandomUint32 returns a random guidCounter. func readRandomUint32() uint32 { var b [4]byte if _, err := ioReadFull(rand.Reader, b[:]); err == nil { return (uint32(b[0]) << 0) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24) } // To initialize package unexported variable 'guidCounter'. // This panic would happen at program startup, so no worries at runtime panic. panic(errors.New("resty - guid: unable to generate random object id")) } var osHostname = os.Hostname // readMachineID generates and returns a machine id. // If this function fails to get the hostname it will cause a runtime error. func readMachineID() []byte { const idSize = 3 id := make([]byte, idSize) if hostname, err := osHostname(); err == nil { hash := sha256.Sum256([]byte(hostname)) copy(id, hash[:idSize]) return id } if _, err := ioReadFull(rand.Reader, id); err == nil { return id } // To initialize package unexported variable 'machineID'. // This panic would happen at program startup, so no worries at runtime panic. panic(errors.New("resty - guid: unable to get hostname and random bytes")) } ================================================ FILE: util_test.go ================================================ // Copyright (c) 2015-present Jeevanandam M (jeeva@myjeeva.com), All rights reserved. // resty source code and usage is governed by a MIT style // license that can be found in the LICENSE file. // SPDX-License-Identifier: MIT package resty import ( "bytes" "errors" "fmt" "io" "net/http" "net/url" "os" "path/filepath" "strings" "testing" "time" ) func TestIsJSONContentType(t *testing.T) { for _, test := range []struct { input string expect bool }{ {"application/json", true}, {"application/xml+json", true}, {"application/vnd.foo+json", true}, {"application/json; charset=utf-8", true}, {"application/vnd.foo+json; charset=utf-8", true}, {"text/json", true}, {"text/vnd.foo+json", true}, {"application/foo-json", true}, {"application/foo.json", true}, {"application/vnd.foo-json", true}, {"application/vnd.foo.json", true}, {"application/x-amz-json-1.1", true}, {"text/foo-json", true}, {"text/foo.json", true}, {"text/vnd.foo-json", true}, {"text/vnd.foo.json", true}, } { result := isJSONContentType(test.input) if result != test.expect { t.Errorf("failed on %q: want %v, got %v", test.input, test.expect, result) } } } func TestIsXMLContentType(t *testing.T) { for _, test := range []struct { input string expect bool }{ {"application/xml", true}, {"application/vnd.foo+xml", true}, {"application/xml; charset=utf-8", true}, {"application/vnd.foo+xml; charset=utf-8", true}, {"text/xml", true}, {"text/vnd.foo+xml", true}, {"application/foo-xml", true}, {"application/foo.xml", true}, {"application/vnd.foo-xml", true}, {"application/vnd.foo.xml", true}, {"text/foo-xml", true}, {"text/foo.xml", true}, {"text/vnd.foo-xml", true}, {"text/vnd.foo.xml", true}, } { result := isXMLContentType(test.input) if result != test.expect { t.Errorf("failed on %q: want %v, got %v", test.input, test.expect, result) } } } func TestCloneURLValues(t *testing.T) { v := url.Values{} v.Add("foo", "bar") v.Add("foo", "baz") v.Add("qux", "quux") c := cloneURLValues(v) nilUrl := cloneURLValues(nil) assertEqual(t, v, c) assertNil(t, nilUrl) } func TestRestyErrorFuncs(t *testing.T) { ne1 := errors.New("new error 1") nie1 := errors.New("inner error 1") assertNil(t, wrapErrors(nil, nil)) e := wrapErrors(ne1, nie1) assertEqual(t, "new error 1", e.Error()) assertEqual(t, "inner error 1", errors.Unwrap(e).Error()) e = wrapErrors(ne1, nil) assertEqual(t, "new error 1", e.Error()) e = wrapErrors(nil, nie1) assertEqual(t, "inner error 1", e.Error()) } func Test_createDirectory(t *testing.T) { errMsg := "test dir error" mkdirAll = func(path string, perm os.FileMode) error { return errors.New(errMsg) } t.Cleanup(func() { mkdirAll = os.MkdirAll }) tempDir := filepath.Join(t.TempDir(), "test-dir") err := createDirectory(tempDir) assertEqual(t, errMsg, err.Error()) } func TestUtil_readRandomUint32(t *testing.T) { defer func() { if r := recover(); r == nil { // panic: resty - guid: unable to generate random object id t.Errorf("The code did not panic") } }() errMsg := "read full error" ioReadFull = func(_ io.Reader, _ []byte) (int, error) { return 0, errors.New(errMsg) } t.Cleanup(func() { ioReadFull = io.ReadFull }) readRandomUint32() } func TestUtil_readMachineID(t *testing.T) { t.Run("hostname error", func(t *testing.T) { errHostMsg := "hostname error" osHostname = func() (string, error) { return "", errors.New(errHostMsg) } t.Cleanup(func() { osHostname = os.Hostname }) readMachineID() }) t.Run("hostname and read full error", func(t *testing.T) { defer func() { if r := recover(); r == nil { // panic: resty - guid: unable to get hostname and random bytes t.Errorf("The code did not panic") } }() errHostMsg := "hostname error" osHostname = func() (string, error) { return "", errors.New(errHostMsg) } errReadMsg := "read full error" ioReadFull = func(_ io.Reader, _ []byte) (int, error) { return 0, errors.New(errReadMsg) } t.Cleanup(func() { osHostname = os.Hostname ioReadFull = io.ReadFull }) readMachineID() }) } func TestInMemoryJSONMarshalUnmarshal(t *testing.T) { t.Run("json encoder", func(t *testing.T) { user := &credentials{Username: "testuser", Password: "testpass"} buf := acquireBuffer() defer releaseBuffer(buf) err := InMemoryJSONMarshal(buf, user) assertNil(t, err) assertEqual(t, `{"username":"testuser","password":"testpass"}`, buf.String()) }) t.Run("json encoder error", func(t *testing.T) { obj := &brokenMarshalJSON{} buf := acquireBuffer() defer releaseBuffer(buf) err := InMemoryJSONMarshal(buf, obj) assertNotNil(t, err) assertTrue(t, strings.Contains(err.Error(), "b0rk3d"), "broken marshal json error") }) t.Run("json decoder", func(t *testing.T) { byteData := []byte(`{"username":"testuser","password":"testpass"}`) cred := &credentials{} err := InMemoryJSONUnmarshal(bytes.NewReader(byteData), cred) assertNil(t, err) assertEqual(t, "testuser", cred.Username) assertEqual(t, "testpass", cred.Password) }) t.Run("json decoder read error", func(t *testing.T) { cred := &credentials{} err := InMemoryJSONUnmarshal(&brokenReadCloser{}, cred) assertNotNil(t, err) assertEqual(t, err.Error(), "read error") }) t.Run("json decoder error", func(t *testing.T) { byteData := []byte(`"username":"testuser","password":"testpass"}`) cred := &credentials{} err := InMemoryJSONUnmarshal(bytes.NewReader(byteData), cred) assertNotNil(t, err) assertTrue(t, strings.Contains(err.Error(), "invalid character ':' after top-level value"), "invalid json unmarshal error") }) } func TestInMemoryXMLMarshalUnmarshal(t *testing.T) { t.Run("xml encoder", func(t *testing.T) { user := &credentials{Username: "testuser", Password: "testpass"} buf := acquireBuffer() defer releaseBuffer(buf) err := InMemoryXMLMarshal(buf, user) assertNil(t, err) assertEqual(t, `testusertestpass`, buf.String()) }) t.Run("xml encoder error", func(t *testing.T) { obj := &brokenMarshalXML{} buf := acquireBuffer() defer releaseBuffer(buf) err := InMemoryXMLMarshal(buf, obj) assertNotNil(t, err) assertEqual(t, err.Error(), "b0rk3d") }) t.Run("xml decoder", func(t *testing.T) { byteData := []byte(`testusertestpass`) cred := &credentials{} err := InMemoryXMLUnmarshal(bytes.NewReader(byteData), cred) assertNil(t, err) assertEqual(t, "testuser", cred.Username) assertEqual(t, "testpass", cred.Password) }) t.Run("xml decoder read error", func(t *testing.T) { cred := &credentials{} err := InMemoryXMLUnmarshal(&brokenReadCloser{}, cred) assertNotNil(t, err) assertEqual(t, err.Error(), "read error") }) t.Run("xml decoder error", func(t *testing.T) { byteData := []byte(`testusertestpass`) cred := &credentials{} err := InMemoryJSONUnmarshal(bytes.NewReader(byteData), cred) fmt.Println(err) assertNotNil(t, err) assertEqual(t, err.Error(), "invalid character '<' looking for beginning of value") }) } func TestInMemoryJSONPost(t *testing.T) { ts := createPostServer(t) defer ts.Close() user := &credentials{Username: "testuser", Password: "testpass"} assertEqual(t, "Username: **********, Password: **********", user.String()) c := dcnl(). AddContentTypeEncoder(jsonContentType, InMemoryJSONMarshal). AddContentTypeDecoder("appLiCaTion/JSon", InMemoryJSONUnmarshal) r := c.R(). SetHeader(hdrContentTypeKey, jsonContentType). SetBody(user). SetResult(&AuthSuccess{}) resp, err := r.Post(ts.URL + "/login") authResp := resp.Result().(*AuthSuccess) assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, int64(50), resp.Size()) assertEqual(t, authResp.ID, "success") assertEqual(t, authResp.Message, "login successful") } func TestInMemoryXMLPost(t *testing.T) { ts := createPostServer(t) defer ts.Close() xmlContentType := "application/xml" c := dcnl(). AddContentTypeEncoder(xmlContentType, InMemoryXMLMarshal). AddContentTypeDecoder(xmlContentType, InMemoryXMLUnmarshal) resp, err := c.R(). SetHeader(hdrContentTypeKey, xmlContentType). SetBody(credentials{Username: "testuser", Password: "testpass"}). SetResult(&AuthSuccess{}). Post(ts.URL + "/login") authResp := resp.Result().(*AuthSuccess) assertError(t, err) assertEqual(t, http.StatusOK, resp.StatusCode()) assertEqual(t, int64(116), resp.Size()) assertEqual(t, authResp.ID, "success") assertEqual(t, authResp.Message, "login successful") } // This test methods exist for test coverage purpose // to validate the getter and setter func TestUtilMiscTestCoverage(t *testing.T) { l := &limitReadCloser{r: strings.NewReader("hello test close for no io.Closer")} assertNil(t, l.Close()) r := ©ReadCloser{s: strings.NewReader("hello test close for no io.Closer")} assertNil(t, r.Close()) v := struct { ID string `json:"id"` Message string `json:"message"` }{} err := decodeJSON(bytes.NewReader([]byte(`{\" \": \"some value\"}`)), &v) assertEqual(t, "invalid character '\\\\' looking for beginning of object key string", err.Error()) ireErr := &invalidRequestError{Err: errors.New("test coverage")} assertEqual(t, "test coverage", ireErr.Error()) } // customStringer implements fmt.Stringer for testing type customStringer struct { value string } func (c customStringer) String() string { return c.value } func TestFormatAnyToString(t *testing.T) { fixedTime := time.Date(2024, 6, 15, 10, 30, 0, 0, time.UTC) for _, test := range []struct { name string input any expect string }{ // Tier 1: most common URL types {"string", "hello", "hello"}, {"empty string", "", ""}, {"int", 42, "42"}, {"int negative", -123, "-123"}, {"bool true", true, "true"}, {"bool false", false, "false"}, {"int64", int64(9223372036854775807), "9223372036854775807"}, {"int64 negative", int64(-9223372036854775808), "-9223372036854775808"}, {"[]string", []string{"a", "b", "c"}, "a,b,c"}, {"[]string single", []string{"only"}, "only"}, {"[]string empty", []string{}, ""}, // Tier 2: common stdlib types {"time.Time", fixedTime, "2024-06-15T10:30:00Z"}, {"[]byte", []byte("binary data"), "binary data"}, {"float64", 3.14159, "3.14159"}, {"float64 whole", float64(42), "42"}, {"float64 negative", -2.5, "-2.5"}, // Tier 3: less common integers (signed) {"int32", int32(2147483647), "2147483647"}, {"int16", int16(32767), "32767"}, {"int8", int8(127), "127"}, // Tier 4: less common integers (unsigned) {"uint64", uint64(18446744073709551615), "18446744073709551615"}, {"uint32", uint32(4294967295), "4294967295"}, {"uint16", uint16(65535), "65535"}, {"uint8", uint8(255), "255"}, {"uint", uint(12345), "12345"}, // Tier 5: rare types and fallbacks {"float32", float32(3.14), "3.14"}, {"fmt.Stringer", customStringer{value: "custom value"}, "custom value"}, {"default struct", struct{ Name string }{Name: "test"}, "{test}"}, {"nil", nil, ""}, } { t.Run(test.name, func(t *testing.T) { result := formatAnyToString(test.input) assertEqual(t, test.expect, result) }) } }