Repository: trinodb/trino-go-client Branch: master Commit: 0415d515f923 Files: 26 Total size: 325.7 KB Directory structure: gitextract_ms0oco_v/ ├── .github/ │ ├── release.yml │ └── workflows/ │ ├── ci.yml │ └── release.yml ├── .gitignore ├── .goreleaser.yml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── go.mod ├── go.sum └── trino/ ├── etc/ │ ├── catalog/ │ │ ├── hive.properties │ │ ├── memory.properties │ │ └── tpch.properties │ ├── config-pre-466version.properties │ ├── config-pre-477version.properties │ ├── config.properties │ ├── jvm.config │ ├── node.properties │ ├── password-authenticator.properties │ ├── secrets/ │ │ └── .gitignore │ └── spooling-manager.properties ├── integration_test.go ├── serial.go ├── serial_test.go ├── trino.go └── trino_test.go ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/release.yml ================================================ changelog: exclude: labels: - ignore-for-release categories: - title: Breaking changes labels: - breaking-change - title: Features labels: - enhancement - title: Bug fixes labels: - bug - title: Other changes labels: - "*" ================================================ FILE: .github/workflows/ci.yml ================================================ name: ci on: push: branches: - master pull_request: jobs: build: runs-on: ubuntu-latest strategy: fail-fast: false matrix: go: ['>=1.25','1.24.7'] trino: ['latest', '372'] steps: - uses: actions/checkout@v6 - uses: actions/setup-go@v6 with: go-version: ${{ matrix.go }} - run: go test -v -race -timeout 2m ./... -trino_image_tag=${{ matrix.trino }} ================================================ FILE: .github/workflows/release.yml ================================================ name: release on: push: # run only against tags tags: - '*' permissions: contents: write jobs: release: runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 with: fetch-depth: 0 - name: Fetch all tags run: git fetch --force --tags - name: Set up Go uses: actions/setup-go@v6 with: go-version: "1.25" - name: Run GoReleaser uses: goreleaser/goreleaser-action@v6 with: distribution: goreleaser version: latest args: release --clean env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} ================================================ FILE: .gitignore ================================================ coverage.out .idea /dist ================================================ FILE: .goreleaser.yml ================================================ builds: - skip: true changelog: use: github-native ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing to Trino ## Contributor License Agreement ("CLA") In order to accept your pull request, we need you to [submit a CLA](https://github.com/trinodb/cla). ## License By contributing to Trino, you agree that your contributions will be licensed under the [Apache License Version 2.0 (APLv2)](LICENSE). # Go Test Please Run [go test](https://pkg.go.dev/testing) before creating Pull Request ```bash go test -v -race -timeout 1m ./... ``` # Releases To create a new release, a maintainer with repository write permissions needs to create and push a new git tag. ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # Trino Go client A [Trino](https://trino.io) client for the [Go](https://golang.org) programming language. It enables you to send SQL statements from your Go application to Trino, and receive the resulting data. [![Build Status](https://github.com/trinodb/trino-go-client/workflows/ci/badge.svg)](https://github.com/trinodb/trino-go-client/actions?query=workflow%3Aci+event%3Apush+branch%3Amaster) [![GoDoc](https://godoc.org/github.com/trinodb/trino-go-client?status.svg)](https://godoc.org/github.com/trinodb/trino-go-client) ## Features * Native Go implementation * Connections over HTTP or HTTPS * HTTP Basic, Kerberos, and JSON web token (JWT) authentication * Per-query user information for access control * Support custom HTTP client (tunable conn pools, timeouts, TLS) * Supports conversion from Trino to native Go data types * `string`, `sql.NullString` * `int64`, `sql.NullInt64` * `float64`, `sql.NullFloat64` * `map`, `trino.NullMap` * `time.Time`, `trino.NullTime` * Up to 3-dimensional arrays to Go slices, of any supported type ## Requirements * Go 1.24.7 or newer * Trino 372 or newer ## Installation You need a working environment with Go installed and $GOPATH set. Download and install Trino database/sql driver: ```bash go get github.com/trinodb/trino-go-client/trino ``` Make sure you have Git installed and in your $PATH. ## Usage This Trino client is an implementation of Go's `database/sql/driver` interface. In order to use it, you need to import the package and use the [`database/sql`](https://golang.org/pkg/database/sql/) API then. Use `trino` as `driverName` and a valid [DSN](#dsn-data-source-name) as the `dataSourceName`. Example: ```go import "database/sql" import _ "github.com/trinodb/trino-go-client/trino" dsn := "http://user@localhost:8080?catalog=default&schema=test" db, err := sql.Open("trino", dsn) ``` ### Authentication Both HTTP Basic, Kerberos, and JWT authentication are supported. #### HTTP Basic authentication If the DSN contains a password, the client enables HTTP Basic authentication by setting the `Authorization` header in every request to Trino. HTTP Basic authentication **is only supported on encrypted connections over HTTPS**. #### Kerberos authentication This driver supports Kerberos authentication by setting up the Kerberos fields in the [Config](https://godoc.org/github.com/trinodb/trino-go-client/trino#Config) struct. Please refer to the [Coordinator Kerberos Authentication](https://trino.io/docs/current/security/server.html) for server-side configuration. #### JSON web token authentication This driver supports JWT authentication by setting up the `AccessToken` field in the [Config](https://godoc.org/github.com/trinodb/trino-go-client/trino#Config) struct. Please refer to the [Coordinator JWT Authentication](https://trino.io/docs/current/security/jwt.html) for server-side configuration. #### Authorization header forwarding This driver supports forwarding authorization headers by adding a [NamedArg](https://godoc.org/database/sql#NamedArg) with the name `accessToken` (e.g., `accessToken=`) and setting the `ForwardAuthorizationHeader` field in the [Config](https://godoc.org/github.com/trinodb/trino-go-client/trino#Config) struct to `true`. When enabled, this configuration will override the `AccessToken` set in the `Config` struct. #### System access control and per-query user information It's possible to pass user information to Trino, different from the principal used to authenticate to the coordinator. See the [System Access Control](https://trino.io/docs/current/develop/system-access-control.html) documentation for details. In order to pass user information in queries to Trino, you have to add a [NamedArg](https://godoc.org/database/sql#NamedArg) to the query parameters where the key is X-Trino-User. This parameter is used by the driver to inform Trino about the user executing the query regardless of the authentication method for the actual connection, and its value is NOT passed to the query. Example: ```go db.Query("SELECT * FROM foobar WHERE id=?", 1, sql.Named("X-Trino-User", string("Alice"))) ``` The position of the X-Trino-User NamedArg is irrelevant and does not affect the query in any way. ### DSN (Data Source Name) The Data Source Name is a URL with a mandatory username, and optional query string parameters that are supported by this driver, in the following format: ``` http[s]://user[:pass]@host[:port][?parameters] ``` The easiest way to build your DSN is by using the [Config.FormatDSN](https://godoc.org/github.com/trinodb/trino-go-client/trino#Config.FormatDSN) helper function. The driver supports both HTTP and HTTPS. If you use HTTPS it's recommended that you also provide a custom `http.Client` that can validate (or skip) the security checks of the server certificate, and/or to configure TLS client authentication. #### Parameters *Parameters are case-sensitive* Refer to the [Trino Concepts](https://trino.io/docs/current/overview/concepts.html) documentation for more information. ##### `source` ``` Type: string Valid values: string describing the source of the connection to Trino Default: empty ``` The `source` parameter is optional, but if used, can help Trino admins troubleshoot queries and trace them back to the original client. ##### `catalog` ``` Type: string Valid values: the name of a catalog configured in the Trino server Default: empty ``` The `catalog` parameter defines the Trino catalog where schemas exist to organize tables. ##### `schema` ``` Type: string Valid values: the name of an existing schema in the catalog Default: empty ``` The `schema` parameter defines the Trino schema where tables exist. This is also known as namespace in some environments. ##### `session_properties` ``` Type: string Valid values: semicolon-separated list of key:value session properties Default: empty ``` The `session_properties` parameter must contain valid parameters accepted by the Trino server. Run `SHOW SESSION` in Trino to get the current list. ##### `custom_client` ``` Type: string Valid values: the name of a client previously registered to the driver Default: empty (defaults to http.DefaultClient) ``` The `custom_client` parameter allows the use of custom `http.Client` for the communication with Trino. Register your custom client in the driver, then refer to it by name in the DSN, on the call to `sql.Open`: ```go foobarClient := &http.Client{ Transport: &http.Transport{ Proxy: http.ProxyFromEnvironment, DialContext: (&net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, DualStack: true, }).DialContext, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, TLSClientConfig: &tls.Config{ // your config here... }, }, } trino.RegisterCustomClient("foobar", foobarClient) db, err := sql.Open("trino", "https://user@localhost:8080?custom_client=foobar") ``` A custom client can also be used to add OpenTelemetry instrumentation. The [otelhttp](https://pkg.go.dev/go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp) package provides a transport wrapper that creates spans for HTTP requests and propagates the trace ID in HTTP headers: ```go otelClient := &http.Client{ Transport: otelhttp.NewTransport(http.DefaultTransport), } trino.RegisterCustomClient("otel", otelClient) db, err := sql.Open("trino", "https://user@localhost:8080?custom_client=otel") ``` ##### `query_timeout` ``` Type: time.Duration Valid values: duration string Default: nil ``` The `query_timeout` parameter sets a timeout for the query. If the query takes longer than the timeout, it will be cancelled. If it is not set the default context timeout will be used. ##### `explicitPrepare` ``` Type: string Valid values: "true", "false" Default: "true" ``` The `explicitPrepare` parameter controls how queries are sent to the Trino server. When set to `false`, the client uses `EXECUTE IMMEDIATE` which sends the query text in the HTTP request body instead of HTTP headers. This allows sending large query text that would otherwise exceed HTTP header size limits. When set to `true` (default), queries use explicit prepared statements sent via HTTP headers. ##### `clientTags` ``` Type: string Valid values: comma-separated list of tags (e.g. tag1,tag2) Default: empty ``` The `clientTags` parameter is optional and is used to identify Trino resource groups. This helps with query tracking and resource management in Trino clusters. **DSN parameter example:** ``` clientTags=tag1,tag2 ``` **Config struct example:** ```go config := &Config{ ServerURI: "http://foobar@localhost:8080", ClientTags: []string{"tag1", "tag2", "tag3"}, } dsn, err := config.FormatDSN() ``` **Query parameter example (overrides DSN client tags):** ```go rows, err := db.Query(query, sql.Named("X-Trino-Client-Tags", "tag1,tag2,tag3")) ``` ======= #### `roles` ``` Type: string Format: roles=catalog1:role1;catalog2=role2 Valid values: A semicolon-separated list of catalog-to-role assignments, where each assignment maps a catalog to a role. Default: empty ``` The roles parameter defines authorization roles to assume for one or more catalogs during the Trino session. ##### Example ``` go c := &Config{ ServerURI: "https://foobar@localhost:8090", SessionProperties: map[string]string{"query_priority": "1"}, Roles: map[string]string{"catalog1": "role1", "catalog2": "role2"}, } dsn, err := c.FormatDSN() ``` **Query parameter example (overrides DSN roles):** ```go rows, err := db.Query( query, sql.Named("X-Trino-Role", map[string]string{ "catalog1": "role1", "catalog2": "role2", }), ) ``` #### Examples ``` http://user@localhost:8080?source=hello&catalog=default&schema=foobar ``` ``` https://user@localhost:8443?session_properties=query_max_run_time:10m;query_priority:2 ``` ``` http://user@localhost:8080?source=hello&catalog=default&schema=foobar&roles=catalog1:role1;catalog2:role2 ``` ## Data types ### Query arguments When passing arguments to queries, the driver supports the following Go data types: * integers * `bool` * `string` * `[]byte` * slices * `trino.Numeric` - a string representation of a number * `time.Time` - passed to Trino as a timestamp with a time zone * the result of `trino.Date(year, month, day)` - passed to Trino as a date * the result of `trino.Time(hour, minute, second, nanosecond)` - passed to Trino as a time without a time zone * the result of `trino.TimeTz(hour, minute, second, nanosecond, location)` - passed to Trino as a time with a time zone * the result of `trino.Timestamp(year, month, day, hour, minute, second, nanosecond)` - passed to Trino as a timestamp without a time zone * `time.Duration` - passed to Trino as an interval day to second. Because Trino does not support nanosecond precision for intervals, if the nanosecond part of the value is not zero, an error will be returned. It's not yet possible to pass: * `float32` or `float64` * `byte` * `json.RawMessage` * maps To use the unsupported types, pass them as strings and use casts in the query, like so: ```sql SELECT * FROM table WHERE col_double = cast(? AS DOUBLE) OR col_timestamp = CAST(? AS TIMESTAMP) ``` ### Response rows When reading response rows, the driver supports most Trino data types, except: * time and timestamps with precision - all time types are returned as `time.Time`. All precisions up to nanoseconds (`TIMESTAMP(9)` or `TIME(9)`) are supported (since this is the maximum precision Golang's `time.Time` supports). If a query returns columns defined with a greater precision, values are trimmed to 9 decimal digits. Use `CAST` to reduce the returned precision, or convert the value to a string that then can be parsed manually. * `DECIMAL` - returned as string * `IPADDRESS` - returned as string * `INTERVAL YEAR TO MONTH` and `INTERVAL DAY TO SECOND` - returned as string * `UUID` - returned as string Data types like `HyperLogLog`, `SetDigest`, `QDigest`, and `TDigest` are not supported and cannot be returned from a query. For reading nullable columns, use: * `trino.NullTime` * `trino.NullMap` - which stores a map of `map[string]interface{}` or similar structs from the `database/sql` package, like `sql.NullInt64` To read query results containing arrays or maps, pass one of the following structs to the `Scan()` function: * `trino.NullSliceBool` * `trino.NullSliceString` * `trino.NullSliceInt64` * `trino.NullSliceFloat64` * `trino.NullSliceTime` * `trino.NullSliceMap` For two or three dimensional arrays, use `trino.NullSlice2Bool` and `trino.NullSlice3Bool` or equivalents for other data types. To read `ROW` values, implement the `sql.Scanner` interface in a struct. Its `Scan()` function receives a `[]interface{}` slice, with values of the following types: * `bool` * `json.Number` for any numeric Trino types * `[]interface{}` for Trino arrays * `map[string]interface{}` for Trino maps * `string` for other Trino types, as character, date, time, or timestamp. > [!NOTE] > `VARBINARY` columns are returned as base64-encoded strings when used within > `ROW`, `MAP`, or `ARRAY` values. ## Spooling Protocol The client supports the [Trino spooling protocol](https://trino.io/docs/current/client/client-protocol.html#spooling-protocol), which enables efficient retrieval of large result sets by downloading data in segments, optionally in parallel and out-of-order. If the Trino server has the spooling protocol enabled, the client will use it by default with the `json` encoding. You can configure other encodings: - Supported encodings: `json`, `json+lz4`, `json+zstd` ```go rows, err := db.Query(query, sql.Named("encoding", "json+zstd")) ``` Or specify a list of supported encodings in order of preference: ```go rows, err := db.Query(query, sql.Named("encoding", "json+zstd, json+lz4, json")) ``` ### Configuration Options You can tune the spooling protocol using the following parameters, passed as `sql.Named` arguments to your query: - **Spooling Worker Count** `sql.Named("spooling_worker_count", "N")` Sets the number of parallel workers used to download spooled segments. **Default:** `5` **Considerations:** - Increasing this value can improve throughput for large result sets, especially on high-latency networks. - Higher values increase parallelism but may also increase memory usage. - **Max Out-of-Order Segments** `sql.Named("max_out_of_order_segments", "N")` Sets the maximum number of segments that can be downloaded and buffered out-of-order before blocking further downloads. **Default:** `10` **Considerations:** - Higher values increase the potential memory usage, but actual usage depends on download behavior and may be lower in practice. - Higher values reduce the chance that one slow or stalled segment will block the download of additional segments. - Lower values reduce memory usage but may limit parallelism and throughput. **Note:** It is **not allowed** to set `spooling_worker_count` higher than `max_out_of_order_segments` — doing so will result in an error. Each download worker must reserve a slot for the segment it fetches, and a slot is only released when that segment can be processed in order. The total number of slots corresponds to max_out_of_order_segments. If you configure more workers than allowed out-of-order segments, the extra workers would immediately block while waiting for a slot — defeating the purpose of parallelism and potentially wasting resources. #### Example: Customizing Spooling Parameters ```go rows, err := db.Query( query, sql.Named("encoding", "json+zstd"), sql.Named("spooling_worker_count", "8"), sql.Named("max_out_of_order_segments", "20"), ) ``` ## License Apache License V2.0, as described in the [LICENSE](./LICENSE) file. ## Build You can build the client code locally and run tests with the following command: ``` go test -v -race -timeout 2m ./... ``` ## Contributing For contributing, development, and release guidelines, see [CONTRIBUTING.md](./CONTRIBUTING.md). ================================================ FILE: go.mod ================================================ module github.com/trinodb/trino-go-client go 1.24.7 require ( github.com/ahmetb/dlog v0.0.0-20170105205344-4fb5f8204f26 github.com/aws/aws-sdk-go v1.55.8 github.com/aws/aws-sdk-go-v2/config v1.31.8 github.com/aws/aws-sdk-go-v2/credentials v1.18.12 github.com/aws/aws-sdk-go-v2/service/s3 v1.88.1 github.com/golang-jwt/jwt/v5 v5.3.0 github.com/google/btree v1.1.3 github.com/jcmturner/gokrb5/v8 v8.4.4 github.com/klauspost/compress v1.18.1 github.com/ory/dockertest/v3 v3.12.0 github.com/pierrec/lz4 v2.6.1+incompatible github.com/stretchr/testify v1.11.1 ) require ( dario.cat/mergo v1.0.2 // indirect github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c // indirect github.com/Microsoft/go-winio v0.6.2 // indirect github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect github.com/ahmetalpbalkan/dlog v0.0.0-20170105205344-4fb5f8204f26 // indirect github.com/aws/aws-sdk-go-v2 v1.39.0 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.1 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.7 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.7 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.7 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.7 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.1 // indirect github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.8.7 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.7 // indirect github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.7 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.29.3 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.34.4 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.38.4 // indirect github.com/aws/smithy-go v1.23.0 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/containerd/continuity v0.4.5 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/docker/cli v28.4.0+incompatible // indirect github.com/docker/docker v28.4.0+incompatible // indirect github.com/docker/go-connections v0.6.0 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/frankban/quicktest v1.14.6 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect github.com/hashicorp/go-uuid v1.0.3 // indirect github.com/jcmturner/aescts/v2 v2.0.0 // indirect github.com/jcmturner/dnsutils/v2 v2.0.0 // indirect github.com/jcmturner/gofork v1.7.6 // indirect github.com/jcmturner/goidentity/v6 v6.0.1 // indirect github.com/jcmturner/rpc/v2 v2.0.3 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/sys/user v0.4.0 // indirect github.com/moby/term v0.5.2 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect github.com/opencontainers/runc v1.3.1 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect github.com/xeipuuv/gojsonschema v1.2.0 // indirect golang.org/x/crypto v0.45.0 // indirect golang.org/x/net v0.47.0 // indirect golang.org/x/sys v0.38.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) ================================================ FILE: go.sum ================================================ dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA= filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEKWjV8V+WSxDXJ4NFATAsZjh8iIbsQIg= github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 h1:TngWCqHvy9oXAN6lEVMRuU21PR1EtLVZJmdB18Gu3Rw= github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5/go.mod h1:lmUJ/7eu/Q8D7ML55dXQrVaamCz2vxCfdQBasLZfHKk= github.com/ahmetalpbalkan/dlog v0.0.0-20170105205344-4fb5f8204f26 h1:pzStYMLAXM7CNQjS/Wn+zK9MUxDhSUNfVvnHsyQyjs0= github.com/ahmetalpbalkan/dlog v0.0.0-20170105205344-4fb5f8204f26/go.mod h1:ilK+u7u1HoqaDk0mjhh27QJB7PyWMreGffEvOCoEKiY= github.com/ahmetb/dlog v0.0.0-20170105205344-4fb5f8204f26 h1:3YVZUqkoev4mL+aCwVOSWV4M7pN+NURHL38Z2zq5JKA= github.com/ahmetb/dlog v0.0.0-20170105205344-4fb5f8204f26/go.mod h1:ymXt5bw5uSNu4jveerFxE0vNYxF8ncqbptntMaFMg3k= github.com/aws/aws-sdk-go v1.55.8 h1:JRmEUbU52aJQZ2AjX4q4Wu7t4uZjOu71uyNmaWlUkJQ= github.com/aws/aws-sdk-go v1.55.8/go.mod h1:ZkViS9AqA6otK+JBBNH2++sx1sgxrPKcSzPPvQkUtXk= github.com/aws/aws-sdk-go-v2 v1.39.0 h1:xm5WV/2L4emMRmMjHFykqiA4M/ra0DJVSWUkDyBjbg4= github.com/aws/aws-sdk-go-v2 v1.39.0/go.mod h1:sDioUELIUO9Znk23YVmIk86/9DOpkbyyVb1i/gUNFXY= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.1 h1:i8p8P4diljCr60PpJp6qZXNlgX4m2yQFpYk+9ZT+J4E= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.1/go.mod h1:ddqbooRZYNoJ2dsTwOty16rM+/Aqmk/GOXrK8cg7V00= github.com/aws/aws-sdk-go-v2/config v1.31.8 h1:kQjtOLlTU4m4A64TsRcqwNChhGCwaPBt+zCQt/oWsHU= github.com/aws/aws-sdk-go-v2/config v1.31.8/go.mod h1:QPpc7IgljrKwH0+E6/KolCgr4WPLerURiU592AYzfSY= github.com/aws/aws-sdk-go-v2/credentials v1.18.12 h1:zmc9e1q90wMn8wQbjryy8IwA6Q4XlaL9Bx2zIqdNNbk= github.com/aws/aws-sdk-go-v2/credentials v1.18.12/go.mod h1:3VzdRDR5u3sSJRI4kYcOSIBbeYsgtVk7dG5R/U6qLWY= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.7 h1:Is2tPmieqGS2edBnmOJIbdvOA6Op+rRpaYR60iBAwXM= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.7/go.mod h1:F1i5V5421EGci570yABvpIXgRIBPb5JM+lSkHF6Dq5w= github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.7 h1:UCxq0X9O3xrlENdKf1r9eRJoKz/b0AfGkpp3a7FPlhg= github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.7/go.mod h1:rHRoJUNUASj5Z/0eqI4w32vKvC7atoWR0jC+IkmVH8k= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.7 h1:Y6DTZUn7ZUC4th9FMBbo8LVE+1fyq3ofw+tRwkUd3PY= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.7/go.mod h1:x3XE6vMnU9QvHN/Wrx2s44kwzV2o2g5x/siw4ZUJ9g8= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo= github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.7 h1:BszAktdUo2xlzmYHjWMq70DqJ7cROM8iBd3f6hrpuMQ= github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.7/go.mod h1:XJ1yHki/P7ZPuG4fd3f0Pg/dSGA2cTQBCLw82MH2H48= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.1 h1:oegbebPEMA/1Jny7kvwejowCaHz1FWZAQ94WXFNCyTM= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.1/go.mod h1:kemo5Myr9ac0U9JfSjMo9yHLtw+pECEHsFtJ9tqCEI8= github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.8.7 h1:zmZ8qvtE9chfhBPuKB2aQFxW5F/rpwXUgmcVCgQzqRw= github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.8.7/go.mod h1:vVYfbpd2l+pKqlSIDIOgouxNsGu5il9uDp0ooWb0jys= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.7 h1:mLgc5QIgOy26qyh5bvW+nDoAppxgn3J2WV3m9ewq7+8= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.7/go.mod h1:wXb/eQnqt8mDQIQTTmcw58B5mYGxzLGZGK8PWNFZ0BA= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.7 h1:u3VbDKUCWarWiU+aIUK4gjTr/wQFXV17y3hgNno9fcA= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.7/go.mod h1:/OuMQwhSyRapYxq6ZNpPer8juGNrB4P5Oz8bZ2cgjQE= github.com/aws/aws-sdk-go-v2/service/s3 v1.88.1 h1:+RpGuaQ72qnU83qBKVwxkznewEdAGhIWo/PQCmkhhog= github.com/aws/aws-sdk-go-v2/service/s3 v1.88.1/go.mod h1:xajPTguLoeQMAOE44AAP2RQoUhF8ey1g5IFHARv71po= github.com/aws/aws-sdk-go-v2/service/sso v1.29.3 h1:7PKX3VYsZ8LUWceVRuv0+PU+E7OtQb1lgmi5vmUE9CM= github.com/aws/aws-sdk-go-v2/service/sso v1.29.3/go.mod h1:Ql6jE9kyyWI5JHn+61UT/Y5Z0oyVJGmgmJbZD5g4unY= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.34.4 h1:e0XBRn3AptQotkyBFrHAxFB8mDhAIOfsG+7KyJ0dg98= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.34.4/go.mod h1:XclEty74bsGBCr1s0VSaA11hQ4ZidK4viWK7rRfO88I= github.com/aws/aws-sdk-go-v2/service/sts v1.38.4 h1:PR00NXRYgY4FWHqOGx3fC3lhVKjsp1GdloDv2ynMSd8= github.com/aws/aws-sdk-go-v2/service/sts v1.38.4/go.mod h1:Z+Gd23v97pX9zK97+tX4ppAgqCt3Z2dIXB02CtBncK8= github.com/aws/smithy-go v1.23.0 h1:8n6I3gXzWJB2DxBDnfxgBaSX6oe0d/t10qGz7OKqMCE= github.com/aws/smithy-go v1.23.0/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/containerd/continuity v0.4.5 h1:ZRoN1sXq9u7V6QoHMcVWGhOwDFqZ4B9i5H6un1Wh0x4= github.com/containerd/continuity v0.4.5/go.mod h1:/lNJvtJKUQStBzpVQ1+rasXO1LAWtUQssk28EZvJ3nE= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/docker/cli v28.4.0+incompatible h1:RBcf3Kjw2pMtwui5V0DIMdyeab8glEw5QY0UUU4C9kY= github.com/docker/cli v28.4.0+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= github.com/docker/docker v28.4.0+incompatible h1:KVC7bz5zJY/4AZe/78BIvCnPsLaC9T/zh72xnlrTTOk= github.com/docker/docker v28.4.0+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94= github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8= github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs= github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo= github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM= github.com/jcmturner/gofork v1.7.6 h1:QH0l3hzAU1tfT3rZCnW5zXl+orbkNMMRGJfdJjHVETg= github.com/jcmturner/gofork v1.7.6/go.mod h1:1622LH6i/EZqLloHfE7IeZ0uEJwMSUyQ/nDd82IeqRo= github.com/jcmturner/goidentity/v6 v6.0.1 h1:VKnZd2oEIMorCTsFBnJWbExfNN7yZr3EhJAxwOkZg6o= github.com/jcmturner/goidentity/v6 v6.0.1/go.mod h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg= github.com/jcmturner/gokrb5/v8 v8.4.4 h1:x1Sv4HaTpepFkXbt2IkL29DXRf8sOfZXo8eRKh687T8= github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs= github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY= github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co= github.com/klauspost/compress v1.18.1/go.mod h1:ZQFFVG+MdnR0P+l6wpXgIL4NTtwiKIdBnrBd8Nrxr+0= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= github.com/moby/sys/user v0.4.0 h1:jhcMKit7SA80hivmFJcbB1vqmw//wU61Zdui2eQXuMs= github.com/moby/sys/user v0.4.0/go.mod h1:bG+tYYYJgaMtRKgEmuueC0hJEAZWwtIbZTB+85uoHjs= github.com/moby/term v0.5.2 h1:6qk3FJAFDs6i/q3W/pQ97SX192qKfZgGjCQqfCJkgzQ= github.com/moby/term v0.5.2/go.mod h1:d3djjFCrjnB+fl8NJux+EJzu0msscUP+f8it8hPkFLc= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M= github.com/opencontainers/runc v1.3.1 h1:c/yY0oh2wK7tzDuD56REnSxyU8ubh8hoAIOLGLrm4SM= github.com/opencontainers/runc v1.3.1/go.mod h1:9wbWt42gV+KRxKRVVugNP6D5+PQciRbenB4fLVsqGPs= github.com/ory/dockertest/v3 v3.12.0 h1:3oV9d0sDzlSQfHtIaB5k6ghUCVMVLpAY8hwrqoCyRCw= github.com/ory/dockertest/v3 v3.12.0/go.mod h1:aKNDTva3cp8dwOWwb9cWuX84aH5akkxXRvO7KCwWVjE= github.com/pierrec/lz4 v2.6.1+incompatible h1:9UY3+iC23yxF0UfGaYrGplQ+79Rg+h/q9FV9ix19jjM= github.com/pierrec/lz4 v2.6.1+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb h1:zGWFAtiMcyryUHoUjUJX0/lt1H2+i2Ka2n+D3DImSNo= github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 h1:EzJWgHovont7NscjpAxXsDA8S8BMYve8Y5+7cuRE7R0= github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= github.com/xeipuuv/gojsonschema v1.2.0 h1:LhYJRs+L4fBtjZUfuSZIKGeVu0QRy8e5Xi7D17UxZ74= github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= ================================================ FILE: trino/etc/catalog/hive.properties ================================================ connector.name=hive hive.metastore=file hive.metastore.catalog.dir=/tmp/metastore hive.security=sql-standard fs.hadoop.enabled=true ================================================ FILE: trino/etc/catalog/memory.properties ================================================ connector.name=memory ================================================ FILE: trino/etc/catalog/tpch.properties ================================================ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved connector.name=tpch ================================================ FILE: trino/etc/config-pre-466version.properties ================================================ coordinator=true node-scheduler.include-coordinator=true http-server.http.port=8080 discovery-server.enabled=true discovery.uri=http://localhost:8080 http-server.authentication.type=PASSWORD,JWT http-server.authentication.jwt.key-file=/etc/trino/secrets/public_key.pem http-server.https.enabled=true http-server.https.port=8443 http-server.authentication.allow-insecure-over-http=true http-server.https.keystore.path=/etc/trino/secrets/certificate_with_key.pem internal-communication.shared-secret=gotrino query.max-length=5000043 ================================================ FILE: trino/etc/config-pre-477version.properties ================================================ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved coordinator=true node-scheduler.include-coordinator=true http-server.http.port=8080 discovery-server.enabled=true discovery.uri=http://localhost:8080 http-server.authentication.type=PASSWORD,JWT http-server.authentication.jwt.key-file=/etc/trino/secrets/public_key.pem http-server.https.enabled=true http-server.https.port=8443 http-server.authentication.allow-insecure-over-http=true http-server.https.keystore.path=/etc/trino/secrets/certificate_with_key.pem internal-communication.shared-secret=gotrino query.max-length=5000043 ## spooling protocol settings protocol.spooling.enabled=true protocol.spooling.shared-secret-key=jxTKysfCBuMZtFqUf8UJDQ1w9ez8rynEJsJqgJf66u0= protocol.spooling.retrieval-mode=coordinator_proxy # Max number of rows to inline per worker # If the number of rows exceeds this threshold, spooled segments will be returned. # If the number of rows is within this threshold and the max size is below the max-size threshold, # inline segments will be returne protocol.spooling.inlining.max-rows=1000 # Max size of rows to inline per worker # If the total size of the rows exceeds this threshold, spooled segments will be returned. # If the total size of the rows is within this threshold and the row count is below the max-rows threshold, # inline segments will be returned. protocol.spooling.inlining.max-size=128kB ================================================ FILE: trino/etc/config.properties ================================================ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved coordinator=true node-scheduler.include-coordinator=true http-server.http.port=8080 http-server.authentication.type=PASSWORD,JWT http-server.authentication.jwt.key-file=/etc/trino/secrets/public_key.pem http-server.https.enabled=true http-server.https.port=8443 http-server.authentication.allow-insecure-over-http=true http-server.https.keystore.path=/etc/trino/secrets/certificate_with_key.pem internal-communication.shared-secret=gotrino query.max-length=5000043 ## spooling protocol settings protocol.spooling.enabled=true protocol.spooling.shared-secret-key=jxTKysfCBuMZtFqUf8UJDQ1w9ez8rynEJsJqgJf66u0= protocol.spooling.retrieval-mode=coordinator_proxy # Max number of rows to inline per worker # If the number of rows exceeds this threshold, spooled segments will be returned. # If the number of rows is within this threshold and the max size is below the max-size threshold, # inline segments will be returne protocol.spooling.inlining.max-rows=1000 # Max size of rows to inline per worker # If the total size of the rows exceeds this threshold, spooled segments will be returned. # If the total size of the rows is within this threshold and the row count is below the max-rows threshold, # inline segments will be returned. protocol.spooling.inlining.max-size=128kB ================================================ FILE: trino/etc/jvm.config ================================================ -Xmx4G -XX:+UseG1GC -XX:G1HeapRegionSize=32M -XX:+UseGCOverheadLimit -XX:+ExplicitGCInvokesConcurrent -XX:+ExitOnOutOfMemoryError -Djdk.attach.allowAttachSelf=true -Djdk.nio.maxCachedBufferSize=2000000 ================================================ FILE: trino/etc/node.properties ================================================ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved node.environment=test node.id=test node.data-dir=/data/trino ================================================ FILE: trino/etc/password-authenticator.properties ================================================ password-authenticator.name=file file.password-file=/etc/trino/secrets/password.db ================================================ FILE: trino/etc/secrets/.gitignore ================================================ *.pem ================================================ FILE: trino/etc/spooling-manager.properties ================================================ spooling-manager.name=filesystem fs.s3.enabled=true fs.location=s3://spooling/ s3.endpoint=http://localstack:4566/ s3.region=us-east-1 s3.aws-access-key=test s3.aws-secret-key=test s3.path-style-access=true ================================================ FILE: trino/integration_test.go ================================================ // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package trino import ( "bytes" "context" "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "database/sql" "database/sql/driver" "encoding/json" "encoding/pem" "errors" "flag" "fmt" "io" "log" "math" "math/big" "net/http" "net/url" "os" "reflect" "strconv" "strings" "testing" "time" "github.com/ahmetb/dlog" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go/aws" "github.com/golang-jwt/jwt/v5" dt "github.com/ory/dockertest/v3" docker "github.com/ory/dockertest/v3/docker" "github.com/stretchr/testify/require" ) const ( DockerLocalStackName = "localstack" bucketName = "spooling" DockerTrinoName = "trino-go-client-tests" MAXRetries = 10 TrinoNetwork = "trino-network" ) var ( pool *dt.Pool trinoResource *dt.Resource localStackResource *dt.Resource spoolingProtocolSupported bool trinoImageTagFlag = flag.String( "trino_image_tag", os.Getenv("TRINO_IMAGE_TAG"), "Docker image tag used for the Trino server container", ) integrationServerFlag = flag.String( "trino_server_dsn", os.Getenv("TRINO_SERVER_DSN"), "dsn of the Trino server used for integration tests instead of starting a Docker container", ) integrationServerQueryTimeout = flag.Duration( "trino_query_timeout", 5*time.Second, "max duration for Trino queries to run before giving up", ) noCleanup = flag.Bool( "no_cleanup", false, "do not delete containers on exit", ) tlsServer = "" ) func TestMain(m *testing.M) { flag.Parse() DefaultQueryTimeout = *integrationServerQueryTimeout DefaultCancelQueryTimeout = *integrationServerQueryTimeout if *trinoImageTagFlag == "" { *trinoImageTagFlag = "latest" } if *trinoImageTagFlag == "latest" { spoolingProtocolSupported = true } else { version, err := strconv.Atoi(*trinoImageTagFlag) if err != nil { log.Fatalf("Invalid trino_image_tag: %s", *trinoImageTagFlag) } spoolingProtocolSupported = version >= 466 } var err error if *integrationServerFlag == "" && !testing.Short() { pool, err = dt.NewPool("") if err != nil { log.Fatalf("Could not connect to docker: %s", err) } pool.MaxWait = 1 * time.Minute networkID := getOrCreateNetwork(pool) wd, err := os.Getwd() if err != nil { log.Fatalf("Failed to get working directory: %s", err) } var ok bool if spoolingProtocolSupported { localStackResource = getOrCreateLocalStack(pool, networkID) } trinoResource, ok = pool.ContainerByName(DockerTrinoName) if !ok { err = generateCerts(wd + "/etc/secrets") if err != nil { log.Fatalf("Could not generate TLS certificates: %s", err) } mounts := []string{ wd + "/etc/secrets:/etc/trino/secrets", wd + "/etc/jvm.config:/etc/trino/jvm.config", wd + "/etc/node.properties:/etc/trino/node.properties", wd + "/etc/password-authenticator.properties:/etc/trino/password-authenticator.properties", wd + "/etc/catalog/memory.properties:/etc/trino/catalog/memory.properties", wd + "/etc/catalog/tpch.properties:/etc/trino/catalog/tpch.properties", } version, err := strconv.Atoi(*trinoImageTagFlag) if (err != nil && *trinoImageTagFlag == "latest") || (err == nil && version >= 458) { mounts = append(mounts, wd+"/etc/catalog/hive.properties:/etc/trino/catalog/hive.properties") } if spoolingProtocolSupported { version, err := strconv.Atoi(*trinoImageTagFlag) if (err != nil && *trinoImageTagFlag != "latest") || (err == nil && version < 477) { mounts = append(mounts, wd+"/etc/config-pre-477version.properties:/etc/trino/config.properties") } else { mounts = append(mounts, wd+"/etc/config.properties:/etc/trino/config.properties") } mounts = append(mounts, wd+"/etc/spooling-manager.properties:/etc/trino/spooling-manager.properties") } else { mounts = append(mounts, wd+"/etc/config-pre-466version.properties:/etc/trino/config.properties") } trinoResource, err = pool.RunWithOptions(&dt.RunOptions{ Name: DockerTrinoName, Repository: "trinodb/trino", Tag: *trinoImageTagFlag, Mounts: mounts, ExposedPorts: []string{ "8080/tcp", "8443/tcp", }, NetworkID: networkID, }, func(hc *docker.HostConfig) { hc.Ulimits = []docker.ULimit{ { Name: "nofile", Hard: 4096, Soft: 4096, }, } }) if err != nil { log.Fatalf("Could not start resource: %s", err) } } else if !trinoResource.Container.State.Running { pool.Client.StartContainer(trinoResource.Container.ID, nil) } waitForContainerHealth(trinoResource.Container.ID, "trino") err = grantAdminRoleToTestUser() if err != nil { log.Fatalf("Warning: Failed to grant admin role to test user: %s", err) } *integrationServerFlag = "http://test@localhost:" + trinoResource.GetPort("8080/tcp") tlsServer = "https://admin:admin@localhost:" + trinoResource.GetPort("8443/tcp") http.DefaultTransport.(*http.Transport).TLSClientConfig, err = getTLSConfig(wd + "/etc/secrets") if err != nil { log.Fatalf("Failed to set the default TLS config: %s", err) } } code := m.Run() if !*noCleanup && pool != nil { if trinoResource != nil { if err := pool.Purge(trinoResource); err != nil { log.Fatalf("Could not purge resource: %s", err) } } if localStackResource != nil { if err := pool.Purge(localStackResource); err != nil { log.Fatalf("Could not purge LocalStack resource: %s", err) } } networkExists, networkID, err := networkExists(pool, TrinoNetwork) if err == nil && networkExists { if err := pool.Client.RemoveNetwork(networkID); err != nil { log.Fatalf("Could not remove Docker network: %s", err) } } } os.Exit(code) } func grantAdminRoleToTestUser() error { grantSQL := "SET ROLE admin IN hive; GRANT admin TO USER test IN hive;" execCmd := []string{ "trino", "--user", "admin", "--execute", grantSQL, } exec, err := pool.Client.CreateExec(docker.CreateExecOptions{ Container: trinoResource.Container.ID, Cmd: execCmd, }) if err != nil { log.Printf("Warning: Failed to create exec for GRANT: %s", err) } else { var stdout, stderr bytes.Buffer err = pool.Client.StartExec(exec.ID, docker.StartExecOptions{ Detach: false, OutputStream: &stdout, ErrorStream: &stderr, }) if err != nil { log.Printf("Warning: Failed to execute GRANT: %s", err) } } return err } func getOrCreateLocalStack(pool *dt.Pool, networkID string) *dt.Resource { resource, ok := pool.ContainerByName(DockerLocalStackName) if ok { return resource } newResource, err := setupLocalStack(pool, networkID) if err != nil { log.Fatalf("Failed to start LocalStack: %s", err) } return newResource } func getOrCreateNetwork(pool *dt.Pool) string { networkExists, networkID, err := networkExists(pool, TrinoNetwork) if err != nil { log.Fatalf("Could not check if Docker network exists: %s", err) } if networkExists { return networkID } network, err := pool.Client.CreateNetwork(docker.CreateNetworkOptions{ Name: TrinoNetwork, }) if err != nil { log.Fatalf("Could not create Docker network: %s", err) } return network.ID } func networkExists(pool *dt.Pool, networkName string) (bool, string, error) { networks, err := pool.Client.ListNetworks() if err != nil { return false, "", fmt.Errorf("could not list Docker networks: %w", err) } for _, network := range networks { if network.Name == networkName { return true, network.ID, nil } } return false, "", nil } func setupLocalStack(pool *dt.Pool, networkID string) (*dt.Resource, error) { localstackResource, err := pool.RunWithOptions(&dt.RunOptions{ Name: DockerLocalStackName, Repository: "localstack/localstack", Tag: "latest", Env: []string{ "SERVICES=s3", "region_name=us-east-1", "AWS_ACCESS_KEY_ID=test", "AWS_SECRET_ACCESS_KEY=test", }, PortBindings: map[docker.Port][]docker.PortBinding{ "4566/tcp": {{HostIP: "0.0.0.0", HostPort: "4566"}}, "4571/tcp": {{HostIP: "0.0.0.0", HostPort: "4571"}}, }, NetworkID: networkID, }) if err != nil { return nil, fmt.Errorf("could not start LocalStack: %w", err) } localstackPort := localstackResource.GetPort("4566/tcp") s3Endpoint := "http://localhost:" + localstackPort log.Println("LocalStack started at:", s3Endpoint) waitForContainerHealth(localstackResource.Container.ID, "localstack") for retry := 0; retry < MAXRetries; retry++ { err := createS3Bucket(s3Endpoint, "test", "test", bucketName) if err == nil { log.Println("S3 bucket created successfully") return localstackResource, nil } log.Printf("Failed to create S3 bucket, retrying... (%d/%d)\n", retry+1, MAXRetries) time.Sleep(2 * time.Second) } return nil, fmt.Errorf("failed to create S3 bucket after multiple attempts: %w", err) } func createS3Bucket(endpoint, accessKey, secretKey, bucketName string) error { cfg, err := config.LoadDefaultConfig(context.TODO(), config.WithRegion("us-east-1"), config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(accessKey, secretKey, "")), ) if err != nil { return fmt.Errorf("failed to load AWS config: %w", err) } s3Client := s3.New(s3.Options{ Credentials: cfg.Credentials, Region: "us-east-1", BaseEndpoint: &endpoint, UsePathStyle: *aws.Bool(true), }) createBucketInput := &s3.CreateBucketInput{ Bucket: aws.String(bucketName), } _, err = s3Client.CreateBucket(context.TODO(), createBucketInput) if err != nil { return fmt.Errorf("failed to create S3 bucket: %w", err) } log.Printf("Bucket %s created successfully!", bucketName) return nil } func waitForContainerHealth(containerID, containerName string) { if err := pool.Retry(func() error { c, err := pool.Client.InspectContainer(containerID) if err != nil { log.Fatalf("Failed to inspect container %s: %s", containerID, err) } if !c.State.Running { log.Fatalf("Container %s is not running: %s\nContainer logs:\n%s", containerID, c.State.String(), getLogs(trinoResource.Container.ID)) } log.Printf("Waiting for %s container: %s\n", containerName, c.State.String()) if c.State.Health.Status != "healthy" { return errors.New("Not ready") } return nil }); err != nil { log.Fatalf("Timed out waiting for container %s to get ready: %s\nContainer logs:\n%s", containerName, err, getLogs(containerID)) } } func generateCerts(dir string) error { priv, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { return fmt.Errorf("failed to generate private key: %w", err) } serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) if err != nil { return fmt.Errorf("failed to generate serial number: %w", err) } template := x509.Certificate{ SerialNumber: serialNumber, Subject: pkix.Name{ Organization: []string{"Trino Software Foundation"}, }, DNSNames: []string{"localhost"}, NotBefore: time.Now(), NotAfter: time.Now().Add(1 * time.Hour), KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, BasicConstraintsValid: true, } privBytes, err := x509.MarshalPKCS8PrivateKey(priv) if err != nil { return fmt.Errorf("unable to marshal private key: %w", err) } privBlock := &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes} err = writePEM(dir+"/private_key.pem", privBlock) if err != nil { return err } pubBytes, err := x509.MarshalPKIXPublicKey(&priv.PublicKey) if err != nil { return fmt.Errorf("unable to marshal public key: %w", err) } pubBlock := &pem.Block{Type: "PUBLIC KEY", Bytes: pubBytes} err = writePEM(dir+"/public_key.pem", pubBlock) if err != nil { return err } certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) if err != nil { return fmt.Errorf("failed to create certificate: %w", err) } certBlock := &pem.Block{Type: "CERTIFICATE", Bytes: certBytes} err = writePEM(dir+"/certificate.pem", certBlock) if err != nil { return err } err = writePEM(dir+"/certificate_with_key.pem", certBlock, privBlock, pubBlock) if err != nil { return err } return nil } func writePEM(filename string, blocks ...*pem.Block) error { // all files are world-readable, so they can be read inside the Trino container out, err := os.Create(filename) if err != nil { return fmt.Errorf("failed to open %s for writing: %w", filename, err) } for _, block := range blocks { if err := pem.Encode(out, block); err != nil { return fmt.Errorf("failed to write %s data to %s: %w", block.Type, filename, err) } } if err := out.Close(); err != nil { return fmt.Errorf("error closing %s: %w", filename, err) } return nil } func getTLSConfig(dir string) (*tls.Config, error) { certPool, err := x509.SystemCertPool() if err != nil { return nil, fmt.Errorf("failed to read the system cert pool: %s", err) } caCertPEM, err := os.ReadFile(dir + "/certificate.pem") if err != nil { return nil, fmt.Errorf("failed to read the certificate: %s", err) } ok := certPool.AppendCertsFromPEM(caCertPEM) if !ok { return nil, fmt.Errorf("failed to parse the certificate: %s", err) } return &tls.Config{ RootCAs: certPool, }, nil } func getLogs(id string) []byte { var buf bytes.Buffer pool.Client.Logs(docker.LogsOptions{ Container: id, OutputStream: &buf, ErrorStream: &buf, Stdout: true, Stderr: true, RawTerminal: true, }) logs, _ := io.ReadAll(dlog.NewReader(&buf)) return logs } // integrationOpen opens a connection to the integration test server. func integrationOpen(t *testing.T, dsn ...string) *sql.DB { if testing.Short() { t.Skip("Skipping test in short mode.") } target := *integrationServerFlag if len(dsn) > 0 { target = dsn[0] } db, err := sql.Open("trino", target) if err != nil { t.Fatal(err) } return db } // integration tests based on python tests: // https://github.com/trinodb/trino-python-client/tree/master/integration_tests type nodesRow struct { NodeID string HTTPURI string NodeVersion string Coordinator bool State string } func TestIntegrationSelectQueryIterator(t *testing.T) { db := integrationOpen(t) defer db.Close() rows, err := db.Query("SELECT * FROM system.runtime.nodes") if err != nil { t.Fatal(err) } defer rows.Close() count := 0 for rows.Next() { count++ var col nodesRow err = rows.Scan( &col.NodeID, &col.HTTPURI, &col.NodeVersion, &col.Coordinator, &col.State, ) if err != nil { t.Fatal(err) } if col.NodeID != "test" { t.Errorf("Expected node_id == test but got %s", col.NodeID) } } if err = rows.Err(); err != nil { t.Fatal(err) } if count < 1 { t.Error("no rows returned") } } func TestIntegrationSelectQueryNoResult(t *testing.T) { db := integrationOpen(t) defer db.Close() row := db.QueryRow("SELECT * FROM system.runtime.nodes where false") var col nodesRow err := row.Scan( &col.NodeID, &col.HTTPURI, &col.NodeVersion, &col.Coordinator, &col.State, ) if err == nil { t.Fatalf("unexpected query returning data: %+v", col) } } func TestIntegrationSelectFailedQuery(t *testing.T) { db := integrationOpen(t) defer db.Close() rows, err := db.Query("SELECT * FROM catalog.schema.do_not_exist") if err == nil { rows.Close() t.Fatal("query to invalid catalog succeeded") } queryFailed, ok := err.(*ErrQueryFailed) if !ok { t.Fatal("unexpected error:", err) } trinoErr, ok := errors.Unwrap(queryFailed).(*ErrTrino) if !ok { t.Fatal("unexpected error:", trinoErr) } expected := ErrTrino{ Message: "line 1:15: Catalog 'catalog'", SqlState: "", ErrorCode: 44, ErrorName: "CATALOG_NOT_FOUND", ErrorType: "USER_ERROR", ErrorLocation: ErrorLocation{ LineNumber: 1, ColumnNumber: 15, }, FailureInfo: FailureInfo{ Type: "io.trino.spi.TrinoException", Message: "line 1:15: Catalog 'catalog'", }, } if !strings.HasPrefix(trinoErr.Message, expected.Message) { t.Fatalf("expected ErrTrino.Message to start with `%s`, got: %s", expected.Message, trinoErr.Message) } if trinoErr.SqlState != expected.SqlState { t.Fatalf("expected ErrTrino.SqlState to be `%s`, got: %s", expected.SqlState, trinoErr.SqlState) } if trinoErr.ErrorCode != expected.ErrorCode { t.Fatalf("expected ErrTrino.ErrorCode to be `%d`, got: %d", expected.ErrorCode, trinoErr.ErrorCode) } if trinoErr.ErrorName != expected.ErrorName { t.Fatalf("expected ErrTrino.ErrorName to be `%s`, got: %s", expected.ErrorName, trinoErr.ErrorName) } if trinoErr.ErrorType != expected.ErrorType { t.Fatalf("expected ErrTrino.ErrorType to be `%s`, got: %s", expected.ErrorType, trinoErr.ErrorType) } if trinoErr.ErrorLocation.LineNumber != expected.ErrorLocation.LineNumber { t.Fatalf("expected ErrTrino.ErrorLocation.LineNumber to be `%d`, got: %d", expected.ErrorLocation.LineNumber, trinoErr.ErrorLocation.LineNumber) } if trinoErr.ErrorLocation.ColumnNumber != expected.ErrorLocation.ColumnNumber { t.Fatalf("expected ErrTrino.ErrorLocation.ColumnNumber to be `%d`, got: %d", expected.ErrorLocation.ColumnNumber, trinoErr.ErrorLocation.ColumnNumber) } if trinoErr.FailureInfo.Type != expected.FailureInfo.Type { t.Fatalf("expected ErrTrino.FailureInfo.Type to be `%s`, got: %s", expected.FailureInfo.Type, trinoErr.FailureInfo.Type) } if !strings.HasPrefix(trinoErr.FailureInfo.Message, expected.FailureInfo.Message) { t.Fatalf("expected ErrTrino.FailureInfo.Message to start with `%s`, got: %s", expected.FailureInfo.Message, trinoErr.FailureInfo.Message) } } type tpchRow struct { CustKey int Name string Address string NationKey int Phone string AcctBal float64 MktSegment string Comment string } func TestIntegrationSelectTpch1000(t *testing.T) { db := integrationOpen(t) defer db.Close() rows, err := db.Query("SELECT * FROM tpch.sf1.customer LIMIT 1000") if err != nil { t.Fatal(err) } defer rows.Close() count := 0 for rows.Next() { count++ var col tpchRow err = rows.Scan( &col.CustKey, &col.Name, &col.Address, &col.NationKey, &col.Phone, &col.AcctBal, &col.MktSegment, &col.Comment, ) if err != nil { t.Fatal(err) } /* if col.CustKey == 1 && col.AcctBal != 711.56 { t.Fatal("unexpected acctbal for custkey=1:", col.AcctBal) } */ } if rows.Err() != nil { t.Fatal(err) } if count != 1000 { t.Fatal("not enough rows returned:", count) } } func TestIntegrationSelectCancelQuery(t *testing.T) { db := integrationOpen(t) defer db.Close() deadline := time.Now().Add(200 * time.Millisecond) ctx, cancel := context.WithDeadline(context.Background(), deadline) defer cancel() rows, err := db.QueryContext(ctx, "SELECT * FROM tpch.sf1.customer") if err != nil { goto handleErr } defer rows.Close() for rows.Next() { var col tpchRow err = rows.Scan( &col.CustKey, &col.Name, &col.Address, &col.NationKey, &col.Phone, &col.AcctBal, &col.MktSegment, &col.Comment, ) if err != nil { break } } if err = rows.Err(); err == nil { t.Fatal("unexpected query with deadline succeeded") } handleErr: errmsg := err.Error() for _, msg := range []string{"cancel", "deadline"} { if strings.Contains(errmsg, msg) { return } } t.Fatal("unexpected error:", err) } func TestIntegrationSessionProperties(t *testing.T) { dsn := *integrationServerFlag dsn += "?session_properties=query_max_run_time%3A10m%3Bquery_priority%3A2" db := integrationOpen(t, dsn) defer db.Close() rows, err := db.Query("SHOW SESSION") if err != nil { t.Fatal(err) } for rows.Next() { col := struct { Name string Value string Default string Type string Description string }{} err = rows.Scan( &col.Name, &col.Value, &col.Default, &col.Type, &col.Description, ) if err != nil { t.Fatal(err) } switch { case col.Name == "query_max_run_time" && col.Value != "10m": t.Fatal("unexpected value for query_max_run_time:", col.Value) case col.Name == "query_priority" && col.Value != "2": t.Fatal("unexpected value for query_priority:", col.Value) } } if err = rows.Err(); err != nil { t.Fatal(err) } } func TestIntegrationTypeConversion(t *testing.T) { err := RegisterCustomClient("uncompressed", &http.Client{Transport: &http.Transport{DisableCompression: true}}) if err != nil { t.Fatal(err) } dsn := *integrationServerFlag dsn += "?custom_client=uncompressed" db := integrationOpen(t, dsn) var ( goTime time.Time nullTime NullTime goBytes []byte nullBytes []byte goString string nullString sql.NullString nullStringSlice NullSliceString nullStringSlice2 NullSlice2String nullStringSlice3 NullSlice3String nullInt64Slice NullSliceInt64 nullInt64Slice2 NullSlice2Int64 nullInt64Slice3 NullSlice3Int64 nullFloat64Slice NullSliceFloat64 nullFloat64Slice2 NullSlice2Float64 nullFloat64Slice3 NullSlice3Float64 goMap map[string]interface{} nullMap NullMap goRow []interface{} ) err = db.QueryRow(` SELECT TIMESTAMP '2017-07-10 01:02:03.004 UTC', CAST(NULL AS TIMESTAMP), CAST(X'FFFF0FFF3FFFFFFF' AS VARBINARY), CAST(NULL AS VARBINARY), CAST('string' AS VARCHAR), CAST(NULL AS VARCHAR), ARRAY['A', 'B', NULL], ARRAY[ARRAY['A'], NULL], ARRAY[ARRAY[ARRAY['A'], NULL], NULL], ARRAY[1, 2, NULL], ARRAY[ARRAY[1, 1, 1], NULL], ARRAY[ARRAY[ARRAY[1, 1, 1], NULL], NULL], ARRAY[1.0, 2.0, NULL], ARRAY[ARRAY[1.1, 1.1, 1.1], NULL], ARRAY[ARRAY[ARRAY[1.1, 1.1, 1.1], NULL], NULL], MAP(ARRAY['a', 'b'], ARRAY['c', 'd']), CAST(NULL AS MAP(ARRAY(INTEGER), ARRAY(INTEGER))), ROW(1, 'a', CAST('2017-07-10 01:02:03.004 UTC' AS TIMESTAMP(6) WITH TIME ZONE), ARRAY['c']) `).Scan( &goTime, &nullTime, &goBytes, &nullBytes, &goString, &nullString, &nullStringSlice, &nullStringSlice2, &nullStringSlice3, &nullInt64Slice, &nullInt64Slice2, &nullInt64Slice3, &nullFloat64Slice, &nullFloat64Slice2, &nullFloat64Slice3, &goMap, &nullMap, &goRow, ) if err != nil { t.Fatal(err) } // Compare the actual and expected values. expectedTime := time.Date(2017, 7, 10, 1, 2, 3, 4*1000000, time.UTC) if !goTime.Equal(expectedTime) { t.Errorf("expected GoTime to be %v, got %v", expectedTime, goTime) } expectedBytes := []byte{0xff, 0xff, 0x0f, 0xff, 0x3f, 0xff, 0xff, 0xff} if !bytes.Equal(goBytes, expectedBytes) { t.Errorf("expected GoBytes to be %v, got %v", expectedBytes, goBytes) } if nullBytes != nil { t.Errorf("expected NullBytes to be nil, got %v", nullBytes) } if goString != "string" { t.Errorf("expected GoString to be %q, got %q", "string", goString) } if nullString.Valid { t.Errorf("expected NullString.Valid to be false, got true") } if !reflect.DeepEqual(nullStringSlice.SliceString, []sql.NullString{{String: "A", Valid: true}, {String: "B", Valid: true}, {Valid: false}}) { t.Errorf("expected NullStringSlice.SliceString to be %v, got %v", []sql.NullString{{String: "A", Valid: true}, {String: "B", Valid: true}, {Valid: false}}, nullStringSlice.SliceString) } if !nullStringSlice.Valid { t.Errorf("expected NullStringSlice.Valid to be true, got false") } expectedSlice2String := [][]sql.NullString{{{String: "A", Valid: true}}, {}} if !reflect.DeepEqual(nullStringSlice2.Slice2String, expectedSlice2String) { t.Errorf("expected NullStringSlice2.Slice2String to be %v, got %v", expectedSlice2String, nullStringSlice2.Slice2String) } if !nullStringSlice2.Valid { t.Errorf("expected NullStringSlice2.Valid to be true, got false") } expectedSlice3String := [][][]sql.NullString{{{{String: "A", Valid: true}}, {}}, {}} if !reflect.DeepEqual(nullStringSlice3.Slice3String, expectedSlice3String) { t.Errorf("expected NullStringSlice3.Slice3String to be %v, got %v", expectedSlice3String, nullStringSlice3.Slice3String) } if !nullStringSlice3.Valid { t.Errorf("expected NullStringSlice3.Valid to be true, got false") } expectedSliceInt64 := []sql.NullInt64{{Int64: 1, Valid: true}, {Int64: 2, Valid: true}, {Valid: false}} if !reflect.DeepEqual(nullInt64Slice.SliceInt64, expectedSliceInt64) { t.Errorf("expected NullInt64Slice.SliceInt64 to be %v, got %v", expectedSliceInt64, nullInt64Slice.SliceInt64) } if !nullInt64Slice.Valid { t.Errorf("expected NullInt64Slice.Valid to be true, got false") } expectedSlice2Int64 := [][]sql.NullInt64{{{Int64: 1, Valid: true}, {Int64: 1, Valid: true}, {Int64: 1, Valid: true}}, {}} if !reflect.DeepEqual(nullInt64Slice2.Slice2Int64, expectedSlice2Int64) { t.Errorf("expected NullInt64Slice2.Slice2Int64 to be %v, got %v", expectedSlice2Int64, nullInt64Slice2.Slice2Int64) } if !nullInt64Slice2.Valid { t.Errorf("expected NullInt64Slice2.Valid to be true, got false") } expectedSlice3Int64 := [][][]sql.NullInt64{{{{Int64: 1, Valid: true}, {Int64: 1, Valid: true}, {Int64: 1, Valid: true}}, {}}, {}} if !reflect.DeepEqual(nullInt64Slice3.Slice3Int64, expectedSlice3Int64) { t.Errorf("expected NullInt64Slice3.Slice3Int64 to be %v, got %v", expectedSlice3Int64, nullInt64Slice3.Slice3Int64) } if !nullInt64Slice3.Valid { t.Errorf("expected NullInt64Slice3.Valid to be true, got false") } expectedSliceFloat64 := []sql.NullFloat64{{Float64: 1.0, Valid: true}, {Float64: 2.0, Valid: true}, {Valid: false}} if !reflect.DeepEqual(nullFloat64Slice.SliceFloat64, expectedSliceFloat64) { t.Errorf("expected NullFloat64Slice.SliceFloat64 to be %v, got %v", expectedSliceFloat64, nullFloat64Slice.SliceFloat64) } if !nullFloat64Slice.Valid { t.Errorf("expected NullFloat64Slice.Valid to be true, got false") } expectedSlice2Float64 := [][]sql.NullFloat64{{{Float64: 1.1, Valid: true}, {Float64: 1.1, Valid: true}, {Float64: 1.1, Valid: true}}, {}} if !reflect.DeepEqual(nullFloat64Slice2.Slice2Float64, expectedSlice2Float64) { t.Errorf("expected NullFloat64Slice2.Slice2Float64 to be %v, got %v", expectedSlice2Float64, nullFloat64Slice2.Slice2Float64) } if !nullFloat64Slice2.Valid { t.Errorf("expected NullFloat64Slice2.Valid to be true, got false") } expectedSlice3Float64 := [][][]sql.NullFloat64{{{{Float64: 1.1, Valid: true}, {Float64: 1.1, Valid: true}, {Float64: 1.1, Valid: true}}, {}}, {}} if !reflect.DeepEqual(nullFloat64Slice3.Slice3Float64, expectedSlice3Float64) { t.Errorf("expected NullFloat64Slice3.Slice3Float64 to be %v, got %v", expectedSlice3Float64, nullFloat64Slice3.Slice3Float64) } if !nullFloat64Slice3.Valid { t.Errorf("expected NullFloat64Slice3.Valid to be true, got false") } expectedMap := map[string]interface{}{"a": "c", "b": "d"} if !reflect.DeepEqual(goMap, expectedMap) { t.Errorf("expected GoMap to be %v, got %v", expectedMap, goMap) } if nullMap.Valid { t.Errorf("expected NullMap.Valid to be false, got true") } expectedRow := []interface{}{json.Number("1"), "a", "2017-07-10 01:02:03.004000 UTC", []interface{}{"c"}} if !reflect.DeepEqual(goRow, expectedRow) { t.Errorf("expected GoRow to be %v, got %v", expectedRow, goRow) } } func TestComplexTypes(t *testing.T) { // This test has been created to showcase some issues with parsing // complex types. It is not intended to be a comprehensive test of // the parsing logic, but rather to provide a reference for future // changes to the parsing logic. // // The current implementation of the parsing logic reads the value // in the same format as the JSON response from Trino. This means // that we don't go further to parse values as their structured types. // For example, a row like `ROW(1, X'0000')` is read as // a list of a `json.Number(1)` and a base64-encoded string. t.Skip("skipping failing test") dsn := *integrationServerFlag db := integrationOpen(t, dsn) for _, tt := range []struct { name string query string expected interface{} }{ { name: "row containing scalar values", query: `SELECT ROW(1, 'a', X'0000')`, expected: []interface{}{1, "a", []byte{0x00, 0x00}}, }, { name: "nested row", query: `SELECT ROW(ROW(1, 'a'), ROW(2, 'b'))`, expected: []interface{}{[]interface{}{1, "a"}, []interface{}{2, "b"}}, }, { name: "map with scalar values", query: `SELECT MAP(ARRAY['a', 'b'], ARRAY[1, 2])`, expected: map[string]interface{}{"a": 1, "b": 2}, }, { name: "map with nested row", query: `SELECT MAP(ARRAY['a', 'b'], ARRAY[ROW(1, 'a'), ROW(2, 'b')])`, expected: map[string]interface{}{"a": []interface{}{1, "a"}, "b": []interface{}{2, "b"}}, }, } { t.Run(tt.name, func(t *testing.T) { var result interface{} err := db.QueryRow(tt.query).Scan(&result) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(result, tt.expected) { t.Errorf("expected %v, got %v", tt.expected, result) } }) } } func TestIntegrationArgsConversion(t *testing.T) { dsn := *integrationServerFlag db := integrationOpen(t, dsn) value := 0 err := db.QueryRow(` SELECT 1 FROM (VALUES ( CAST(1 AS TINYINT), CAST(1 AS SMALLINT), CAST(1 AS INTEGER), CAST(1 AS BIGINT), CAST(1 AS REAL), CAST(1 AS DOUBLE), TIMESTAMP '2017-07-10 01:02:03.004 UTC', CAST('string' AS VARCHAR), CAST(X'FFFF0FFF3FFFFFFF' AS VARBINARY), ARRAY['A', 'B'] )) AS t(col_tiny, col_small, col_int, col_big, col_real, col_double, col_ts, col_varchar, col_varbinary, col_array ) WHERE 1=1 AND col_tiny = ? AND col_small = ? AND col_int = ? AND col_big = ? AND col_real = cast(? as real) AND col_double = cast(? as double) AND col_ts = ? AND col_varchar = ? AND col_varbinary = ? AND col_array = ?`, int16(1), int16(1), int32(1), int64(1), Numeric("1"), Numeric("1"), time.Date(2017, 7, 10, 1, 2, 3, 4*1000000, time.UTC), "string", []byte{0xff, 0xff, 0x0f, 0xff, 0x3f, 0xff, 0xff, 0xff}, []string{"A", "B"}, ).Scan(&value) if err != nil { t.Fatal(err) } } func TestIntegrationNoResults(t *testing.T) { db := integrationOpen(t) rows, err := db.Query("SELECT 1 LIMIT 0") if err != nil { t.Fatal(err) } for rows.Next() { t.Fatal(errors.New("Rows returned")) } if err = rows.Err(); err != nil { t.Fatal(err) } } func TestRoleHeaderSupport(t *testing.T) { version, err := strconv.Atoi(*trinoImageTagFlag) if (err != nil && *trinoImageTagFlag != "latest") || (err == nil && version < 458) { t.Skip("Skipping test when not using Trino 458 or later.") } tests := []struct { name string config Config rawDSN string query string expectError bool errorSubstr string validateRows func(t *testing.T, rows *sql.Rows) }{ { name: "Valid hive admin role via Config", config: Config{ ServerURI: *integrationServerFlag, Roles: map[string]string{"hive": "admin"}, }, query: "SHOW ROLES FROM hive", expectError: false, validateRows: func(t *testing.T, rows *sql.Rows) { foundAdmin := false for rows.Next() { var roleName string err := rows.Scan(&roleName) require.NoError(t, err) if roleName == "admin" { foundAdmin = true } } require.True(t, foundAdmin, "Expected to find 'admin' role in SHOW ROLES output") }, }, { config: Config{ ServerURI: *integrationServerFlag, Roles: map[string]string{"tpch": "NONE", "memory": "ALL"}, }, query: "SELECT 1", expectError: false, }, { name: "Valid special roles via Config", config: Config{ ServerURI: *integrationServerFlag, Roles: map[string]string{"tpch": "NONE", "memory": "ALL"}, }, query: "SELECT 1", expectError: false, }, { name: "Valid hive admin role via DSN, not encoded url", rawDSN: *integrationServerFlag + "?roles=hive:admin", query: "SHOW ROLES FROM hive", expectError: false, validateRows: func(t *testing.T, rows *sql.Rows) { foundAdmin := false for rows.Next() { var roleName string err := rows.Scan(&roleName) require.NoError(t, err) if roleName == "admin" { foundAdmin = true } } require.True(t, foundAdmin, "Expected to find 'admin' role in SHOW ROLES output") }, }, { name: "Valid roles via DSN, url encoded", rawDSN: *integrationServerFlag + "?roles=hive:admin", query: "SHOW ROLES FROM hive", expectError: false, validateRows: func(t *testing.T, rows *sql.Rows) { foundAdmin := false for rows.Next() { var roleName string err := rows.Scan(&roleName) require.NoError(t, err) if roleName == "admin" { foundAdmin = true } } require.True(t, foundAdmin, "Expected to find 'admin' role in SHOW ROLES output") }, }, { name: "No role - should fail to show roles", config: Config{ ServerURI: *integrationServerFlag, }, query: "SHOW ROLES FROM hive", expectError: true, errorSubstr: "Access Denied", }, { name: "Wrong role - should fail to show roles", config: Config{ ServerURI: *integrationServerFlag, Roles: map[string]string{"hive": "ALL"}, }, query: "SHOW ROLES FROM hive", expectError: true, errorSubstr: "Access Denied", }, { name: "Non-existent catalog role", config: Config{ ServerURI: *integrationServerFlag, Roles: map[string]string{"not-exist-catalog": "role1"}, }, query: "SELECT 1", expectError: true, errorSubstr: "USER_ERROR: Catalog", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var dns string var err error if tt.rawDSN != "" { dns = tt.rawDSN } else { dns, err = tt.config.FormatDSN() if err != nil { t.Fatal(err) } } db := integrationOpen(t, dns) defer db.Close() rows, err := db.Query(tt.query) if tt.expectError { require.Error(t, err) if tt.errorSubstr != "" { require.Contains(t, err.Error(), tt.errorSubstr) } } else { require.NoError(t, err) if tt.validateRows != nil && rows != nil { defer rows.Close() tt.validateRows(t, rows) } } }) } } func TestIntegrationQueryParametersSelect(t *testing.T) { scenarios := []struct { name string query string args []interface{} expectedError error expectedRows int }{ { name: "valid string as varchar", query: "SELECT * FROM system.runtime.nodes WHERE system.runtime.nodes.node_id=?", args: []interface{}{"test"}, expectedRows: 1, }, { name: "valid int as bigint", query: "SELECT * FROM tpch.sf1.customer WHERE custkey=? LIMIT 2", args: []interface{}{int(1)}, expectedRows: 1, }, { name: "invalid string as bigint", query: "SELECT * FROM tpch.sf1.customer WHERE custkey=? LIMIT 2", args: []interface{}{"1"}, expectedError: errors.New(`trino: query failed (200 OK): "USER_ERROR: line 1:46: Cannot apply operator: bigint = varchar(1)"`), }, { name: "valid string as date", query: "SELECT * FROM tpch.sf1.lineitem WHERE shipdate=? LIMIT 2", args: []interface{}{"1995-01-27"}, expectedError: errors.New(`trino: query failed (200 OK): "USER_ERROR: line 1:47: Cannot apply operator: date = varchar(10)"`), }, } for i := range scenarios { scenario := scenarios[i] t.Run(scenario.name, func(t *testing.T) { db := integrationOpen(t) defer db.Close() rows, err := db.Query(scenario.query, scenario.args...) if err != nil { if scenario.expectedError == nil { t.Errorf("Unexpected err: %s", err) return } if err.Error() == scenario.expectedError.Error() { return } t.Errorf("Expected err to be %s but got %s", scenario.expectedError, err) } if scenario.expectedError != nil { t.Error("missing expected error") return } defer rows.Close() var count int for rows.Next() { count++ } if err = rows.Err(); err != nil { t.Fatal(err) } if count != scenario.expectedRows { t.Errorf("expecting %d rows, got %d", scenario.expectedRows, count) } }) } } func TestIntegrationQueryNextAfterClose(t *testing.T) { // NOTE: This is testing invalid behaviour. It ensures that we don't // panic if we call driverRows.Next after we closed the driverStmt. ctx := context.Background() conn, err := (&Driver{}).Open(*integrationServerFlag) if err != nil { t.Fatalf("Failed to open connection: %v", err) } defer conn.Close() stmt, err := conn.(driver.ConnPrepareContext).PrepareContext(ctx, "SELECT 1") if err != nil { t.Fatalf("Failed preparing query: %v", err) } rows, err := stmt.(driver.StmtQueryContext).QueryContext(ctx, []driver.NamedValue{}) if err != nil { t.Fatalf("Failed running query: %v", err) } defer rows.Close() stmt.Close() // NOTE: the important bit. var result driver.Value if err := rows.Next([]driver.Value{result}); err != nil && !spoolingProtocolSupported { t.Fatalf("unexpected result: %+v, no error was expected", err) } if err := rows.Next([]driver.Value{result}); err != io.EOF { t.Fatalf("unexpected result: %+v, expected io.EOF", err) } } func TestIntegrationExec(t *testing.T) { db := integrationOpen(t) defer db.Close() _, err := db.Query(`SELECT count(*) FROM nation`) expected := "Schema must be specified when session schema is not set" if err == nil || !strings.Contains(err.Error(), expected) { t.Fatalf("Expected to fail to execute query with error: %v, got: %v", expected, err) } result, err := db.Exec("USE tpch.sf100") if err != nil { t.Fatal("Failed executing query:", err.Error()) } if result == nil { t.Fatal("Expected exec result to be not nil") } a, err := result.RowsAffected() if err != nil { t.Fatal("Expected RowsAffected not to return any error, got:", err) } if a != 0 { t.Fatal("Expected RowsAffected to be zero, got:", a) } rows, err := db.Query(`SELECT count(*) FROM nation`) if err != nil { t.Fatal("Failed executing query:", err.Error()) } if rows == nil || !rows.Next() { t.Fatal("Failed fetching results") } } func TestIntegrationUnsupportedHeader(t *testing.T) { dsn := *integrationServerFlag dsn += "?catalog=tpch&schema=sf10" db := integrationOpen(t, dsn) defer db.Close() cases := []struct { query string err error }{ { query: "SET ROLE dummy", err: errors.New(`trino: query failed (200 OK): "USER_ERROR: line 1:1: Role 'dummy' does not exist"`), }, { query: "SET PATH dummy", err: errors.New(`trino: query failed (200 OK): "USER_ERROR: SET PATH not supported by client"`), }, } for _, c := range cases { _, err := db.Query(c.query) if err == nil || err.Error() != c.err.Error() { t.Fatal("unexpected error:", err) } } } func TestSpoolingWorkersHigherThenAllowedOutOfOrderSegments(t *testing.T) { if !spoolingProtocolSupported { t.Skip("Skipping test when spooling protocol is not supported.") } db := integrationOpen(t) defer db.Close() expectedError := "spooling worker cannot be greater than max out of order segments allowed. spooling workers: 2, allowed out of order segments: 1" _, err := db.Query("SELECT 1", sql.Named(trinoEncoding, "json"), sql.Named(trinoSpoolingWorkerCount, "2"), sql.Named(trinoMaxOutOfOrdersSegments, "1")) if err == nil || err.Error() != expectedError { t.Fatal("unexpected error:", err) } } func TestIntegrationQueryContext(t *testing.T) { tests := []struct { name string timeout time.Duration expectedErrMsg string }{ { name: "Context Cancellation", timeout: 0, expectedErrMsg: "canceled", }, { name: "Context Deadline Exceeded", timeout: 3 * time.Second, expectedErrMsg: "context deadline exceeded", }, } if err := RegisterCustomClient("uncompressed", &http.Client{Transport: &http.Transport{DisableCompression: true}}); err != nil { t.Fatal(err) } dsn := *integrationServerFlag + "?catalog=tpch&schema=sf100&source=cancel-test&custom_client=uncompressed" db := integrationOpen(t, dsn) defer db.Close() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var ctx context.Context var cancel context.CancelFunc if tt.timeout == 0 { ctx, cancel = context.WithCancel(context.Background()) } else { ctx, cancel = context.WithTimeout(context.Background(), tt.timeout) } defer cancel() errCh := make(chan error, 1) done := make(chan struct{}) longQuery := "SELECT COUNT(*) FROM lineitem" go func() { // query will complete in ~7s unless cancelled rows, err := db.QueryContext(ctx, longQuery) if err != nil { errCh <- err return } defer rows.Close() rows.Next() if err = rows.Err(); err != nil { errCh <- err return } close(done) }() // Poll system.runtime.queries to get the query ID var queryID string pollCtx, pollCancel := context.WithTimeout(context.Background(), 1*time.Second) defer pollCancel() for { row := db.QueryRowContext(pollCtx, "SELECT query_id FROM system.runtime.queries WHERE state = 'RUNNING' AND source = 'cancel-test' AND query = ?", longQuery) err := row.Scan(&queryID) if err == nil { break } if err != sql.ErrNoRows { t.Fatal("failed to read query ID:", err) } if err = contextSleep(pollCtx, 100*time.Millisecond); err != nil { t.Fatal("query did not start in 1 second") } } if tt.timeout == 0 { cancel() } // Wait for the query to be canceled or completed select { case <-done: t.Fatal("unexpected query succeeded despite cancellation or deadline") case err := <-errCh: if !strings.Contains(err.Error(), tt.expectedErrMsg) { t.Fatalf("expected error containing %q, but got: %v", tt.expectedErrMsg, err) } } // Poll system.runtime.queries to verify the query was canceled pollCtx, pollCancel = context.WithTimeout(context.Background(), 2*time.Second) defer pollCancel() for { row := db.QueryRowContext(pollCtx, "SELECT state, error_code FROM system.runtime.queries WHERE query_id = ?", queryID) var state string var code *string err := row.Scan(&state, &code) if err != nil { t.Fatal("failed to read query state:", err) } if state == "FAILED" && code != nil && *code == "USER_CANCELED" { return } if err = contextSleep(pollCtx, 100*time.Millisecond); err != nil { t.Fatalf("query was not canceled in 2 seconds; state: %s, code: %v, err: %v", state, code, err) } } }) } } func TestIntegrationAccessToken(t *testing.T) { if tlsServer == "" { t.Skip("Skipping access token test when using a custom integration server.") } accessToken, err := generateToken() if err != nil { t.Fatal(err) } dsn := tlsServer + "?accessToken=" + accessToken db := integrationOpen(t, dsn) defer db.Close() rows, err := db.Query("SHOW CATALOGS") if err != nil { t.Fatal(err) } defer rows.Close() count := 0 for rows.Next() { count++ } if count < 1 { t.Fatal("not enough rows returned:", count) } } func generateToken() (string, error) { privateKeyPEM, err := os.ReadFile("etc/secrets/private_key.pem") if err != nil { return "", fmt.Errorf("error reading private key file: %w", err) } privateKey, err := jwt.ParseRSAPrivateKeyFromPEM(privateKeyPEM) if err != nil { return "", fmt.Errorf("error parsing private key: %w", err) } // Subject must be 'test' claims := jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * 365 * time.Hour)), Issuer: "gotrino", Subject: "test", } token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) signedToken, err := token.SignedString(privateKey) if err != nil { return "", fmt.Errorf("error generating token: %w", err) } return signedToken, nil } func TestIntegrationTLS(t *testing.T) { if tlsServer == "" { t.Skip("Skipping TLS test when using a custom integration server.") } dsn := tlsServer db := integrationOpen(t, dsn) defer db.Close() row := db.QueryRow("SELECT 1") var count int if err := row.Scan(&count); err != nil { t.Fatal(err) } if count != 1 { t.Fatal("unexpected count=", count) } } func contextSleep(ctx context.Context, d time.Duration) error { timer := time.NewTimer(100 * time.Millisecond) select { case <-timer.C: return nil case <-ctx.Done(): if !timer.Stop() { <-timer.C } return ctx.Err() } } func TestIntegrationDayToHourIntervalMilliPrecision(t *testing.T) { db := integrationOpen(t) defer db.Close() tests := []struct { name string arg time.Duration wantErr bool }{ { name: "valid 1234567891s", arg: time.Duration(1234567891) * time.Second, wantErr: false, }, { name: "valid 123456789.1s", arg: time.Duration(123456789100) * time.Millisecond, wantErr: false, }, { name: "valid 12345678.91s", arg: time.Duration(12345678910) * time.Millisecond, wantErr: false, }, { name: "valid 1234567.891s", arg: time.Duration(1234567891) * time.Millisecond, wantErr: false, }, { name: "valid -1234567891s", arg: time.Duration(-1234567891) * time.Second, wantErr: false, }, { name: "valid -123456789.1s", arg: time.Duration(-123456789100) * time.Millisecond, wantErr: false, }, { name: "valid -12345678.91s", arg: time.Duration(-12345678910) * time.Millisecond, wantErr: false, }, { name: "valid -1234567.891s", arg: time.Duration(-1234567891) * time.Millisecond, wantErr: false, }, { name: "invalid 1234567891.2s", arg: time.Duration(1234567891200) * time.Millisecond, wantErr: true, }, { name: "invalid 123456789.12s", arg: time.Duration(123456789120) * time.Millisecond, wantErr: true, }, { name: "invalid 12345678.912s", arg: time.Duration(12345678912) * time.Millisecond, wantErr: true, }, { name: "invalid -1234567891.2s", arg: time.Duration(-1234567891200) * time.Millisecond, wantErr: true, }, { name: "invalid -123456789.12s", arg: time.Duration(-123456789120) * time.Millisecond, wantErr: true, }, { name: "invalid -12345678.912s", arg: time.Duration(-12345678912) * time.Millisecond, wantErr: true, }, { name: "invalid max seconds (9223372036)", arg: time.Duration(math.MaxInt64) / time.Second * time.Second, wantErr: true, }, { name: "invalid min seconds (-9223372036)", arg: time.Duration(math.MinInt64) / time.Second * time.Second, wantErr: true, }, { name: "valid max seconds (2147483647)", arg: math.MaxInt32 * time.Second, }, { name: "valid min seconds (-2147483647)", arg: -math.MaxInt32 * time.Second, }, { name: "valid max minutes (153722867)", arg: time.Duration(math.MaxInt64) / time.Minute * time.Minute, }, { name: "valid min minutes (-153722867)", arg: time.Duration(math.MinInt64) / time.Minute * time.Minute, }, { name: "valid max hours (2562047)", arg: time.Duration(math.MaxInt64) / time.Hour * time.Hour, }, { name: "valid min hours (-2562047)", arg: time.Duration(math.MinInt64) / time.Hour * time.Hour, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { _, err := db.Exec("SELECT ?", test.arg) if (err != nil) != test.wantErr { t.Errorf("Exec() error = %v, wantErr %v", err, test.wantErr) return } }) } } func TestIntegrationLargeQuery(t *testing.T) { version, err := strconv.Atoi(*trinoImageTagFlag) if (err != nil && *trinoImageTagFlag != "latest") || (err == nil && version < 418) { t.Skip("Skipping test when not using Trino 418 or later.") } dsn := *integrationServerFlag dsn += "?explicitPrepare=false" db := integrationOpen(t, dsn) defer db.Close() rows, err := db.Query("SELECT ?, '"+strings.Repeat("a", 5000000)+"'", 42) if err != nil { t.Fatal(err) } defer rows.Close() count := 0 for rows.Next() { count++ } if rows.Err() != nil { t.Fatal(err) } if count != 1 { t.Fatal("not enough rows returned:", count) } } func TestIntegrationTypeConversionSpoolingProtocolInlineJsonEncoder(t *testing.T) { err := RegisterCustomClient("uncompressed", &http.Client{Transport: &http.Transport{DisableCompression: true}}) if err != nil { t.Fatal(err) } dsn := *integrationServerFlag dsn += "?custom_client=uncompressed" db := integrationOpen(t, dsn) var ( goTime time.Time nullTime NullTime goString string nullString sql.NullString nullStringSlice NullSliceString nullStringSlice2 NullSlice2String nullStringSlice3 NullSlice3String nullInt64Slice NullSliceInt64 nullInt64Slice2 NullSlice2Int64 nullInt64Slice3 NullSlice3Int64 nullFloat64Slice NullSliceFloat64 nullFloat64Slice2 NullSlice2Float64 nullFloat64Slice3 NullSlice3Float64 goMap map[string]interface{} nullMap NullMap goRow []interface{} ) err = db.QueryRow(` SELECT TIMESTAMP '2017-07-10 01:02:03.004 UTC', CAST(NULL AS TIMESTAMP), CAST('string' AS VARCHAR), CAST(NULL AS VARCHAR), ARRAY['A', 'B', NULL], ARRAY[ARRAY['A'], NULL], ARRAY[ARRAY[ARRAY['A'], NULL], NULL], ARRAY[1, 2, NULL], ARRAY[ARRAY[1, 1, 1], NULL], ARRAY[ARRAY[ARRAY[1, 1, 1], NULL], NULL], ARRAY[1.0, 2.0, NULL], ARRAY[ARRAY[1.1, 1.1, 1.1], NULL], ARRAY[ARRAY[ARRAY[1.1, 1.1, 1.1], NULL], NULL], MAP(ARRAY['a', 'b'], ARRAY['c', 'd']), CAST(NULL AS MAP(ARRAY(INTEGER), ARRAY(INTEGER))), ROW(1, 'a', CAST('2017-07-10 01:02:03.004 UTC' AS TIMESTAMP(6) WITH TIME ZONE), ARRAY['c']) `, sql.Named(trinoEncoding, "json")).Scan( &goTime, &nullTime, &goString, &nullString, &nullStringSlice, &nullStringSlice2, &nullStringSlice3, &nullInt64Slice, &nullInt64Slice2, &nullInt64Slice3, &nullFloat64Slice, &nullFloat64Slice2, &nullFloat64Slice3, &goMap, &nullMap, &goRow, ) if err != nil { t.Fatal(err) } } func TestIntegrationSelectTpchSpoolingSegments(t *testing.T) { tests := []struct { name string query string encoding string expected int }{ // Testing with a LIMIT of 1001 rows. // Since we exceed the `protocol.spooling.inlining.max-rows` threshold (1000), // this query trigger spooling protocol with spooled segments. { name: "Spooled Segment JSON+ZSTD Encoded", query: "SELECT * FROM tpch.sf1.customer LIMIT 1001", encoding: "json+zstd", expected: 1001, }, { name: "Spooled Segment JSON Encoded", query: "SELECT * FROM tpch.sf1.customer LIMIT 1001", encoding: "json", expected: 1001, }, { name: "Spooled Segment JSON+LZ4 Encoded", query: "SELECT * FROM tpch.sf1.customer LIMIT 1001", encoding: "json+lz4", expected: 1001, }, // Testing with a LIMIT of 100 rows. // This should remain inline as it is below the `protocol.spooling.inlining.max-rows` (1000) and bellow `protocol.spooling.inlining.max-size` 128kb { name: "Inline Segment JSON+ZSTD Encoded", query: "SELECT * FROM tpch.sf1.customer LIMIT 100", encoding: "json+zstd", expected: 100, }, { name: "Inline Segment JSON+LZ4 Encoded", query: "SELECT * FROM tpch.sf1.customer LIMIT 100", encoding: "json+lz4", expected: 100, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db := integrationOpen(t) defer db.Close() rows, err := db.Query(tt.query, sql.Named(trinoEncoding, tt.encoding)) if err != nil { t.Fatalf("Query failed: %v", err) } defer rows.Close() count := 0 for rows.Next() { count++ var col tpchRow err = rows.Scan( &col.CustKey, &col.Name, &col.Address, &col.NationKey, &col.Phone, &col.AcctBal, &col.MktSegment, &col.Comment, ) if err != nil { t.Fatalf("Row scan failed: %v", err) } } if rows.Err() != nil { t.Fatalf("Rows iteration error: %v", rows.Err()) } if count != tt.expected { t.Fatalf("Expected %d rows, got %d", tt.expected, count) } }) } } func TestSpoolingIntegrationOrderedResults(t *testing.T) { if !spoolingProtocolSupported { t.Skip("Skipping test when spooling protocol is not supported.") } db := integrationOpen(t) defer db.Close() query := ` SELECT * FROM TABLE(sequence( start => 1, stop => 5000000 )) ORDER BY sequential_number ` rows, err := db.Query(query, sql.Named(trinoEncoding, "json")) if err != nil { t.Fatalf("Query failed: %v", err) } defer rows.Close() expected := 1 var actual int for rows.Next() { err = rows.Scan(&actual) if err != nil { t.Fatalf("Row scan failed: %v", err) } if actual != expected { t.Fatalf("Unexpected number at position %d: got %d, expected %d", expected, actual, expected) } expected++ } if rows.Err() != nil { t.Fatalf("Rows iteration error: %v", rows.Err()) } if expected != 5_000_001 { t.Fatalf("Expected 5,000,000 rows, got %d", expected-1) } } func TestDsnClientTags(t *testing.T) { tests := []struct { name string dsnSuffix string source string expectedTags []string }{ { name: "Single tag", dsnSuffix: "?clientTags=test&source=single-tag-test", source: "single-tag-test", expectedTags: []string{"test"}, }, { name: "Multiple tags with special characters", dsnSuffix: "?clientTags=foo+%2520%2Cbar%3Dtest%2Cbaz%23tag&source=multiple-tags-test-special-characters", source: "multiple-tags-test-special-characters", expectedTags: []string{"foo %20", "bar=test", "baz#tag"}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { dsn := *integrationServerFlag + tt.dsnSuffix db := integrationOpen(t, dsn) defer db.Close() query := "SELECT 1" rows, err := db.Query(query) if err != nil { t.Fatal(err) } defer rows.Close() if rows.Next() { } if err := rows.Err(); err != nil { t.Fatal(err) } var queryID string err = db.QueryRowContext(context.Background(), "SELECT query_id FROM system.runtime.queries WHERE source = ? AND query = ?", tt.source, query, ).Scan(&queryID) if err != nil { t.Fatal(err) } queryInfo, err := getQueryInfo(dsn, queryID) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(queryInfo.Session.ClientTags, tt.expectedTags) { t.Fatalf("Expected client tags %v, got %v", tt.expectedTags, queryInfo.Session.ClientTags) } }) } } func TestParametersClientTags(t *testing.T) { tests := []struct { name string dsnSuffix string Tags string source string expectedTags []string }{ { name: "Single tag", dsnSuffix: "?clientTags=query-parameter-single-tag-test&source=query-parameter-single-tag-test", Tags: "single-tag", source: "query-parameter-single-tag-test", expectedTags: []string{"single-tag"}, }, { name: "Multiple tags with special characters", dsnSuffix: "?clientTags=query-parameter-multiple-tags-test&source=query-parameter-multiple-tags-test", Tags: "foo %20,bar=test,baz#tag", source: "query-parameter-multiple-tags-test", expectedTags: []string{"foo %20", "bar=test", "baz#tag"}, }, { name: "Override dsn tags", dsnSuffix: "?clientTags=foo%2B%2520%3Bbar%3Dtest%3Bbaz%23tag&source=query-parameter-override-tags", Tags: "query-parameter-override-tag-test", source: "query-parameter-override-tags", expectedTags: []string{"query-parameter-override-tag-test"}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { dsn := *integrationServerFlag + tt.dsnSuffix db := integrationOpen(t, dsn) defer db.Close() query := "SELECT 1" rows, err := db.Query(query, sql.Named(trinoTagsHeader, tt.Tags)) if err != nil { t.Fatal(err) } defer rows.Close() if rows.Next() { } if err := rows.Err(); err != nil { t.Fatal(err) } var queryID string err = db.QueryRowContext(context.Background(), "SELECT query_id FROM system.runtime.queries WHERE source = ? AND query = ?", tt.source, query, ).Scan(&queryID) if err != nil { t.Fatal(err) } queryInfo, err := getQueryInfo(dsn, queryID) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(queryInfo.Session.ClientTags, tt.expectedTags) { t.Fatalf("Expected client tags %v, got %v", tt.expectedTags, queryInfo.Session.ClientTags) } }) } } type QuerySession struct { ClientTags []string `json:"clientTags"` } type QueryInfo struct { Session QuerySession `json:"session"` } func getQueryInfo(dsn, queryId string) (QueryInfo, error) { serverURL, err := url.Parse(dsn) if err != nil { return QueryInfo{}, err } queryInfoURL := serverURL.Scheme + "://" + serverURL.Host + "/v1/query/" + url.PathEscape(queryId) req, err := http.NewRequest("GET", queryInfoURL, nil) if err != nil { return QueryInfo{}, err } req.Header.Set("X-Trino-User", serverURL.User.Username()) resp, err := http.DefaultClient.Do(req) if err != nil { return QueryInfo{}, err } defer resp.Body.Close() var queryInfo QueryInfo if err := json.NewDecoder(resp.Body).Decode(&queryInfo); err != nil { return QueryInfo{}, err } return queryInfo, nil } ================================================ FILE: trino/serial.go ================================================ // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package trino import ( "encoding/hex" "encoding/json" "fmt" "math" "reflect" "strconv" "strings" "time" ) type UnsupportedArgError struct { t string } func (e UnsupportedArgError) Error() string { return fmt.Sprintf("trino: unsupported arg type: %s", e.t) } // Numeric is a string representation of a number, such as "10", "5.5" or in scientific form // If another string format is used it will error to serialise type Numeric string // trinoDate represents a Date type in Trino. type trinoDate struct { year int month time.Month day int } // Date creates a representation of a Trino Date type. func Date(year int, month time.Month, day int) trinoDate { return trinoDate{year, month, day} } // trinoTime represents a Time type in Trino. type trinoTime struct { hour int minute int second int nanosecond int } // Time creates a representation of a Trino Time type. To represent time with precision higher than nanoseconds, pass the value as a string and use a cast in the query. func Time(hour int, minute int, second int, nanosecond int) trinoTime { return trinoTime{hour, minute, second, nanosecond} } // trinoTimeTz represents a Time(9) With Timezone type in Trino. type trinoTimeTz time.Time // TimeTz creates a representation of a Trino Time(9) With Timezone type. func TimeTz(hour int, minute int, second int, nanosecond int, location *time.Location) trinoTimeTz { // When reading a time, a nil location indicates UTC. // However, passing nil to time.Date() panics. if location == nil { location = time.UTC } return trinoTimeTz(time.Date(0, 0, 0, hour, minute, second, nanosecond, location)) } // Timestamp indicates we want a TimeStamp type WITHOUT a time zone in Trino from a Golang time. type trinoTimestamp time.Time // Timestamp creates a representation of a Trino Timestamp(9) type. func Timestamp(year int, month time.Month, day int, hour int, minute int, second int, nanosecond int) trinoTimestamp { return trinoTimestamp(time.Date(year, month, day, hour, minute, second, nanosecond, time.UTC)) } // Serial converts any supported value to its equivalent string for as a Trino parameter // See https://trino.io/docs/current/language/types.html func Serial(v interface{}) (string, error) { switch x := v.(type) { case nil: return "NULL", nil // numbers convertible to int case int8: return strconv.Itoa(int(x)), nil case int16: return strconv.Itoa(int(x)), nil case int32: return strconv.Itoa(int(x)), nil case int: return strconv.Itoa(x), nil case uint16: return strconv.Itoa(int(x)), nil case int64: return strconv.FormatInt(x, 10), nil case uint32: return strconv.FormatUint(uint64(x), 10), nil case uint: return strconv.FormatUint(uint64(x), 10), nil case uint64: return strconv.FormatUint(x, 10), nil // float32, float64 not supported because digit precision will easily cause large problems case float32: return "", UnsupportedArgError{"float32"} case float64: return "", UnsupportedArgError{"float64"} case Numeric: if _, err := strconv.ParseFloat(string(x), 64); err != nil { return "", err } return string(x), nil // note byte and uint are not supported, this is because byte is an alias for uint8 // if you were to use uint8 (as a number) it could be interpreted as a byte, so it is unsupported // use string instead of byte and any other uint/int type for uint8 case byte: return "", UnsupportedArgError{"byte/uint8"} case bool: return strconv.FormatBool(x), nil case string: return "'" + strings.Replace(x, "'", "''", -1) + "'", nil case []byte: if x == nil { return "NULL", nil } return "X'" + hex.EncodeToString(x) + "'", nil case trinoDate: return fmt.Sprintf("DATE '%04d-%02d-%02d'", x.year, x.month, x.day), nil case trinoTime: return fmt.Sprintf("TIME '%02d:%02d:%02d.%09d'", x.hour, x.minute, x.second, x.nanosecond), nil case trinoTimeTz: return "TIME " + time.Time(x).Format("'15:04:05.999999999 Z07:00'"), nil case trinoTimestamp: return "TIMESTAMP " + time.Time(x).Format("'2006-01-02 15:04:05.999999999'"), nil case time.Time: return "TIMESTAMP " + time.Time(x).Format("'2006-01-02 15:04:05.999999999 Z07:00'"), nil case time.Duration: return serialDuration(x) // TODO - json.RawMesssage should probably be matched to 'JSON' in Trino case json.RawMessage: return "", UnsupportedArgError{"json.RawMessage"} } if reflect.TypeOf(v).Kind() == reflect.Slice { x := reflect.ValueOf(v) if x.IsNil() { return "", UnsupportedArgError{"[]"} } slice := make([]interface{}, x.Len()) for i := 0; i < x.Len(); i++ { slice[i] = x.Index(i).Interface() } return serialSlice(slice) } if reflect.TypeOf(v).Kind() == reflect.Map { // are Trino MAPs indifferent to order? Golang maps are, if Trino aren't then the two types can't be compatible return "", UnsupportedArgError{"map"} } // TODO - consider the remaining types in https://trino.io/docs/current/language/types.html (Row, IP, ...) return "", UnsupportedArgError{fmt.Sprintf("%T", v)} } func serialSlice(v []interface{}) (string, error) { ss := make([]string, len(v)) for i, x := range v { s, err := Serial(x) if err != nil { return "", err } ss[i] = s } return "ARRAY[" + strings.Join(ss, ", ") + "]", nil } const ( // For seconds with milliseconds there is a maximum length of 10 digits // or 11 characters with the dot and 12 characters with the minus sign and dot maxIntervalStrLenWithDot = 11 // 123456789.1 and 12345678.91 are valid ) func serialDuration(dur time.Duration) (string, error) { switch { case dur%time.Hour == 0: return serialHoursInterval(dur), nil case dur%time.Minute == 0: return serialMinutesInterval(dur), nil case dur%time.Second == 0: return serialSecondsInterval(dur) case dur%time.Millisecond == 0: return serialMillisecondsInterval(dur) default: return "", fmt.Errorf("trino: duration %v is not a multiple of hours, minutes, seconds or milliseconds", dur) } } func serialHoursInterval(dur time.Duration) string { return "INTERVAL '" + strconv.Itoa(int(dur/time.Hour)) + "' HOUR" } func serialMinutesInterval(dur time.Duration) string { return "INTERVAL '" + strconv.Itoa(int(dur/time.Minute)) + "' MINUTE" } func serialSecondsInterval(dur time.Duration) (string, error) { seconds := int64(dur / time.Second) if seconds <= math.MinInt32 || seconds > math.MaxInt32 { return "", fmt.Errorf("trino: duration %v is out of range for interval of seconds type", dur) } return "INTERVAL '" + strconv.FormatInt(seconds, 10) + "' SECOND", nil } func serialMillisecondsInterval(dur time.Duration) (string, error) { seconds := int64(dur / time.Second) millisInSecond := dur.Abs().Milliseconds() % 1000 intervalNr := strings.TrimRight(fmt.Sprintf("%d.%03d", seconds, millisInSecond), "0") if seconds > 0 && len(intervalNr) > maxIntervalStrLenWithDot || seconds < 0 && len(intervalNr) > maxIntervalStrLenWithDot+1 { // +1 for the minus sign return "", fmt.Errorf("trino: duration %v is out of range for interval of seconds with millis type", dur) } return "INTERVAL '" + intervalNr + "' SECOND", nil } ================================================ FILE: trino/serial_test.go ================================================ // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package trino import ( "math" "testing" "time" "github.com/stretchr/testify/require" ) func TestSerial(t *testing.T) { paris, err := time.LoadLocation("Europe/Paris") require.NoError(t, err) scenarios := []struct { name string value interface{} expectedError bool expectedSerial string }{ { name: "basic string", value: "hello world", expectedSerial: `'hello world'`, }, { name: "single quoted string", value: "hello world's", expectedSerial: `'hello world''s'`, }, { name: "double quoted string", value: `hello "world"`, expectedSerial: `'hello "world"'`, }, { name: "basic binary", value: []byte{0x01, 0x02, 0x03}, expectedSerial: `X'010203'`, }, { name: "empty binary", value: []byte{}, expectedSerial: `X''`, }, { name: "nil binary", value: []byte(nil), expectedSerial: `NULL`, }, { name: "int8", value: int8(100), expectedSerial: "100", }, { name: "int16", value: int16(100), expectedSerial: "100", }, { name: "int32", value: int32(100), expectedSerial: "100", }, { name: "int", value: int(100), expectedSerial: "100", }, { name: "int64", value: int64(100), expectedSerial: "100", }, { name: "uint8", value: uint8(100), expectedError: true, }, { name: "uint16", value: uint16(100), expectedSerial: "100", }, { name: "uint32", value: uint32(100), expectedSerial: "100", }, { name: "uint", value: uint(100), expectedSerial: "100", }, { name: "uint64", value: uint64(100), expectedSerial: "100", }, { name: "byte", value: byte('a'), expectedError: true, }, { name: "valid Numeric", value: Numeric("10"), expectedSerial: "10", }, { name: "invalid Numeric", value: Numeric("not-a-number"), expectedError: true, }, { name: "bool true", value: true, expectedSerial: "true", }, { name: "bool false", value: false, expectedSerial: "false", }, { name: "date", value: Date(2017, 7, 10), expectedSerial: "DATE '2017-07-10'", }, { name: "time without timezone", value: Time(11, 34, 25, 123456), expectedSerial: "TIME '11:34:25.000123456'", }, { name: "time with timezone", value: TimeTz(11, 34, 25, 123456, time.FixedZone("test zone", +2*3600)), expectedSerial: "TIME '11:34:25.000123456 +02:00'", }, { name: "time with timezone", value: TimeTz(11, 34, 25, 123456, nil), expectedSerial: "TIME '11:34:25.000123456 Z'", }, { name: "timestamp without timezone", value: Timestamp(2017, 7, 10, 11, 34, 25, 123456), expectedSerial: "TIMESTAMP '2017-07-10 11:34:25.000123456'", }, { name: "timestamp with time zone in Fixed Zone", value: time.Date(2017, 7, 10, 11, 34, 25, 123456, time.FixedZone("test zone", +2*3600)), expectedSerial: "TIMESTAMP '2017-07-10 11:34:25.000123456 +02:00'", }, { name: "timestamp with time zone in Named Zone", value: time.Date(2017, 7, 10, 11, 34, 25, 123456, paris), expectedSerial: "TIMESTAMP '2017-07-10 11:34:25.000123456 +02:00'", }, { name: "timestamp with time zone in UTC", value: time.Date(2017, 7, 10, 11, 34, 25, 123456, time.UTC), expectedSerial: "TIMESTAMP '2017-07-10 11:34:25.000123456 Z'", }, { name: "duration", value: 10*time.Second + 5*time.Millisecond, expectedSerial: "INTERVAL '10.005' SECOND", }, { name: "duration with negative value", value: -(10*time.Second + 5*time.Millisecond), expectedSerial: "INTERVAL '-10.005' SECOND", }, { name: "minute duration", value: 10 * time.Minute, expectedSerial: "INTERVAL '10' MINUTE", }, { name: "hour duration", value: 23 * time.Hour, expectedSerial: "INTERVAL '23' HOUR", }, { name: "max hour duration", value: (math.MaxInt64 / time.Hour) * time.Hour, expectedSerial: "INTERVAL '2562047' HOUR", }, { name: "min hour duration", value: (math.MinInt64 / time.Hour) * time.Hour, expectedSerial: "INTERVAL '-2562047' HOUR", }, { name: "max minute duration", value: (math.MaxInt64 / time.Minute) * time.Minute, expectedSerial: "INTERVAL '153722867' MINUTE", }, { name: "min minute duration", value: (math.MinInt64 / time.Minute) * time.Minute, expectedSerial: "INTERVAL '-153722867' MINUTE", }, { name: "too big second duration", value: (math.MaxInt64 / time.Second) * time.Second, expectedError: true, }, { name: "too small second duration", value: (math.MinInt64 / time.Second) * time.Second, expectedError: true, }, { name: "too big millisecond duration", value: time.Millisecond*912 + time.Second*12345678, expectedError: true, }, { name: "too small millisecond duration", value: -(time.Millisecond*910 + time.Second*123456789), expectedError: true, }, { name: "max allowed second duration", value: math.MaxInt32 * time.Second, expectedSerial: "INTERVAL '2147483647' SECOND", }, { name: "min allowed second duration", value: -math.MaxInt32 * time.Second, expectedSerial: "INTERVAL '-2147483647' SECOND", }, { name: "max allowed second with milliseconds duration", value: 999999999*time.Second + 900*time.Millisecond, expectedSerial: "INTERVAL '999999999.9' SECOND", }, { name: "min allowed second with milliseconds duration", value: -999999999*time.Second - 900*time.Millisecond, expectedSerial: "INTERVAL '-999999999.9' SECOND", }, { name: "nil", value: nil, expectedSerial: "NULL", }, { name: "slice typed nil", value: []interface{}(nil), expectedError: true, }, { name: "valid slice", value: []interface{}{1, 2}, expectedSerial: "ARRAY[1, 2]", }, { name: "valid empty", value: []interface{}{}, expectedSerial: "ARRAY[]", }, { name: "invalid slice contents", value: []interface{}{1, byte('a')}, expectedError: true, }, } for i := range scenarios { scenario := scenarios[i] t.Run(scenario.name, func(t *testing.T) { s, err := Serial(scenario.value) if err != nil { if scenario.expectedError { return } t.Fatal(err) } if scenario.expectedError { t.Fatal("missing an expected error") } if scenario.expectedSerial != s { t.Fatalf("mismatched serial, got %q expected %q", s, scenario.expectedSerial) } }) } } ================================================ FILE: trino/trino.go ================================================ // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // This file contains code that was borrowed from prestgo, mainly some // data type definitions. // // See https://github.com/avct/prestgo for copyright information. // // The MIT License (MIT) // // Copyright (c) 2015 Avocet Systems Ltd. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in all // copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. // Package trino provides a database/sql driver for Trino. // // The driver should be used via the database/sql package: // // import "database/sql" // import _ "github.com/trinodb/trino-go-client/trino" // // dsn := "http://user@localhost:8080?catalog=default&schema=test" // db, err := sql.Open("trino", dsn) package trino import ( "bytes" "context" "crypto/tls" "crypto/x509" "database/sql" "database/sql/driver" "encoding/base64" "encoding/json" "errors" "fmt" "io" "math" "net" "net/http" "net/url" "os" "reflect" "slices" "sort" "strconv" "strings" "sync" "time" "unicode" "github.com/jcmturner/gokrb5/v8/client" "github.com/jcmturner/gokrb5/v8/config" "github.com/jcmturner/gokrb5/v8/keytab" "github.com/jcmturner/gokrb5/v8/spnego" "github.com/klauspost/compress/zstd" "github.com/pierrec/lz4" ) func init() { sql.Register("trino", &Driver{}) } var ( // DefaultQueryTimeout is the default timeout for queries executed without a context. DefaultQueryTimeout = 10 * time.Hour // DefaultCancelQueryTimeout is the timeout for the request to cancel queries in Trino. DefaultCancelQueryTimeout = 30 * time.Second // ErrOperationNotSupported indicates that a database operation is not supported. ErrOperationNotSupported = errors.New("trino: operation not supported") // ErrQueryCancelled indicates that a query has been cancelled. ErrQueryCancelled = errors.New("trino: query cancelled") // ErrUnsupportedHeader indicates that the server response contains an unsupported header. ErrUnsupportedHeader = errors.New("trino: server response contains an unsupported header") // ErrInvalidResponseType indicates that the server returned an invalid type definition. ErrInvalidResponseType = errors.New("trino: server response contains an invalid type") // ErrInvalidProgressCallbackHeader indicates that server did not get valid headers for progress callback ErrInvalidProgressCallbackHeader = errors.New("trino: both " + trinoProgressCallbackParam + " and " + trinoProgressCallbackPeriodParam + " must be set when using progress callback") ) const ( trinoHeaderPrefix = `X-Trino-` preparedStatementHeader = trinoHeaderPrefix + "Prepared-Statement" preparedStatementName = "_trino_go" trinoUserHeader = trinoHeaderPrefix + `User` trinoSourceHeader = trinoHeaderPrefix + `Source` trinoCatalogHeader = trinoHeaderPrefix + `Catalog` trinoSchemaHeader = trinoHeaderPrefix + `Schema` trinoSessionHeader = trinoHeaderPrefix + `Session` trinoSetCatalogHeader = trinoHeaderPrefix + `Set-Catalog` trinoSetSchemaHeader = trinoHeaderPrefix + `Set-Schema` trinoSetPathHeader = trinoHeaderPrefix + `Set-Path` trinoSetSessionHeader = trinoHeaderPrefix + `Set-Session` trinoClearSessionHeader = trinoHeaderPrefix + `Clear-Session` trinoSetRoleHeader = trinoHeaderPrefix + `Set-Role` trinoRoleHeader = trinoHeaderPrefix + `Role` trinoExtraCredentialHeader = trinoHeaderPrefix + `Extra-Credential` trinoProgressCallbackParam = trinoHeaderPrefix + `Progress-Callback` trinoProgressCallbackPeriodParam = trinoHeaderPrefix + `Progress-Callback-Period` trinoAddedPrepareHeader = trinoHeaderPrefix + `Added-Prepare` trinoDeallocatedPrepareHeader = trinoHeaderPrefix + `Deallocated-Prepare` trinoTagsHeader = trinoHeaderPrefix + `Client-Tags` trinoQueryDataEncodingHeader = trinoHeaderPrefix + `Query-Data-Encoding` trinoEncoding = "encoding" trinoSpoolingWorkerCount = `spooling_worker_count` trinoMaxOutOfOrdersSegments = `max_out_of_order_segments` authorizationHeader = "Authorization" kerberosEnabledConfig = "KerberosEnabled" kerberosKeytabPathConfig = "KerberosKeytabPath" kerberosPrincipalConfig = "KerberosPrincipal" kerberosRealmConfig = "KerberosRealm" kerberosConfigPathConfig = "KerberosConfigPath" kerberosRemoteServiceNameConfig = "KerberosRemoteServiceName" sslCertPathConfig = "SSLCertPath" sslCertConfig = "SSLCert" accessTokenConfig = "accessToken" explicitPrepareConfig = "explicitPrepare" forwardAuthorizationHeaderConfig = "forwardAuthorizationHeader" mapKeySeparator = ":" mapEntrySeparator = ";" commaSeparator = "," defaultallowedOutOfOrder = 10 defaultSpoolingDownloadWorkers = 5 defaulttrinoEncoding = "json" defaultSourceName = "trino-go-client" defaultKerberosServiceName = "trino" ) var ( responseToRequestHeaderMap = map[string]string{ trinoSetSchemaHeader: trinoSchemaHeader, trinoSetCatalogHeader: trinoCatalogHeader, trinoSetRoleHeader: trinoRoleHeader, } unsupportedResponseHeaders = []string{ trinoSetPathHeader, } ) type Driver struct{} func (d *Driver) Open(name string) (driver.Conn, error) { return newConn(name) } var _ driver.Driver = &Driver{} // Config is a configuration that can be encoded to a DSN string. type Config struct { ServerURI string // URI of the Trino server, e.g. http://user@localhost:8080 Source string // Source of the connection (optional) Catalog string // Catalog (optional) Schema string // Schema (optional) SessionProperties map[string]string // Session properties (optional) ExtraCredentials map[string]string // Extra credentials (optional) ClientTags []string // A comma-separated list of “tag” strings, used to identify Trino resource groups (optional) CustomClientName string // Custom client name (optional) KerberosEnabled bool // KerberosEnabled (optional, default is false) KerberosKeytabPath string // Kerberos Keytab Path (optional) KerberosPrincipal string // Kerberos Principal used to authenticate to KDC (optional) KerberosRemoteServiceName string // Trino coordinator Kerberos service name (optional) KerberosRealm string // The Kerberos Realm (optional) KerberosConfigPath string // The krb5 config path (optional) SSLCertPath string // The SSL cert path for TLS verification (optional) SSLCert string // The SSL cert for TLS verification (optional) AccessToken string // An access token (JWT) for authentication (optional) DisableExplicitPrepare bool // Disable the use of explicit prepared statements (optional, default is false) ForwardAuthorizationHeader bool // Allow forwarding the `accessToken` named query parameter in the authorization header, overwriting the `AccessToken` option, if set (optional) QueryTimeout *time.Duration // Configurable timeout for query (optional) Roles map[string]string // Roles (optional) } func (c *Config) applyDefaults() { if c.Source == "" { c.Source = defaultSourceName } if c.KerberosRemoteServiceName == "" && c.KerberosEnabled { c.KerberosRemoteServiceName = defaultKerberosServiceName } } func ParseDSN(dsn string) (*Config, error) { serverURL, err := url.Parse(dsn) if err != nil { return nil, fmt.Errorf("invalid DSN: %w", err) } query := serverURL.Query() config := &Config{} serverURI := serverURL.Scheme + "://" if serverURL.User != nil { serverURI += serverURL.User.String() + "@" } serverURI += serverURL.Host config.ServerURI = serverURI config.Source = query.Get("source") config.Catalog = query.Get("catalog") config.Schema = query.Get("schema") if sessionProps := query.Get("session_properties"); sessionProps != "" { var err error config.SessionProperties, err = parseMapParameter(sessionProps, "session property", mapEntrySeparator, mapKeySeparator) if err != nil { return nil, err } } if extraCreds := query.Get("extra_credentials"); extraCreds != "" { var err error config.ExtraCredentials, err = parseMapParameter(extraCreds, "extra credential", mapEntrySeparator, mapKeySeparator) if err != nil { return nil, err } } if roles := query.Get("roles"); roles != "" { var err error config.Roles, err = parseMapParameter(roles, "role", mapEntrySeparator, mapKeySeparator) if err != nil { return nil, err } } if clientTags := query.Get("clientTags"); clientTags != "" { config.ClientTags = strings.Split(clientTags, commaSeparator) } config.CustomClientName = query.Get("custom_client") config.AccessToken = query.Get(accessTokenConfig) if explicitPrepare := query.Get(explicitPrepareConfig); explicitPrepare != "" { explicitPrepareValue, err := strconv.ParseBool(explicitPrepare) if err != nil { return nil, fmt.Errorf("invalid boolean for %s: %q", explicitPrepareConfig, explicitPrepare) } config.DisableExplicitPrepare = !explicitPrepareValue } if forwardAuth := query.Get(forwardAuthorizationHeaderConfig); forwardAuth != "" { forwardAuthValue, err := strconv.ParseBool(forwardAuth) if err != nil { return nil, fmt.Errorf("invalid boolean for %s: %q", forwardAuthorizationHeaderConfig, forwardAuth) } config.ForwardAuthorizationHeader = forwardAuthValue } if queryTimeoutStr := query.Get("query_timeout"); queryTimeoutStr != "" { queryTimeout, err := time.ParseDuration(queryTimeoutStr) if err != nil { return nil, fmt.Errorf("trino: invalid timeout for query_timeout: %q", queryTimeoutStr) } config.QueryTimeout = &queryTimeout } if kerberosParam := query.Get(kerberosEnabledConfig); kerberosParam != "" { enabled, err := strconv.ParseBool(kerberosParam) if err != nil { return nil, fmt.Errorf("invalid boolean for %s: %q", kerberosEnabledConfig, kerberosParam) } config.KerberosEnabled = enabled } if kp := query.Get(kerberosKeytabPathConfig); kp != "" { config.KerberosKeytabPath = kp } if p := query.Get(kerberosPrincipalConfig); p != "" { config.KerberosPrincipal = p } if r := query.Get(kerberosRealmConfig); r != "" { config.KerberosRealm = r } if kp := query.Get(kerberosConfigPathConfig); kp != "" { config.KerberosConfigPath = kp } if rsn := query.Get(kerberosRemoteServiceNameConfig); rsn != "" { config.KerberosRemoteServiceName = rsn } if sslCertPath := query.Get(sslCertPathConfig); sslCertPath != "" { config.SSLCertPath = sslCertPath } if sslCert := query.Get(sslCertConfig); sslCert != "" { config.SSLCert = sslCert } config.applyDefaults() return config, nil } func parseMapParameter(value, paramName, entrySeparator, keyValueSeparator string) (map[string]string, error) { result := make(map[string]string) for _, entry := range strings.Split(value, entrySeparator) { parts := strings.SplitN(entry, keyValueSeparator, 2) if len(parts) != 2 { return nil, fmt.Errorf("invalid %s entry: %q", paramName, entry) } result[parts[0]] = parts[1] } return result, nil } func (c *Config) FormatDSN() (string, error) { c.applyDefaults() serverURL, err := url.Parse(c.ServerURI) if err != nil { return "", err } var sessionkv []string if c.SessionProperties != nil { for k, v := range c.SessionProperties { sessionkv = append(sessionkv, k+mapKeySeparator+v) } } var credkv []string if c.ExtraCredentials != nil { for k, v := range c.ExtraCredentials { credkv = append(credkv, k+mapKeySeparator+v) } } var roles []string if c.Roles != nil { for k, v := range c.Roles { roles = append(roles, fmt.Sprintf("%s:%s", k, v)) } } query := make(url.Values) query.Add("source", c.Source) if c.ForwardAuthorizationHeader { query.Add(forwardAuthorizationHeaderConfig, "true") } isSSL := serverURL.Scheme == "https" if c.DisableExplicitPrepare { query.Add(explicitPrepareConfig, "false") } if c.CustomClientName != "" { if c.SSLCert != "" || c.SSLCertPath != "" { return "", fmt.Errorf("trino: client configuration error, a custom client cannot be specific together with a custom SSL certificate") } } if c.SSLCertPath != "" { if !isSSL { return "", fmt.Errorf("trino: client configuration error, SSL must be enabled to specify a custom SSL certificate file") } if c.SSLCert != "" { return "", fmt.Errorf("trino: client configuration error, a custom SSL certificate file cannot be specified together with a certificate string") } query.Add(sslCertPathConfig, c.SSLCertPath) } if c.SSLCert != "" { if !isSSL { return "", fmt.Errorf("trino: client configuration error, SSL must be enabled to specify a custom SSL certificate") } if c.SSLCertPath != "" { return "", fmt.Errorf("trino: client configuration error, a custom SSL certificate string cannot be specified together with a certificate file") } query.Add(sslCertConfig, c.SSLCert) } if c.KerberosEnabled { if !isSSL { return "", fmt.Errorf("trino: client configuration error, SSL must be enabled for secure env") } query.Add(kerberosEnabledConfig, "true") query.Add(kerberosKeytabPathConfig, c.KerberosKeytabPath) query.Add(kerberosPrincipalConfig, c.KerberosPrincipal) query.Add(kerberosRealmConfig, c.KerberosRealm) query.Add(kerberosConfigPathConfig, c.KerberosConfigPath) query.Add(kerberosRemoteServiceNameConfig, c.KerberosRemoteServiceName) } // ensure consistent order of items sort.Strings(sessionkv) sort.Strings(credkv) sort.Strings(roles) if c.QueryTimeout != nil { query.Add("query_timeout", c.QueryTimeout.String()) } for k, v := range map[string]string{ "catalog": c.Catalog, "clientTags": strings.Join(c.ClientTags, commaSeparator), "schema": c.Schema, "session_properties": strings.Join(sessionkv, mapEntrySeparator), "extra_credentials": strings.Join(credkv, mapEntrySeparator), "custom_client": c.CustomClientName, accessTokenConfig: c.AccessToken, "roles": strings.Join(roles, mapEntrySeparator), } { if v != "" { query[k] = []string{v} } } serverURL.RawQuery = query.Encode() return serverURL.String(), nil } // Conn is a Trino connection. type Conn struct { baseURL string auth *url.Userinfo httpClient http.Client httpHeaders http.Header kerberosEnabled bool kerberosClient *client.Client kerberosRemoteServiceName string progressUpdater ProgressUpdater progressUpdaterPeriod queryProgressCallbackPeriod useExplicitPrepare bool forwardAuthorizationHeader bool queryTimeout *time.Duration } var ( _ driver.Conn = &Conn{} _ driver.ConnPrepareContext = &Conn{} ) // formatRolesFromMap formats roles from a map into the Trino header format func formatRolesFromMap(rolesMap map[string]string) string { var formattedRoles []string for catalog, role := range rolesMap { formattedRoles = append(formattedRoles, formatRoleEntry(catalog, role)) } sort.Strings(formattedRoles) return strings.Join(formattedRoles, commaSeparator) } // formatRoleEntry formats a single catalog role entry into Trino header format func formatRoleEntry(catalog, role string) string { if role == "ALL" || role == "NONE" { return fmt.Sprintf("%s=%s", catalog, role) } return fmt.Sprintf("%s=ROLE{%s}", catalog, role) } // formatHeaderValue converts a named argument value to a string suitable for HTTP headers. func formatHeaderValue(headerName string, value interface{}) (string, error) { if headerName == trinoRoleHeader { rolesMap, ok := value.(map[string]string) if !ok { return "", fmt.Errorf("%s must be a map[string]string, got %T", trinoRoleHeader, value) } return formatRolesFromMap(rolesMap), nil } headerValue, ok := value.(string) if !ok { return "", fmt.Errorf("%s must be a string, got %T", headerName, value) } return headerValue, nil } func newConn(dsn string) (*Conn, error) { conf, err := ParseDSN(dsn) if err != nil { return nil, err } var kerberosClient *client.Client if conf.KerberosEnabled { kt, err := keytab.Load(conf.KerberosKeytabPath) if err != nil { return nil, fmt.Errorf("trino: Error loading Keytab: %w", err) } confKerb, err := config.Load(conf.KerberosConfigPath) if err != nil { return nil, fmt.Errorf("trino: Error loading krb config: %w", err) } kerberosClient = client.NewWithKeytab(conf.KerberosPrincipal, conf.KerberosRealm, kt, confKerb) loginErr := kerberosClient.Login() if loginErr != nil { return nil, fmt.Errorf("trino: Error login to KDC: %v", loginErr) } } serverURL, err := url.Parse(conf.ServerURI) if err != nil { return nil, fmt.Errorf("trino: invalid server URL: %w", err) } var httpClient = http.DefaultClient if clientKey := conf.CustomClientName; clientKey != "" { httpClient = getCustomClient(clientKey) if httpClient == nil { return nil, fmt.Errorf("trino: custom client not registered: %q", clientKey) } } else if serverURL.Scheme == "https" { cert := []byte(conf.SSLCert) if certPath := conf.SSLCertPath; certPath != "" { cert, err = os.ReadFile(certPath) if err != nil { return nil, fmt.Errorf("trino: Error loading SSL Cert File: %w", err) } } if len(cert) != 0 { certPool := x509.NewCertPool() certPool.AppendCertsFromPEM(cert) httpClient = &http.Client{ Transport: &http.Transport{ TLSClientConfig: &tls.Config{ RootCAs: certPool, }, }, } } } c := &Conn{ baseURL: serverURL.Scheme + "://" + serverURL.Host, httpClient: *httpClient, httpHeaders: make(http.Header), kerberosClient: kerberosClient, kerberosEnabled: conf.KerberosEnabled, kerberosRemoteServiceName: conf.KerberosRemoteServiceName, useExplicitPrepare: !conf.DisableExplicitPrepare, forwardAuthorizationHeader: conf.ForwardAuthorizationHeader, queryTimeout: conf.QueryTimeout, } var user string if serverURL.User != nil { user = serverURL.User.Username() pass, _ := serverURL.User.Password() if pass != "" && serverURL.Scheme == "https" { c.auth = serverURL.User } } if tags := conf.ClientTags; tags != nil { c.httpHeaders.Add(trinoTagsHeader, strings.Join(tags, commaSeparator)) } if conf.Roles != nil { rolesHeader := formatRolesFromMap(conf.Roles) if rolesHeader != "" { c.httpHeaders.Add(trinoRoleHeader, rolesHeader) } } for k, v := range map[string]string{ trinoUserHeader: user, trinoSourceHeader: conf.Source, trinoCatalogHeader: conf.Catalog, trinoSchemaHeader: conf.Schema, authorizationHeader: getAuthorization(conf.AccessToken), } { if v != "" { c.httpHeaders.Add(k, v) } } if conf.ExtraCredentials != nil { c.httpHeaders[trinoExtraCredentialHeader], err = decodeMapHeader("extra_credentials", conf.ExtraCredentials) if err != nil { return c, err } } if conf.SessionProperties != nil { c.httpHeaders[trinoSessionHeader], err = decodeMapHeader("session_properties", conf.SessionProperties) if err != nil { return c, err } } return c, nil } func decodeMapHeader(name string, m map[string]string) ([]string, error) { result := make([]string, 0, len(m)) for key, value := range m { if len(key) == 0 { return nil, fmt.Errorf("trino: %s key is empty", name) } if len(value) == 0 { return nil, fmt.Errorf("trino: %s value is empty", name) } if !isASCII(key) { return nil, fmt.Errorf("trino: %s key '%s' contains spaces or is not printable ASCII", name, key) } if !isASCII(value) { return nil, fmt.Errorf("trino: %s value for key '%s' contains spaces or is not printable ASCII", name, key) } result = append(result, key+"="+url.QueryEscape(value)) } return result, nil } func isASCII(s string) bool { for i := 0; i < len(s); i++ { if s[i] < '\u0021' || s[i] > '\u007E' { return false } } return true } func getAuthorization(token string) string { if token == "" { return "" } return fmt.Sprintf("Bearer %s", token) } // registry for custom http clients var customClientRegistry = struct { sync.RWMutex Index map[string]http.Client }{ Index: make(map[string]http.Client), } // RegisterCustomClient associates a client to a key in the driver's registry. // // Register your custom client in the driver, then refer to it by name in the DSN, on the call to sql.Open: // // foobarClient := &http.Client{ // Transport: &http.Transport{ // Proxy: http.ProxyFromEnvironment, // DialContext: (&net.Dialer{ // Timeout: 30 * time.Second, // KeepAlive: 30 * time.Second, // DualStack: true, // }).DialContext, // MaxIdleConns: 100, // IdleConnTimeout: 90 * time.Second, // TLSHandshakeTimeout: 10 * time.Second, // ExpectContinueTimeout: 1 * time.Second, // TLSClientConfig: &tls.Config{ // // your config here... // }, // }, // } // trino.RegisterCustomClient("foobar", foobarClient) // db, err := sql.Open("trino", "https://user@localhost:8080?custom_client=foobar") func RegisterCustomClient(key string, client *http.Client) error { if _, err := strconv.ParseBool(key); err == nil { return fmt.Errorf("trino: custom client key %q is reserved", key) } customClientRegistry.Lock() customClientRegistry.Index[key] = *client customClientRegistry.Unlock() return nil } // DeregisterCustomClient removes the client associated to the key. func DeregisterCustomClient(key string) { customClientRegistry.Lock() delete(customClientRegistry.Index, key) customClientRegistry.Unlock() } func getCustomClient(key string) *http.Client { customClientRegistry.RLock() defer customClientRegistry.RUnlock() if client, ok := customClientRegistry.Index[key]; ok { return &client } return nil } // Begin implements the driver.Conn interface. func (c *Conn) Begin() (driver.Tx, error) { return nil, ErrOperationNotSupported } // Prepare implements the driver.Conn interface. func (c *Conn) Prepare(query string) (driver.Stmt, error) { return nil, driver.ErrSkip } // PrepareContext implements the driver.ConnPrepareContext interface. func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { return &driverStmt{conn: c, query: query}, nil } // Close implements the driver.Conn interface. func (c *Conn) Close() error { return nil } func (c *Conn) newRequest(ctx context.Context, method, url string, body io.Reader, hs http.Header) (*http.Request, error) { req, err := http.NewRequestWithContext(ctx, method, url, body) if err != nil { return nil, fmt.Errorf("trino: %w", err) } if c.kerberosEnabled { remoteServiceName := "trino" if c.kerberosRemoteServiceName != "" { remoteServiceName = c.kerberosRemoteServiceName } err = spnego.SetSPNEGOHeader(c.kerberosClient, req, remoteServiceName+"/"+req.URL.Hostname()) if err != nil { return nil, fmt.Errorf("error setting client SPNEGO header: %w", err) } } for k, v := range c.httpHeaders { req.Header[k] = v } for k, v := range hs { req.Header[k] = v } if c.auth != nil { pass, _ := c.auth.Password() req.SetBasicAuth(c.auth.Username(), pass) } return req, nil } func (c *Conn) roundTrip(ctx context.Context, req *http.Request) (*http.Response, error) { delay := 100 * time.Millisecond const maxDelayBetweenRequests = float64(15 * time.Second) timer := time.NewTimer(0) defer timer.Stop() for { select { case <-ctx.Done(): return nil, ctx.Err() case <-timer.C: resp, err := c.httpClient.Do(req) if err != nil { return nil, &ErrQueryFailed{Reason: err} } switch resp.StatusCode { case http.StatusOK: for src, dst := range responseToRequestHeaderMap { if v := resp.Header.Get(src); v != "" { c.httpHeaders.Set(dst, v) } } if v := resp.Header.Get(trinoAddedPrepareHeader); v != "" { c.httpHeaders.Add(preparedStatementHeader, v) } if v := resp.Header.Get(trinoDeallocatedPrepareHeader); v != "" { values := c.httpHeaders.Values(preparedStatementHeader) c.httpHeaders.Del(preparedStatementHeader) for _, v2 := range values { if !strings.HasPrefix(v2, v+"=") { c.httpHeaders.Add(preparedStatementHeader, v2) } } } if v := resp.Header.Get(trinoSetSessionHeader); v != "" { c.httpHeaders.Add(trinoSessionHeader, v) } if v := resp.Header.Get(trinoClearSessionHeader); v != "" { values := c.httpHeaders.Values(trinoSessionHeader) c.httpHeaders.Del(trinoSessionHeader) for _, v2 := range values { if !strings.HasPrefix(v2, v+"=") { c.httpHeaders.Add(trinoSessionHeader, v2) } } } for _, name := range unsupportedResponseHeaders { if v := resp.Header.Get(name); v != "" { return nil, ErrUnsupportedHeader } } return resp, nil case http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout: resp.Body.Close() timer.Reset(delay) delay = time.Duration(math.Min( float64(delay)*math.Phi, maxDelayBetweenRequests, )) continue default: return nil, newErrQueryFailedFromResponse(resp) } } } } // ErrQueryFailed indicates that a query to Trino failed. type ErrQueryFailed struct { StatusCode int Reason error } // Error implements the error interface. func (e *ErrQueryFailed) Error() string { return fmt.Sprintf("trino: query failed (%d %s): %q", e.StatusCode, http.StatusText(e.StatusCode), e.Reason) } // Unwrap implements the unwrap interface. func (e *ErrQueryFailed) Unwrap() error { return e.Reason } func newErrQueryFailedFromResponse(resp *http.Response) *ErrQueryFailed { const maxBytes = 8 * 1024 defer resp.Body.Close() qf := &ErrQueryFailed{StatusCode: resp.StatusCode} b, err := io.ReadAll(io.LimitReader(resp.Body, maxBytes)) if err != nil { qf.Reason = err return qf } reason := string(b) if resp.ContentLength > maxBytes { reason += "..." } qf.Reason = errors.New(reason) return qf } type driverStmt struct { conn *Conn query string user string nextURIs chan string httpResponses chan *http.Response queryResponses chan queryResponse statsCh chan QueryProgressInfo usingSpooledProtocol bool spoolingMaxOutOfOrderSegments int spoolingWorkerCount int spooledSegmentsMetadata chan spooledMetadata spooledSegmentsToDecode chan segmentToDecode decodedSegments chan decodedSegment segmentsToProccess chan segmentToProccess waitSegmentDecodersWorkers sync.WaitGroup waitDownloadSegmentsWorkers sync.WaitGroup cancelDownloadWorkers context.CancelFunc cancelDecodersWorkers context.CancelFunc spoolingRowsChannel chan []queryData spoolingProcesserDone chan struct{} segmentThrottleCh chan struct{} errors chan error doneCh chan struct{} segmentDispatcherDoneCh chan struct{} } type segmentToDecode struct { segmentIndex int encoding string data []byte metadata segmentMetadata } type decodedSegment struct { rowOffset int64 queryData []queryData } var ( _ driver.Stmt = &driverStmt{} _ driver.StmtQueryContext = &driverStmt{} _ driver.StmtExecContext = &driverStmt{} _ driver.NamedValueChecker = &driverStmt{} ) // Close closes statement just before releasing connection func (st *driverStmt) Close() error { if st.doneCh == nil { return nil } close(st.doneCh) if st.statsCh != nil { <-st.statsCh st.statsCh = nil } go func() { // drain errors chan to allow goroutines to write to it for range st.errors { } }() for range st.queryResponses { } for range st.httpResponses { } if st.cancelDownloadWorkers != nil { st.cancelDownloadWorkers() } if st.cancelDecodersWorkers != nil { st.cancelDecodersWorkers() } if st.spoolingRowsChannel != nil { for range st.spoolingRowsChannel { } } if st.decodedSegments != nil { for range st.decodedSegments { } } if st.spooledSegmentsToDecode != nil { for range st.spooledSegmentsToDecode { } } if st.spooledSegmentsMetadata != nil { for range st.spooledSegmentsMetadata { } } if st.segmentsToProccess != nil { for range st.segmentsToProccess { } } st.waitDownloadSegmentsWorkers.Wait() st.waitSegmentDecodersWorkers.Wait() close(st.nextURIs) close(st.errors) st.doneCh = nil st.cancelDownloadWorkers = nil st.spooledSegmentsMetadata = nil st.spooledSegmentsToDecode = nil st.cancelDecodersWorkers = nil st.segmentsToProccess = nil st.decodedSegments = nil st.spoolingRowsChannel = nil return nil } func (st *driverStmt) NumInput() int { return -1 } func (st *driverStmt) Exec(args []driver.Value) (driver.Result, error) { return nil, driver.ErrSkip } func (st *driverStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { sr, err := st.exec(ctx, args) if err != nil { return nil, err } rows := &driverRows{ ctx: ctx, stmt: st, queryID: sr.ID, nextURI: sr.NextURI, rowsAffected: sr.UpdateCount, statsCh: st.statsCh, doneCh: st.doneCh, } // consume all results, if there are any for err == nil { err = rows.fetch() } if err != nil && err != io.EOF { return nil, err } return rows, nil } func (st *driverStmt) CheckNamedValue(arg *driver.NamedValue) error { switch arg.Value.(type) { case nil: return nil case Numeric, trinoDate, trinoTime, trinoTimeTz, trinoTimestamp, time.Duration: return nil default: { if reflect.TypeOf(arg.Value).Kind() == reflect.Slice { return nil } if arg.Name == trinoRoleHeader { return nil } if arg.Name == trinoProgressCallbackParam { return nil } if arg.Name == trinoProgressCallbackPeriodParam { return nil } } } return driver.ErrSkip } type stmtResponse struct { ID string `json:"id"` InfoURI string `json:"infoUri"` NextURI string `json:"nextUri"` Stats stmtStats `json:"stats"` Error ErrTrino `json:"error"` UpdateType string `json:"updateType"` UpdateCount int64 `json:"updateCount"` } type stmtStats struct { State string `json:"state"` Scheduled bool `json:"scheduled"` Nodes int `json:"nodes"` TotalSplits int `json:"totalSplits"` QueuesSplits int `json:"queuedSplits"` RunningSplits int `json:"runningSplits"` CompletedSplits int `json:"completedSplits"` UserTimeMillis int `json:"userTimeMillis"` CPUTimeMillis int64 `json:"cpuTimeMillis"` WallTimeMillis int64 `json:"wallTimeMillis"` QueuedTimeMillis int64 `json:"queuedTimeMillis"` ElapsedTimeMillis int64 `json:"elapsedTimeMillis"` ProcessedRows int64 `json:"processedRows"` ProcessedBytes int64 `json:"processedBytes"` PhysicalInputBytes int64 `json:"physicalInputBytes"` PhysicalWrittenBytes int64 `json:"physicalWrittenBytes"` PeakMemoryBytes int64 `json:"peakMemoryBytes"` SpilledBytes int64 `json:"spilledBytes"` RootStage stmtStage `json:"rootStage"` ProgressPercentage jsonFloat64 `json:"progressPercentage"` RunningPercentage jsonFloat64 `json:"runningPercentage"` } type ErrTrino struct { Message string `json:"message"` SqlState string `json:"sqlState"` ErrorCode int `json:"errorCode"` ErrorName string `json:"errorName"` ErrorType string `json:"errorType"` ErrorLocation ErrorLocation `json:"errorLocation"` FailureInfo FailureInfo `json:"failureInfo"` } func (i ErrTrino) Error() string { return i.ErrorType + ": " + i.Message } type ErrorLocation struct { LineNumber int `json:"lineNumber"` ColumnNumber int `json:"columnNumber"` } type FailureInfo struct { Type string `json:"type"` Message string `json:"message"` Cause *FailureInfo `json:"cause"` Suppressed []FailureInfo `json:"suppressed"` Stack []string `json:"stack"` ErrorInfo ErrorInfo `json:"errorInfo"` ErrorLocation ErrorLocation `json:"errorLocation"` } type ErrorInfo struct { Code int `json:"code"` Name string `json:"name"` Type string `json:"type"` } func (i ErrorInfo) Error() string { return fmt.Sprintf("%s: %s (%d)", i.Type, i.Name, i.Code) } type stmtStage struct { StageID string `json:"stageId"` State string `json:"state"` Done bool `json:"done"` Nodes int `json:"nodes"` TotalSplits int `json:"totalSplits"` QueuedSplits int `json:"queuedSplits"` RunningSplits int `json:"runningSplits"` CompletedSplits int `json:"completedSplits"` UserTimeMillis int `json:"userTimeMillis"` CPUTimeMillis int `json:"cpuTimeMillis"` WallTimeMillis int `json:"wallTimeMillis"` ProcessedRows int `json:"processedRows"` ProcessedBytes int `json:"processedBytes"` SubStages []stmtStage `json:"subStages"` } type jsonFloat64 float64 func (f *jsonFloat64) UnmarshalJSON(data []byte) error { var v float64 err := json.Unmarshal(data, &v) if err != nil { var jsonErr *json.UnmarshalTypeError if errors.As(err, &jsonErr) { if f != nil { *f = 0 } return nil } return err } p := (*float64)(f) *p = v return nil } var _ json.Unmarshaler = new(jsonFloat64) func (st *driverStmt) Query(args []driver.Value) (driver.Rows, error) { return nil, driver.ErrSkip } func (st *driverStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { sr, err := st.exec(ctx, args) if err != nil { return nil, err } rows := &driverRows{ ctx: ctx, stmt: st, queryID: sr.ID, nextURI: sr.NextURI, statsCh: st.statsCh, doneCh: st.doneCh, } if err = rows.fetch(); err != nil && err != io.EOF { return nil, err } return rows, nil } func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmtResponse, error) { query := st.query hs := make(http.Header) // Ensure the server returns timestamps preserving their precision, without truncating them to timestamp(3). hs.Add("X-Trino-Client-Capabilities", "PARAMETRIC_DATETIME") if len(args) > 0 { var ss []string for _, arg := range args { if arg.Name == trinoProgressCallbackParam { st.conn.progressUpdater = arg.Value.(ProgressUpdater) continue } if arg.Name == trinoProgressCallbackPeriodParam { st.conn.progressUpdaterPeriod.Period = arg.Value.(time.Duration) continue } if st.conn.forwardAuthorizationHeader && arg.Name == accessTokenConfig { token := arg.Value.(string) hs.Add(authorizationHeader, getAuthorization(token)) continue } if arg.Name == trinoEncoding { hs.Add(trinoQueryDataEncodingHeader, arg.Value.(string)) continue } if arg.Name == trinoSpoolingWorkerCount { numberOfWorkers, err := strconv.Atoi(arg.Value.(string)) if err != nil { return nil, err } st.spoolingWorkerCount = numberOfWorkers continue } if arg.Name == trinoMaxOutOfOrdersSegments { maxSegmentsOutOfOrder, err := strconv.Atoi(arg.Value.(string)) if err != nil { return nil, err } st.spoolingMaxOutOfOrderSegments = maxSegmentsOutOfOrder continue } if strings.HasPrefix(arg.Name, trinoHeaderPrefix) { headerValue, err := formatHeaderValue(arg.Name, arg.Value) if err != nil { return nil, err } if arg.Name == trinoUserHeader { st.user = headerValue } if arg.Name == trinoRoleHeader { st.conn.httpHeaders.Set(trinoRoleHeader, headerValue) } hs.Add(arg.Name, headerValue) } else { s, err := Serial(arg.Value) if err != nil { return nil, err } if st.conn.useExplicitPrepare && hs.Get(preparedStatementHeader) == "" { for _, v := range st.conn.httpHeaders.Values(preparedStatementHeader) { hs.Add(preparedStatementHeader, v) } hs.Add(preparedStatementHeader, preparedStatementName+"="+url.QueryEscape(st.query)) } ss = append(ss, s) } } if (st.conn.progressUpdater != nil && st.conn.progressUpdaterPeriod.Period == 0) || (st.conn.progressUpdater == nil && st.conn.progressUpdaterPeriod.Period > 0) { return nil, ErrInvalidProgressCallbackHeader } if len(ss) > 0 { if st.conn.useExplicitPrepare { query = "EXECUTE " + preparedStatementName + " USING " + strings.Join(ss, ", ") } else { query = "EXECUTE IMMEDIATE " + formatStringLiteral(st.query) + " USING " + strings.Join(ss, ", ") } } } if st.spoolingWorkerCount > st.spoolingMaxOutOfOrderSegments { return nil, fmt.Errorf("spooling worker cannot be greater than max out of order segments allowed. spooling workers: %d, allowed out of order segments: %d", st.spoolingWorkerCount, st.spoolingMaxOutOfOrderSegments) } if hs.Get(trinoQueryDataEncodingHeader) == "" { hs.Add(trinoQueryDataEncodingHeader, defaulttrinoEncoding) } var cancel context.CancelFunc = func() {} if st.conn.queryTimeout != nil { ctx, cancel = context.WithTimeout(ctx, *st.conn.queryTimeout) } else if _, ok := ctx.Deadline(); !ok { ctx, cancel = context.WithTimeout(ctx, DefaultQueryTimeout) } req, err := st.conn.newRequest(ctx, "POST", st.conn.baseURL+"/v1/statement", strings.NewReader(query), hs) if err != nil { cancel() return nil, err } resp, err := st.conn.roundTrip(ctx, req) if err != nil { cancel() return nil, err } defer resp.Body.Close() var sr stmtResponse d := json.NewDecoder(resp.Body) d.UseNumber() err = d.Decode(&sr) if err != nil { cancel() return nil, fmt.Errorf("trino: %w", err) } st.doneCh = make(chan struct{}) st.nextURIs = make(chan string) st.httpResponses = make(chan *http.Response) st.queryResponses = make(chan queryResponse) st.errors = make(chan error) go func() { defer close(st.httpResponses) for { select { case nextURI := <-st.nextURIs: if nextURI == "" { return } hs := make(http.Header) hs.Add(trinoUserHeader, st.user) req, err := st.conn.newRequest(ctx, "GET", nextURI, nil, hs) if err != nil { if ctx.Err() == context.Canceled { st.errors <- context.Canceled return } st.errors <- err return } resp, err := st.conn.roundTrip(ctx, req) if err != nil { if ctx.Err() == context.Canceled { st.errors <- context.Canceled return } st.errors <- err return } select { case st.httpResponses <- resp: case <-st.doneCh: return } case <-st.doneCh: return } } }() go func() { defer close(st.queryResponses) defer cancel() for { select { case resp := <-st.httpResponses: if resp == nil { return } var qresp queryResponse d := json.NewDecoder(resp.Body) d.UseNumber() err = d.Decode(&qresp) if err != nil { st.errors <- fmt.Errorf("trino: %w", err) return } err = resp.Body.Close() if err != nil { st.errors <- err return } err = handleResponseError(resp.StatusCode, qresp.Error) if err != nil { st.errors <- err return } select { case st.nextURIs <- qresp.NextURI: case <-st.doneCh: return } select { case st.queryResponses <- qresp: case <-st.doneCh: return } case <-st.doneCh: return } } }() st.nextURIs <- sr.NextURI if st.conn.progressUpdater != nil { st.statsCh = make(chan QueryProgressInfo) // progress updater go func go func() { for { select { case stats := <-st.statsCh: st.conn.progressUpdater.Update(stats) case <-st.doneCh: close(st.statsCh) return } } }() // initial progress callback call srStats := QueryProgressInfo{ QueryId: sr.ID, QueryStats: sr.Stats, } select { case st.statsCh <- srStats: default: // ignore when can't send stats } st.conn.progressUpdaterPeriod.LastCallbackTime = time.Now() st.conn.progressUpdaterPeriod.LastQueryState = sr.Stats.State } return &sr, handleResponseError(resp.StatusCode, sr.Error) } type SegmentFetcher struct { ctx context.Context httpClient http.Client spooledMetadata spooledMetadata } func (sf *SegmentFetcher) roundTrip(req *http.Request) (*http.Response, error) { delay := 200 * time.Millisecond const maxRetries = 5 retries := 0 timer := time.NewTimer(0) defer timer.Stop() for { select { case <-timer.C: resp, err := sf.httpClient.Do(req) if err != nil { var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { retries++ if retries > maxRetries { return nil, &ErrQueryFailed{Reason: fmt.Errorf("max retries reached: %w", err)} } delay = time.Duration(float64(delay) * math.Phi) timer.Reset(delay) continue } return nil, &ErrQueryFailed{Reason: err} } switch resp.StatusCode { case http.StatusOK: return resp, nil case http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout: resp.Body.Close() retries++ if retries > maxRetries { return nil, &ErrQueryFailed{Reason: fmt.Errorf("max retries reached for status code %d", resp.StatusCode)} } delay = time.Duration(float64(delay) * math.Phi) timer.Reset(delay) continue default: return nil, newErrQueryFailedFromResponse(resp) } } } } func (sf *SegmentFetcher) fetchSegment() ([]byte, error) { req, err := http.NewRequestWithContext(sf.ctx, "GET", sf.spooledMetadata.uri, nil) if err != nil { return nil, err } for k, v := range sf.spooledMetadata.headers { headerSlice, ok := v.([]interface{}) if !ok { return nil, fmt.Errorf("unsupported header type %T", v) } if len(headerSlice) == 0 { continue } if len(headerSlice) > 1 { return nil, fmt.Errorf("multiple values for header %s", k) } header, ok := headerSlice[0].(string) if !ok { return nil, fmt.Errorf("unsupported header value type %T", headerSlice[0]) } req.Header.Add(k, header) } resp, err := sf.roundTrip(req) if err != nil { return nil, fmt.Errorf("error fetching segment from uri '%s': %v", sf.spooledMetadata.uri, err) } data, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("error reading response body: %v", err) } //acknowledge the segment read go func() { // TODO: handle ack erros ackReq, err := http.NewRequestWithContext(sf.ctx, "GET", sf.spooledMetadata.ackUri, nil) if err != nil { return } for k, values := range req.Header { for _, v := range values { ackReq.Header.Add(k, v) } } resp, err := sf.httpClient.Do(ackReq) if err != nil { return } resp.Body.Close() }() return data, nil } func formatStringLiteral(query string) string { return "'" + strings.ReplaceAll(query, "'", "''") + "'" } type driverRows struct { ctx context.Context stmt *driverStmt queryID string nextURI string err error rowindex int columns []string coltype []*typeConverter data []queryData rowsAffected int64 statsCh chan QueryProgressInfo doneCh chan struct{} } var _ driver.Rows = &driverRows{} var _ driver.Result = &driverRows{} var _ driver.RowsColumnTypeScanType = &driverRows{} var _ driver.RowsColumnTypeDatabaseTypeName = &driverRows{} var _ driver.RowsColumnTypeLength = &driverRows{} var _ driver.RowsColumnTypePrecisionScale = &driverRows{} // Close closes the rows iterator. func (qr *driverRows) Close() error { if qr.err == sql.ErrNoRows || qr.err == io.EOF { return nil } qr.err = io.EOF if !qr.stmt.usingSpooledProtocol { err := qr.fetch() if err != nil && err != io.EOF { return err } if qr.nextURI == "" { return nil } } else { select { case _, ok := <-qr.stmt.spoolingRowsChannel: if !ok { // channel is closed, all data has been consumed return nil } case <-time.NewTimer(100 * time.Millisecond).C: // no data is ready } } hs := make(http.Header) if qr.stmt.user != "" { hs.Add(trinoUserHeader, qr.stmt.user) } ctx, cancel := context.WithTimeout(context.WithoutCancel(qr.ctx), DefaultCancelQueryTimeout) defer cancel() req, err := qr.stmt.conn.newRequest(ctx, "DELETE", qr.stmt.conn.baseURL+"/v1/query/"+url.PathEscape(qr.queryID), nil, hs) if err != nil { return err } resp, err := qr.stmt.conn.roundTrip(ctx, req) if err != nil { qferr, ok := err.(*ErrQueryFailed) if ok && qferr.StatusCode == http.StatusNoContent { qr.nextURI = "" return nil } return err } resp.Body.Close() return qr.err } // Columns returns the names of the columns. func (qr *driverRows) Columns() []string { if qr.err != nil { return []string{} } if qr.columns == nil { if err := qr.fetch(); err != nil && err != io.EOF { qr.err = err return []string{} } } return qr.columns } func (qr *driverRows) ColumnTypeDatabaseTypeName(index int) string { typeName := qr.coltype[index].parsedType[0] if typeName == "map" || typeName == "array" || typeName == "row" { typeName = qr.coltype[index].typeName } return strings.ToUpper(typeName) } func (qr *driverRows) ColumnTypeScanType(index int) reflect.Type { return qr.coltype[index].scanType } func (qr *driverRows) ColumnTypeLength(index int) (int64, bool) { return qr.coltype[index].size.value, qr.coltype[index].size.hasValue } func (qr *driverRows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { return qr.coltype[index].precision.value, qr.coltype[index].scale.value, qr.coltype[index].precision.hasValue } // Next is called to populate the next row of data into // the provided slice. The provided slice will be the same // size as the Columns() are wide. // // Next should return io.EOF when there are no more rows. func (qr *driverRows) Next(dest []driver.Value) error { if qr.err != nil { return qr.err } if !qr.stmt.usingSpooledProtocol && (qr.columns == nil || qr.rowindex >= len(qr.data)) { if qr.nextURI == "" { qr.err = io.EOF return qr.err } if err := qr.fetch(); err != nil { qr.err = err return err } } else if qr.stmt.usingSpooledProtocol && (qr.rowindex >= len(qr.data) || qr.data == nil) { var ok bool select { // The spoolingRowsChannel is initialized in startSpoolingProtocolWorkers, // which is called by fetch() when the first query response indicates // the spooling protocol (i.e., the response contains segments). // At that point, usingSpooledProtocol is set to true and the channel is created. case qr.data, ok = <-qr.stmt.spoolingRowsChannel: if !ok { qr.err = io.EOF return qr.err } qr.rowindex = 0 case err := <-qr.stmt.errors: if err == nil { // Channel was closed, which means the statement // or rows were closed. qr.err = io.EOF return qr.err } else if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { qr.Close() } qr.stmt.cancelDecodersWorkers() qr.stmt.cancelDownloadWorkers() qr.err = err return qr.err } } return qr.next(dest) } func (qr *driverRows) next(dest []driver.Value) error { if len(qr.coltype) == 0 { qr.err = sql.ErrNoRows return qr.err } for i, v := range qr.coltype { if i > len(dest)-1 { break } vv, err := v.ConvertValue(qr.data[qr.rowindex][i]) if err != nil { qr.err = err return err } dest[i] = vv } qr.rowindex++ return nil } // LastInsertId returns the database's auto-generated ID // after, for example, an INSERT into a table with primary // key. func (qr driverRows) LastInsertId() (int64, error) { return 0, ErrOperationNotSupported } // RowsAffected returns the number of rows affected by the query. func (qr driverRows) RowsAffected() (int64, error) { return qr.rowsAffected, nil } type queryResponse struct { ID string `json:"id"` InfoURI string `json:"infoUri"` PartialCancelURI string `json:"partialCancelUri"` NextURI string `json:"nextUri"` Columns []queryColumn `json:"columns"` Data interface{} `json:"data"` Stats stmtStats `json:"stats"` Error ErrTrino `json:"error"` UpdateType string `json:"updateType"` UpdateCount int64 `json:"updateCount"` } type segmentMetadata struct { rowOffset int64 rowsCount int64 segmentSize int64 uncompressedSize int64 } type spooledMetadata struct { uri string ackUri string encoding string headers map[string]interface{} metadata segmentMetadata } func parseSpooledMetadata(segment map[string]interface{}, segmentIndex int, segmentMetadata segmentMetadata, encoding string) (spooledMetadata, error) { result := spooledMetadata{ metadata: segmentMetadata, encoding: encoding, headers: make(map[string]interface{}), } var ok bool result.uri, ok = segment["uri"].(string) if !ok || result.uri == "" { return spooledMetadata{}, fmt.Errorf("missing or invalid 'uri' field in spooled segment at index %d", segmentIndex) } result.ackUri, ok = segment["ackUri"].(string) if !ok || result.ackUri == "" { return spooledMetadata{}, fmt.Errorf("missing or invalid 'ackUri' field in spooled segment at index %d", segmentIndex) } if rawHeaders, exists := segment["headers"]; exists { result.headers, ok = rawHeaders.(map[string]interface{}) if !ok { return spooledMetadata{}, fmt.Errorf("invalid 'headers' field in spooled segment at index %d: expected map[string]interface{}", segmentIndex) } } return result, nil } func parseSegmentMetadata(metadata map[string]interface{}) (segmentMetadata, error) { result := segmentMetadata{ rowOffset: 0, rowsCount: 0, segmentSize: 0, uncompressedSize: 0, } var err error // Mandatory field if result.rowOffset, err = getInt64(metadata, "rowOffset"); err != nil { return segmentMetadata{}, err } // Mandatory field if result.segmentSize, err = getInt64(metadata, "segmentSize"); err != nil { return segmentMetadata{}, err } if result.uncompressedSize, err = getOptionalInt64(metadata, "uncompressedSize"); err != nil { return segmentMetadata{}, err } // Bug: rowsCount was wrongly not enforced as a mandatory field on Trino response. Fixed on 475 release if result.rowsCount, err = getOptionalInt64(metadata, "rowsCount"); err != nil { return segmentMetadata{}, err } return result, nil } func getInt64(metadata map[string]interface{}, key string) (int64, error) { val, exists := metadata[key] if !exists { return 0, fmt.Errorf("%s is missing in segment metadata", key) } return parseInt64(val, key) } func getOptionalInt64(metadata map[string]interface{}, key string) (int64, error) { val, exists := metadata[key] if !exists { return 0, nil } return parseInt64(val, key) } func parseInt64(val interface{}, key string) (int64, error) { num, ok := val.(json.Number) if !ok { return 0, fmt.Errorf("invalid type for %s in segment metadata, expected json.Number, got %T", key, val) } n, err := num.Int64() if err != nil { return 0, fmt.Errorf("error converting %s to int64: %v", key, err) } return n, nil } func decodeSegment(data []byte, encoding string, metadata segmentMetadata) ([]queryData, error) { if int64(len(data)) != metadata.segmentSize { return nil, fmt.Errorf("segment size mismatch: expected %d bytes, got %d bytes", metadata.segmentSize, len(data)) } decompressedSegment, err := decompressSegment(data, encoding, metadata) if err != nil { return nil, err } var queryDataList = make([]queryData, metadata.rowsCount) decoder := json.NewDecoder(bytes.NewReader(decompressedSegment)) decoder.UseNumber() err = decoder.Decode(&queryDataList) if err != nil { return nil, fmt.Errorf("failed to decode segment into JSON at rowOffset %d: %v", metadata.rowOffset, err) } return queryDataList, nil } func decompressSegment(data []byte, encoding string, metadata segmentMetadata) ([]byte, error) { if metadata.uncompressedSize == 0 { return data, nil } var decompressedData []byte switch encoding { case "json+zstd": zstdDecoder, err := zstd.NewReader(nil) if err != nil { return nil, fmt.Errorf("error creating zstd reader: %w", err) } defer zstdDecoder.Close() dst := make([]byte, 0, metadata.uncompressedSize) decompressedData, err = zstdDecoder.DecodeAll(data, dst) if err != nil { return nil, fmt.Errorf("failed to decompress zstd segment at rowOffset %d: %v", metadata.rowOffset, err) } case "json+lz4": decompressedData = make([]byte, metadata.uncompressedSize) n, err := lz4.UncompressBlock(data, decompressedData) if err != nil { return nil, fmt.Errorf("failed to decompress LZ4 segment at rowOffset %d: %v", metadata.rowOffset, err) } decompressedData = decompressedData[:n] default: return nil, fmt.Errorf("unsupported segment encoder: %s", encoding) } if int64(len(decompressedData)) != metadata.uncompressedSize { return nil, fmt.Errorf("decompressed size mismatch: expected %d bytes, got %d bytes", metadata.uncompressedSize, len(decompressedData)) } return decompressedData, nil } type queryColumn struct { Name string `json:"name"` Type string `json:"type"` TypeSignature typeSignature `json:"typeSignature"` } type queryData []interface{} type namedTypeSignature struct { FieldName rowFieldName `json:"fieldName"` TypeSignature typeSignature `json:"typeSignature"` } type rowFieldName struct { Name string `json:"name"` } type typeSignature struct { RawType string `json:"rawType"` Arguments []typeArgument `json:"arguments"` } type typeKind string const ( KIND_TYPE = typeKind("TYPE") KIND_NAMED_TYPE = typeKind("NAMED_TYPE") KIND_LONG = typeKind("LONG") KIND_VARIABLE = typeKind("VARIABLE") ) type typeArgument struct { // Kind determines if the typeSignature, namedTypeSignature, or long field has a value Kind typeKind `json:"kind"` Value json.RawMessage `json:"value"` // typeSignature decoded from Value when Kind is TYPE typeSignature typeSignature // namedTypeSignature decoded from Value when Kind is NAMED_TYPE namedTypeSignature namedTypeSignature // long decoded from Value when Kind is LONG long int64 } func handleResponseError(status int, respErr ErrTrino) error { switch respErr.ErrorName { case "": return nil case "USER_CANCELLED": return ErrQueryCancelled default: return &ErrQueryFailed{ StatusCode: status, Reason: &respErr, } } } func (qr *driverRows) startOrderedSegmentStreamer() { go func() { defer close(qr.stmt.spoolingRowsChannel) defer close(qr.stmt.spoolingProcesserDone) consumed := 0 buffer := make([]decodedSegment, 0, qr.stmt.spoolingMaxOutOfOrderSegments) var nextExpectedOffset int64 = 0 for { select { case segment, ok := <-qr.stmt.decodedSegments: if !ok { return } buffer = append(buffer, segment) if nextExpectedOffset != segment.rowOffset { if len(buffer) >= qr.stmt.spoolingMaxOutOfOrderSegments { qr.stmt.errors <- fmt.Errorf( "all %d out-of-order segments buffered (limit: %d). This indicates a bug or inconsistency in the segments metadata response (e.g., missing, duplicate, or misordered segments, or row offsets not matching the expected sequence)", len(buffer), qr.stmt.spoolingMaxOutOfOrderSegments) } continue } consumed = 0 slices.SortFunc(buffer, func(a, b decodedSegment) int { if a.rowOffset < b.rowOffset { return -1 } if a.rowOffset > b.rowOffset { return 1 } return 0 }) for consumed < len(buffer) && buffer[consumed].rowOffset == nextExpectedOffset { select { case qr.stmt.spoolingRowsChannel <- buffer[consumed].queryData: case <-qr.doneCh: return } // release reserved slot select { case <-qr.stmt.segmentThrottleCh: case <-qr.doneCh: return } nextExpectedOffset += int64(len(buffer[consumed].queryData)) consumed++ } copy(buffer[0:], buffer[consumed:]) buffer = buffer[:len(buffer)-consumed] case <-qr.doneCh: return } } }() } func (qr *driverRows) fetch() error { var qresp queryResponse var err error for { select { case qresp = <-qr.stmt.queryResponses: if qresp.ID == "" { return io.EOF } err = qr.initColumns(&qresp) if err != nil { return err } qr.rowindex = 0 qr.nextURI = qresp.NextURI switch data := qresp.Data.(type) { case []interface{}: // direct protocol qr.data = make([]queryData, len(data)) for i, item := range data { if row, ok := item.([]interface{}); ok { qr.data[i] = row } else { return fmt.Errorf("unexpected data type for row at index %d: expected []interface{}, got %T", i, item) } } case map[string]interface{}: // spooling protocol qr.stmt.startSpoolingProtocolWorkers(qr.ctx) qr.startOrderedSegmentStreamer() err := qr.queueSpoolingSegments(data) qr.proccessSpollingSegments() return err case nil: qr.data = nil } qr.rowsAffected = qresp.UpdateCount qr.scheduleProgressUpdate(qresp.ID, qresp.Stats) if len(qr.data) != 0 { return nil } case err = <-qr.stmt.errors: if err == nil { // Channel was closed, which means the statement // or rows were closed. err = io.EOF } else if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { qr.Close() } qr.err = err return err } } } func (st *driverStmt) startSpoolingProtocolWorkers(ctx context.Context) { st.usingSpooledProtocol = true if st.spoolingWorkerCount == 0 { st.spoolingWorkerCount = defaultSpoolingDownloadWorkers } if st.spoolingMaxOutOfOrderSegments == 0 { st.spoolingMaxOutOfOrderSegments = defaultallowedOutOfOrder } downloadSegmentsCtx, cancelDownloadWorkers := context.WithCancel(context.WithoutCancel(ctx)) st.cancelDownloadWorkers = cancelDownloadWorkers decodeSegmentCtx, cancelDecodersWorkers := context.WithCancel(context.WithoutCancel(ctx)) st.cancelDecodersWorkers = cancelDecodersWorkers st.segmentsToProccess = make(chan segmentToProccess, 1000) st.spooledSegmentsMetadata = make(chan spooledMetadata, st.spoolingMaxOutOfOrderSegments) st.spooledSegmentsToDecode = make(chan segmentToDecode, st.spoolingMaxOutOfOrderSegments) st.spoolingRowsChannel = make(chan []queryData) st.spoolingProcesserDone = make(chan struct{}) st.segmentDispatcherDoneCh = make(chan struct{}) st.segmentThrottleCh = make(chan struct{}, st.spoolingMaxOutOfOrderSegments) st.decodedSegments = make(chan decodedSegment) st.startSegmentDispatcher() st.startDownloadSegmentsWorkers(downloadSegmentsCtx) st.startSegmentsDecodersWorkers(decodeSegmentCtx) } func (st *driverStmt) startSegmentDispatcher() { go func() { defer close(st.segmentDispatcherDoneCh) defer close(st.segmentThrottleCh) for { select { case segmentToProccess, ok := <-st.segmentsToProccess: if !ok { return } // segmentThrottleCh blocks if there are too many out-of-order segments. // Once all currently downloaded segments are downloaded, decoded, // and can be ordered, this channel will be drained. select { case st.segmentThrottleCh <- struct{}{}: case <-st.doneCh: return } segmentMetadata, exists := segmentToProccess.segment["metadata"] if !exists { st.errors <- fmt.Errorf("metadata is missing in segment at index %d", segmentToProccess.segmentIndex) } typedMetadata, ok := segmentMetadata.(map[string]interface{}) if !ok { st.errors <- fmt.Errorf("metadata is invalid or cannot be parsed as map[string]interface{} in segment at index %d", segmentToProccess.segmentIndex) } metadata, err := parseSegmentMetadata(typedMetadata) if err != nil { st.errors <- err } switch segmentToProccess.segment["type"] { case "inline": decodedBytes, err := base64.StdEncoding.DecodeString(segmentToProccess.segment["data"].(string)) if err != nil { st.errors <- fmt.Errorf("error decoding base64 data in inline segment at index %d: %v", segmentToProccess.segmentIndex, err) } st.spooledSegmentsToDecode <- segmentToDecode{ segmentIndex: 0, encoding: segmentToProccess.encoding, data: decodedBytes, metadata: metadata, } case "spooled": spooledMetadata, err := parseSpooledMetadata(segmentToProccess.segment, 0, metadata, segmentToProccess.encoding) if err != nil { st.errors <- err } st.spooledSegmentsMetadata <- spooledMetadata } case <-st.doneCh: return } } }() } func (st *driverStmt) startDownloadSegmentsWorkers(ctx context.Context) { st.waitDownloadSegmentsWorkers.Add(st.spoolingWorkerCount) for i := 0; i < st.spoolingWorkerCount; i++ { go func() { defer st.waitDownloadSegmentsWorkers.Done() for { select { case metadata, ok := <-st.spooledSegmentsMetadata: if !ok { return } segmentFetcher := &SegmentFetcher{ ctx: ctx, httpClient: st.conn.httpClient, spooledMetadata: metadata, } segment, err := segmentFetcher.fetchSegment() if err != nil { st.errors <- err return } select { case st.spooledSegmentsToDecode <- segmentToDecode{ encoding: metadata.encoding, data: segment, metadata: metadata.metadata, }: case <-st.doneCh: return case <-ctx.Done(): return } case <-st.doneCh: return case <-ctx.Done(): return } } }() } } func (st *driverStmt) startSegmentsDecodersWorkers(ctx context.Context) { st.waitSegmentDecodersWorkers.Add(st.spoolingWorkerCount) for i := 0; i < st.spoolingWorkerCount; i++ { go func() { defer st.waitSegmentDecodersWorkers.Done() for { select { case segmentToDecode, ok := <-st.spooledSegmentsToDecode: if !ok { return } segment, err := decodeSegment(segmentToDecode.data, segmentToDecode.encoding, segmentToDecode.metadata) if err != nil { st.cancelDecodersWorkers() st.errors <- fmt.Errorf("failed to decode spooled segment at index %d: %v", segmentToDecode.segmentIndex, err) return } select { case st.decodedSegments <- decodedSegment{ rowOffset: segmentToDecode.metadata.rowOffset, queryData: segment, }: case <-st.doneCh: return case <-ctx.Done(): return } case <-st.doneCh: return case <-ctx.Done(): return } } }() } } func (qr *driverRows) proccessSpollingSegments() { go func() { var qresp queryResponse var err error for { select { case qresp = <-qr.stmt.queryResponses: if qresp.ID == "" { qr.waitForAllSpoolingWorkersFinish() return } err = qr.initColumns(&qresp) if err != nil { qr.stmt.errors <- err } switch data := qresp.Data.(type) { case map[string]interface{}: if err := qr.queueSpoolingSegments(data); err != nil { qr.stmt.errors <- err } case nil: // do nothing: trino response without data (e.g only status information) default: qr.stmt.errors <- fmt.Errorf("unexpected data type for row at index %s: expected map[string]interface{}, got %T", qresp.ID, data) } qr.scheduleProgressUpdate(qresp.ID, qresp.Stats) } } }() } func (qr *driverRows) waitForAllSpoolingWorkersFinish() { close(qr.stmt.segmentsToProccess) <-qr.stmt.segmentDispatcherDoneCh close(qr.stmt.spooledSegmentsMetadata) qr.stmt.waitDownloadSegmentsWorkers.Wait() close(qr.stmt.spooledSegmentsToDecode) qr.stmt.waitSegmentDecodersWorkers.Wait() close(qr.stmt.decodedSegments) <-qr.stmt.spoolingProcesserDone } type segmentToProccess struct { segmentIndex int encoding string segment map[string]interface{} } func (qr *driverRows) queueSpoolingSegments(data map[string]interface{}) error { encoding, ok := data["encoding"].(string) if !ok { return fmt.Errorf("invalid or missing 'encoding' field on spooling protocol, expected string") } segments, ok := data["segments"].([]interface{}) if !ok { return fmt.Errorf("invalid or missing 'segments' field on spooling protocol, expected []interface{}") } for segmentIndex, segment := range segments { segment, ok := segment.(map[string]interface{}) if !ok { return fmt.Errorf("segment at index %d is invalid: expected map[string]interface{}, got %T", segmentIndex, segment) } qr.stmt.segmentsToProccess <- segmentToProccess{ segmentIndex: segmentIndex, encoding: encoding, segment: segment, } } return nil } func unmarshalArguments(signature *typeSignature) error { for i, argument := range signature.Arguments { var payload interface{} switch argument.Kind { case KIND_TYPE: payload = &(signature.Arguments[i].typeSignature) case KIND_NAMED_TYPE: payload = &(signature.Arguments[i].namedTypeSignature) case KIND_LONG: payload = &(signature.Arguments[i].long) } err := json.Unmarshal(argument.Value, payload) if err != nil { return err } switch argument.Kind { case KIND_TYPE: err = unmarshalArguments(&(signature.Arguments[i].typeSignature)) case KIND_NAMED_TYPE: err = unmarshalArguments(&(signature.Arguments[i].namedTypeSignature.TypeSignature)) } if err != nil { return err } } return nil } func (qr *driverRows) initColumns(qresp *queryResponse) error { if qr.columns != nil || len(qresp.Columns) == 0 { return nil } var err error for i := range qresp.Columns { err = unmarshalArguments(&(qresp.Columns[i].TypeSignature)) if err != nil { return fmt.Errorf("error decoding column type signature: %w", err) } } qr.columns = make([]string, len(qresp.Columns)) qr.coltype = make([]*typeConverter, len(qresp.Columns)) for i, col := range qresp.Columns { err = unmarshalArguments(&(qresp.Columns[i].TypeSignature)) if err != nil { return fmt.Errorf("error decoding column type signature: %w", err) } qr.columns[i] = col.Name qr.coltype[i], err = newTypeConverter(col.Type, col.TypeSignature) if err != nil { return err } } return nil } func (qr *driverRows) scheduleProgressUpdate(id string, stats stmtStats) { if qr.stmt.conn.progressUpdater == nil { return } qrStats := QueryProgressInfo{ QueryId: id, QueryStats: stats, } currentTime := time.Now() diff := currentTime.Sub(qr.stmt.conn.progressUpdaterPeriod.LastCallbackTime) period := qr.stmt.conn.progressUpdaterPeriod.Period // Check if period has not passed yet AND if query state did not change if diff < period && qr.stmt.conn.progressUpdaterPeriod.LastQueryState == qrStats.QueryStats.State { return } select { case qr.statsCh <- qrStats: default: // ignore when can't send stats } qr.stmt.conn.progressUpdaterPeriod.LastCallbackTime = currentTime qr.stmt.conn.progressUpdaterPeriod.LastQueryState = qrStats.QueryStats.State } type typeConverter struct { typeName string parsedType []string scanType reflect.Type precision optionalInt64 scale optionalInt64 size optionalInt64 } type optionalInt64 struct { value int64 hasValue bool } func newOptionalInt64(value int64) optionalInt64 { return optionalInt64{value: value, hasValue: true} } func newTypeConverter(typeName string, signature typeSignature) (*typeConverter, error) { result := &typeConverter{ typeName: typeName, parsedType: getNestedTypes([]string{}, signature), } var err error result.scanType, err = getScanType(result.parsedType) if err != nil { return nil, err } switch signature.RawType { case "char", "varchar": if len(signature.Arguments) > 0 { if signature.Arguments[0].Kind != KIND_LONG { return nil, ErrInvalidResponseType } result.size = newOptionalInt64(signature.Arguments[0].long) } case "decimal": if len(signature.Arguments) > 0 { if signature.Arguments[0].Kind != KIND_LONG { return nil, ErrInvalidResponseType } result.precision = newOptionalInt64(signature.Arguments[0].long) } if len(signature.Arguments) > 1 { if signature.Arguments[1].Kind != KIND_LONG { return nil, ErrInvalidResponseType } result.scale = newOptionalInt64(signature.Arguments[1].long) } case "time", "time with time zone", "timestamp", "timestamp with time zone": if len(signature.Arguments) > 0 { if signature.Arguments[0].Kind != KIND_LONG { return nil, ErrInvalidResponseType } result.precision = newOptionalInt64(signature.Arguments[0].long) } } return result, nil } func getNestedTypes(types []string, signature typeSignature) []string { types = append(types, signature.RawType) if len(signature.Arguments) == 1 { switch signature.Arguments[0].Kind { case KIND_TYPE: types = getNestedTypes(types, signature.Arguments[0].typeSignature) case KIND_NAMED_TYPE: types = getNestedTypes(types, signature.Arguments[0].namedTypeSignature.TypeSignature) } } return types } func getScanType(typeNames []string) (reflect.Type, error) { var v interface{} switch typeNames[0] { case "boolean": v = sql.NullBool{} case "json", "char", "varchar", "interval year to month", "interval day to second", "decimal", "ipaddress", "uuid", "unknown": v = sql.NullString{} case "varbinary": v = []byte{} case "tinyint", "smallint": v = sql.NullInt32{} case "integer": v = sql.NullInt32{} case "bigint": v = sql.NullInt64{} case "real", "double": v = sql.NullFloat64{} case "date", "time", "time with time zone", "timestamp", "timestamp with time zone": v = sql.NullTime{} case "map": v = NullMap{} case "array": if len(typeNames) <= 1 { return nil, ErrInvalidResponseType } switch typeNames[1] { case "boolean": v = NullSliceBool{} case "json", "char", "varchar", "varbinary", "interval year to month", "interval day to second", "decimal", "ipaddress", "uuid", "unknown": v = NullSliceString{} case "tinyint", "smallint", "integer", "bigint": v = NullSliceInt64{} case "real", "double": v = NullSliceFloat64{} case "date", "time", "time with time zone", "timestamp", "timestamp with time zone": v = NullSliceTime{} case "map": v = NullSliceMap{} case "array": if len(typeNames) <= 2 { return nil, ErrInvalidResponseType } switch typeNames[2] { case "boolean": v = NullSlice2Bool{} case "json", "char", "varchar", "varbinary", "interval year to month", "interval day to second", "decimal", "ipaddress", "uuid", "unknown": v = NullSlice2String{} case "tinyint", "smallint", "integer", "bigint": v = NullSlice2Int64{} case "real", "double": v = NullSlice2Float64{} case "date", "time", "time with time zone", "timestamp", "timestamp with time zone": v = NullSlice2Time{} case "map": v = NullSlice2Map{} case "array": if len(typeNames) <= 3 { return nil, ErrInvalidResponseType } switch typeNames[3] { case "boolean": v = NullSlice3Bool{} case "json", "char", "varchar", "varbinary", "interval year to month", "interval day to second", "decimal", "ipaddress", "uuid", "unknown": v = NullSlice3String{} case "tinyint", "smallint", "integer", "bigint": v = NullSlice3Int64{} case "real", "double": v = NullSlice3Float64{} case "date", "time", "time with time zone", "timestamp", "timestamp with time zone": v = NullSlice3Time{} case "map": v = NullSlice3Map{} } // if this is a 4 or more dimensional array, scan type will be an empty interface } } } if v == nil { return reflect.TypeOf(new(interface{})).Elem(), nil } return reflect.TypeOf(v), nil } // ConvertValue implements the driver.ValueConverter interface. func (c *typeConverter) ConvertValue(v interface{}) (driver.Value, error) { switch c.parsedType[0] { case "boolean": vv, err := scanNullBool(v) if !vv.Valid { return nil, err } return vv.Bool, err case "json", "char", "varchar", "interval year to month", "interval day to second", "decimal", "ipaddress", "uuid", "Geometry", "SphericalGeography", "unknown": vv, err := scanNullString(v) if !vv.Valid { return nil, err } return vv.String, err case "varbinary": vv, err := scanNullBytes(v) if !vv.Valid { return nil, err } return vv.Bytes, err case "tinyint", "smallint", "integer", "bigint": vv, err := scanNullInt64(v) if !vv.Valid { return nil, err } return vv.Int64, err case "real", "double": vv, err := scanNullFloat64(v) if !vv.Valid { return nil, err } return vv.Float64, err case "date", "time", "time with time zone", "timestamp", "timestamp with time zone": vv, err := scanNullTime(v) if !vv.Valid { return nil, err } return vv.Time, err case "map": if err := validateMap(v); err != nil { return nil, err } return v, nil case "array": if err := validateSlice(v); err != nil { return nil, err } return v, nil case "row": if err := validateSlice(v); err != nil { return nil, err } return v, nil default: return nil, fmt.Errorf("type not supported: %q", c.typeName) } } func validateMap(v interface{}) error { if v == nil { return nil } if _, ok := v.(map[string]interface{}); !ok { return fmt.Errorf("cannot convert %v (%T) to map", v, v) } return nil } func validateSlice(v interface{}) error { if v == nil { return nil } if _, ok := v.([]interface{}); !ok { return fmt.Errorf("cannot convert %v (%T) to slice", v, v) } return nil } func scanNullBool(v interface{}) (sql.NullBool, error) { if v == nil { return sql.NullBool{}, nil } vv, ok := v.(bool) if !ok { return sql.NullBool{}, fmt.Errorf("cannot convert %v (%T) to bool", v, v) } return sql.NullBool{Valid: true, Bool: vv}, nil } // NullSliceBool represents a slice of bool that may be null. type NullSliceBool struct { SliceBool []sql.NullBool Valid bool } // Scan implements the sql.Scanner interface. func (s *NullSliceBool) Scan(value interface{}) error { if value == nil { s.SliceBool, s.Valid = []sql.NullBool{}, false return nil } vs, ok := value.([]interface{}) if !ok { return fmt.Errorf("trino: cannot convert %v (%T) to []bool", value, value) } slice := make([]sql.NullBool, len(vs)) for i := range vs { v, err := scanNullBool(vs[i]) if err != nil { return err } slice[i] = v } s.SliceBool = slice s.Valid = true return nil } // NullSlice2Bool represents a two-dimensional slice of bool that may be null. type NullSlice2Bool struct { Slice2Bool [][]sql.NullBool Valid bool } // Scan implements the sql.Scanner interface. func (s *NullSlice2Bool) Scan(value interface{}) error { if value == nil { s.Slice2Bool, s.Valid = [][]sql.NullBool{}, false return nil } vs, ok := value.([]interface{}) if !ok { return fmt.Errorf("trino: cannot convert %v (%T) to [][]bool", value, value) } slice := make([][]sql.NullBool, len(vs)) for i := range vs { var ss NullSliceBool if err := ss.Scan(vs[i]); err != nil { return err } slice[i] = ss.SliceBool } s.Slice2Bool = slice s.Valid = true return nil } // NullSlice3Bool implements a three-dimensional slice of bool that may be null. type NullSlice3Bool struct { Slice3Bool [][][]sql.NullBool Valid bool } // Scan implements the sql.Scanner interface. func (s *NullSlice3Bool) Scan(value interface{}) error { if value == nil { s.Slice3Bool, s.Valid = [][][]sql.NullBool{}, false return nil } vs, ok := value.([]interface{}) if !ok { return fmt.Errorf("trino: cannot convert %v (%T) to [][][]bool", value, value) } slice := make([][][]sql.NullBool, len(vs)) for i := range vs { var ss NullSlice2Bool if err := ss.Scan(vs[i]); err != nil { return err } slice[i] = ss.Slice2Bool } s.Slice3Bool = slice s.Valid = true return nil } func scanNullString(v interface{}) (sql.NullString, error) { if v == nil { return sql.NullString{}, nil } vv, ok := v.(string) if !ok { return sql.NullString{}, fmt.Errorf("cannot convert %v (%T) to string", v, v) } return sql.NullString{Valid: true, String: vv}, nil } // NullBinary represents a []byte that may be null. // This follows the same pattern as sql.NullString, sql.NullInt64, etc. type NullBinary struct { Bytes []byte Valid bool // Valid is true if Bytes is not NULL } func scanNullBytes(v interface{}) (NullBinary, error) { if v == nil { return NullBinary{}, nil // Valid: false, Bytes: nil } // VARBINARY values come back as a base64 encoded string. vv, ok := v.(string) if !ok { return NullBinary{}, fmt.Errorf("cannot convert %v (%T) to []byte", v, v) } // Decode the base64 encoded string into a []byte. decoded, err := base64.StdEncoding.DecodeString(vv) if err != nil { return NullBinary{}, fmt.Errorf("cannot decode base64 string into []byte: %w", err) } return NullBinary{Bytes: decoded, Valid: true}, nil } // NullSliceString represents a slice of string that may be null. type NullSliceString struct { SliceString []sql.NullString Valid bool } // Scan implements the sql.Scanner interface. func (s *NullSliceString) Scan(value interface{}) error { if value == nil { s.SliceString, s.Valid = []sql.NullString{}, false return nil } vs, ok := value.([]interface{}) if !ok { return fmt.Errorf("trino: cannot convert %v (%T) to []string", value, value) } slice := make([]sql.NullString, len(vs)) for i := range vs { v, err := scanNullString(vs[i]) if err != nil { return err } slice[i] = v } s.SliceString = slice s.Valid = true return nil } // NullSlice2String represents a two-dimensional slice of string that may be null. type NullSlice2String struct { Slice2String [][]sql.NullString Valid bool } // Scan implements the sql.Scanner interface. func (s *NullSlice2String) Scan(value interface{}) error { if value == nil { s.Slice2String, s.Valid = [][]sql.NullString{}, false return nil } vs, ok := value.([]interface{}) if !ok { return fmt.Errorf("trino: cannot convert %v (%T) to [][]string", value, value) } slice := make([][]sql.NullString, len(vs)) for i := range vs { var ss NullSliceString if err := ss.Scan(vs[i]); err != nil { return err } slice[i] = ss.SliceString } s.Slice2String = slice s.Valid = true return nil } // NullSlice3String implements a three-dimensional slice of string that may be null. type NullSlice3String struct { Slice3String [][][]sql.NullString Valid bool } // Scan implements the sql.Scanner interface. func (s *NullSlice3String) Scan(value interface{}) error { if value == nil { s.Slice3String, s.Valid = [][][]sql.NullString{}, false return nil } vs, ok := value.([]interface{}) if !ok { return fmt.Errorf("trino: cannot convert %v (%T) to [][][]string", value, value) } slice := make([][][]sql.NullString, len(vs)) for i := range vs { var ss NullSlice2String if err := ss.Scan(vs[i]); err != nil { return err } slice[i] = ss.Slice2String } s.Slice3String = slice s.Valid = true return nil } func scanNullInt64(v interface{}) (sql.NullInt64, error) { if v == nil { return sql.NullInt64{}, nil } vNumber, ok := v.(json.Number) if !ok { return sql.NullInt64{}, fmt.Errorf("cannot convert %v (%T) to int64", v, v) } vv, err := vNumber.Int64() if err != nil { return sql.NullInt64{}, fmt.Errorf("cannot convert %v (%T) to int64", v, v) } return sql.NullInt64{Valid: true, Int64: vv}, nil } // NullSliceInt64 represents a slice of int64 that may be null. type NullSliceInt64 struct { SliceInt64 []sql.NullInt64 Valid bool } // Scan implements the sql.Scanner interface. func (s *NullSliceInt64) Scan(value interface{}) error { if value == nil { s.SliceInt64, s.Valid = []sql.NullInt64{}, false return nil } vs, ok := value.([]interface{}) if !ok { return fmt.Errorf("trino: cannot convert %v (%T) to []int64", value, value) } slice := make([]sql.NullInt64, len(vs)) for i := range vs { v, err := scanNullInt64(vs[i]) if err != nil { return err } slice[i] = v } s.SliceInt64 = slice s.Valid = true return nil } // NullSlice2Int64 represents a two-dimensional slice of int64 that may be null. type NullSlice2Int64 struct { Slice2Int64 [][]sql.NullInt64 Valid bool } // Scan implements the sql.Scanner interface. func (s *NullSlice2Int64) Scan(value interface{}) error { if value == nil { s.Slice2Int64, s.Valid = [][]sql.NullInt64{}, false return nil } vs, ok := value.([]interface{}) if !ok { return fmt.Errorf("trino: cannot convert %v (%T) to [][]int64", value, value) } slice := make([][]sql.NullInt64, len(vs)) for i := range vs { var ss NullSliceInt64 if err := ss.Scan(vs[i]); err != nil { return err } slice[i] = ss.SliceInt64 } s.Slice2Int64 = slice s.Valid = true return nil } // NullSlice3Int64 implements a three-dimensional slice of int64 that may be null. type NullSlice3Int64 struct { Slice3Int64 [][][]sql.NullInt64 Valid bool } // Scan implements the sql.Scanner interface. func (s *NullSlice3Int64) Scan(value interface{}) error { if value == nil { s.Slice3Int64, s.Valid = [][][]sql.NullInt64{}, false return nil } vs, ok := value.([]interface{}) if !ok { return fmt.Errorf("trino: cannot convert %v (%T) to [][][]int64", value, value) } slice := make([][][]sql.NullInt64, len(vs)) for i := range vs { var ss NullSlice2Int64 if err := ss.Scan(vs[i]); err != nil { return err } slice[i] = ss.Slice2Int64 } s.Slice3Int64 = slice s.Valid = true return nil } func scanNullFloat64(v interface{}) (sql.NullFloat64, error) { if v == nil { return sql.NullFloat64{}, nil } vNumber, ok := v.(json.Number) if ok { vFloat, err := vNumber.Float64() if err != nil { return sql.NullFloat64{}, fmt.Errorf("cannot convert %v (%T) to float64: %w", vNumber, vNumber, err) } return sql.NullFloat64{Valid: true, Float64: vFloat}, nil } switch v { case "NaN": return sql.NullFloat64{Valid: true, Float64: math.NaN()}, nil case "Infinity": return sql.NullFloat64{Valid: true, Float64: math.Inf(+1)}, nil case "-Infinity": return sql.NullFloat64{Valid: true, Float64: math.Inf(-1)}, nil default: vString, ok := v.(string) if !ok { return sql.NullFloat64{}, fmt.Errorf("cannot convert %v (%T) to float64", v, v) } vFloat, err := strconv.ParseFloat(vString, 64) if err != nil { return sql.NullFloat64{}, fmt.Errorf("cannot convert %v (%T) to float64: %w", v, v, err) } return sql.NullFloat64{Valid: true, Float64: vFloat}, nil } } // NullSliceFloat64 represents a slice of float64 that may be null. type NullSliceFloat64 struct { SliceFloat64 []sql.NullFloat64 Valid bool } // Scan implements the sql.Scanner interface. func (s *NullSliceFloat64) Scan(value interface{}) error { if value == nil { s.SliceFloat64, s.Valid = []sql.NullFloat64{}, false return nil } vs, ok := value.([]interface{}) if !ok { return fmt.Errorf("trino: cannot convert %v (%T) to []float64", value, value) } slice := make([]sql.NullFloat64, len(vs)) for i := range vs { v, err := scanNullFloat64(vs[i]) if err != nil { return err } slice[i] = v } s.SliceFloat64 = slice s.Valid = true return nil } // NullSlice2Float64 represents a two-dimensional slice of float64 that may be null. type NullSlice2Float64 struct { Slice2Float64 [][]sql.NullFloat64 Valid bool } // Scan implements the sql.Scanner interface. func (s *NullSlice2Float64) Scan(value interface{}) error { if value == nil { s.Slice2Float64, s.Valid = [][]sql.NullFloat64{}, false return nil } vs, ok := value.([]interface{}) if !ok { return fmt.Errorf("trino: cannot convert %v (%T) to [][]float64", value, value) } slice := make([][]sql.NullFloat64, len(vs)) for i := range vs { var ss NullSliceFloat64 if err := ss.Scan(vs[i]); err != nil { return err } slice[i] = ss.SliceFloat64 } s.Slice2Float64 = slice s.Valid = true return nil } // NullSlice3Float64 represents a three-dimensional slice of float64 that may be null. type NullSlice3Float64 struct { Slice3Float64 [][][]sql.NullFloat64 Valid bool } // Scan implements the sql.Scanner interface. func (s *NullSlice3Float64) Scan(value interface{}) error { if value == nil { s.Slice3Float64, s.Valid = [][][]sql.NullFloat64{}, false return nil } vs, ok := value.([]interface{}) if !ok { return fmt.Errorf("trino: cannot convert %v (%T) to [][][]float64", value, value) } slice := make([][][]sql.NullFloat64, len(vs)) for i := range vs { var ss NullSlice2Float64 if err := ss.Scan(vs[i]); err != nil { return err } slice[i] = ss.Slice2Float64 } s.Slice3Float64 = slice s.Valid = true return nil } // Layout for time and timestamp WITHOUT time zone. // Trino can support up to 12 digits sub second precision, but Go only 9. // (Requires X-Trino-Client-Capabilities: PARAMETRIC_DATETIME) var timeLayouts = []string{ "2006-01-02", "15:04:05.999999999", "2006-01-02 15:04:05.999999999", } // Layout for time and timestamp WITH time zone. // Trino can support up to 12 digits sub second precision, but Go only 9. // (Requires X-Trino-Client-Capabilities: PARAMETRIC_DATETIME) var timeLayoutsTZ = []string{ "15:04:05.999999999 -07:00", "2006-01-02 15:04:05.999999999 -07:00", } func scanNullTime(v interface{}) (NullTime, error) { if v == nil { return NullTime{}, nil } vv, ok := v.(string) if !ok { return NullTime{}, fmt.Errorf("cannot convert %v (%T) to time string", v, v) } vparts := strings.Split(vv, " ") if len(vparts) > 1 && !unicode.IsDigit(rune(vparts[len(vparts)-1][0])) { return parseNullTimeWithLocation(vv) } // Time literals may not have spaces before the timezone. if strings.ContainsRune(vv, '+') { return parseNullTimeWithLocation(strings.Replace(vv, "+", " +", 1)) } hyphenCount := strings.Count(vv, "-") // We need to ensure we don't treat the hyphens in dates as the minus offset sign. // So if there's only one hyphen or more than 2, we have a negative offset. if hyphenCount == 1 || hyphenCount > 2 { // We add a space before the last hyphen to parse properly. i := strings.LastIndex(vv, "-") timestamp := vv[:i] + strings.Replace(vv[i:], "-", " -", 1) return parseNullTimeWithLocation(timestamp) } return parseNullTime(vv) } func parseNullTime(v string) (NullTime, error) { var t time.Time var err error for _, layout := range timeLayouts { t, err = time.ParseInLocation(layout, v, time.Local) if err == nil { return NullTime{Valid: true, Time: t}, nil } } return NullTime{}, err } func parseNullTimeWithLocation(v string) (NullTime, error) { idx := strings.LastIndex(v, " ") if idx == -1 { return NullTime{}, fmt.Errorf("cannot convert %v (%T) to time+zone", v, v) } stamp, location := v[:idx], v[idx+1:] var t time.Time var err error // Try offset timezones. if strings.HasPrefix(location, "+") || strings.HasPrefix(location, "-") { for _, layout := range timeLayoutsTZ { t, err = time.Parse(layout, v) if err == nil { return NullTime{Valid: true, Time: t}, nil } } return NullTime{}, err } loc, err := time.LoadLocation(location) // Not a named location. if err != nil { return NullTime{}, fmt.Errorf("cannot load timezone %q: %v", location, err) } for _, layout := range timeLayouts { t, err = time.ParseInLocation(layout, stamp, loc) if err == nil { return NullTime{Valid: true, Time: t}, nil } } return NullTime{}, err } // NullTime represents a time.Time value that can be null. // The NullTime supports Trino's Date, Time and Timestamp data types, // with or without time zone. type NullTime struct { Time time.Time Valid bool } // Scan implements the sql.Scanner interface. func (s *NullTime) Scan(value interface{}) error { if value == nil { s.Time, s.Valid = time.Time{}, false return nil } switch t := value.(type) { case time.Time: s.Time, s.Valid = t, true case NullTime: *s = t } return nil } // NullSliceTime represents a slice of time.Time that may be null. type NullSliceTime struct { SliceTime []NullTime Valid bool } // Scan implements the sql.Scanner interface. func (s *NullSliceTime) Scan(value interface{}) error { if value == nil { s.SliceTime, s.Valid = []NullTime{}, false return nil } vs, ok := value.([]interface{}) if !ok { return fmt.Errorf("trino: cannot convert %v (%T) to []time.Time", value, value) } slice := make([]NullTime, len(vs)) for i := range vs { v, err := scanNullTime(vs[i]) if err != nil { return err } slice[i] = v } s.SliceTime = slice s.Valid = true return nil } // NullSlice2Time represents a two-dimensional slice of time.Time that may be null. type NullSlice2Time struct { Slice2Time [][]NullTime Valid bool } // Scan implements the sql.Scanner interface. func (s *NullSlice2Time) Scan(value interface{}) error { if value == nil { s.Slice2Time, s.Valid = [][]NullTime{}, false return nil } vs, ok := value.([]interface{}) if !ok { return fmt.Errorf("trino: cannot convert %v (%T) to [][]time.Time", value, value) } slice := make([][]NullTime, len(vs)) for i := range vs { var ss NullSliceTime if err := ss.Scan(vs[i]); err != nil { return err } slice[i] = ss.SliceTime } s.Slice2Time = slice s.Valid = true return nil } // NullSlice3Time represents a three-dimensional slice of time.Time that may be null. type NullSlice3Time struct { Slice3Time [][][]NullTime Valid bool } // Scan implements the sql.Scanner interface. func (s *NullSlice3Time) Scan(value interface{}) error { if value == nil { s.Slice3Time, s.Valid = [][][]NullTime{}, false return nil } vs, ok := value.([]interface{}) if !ok { return fmt.Errorf("trino: cannot convert %v (%T) to [][][]time.Time", value, value) } slice := make([][][]NullTime, len(vs)) for i := range vs { var ss NullSlice2Time if err := ss.Scan(vs[i]); err != nil { return err } slice[i] = ss.Slice2Time } s.Slice3Time = slice s.Valid = true return nil } // NullMap represents a map type that may be null. type NullMap struct { Map map[string]interface{} Valid bool } // Scan implements the sql.Scanner interface. func (m *NullMap) Scan(v interface{}) error { if v == nil { m.Map, m.Valid = map[string]interface{}{}, false return nil } m.Map, m.Valid = v.(map[string]interface{}) return nil } // NullSliceMap represents a slice of NullMap that may be null. type NullSliceMap struct { SliceMap []NullMap Valid bool } // Scan implements the sql.Scanner interface. func (s *NullSliceMap) Scan(value interface{}) error { if value == nil { s.SliceMap, s.Valid = []NullMap{}, false return nil } vs, ok := value.([]interface{}) if !ok { return fmt.Errorf("trino: cannot convert %v (%T) to []NullMap", value, value) } slice := make([]NullMap, len(vs)) for i := range vs { if err := validateMap(vs[i]); err != nil { return fmt.Errorf("cannot convert %v (%T) to []NullMap", value, value) } m := NullMap{} // this scan can never fail _ = m.Scan(vs[i]) slice[i] = m } s.SliceMap = slice s.Valid = true return nil } // NullSlice2Map represents a two-dimensional slice of NullMap that may be null. type NullSlice2Map struct { Slice2Map [][]NullMap Valid bool } // Scan implements the sql.Scanner interface. func (s *NullSlice2Map) Scan(value interface{}) error { if value == nil { s.Slice2Map, s.Valid = [][]NullMap{}, false return nil } vs, ok := value.([]interface{}) if !ok { return fmt.Errorf("trino: cannot convert %v (%T) to [][]NullMap", value, value) } slice := make([][]NullMap, len(vs)) for i := range vs { var ss NullSliceMap if err := ss.Scan(vs[i]); err != nil { return err } slice[i] = ss.SliceMap } s.Slice2Map = slice s.Valid = true return nil } // NullSlice3Map represents a three-dimensional slice of NullMap that may be null. type NullSlice3Map struct { Slice3Map [][][]NullMap Valid bool } // Scan implements the sql.Scanner interface. func (s *NullSlice3Map) Scan(value interface{}) error { if value == nil { s.Slice3Map, s.Valid = [][][]NullMap{}, false return nil } vs, ok := value.([]interface{}) if !ok { return fmt.Errorf("trino: cannot convert %v (%T) to [][][]NullMap", value, value) } slice := make([][][]NullMap, len(vs)) for i := range vs { var ss NullSlice2Map if err := ss.Scan(vs[i]); err != nil { return err } slice[i] = ss.Slice2Map } s.Slice3Map = slice s.Valid = true return nil } type QueryProgressInfo struct { QueryId string QueryStats stmtStats } type queryProgressCallbackPeriod struct { Period time.Duration LastCallbackTime time.Time LastQueryState string } type ProgressUpdater interface { // Update the query progress, immediately when the query starts, when receiving data, and once when the query is finished. Update(QueryProgressInfo) } ================================================ FILE: trino/trino_test.go ================================================ // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package trino import ( "bytes" "context" "database/sql" "encoding/base64" "encoding/json" "fmt" "math" "net/http" "net/http/httptest" "net/url" "reflect" "runtime/debug" "sort" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestConfig(t *testing.T) { c := &Config{ ServerURI: "http://foobar@localhost:8080", SessionProperties: map[string]string{"query_priority": "1"}, } dsn, err := c.FormatDSN() require.NoError(t, err) want := "http://foobar@localhost:8080?session_properties=query_priority%3A1&source=trino-go-client" assert.Equal(t, want, dsn) } func TestPreserveExplicitPrepareQueryParameterConfig(t *testing.T) { c := &Config{ ServerURI: "https://foobar@localhost:8090", DisableExplicitPrepare: true, } dsn, err := c.FormatDSN() require.NoError(t, err) want := "https://foobar@localhost:8090?explicitPrepare=false&source=trino-go-client" assert.Equal(t, want, dsn) } func TestParseDSNToConfig(t *testing.T) { tests := []struct { name string config *Config }{ { name: "HTTP with custom client and full configuration", config: &Config{ ServerURI: "http://foobar@localhost:8080", Source: "trino-go-client", Catalog: "test_catalog", Schema: "test_schema", SessionProperties: map[string]string{"session_property_one": "1", "session_property_two": "2"}, ExtraCredentials: map[string]string{"extra_credential_one": "1", "extra_credential_two": "2"}, ClientTags: []string{"tag1", "tag2", "tag3"}, CustomClientName: "client_name", AccessToken: "token_test", DisableExplicitPrepare: true, ForwardAuthorizationHeader: true, QueryTimeout: &[]time.Duration{5 * time.Minute}[0], Roles: map[string]string{"catalog1": "role1", "catalog2": "role2"}, }, }, { name: "HTTPS with Kerberos and SSL cert path", config: &Config{ ServerURI: "https://foobar@localhost:8080", Source: "trino-go-client", Catalog: "test_catalog", Schema: "test_schema", SessionProperties: map[string]string{"session_property_one": "1", "session_property_two": "2"}, ExtraCredentials: map[string]string{"extra_credential_one": "1", "extra_credential_two": "2"}, ClientTags: []string{"tag1", "tag2", "tag3"}, KerberosEnabled: true, // Requires HTTPS KerberosKeytabPath: "kerberos-path", KerberosPrincipal: "kerberos-pricipal", KerberosRemoteServiceName: "kerberos-remote-service-name", KerberosRealm: "kerberos-realm", KerberosConfigPath: "kerberos-config-path", SSLCertPath: "ssl-cert-path", AccessToken: "token_test", DisableExplicitPrepare: true, ForwardAuthorizationHeader: true, QueryTimeout: &[]time.Duration{5 * time.Minute}[0], }, }, { name: "HTTPS with SSL cert string (alternative to cert path)", config: &Config{ ServerURI: "https://localhost:8080", Source: "trino-go-client", SSLCert: "-----BEGIN CERTIFICATE-----\ntest-cert-data\n-----END CERTIFICATE-----", }, }, { name: "HTTP with explicit default boolean values", config: &Config{ ServerURI: "http://localhost:8080", Source: "trino-go-client", DisableExplicitPrepare: false, ForwardAuthorizationHeader: false, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { dsn, err := tt.config.FormatDSN() require.NoError(t, err) got, err := ParseDSN(dsn) require.NoError(t, err) assert.Equal(t, tt.config, got) }) } } func TestParseDSNToConfigAllFieldsHandled(t *testing.T) { complexDSN := "https://user:pass@localhost:8080/?" + "source=test-source&" + "catalog=test_catalog&" + "schema=test_schema&" + "session_properties=prop1%3Avalue1%3Bprop2%3Avalue2&" + "extra_credentials=cred1%3Asecret1%3Bcred2%3Asecret2&" + "clientTags=tag1%2Ctag2%2Ctag3&" + "custom_client=test_client&" + "KerberosEnabled=true&" + "KerberosKeytabPath=/path/to/keytab&" + "KerberosPrincipal=user%40REALM.COM&" + "KerberosRemoteServiceName=trino-service&" + "KerberosRealm=REALM.COM&" + "KerberosConfigPath=/etc/krb5.conf&" + "SSLCertPath=/path/to/cert.pem&" + "SSLCert=-----BEGIN%20CERTIFICATE-----test-cert-----END%20CERTIFICATE-----&" + "accessToken=jwt-token-here&" + "explicitPrepare=false&" + "forwardAuthorizationHeader=true&" + "query_timeout=5m30s&" + "roles=catalog1%3Arole1%3Bcatalog2%3Arole2" config, err := ParseDSN(complexDSN) require.NoError(t, err) require.NotNil(t, config) v := reflect.ValueOf(config).Elem() configType := v.Type() for i := 0; i < v.NumField(); i++ { field := v.Field(i) fieldName := configType.Field(i).Name fieldType := field.Type() switch fieldType.Kind() { case reflect.String: assert.NotEmpty(t, field.String(), "Field %s should not be empty - add it to the test DSN and ParseDSNToConfig", fieldName) case reflect.Slice: assert.Greater(t, field.Len(), 0, "Field %s should not be empty slice - add it to the test DSN and ParseDSNToConfig", fieldName) case reflect.Map: assert.Greater(t, field.Len(), 0, "Field %s should not be empty map - add it to the test DSN and ParseDSNToConfig", fieldName) case reflect.Bool: assert.True(t, field.Bool(), "Field %s should be true - add it to the test DSN and ParseDSNToConfig", fieldName) case reflect.Ptr: assert.NotNil(t, field.Interface(), "Field %s should not be nil - add it to the test DSN and ParseDSNToConfig", fieldName) } } assert.Equal(t, "https://user:pass@localhost:8080", config.ServerURI) assert.Equal(t, "test-source", config.Source) assert.Equal(t, "test_catalog", config.Catalog) assert.Equal(t, "test_schema", config.Schema) assert.Equal(t, map[string]string{"prop1": "value1", "prop2": "value2"}, config.SessionProperties) assert.Equal(t, map[string]string{"cred1": "secret1", "cred2": "secret2"}, config.ExtraCredentials) assert.Equal(t, []string{"tag1", "tag2", "tag3"}, config.ClientTags) assert.Equal(t, "test_client", config.CustomClientName) assert.Equal(t, true, config.KerberosEnabled) assert.Equal(t, "/path/to/keytab", config.KerberosKeytabPath) assert.Equal(t, "user@REALM.COM", config.KerberosPrincipal) assert.Equal(t, "trino-service", config.KerberosRemoteServiceName) assert.Equal(t, "REALM.COM", config.KerberosRealm) assert.Equal(t, "/etc/krb5.conf", config.KerberosConfigPath) assert.Equal(t, "/path/to/cert.pem", config.SSLCertPath) assert.Equal(t, "-----BEGIN CERTIFICATE-----test-cert-----END CERTIFICATE-----", config.SSLCert) assert.Equal(t, "jwt-token-here", config.AccessToken) assert.Equal(t, true, config.DisableExplicitPrepare) assert.Equal(t, true, config.ForwardAuthorizationHeader) assert.NotNil(t, config.QueryTimeout) assert.Equal(t, 5*time.Minute+30*time.Second, *config.QueryTimeout) assert.Equal(t, map[string]string{"catalog1": "role1", "catalog2": "role2"}, config.Roles) } func TestConfigFormatDSNTags(t *testing.T) { tests := []struct { name string config *Config want string }{ { name: "multiple tags", config: &Config{ ServerURI: "http://foobar@localhost:8080", SessionProperties: map[string]string{"query_priority": "1"}, ClientTags: []string{"test1", "test2", "test3"}, }, want: "http://foobar@localhost:8080?clientTags=test1%2Ctest2%2Ctest3&session_properties=query_priority%3A1&source=trino-go-client", }, { name: "single tag", config: &Config{ ServerURI: "http://foobar@localhost:8080", SessionProperties: map[string]string{"query_priority": "1"}, ClientTags: []string{"test1"}, }, want: "http://foobar@localhost:8080?clientTags=test1&session_properties=query_priority%3A1&source=trino-go-client", }, { name: "multiple tags with special characters", config: &Config{ ServerURI: "http://foobar@localhost:8080", SessionProperties: map[string]string{"query_priority": "1"}, ClientTags: []string{"foo %20", "bar=test", "baz#tag"}, }, want: "http://foobar@localhost:8080?clientTags=foo+%2520%2Cbar%3Dtest%2Cbaz%23tag&session_properties=query_priority%3A1&source=trino-go-client", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.config.FormatDSN() require.NoError(t, err) assert.Equal(t, tt.want, got) }) } } func TestConfigSSLCertPath(t *testing.T) { c := &Config{ ServerURI: "https://foobar@localhost:8080", SessionProperties: map[string]string{"query_priority": "1"}, SSLCertPath: "cert.pem", } dsn, err := c.FormatDSN() require.NoError(t, err) want := "https://foobar@localhost:8080?SSLCertPath=cert.pem&session_properties=query_priority%3A1&source=trino-go-client" assert.Equal(t, want, dsn) } func TestConfigSSLCert(t *testing.T) { sslCert := `-----BEGIN CERTIFICATE----- MIIFijCCA3ICCQDngXKCZFwSazANBgkqhkiG9w0BAQsFADCBhjELMAkGA1UEBhMC WFgxEjAQBgNVBAgMCVN0YXRlTmFtZTERMA8GA1UEBwwIQ2l0eU5hbWUxFDASBgNV BAoMC0NvbXBhbnlOYW1lMRswGQYDVQQLDBJDb21wYW55U2VjdGlvbk5hbWUxHTAb BgNVBAMMFENvbW1vbk5hbWVPckhvc3RuYW1lMB4XDTIzMDUxNzE2MzQ0MloXDTMz MDUxNDE2MzQ0MlowgYYxCzAJBgNVBAYTAlhYMRIwEAYDVQQIDAlTdGF0ZU5hbWUx ETAPBgNVBAcMCENpdHlOYW1lMRQwEgYDVQQKDAtDb21wYW55TmFtZTEbMBkGA1UE CwwSQ29tcGFueVNlY3Rpb25OYW1lMR0wGwYDVQQDDBRDb21tb25OYW1lT3JIb3N0 bmFtZTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAKzz/SIuOiHZbUAH xCWrMaiJybdHHHl0smCu50XKvl/ZkszO1c4aES8/Vohw44ttaE+GOknTSGPka356 NqwdPYMjnXN0d5HY5T5nOfgLxGD/1iCHACrT4gkd1asJ7eFaUgud0a+e9+oG53Vh Z3QV8+5JaWPuBMudJ8EOtrPMd0dJKVzeExTbpQLJ9HdIsHc6DXqshACd8Iy+ezqf OoYMYyJMAHO86MZrTs3t9AwUADlvntrwwObVrZ3v43IOKwJTRnpImmVlkouKrGn/ HKzRmJEJ6hJQXhuhqI/0rr61XR8aa8Gs0FqtTTMJ32+PciPPzFtFVLAeA417lYz+ uXZ6IpTLK4oDH8Q6gJY80GYqcGc+01ZY90W2L+odTz9P74vnTvsUgSjOcy7prJ0+ WxoeBNPvkLeetX9WDZW4XaR++HVO1qelNJQqeB6Nver9MJdKkXvR3OxT6iluqXfA l9JJ57tnzspSrttjWG4kwwiaGn/4xPqd95Hp0r1WAK8U0Cqtvz+Zw9jl341tC1Ya K1KFIErZYf0KX8ZiYvmkHaTRxYiCmFnnfLtGdrAWkacisLKMhjeb9LXwC/TVtvio a+ofiW2DX80pQptkfNJs9P19ZFEojPAEFHiZFpz5yZSxHglxIsdIhRsuy5xb/KTo zey3tsKQJaFIah+aHKjyn3uZx2IRAgMBAAEwDQYJKoZIhvcNAQELBQADggIBAIs5 sbCMB6bT0hcNFqFRCI/BL23m5jwdL9kNWDlEQxBvErtzTC+uStGrCqwV+qu49QAZ 64kUolbzFyq/hQFpHd+9EzNkZGbiOf5toWaBUP6jaZzqYPdfDW+AwIA7iPHcqwH1 iWX2zuAWAICy4H+S4oa/ShOPc8BrrnS8k5f1NpergOhd+wl+szuXJN9Tjli3wd/k L7f86xvZfOrEbss8YP4QE0+mKh6G71NLEVQ4SV7yIE2hCNLDFWS2ltGVRLv6CDaQ fXIQrZx2Khvpj+HI/hrwm1wV8Cg5w2IvB831YjTSepSoos0Cc/qYC78zqol/NbwL 7TdHtuZKukDrisRiCDdoKFmS1/IUVeVR2352CG8G3Zo0wwfzoKLxLUtunnrKMmmO r2jXykqP2hb1dApBNFM7FoaJ7a0j6EcURW8wYl4I+b9ymftPnnZ8mgrjwvLh5ETj RgGsIBychLZoc1WWTZWu62+mvmSJnzEIFfaiSeYZLaL6qFHm6kqsAUn4s1Looj8/ XoCNjMecchWbpHGCPwMFH1k2smxu7bKk/RJNuWSVn1IPUceJnOBHZGj92aJGZpjr 8j39T3dK9F2r5rHwjZpeEIhyhbLw6pYKif+lBgAWJD3waG0ycwURA02/POHN4CpT FKu5ZAlRfb2aYegr49DHhzoVAdInWQmP+5EZEUD1 -----END CERTIFICATE-----` c := &Config{ ServerURI: "https://foobar@localhost:8080", SessionProperties: map[string]string{"query_priority": "1"}, SSLCert: sslCert, } dsn, err := c.FormatDSN() require.NoError(t, err) want := "https://foobar@localhost:8080?SSLCert=" + url.QueryEscape(sslCert) + "&session_properties=query_priority%3A1&source=trino-go-client" assert.Equal(t, want, dsn) } func TestExtraCredentials(t *testing.T) { c := &Config{ ServerURI: "http://foobar@localhost:8080", ExtraCredentials: map[string]string{"token": "mYtOkEn", "otherToken": "oThErToKeN%*!#@special"}, } dsn, err := c.FormatDSN() require.NoError(t, err) want := "http://foobar@localhost:8080?extra_credentials=otherToken%3AoThErToKeN%25%2A%21%23%40special%3Btoken%3AmYtOkEn&source=trino-go-client" assert.Equal(t, want, dsn) } func TestInvalidExtraCredentials(t *testing.T) { testcases := []struct { Name string Credentials map[string]string Error string }{ { Name: "Empty key", Credentials: map[string]string{"": "emptyKey"}, Error: "trino: extra_credentials key is empty", }, { Name: "Empty value", Credentials: map[string]string{"valid": "a", "emptyValue": ""}, Error: "trino: extra_credentials value is empty", }, { Name: "Unprintable key", Credentials: map[string]string{"😊": "unprintableKey"}, Error: "trino: extra_credentials key '😊' contains spaces or is not printable ASCII", }, { Name: "Unprintable value", Credentials: map[string]string{"unprintableValue": "😊"}, Error: "trino: extra_credentials value for key 'unprintableValue' contains spaces or is not printable ASCII", }, } for _, tc := range testcases { t.Run(tc.Name, func(t *testing.T) { c := &Config{ ServerURI: "http://foobar@localhost:8080", ExtraCredentials: tc.Credentials, } dsn, err := c.FormatDSN() require.NoError(t, err) db, err := sql.Open("trino", dsn) require.NoError(t, err) err = db.Ping() assert.EqualError(t, err, tc.Error) }) } } func TestConfigWithoutSSLCertPath(t *testing.T) { c := &Config{ ServerURI: "https://foobar@localhost:8080", SessionProperties: map[string]string{"query_priority": "1"}, } dsn, err := c.FormatDSN() require.NoError(t, err) want := "https://foobar@localhost:8080?session_properties=query_priority%3A1&source=trino-go-client" assert.Equal(t, want, dsn) } func TestKerberosConfig(t *testing.T) { c := &Config{ ServerURI: "https://foobar@localhost:8090", SessionProperties: map[string]string{"query_priority": "1"}, KerberosEnabled: true, KerberosKeytabPath: "/opt/test.keytab", KerberosPrincipal: "trino/testhost", KerberosRealm: "example.com", KerberosConfigPath: "/etc/krb5.conf", KerberosRemoteServiceName: "service", SSLCertPath: "/tmp/test.cert", } dsn, err := c.FormatDSN() require.NoError(t, err) want := "https://foobar@localhost:8090?KerberosConfigPath=%2Fetc%2Fkrb5.conf&KerberosEnabled=true&KerberosKeytabPath=%2Fopt%2Ftest.keytab&KerberosPrincipal=trino%2Ftesthost&KerberosRealm=example.com&KerberosRemoteServiceName=service&SSLCertPath=%2Ftmp%2Ftest.cert&session_properties=query_priority%3A1&source=trino-go-client" assert.Equal(t, want, dsn) } func TestFormatDSNWithRoles(t *testing.T) { tests := []struct { name string config *Config wantDSN string expectError bool }{ { name: "Multiple catalog roles", config: &Config{ ServerURI: "https://foobar@localhost:8090", SessionProperties: map[string]string{"query_priority": "1"}, Roles: map[string]string{"catalog1": "role1", "catalog2": "role2"}, }, wantDSN: "https://foobar@localhost:8090?roles=catalog1%3Arole1%3Bcatalog2%3Arole2&session_properties=query_priority%3A1&source=trino-go-client", }, { name: "Single catalog role", config: &Config{ ServerURI: "https://foobar@localhost:8090", SessionProperties: map[string]string{"query_priority": "1"}, Roles: map[string]string{"catalog1": "role1"}, }, wantDSN: "https://foobar@localhost:8090?roles=catalog1%3Arole1&session_properties=query_priority%3A1&source=trino-go-client", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { dsn, err := tt.config.FormatDSN() if tt.expectError { require.Error(t, err) } else { require.NoError(t, err) assert.Equal(t, tt.wantDSN, dsn) } }) } } func TestInvalidKerberosConfig(t *testing.T) { c := &Config{ ServerURI: "http://foobar@localhost:8090", KerberosEnabled: true, } _, err := c.FormatDSN() assert.Error(t, err, "dsn generated from invalid secure url, since kerberos enabled must has SSL enabled") } func TestAccessTokenConfig(t *testing.T) { c := &Config{ ServerURI: "https://foobar@localhost:8090", AccessToken: "token", } dsn, err := c.FormatDSN() require.NoError(t, err) want := "https://foobar@localhost:8090?accessToken=token&source=trino-go-client" assert.Equal(t, want, dsn) } func TestConfigWithMalformedURL(t *testing.T) { _, err := (&Config{ServerURI: ":("}).FormatDSN() assert.Error(t, err, "dsn generated from malformed url") } func TestConnErrorDSN(t *testing.T) { testcases := []struct { Name string DSN string }{ {Name: "malformed", DSN: "://"}, {Name: "unknown_client", DSN: "http://localhost?custom_client=unknown"}, } for _, tc := range testcases { t.Run(tc.Name, func(t *testing.T) { db, err := sql.Open("trino", tc.DSN) require.NoError(t, err) _, err = db.Query("SELECT 1") assert.Errorf(t, err, "test dsn is supposed to fail: %s", tc.DSN) if err == nil { require.NoError(t, db.Close()) } }) } } func TestRegisterCustomClientReserved(t *testing.T) { for _, tc := range []string{"true", "false"} { t.Run(fmt.Sprintf("%v", tc), func(t *testing.T) { require.Errorf(t, RegisterCustomClient(tc, &http.Client{}), "client key name supposed to fail: %s", tc) }) } } func TestQueryTimeout(t *testing.T) { timeout := 10 * time.Second c := &Config{ ServerURI: "https://foobar@localhost:8090", QueryTimeout: &timeout, } dsn, err := c.FormatDSN() require.NoError(t, err) want := "https://foobar@localhost:8090?query_timeout=10s&source=trino-go-client" assert.Equal(t, want, dsn) } func TestRoundTripRetryQueryError(t *testing.T) { testcases := []struct { Name string HttpStatus int ExpectedErrorStatus string }{ { Name: "Test retry 502 Bad Gateway", HttpStatus: http.StatusBadGateway, ExpectedErrorStatus: "200 OK", }, { Name: "Test retry 503 Service Unavailable", HttpStatus: http.StatusServiceUnavailable, ExpectedErrorStatus: "200 OK", }, { Name: "Test retry 504 Gateway Timeout", HttpStatus: http.StatusGatewayTimeout, ExpectedErrorStatus: "200 OK", }, { Name: "Test no retry 404 Not Found", HttpStatus: http.StatusNotFound, ExpectedErrorStatus: "404 Not Found", }, } for _, tc := range testcases { t.Run(tc.Name, func(t *testing.T) { count := 0 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if count == 0 { count++ w.WriteHeader(tc.HttpStatus) return } w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(&stmtResponse{ Error: ErrTrino{ ErrorName: "TEST", }, }) })) t.Cleanup(ts.Close) db, err := sql.Open("trino", ts.URL) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, db.Close()) }) _, err = db.Query("SELECT 1") assert.ErrorContains(t, err, tc.ExpectedErrorStatus, "unexpected error: %w", err) }) } } func TestRoundTripBogusData(t *testing.T) { count := 0 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if count == 0 { count++ w.WriteHeader(http.StatusServiceUnavailable) return } w.WriteHeader(http.StatusOK) // some invalid JSON w.Write([]byte(`{"stats": {"progressPercentage": ""}}`)) })) t.Cleanup(ts.Close) db, err := sql.Open("trino", ts.URL) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, db.Close()) }) rows, err := db.Query("SELECT 1") require.NoError(t, err) assert.False(t, rows.Next()) require.NoError(t, rows.Err()) } func TestRoundTripCancellation(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusServiceUnavailable) })) t.Cleanup(ts.Close) db, err := sql.Open("trino", ts.URL) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, db.Close()) }) ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) t.Cleanup(cancel) _, err = db.QueryContext(ctx, "SELECT 1") assert.Error(t, err, "unexpected query with cancelled context succeeded") } func TestAuthFailure(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) })) t.Cleanup(ts.Close) db, err := sql.Open("trino", ts.URL) require.NoError(t, err) assert.NoError(t, db.Close()) } func TestTokenAuth(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Header.Get("Authorization") != "Bearer token" { w.WriteHeader(http.StatusUnauthorized) } else { w.WriteHeader(http.StatusOK) } })) t.Cleanup(ts.Close) db, err := sql.Open("trino", ts.URL+"?accessToken=token") require.NoError(t, err) _, err = db.Query("SELECT 1") require.Error(t, err, "trino: EOF") assert.NoError(t, db.Close()) } func TestQueryForUsername(t *testing.T) { if testing.Short() { t.Skip("Skipping test in short mode.") } c := &Config{ ServerURI: *integrationServerFlag, SessionProperties: map[string]string{"query_priority": "1"}, } dsn, err := c.FormatDSN() require.NoError(t, err) db, err := sql.Open("trino", dsn) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, db.Close()) }) rows, err := db.Query("SELECT current_user", sql.Named("X-Trino-User", string("TestUser"))) require.NoError(t, err, "Failed executing query") assert.NotNil(t, rows) for rows.Next() { var user string require.NoError(t, rows.Scan(&user), "Failed scanning query result") assert.Equal(t, "TestUser", user, "Expected value does not equal result value") } } type TestQueryProgressCallback struct { progressMap map[time.Time]float64 statusMap map[time.Time]string } func (qpc *TestQueryProgressCallback) Update(qpi QueryProgressInfo) { qpc.progressMap[time.Now()] = float64(qpi.QueryStats.ProgressPercentage) qpc.statusMap[time.Now()] = qpi.QueryStats.State } func TestQueryProgressWithCallback(t *testing.T) { if testing.Short() { t.Skip("Skipping test in short mode.") } c := &Config{ ServerURI: *integrationServerFlag, SessionProperties: map[string]string{"query_priority": "1"}, } dsn, err := c.FormatDSN() require.NoError(t, err) db, err := sql.Open("trino", dsn) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, db.Close()) }) callback := &TestQueryProgressCallback{} _, err = db.Query("SELECT 2", sql.Named("X-Trino-Progress-Callback", callback)) assert.EqualError(t, err, ErrInvalidProgressCallbackHeader.Error(), "unexpected error") } func TestQueryProgressWithCallbackPeriod(t *testing.T) { if testing.Short() { t.Skip("Skipping test in short mode.") } c := &Config{ ServerURI: *integrationServerFlag, SessionProperties: map[string]string{"query_priority": "1"}, } dsn, err := c.FormatDSN() require.NoError(t, err) db, err := sql.Open("trino", dsn) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, db.Close()) }) progressMap := make(map[time.Time]float64) statusMap := make(map[time.Time]string) progressUpdater := &TestQueryProgressCallback{ progressMap: progressMap, statusMap: statusMap, } progressUpdaterPeriod, err := time.ParseDuration("1ms") require.NoError(t, err) rows, err := db.Query("SELECT 2", sql.Named("X-Trino-Progress-Callback", progressUpdater), sql.Named("X-Trino-Progress-Callback-Period", progressUpdaterPeriod), ) require.NoError(t, err, "Failed executing query") assert.NotNil(t, rows) for rows.Next() { var ts string require.NoError(t, rows.Scan(&ts), "Failed scanning query result") assert.Equal(t, "2", ts, "Expected value does not equal result value") } if err = rows.Err(); err != nil { t.Fatal(err) } if err = rows.Close(); err != nil { t.Fatal(err) } // sort time in order to calculate interval assert.NotEmpty(t, progressMap) assert.NotEmpty(t, statusMap) var keys []time.Time for k := range statusMap { keys = append(keys, k) } sort.Slice(keys, func(i, j int) bool { return keys[i].Before(keys[j]) }) for i, k := range keys { if i > 0 { assert.GreaterOrEqual(t, k.Sub(keys[i-1]), progressUpdaterPeriod) } assert.GreaterOrEqual(t, progressMap[k], 0.0) } } func TestQueryColumns(t *testing.T) { c := &Config{ ServerURI: *integrationServerFlag, SessionProperties: map[string]string{"query_priority": "1"}, } dsn, err := c.FormatDSN() require.NoError(t, err) db, err := sql.Open("trino", dsn) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, db.Close()) }) rows, err := db.Query(`SELECT true AS bool, cast(123 AS tinyint) AS tinyint, cast(456 AS smallint) AS smallint, cast(678 AS integer) AS integer, cast(1234 AS bigint) AS bigint, cast(1.23 AS real) AS real, cast(1.23 AS double) AS double, cast(1.23 as decimal(10,5)) AS decimal, cast('aaa' as varchar) AS vunbounded, cast('bbb' as varchar(10)) AS vbounded, cast('ccc' AS char) AS cunbounded, cast('ddd' as char(10)) AS cbounded, cast('ddd' as varbinary) AS varbinary, cast('{"aaa": 1}' as json) AS json, current_date AS date, cast(current_time as time) AS time, cast(current_time as time(6)) AS timep, cast(current_time as time with time zone) AS timetz, cast(current_time as timestamp) AS ts, cast(current_time as timestamp(6)) AS tsp, cast(current_time as timestamp with time zone) AS tstz, cast(current_time as timestamp(6) with time zone) AS tsptz, interval '3' month AS ytm, interval '2' day AS dts, array['a', 'b'] AS varray, array[array['a'], array['b']] AS v2array, array[array[array['a'], array['b']]] AS v3array, map(array['a'], array[1]) AS map, array[map(array['a'], array[1]), map(array['b'], array[2])] AS marray, row('a', 1) AS row, cast(row('a', 1.23) AS row(x varchar, y double)) AS named_row, ipaddress '10.0.0.1' AS ip, uuid '12151fd2-7586-11e9-8f9e-2a86e4085a59' AS uuid`) require.NoError(t, err, "Failed executing query") assert.NotNil(t, rows) columns, err := rows.Columns() require.NoError(t, err, "Failed reading result columns") assert.Equal(t, 33, len(columns), "Expected 33 result column") expectedNames := []string{ "bool", "tinyint", "smallint", "integer", "bigint", "real", "double", "decimal", "vunbounded", "vbounded", "cunbounded", "cbounded", "varbinary", "json", "date", "time", "timep", "timetz", "ts", "tsp", "tstz", "tsptz", "ytm", "dts", "varray", "v2array", "v3array", "map", "marray", "row", "named_row", "ip", "uuid", } assert.Equal(t, expectedNames, columns) columnTypes, err := rows.ColumnTypes() require.NoError(t, err, "Failed reading result column types") assert.Equal(t, 33, len(columnTypes), "Expected 33 result column type") type columnType struct { typeName string hasScale bool precision int64 scale int64 hasLength bool length int64 scanType reflect.Type } expectedTypes := []columnType{ { "BOOLEAN", false, 0, 0, false, 0, reflect.TypeOf(sql.NullBool{}), }, { "TINYINT", false, 0, 0, false, 0, reflect.TypeOf(sql.NullInt32{}), }, { "SMALLINT", false, 0, 0, false, 0, reflect.TypeOf(sql.NullInt32{}), }, { "INTEGER", false, 0, 0, false, 0, reflect.TypeOf(sql.NullInt32{}), }, { "BIGINT", false, 0, 0, false, 0, reflect.TypeOf(sql.NullInt64{}), }, { "REAL", false, 0, 0, false, 0, reflect.TypeOf(sql.NullFloat64{}), }, { "DOUBLE", false, 0, 0, false, 0, reflect.TypeOf(sql.NullFloat64{}), }, { "DECIMAL", true, 10, 5, false, 0, reflect.TypeOf(sql.NullString{}), }, { "VARCHAR", false, 0, 0, true, math.MaxInt32, reflect.TypeOf(sql.NullString{}), }, { "VARCHAR", false, 0, 0, true, 10, reflect.TypeOf(sql.NullString{}), }, { "CHAR", false, 0, 0, true, 1, reflect.TypeOf(sql.NullString{}), }, { "CHAR", false, 0, 0, true, 10, reflect.TypeOf(sql.NullString{}), }, { "VARBINARY", false, 0, 0, false, 0, reflect.TypeOf([]byte{}), }, { "JSON", false, 0, 0, false, 0, reflect.TypeOf(sql.NullString{}), }, { "DATE", false, 0, 0, false, 0, reflect.TypeOf(sql.NullTime{}), }, { "TIME", true, 3, 0, false, 0, reflect.TypeOf(sql.NullTime{}), }, { "TIME", true, 6, 0, false, 0, reflect.TypeOf(sql.NullTime{}), }, { "TIME WITH TIME ZONE", true, 3, 0, false, 0, reflect.TypeOf(sql.NullTime{}), }, { "TIMESTAMP", true, 3, 0, false, 0, reflect.TypeOf(sql.NullTime{}), }, { "TIMESTAMP", true, 6, 0, false, 0, reflect.TypeOf(sql.NullTime{}), }, { "TIMESTAMP WITH TIME ZONE", true, 3, 0, false, 0, reflect.TypeOf(sql.NullTime{}), }, { "TIMESTAMP WITH TIME ZONE", true, 6, 0, false, 0, reflect.TypeOf(sql.NullTime{}), }, { "INTERVAL YEAR TO MONTH", false, 0, 0, false, 0, reflect.TypeOf(sql.NullString{}), }, { "INTERVAL DAY TO SECOND", false, 0, 0, false, 0, reflect.TypeOf(sql.NullString{}), }, { "ARRAY(VARCHAR(1))", false, 0, 0, false, 0, reflect.TypeOf(NullSliceString{}), }, { "ARRAY(ARRAY(VARCHAR(1)))", false, 0, 0, false, 0, reflect.TypeOf(NullSlice2String{}), }, { "ARRAY(ARRAY(ARRAY(VARCHAR(1))))", false, 0, 0, false, 0, reflect.TypeOf(NullSlice3String{}), }, { "MAP(VARCHAR(1), INTEGER)", false, 0, 0, false, 0, reflect.TypeOf(NullMap{}), }, { "ARRAY(MAP(VARCHAR(1), INTEGER))", false, 0, 0, false, 0, reflect.TypeOf(NullSliceMap{}), }, { "ROW(VARCHAR(1), INTEGER)", false, 0, 0, false, 0, reflect.TypeOf(new(interface{})).Elem(), }, { "ROW(X VARCHAR, Y DOUBLE)", false, 0, 0, false, 0, reflect.TypeOf(new(interface{})).Elem(), }, { "IPADDRESS", false, 0, 0, false, 0, reflect.TypeOf(sql.NullString{}), }, { "UUID", false, 0, 0, false, 0, reflect.TypeOf(sql.NullString{}), }, } actualTypes := make([]columnType, 33) for i, column := range columnTypes { actualTypes[i].typeName = column.DatabaseTypeName() actualTypes[i].precision, actualTypes[i].scale, actualTypes[i].hasScale = column.DecimalSize() actualTypes[i].length, actualTypes[i].hasLength = column.Length() actualTypes[i].scanType = column.ScanType() } assert.Equal(t, actualTypes, expectedTypes) } func TestMaxGoPrecisionDateTime(t *testing.T) { c := &Config{ ServerURI: *integrationServerFlag, SessionProperties: map[string]string{"query_priority": "1"}, } dsn, err := c.FormatDSN() require.NoError(t, err) db, err := sql.Open("trino", dsn) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, db.Close()) }) rows, err := db.Query(`SELECT cast(current_time as time(9)) AS timep, cast(current_time as time(9) with time zone) AS timeptz, cast(current_time as timestamp(9)) AS tsp, cast(current_time as timestamp(9) with time zone) AS tsptz`) require.NoError(t, err, "Failed executing query") assert.NotNil(t, rows) columns, err := rows.Columns() require.NoError(t, err, "Failed reading result columns") assert.Equal(t, 4, len(columns), "Expected 4 result column") expectedNames := []string{ "timep", "timeptz", "tsp", "tsptz", } assert.Equal(t, expectedNames, columns) columnTypes, err := rows.ColumnTypes() require.NoError(t, err, "Failed reading result column types") assert.Equal(t, 4, len(columnTypes), "Expected 4 result column type") type columnType struct { typeName string hasScale bool precision int64 scale int64 hasLength bool length int64 scanType reflect.Type } expectedTypes := []columnType{ { "TIME", true, 9, 0, false, 0, reflect.TypeOf(sql.NullTime{}), }, { "TIME WITH TIME ZONE", true, 9, 0, false, 0, reflect.TypeOf(sql.NullTime{}), }, { "TIMESTAMP", true, 9, 0, false, 0, reflect.TypeOf(sql.NullTime{}), }, { "TIMESTAMP WITH TIME ZONE", true, 9, 0, false, 0, reflect.TypeOf(sql.NullTime{}), }, } actualTypes := make([]columnType, 4) for i, column := range columnTypes { actualTypes[i].typeName = column.DatabaseTypeName() actualTypes[i].precision, actualTypes[i].scale, actualTypes[i].hasScale = column.DecimalSize() actualTypes[i].length, actualTypes[i].hasLength = column.Length() actualTypes[i].scanType = column.ScanType() } assert.Equal(t, actualTypes, expectedTypes) assert.True(t, rows.Next()) require.NoError(t, rows.Err()) } func TestQueryCancellation(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(&stmtResponse{ Error: ErrTrino{ ErrorName: "USER_CANCELLED", }, }) })) t.Cleanup(ts.Close) db, err := sql.Open("trino", ts.URL) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, db.Close()) }) _, err = db.Query("SELECT 1") assert.EqualError(t, err, ErrQueryCancelled.Error(), "unexpected error") } func TestRoleHeader(t *testing.T) { tests := []struct { name string roles map[string]string namedArgRoles map[string]string expectedHeader string }{ { name: "Roles from config", roles: map[string]string{"catalog1": "role1", "catalog2": "role2"}, namedArgRoles: nil, expectedHeader: `catalog1=ROLE{role1},catalog2=ROLE{role2}`, }, { name: "Override dsn roles with named argument", roles: map[string]string{"catalog1": "role1"}, namedArgRoles: map[string]string{"catalog3": "role3", "catalog4": "role4", "catalog5": "ALL"}, expectedHeader: `catalog3=ROLE{role3},catalog4=ROLE{role4},catalog5=ALL`, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var receivedHeader string var serverURL string ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { receivedHeader = r.Header.Get(trinoRoleHeader) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(`{"id":"1","nextUri":"` + serverURL + `/1"}`)) })) serverURL = ts.URL t.Cleanup(ts.Close) c := &Config{ ServerURI: ts.URL, Roles: tt.roles, } dsn, err := c.FormatDSN() require.NoError(t, err) db, err := sql.Open("trino", dsn) require.NoError(t, err) if tt.namedArgRoles != nil { _, _ = db.Query("SELECT 1", sql.Named("X-Trino-Role", tt.namedArgRoles)) } else { _, _ = db.Query("SELECT 1") } assert.Equal(t, tt.expectedHeader, receivedHeader, "expected X-Trino-Role header to match") }) } } func TestQueryFailure(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) })) t.Cleanup(ts.Close) db, err := sql.Open("trino", ts.URL) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, db.Close()) }) _, err = db.Query("SELECT 1") assert.IsTypef(t, new(ErrQueryFailed), err, "unexpected error: %w", err) } // This test ensures that the fetch method is not generating stack overflow errors. // === RUN TestFetchNoStackOverflow // runtime: goroutine stack exceeds 1000000000-byte limit // runtime: sp=0x14037b00390 stack=[0x14037b00000, 0x14057b00000] // fatal error: stack overflow func TestFetchNoStackOverflow(t *testing.T) { previousSetting := debug.SetMaxStack(50 * 1024) defer debug.SetMaxStack(previousSetting) count := 0 var buf *bytes.Buffer var ts *httptest.Server ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if count <= 50 { if buf == nil { buf = new(bytes.Buffer) json.NewEncoder(buf).Encode(&stmtResponse{ ID: "fake-query", NextURI: ts.URL + "/v1/statement/20210817_140827_00000_arvdv/1", }) } w.WriteHeader(http.StatusOK) w.Write(buf.Bytes()) count++ return } w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(&stmtResponse{ Error: ErrTrino{ ErrorName: "TEST", }, }) })) db, err := sql.Open("trino", ts.URL) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, db.Close()) }) _, err = db.Query("SELECT 1") assert.IsTypef(t, new(ErrQueryFailed), err, "unexpected error: %w", err) } func TestSpoolingProtocolSpooledSegmentDecoders(t *testing.T) { testcases := []struct { Name string Segments []map[string]interface{} ExpectedResult []int Encoding string DownloadedData []byte }{ { Name: "noCompression", Segments: []map[string]interface{}{ { "type": "spooled", "metadata": map[string]interface{}{"segmentSize": 16, "rowOffset": 0, "rowsCount": 2}, "ackUri": "test", "headers": map[string]interface{}{ "test": []interface{}{"test"}, }, }, }, Encoding: "json", ExpectedResult: []int{1000, 10001}, DownloadedData: []byte("[[1000],[10001]]"), }, { Name: "zstdCompression", Segments: []map[string]interface{}{ { "type": "spooled", "metadata": map[string]interface{}{"uncompressedSize": 16, "rowOffset": 0, "segmentSize": 29}, "ackUri": "test", "headers": map[string]interface{}{ "test": []interface{}{"test"}, }, }, }, Encoding: "json+zstd", ExpectedResult: []int{1000, 10001}, DownloadedData: mustDecodeBase64("KLUv/QQAgQAAW1sxMDAwXSxbMTAwMDFdXZfUttw="), }, { Name: "spooledSegmentWithoutHeadersOnReponse", // headers are optional Segments: []map[string]interface{}{ { "type": "spooled", "metadata": map[string]interface{}{"uncompressedSize": 16, "rowOffset": 0, "segmentSize": 29}, "ackUri": "test", }, }, Encoding: "json+zstd", ExpectedResult: []int{1000, 10001}, DownloadedData: mustDecodeBase64("KLUv/QQAgQAAW1sxMDAwXSxbMTAwMDFdXZfUttw="), }, { Name: "zlibCompression", Segments: []map[string]interface{}{ { "type": "spooled", "metadata": map[string]interface{}{"uncompressedSize": 16, "rowOffset": 0, "segmentSize": 18}, "ackUri": "test", "headers": map[string]interface{}{ "test": []interface{}{"test"}, }, }, }, Encoding: "json+lz4", ExpectedResult: []int{1000, 10001}, DownloadedData: mustDecodeBase64("8AFbWzEwMDBdLFsxMDAwMV1d"), }, } for _, tc := range testcases { t.Run(tc.Name, func(t *testing.T) { var ts *httptest.Server ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/v1/statement" { json.NewEncoder(w).Encode(&stmtResponse{ ID: "fake-query", NextURI: ts.URL + "/v1/statement/20210817_140827_00000_arvdv/1", }) return } if r.URL.Path == "/v1/statement/20210817_140827_00000_arvdv/1" { json.NewEncoder(w).Encode(&queryResponse{ ID: "fake-query", Columns: []queryColumn{ { Name: "_col0", Type: "integer", TypeSignature: typeSignature{ RawType: "integer", Arguments: []typeArgument{}, }, }, }, Data: map[string]interface{}{ "encoding": tc.Encoding, "segments": tc.Segments, }, }) return } if r.URL.Path == "/v1/spooled/download/jKaLK0aVkNp2ixl6BOuwGMJ0nRjbUVKLHW_f3-I-1Cc=" { w.Write(tc.DownloadedData) return } w.WriteHeader(http.StatusInternalServerError) json.NewEncoder(w).Encode(ErrTrino{ErrorName: "Unexpected request"}) })) defer ts.Close() tc.Segments[0]["uri"] = ts.URL + "/v1/spooled/download/jKaLK0aVkNp2ixl6BOuwGMJ0nRjbUVKLHW_f3-I-1Cc=" db, err := sql.Open("trino", ts.URL) require.NoError(t, err) defer db.Close() rows, err := db.Query("SELECT 1") require.NoError(t, err) var results []int for rows.Next() { var value int err := rows.Scan(&value) require.NoError(t, err) results = append(results, value) } require.NoError(t, rows.Err()) assert.Equal(t, tc.ExpectedResult, results, "Expected query results to match") }) } } func TestSpoolingProtocolToManyOutOfOrderSegmentDownload(t *testing.T) { segments := []map[string]interface{}{ { "type": "spooled", "metadata": map[string]interface{}{"segmentSize": 8, "rowOffset": 30, "rowsCount": 1}, "ackUri": "test", "headers": map[string]interface{}{ "test": []interface{}{"test"}, }, }, { "type": "spooled", "metadata": map[string]interface{}{"segmentSize": 8, "rowOffset": 20, "rowsCount": 1}, "ackUri": "test", "headers": map[string]interface{}{ "test": []interface{}{"test"}, }, }, { "type": "spooled", "metadata": map[string]interface{}{"segmentSize": 8, "rowOffset": 40, "rowsCount": 1}, "ackUri": "test", "headers": map[string]interface{}{ "test": []interface{}{"test"}, }, }, } var ts *httptest.Server ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/v1/statement": json.NewEncoder(w).Encode(&stmtResponse{ ID: "fake-query", NextURI: ts.URL + "/v1/statement/20210817_140827_00000_arvdv/1", }) return case "/v1/statement/20210817_140827_00000_arvdv/1": json.NewEncoder(w).Encode(&queryResponse{ ID: "fake-query", Columns: []queryColumn{ { Name: "_col0", Type: "integer", TypeSignature: typeSignature{ RawType: "integer", Arguments: []typeArgument{}, }, }, }, Data: map[string]interface{}{ "encoding": "json", "segments": segments, }, }) return case "/v1/spooled/download/jKaLK0aVkNp2ixl6BOuwGMJ0nRjbUVKLHW_f3-I-1Cc=": w.Write([]byte("[[1000]]")) return case "/v1/spooled/download/jKaLK0aVkNp2ixl6BOuwGMJ0nRjbUVKLHW_f3-I-1Cc1=": w.Write([]byte("[[1001]]")) return case "/v1/spooled/download/jKaLK0aVkNp2ixl6BOuwGMJ0nRjbUVKLHW_f3-I-1Cc2=": w.Write([]byte("[[1002]]")) return default: w.WriteHeader(http.StatusInternalServerError) json.NewEncoder(w).Encode(ErrTrino{ErrorName: "Unexpected request"}) } })) defer ts.Close() // Inject segment URIs into the segment definitions segments[0]["uri"] = ts.URL + "/v1/spooled/download/jKaLK0aVkNp2ixl6BOuwGMJ0nRjbUVKLHW_f3-I-1Cc=" segments[1]["uri"] = ts.URL + "/v1/spooled/download/jKaLK0aVkNp2ixl6BOuwGMJ0nRjbUVKLHW_f3-I-1Cc1=" segments[2]["uri"] = ts.URL + "/v1/spooled/download/jKaLK0aVkNp2ixl6BOuwGMJ0nRjbUVKLHW_f3-I-1Cc2=" db, err := sql.Open("trino", ts.URL) require.NoError(t, err) defer db.Close() rows, err := db.Query("SELECT 1", sql.Named(trinoMaxOutOfOrdersSegments, "3"), sql.Named(trinoSpoolingWorkerCount, "2")) require.NoError(t, err) for rows.Next() { var value int err := rows.Scan(&value) require.NoError(t, err) } require.Error(t, rows.Err()) require.ErrorContains(t, rows.Err(), "all 3 out-of-order segments buffered (limit: 3). This indicates a bug or inconsistency in the segments metadata response (e.g., missing, duplicate, or misordered segments, or row offsets not matching the expected sequence)") } func TestSpoolingProtocolOutOfOrderSegment(t *testing.T) { // Define the segments segments := []map[string]interface{}{ { "type": "spooled", "metadata": map[string]interface{}{"segmentSize": 8, "rowOffset": 2, "rowsCount": 1}, "ackUri": "test", "headers": map[string]interface{}{ "test": []interface{}{"test"}, }, }, { "type": "spooled", "metadata": map[string]interface{}{"segmentSize": 8, "rowOffset": 1, "rowsCount": 1}, "ackUri": "test", "headers": map[string]interface{}{ "test": []interface{}{"test"}, }, }, { "type": "spooled", "metadata": map[string]interface{}{"segmentSize": 8, "rowOffset": 0, "rowsCount": 1}, "ackUri": "test", "headers": map[string]interface{}{ "test": []interface{}{"test"}, }, }, } var ts *httptest.Server ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/v1/statement": json.NewEncoder(w).Encode(&stmtResponse{ ID: "fake-query", NextURI: ts.URL + "/v1/statement/20210817_140827_00000_arvdv/1", }) return case "/v1/statement/20210817_140827_00000_arvdv/1": json.NewEncoder(w).Encode(&queryResponse{ ID: "fake-query", Columns: []queryColumn{ { Name: "_col0", Type: "integer", TypeSignature: typeSignature{ RawType: "integer", Arguments: []typeArgument{}, }, }, }, Data: map[string]interface{}{ "encoding": "json", "segments": segments, }, }) return case "/v1/spooled/download/jKaLK0aVkNp2ixl6BOuwGMJ0nRjbUVKLHW_f3-I-1Cc=": w.Write([]byte("[[1000]]")) return case "/v1/spooled/download/jKaLK0aVkNp2ixl6BOuwGMJ0nRjbUVKLHW_f3-I-1Cc1=": w.Write([]byte("[[1001]]")) return case "/v1/spooled/download/jKaLK0aVkNp2ixl6BOuwGMJ0nRjbUVKLHW_f3-I-1Cc2=": w.Write([]byte("[[1002]]")) return default: w.WriteHeader(http.StatusInternalServerError) json.NewEncoder(w).Encode(ErrTrino{ErrorName: "Unexpected request"}) } })) defer ts.Close() // Inject segment URIs into the segment definitions segments[2]["uri"] = ts.URL + "/v1/spooled/download/jKaLK0aVkNp2ixl6BOuwGMJ0nRjbUVKLHW_f3-I-1Cc=" segments[1]["uri"] = ts.URL + "/v1/spooled/download/jKaLK0aVkNp2ixl6BOuwGMJ0nRjbUVKLHW_f3-I-1Cc1=" segments[0]["uri"] = ts.URL + "/v1/spooled/download/jKaLK0aVkNp2ixl6BOuwGMJ0nRjbUVKLHW_f3-I-1Cc2=" db, err := sql.Open("trino", ts.URL) require.NoError(t, err) defer db.Close() rows, err := db.Query("SELECT 1", sql.Named(trinoMaxOutOfOrdersSegments, "3"), sql.Named(trinoSpoolingWorkerCount, "1")) require.NoError(t, err) var results []int for rows.Next() { var value int err := rows.Scan(&value) require.NoError(t, err) results = append(results, value) } require.NoError(t, rows.Err()) expected := []int{1000, 1001, 1002} assert.Equal(t, expected, results, "Expected query results to match") } func TestSpoolingProtocolSegmentDownloadRetryFails(t *testing.T) { testcases := []struct { Name string ExpectedErrorMsg string SimulateTimeout bool HttpStatusReponse int }{ { Name: "Test retry 502 Bad Gateway", HttpStatusReponse: http.StatusBadGateway, }, { Name: "Test retry 503 Service Unavailable", HttpStatusReponse: http.StatusServiceUnavailable, }, { Name: "Test retry 504 Gateway Timeout", HttpStatusReponse: http.StatusGatewayTimeout, }, } for _, tc := range testcases { t.Run(tc.Name, func(t *testing.T) { var ts *httptest.Server var failCounter int ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/v1/statement" { json.NewEncoder(w).Encode(&stmtResponse{ ID: "fake-query", NextURI: ts.URL + "/v1/statement/20210817_140827_00000_arvdv/1", }) return } if r.URL.Path == "/v1/statement/20210817_140827_00000_arvdv/1" { json.NewEncoder(w).Encode(&queryResponse{ ID: "fake-query", Columns: []queryColumn{ { Name: "_col0", Type: "integer", TypeSignature: typeSignature{ RawType: "integer", Arguments: []typeArgument{}, }, }, }, Data: map[string]interface{}{ "encoding": "json", "segments": []map[string]interface{}{ { "uri": ts.URL + "/v1/spooled/download/jKaLK0aVkNp2ixl6BOuwGMJ0nRjbUVKLHW_f3-I-1Cc=", "type": "spooled", "metadata": map[string]interface{}{"segmentSize": 325, "rowOffset": 0, "rowsCount": 1}, "ackUri": "test", "headers": map[string]interface{}{ "test": []interface{}{"test"}, }, }, }, }, }) return } if r.URL.Path == "/v1/spooled/download/jKaLK0aVkNp2ixl6BOuwGMJ0nRjbUVKLHW_f3-I-1Cc=" { if failCounter < 2 { failCounter++ w.WriteHeader(tc.HttpStatusReponse) return } w.WriteHeader(http.StatusOK) w.Write([]byte("[[1000]]")) } w.WriteHeader(http.StatusInternalServerError) json.NewEncoder(w).Encode(ErrTrino{ErrorName: "Unexpected request"}) })) defer ts.Close() db, err := sql.Open("trino", ts.URL) require.NoError(t, err) defer db.Close() rows, err := db.Query("SELECT 1") require.NoError(t, err) var results []int for rows.Next() { var value int err := rows.Scan(&value) require.NoError(t, err) results = append(results, value) } require.NoError(t, rows.Err()) assert.Equal(t, []int{1000}, results, "Expected query results to match") assert.Equal(t, 2, failCounter, "Expected segment download to fail exactly 2 times before succeeding") }) } } func TestSpoolingProtocolSegmentDownloadRetryMaxAttempts(t *testing.T) { var ts *httptest.Server failCounter := 0 maxRetries := 6 ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/v1/statement": json.NewEncoder(w).Encode(&stmtResponse{ ID: "fake-query", NextURI: ts.URL + "/v1/statement/20210817_140827_00000_arvdv/1", }) case "/v1/statement/20210817_140827_00000_arvdv/1": json.NewEncoder(w).Encode(&queryResponse{ ID: "fake-query", Columns: []queryColumn{ { Name: "_col0", Type: "integer", TypeSignature: typeSignature{ RawType: "integer", Arguments: []typeArgument{}, }, }, }, Data: map[string]interface{}{ "encoding": "json", "segments": []map[string]interface{}{ { "uri": ts.URL + "/v1/spooled/download/jKaLK0aVkNp2ixl6BOuwGMJ0nRjbUVKLHW_f3-I-1Cc=", "type": "spooled", "metadata": map[string]interface{}{"segmentSize": 325, "rowOffset": 0, "rowsCount": 1}, "ackUri": "test", "headers": map[string]interface{}{ "test": []interface{}{"test"}, }, }, }, }, }) case "/v1/spooled/download/jKaLK0aVkNp2ixl6BOuwGMJ0nRjbUVKLHW_f3-I-1Cc=": if failCounter <= maxRetries { failCounter++ w.WriteHeader(http.StatusBadGateway) return } default: w.WriteHeader(http.StatusInternalServerError) json.NewEncoder(w).Encode(ErrTrino{ErrorName: "Unexpected request"}) } })) defer ts.Close() db, err := sql.Open("trino", ts.URL) require.NoError(t, err) defer db.Close() rows, err := db.Query("SELECT 1") require.NoError(t, err) for rows.Next() { } require.Error(t, rows.Err()) require.ErrorContains(t, rows.Err(), "max retries reached for status code 502") assert.Equal(t, maxRetries, failCounter, "Expected segment download to fail exactly 5 times before succeeding") } func mustDecodeBase64(encoded string) []byte { data, err := base64.StdEncoding.DecodeString(encoded) if err != nil { panic(fmt.Sprintf("Failed to decode base64: %v", err)) } return data } func TestSpoolingProtocolOnlyWithInlineSegments(t *testing.T) { var ts *httptest.Server ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/v1/statement" { json.NewEncoder(w).Encode(&stmtResponse{ ID: "fake-query", NextURI: ts.URL + "/v1/statement/20210817_140827_00000_arvdv/1", }) return } if r.URL.Path == "/v1/statement/20210817_140827_00000_arvdv/1" { json.NewEncoder(w).Encode(&queryResponse{ ID: "fake-query", Columns: []queryColumn{ { Name: "_col0", Type: "integer", TypeSignature: typeSignature{ RawType: "integer", Arguments: []typeArgument{}, }, }, }, Data: map[string]interface{}{ "encoding": "json", "segments": []map[string]interface{}{ { "type": "inline", "data": "W1sxMDAwXSwgWzEwMDAxXV0=", "metadata": map[string]interface{}{"segmentSize": 17, "rowOffset": 0}, }, { "type": "inline", "data": "W1sxMDAwXSwgWzEwMDAxXV0=", "metadata": map[string]interface{}{"segmentSize": 17, "rowOffset": 2}, }, { "type": "inline", "data": "W1sxMDAwXSwgWzEwMDAxXV0=", "metadata": map[string]interface{}{"segmentSize": 17, "rowOffset": 4}, }, { "type": "inline", "data": "W1sxMDAwXSwgWzEwMDAxXV0=", "metadata": map[string]interface{}{"segmentSize": 17, "rowOffset": 6}, }, }, }, }) return } w.WriteHeader(http.StatusInternalServerError) json.NewEncoder(w).Encode(ErrTrino{ErrorName: "Unexpected request"}) })) defer ts.Close() db, err := sql.Open("trino", ts.URL) require.NoError(t, err) defer db.Close() rows, err := db.Query("SELECT 1", sql.Named(trinoSpoolingWorkerCount, "2"), sql.Named(trinoMaxOutOfOrdersSegments, "2")) require.NoError(t, err) var results []int for rows.Next() { var value int err := rows.Scan(&value) require.NoError(t, err) results = append(results, value) } require.NoError(t, rows.Err()) assert.Equal(t, []int{1000, 10001, 1000, 10001, 1000, 10001, 1000, 10001}, results, "Expected query results to match") } func TestSpoolingProtocolInlineSegmentDecoders(t *testing.T) { testcases := []struct { Name string Segments []map[string]interface{} ExpectedResult []int Encoding string }{ { Name: "noCompression", Segments: []map[string]interface{}{ { "type": "inline", "data": "W1sxMDAwXSwgWzEwMDAxXV0=", "metadata": map[string]interface{}{"segmentSize": 17, "rowOffset": 0}, }, }, Encoding: "json", ExpectedResult: []int{1000, 10001}, }, { Name: "zstdCompression", Segments: []map[string]interface{}{ { "type": "inline", "data": "KLUv/QQAgQAAW1sxMDAwXSxbMTAwMDFdXZfUttw=", "metadata": map[string]interface{}{"uncompressedSize": 16, "rowOffset": 0, "segmentSize": 29}, }, }, Encoding: "json+zstd", ExpectedResult: []int{1000, 10001}, }, { Name: "zlibCompression", Segments: []map[string]interface{}{ { "type": "inline", "data": "8AFbWzEwMDBdLFsxMDAwMV1d", "metadata": map[string]interface{}{"uncompressedSize": 16, "rowOffset": 0, "segmentSize": 18}, }, }, Encoding: "json+lz4", ExpectedResult: []int{1000, 10001}, }, } for _, tc := range testcases { t.Run(tc.Name, func(t *testing.T) { var ts *httptest.Server ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/v1/statement" { json.NewEncoder(w).Encode(&stmtResponse{ ID: "fake-query", NextURI: ts.URL + "/v1/statement/20210817_140827_00000_arvdv/1", }) return } if r.URL.Path == "/v1/statement/20210817_140827_00000_arvdv/1" { json.NewEncoder(w).Encode(&queryResponse{ ID: "fake-query", Columns: []queryColumn{ { Name: "_col0", Type: "integer", TypeSignature: typeSignature{ RawType: "integer", Arguments: []typeArgument{}, }, }, }, Data: map[string]interface{}{ "encoding": tc.Encoding, "segments": tc.Segments, }, }) return } w.WriteHeader(http.StatusInternalServerError) json.NewEncoder(w).Encode(ErrTrino{ErrorName: "Unexpected request"}) })) defer ts.Close() db, err := sql.Open("trino", ts.URL) require.NoError(t, err) defer db.Close() rows, err := db.Query("SELECT 1") require.NoError(t, err) var results []int for rows.Next() { var value int err := rows.Scan(&value) require.NoError(t, err) results = append(results, value) } require.NoError(t, rows.Err()) assert.Equal(t, tc.ExpectedResult, results, "Expected query results to match") }) } } func TestSpoolingProtocolSpooledSegmentErrorHandling(t *testing.T) { testcases := []struct { name string segments []map[string]interface{} expectedError string downloadedData []byte downloadedDataStatusCodeError bool }{ { name: "MissingRowOffsetMetadata", segments: []map[string]interface{}{ { "type": "spooled", "metadata": map[string]interface{}{"uncompressedSize": 2, "segmentSize": 11}, "ackUri": "test", "headers": map[string]interface{}{ "test": []interface{}{"test"}, }, }, }, expectedError: "rowOffset is missing in segment metadata", }, { name: "WrongRowOffsetMetadataType", segments: []map[string]interface{}{ { "type": "spooled", "metadata": map[string]interface{}{"uncompressedSize": 2, "rowOffset": "2", "segmentSize": 11}, "ackUri": "test", "headers": map[string]interface{}{ "test": []interface{}{"test"}, }, }, }, expectedError: "invalid type for rowOffset in segment metadata, expected json.Number", }, { name: "MissingSegmentSizeMetadata", segments: []map[string]interface{}{ { "type": "spooled", "metadata": map[string]interface{}{"uncompressedSize": 2, "rowOffset": 2}, "ackUri": "test", "headers": map[string]interface{}{ "test": []interface{}{"test"}, }, }, }, expectedError: "segmentSize is missing in segment metadata", }, { name: "WrongSegmentSizeMetadataType", segments: []map[string]interface{}{ { "type": "spooled", "metadata": map[string]interface{}{"uncompressedSize": 2, "rowOffset": 2, "segmentSize": "11"}, "ackUri": "test", "headers": map[string]interface{}{ "test": []interface{}{"test"}, }, }, }, expectedError: "invalid type for segmentSize in segment metadata, expected json.Number", }, { name: "MissingMetadata", segments: []map[string]interface{}{ { "type": "spooled", "ackUri": "test", "headers": map[string]interface{}{ "test": []interface{}{"test"}, }, }, }, expectedError: "metadata is missing in segment at index 0", }, { name: "WrongMetadataType", segments: []map[string]interface{}{ { "type": "spooled", "metadata": "fake-metadata", "ackUri": "test", "headers": map[string]interface{}{ "test": []interface{}{"test"}, }, }, }, expectedError: "metadata is invalid or cannot be parsed as map[string]interface{} in segment at index 0", }, { name: "WrongUncompressSize", segments: []map[string]interface{}{ { "type": "spooled", "metadata": map[string]interface{}{"uncompressedSize": 2, "rowOffset": 2, "segmentSize": 11}, "ackUri": "test", "headers": map[string]interface{}{ "test": []interface{}{"test"}, }, }, }, expectedError: "failed to decode spooled segment at index 0: segment size mismatch: expected 11 bytes, got 29 byte", downloadedData: mustDecodeBase64("KLUv/QQAgQAAW1sxMDAwXSxbMTAwMDFdXZfUttw="), }, { name: "WrongCompresSize", segments: []map[string]interface{}{ { "type": "spooled", "metadata": map[string]interface{}{"uncompressedSize": 2, "rowOffset": 2, "segmentSize": 29}, "ackUri": "test", "headers": map[string]interface{}{ "test": []interface{}{"test"}, }, }, }, expectedError: "decompressed size mismatch: expected 2 bytes, got 16 bytes", downloadedData: mustDecodeBase64("KLUv/QQAgQAAW1sxMDAwXSxbMTAwMDFdXZfUttw="), }, { name: "MissingUri", segments: []map[string]interface{}{ { "type": "spooled", "data": "fake-data", "ackUri": "test", "metadata": map[string]interface{}{ "segmentSize": 3679, "uncompressedSize": 2, "rowOffset": 0, }, "headers": map[string][]interface{}{ "x-amz-server-side-encryption-customer-algorithm": {"AES256"}, "x-amz-server-side-encryption-customer-key": {"key"}, "x-amz-server-side-encryption-customer-key-md5": {"md5"}, }, }, }, expectedError: "missing or invalid 'uri' field in spooled segment at index 0", }, { name: "MissingUriAck", segments: []map[string]interface{}{ { "type": "spooled", "data": "fake-data", "uri": "fake-uri", "metadata": map[string]interface{}{ "segmentSize": 3679, "uncompressedSize": 2, "rowOffset": 0, }, "headers": map[string][]interface{}{ "x-amz-server-side-encryption-customer-algorithm": {"AES256"}, "x-amz-server-side-encryption-customer-key": {"key"}, "x-amz-server-side-encryption-customer-key-md5": {"md5"}, }, }, }, expectedError: "missing or invalid 'ackUri' field in spooled segment at index 0", }, { name: "wrongHeadersFormat", segments: []map[string]interface{}{ { "type": "spooled", "data": "fake-data", "uri": "fake-uri", "ackUri": "test", "metadata": map[string]interface{}{ "segmentSize": 3679, "uncompressedSize": 2, "rowOffset": 0, }, "headers": [][]string{ {"x-amz-server-side-encryption-customer-algorithm", "AES256"}, {"x-amz-server-side-encryption-customer-key", "key"}, }, }, }, expectedError: "invalid 'headers' field in spooled segment at index 0: expected map[string]interface{}", }, { name: "HeadersWithMultipleValues", segments: []map[string]interface{}{ { "type": "spooled", "data": "fake-data", "uri": "fake-uri", "ackUri": "test", "metadata": map[string]interface{}{ "segmentSize": 3679, "uncompressedSize": 2, "rowOffset": 0, }, "headers": map[string][]interface{}{ "x-amz-server-side-encryption-customer-algorithm": {"AES256"}, "x-amz-server-side-encryption-customer-key": {"key"}, "x-amz-server-side-encryption-customer-key-md5": {"md5", "md5"}, // wrong, more then one }, }, }, expectedError: "multiple values for header x-amz-server-side-encryption-customer-key-md5", }, { name: "HeaderValueWrongType", segments: []map[string]interface{}{ { "type": "spooled", "data": "fake-data", "uri": "fake-uri", "ackUri": "test", "metadata": map[string]interface{}{ "segmentSize": 3679, "uncompressedSize": 2, "rowOffset": 0, }, "headers": map[string]interface{}{ "x-amz-server-side-encryption-customer-algorithm": []interface{}{"AES256"}, "x-amz-server-side-encryption-customer-key": []interface{}{"key"}, "x-amz-server-side-encryption-customer-key-md5": []interface{}{123}, // Wrong type: integer instead of string }, }, }, expectedError: "unsupported header value type json.Number", }, { name: "HeaderTypeInvalid", segments: []map[string]interface{}{ { "type": "spooled", "data": "fake-data", "uri": "fake-uri", "ackUri": "test", "metadata": map[string]interface{}{ "segmentSize": 3679, "uncompressedSize": 2, "rowOffset": 0, }, "headers": map[string]interface{}{ "x-amz-server-side-encryption-customer-algorithm": "AES256", // Invalid type: string instead of []interface{} }, }, }, expectedError: "unsupported header type string", }, { name: "ErrorDownloadingSegment", segments: []map[string]interface{}{ { "type": "spooled", "metadata": map[string]interface{}{"uncompressedSize": 2, "rowOffset": 2, "segmentSize": 11}, "ackUri": "test", "headers": map[string]interface{}{ "test": []interface{}{"test"}, }, }, }, expectedError: "trino: query failed (500 Internal Server Error):", downloadedData: mustDecodeBase64("KLUv/QQAgQAAW1sxMDAwXSxbMTAwMDFdXZfUttw="), downloadedDataStatusCodeError: true, }, } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { var ts *httptest.Server ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/v1/statement" { json.NewEncoder(w).Encode(&stmtResponse{ ID: "fake-query", NextURI: ts.URL + "/v1/statement/20210817_140827_00000_arvdv/1", }) return } if r.URL.Path == "/v1/statement/20210817_140827_00000_arvdv/1" { json.NewEncoder(w).Encode(&queryResponse{ ID: "fake-query", Columns: []queryColumn{ { Name: "_col0", Type: "integer", TypeSignature: typeSignature{ RawType: "integer", Arguments: []typeArgument{}, }, }, }, Data: map[string]interface{}{ "encoding": "json+zstd", "segments": tc.segments, }, }) return } if r.URL.Path == "/v1/spooled/download/jKaLK0aVkNp2ixl6BOuwGMJ0nRjbUVKLHW_f3-I-1Cc=" { if tc.downloadedDataStatusCodeError { w.WriteHeader(http.StatusInternalServerError) } w.Write(tc.downloadedData) return } w.WriteHeader(http.StatusInternalServerError) json.NewEncoder(w).Encode(ErrTrino{ErrorName: "Unexpected request"}) })) defer ts.Close() if tc.name != "MissingUri" { tc.segments[0]["uri"] = ts.URL + "/v1/spooled/download/jKaLK0aVkNp2ixl6BOuwGMJ0nRjbUVKLHW_f3-I-1Cc=" } db, err := sql.Open("trino", ts.URL) require.NoError(t, err) defer db.Close() rows, err := db.Query("SELECT 1") require.NoError(t, err) defer rows.Close() for rows.Next() { // force segment processing } err = rows.Err() require.Error(t, err) require.Contains(t, err.Error(), tc.expectedError) }) } } func TestSpoolingProtocolInlineSegmentErrorHandling(t *testing.T) { testcases := []struct { name string segments []map[string]interface{} expectedError string }{ { name: "WrongUncompressSize", segments: []map[string]interface{}{ { "type": "inline", "data": "KLUv/QQAgQAAW1sxMDAwXSxbMTAwMDFdXZfUttw=", "metadata": map[string]interface{}{"uncompressedSize": 1, "rowOffset": 2, "segmentSize": 29}, }, }, expectedError: "failed to decode spooled segment at index 0: decompressed size mismatch: expected 1 bytes, got 16 bytes", }, { name: "WrongCompresSize", segments: []map[string]interface{}{ { "type": "inline", "data": "KLUv/QQAgQAAW1sxMDAwXSxbMTAwMDFdXZfUttw=", "metadata": map[string]interface{}{"uncompressedSize": 16, "rowOffset": 2, "segmentSize": 1}, }, }, expectedError: "failed to decode spooled segment at index 0: segment size mismatch: expected 1 bytes, got 29 bytes", }, } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { var ts *httptest.Server ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/v1/statement" { json.NewEncoder(w).Encode(&stmtResponse{ ID: "fake-query", NextURI: ts.URL + "/v1/statement/20210817_140827_00000_arvdv/1", }) return } if r.URL.Path == "/v1/statement/20210817_140827_00000_arvdv/1" { json.NewEncoder(w).Encode(&queryResponse{ ID: "fake-query", Columns: []queryColumn{ { Name: "_col0", Type: "integer", TypeSignature: typeSignature{ RawType: "integer", Arguments: []typeArgument{}, }, }, }, Data: map[string]interface{}{ "encoding": "json+zstd", "segments": tc.segments, }, }) return } w.WriteHeader(http.StatusInternalServerError) json.NewEncoder(w).Encode(ErrTrino{ErrorName: "Unexpected request"}) })) defer ts.Close() if tc.name != "MissingUri" { tc.segments[0]["uri"] = ts.URL + "/v1/spooled/download/jKaLK0aVkNp2ixl6BOuwGMJ0nRjbUVKLHW_f3-I-1Cc=" } db, err := sql.Open("trino", ts.URL) require.NoError(t, err) defer db.Close() rows, err := db.Query("SELECT 1") require.NoError(t, err) for rows.Next() { // force segment processing } err = rows.Err() require.Error(t, err) require.Contains(t, err.Error(), tc.expectedError) }) } } func TestProtocolErrorHandling(t *testing.T) { testcases := []struct { name string data interface{} expectedError string }{ { name: "DirectProtocolInvalidRowType", data: []interface{}{ 123, }, expectedError: "unexpected data type for row at index 0: expected []interface{}, got json.Number", }, { name: "SpoolingProtocolMissingEncoding", data: map[string]interface{}{ "segments": []interface{}{}, // Missing "encoding" field }, expectedError: "invalid or missing 'encoding' field on spooling protocol, expected string", }, { name: "SpoolingProtocolInvalidSegmentsType", data: map[string]interface{}{ "encoding": "json", "segments": "invalid", // Invalid type for "segments" }, expectedError: "nvalid or missing 'segments' field on spooling protocol, expected []interface{}", }, } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { var ts *httptest.Server ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/v1/statement" { json.NewEncoder(w).Encode(&stmtResponse{ ID: "fake-query", NextURI: ts.URL + "/v1/statement/20210817_140827_00000_arvdv/1", }) return } if r.URL.Path == "/v1/statement/20210817_140827_00000_arvdv/1" { json.NewEncoder(w).Encode(&queryResponse{ ID: "fake-query", Columns: []queryColumn{ { Name: "_col0", Type: "integer", TypeSignature: typeSignature{ RawType: "integer", Arguments: []typeArgument{}, }, }, }, Data: tc.data, }) return } w.WriteHeader(http.StatusInternalServerError) json.NewEncoder(w).Encode(ErrTrino{ErrorName: "Unexpected request"}) })) defer ts.Close() db, err := sql.Open("trino", ts.URL) require.NoError(t, err) defer db.Close() _, err = db.Query("SELECT 1") require.Error(t, err) require.Contains(t, err.Error(), tc.expectedError) }) } } func TestSession(t *testing.T) { if testing.Short() { t.Skip("Skipping test in short mode.") } err := RegisterCustomClient("uncompressed", &http.Client{Transport: &http.Transport{DisableCompression: true}}) if err != nil { t.Fatal(err) } c := &Config{ ServerURI: *integrationServerFlag + "?custom_client=uncompressed", SessionProperties: map[string]string{"query_priority": "1"}, } dsn, err := c.FormatDSN() require.NoError(t, err) db, err := sql.Open("trino", dsn) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, db.Close()) }) _, err = db.Exec("SET SESSION join_distribution_type='BROADCAST'") require.NoError(t, err, "Failed executing query") row := db.QueryRow("SHOW SESSION LIKE 'join_distribution_type'") var name string var value string var defaultValue string var typeName string var description string err = row.Scan(&name, &value, &defaultValue, &typeName, &description) require.NoError(t, err, "Failed executing query") assert.Equal(t, "BROADCAST", value) _, err = db.Exec("RESET SESSION join_distribution_type") require.NoError(t, err, "Failed executing query") row = db.QueryRow("SHOW SESSION LIKE 'join_distribution_type'") err = row.Scan(&name, &value, &defaultValue, &typeName, &description) require.NoError(t, err, "Failed executing query") assert.Equal(t, "AUTOMATIC", value) } func TestSetRoleHeader(t *testing.T) { var firstRoleHeader string var secondRoleHeader string var requestCount int var baseURL string ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { requestCount++ roleHeader := r.Header.Get(trinoRoleHeader) if r.URL.Path == "/v1/statement" { // Capture the initial role from DSN firstRoleHeader = roleHeader w.Header().Set(trinoSetRoleHeader, "ROLE%7Badmin%7D") w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(&stmtResponse{ ID: "query1", NextURI: baseURL + "/v1/statement/query1/1", Stats: stmtStats{ State: "RUNNING", }, }) } else if r.URL.Path == "/v1/statement/query1/1" { // Capture the role in subsequent request(e.g after server set) secondRoleHeader = roleHeader w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(&queryResponse{ ID: "query1", Stats: stmtStats{ State: "FINISHED", }, Data: [][]interface{}{{1}}, Columns: []queryColumn{ { Name: "_col0", Type: "integer", TypeSignature: typeSignature{ RawType: "integer", Arguments: []typeArgument{}, }, }, }, }) } else if r.Method == "DELETE" && r.URL.Path == "/v1/query/query1" { w.WriteHeader(http.StatusNoContent) } else { w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(&queryResponse{ ID: "query1", Stats: stmtStats{ State: "FINISHED", }, }) } })) baseURL = ts.URL t.Cleanup(ts.Close) db, err := sql.Open("trino", ts.URL+"?roles=catalog%3Auser") require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, db.Close()) }) rows, err := db.Query("SELECT 1") require.NoError(t, err) require.NoError(t, rows.Close()) assert.Equal(t, `catalog=ROLE{user}`, firstRoleHeader, "initial role from DSN should be sent in first request") assert.Equal(t, "ROLE%7Badmin%7D", secondRoleHeader, "server-set role should be sent in subsequent requests") assert.NotEqual(t, firstRoleHeader, secondRoleHeader, "role should have changed from DSN value to server-set value") } func TestUnsupportedHeader(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set(trinoSetPathHeader, "foo.bar") w.WriteHeader(http.StatusOK) })) t.Cleanup(ts.Close) db, err := sql.Open("trino", ts.URL) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, db.Close()) }) _, err = db.Query("SELECT 1") assert.EqualError(t, err, ErrUnsupportedHeader.Error(), "unexpected error") } func TestSSLCertPath(t *testing.T) { db, err := sql.Open("trino", "https://localhost:9?SSLCertPath=/tmp/invalid_test.cert") require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, db.Close()) }) want := "Error loading SSL Cert File" err = db.Ping() require.Error(t, err) require.Contains(t, err.Error(), want) } func TestWithoutSSLCertPath(t *testing.T) { db, err := sql.Open("trino", "https://localhost:9") require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, db.Close()) }) assert.NoError(t, db.Ping()) } func TestUnsupportedTransaction(t *testing.T) { db, err := sql.Open("trino", "http://localhost:9") require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, db.Close()) }) _, err = db.Begin() require.Error(t, err, "unsupported transaction succeeded with no error") expected := "operation not supported" assert.Contains(t, err.Error(), expected) } func TestTypeConversion(t *testing.T) { utc, err := time.LoadLocation("UTC") require.NoError(t, err) paris, err := time.LoadLocation("Europe/Paris") require.NoError(t, err) testcases := []struct { DataType string RawType string Arguments []typeArgument ResponseUnmarshalledSample interface{} ExpectedGoValue interface{} }{ { DataType: "boolean", RawType: "boolean", ResponseUnmarshalledSample: true, ExpectedGoValue: true, }, { DataType: "varchar(1)", RawType: "varchar", ResponseUnmarshalledSample: "hello", ExpectedGoValue: "hello", }, { DataType: "bigint", RawType: "bigint", ResponseUnmarshalledSample: json.Number("1234516165077230279"), ExpectedGoValue: int64(1234516165077230279), }, { DataType: "double", RawType: "double", ResponseUnmarshalledSample: json.Number("1.0"), ExpectedGoValue: float64(1), }, { DataType: "date", RawType: "date", ResponseUnmarshalledSample: "2017-07-10", ExpectedGoValue: time.Date(2017, 7, 10, 0, 0, 0, 0, time.Local), }, { DataType: "time", RawType: "time", ResponseUnmarshalledSample: "01:02:03.000", ExpectedGoValue: time.Date(0, 1, 1, 1, 2, 3, 0, time.Local), }, { DataType: "time with time zone", RawType: "time with time zone", ResponseUnmarshalledSample: "01:02:03.000 UTC", ExpectedGoValue: time.Date(0, 1, 1, 1, 2, 3, 0, utc), }, { DataType: "time with time zone", RawType: "time with time zone", ResponseUnmarshalledSample: "01:02:03.000 +03:00", ExpectedGoValue: time.Date(0, 1, 1, 1, 2, 3, 0, time.FixedZone("", 3*3600)), }, { DataType: "time with time zone", RawType: "time with time zone", ResponseUnmarshalledSample: "01:02:03.000+03:00", ExpectedGoValue: time.Date(0, 1, 1, 1, 2, 3, 0, time.FixedZone("", 3*3600)), }, { DataType: "time with time zone", RawType: "time with time zone", ResponseUnmarshalledSample: "01:02:03.000 -05:00", ExpectedGoValue: time.Date(0, 1, 1, 1, 2, 3, 0, time.FixedZone("", -5*3600)), }, { DataType: "time with time zone", RawType: "time with time zone", ResponseUnmarshalledSample: "01:02:03.000-05:00", ExpectedGoValue: time.Date(0, 1, 1, 1, 2, 3, 0, time.FixedZone("", -5*3600)), }, { DataType: "time", RawType: "time", ResponseUnmarshalledSample: "01:02:03.123456789", ExpectedGoValue: time.Date(0, 1, 1, 1, 2, 3, 123456789, time.Local), }, { DataType: "time with time zone", RawType: "time with time zone", ResponseUnmarshalledSample: "01:02:03.123456789 UTC", ExpectedGoValue: time.Date(0, 1, 1, 1, 2, 3, 123456789, utc), }, { DataType: "time with time zone", RawType: "time with time zone", ResponseUnmarshalledSample: "01:02:03.123456789 +03:00", ExpectedGoValue: time.Date(0, 1, 1, 1, 2, 3, 123456789, time.FixedZone("", 3*3600)), }, { DataType: "time with time zone", RawType: "time with time zone", ResponseUnmarshalledSample: "01:02:03.123456789+03:00", ExpectedGoValue: time.Date(0, 1, 1, 1, 2, 3, 123456789, time.FixedZone("", 3*3600)), }, { DataType: "time with time zone", RawType: "time with time zone", ResponseUnmarshalledSample: "01:02:03.123456789 -05:00", ExpectedGoValue: time.Date(0, 1, 1, 1, 2, 3, 123456789, time.FixedZone("", -5*3600)), }, { DataType: "time with time zone", RawType: "time with time zone", ResponseUnmarshalledSample: "01:02:03.123456789-05:00", ExpectedGoValue: time.Date(0, 1, 1, 1, 2, 3, 123456789, time.FixedZone("", -5*3600)), }, { DataType: "time with time zone", RawType: "time with time zone", ResponseUnmarshalledSample: "01:02:03.123456789 Europe/Paris", ExpectedGoValue: time.Date(0, 1, 1, 1, 2, 3, 123456789, paris), }, { DataType: "timestamp", RawType: "timestamp", ResponseUnmarshalledSample: "2017-07-10 01:02:03.000", ExpectedGoValue: time.Date(2017, 7, 10, 1, 2, 3, 0, time.Local), }, { DataType: "timestamp with time zone", RawType: "timestamp with time zone", ResponseUnmarshalledSample: "2017-07-10 01:02:03.000 UTC", ExpectedGoValue: time.Date(2017, 7, 10, 1, 2, 3, 0, utc), }, { DataType: "timestamp with time zone", RawType: "timestamp with time zone", ResponseUnmarshalledSample: "2017-07-10 01:02:03.000 +03:00", ExpectedGoValue: time.Date(2017, 7, 10, 1, 2, 3, 0, time.FixedZone("", 3*3600)), }, { DataType: "timestamp with time zone", RawType: "timestamp with time zone", ResponseUnmarshalledSample: "2017-07-10 01:02:03.000+03:00", ExpectedGoValue: time.Date(2017, 7, 10, 1, 2, 3, 0, time.FixedZone("", 3*3600)), }, { DataType: "timestamp with time zone", RawType: "timestamp with time zone", ResponseUnmarshalledSample: "2017-07-10 01:02:03.000 -04:00", ExpectedGoValue: time.Date(2017, 7, 10, 1, 2, 3, 0, time.FixedZone("", -4*3600)), }, { DataType: "timestamp with time zone", RawType: "timestamp with time zone", ResponseUnmarshalledSample: "2017-07-10 01:02:03.000-04:00", ExpectedGoValue: time.Date(2017, 7, 10, 1, 2, 3, 0, time.FixedZone("", -4*3600)), }, { DataType: "timestamp", RawType: "timestamp", ResponseUnmarshalledSample: "2017-07-10 01:02:03.123456789", ExpectedGoValue: time.Date(2017, 7, 10, 1, 2, 3, 123456789, time.Local), }, { DataType: "timestamp with time zone", RawType: "timestamp with time zone", ResponseUnmarshalledSample: "2017-07-10 01:02:03.123456789 UTC", ExpectedGoValue: time.Date(2017, 7, 10, 1, 2, 3, 123456789, utc), }, { DataType: "timestamp with time zone", RawType: "timestamp with time zone", ResponseUnmarshalledSample: "2017-07-10 01:02:03.123456789 +03:00", ExpectedGoValue: time.Date(2017, 7, 10, 1, 2, 3, 123456789, time.FixedZone("", 3*3600)), }, { DataType: "timestamp with time zone", RawType: "timestamp with time zone", ResponseUnmarshalledSample: "2017-07-10 01:02:03.123456789+03:00", ExpectedGoValue: time.Date(2017, 7, 10, 1, 2, 3, 123456789, time.FixedZone("", 3*3600)), }, { DataType: "timestamp with time zone", RawType: "timestamp with time zone", ResponseUnmarshalledSample: "2017-07-10 01:02:03.123456789 -04:00", ExpectedGoValue: time.Date(2017, 7, 10, 1, 2, 3, 123456789, time.FixedZone("", -4*3600)), }, { DataType: "timestamp with time zone", RawType: "timestamp with time zone", ResponseUnmarshalledSample: "2017-07-10 01:02:03.123456789-04:00", ExpectedGoValue: time.Date(2017, 7, 10, 1, 2, 3, 123456789, time.FixedZone("", -4*3600)), }, { DataType: "timestamp with time zone", RawType: "timestamp with time zone", ResponseUnmarshalledSample: "2017-07-10 01:02:03.123456789 Europe/Paris", ExpectedGoValue: time.Date(2017, 7, 10, 1, 2, 3, 123456789, paris), }, { DataType: "map(varchar,varchar)", RawType: "map", Arguments: []typeArgument{ { Kind: "NAMED_TYPE", namedTypeSignature: namedTypeSignature{ TypeSignature: typeSignature{ RawType: "varchar", }, }, }, { Kind: "NAMED_TYPE", namedTypeSignature: namedTypeSignature{ TypeSignature: typeSignature{ RawType: "varchar", }, }, }, }, ResponseUnmarshalledSample: nil, ExpectedGoValue: nil, }, { // arrays return data as-is for slice scanners DataType: "array(varchar)", RawType: "array", Arguments: []typeArgument{ { Kind: "NAMED_TYPE", namedTypeSignature: namedTypeSignature{ TypeSignature: typeSignature{ RawType: "varchar", }, }, }, }, ResponseUnmarshalledSample: nil, ExpectedGoValue: nil, }, { // rows return data as-is for slice scanners DataType: "row(int, varchar(1), timestamp, array(varchar(1)))", RawType: "row", Arguments: []typeArgument{ { Kind: "NAMED_TYPE", namedTypeSignature: namedTypeSignature{ TypeSignature: typeSignature{ RawType: "integer", }, }, }, { Kind: "NAMED_TYPE", namedTypeSignature: namedTypeSignature{ TypeSignature: typeSignature{ RawType: "varchar", Arguments: []typeArgument{ { Kind: "LONG", long: 1, }, }, }, }, }, { Kind: "NAMED_TYPE", namedTypeSignature: namedTypeSignature{ TypeSignature: typeSignature{ RawType: "timestamp", }, }, }, { Kind: "NAMED_TYPE", namedTypeSignature: namedTypeSignature{ TypeSignature: typeSignature{ RawType: "array", Arguments: []typeArgument{ { Kind: "TYPE", typeSignature: typeSignature{ RawType: "varchar", Arguments: []typeArgument{ { Kind: "LONG", long: 1, }, }, }, }, }, }, }, }, }, ResponseUnmarshalledSample: []interface{}{ json.Number("1"), "a", "2017-07-10 01:02:03.000 UTC", []interface{}{"b"}, }, ExpectedGoValue: []interface{}{ json.Number("1"), "a", "2017-07-10 01:02:03.000 UTC", []interface{}{"b"}, }, }, { DataType: "Geometry", RawType: "Geometry", ResponseUnmarshalledSample: "Point (0 0)", ExpectedGoValue: "Point (0 0)", }, { DataType: "SphericalGeography", RawType: "SphericalGeography", ResponseUnmarshalledSample: "Point (0 0)", ExpectedGoValue: "Point (0 0)", }, } for _, tc := range testcases { converter, err := newTypeConverter(tc.DataType, typeSignature{RawType: tc.RawType, Arguments: tc.Arguments}) assert.NoError(t, err) t.Run(tc.DataType+":nil", func(t *testing.T) { _, err := converter.ConvertValue(nil) assert.NoError(t, err) }) t.Run(tc.DataType+":bogus", func(t *testing.T) { _, err := converter.ConvertValue(struct{}{}) assert.Error(t, err, "bogus data scanned with no error") }) t.Run(tc.DataType+":sample", func(t *testing.T) { v, err := converter.ConvertValue(tc.ResponseUnmarshalledSample) require.NoError(t, err) require.Equal(t, v, tc.ExpectedGoValue, "unexpected data from sample:\nhave %+v\nwant %+v", v, tc.ExpectedGoValue) }) } } func TestSliceTypeConversion(t *testing.T) { testcases := []struct { GoType string Scanner sql.Scanner TrinoResponseUnmarshalledSample interface{} TestScanner func(t *testing.T, s sql.Scanner, isValid bool) }{ { GoType: "[]bool", Scanner: &NullSliceBool{}, TrinoResponseUnmarshalledSample: []interface{}{true}, TestScanner: func(t *testing.T, s sql.Scanner, isValid bool) { v, _ := s.(*NullSliceBool) assert.Equal(t, isValid, v.Valid, "scanner failed") }, }, { GoType: "[]string", Scanner: &NullSliceString{}, TrinoResponseUnmarshalledSample: []interface{}{"hello"}, TestScanner: func(t *testing.T, s sql.Scanner, isValid bool) { v, _ := s.(*NullSliceString) assert.Equal(t, isValid, v.Valid, "scanner failed") }, }, { GoType: "[]int64", Scanner: &NullSliceInt64{}, TrinoResponseUnmarshalledSample: []interface{}{json.Number("1")}, TestScanner: func(t *testing.T, s sql.Scanner, isValid bool) { v, _ := s.(*NullSliceInt64) assert.Equal(t, isValid, v.Valid, "scanner failed") }, }, { GoType: "[]float64", Scanner: &NullSliceFloat64{}, TrinoResponseUnmarshalledSample: []interface{}{json.Number("1.0")}, TestScanner: func(t *testing.T, s sql.Scanner, isValid bool) { v, _ := s.(*NullSliceFloat64) assert.Equal(t, isValid, v.Valid, "scanner failed") }, }, { GoType: "[]time.Time", Scanner: &NullSliceTime{}, TrinoResponseUnmarshalledSample: []interface{}{"2017-07-01"}, TestScanner: func(t *testing.T, s sql.Scanner, isValid bool) { v, _ := s.(*NullSliceTime) assert.Equal(t, isValid, v.Valid, "scanner failed") }, }, { GoType: "[]map[string]interface{}", Scanner: &NullSliceMap{}, TrinoResponseUnmarshalledSample: []interface{}{map[string]interface{}{"hello": "world"}}, TestScanner: func(t *testing.T, s sql.Scanner, isValid bool) { v, _ := s.(*NullSliceMap) assert.Equal(t, isValid, v.Valid, "scanner failed") }, }, } for _, tc := range testcases { t.Run(tc.GoType+":nil", func(t *testing.T) { assert.NoError(t, tc.Scanner.Scan(nil)) }) t.Run(tc.GoType+":bogus", func(t *testing.T) { assert.Error(t, tc.Scanner.Scan(struct{}{})) assert.Error(t, tc.Scanner.Scan([]interface{}{struct{}{}}), "bogus data scanned with no error") }) t.Run(tc.GoType+":sample", func(t *testing.T) { require.NoError(t, tc.Scanner.Scan(tc.TrinoResponseUnmarshalledSample)) tc.TestScanner(t, tc.Scanner, true) require.NoError(t, tc.Scanner.Scan(nil)) tc.TestScanner(t, tc.Scanner, false) }) } } func TestSlice2TypeConversion(t *testing.T) { testcases := []struct { GoType string Scanner sql.Scanner TrinoResponseUnmarshalledSample interface{} TestScanner func(t *testing.T, s sql.Scanner, isValid bool) }{ { GoType: "[][]bool", Scanner: &NullSlice2Bool{}, TrinoResponseUnmarshalledSample: []interface{}{[]interface{}{true}}, TestScanner: func(t *testing.T, s sql.Scanner, isValid bool) { v, _ := s.(*NullSlice2Bool) assert.Equal(t, isValid, v.Valid, "scanner failed") }, }, { GoType: "[][]string", Scanner: &NullSlice2String{}, TrinoResponseUnmarshalledSample: []interface{}{[]interface{}{"hello"}}, TestScanner: func(t *testing.T, s sql.Scanner, isValid bool) { v, _ := s.(*NullSlice2String) assert.Equal(t, isValid, v.Valid, "scanner failed") }, }, { GoType: "[][]int64", Scanner: &NullSlice2Int64{}, TrinoResponseUnmarshalledSample: []interface{}{[]interface{}{json.Number("1")}}, TestScanner: func(t *testing.T, s sql.Scanner, isValid bool) { v, _ := s.(*NullSlice2Int64) assert.Equal(t, isValid, v.Valid, "scanner failed") }, }, { GoType: "[][]float64", Scanner: &NullSlice2Float64{}, TrinoResponseUnmarshalledSample: []interface{}{[]interface{}{json.Number("1.0")}}, TestScanner: func(t *testing.T, s sql.Scanner, isValid bool) { v, _ := s.(*NullSlice2Float64) assert.Equal(t, isValid, v.Valid, "scanner failed") }, }, { GoType: "[][]time.Time", Scanner: &NullSlice2Time{}, TrinoResponseUnmarshalledSample: []interface{}{[]interface{}{"2017-07-01"}}, TestScanner: func(t *testing.T, s sql.Scanner, isValid bool) { v, _ := s.(*NullSlice2Time) assert.Equal(t, isValid, v.Valid, "scanner failed") }, }, { GoType: "[][]map[string]interface{}", Scanner: &NullSlice2Map{}, TrinoResponseUnmarshalledSample: []interface{}{[]interface{}{map[string]interface{}{"hello": "world"}}}, TestScanner: func(t *testing.T, s sql.Scanner, isValid bool) { v, _ := s.(*NullSlice2Map) assert.Equal(t, isValid, v.Valid, "scanner failed") }, }, } for _, tc := range testcases { t.Run(tc.GoType+":nil", func(t *testing.T) { assert.NoError(t, tc.Scanner.Scan(nil)) assert.NoError(t, tc.Scanner.Scan([]interface{}{nil})) }) t.Run(tc.GoType+":bogus", func(t *testing.T) { assert.Error(t, tc.Scanner.Scan(struct{}{}), "bogus data scanned with no error") assert.Error(t, tc.Scanner.Scan([]interface{}{struct{}{}}), "bogus data scanned with no error") assert.Error(t, tc.Scanner.Scan([]interface{}{[]interface{}{struct{}{}}}), "bogus data scanned with no error") }) t.Run(tc.GoType+":sample", func(t *testing.T) { require.NoError(t, tc.Scanner.Scan(tc.TrinoResponseUnmarshalledSample)) tc.TestScanner(t, tc.Scanner, true) require.NoError(t, tc.Scanner.Scan(nil)) tc.TestScanner(t, tc.Scanner, false) }) } } func TestSlice3TypeConversion(t *testing.T) { testcases := []struct { GoType string Scanner sql.Scanner TrinoResponseUnmarshalledSample interface{} TestScanner func(t *testing.T, s sql.Scanner, isValid bool) }{ { GoType: "[][][]bool", Scanner: &NullSlice3Bool{}, TrinoResponseUnmarshalledSample: []interface{}{[]interface{}{[]interface{}{true}}}, TestScanner: func(t *testing.T, s sql.Scanner, isValid bool) { v, _ := s.(*NullSlice3Bool) assert.Equal(t, isValid, v.Valid, "scanner failed") }, }, { GoType: "[][][]string", Scanner: &NullSlice3String{}, TrinoResponseUnmarshalledSample: []interface{}{[]interface{}{[]interface{}{"hello"}}}, TestScanner: func(t *testing.T, s sql.Scanner, isValid bool) { v, _ := s.(*NullSlice3String) assert.Equal(t, isValid, v.Valid, "scanner failed") }, }, { GoType: "[][][]int64", Scanner: &NullSlice3Int64{}, TrinoResponseUnmarshalledSample: []interface{}{[]interface{}{[]interface{}{json.Number("1")}}}, TestScanner: func(t *testing.T, s sql.Scanner, isValid bool) { v, _ := s.(*NullSlice3Int64) assert.Equal(t, isValid, v.Valid, "scanner failed") }, }, { GoType: "[][][]float64", Scanner: &NullSlice3Float64{}, TrinoResponseUnmarshalledSample: []interface{}{[]interface{}{[]interface{}{json.Number("1.0")}}}, TestScanner: func(t *testing.T, s sql.Scanner, isValid bool) { v, _ := s.(*NullSlice3Float64) assert.Equal(t, isValid, v.Valid, "scanner failed") }, }, { GoType: "[][][]time.Time", Scanner: &NullSlice3Time{}, TrinoResponseUnmarshalledSample: []interface{}{[]interface{}{[]interface{}{"2017-07-01"}}}, TestScanner: func(t *testing.T, s sql.Scanner, isValid bool) { v, _ := s.(*NullSlice3Time) assert.Equal(t, isValid, v.Valid, "scanner failed") }, }, { GoType: "[][][]map[string]interface{}", Scanner: &NullSlice3Map{}, TrinoResponseUnmarshalledSample: []interface{}{[]interface{}{[]interface{}{map[string]interface{}{"hello": "world"}}}}, TestScanner: func(t *testing.T, s sql.Scanner, isValid bool) { v, _ := s.(*NullSlice3Map) assert.Equal(t, isValid, v.Valid, "scanner failed") }, }, } for _, tc := range testcases { t.Run(tc.GoType+":nil", func(t *testing.T) { assert.NoError(t, tc.Scanner.Scan(nil)) assert.NoError(t, tc.Scanner.Scan([]interface{}{[]interface{}{nil}})) }) t.Run(tc.GoType+":bogus", func(t *testing.T) { assert.Error(t, tc.Scanner.Scan(struct{}{}), "bogus data scanned with no error") assert.Error(t, tc.Scanner.Scan([]interface{}{[]interface{}{struct{}{}}}), "bogus data scanned with no error") assert.Error(t, tc.Scanner.Scan([]interface{}{[]interface{}{[]interface{}{struct{}{}}}}), "bogus data scanned with no error") }) t.Run(tc.GoType+":sample", func(t *testing.T) { require.NoError(t, tc.Scanner.Scan(tc.TrinoResponseUnmarshalledSample)) tc.TestScanner(t, tc.Scanner, true) require.NoError(t, tc.Scanner.Scan(nil)) tc.TestScanner(t, tc.Scanner, false) }) } } func BenchmarkQuery(b *testing.B) { c := &Config{ ServerURI: *integrationServerFlag, SessionProperties: map[string]string{"query_priority": "1"}, } dsn, err := c.FormatDSN() require.NoError(b, err) db, err := sql.Open("trino", dsn) require.NoError(b, err) b.Cleanup(func() { assert.NoError(b, db.Close()) }) q := `SELECT * FROM tpch.sf1.orders LIMIT 10000000` for n := 0; n < b.N; n++ { rows, err := db.Query(q) require.NoError(b, err) for rows.Next() { } rows.Close() } } // BenchmarkSpoolingProtocolSpooledSegmentlJsonZstdDecoderQuery benchmarks the performance of querying a large dataset // from Trino with JSON encoding and Zstd compression, testing the spooling mechanism. The query retrieves a result set // of 10 million rows, exceeding the default inline row limit of 1000 (defined by `protocol.spooling.inlining.max-rows`), // triggering the spooling mechanism to handle the large data efficiently. // // **Session properties & headers:** // - **`encoding: json+zstd`**: Specifies JSON encoding with Zstd compression for the query result. // - **`protocol.spooling.inlining.max-rows`**: Default is 1000, determining when spooling is triggered to manage large result sets. func BenchmarkSpoolingProtocolSpooledSegmentlJsonZstdDecoderQuery(b *testing.B) { c := &Config{ ServerURI: *integrationServerFlag, SessionProperties: map[string]string{"query_priority": "1"}, } dsn, err := c.FormatDSN() require.NoError(b, err) db, err := sql.Open("trino", dsn) require.NoError(b, err) b.Cleanup(func() { assert.NoError(b, db.Close()) }) q := `SELECT * FROM tpch.sf1.orders LIMIT 10000000` for n := 0; n < b.N; n++ { rows, err := db.Query(q, sql.Named(trinoEncoding, "json+zstd")) require.NoError(b, err) for rows.Next() { } rows.Close() } } // BenchmarkSpoolingProtocolSpooledSegmentJsonLz4DecoderQuery benchmarks the performance of querying a large dataset // from Trino with JSON encoding and LZ4 compression, testing the spooling mechanism. The query retrieves a result set // of 10 million rows, exceeding the default inline row limit of 1000 (defined by `protocol.spooling.inlining.max-rows`), // triggering the spooling mechanism to handle the large data efficiently. // // **Session properties & headers:** // - **`encoding: json+lz4`**: Specifies JSON encoding with LZ4 compression for the query result. // - **`protocol.spooling.inlining.max-rows`**: Default is 1000, determining when spooling is triggered to manage large result sets. func BenchmarkSpoolingProtocolSpooledSegmentJsonLz4DecoderQuery(b *testing.B) { c := &Config{ ServerURI: *integrationServerFlag, SessionProperties: map[string]string{"query_priority": "1"}, } dsn, err := c.FormatDSN() require.NoError(b, err) db, err := sql.Open("trino", dsn) require.NoError(b, err) b.Cleanup(func() { assert.NoError(b, db.Close()) }) q := `SELECT * FROM tpch.sf1.orders LIMIT 10000000` for n := 0; n < b.N; n++ { rows, err := db.Query(q, sql.Named(trinoEncoding, "json+lz4")) require.NoError(b, err) for rows.Next() { } rows.Close() } } // BenchmarkSpoolingProtocolSpooledSegmentJsonDecoderQuery benchmarks the performance of querying a large dataset // from Trino with JSON encoding (without compression), testing the spooling mechanism. The query retrieves a result set // of 10 million rows, exceeding the default inline row limit of 1000 (defined by `protocol.spooling.inlining.max-rows`), // triggering the spooling mechanism to handle the large data efficiently. // // **Session properties & headers:** // - **`encoding: json`**: Specifies JSON encoding without compression for the query result. // - **`protocol.spooling.inlining.max-rows`**: Default is 1000, determining when spooling is triggered to manage large result sets func BenchmarkSpoolingProtocolSpooledSegmentJsonDecoderQuery(b *testing.B) { c := &Config{ ServerURI: *integrationServerFlag, SessionProperties: map[string]string{"query_priority": "1"}, } dsn, err := c.FormatDSN() require.NoError(b, err) db, err := sql.Open("trino", dsn) require.NoError(b, err) b.Cleanup(func() { assert.NoError(b, db.Close()) }) q := `SELECT * FROM tpch.sf1.orders LIMIT 10000000` for n := 0; n < b.N; n++ { rows, err := db.Query(q, sql.Named(trinoEncoding, "json")) require.NoError(b, err) for rows.Next() { } rows.Close() } } func TestExec(t *testing.T) { if testing.Short() { t.Skip("Skipping test in short mode.") } c := &Config{ ServerURI: *integrationServerFlag, SessionProperties: map[string]string{"query_priority": "1"}, } dsn, err := c.FormatDSN() require.NoError(t, err) db, err := sql.Open("trino", dsn) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, db.Close()) }) _, err = db.Exec("CREATE TABLE memory.default.test (id INTEGER, name VARCHAR, optional VARCHAR)") require.NoError(t, err, "Failed executing CREATE TABLE query") result, err := db.Exec("INSERT INTO memory.default.test (id, name, optional) VALUES (?, ?, ?), (?, ?, ?), (?, ?, ?)", 123, "abc", nil, 456, "def", "present", 789, "ghi", nil) require.NoError(t, err, "Failed executing INSERT query") _, err = result.LastInsertId() assert.Error(t, err, "trino: operation not supported") numRows, err := result.RowsAffected() require.NoError(t, err, "Failed checking rows affected") assert.Equal(t, numRows, int64(3)) rows, err := db.Query("SELECT * FROM memory.default.test") require.NoError(t, err, "Failed executing DELETE query") expectedIds := []int{123, 456, 789} expectedNames := []string{"abc", "def", "ghi"} expectedOptionals := []sql.NullString{ sql.NullString{Valid: false}, sql.NullString{String: "present", Valid: true}, sql.NullString{Valid: false}, } actualIds := []int{} actualNames := []string{} actualOptionals := []sql.NullString{} for rows.Next() { var id int var name string var optional sql.NullString require.NoError(t, rows.Scan(&id, &name, &optional), "Failed scanning query result") actualIds = append(actualIds, id) actualNames = append(actualNames, name) actualOptionals = append(actualOptionals, optional) } assert.Equal(t, expectedIds, actualIds) assert.Equal(t, expectedNames, actualNames) assert.Equal(t, expectedOptionals, actualOptionals) _, err = db.Exec("DROP TABLE memory.default.test") require.NoError(t, err, "Failed executing DROP TABLE query") } func TestForwardAuthorizationHeaderConfig(t *testing.T) { c := &Config{ ServerURI: "https://foobar@localhost:8090", ForwardAuthorizationHeader: true, } dsn, err := c.FormatDSN() require.NoError(t, err) want := "https://foobar@localhost:8090?forwardAuthorizationHeader=true&source=trino-go-client" assert.Equal(t, want, dsn) } func TestForwardAuthorizationHeader(t *testing.T) { var captureAuthHeader string ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Capture the Authorization header for later inspection captureAuthHeader = r.Header.Get("Authorization") })) t.Cleanup(ts.Close) db, err := sql.Open("trino", ts.URL+"?forwardAuthorizationHeader=true") require.NoError(t, err) _, _ = db.Query("SELECT 1", sql.Named("accessToken", string("token"))) // Ingore response to focus on header capture require.Equal(t, "Bearer token", captureAuthHeader, "Authorization header is incorrect") assert.NoError(t, db.Close()) } func TestQueryTimeoutDeadline(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { time.Sleep(200 * time.Millisecond) // Simulate slow response w.WriteHeader(http.StatusOK) })) defer ts.Close() testcases := []struct { name string queryTimeout string expectedError string }{ { name: "with timeout", queryTimeout: "100ms", expectedError: "context deadline exceeded", }, { name: "without timeout", queryTimeout: "10s", expectedError: "EOF", // Default server response }, { name: "bad timeout", queryTimeout: "abc", expectedError: "trino: invalid timeout", // Default server response }, } for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { println(ts.URL + "?query_timeout=" + tc.queryTimeout) db, err := sql.Open("trino", ts.URL+"?query_timeout="+tc.queryTimeout) require.NoError(t, err) defer db.Close() _, err = db.Query("SELECT 1") assert.ErrorContains(t, err, tc.expectedError) }) } }